├── .github ├── FUNDING.yml ├── ISSUE_TEMPLATE │ └── bug_report.md └── workflows │ ├── code_formatter.yml │ ├── nextjs.yml │ └── unittest.yml ├── .gitignore ├── .vscode └── settings.json ├── LICENSE ├── README.md ├── docs ├── components │ ├── counters.module.css │ └── counters.tsx ├── favicon.ico ├── next-env.d.ts ├── next.config.js ├── package.json ├── pages │ ├── Usage │ │ ├── _meta.json │ │ ├── rvc.mdx │ │ └── uvr.mdx │ ├── _meta.json │ ├── index.mdx │ └── installation.mdx ├── pnpm-lock.yaml ├── theme.config.tsx └── tsconfig.json ├── install.bat ├── install.sh ├── logs ├── mute │ ├── extracted │ │ └── mute.npy │ ├── f0 │ │ └── mute.wav.npy │ ├── f0_voiced │ │ └── mute.wav.npy │ ├── sliced_audios │ │ ├── mute32000.wav │ │ ├── mute40000.wav │ │ ├── mute44100.wav │ │ └── mute48000.wav │ └── sliced_audios_16k │ │ └── mute.wav └── reference │ ├── ref32000.wav │ ├── ref32000_f0c.npy │ ├── ref32000_f0f.npy │ ├── ref32000_feats.npy │ ├── ref40000.wav │ ├── ref40000_f0c.npy │ ├── ref40000_f0f.npy │ ├── ref40000_feats.npy │ ├── ref48000.wav │ ├── ref48000_f0c.npy │ ├── ref48000_f0f.npy │ └── ref48000_feats.npy ├── requirements.txt ├── rvc ├── configs │ ├── 32000.json │ ├── 40000.json │ ├── 48000.json │ └── config.py ├── infer │ ├── infer.py │ └── pipeline.py ├── lib │ ├── algorithm │ │ ├── __init__.py │ │ ├── attentions.py │ │ ├── commons.py │ │ ├── discriminators.py │ │ ├── encoders.py │ │ ├── generators │ │ │ ├── __init__.py │ │ │ ├── hifigan.py │ │ │ ├── hifigan_mrf.py │ │ │ ├── hifigan_nsf.py │ │ │ └── refinegan.py │ │ ├── modules.py │ │ ├── normalization.py │ │ ├── residuals.py │ │ └── synthesizers.py │ ├── predictors │ │ ├── F0Extractor.py │ │ ├── FCPE.py │ │ └── RMVPE.py │ ├── tools │ │ ├── analyzer.py │ │ ├── gdown.py │ │ ├── launch_tensorboard.py │ │ ├── model_download.py │ │ ├── prerequisites_download.py │ │ ├── pretrained_selector.py │ │ ├── split_audio.py │ │ ├── tts.py │ │ └── tts_voices.json │ ├── utils.py │ └── zluda.py ├── models │ ├── embedders │ │ ├── .gitkeep │ │ └── embedders_custom │ │ │ └── .gitkeep │ ├── formant │ │ └── .gitkeep │ ├── predictors │ │ └── .gitkeep │ └── pretraineds │ │ ├── .gitkeep │ │ ├── custom │ │ └── .gitkeep │ │ └── hifi-gan │ │ └── .gitkeep └── train │ ├── data_utils.py │ ├── extract │ ├── extract.py │ └── preparing_files.py │ ├── losses.py │ ├── mel_processing.py │ ├── preprocess │ ├── preprocess.py │ └── slicer.py │ ├── process │ ├── change_info.py │ ├── extract_index.py │ ├── extract_model.py │ ├── model_blender.py │ └── model_information.py │ ├── train.py │ └── utils.py ├── rvc_cli.py ├── uvr ├── __init__.py ├── architectures │ ├── __init__.py │ ├── demucs_separator.py │ ├── mdx_separator.py │ ├── mdxc_separator.py │ └── vr_separator.py ├── common_separator.py ├── separator.py └── uvr_lib_v5 │ ├── __init__.py │ ├── attend.py │ ├── bs_roformer.py │ ├── demucs │ ├── __init__.py │ ├── __main__.py │ ├── apply.py │ ├── demucs.py │ ├── filtering.py │ ├── hdemucs.py │ ├── htdemucs.py │ ├── model.py │ ├── model_v2.py │ ├── pretrained.py │ ├── repo.py │ ├── spec.py │ ├── states.py │ ├── tasnet.py │ ├── tasnet_v2.py │ ├── transformer.py │ └── utils.py │ ├── mdxnet.py │ ├── mel_band_roformer.py │ ├── mixer.ckpt │ ├── modules.py │ ├── playsound.py │ ├── pyrb.py │ ├── results.py │ ├── spec_utils.py │ ├── stft.py │ ├── tfc_tdf_v3.py │ └── vr_network │ ├── __init__.py │ ├── layers.py │ ├── layers_new.py │ ├── model_param_init.py │ ├── modelparams │ ├── 1band_sr16000_hl512.json │ ├── 1band_sr32000_hl512.json │ ├── 1band_sr33075_hl384.json │ ├── 1band_sr44100_hl1024.json │ ├── 1band_sr44100_hl256.json │ ├── 1band_sr44100_hl512.json │ ├── 1band_sr44100_hl512_cut.json │ ├── 1band_sr44100_hl512_nf1024.json │ ├── 2band_32000.json │ ├── 2band_44100_lofi.json │ ├── 2band_48000.json │ ├── 3band_44100.json │ ├── 3band_44100_mid.json │ ├── 3band_44100_msb2.json │ ├── 4band_44100.json │ ├── 4band_44100_mid.json │ ├── 4band_44100_msb.json │ ├── 4band_44100_msb2.json │ ├── 4band_44100_reverse.json │ ├── 4band_44100_sw.json │ ├── 4band_v2.json │ ├── 4band_v2_sn.json │ ├── 4band_v3.json │ ├── 4band_v3_sn.json │ └── ensemble.json │ ├── nets.py │ └── nets_new.py └── uvr_cli.py /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: # 4 | patreon: # Replace with a single Patreon username 5 | open_collective: # Replace with a single Open Collective username 6 | ko_fi: iahispano 7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 9 | liberapay: # Replace with a single Liberapay username 10 | issuehunt: # Replace with a single IssueHunt username 11 | otechie: # Replace with a single Otechie username 12 | lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry 13 | custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] 14 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: "[BUG]" 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Before You Report a Bug** 11 | Reporting a bug is essential for us to improve our service, but we need detailed information to address the issue effectively. Since every computer setup is unique, there can be various reasons behind a bug. Before reporting, consider potential causes and provide as much detail as possible to help us understand the problem. 12 | 13 | **Bug Description** 14 | Please provide a clear and concise description of the bug. 15 | 16 | **Steps to Reproduce** 17 | Outline the steps to replicate the issue: 18 | 1. Go to '...' 19 | 2. Click on '....' 20 | 3. Scroll down to '....' 21 | 4. Observe the error. 22 | 23 | **Expected Behavior** 24 | Describe what you expected to happen. 25 | 26 | **Assets** 27 | Include screenshots or videos if they can illustrate the issue. 28 | 29 | **Desktop Details:** 30 | - Operating System: [e.g., Windows 11] 31 | - Browser: [e.g., Chrome, Safari] 32 | 33 | **Additional Context** 34 | Any additional information that might be relevant to the issue. 35 | -------------------------------------------------------------------------------- /.github/workflows/code_formatter.yml: -------------------------------------------------------------------------------- 1 | name: Code Formatter 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | 8 | jobs: 9 | push_format: 10 | runs-on: ubuntu-latest 11 | 12 | permissions: 13 | contents: write 14 | pull-requests: write 15 | 16 | steps: 17 | - uses: actions/checkout@v3 18 | with: 19 | ref: ${{github.ref_name}} 20 | 21 | - name: Set up Python ${{ matrix.python-version }} 22 | uses: actions/setup-python@v4 23 | with: 24 | python-version: ${{ matrix.python-version }} 25 | 26 | - name: Install Black 27 | run: pip install "black[jupyter]" 28 | 29 | - name: Run Black 30 | # run: black $(git ls-files '*.py') 31 | run: black . --exclude=".*\.ipynb$" 32 | 33 | - name: Commit Back 34 | continue-on-error: true 35 | id: commitback 36 | run: | 37 | git config --local user.email "github-actions[bot]@users.noreply.github.com" 38 | git config --local user.name "github-actions[bot]" 39 | git add --all 40 | git commit -m "chore(format): run black on ${{github.ref_name}}" 41 | 42 | - name: Create Pull Request 43 | if: steps.commitback.outcome == 'success' 44 | continue-on-error: true 45 | uses: peter-evans/create-pull-request@v5 46 | with: 47 | delete-branch: true 48 | body: "Automatically apply code formatter change" 49 | title: "chore(format): run black on ${{github.ref_name}}" 50 | commit-message: "chore(format): run black on ${{github.ref_name}}" 51 | branch: formatter-${{github.ref_name}} 52 | -------------------------------------------------------------------------------- /.github/workflows/nextjs.yml: -------------------------------------------------------------------------------- 1 | # Sample workflow for building and deploying a Next.js site to GitHub Pages 2 | # 3 | # To get started with Next.js see: https://nextjs.org/docs/getting-started 4 | # 5 | name: Deploy Next.js site to Pages 6 | 7 | on: 8 | # Runs on pushes targeting the default branch 9 | push: 10 | branches: ["main"] 11 | 12 | # Allows you to run this workflow manually from the Actions tab 13 | workflow_dispatch: 14 | 15 | # Sets permissions of the GITHUB_TOKEN to allow deployment to GitHub Pages 16 | permissions: 17 | contents: read 18 | pages: write 19 | id-token: write 20 | 21 | # Allow only one concurrent deployment, skipping runs queued between the run in-progress and latest queued. 22 | # However, do NOT cancel in-progress runs as we want to allow these production deployments to complete. 23 | concurrency: 24 | group: "pages" 25 | cancel-in-progress: false 26 | 27 | jobs: 28 | # Build job 29 | build: 30 | runs-on: ubuntu-latest 31 | steps: 32 | - name: Checkout 33 | uses: actions/checkout@v4 34 | - name: Detect package manager 35 | id: detect-package-manager 36 | run: | 37 | if [ -f "${{ github.workspace }}/yarn.lock" ]; then 38 | echo "manager=yarn" >> $GITHUB_OUTPUT 39 | echo "command=install" >> $GITHUB_OUTPUT 40 | echo "runner=yarn" >> $GITHUB_OUTPUT 41 | exit 0 42 | elif [ -f "${{ github.workspace }}/package.json" ]; then 43 | echo "manager=npm" >> $GITHUB_OUTPUT 44 | echo "command=ci" >> $GITHUB_OUTPUT 45 | echo "runner=npx --no-install" >> $GITHUB_OUTPUT 46 | exit 0 47 | else 48 | echo "Unable to determine package manager" 49 | exit 1 50 | fi 51 | - name: Setup Node 52 | uses: actions/setup-node@v4 53 | with: 54 | node-version: "20" 55 | cache: ${{ steps.detect-package-manager.outputs.manager }} 56 | - name: Setup Pages 57 | uses: actions/configure-pages@v5 58 | with: 59 | # Automatically inject basePath in your Next.js configuration file and disable 60 | # server side image optimization (https://nextjs.org/docs/api-reference/next/image#unoptimized). 61 | # 62 | # You may remove this line if you want to manage the configuration yourself. 63 | static_site_generator: next 64 | - name: Restore cache 65 | uses: actions/cache@v4 66 | with: 67 | path: | 68 | .next/cache 69 | # Generate a new cache whenever packages or source files change. 70 | key: ${{ runner.os }}-nextjs-${{ hashFiles('**/package-lock.json', '**/yarn.lock') }}-${{ hashFiles('**.[jt]s', '**.[jt]sx') }} 71 | # If source files changed but packages didn't, rebuild from a prior cache. 72 | restore-keys: | 73 | ${{ runner.os }}-nextjs-${{ hashFiles('**/package-lock.json', '**/yarn.lock') }}- 74 | - name: Install dependencies 75 | run: ${{ steps.detect-package-manager.outputs.manager }} ${{ steps.detect-package-manager.outputs.command }} 76 | - name: Build with Next.js 77 | run: ${{ steps.detect-package-manager.outputs.runner }} next build 78 | - name: Upload artifact 79 | uses: actions/upload-pages-artifact@v3 80 | with: 81 | path: ./out 82 | 83 | # Deployment job 84 | deploy: 85 | environment: 86 | name: github-pages 87 | url: ${{ steps.deployment.outputs.page_url }} 88 | runs-on: ubuntu-latest 89 | needs: build 90 | steps: 91 | - name: Deploy to GitHub Pages 92 | id: deployment 93 | uses: actions/deploy-pages@v4 94 | -------------------------------------------------------------------------------- /.github/workflows/unittest.yml: -------------------------------------------------------------------------------- 1 | name: Test preprocess and extract 2 | on: [push, pull_request] 3 | jobs: 4 | build: 5 | runs-on: ${{ matrix.os }} 6 | strategy: 7 | matrix: 8 | python-version: ["3.9", "3.10"] 9 | os: [ubuntu-latest] 10 | fail-fast: true 11 | 12 | steps: 13 | - uses: actions/checkout@master 14 | - name: Set up Python ${{ matrix.python-version }} 15 | uses: actions/setup-python@v4 16 | with: 17 | python-version: ${{ matrix.python-version }} 18 | - name: Install dependencies 19 | run: | 20 | sudo apt update 21 | sudo apt -y install ffmpeg 22 | python -m pip install --upgrade pip 23 | python -m pip install --upgrade setuptools 24 | python -m pip install --upgrade wheel 25 | pip install torch torchvision torchaudio 26 | pip install -r requirements.txt 27 | python rvc_cli.py prerequisites --models True 28 | - name: Test Preprocess 29 | run: | 30 | python rvc_cli.py preprocess --model_name "Evaluate" --dataset_path "logs/mute/sliced_audios" --sample_rate 48000 31 | - name: Test Extract 32 | run: | 33 | python rvc_cli.py extract --model_name "Evaluate" --sample_rate 48000 -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore logs folder 2 | # logs 3 | 4 | # Ignore compiled executables 5 | *.exe 6 | 7 | # Ignore model files 8 | *.pt 9 | *.onnx 10 | *.pth 11 | *.index 12 | *.bin 13 | *.ckpt 14 | *.txt 15 | 16 | # Ignore Python bytecode files 17 | *.pyc 18 | 19 | # Ignore environment and virtual environment directories 20 | env/ 21 | venv/ 22 | 23 | # Ignore cached files 24 | .cache/ 25 | docs/.next 26 | node_modules 27 | 28 | # Ignore specific project directories 29 | /tracks/ 30 | /lyrics/ 31 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.languageServer": "Pylance", 3 | "python.analysis.diagnosticSeverityOverrides": { 4 | "reportShadowedImports": "none" 5 | }, 6 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## 🚀 RVC + UVR = A perfect set of tools for voice cloning, easily and free! 2 | 3 | [![Open In Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/iahispano/applio/blob/master/assets/Applio_NoUI.ipynb) 4 | 5 | > [!NOTE] 6 | > Issues are not handled in this repository due to time constraints. For questions or discussions, feel free to join [AI Hispano on Discord](https://discord.gg/iahispano). If you're experiencing a technical issue, please report it on [Applio's Issues page](https://github.com/IAHispano/Applio/issues) and specify that the problem occurs via CLI. 7 | > This repository is partially abandoned, use Applio for more constant updates (You can use CLI through the `core.py` file). 8 | 9 | ### Installation 10 | 11 | Ensure that you have the necessary Python packages installed by following these steps (Python 3.9 is recommended): 12 | 13 | #### Windows 14 | 15 | Execute the [install.bat](./install.bat) file to activate a Conda environment. Afterward, launch the application using `env/python.exe rvc_cli.py` instead of the conventional `python rvc_cli.py` command. 16 | 17 | #### Linux 18 | 19 | ```bash 20 | chmod +x install.sh 21 | ./install.sh 22 | ``` 23 | 24 | ### Getting Started 25 | 26 | For detailed information and command-line options, refer to the help command: 27 | 28 | ```bash 29 | python rvc_cli.py -h 30 | python uvr_cli.py -h 31 | ``` 32 | 33 | This command provides a clear overview of the available modes and their corresponding parameters, facilitating effective utilization of the RVC CLI, but if you need more information, you can check the [documentation](https://rvc-cli.pages.dev/). 34 | 35 | ### References 36 | 37 | The RVC CLI builds upon the foundations of the following projects: 38 | 39 | - **Vocoders:** 40 | 41 | - [HiFi-GAN](https://github.com/jik876/hifi-gan) by jik876 42 | - [Vocos](https://github.com/gemelo-ai/vocos) by gemelo-ai 43 | - [BigVGAN](https://github.com/NVIDIA/BigVGAN) by NVIDIA 44 | - [BigVSAN](https://github.com/sony/bigvsan) by sony 45 | - [vocoders](https://github.com/reppy4620/vocoders) by reppy4620 46 | - [vocoder](https://github.com/fishaudio/vocoder) by fishaudio 47 | 48 | - **VC Clients:** 49 | 50 | - [Retrieval-based-Voice-Conversion-WebUI](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI) by RVC-Project 51 | - [So-Vits-SVC](https://github.com/svc-develop-team/so-vits-svc) by svc-develop-team 52 | - [Mangio-RVC-Fork](https://github.com/Mangio621/Mangio-RVC-Fork) by Mangio621 53 | - [VITS](https://github.com/jaywalnut310/vits) by jaywalnut310 54 | - [Harmonify](https://huggingface.co/Eempostor/Harmonify) by Eempostor 55 | - [rvc-trainer](https://github.com/thepowerfuldeez/rvc-trainer) by thepowerfuldeez 56 | 57 | - **Pitch Extractors:** 58 | 59 | - [RMVPE](https://github.com/Dream-High/RMVPE) by Dream-High 60 | - [torchfcpe](https://github.com/CNChTu/FCPE) by CNChTu 61 | - [torchcrepe](https://github.com/maxrmorrison/torchcrepe) by maxrmorrison 62 | - [anyf0](https://github.com/SoulMelody/anyf0) by SoulMelody 63 | 64 | - **Other:** 65 | - [FAIRSEQ](https://github.com/facebookresearch/fairseq) by facebookresearch 66 | - [FAISS](https://github.com/facebookresearch/faiss) by facebookresearch 67 | - [ContentVec](https://github.com/auspicious3000/contentvec/) by auspicious3000 68 | - [audio-slicer](https://github.com/openvpi/audio-slicer) by openvpi 69 | - [python-audio-separator](https://github.com/karaokenerds/python-audio-separator) by karaokenerds 70 | - [ultimatevocalremovergui](https://github.com/Anjok07/ultimatevocalremovergui) by Anjok07 71 | 72 | We acknowledge and appreciate the contributions of the respective authors and communities involved in these projects. 73 | -------------------------------------------------------------------------------- /docs/components/counters.module.css: -------------------------------------------------------------------------------- 1 | .counter { 2 | border: 1px solid #ccc; 3 | border-radius: 5px; 4 | padding: 2px 6px; 5 | margin: 12px 0 0; 6 | } 7 | -------------------------------------------------------------------------------- /docs/components/counters.tsx: -------------------------------------------------------------------------------- 1 | // Example from https://beta.reactjs.org/learn 2 | 3 | import { useState } from 'react' 4 | import styles from './counters.module.css' 5 | 6 | function MyButton() { 7 | const [count, setCount] = useState(0) 8 | 9 | function handleClick() { 10 | setCount(count + 1) 11 | } 12 | 13 | return ( 14 |
15 | 18 |
19 | ) 20 | } 21 | 22 | export default function MyApp() { 23 | return 24 | } 25 | -------------------------------------------------------------------------------- /docs/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blaisewf/rvc-cli/a732f87a54780e26e4a0e3c91b2d93f848eef6d2/docs/favicon.ico -------------------------------------------------------------------------------- /docs/next-env.d.ts: -------------------------------------------------------------------------------- 1 | /// 2 | /// 3 | 4 | // NOTE: This file should not be edited 5 | // see https://nextjs.org/docs/basic-features/typescript for more information. 6 | -------------------------------------------------------------------------------- /docs/next.config.js: -------------------------------------------------------------------------------- 1 | const withNextra = require('nextra')({ 2 | theme: 'nextra-theme-docs', 3 | themeConfig: './theme.config.tsx', 4 | }) 5 | 6 | module.exports = withNextra() 7 | -------------------------------------------------------------------------------- /docs/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "rvc_cli", 3 | "version": "0.0.2", 4 | "description": "🚀 RVC + UVR = A perfect set of tools for voice cloning, easily and free!", 5 | "scripts": { 6 | "dev": "next dev", 7 | "build": "next build", 8 | "start": "next start" 9 | }, 10 | "dependencies": { 11 | "next": "^13.0.6", 12 | "nextra": "latest", 13 | "nextra-theme-docs": "latest", 14 | "react": "^18.2.0", 15 | "react-dom": "^18.2.0" 16 | }, 17 | "devDependencies": { 18 | "@types/node": "18.11.10", 19 | "typescript": "^4.9.3" 20 | } 21 | } -------------------------------------------------------------------------------- /docs/pages/Usage/_meta.json: -------------------------------------------------------------------------------- 1 | { 2 | "rvc": "RVC", 3 | "uvr": "UVR" 4 | } -------------------------------------------------------------------------------- /docs/pages/_meta.json: -------------------------------------------------------------------------------- 1 | { 2 | "index": "Introduction", 3 | "installation": "Installation", 4 | "contact": { 5 | "title": "Contact ↗", 6 | "type": "page", 7 | "href": "https://twitter.com/blaisewf", 8 | "newWindow": true 9 | } 10 | } -------------------------------------------------------------------------------- /docs/pages/index.mdx: -------------------------------------------------------------------------------- 1 | # Introduction 2 | 3 | ### References 4 | 5 | The RVC CLI builds upon the foundations of the following projects: 6 | 7 | - **Vocoders:** 8 | 9 | - [HiFi-GAN](https://github.com/jik876/hifi-gan) by jik876 10 | - [Vocos](https://github.com/gemelo-ai/vocos) by gemelo-ai 11 | - [BigVGAN](https://github.com/NVIDIA/BigVGAN) by NVIDIA 12 | - [BigVSAN](https://github.com/sony/bigvsan) by sony 13 | - [vocoders](https://github.com/reppy4620/vocoders) by reppy4620 14 | - [vocoder](https://github.com/fishaudio/vocoder) by fishaudio 15 | 16 | - **VC Clients:** 17 | 18 | - [Retrieval-based-Voice-Conversion-WebUI](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI) by RVC-Project 19 | - [So-Vits-SVC](https://github.com/svc-develop-team/so-vits-svc) by svc-develop-team 20 | - [Mangio-RVC-Fork](https://github.com/Mangio621/Mangio-RVC-Fork) by Mangio621 21 | - [VITS](https://github.com/jaywalnut310/vits) by jaywalnut310 22 | - [Harmonify](https://huggingface.co/Eempostor/Harmonify) by Eempostor 23 | - [rvc-trainer](https://github.com/thepowerfuldeez/rvc-trainer) by thepowerfuldeez 24 | 25 | - **Pitch Extractors:** 26 | 27 | - [RMVPE](https://github.com/Dream-High/RMVPE) by Dream-High 28 | - [torchfcpe](https://github.com/CNChTu/FCPE) by CNChTu 29 | - [torchcrepe](https://github.com/maxrmorrison/torchcrepe) by maxrmorrison 30 | - [anyf0](https://github.com/SoulMelody/anyf0) by SoulMelody 31 | 32 | - **Other:** 33 | - [FAIRSEQ](https://github.com/facebookresearch/fairseq) by facebookresearch 34 | - [FAISS](https://github.com/facebookresearch/faiss) by facebookresearch 35 | - [ContentVec](https://github.com/auspicious3000/contentvec/) by auspicious3000 36 | - [audio-slicer](https://github.com/openvpi/audio-slicer) by openvpi 37 | - [python-audio-separator](https://github.com/karaokenerds/python-audio-separator) by karaokenerds 38 | - [ultimatevocalremovergui](https://github.com/Anjok07/ultimatevocalremovergui) by Anjok07 39 | 40 | We acknowledge and appreciate the contributions of the respective authors and communities involved in these projects. 41 | -------------------------------------------------------------------------------- /docs/pages/installation.mdx: -------------------------------------------------------------------------------- 1 | ## Installation Guides 2 | 3 | ### Windows 4 | 5 | #### 1. **Install Dependencies:** 6 | - Open a Command Prompt or PowerShell window. 7 | - Navigate to the directory containing the `install.bat` file. 8 | - Run the `install.bat` file. This will create a Conda environment and install all necessary dependencies. 9 | #### 2. **Run the RVC CLI:** 10 | - After installing requirements, you can start using the CLI. However, make sure you run the `prerequisites` command first to download additional models and executables. 11 | - Launch the application using: 12 | ```bash 13 | env/python.exe cli.py prerequisites 14 | ``` 15 | - Once the prerequisites are downloaded, you can use the RVC CLI as usual. 16 | ### macOS & Linux 17 | 18 | #### 1. **Install Python (3.9 or 3.10):** 19 | - If you don't have Python installed, use your system's package manager. For example, on Ubuntu: 20 | ```bash 21 | sudo apt update 22 | sudo apt install python3 23 | ``` 24 | #### 2. **Create a Virtual Environment:** 25 | - Open a terminal window. 26 | - Navigate to the directory containing the `rvc_cli.py` file. 27 | - Run the following command to create a virtual environment: 28 | ```bash 29 | python3 -m venv venv 30 | ``` 31 | #### 3. **Activate the Environment:** 32 | - Run the following command to activate the virtual environment: 33 | ```bash 34 | source venv/bin/activate 35 | ``` 36 | #### 4. **Install Dependencies:** 37 | - Run the following command to install all necessary dependencies: 38 | ```bash 39 | pip install -r requirements.txt 40 | ``` 41 | #### 5. **Run the RVC CLI with `prerequisites`:** 42 | - Launch the application and run the `prerequisites` command to download additional models and executables. 43 | ```bash 44 | python cli.py prerequisites 45 | ``` 46 | #### 6. **Run the RVC CLI:** 47 | - After the prerequisites are downloaded, you can use the RVC CLI as usual. 48 | 49 | **Note:** 50 | 51 | - For Linux, you may need to install additional packages depending on your system. 52 | - If you face any errors during the installation process, consult the respective package managers' documentation for further instructions. 53 | 54 | 55 | 56 | -------------------------------------------------------------------------------- /docs/theme.config.tsx: -------------------------------------------------------------------------------- 1 | import React from "react"; 2 | import { DocsThemeConfig, useConfig } from "nextra-theme-docs"; 3 | 4 | const config: DocsThemeConfig = { 5 | logo: 'RVC CLI', 6 | search: { 7 | placeholder: "What are you looking for? 🧐", 8 | }, 9 | project: { 10 | link: "https://github.com/blaisewf/rvc_cli", 11 | }, 12 | chat: { 13 | link: "https://discord.gg/iahispano", 14 | }, 15 | docsRepositoryBase: "https://github.com/blaisewf/rvc_cli/tree/main/docs", 16 | footer: { 17 | text: ( 18 | 19 | made w ❤️ by blaisewf 20 | 21 | ), 22 | }, 23 | nextThemes: { 24 | defaultTheme: "dark", 25 | }, 26 | feedback: { 27 | content: "Do you think we should improve something? Let us know!", 28 | 29 | }, 30 | editLink: { 31 | component: null, 32 | }, 33 | faviconGlyph: "favicon.ico", 34 | logoLink: "/", 35 | primaryHue: 317, 36 | head: () => { 37 | const { frontMatter } = useConfig(); 38 | 39 | return ( 40 | <> 41 | 42 | 43 | 44 | 45 | 49 | 53 | 54 | 55 | 56 | 60 | 61 | 62 | 63 | ); 64 | }, 65 | useNextSeoProps() { 66 | return { 67 | titleTemplate: `%s - RVC CLI`, 68 | }; 69 | }, 70 | }; 71 | 72 | export default config; -------------------------------------------------------------------------------- /docs/tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "target": "es5", 4 | "lib": ["dom", "dom.iterable", "esnext"], 5 | "allowJs": true, 6 | "skipLibCheck": true, 7 | "strict": false, 8 | "forceConsistentCasingInFileNames": true, 9 | "noEmit": true, 10 | "incremental": true, 11 | "esModuleInterop": true, 12 | "module": "esnext", 13 | "moduleResolution": "node", 14 | "resolveJsonModule": true, 15 | "isolatedModules": true, 16 | "jsx": "preserve" 17 | }, 18 | "include": ["next-env.d.ts", "**/*.ts", "**/*.tsx"], 19 | "exclude": ["node_modules"] 20 | } 21 | -------------------------------------------------------------------------------- /install.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | setlocal enabledelayedexpansion 3 | title RVC CLI Installer 4 | 5 | echo Welcome to the RVC CLI Installer! 6 | echo. 7 | 8 | set "INSTALL_DIR=%cd%" 9 | set "MINICONDA_DIR=%UserProfile%\Miniconda3" 10 | set "ENV_DIR=%INSTALL_DIR%\env" 11 | set "MINICONDA_URL=https://repo.anaconda.com/miniconda/Miniconda3-py310_24.7.1-0-Windows-x86_64.exe" 12 | set "CONDA_EXE=%MINICONDA_DIR%\Scripts\conda.exe" 13 | 14 | set "startTime=%TIME%" 15 | set "startHour=%TIME:~0,2%" 16 | set "startMin=%TIME:~3,2%" 17 | set "startSec=%TIME:~6,2%" 18 | set /a startHour=1%startHour% - 100 19 | set /a startMin=1%startMin% - 100 20 | set /a startSec=1%startSec% - 100 21 | set /a startTotal = startHour*3600 + startMin*60 + startSec 22 | 23 | call :cleanup 24 | call :install_miniconda 25 | call :create_conda_env 26 | call :install_dependencies 27 | 28 | set "endTime=%TIME%" 29 | set "endHour=%TIME:~0,2%" 30 | set "endMin=%TIME:~3,2%" 31 | set "endSec=%TIME:~6,2%" 32 | set /a endHour=1%endHour% - 100 33 | set /a endMin=1%endMin% - 100 34 | set /a endSec=1%endSec% - 100 35 | set /a endTotal = endHour*3600 + endMin*60 + endSec 36 | set /a elapsed = endTotal - startTotal 37 | if %elapsed% lss 0 set /a elapsed += 86400 38 | set /a hours = elapsed / 3600 39 | set /a minutes = (elapsed %% 3600) / 60 40 | set /a seconds = elapsed %% 60 41 | 42 | echo Installation time: %hours% hours, %minutes% minutes, %seconds% seconds. 43 | echo. 44 | 45 | echo RVC CLI has been installed successfully! 46 | echo. 47 | pause 48 | exit /b 0 49 | 50 | :cleanup 51 | echo Cleaning up unnecessary files... 52 | for %%F in (Makefile Dockerfile docker-compose.yaml *.sh) do if exist "%%F" del "%%F" 53 | echo Cleanup complete. 54 | echo. 55 | exit /b 0 56 | 57 | :install_miniconda 58 | if exist "%CONDA_EXE%" ( 59 | echo Miniconda already installed. Skipping installation. 60 | exit /b 0 61 | ) 62 | 63 | echo Miniconda not found. Starting download and installation... 64 | powershell -Command "& {Invoke-WebRequest -Uri '%MINICONDA_URL%' -OutFile 'miniconda.exe'}" 65 | if not exist "miniconda.exe" goto :download_error 66 | 67 | start /wait "" miniconda.exe /InstallationType=JustMe /RegisterPython=0 /S /D=%MINICONDA_DIR% 68 | if errorlevel 1 goto :install_error 69 | 70 | del miniconda.exe 71 | echo Miniconda installation complete. 72 | echo. 73 | exit /b 0 74 | 75 | :create_conda_env 76 | echo Creating Conda environment... 77 | call "%MINICONDA_DIR%\_conda.exe" create --no-shortcuts -y -k --prefix "%ENV_DIR%" python=3.10 78 | if errorlevel 1 goto :error 79 | echo Conda environment created successfully. 80 | echo. 81 | 82 | if exist "%ENV_DIR%\python.exe" ( 83 | echo Installing uv package installer... 84 | "%ENV_DIR%\python.exe" -m pip install uv 85 | if errorlevel 1 goto :error 86 | echo uv installation complete. 87 | echo. 88 | ) 89 | exit /b 0 90 | 91 | :install_dependencies 92 | echo Installing dependencies... 93 | call "%MINICONDA_DIR%\condabin\conda.bat" activate "%ENV_DIR%" || goto :error 94 | uv pip install --upgrade setuptools || goto :error 95 | uv pip install torch==2.7.0 torchvision torchaudio==2.7.0 --upgrade --index-url https://download.pytorch.org/whl/cu128 || goto :error 96 | uv pip install -r "%INSTALL_DIR%\requirements.txt" || goto :error 97 | call "%MINICONDA_DIR%\condabin\conda.bat" deactivate 98 | echo Dependencies installation complete. 99 | echo. 100 | exit /b 0 101 | 102 | :download_error 103 | echo Download failed. Please check your internet connection and try again. 104 | goto :error 105 | 106 | :install_error 107 | echo Miniconda installation failed. 108 | goto :error 109 | 110 | :error 111 | echo An error occurred during installation. Please check the output above for details. 112 | pause 113 | exit /b 1 -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python3 -m venv venv 4 | source venv/bin/activate 5 | 6 | pip install -r requirements.txt 7 | -------------------------------------------------------------------------------- /logs/mute/extracted/mute.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blaisewf/rvc-cli/a732f87a54780e26e4a0e3c91b2d93f848eef6d2/logs/mute/extracted/mute.npy -------------------------------------------------------------------------------- /logs/mute/f0/mute.wav.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blaisewf/rvc-cli/a732f87a54780e26e4a0e3c91b2d93f848eef6d2/logs/mute/f0/mute.wav.npy -------------------------------------------------------------------------------- /logs/mute/f0_voiced/mute.wav.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blaisewf/rvc-cli/a732f87a54780e26e4a0e3c91b2d93f848eef6d2/logs/mute/f0_voiced/mute.wav.npy -------------------------------------------------------------------------------- /logs/mute/sliced_audios/mute32000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blaisewf/rvc-cli/a732f87a54780e26e4a0e3c91b2d93f848eef6d2/logs/mute/sliced_audios/mute32000.wav -------------------------------------------------------------------------------- /logs/mute/sliced_audios/mute40000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blaisewf/rvc-cli/a732f87a54780e26e4a0e3c91b2d93f848eef6d2/logs/mute/sliced_audios/mute40000.wav -------------------------------------------------------------------------------- /logs/mute/sliced_audios/mute44100.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blaisewf/rvc-cli/a732f87a54780e26e4a0e3c91b2d93f848eef6d2/logs/mute/sliced_audios/mute44100.wav -------------------------------------------------------------------------------- /logs/mute/sliced_audios/mute48000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blaisewf/rvc-cli/a732f87a54780e26e4a0e3c91b2d93f848eef6d2/logs/mute/sliced_audios/mute48000.wav -------------------------------------------------------------------------------- /logs/mute/sliced_audios_16k/mute.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blaisewf/rvc-cli/a732f87a54780e26e4a0e3c91b2d93f848eef6d2/logs/mute/sliced_audios_16k/mute.wav -------------------------------------------------------------------------------- /logs/reference/ref32000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blaisewf/rvc-cli/a732f87a54780e26e4a0e3c91b2d93f848eef6d2/logs/reference/ref32000.wav -------------------------------------------------------------------------------- /logs/reference/ref32000_f0c.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blaisewf/rvc-cli/a732f87a54780e26e4a0e3c91b2d93f848eef6d2/logs/reference/ref32000_f0c.npy -------------------------------------------------------------------------------- /logs/reference/ref32000_f0f.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blaisewf/rvc-cli/a732f87a54780e26e4a0e3c91b2d93f848eef6d2/logs/reference/ref32000_f0f.npy -------------------------------------------------------------------------------- /logs/reference/ref32000_feats.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blaisewf/rvc-cli/a732f87a54780e26e4a0e3c91b2d93f848eef6d2/logs/reference/ref32000_feats.npy -------------------------------------------------------------------------------- /logs/reference/ref40000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blaisewf/rvc-cli/a732f87a54780e26e4a0e3c91b2d93f848eef6d2/logs/reference/ref40000.wav -------------------------------------------------------------------------------- /logs/reference/ref40000_f0c.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blaisewf/rvc-cli/a732f87a54780e26e4a0e3c91b2d93f848eef6d2/logs/reference/ref40000_f0c.npy -------------------------------------------------------------------------------- /logs/reference/ref40000_f0f.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blaisewf/rvc-cli/a732f87a54780e26e4a0e3c91b2d93f848eef6d2/logs/reference/ref40000_f0f.npy -------------------------------------------------------------------------------- /logs/reference/ref40000_feats.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blaisewf/rvc-cli/a732f87a54780e26e4a0e3c91b2d93f848eef6d2/logs/reference/ref40000_feats.npy -------------------------------------------------------------------------------- /logs/reference/ref48000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blaisewf/rvc-cli/a732f87a54780e26e4a0e3c91b2d93f848eef6d2/logs/reference/ref48000.wav -------------------------------------------------------------------------------- /logs/reference/ref48000_f0c.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blaisewf/rvc-cli/a732f87a54780e26e4a0e3c91b2d93f848eef6d2/logs/reference/ref48000_f0c.npy -------------------------------------------------------------------------------- /logs/reference/ref48000_f0f.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blaisewf/rvc-cli/a732f87a54780e26e4a0e3c91b2d93f848eef6d2/logs/reference/ref48000_f0f.npy -------------------------------------------------------------------------------- /logs/reference/ref48000_feats.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blaisewf/rvc-cli/a732f87a54780e26e4a0e3c91b2d93f848eef6d2/logs/reference/ref48000_feats.npy -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Core dependencies 2 | pip>=23.3; sys_platform == 'darwin' 3 | wheel; sys_platform == 'darwin' 4 | PyYAML; sys_platform == 'darwin' 5 | numpy==1.26.4 6 | requests>=2.31.0,<2.32.0 7 | tqdm 8 | wget 9 | 10 | # Audio processing 11 | ffmpeg-python>=0.2.0 12 | faiss-cpu==1.7.3 13 | librosa==0.9.2 14 | scipy==1.11.1 15 | soundfile==0.12.1 16 | noisereduce 17 | pedalboard 18 | stftpitchshift 19 | soxr 20 | 21 | # Machine learning and deep learning 22 | omegaconf>=2.0.6; sys_platform == 'darwin' 23 | numba; sys_platform == 'linux' 24 | numba==0.61.0; sys_platform == 'darwin' or sys_platform == 'win32' 25 | torch==2.7.0 26 | torchaudio==2.7.0 27 | torchvision 28 | torchcrepe==0.0.23 29 | torchfcpe 30 | einops 31 | transformers==4.44.2 32 | 33 | # Visualization and UI 34 | matplotlib==3.7.2 35 | tensorboard 36 | gradio==5.23.1 37 | 38 | # Miscellaneous utilities 39 | certifi>=2023.07.22; sys_platform == 'darwin' 40 | antlr4-python3-runtime==4.8; sys_platform == 'darwin' 41 | tensorboardX 42 | edge-tts==6.1.9 43 | pypresence 44 | beautifulsoup4 45 | 46 | # UVR 47 | samplerate==0.1.0 48 | six>=1.16 49 | pydub>=0.25 50 | onnx>=1.14 51 | onnx2torch>=1.5 52 | onnxruntime>=1.17; sys_platform != 'darwin' 53 | onnxruntime-gpu>=1.17; sys_platform != 'darwin' 54 | julius>=0.2 55 | diffq>=0.2 56 | ml_collections 57 | resampy>=0.4 58 | beartype==0.18.5 59 | rotary-embedding-torch==0.6.1 -------------------------------------------------------------------------------- /rvc/configs/32000.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": { 3 | "log_interval": 200, 4 | "seed": 1234, 5 | "learning_rate": 1e-4, 6 | "betas": [0.8, 0.99], 7 | "eps": 1e-9, 8 | "lr_decay": 0.999875, 9 | "segment_size": 12800, 10 | "c_mel": 45, 11 | "c_kl": 1.0 12 | }, 13 | "data": { 14 | "max_wav_value": 32768.0, 15 | "sample_rate": 32000, 16 | "filter_length": 1024, 17 | "hop_length": 320, 18 | "win_length": 1024, 19 | "n_mel_channels": 80, 20 | "mel_fmin": 0.0, 21 | "mel_fmax": null 22 | }, 23 | "model": { 24 | "inter_channels": 192, 25 | "hidden_channels": 192, 26 | "filter_channels": 768, 27 | "text_enc_hidden_dim": 768, 28 | "n_heads": 2, 29 | "n_layers": 6, 30 | "kernel_size": 3, 31 | "p_dropout": 0, 32 | "resblock": "1", 33 | "resblock_kernel_sizes": [3,7,11], 34 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], 35 | "upsample_rates": [10,8,2,2], 36 | "upsample_initial_channel": 512, 37 | "upsample_kernel_sizes": [20,16,4,4], 38 | "use_spectral_norm": false, 39 | "gin_channels": 256, 40 | "spk_embed_dim": 109 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /rvc/configs/40000.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": { 3 | "log_interval": 200, 4 | "seed": 1234, 5 | "learning_rate": 1e-4, 6 | "betas": [0.8, 0.99], 7 | "eps": 1e-9, 8 | "lr_decay": 0.999875, 9 | "segment_size": 12800, 10 | "c_mel": 45, 11 | "c_kl": 1.0 12 | }, 13 | "data": { 14 | "max_wav_value": 32768.0, 15 | "sample_rate": 40000, 16 | "filter_length": 2048, 17 | "hop_length": 400, 18 | "win_length": 2048, 19 | "n_mel_channels": 125, 20 | "mel_fmin": 0.0, 21 | "mel_fmax": null 22 | }, 23 | "model": { 24 | "inter_channels": 192, 25 | "hidden_channels": 192, 26 | "filter_channels": 768, 27 | "text_enc_hidden_dim": 768, 28 | "n_heads": 2, 29 | "n_layers": 6, 30 | "kernel_size": 3, 31 | "p_dropout": 0, 32 | "resblock": "1", 33 | "resblock_kernel_sizes": [3,7,11], 34 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], 35 | "upsample_rates": [10,10,2,2], 36 | "upsample_initial_channel": 512, 37 | "upsample_kernel_sizes": [16,16,4,4], 38 | "use_spectral_norm": false, 39 | "gin_channels": 256, 40 | "spk_embed_dim": 109 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /rvc/configs/48000.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": { 3 | "log_interval": 200, 4 | "seed": 1234, 5 | "learning_rate": 1e-4, 6 | "betas": [0.8, 0.99], 7 | "eps": 1e-9, 8 | "lr_decay": 0.999875, 9 | "segment_size": 17280, 10 | "c_mel": 45, 11 | "c_kl": 1.0 12 | }, 13 | "data": { 14 | "max_wav_value": 32768.0, 15 | "sample_rate": 48000, 16 | "filter_length": 2048, 17 | "hop_length": 480, 18 | "win_length": 2048, 19 | "n_mel_channels": 128, 20 | "mel_fmin": 0.0, 21 | "mel_fmax": null 22 | }, 23 | "model": { 24 | "inter_channels": 192, 25 | "hidden_channels": 192, 26 | "filter_channels": 768, 27 | "text_enc_hidden_dim": 768, 28 | "n_heads": 2, 29 | "n_layers": 6, 30 | "kernel_size": 3, 31 | "p_dropout": 0, 32 | "resblock": "1", 33 | "resblock_kernel_sizes": [3,7,11], 34 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], 35 | "upsample_rates": [12,10,2,2], 36 | "upsample_initial_channel": 512, 37 | "upsample_kernel_sizes": [24,20,4,4], 38 | "use_spectral_norm": false, 39 | "gin_channels": 256, 40 | "spk_embed_dim": 109 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /rvc/configs/config.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import json 3 | import os 4 | 5 | version_config_paths = [ 6 | os.path.join("48000.json"), 7 | os.path.join("40000.json"), 8 | os.path.join("32000.json"), 9 | ] 10 | 11 | 12 | def singleton(cls): 13 | instances = {} 14 | 15 | def get_instance(*args, **kwargs): 16 | if cls not in instances: 17 | instances[cls] = cls(*args, **kwargs) 18 | return instances[cls] 19 | 20 | return get_instance 21 | 22 | 23 | @singleton 24 | class Config: 25 | def __init__(self): 26 | self.device = "cuda:0" if torch.cuda.is_available() else "cpu" 27 | self.gpu_name = ( 28 | torch.cuda.get_device_name(int(self.device.split(":")[-1])) 29 | if self.device.startswith("cuda") 30 | else None 31 | ) 32 | self.json_config = self.load_config_json() 33 | self.gpu_mem = None 34 | self.x_pad, self.x_query, self.x_center, self.x_max = self.device_config() 35 | 36 | def load_config_json(self): 37 | configs = {} 38 | for config_file in version_config_paths: 39 | config_path = os.path.join("rvc", "configs", config_file) 40 | with open(config_path, "r") as f: 41 | configs[config_file] = json.load(f) 42 | return configs 43 | 44 | def device_config(self): 45 | if self.device.startswith("cuda"): 46 | self.set_cuda_config() 47 | else: 48 | self.device = "cpu" 49 | 50 | # Configuration for 6GB GPU memory 51 | x_pad, x_query, x_center, x_max = (1, 6, 38, 41) 52 | if self.gpu_mem is not None and self.gpu_mem <= 4: 53 | # Configuration for 5GB GPU memory 54 | x_pad, x_query, x_center, x_max = (1, 5, 30, 32) 55 | 56 | return x_pad, x_query, x_center, x_max 57 | 58 | def set_cuda_config(self): 59 | i_device = int(self.device.split(":")[-1]) 60 | self.gpu_name = torch.cuda.get_device_name(i_device) 61 | self.gpu_mem = torch.cuda.get_device_properties(i_device).total_memory // ( 62 | 1024**3 63 | ) 64 | 65 | 66 | def max_vram_gpu(gpu): 67 | if torch.cuda.is_available(): 68 | gpu_properties = torch.cuda.get_device_properties(gpu) 69 | total_memory_gb = round(gpu_properties.total_memory / 1024 / 1024 / 1024) 70 | return total_memory_gb 71 | else: 72 | return "8" 73 | 74 | 75 | def get_gpu_info(): 76 | ngpu = torch.cuda.device_count() 77 | gpu_infos = [] 78 | if torch.cuda.is_available() or ngpu != 0: 79 | for i in range(ngpu): 80 | gpu_name = torch.cuda.get_device_name(i) 81 | mem = int( 82 | torch.cuda.get_device_properties(i).total_memory / 1024 / 1024 / 1024 83 | + 0.4 84 | ) 85 | gpu_infos.append(f"{i}: {gpu_name} ({mem} GB)") 86 | if len(gpu_infos) > 0: 87 | gpu_info = "\n".join(gpu_infos) 88 | else: 89 | gpu_info = "Unfortunately, there is no compatible GPU available to support your training." 90 | return gpu_info 91 | 92 | 93 | def get_number_of_gpus(): 94 | if torch.cuda.is_available(): 95 | num_gpus = torch.cuda.device_count() 96 | return "-".join(map(str, range(num_gpus))) 97 | else: 98 | return "-" 99 | -------------------------------------------------------------------------------- /rvc/lib/algorithm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blaisewf/rvc-cli/a732f87a54780e26e4a0e3c91b2d93f848eef6d2/rvc/lib/algorithm/__init__.py -------------------------------------------------------------------------------- /rvc/lib/algorithm/commons.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Optional 3 | 4 | 5 | def init_weights(m, mean=0.0, std=0.01): 6 | """ 7 | Initialize the weights of a module. 8 | 9 | Args: 10 | m: The module to initialize. 11 | mean: The mean of the normal distribution. 12 | std: The standard deviation of the normal distribution. 13 | """ 14 | classname = m.__class__.__name__ 15 | if classname.find("Conv") != -1: 16 | m.weight.data.normal_(mean, std) 17 | 18 | 19 | def get_padding(kernel_size, dilation=1): 20 | """ 21 | Calculate the padding needed for a convolution. 22 | 23 | Args: 24 | kernel_size: The size of the kernel. 25 | dilation: The dilation of the convolution. 26 | """ 27 | return int((kernel_size * dilation - dilation) / 2) 28 | 29 | 30 | def convert_pad_shape(pad_shape): 31 | """ 32 | Convert the pad shape to a list of integers. 33 | 34 | Args: 35 | pad_shape: The pad shape.. 36 | """ 37 | l = pad_shape[::-1] 38 | pad_shape = [item for sublist in l for item in sublist] 39 | return pad_shape 40 | 41 | 42 | def slice_segments( 43 | x: torch.Tensor, ids_str: torch.Tensor, segment_size: int = 4, dim: int = 2 44 | ): 45 | """ 46 | Slice segments from a tensor, handling tensors with different numbers of dimensions. 47 | 48 | Args: 49 | x (torch.Tensor): The tensor to slice. 50 | ids_str (torch.Tensor): The starting indices of the segments. 51 | segment_size (int, optional): The size of each segment. Defaults to 4. 52 | dim (int, optional): The dimension to slice across (2D or 3D tensors). Defaults to 2. 53 | """ 54 | if dim == 2: 55 | ret = torch.zeros_like(x[:, :segment_size]) 56 | elif dim == 3: 57 | ret = torch.zeros_like(x[:, :, :segment_size]) 58 | 59 | for i in range(x.size(0)): 60 | idx_str = ids_str[i].item() 61 | idx_end = idx_str + segment_size 62 | if dim == 2: 63 | ret[i] = x[i, idx_str:idx_end] 64 | else: 65 | ret[i] = x[i, :, idx_str:idx_end] 66 | 67 | return ret 68 | 69 | 70 | def rand_slice_segments(x, x_lengths=None, segment_size=4): 71 | """ 72 | Randomly slice segments from a tensor. 73 | 74 | Args: 75 | x: The tensor to slice. 76 | x_lengths: The lengths of the sequences. 77 | segment_size: The size of each segment. 78 | """ 79 | b, d, t = x.size() 80 | if x_lengths is None: 81 | x_lengths = t 82 | ids_str_max = x_lengths - segment_size + 1 83 | ids_str = (torch.rand([b], device=x.device) * ids_str_max).to(dtype=torch.long) 84 | ret = slice_segments(x, ids_str, segment_size, dim=3) 85 | return ret, ids_str 86 | 87 | 88 | @torch.jit.script 89 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): 90 | """ 91 | Fused add tanh sigmoid multiply operation. 92 | 93 | Args: 94 | input_a: The first input tensor. 95 | input_b: The second input tensor. 96 | n_channels: The number of channels. 97 | """ 98 | n_channels_int = n_channels[0] 99 | in_act = input_a + input_b 100 | t_act = torch.tanh(in_act[:, :n_channels_int, :]) 101 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) 102 | acts = t_act * s_act 103 | return acts 104 | 105 | 106 | def sequence_mask(length: torch.Tensor, max_length: Optional[int] = None): 107 | """ 108 | Generate a sequence mask. 109 | 110 | Args: 111 | length: The lengths of the sequences. 112 | max_length: The maximum length of the sequences. 113 | """ 114 | if max_length is None: 115 | max_length = length.max() 116 | x = torch.arange(max_length, dtype=length.dtype, device=length.device) 117 | return x.unsqueeze(0) < length.unsqueeze(1) 118 | 119 | 120 | def grad_norm(parameters, norm_type: float = 2.0): 121 | """ 122 | Calculates norm of parameter gradients 123 | 124 | Args: 125 | parameters: The list of parameters to clip. 126 | norm_type: The type of norm to use for clipping. 127 | """ 128 | if isinstance(parameters, torch.Tensor): 129 | parameters = [parameters] 130 | 131 | parameters = [p for p in parameters if p.grad is not None] 132 | 133 | if not parameters: 134 | return 0.0 135 | 136 | return torch.linalg.vector_norm( 137 | torch.stack([p.grad.norm(norm_type) for p in parameters]), ord=norm_type 138 | ).item() 139 | -------------------------------------------------------------------------------- /rvc/lib/algorithm/discriminators.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.checkpoint import checkpoint 3 | from torch.nn.utils.parametrizations import spectral_norm, weight_norm 4 | 5 | from rvc.lib.algorithm.commons import get_padding 6 | from rvc.lib.algorithm.residuals import LRELU_SLOPE 7 | 8 | 9 | class MultiPeriodDiscriminator(torch.nn.Module): 10 | """ 11 | Multi-period discriminator. 12 | 13 | This class implements a multi-period discriminator, which is used to 14 | discriminate between real and fake audio signals. The discriminator 15 | is composed of a series of convolutional layers that are applied to 16 | the input signal at different periods. 17 | 18 | Args: 19 | use_spectral_norm (bool): Whether to use spectral normalization. 20 | Defaults to False. 21 | """ 22 | 23 | def __init__(self, use_spectral_norm: bool = False, checkpointing: bool = False): 24 | super().__init__() 25 | periods = [2, 3, 5, 7, 11, 17, 23, 37] 26 | self.checkpointing = checkpointing 27 | self.discriminators = torch.nn.ModuleList( 28 | [DiscriminatorS(use_spectral_norm=use_spectral_norm)] 29 | + [DiscriminatorP(p, use_spectral_norm=use_spectral_norm) for p in periods] 30 | ) 31 | 32 | def forward(self, y, y_hat): 33 | y_d_rs, y_d_gs, fmap_rs, fmap_gs = [], [], [], [] 34 | for d in self.discriminators: 35 | if self.training and self.checkpointing: 36 | y_d_r, fmap_r = checkpoint(d, y, use_reentrant=False) 37 | y_d_g, fmap_g = checkpoint(d, y_hat, use_reentrant=False) 38 | else: 39 | y_d_r, fmap_r = d(y) 40 | y_d_g, fmap_g = d(y_hat) 41 | y_d_rs.append(y_d_r) 42 | y_d_gs.append(y_d_g) 43 | fmap_rs.append(fmap_r) 44 | fmap_gs.append(fmap_g) 45 | 46 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 47 | 48 | 49 | class DiscriminatorS(torch.nn.Module): 50 | """ 51 | Discriminator for the short-term component. 52 | 53 | This class implements a discriminator for the short-term component 54 | of the audio signal. The discriminator is composed of a series of 55 | convolutional layers that are applied to the input signal. 56 | """ 57 | 58 | def __init__(self, use_spectral_norm: bool = False): 59 | super().__init__() 60 | 61 | norm_f = spectral_norm if use_spectral_norm else weight_norm 62 | self.convs = torch.nn.ModuleList( 63 | [ 64 | norm_f(torch.nn.Conv1d(1, 16, 15, 1, padding=7)), 65 | norm_f(torch.nn.Conv1d(16, 64, 41, 4, groups=4, padding=20)), 66 | norm_f(torch.nn.Conv1d(64, 256, 41, 4, groups=16, padding=20)), 67 | norm_f(torch.nn.Conv1d(256, 1024, 41, 4, groups=64, padding=20)), 68 | norm_f(torch.nn.Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), 69 | norm_f(torch.nn.Conv1d(1024, 1024, 5, 1, padding=2)), 70 | ] 71 | ) 72 | self.conv_post = norm_f(torch.nn.Conv1d(1024, 1, 3, 1, padding=1)) 73 | self.lrelu = torch.nn.LeakyReLU(LRELU_SLOPE) 74 | 75 | def forward(self, x): 76 | fmap = [] 77 | for conv in self.convs: 78 | x = self.lrelu(conv(x)) 79 | fmap.append(x) 80 | x = self.conv_post(x) 81 | fmap.append(x) 82 | x = torch.flatten(x, 1, -1) 83 | return x, fmap 84 | 85 | 86 | class DiscriminatorP(torch.nn.Module): 87 | """ 88 | Discriminator for the long-term component. 89 | 90 | This class implements a discriminator for the long-term component 91 | of the audio signal. The discriminator is composed of a series of 92 | convolutional layers that are applied to the input signal at a given 93 | period. 94 | 95 | Args: 96 | period (int): Period of the discriminator. 97 | kernel_size (int): Kernel size of the convolutional layers. Defaults to 5. 98 | stride (int): Stride of the convolutional layers. Defaults to 3. 99 | use_spectral_norm (bool): Whether to use spectral normalization. Defaults to False. 100 | """ 101 | 102 | def __init__( 103 | self, 104 | period: int, 105 | kernel_size: int = 5, 106 | stride: int = 3, 107 | use_spectral_norm: bool = False, 108 | ): 109 | super().__init__() 110 | self.period = period 111 | norm_f = spectral_norm if use_spectral_norm else weight_norm 112 | 113 | in_channels = [1, 32, 128, 512, 1024] 114 | out_channels = [32, 128, 512, 1024, 1024] 115 | 116 | self.convs = torch.nn.ModuleList( 117 | [ 118 | norm_f( 119 | torch.nn.Conv2d( 120 | in_ch, 121 | out_ch, 122 | (kernel_size, 1), 123 | (stride, 1), 124 | padding=(get_padding(kernel_size, 1), 0), 125 | ) 126 | ) 127 | for in_ch, out_ch in zip(in_channels, out_channels) 128 | ] 129 | ) 130 | 131 | self.conv_post = norm_f(torch.nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) 132 | self.lrelu = torch.nn.LeakyReLU(LRELU_SLOPE) 133 | 134 | def forward(self, x): 135 | fmap = [] 136 | b, c, t = x.shape 137 | if t % self.period != 0: 138 | n_pad = self.period - (t % self.period) 139 | x = torch.nn.functional.pad(x, (0, n_pad), "reflect") 140 | x = x.view(b, c, -1, self.period) 141 | 142 | for conv in self.convs: 143 | x = self.lrelu(conv(x)) 144 | fmap.append(x) 145 | x = self.conv_post(x) 146 | fmap.append(x) 147 | x = torch.flatten(x, 1, -1) 148 | return x, fmap 149 | -------------------------------------------------------------------------------- /rvc/lib/algorithm/encoders.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from typing import Optional 4 | 5 | from rvc.lib.algorithm.commons import sequence_mask 6 | from rvc.lib.algorithm.modules import WaveNet 7 | from rvc.lib.algorithm.normalization import LayerNorm 8 | from rvc.lib.algorithm.attentions import FFN, MultiHeadAttention 9 | 10 | 11 | class Encoder(torch.nn.Module): 12 | """ 13 | Encoder module for the Transformer model. 14 | 15 | Args: 16 | hidden_channels (int): Number of hidden channels in the encoder. 17 | filter_channels (int): Number of filter channels in the feed-forward network. 18 | n_heads (int): Number of attention heads. 19 | n_layers (int): Number of encoder layers. 20 | kernel_size (int, optional): Kernel size of the convolution layers in the feed-forward network. Defaults to 1. 21 | p_dropout (float, optional): Dropout probability. Defaults to 0.0. 22 | window_size (int, optional): Window size for relative positional encoding. Defaults to 10. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | hidden_channels: int, 28 | filter_channels: int, 29 | n_heads: int, 30 | n_layers: int, 31 | kernel_size: int = 1, 32 | p_dropout: float = 0.0, 33 | window_size: int = 10, 34 | ): 35 | super().__init__() 36 | 37 | self.hidden_channels = hidden_channels 38 | self.n_layers = n_layers 39 | self.drop = torch.nn.Dropout(p_dropout) 40 | 41 | self.attn_layers = torch.nn.ModuleList( 42 | [ 43 | MultiHeadAttention( 44 | hidden_channels, 45 | hidden_channels, 46 | n_heads, 47 | p_dropout=p_dropout, 48 | window_size=window_size, 49 | ) 50 | for _ in range(n_layers) 51 | ] 52 | ) 53 | self.norm_layers_1 = torch.nn.ModuleList( 54 | [LayerNorm(hidden_channels) for _ in range(n_layers)] 55 | ) 56 | self.ffn_layers = torch.nn.ModuleList( 57 | [ 58 | FFN( 59 | hidden_channels, 60 | hidden_channels, 61 | filter_channels, 62 | kernel_size, 63 | p_dropout=p_dropout, 64 | ) 65 | for _ in range(n_layers) 66 | ] 67 | ) 68 | self.norm_layers_2 = torch.nn.ModuleList( 69 | [LayerNorm(hidden_channels) for _ in range(n_layers)] 70 | ) 71 | 72 | def forward(self, x, x_mask): 73 | attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) 74 | x = x * x_mask 75 | 76 | for i in range(self.n_layers): 77 | y = self.attn_layers[i](x, x, attn_mask) 78 | y = self.drop(y) 79 | x = self.norm_layers_1[i](x + y) 80 | 81 | y = self.ffn_layers[i](x, x_mask) 82 | y = self.drop(y) 83 | x = self.norm_layers_2[i](x + y) 84 | 85 | return x * x_mask 86 | 87 | 88 | class TextEncoder(torch.nn.Module): 89 | """ 90 | Text Encoder with configurable embedding dimension. 91 | 92 | Args: 93 | out_channels (int): Output channels of the encoder. 94 | hidden_channels (int): Hidden channels of the encoder. 95 | filter_channels (int): Filter channels of the encoder. 96 | n_heads (int): Number of attention heads. 97 | n_layers (int): Number of encoder layers. 98 | kernel_size (int): Kernel size of the convolutional layers. 99 | p_dropout (float): Dropout probability. 100 | embedding_dim (int): Embedding dimension for phone embeddings (v1 = 256, v2 = 768). 101 | f0 (bool, optional): Whether to use F0 embedding. Defaults to True. 102 | """ 103 | 104 | def __init__( 105 | self, 106 | out_channels: int, 107 | hidden_channels: int, 108 | filter_channels: int, 109 | n_heads: int, 110 | n_layers: int, 111 | kernel_size: int, 112 | p_dropout: float, 113 | embedding_dim: int, 114 | f0: bool = True, 115 | ): 116 | super().__init__() 117 | self.hidden_channels = hidden_channels 118 | self.out_channels = out_channels 119 | self.emb_phone = torch.nn.Linear(embedding_dim, hidden_channels) 120 | self.lrelu = torch.nn.LeakyReLU(0.1, inplace=True) 121 | self.emb_pitch = torch.nn.Embedding(256, hidden_channels) if f0 else None 122 | 123 | self.encoder = Encoder( 124 | hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout 125 | ) 126 | self.proj = torch.nn.Conv1d(hidden_channels, out_channels * 2, 1) 127 | 128 | def forward( 129 | self, phone: torch.Tensor, pitch: Optional[torch.Tensor], lengths: torch.Tensor 130 | ): 131 | x = self.emb_phone(phone) 132 | if pitch is not None and self.emb_pitch: 133 | x += self.emb_pitch(pitch) 134 | 135 | x *= math.sqrt(self.hidden_channels) 136 | x = self.lrelu(x) 137 | x = x.transpose(1, -1) # [B, H, T] 138 | 139 | x_mask = sequence_mask(lengths, x.size(2)).unsqueeze(1).to(x.dtype) 140 | x = self.encoder(x, x_mask) 141 | stats = self.proj(x) * x_mask 142 | 143 | m, logs = torch.split(stats, self.out_channels, dim=1) 144 | return m, logs, x_mask 145 | 146 | 147 | class PosteriorEncoder(torch.nn.Module): 148 | """ 149 | Posterior Encoder for inferring latent representation. 150 | 151 | Args: 152 | in_channels (int): Number of channels in the input. 153 | out_channels (int): Number of channels in the output. 154 | hidden_channels (int): Number of hidden channels in the encoder. 155 | kernel_size (int): Kernel size of the convolutional layers. 156 | dilation_rate (int): Dilation rate of the convolutional layers. 157 | n_layers (int): Number of layers in the encoder. 158 | gin_channels (int, optional): Number of channels for the global conditioning input. Defaults to 0. 159 | """ 160 | 161 | def __init__( 162 | self, 163 | in_channels: int, 164 | out_channels: int, 165 | hidden_channels: int, 166 | kernel_size: int, 167 | dilation_rate: int, 168 | n_layers: int, 169 | gin_channels: int = 0, 170 | ): 171 | super().__init__() 172 | self.out_channels = out_channels 173 | self.pre = torch.nn.Conv1d(in_channels, hidden_channels, 1) 174 | self.enc = WaveNet( 175 | hidden_channels, 176 | kernel_size, 177 | dilation_rate, 178 | n_layers, 179 | gin_channels=gin_channels, 180 | ) 181 | self.proj = torch.nn.Conv1d(hidden_channels, out_channels * 2, 1) 182 | 183 | def forward( 184 | self, x: torch.Tensor, x_lengths: torch.Tensor, g: Optional[torch.Tensor] = None 185 | ): 186 | x_mask = sequence_mask(x_lengths, x.size(2)).unsqueeze(1).to(x.dtype) 187 | 188 | x = self.pre(x) * x_mask 189 | x = self.enc(x, x_mask, g=g) 190 | 191 | stats = self.proj(x) * x_mask 192 | m, logs = torch.split(stats, self.out_channels, dim=1) 193 | 194 | z = m + torch.randn_like(m) * torch.exp(logs) 195 | z *= x_mask 196 | 197 | return z, m, logs, x_mask 198 | 199 | def remove_weight_norm(self): 200 | self.enc.remove_weight_norm() 201 | 202 | def __prepare_scriptable__(self): 203 | for hook in self.enc._forward_pre_hooks.values(): 204 | if ( 205 | hook.__module__ == "torch.nn.utils.parametrizations.weight_norm" 206 | and hook.__class__.__name__ == "WeightNorm" 207 | ): 208 | torch.nn.utils.remove_weight_norm(self.enc) 209 | return self 210 | -------------------------------------------------------------------------------- /rvc/lib/algorithm/generators/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blaisewf/rvc-cli/a732f87a54780e26e4a0e3c91b2d93f848eef6d2/rvc/lib/algorithm/generators/__init__.py -------------------------------------------------------------------------------- /rvc/lib/algorithm/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from rvc.lib.algorithm.commons import fused_add_tanh_sigmoid_multiply 3 | 4 | 5 | class WaveNet(torch.nn.Module): 6 | """ 7 | WaveNet residual blocks as used in WaveGlow. 8 | 9 | Args: 10 | hidden_channels (int): Number of hidden channels. 11 | kernel_size (int): Size of the convolutional kernel. 12 | dilation_rate (int): Dilation rate of the convolution. 13 | n_layers (int): Number of convolutional layers. 14 | gin_channels (int, optional): Number of conditioning channels. Defaults to 0. 15 | p_dropout (float, optional): Dropout probability. Defaults to 0. 16 | """ 17 | 18 | def __init__( 19 | self, 20 | hidden_channels: int, 21 | kernel_size: int, 22 | dilation_rate, 23 | n_layers: int, 24 | gin_channels: int = 0, 25 | p_dropout: int = 0, 26 | ): 27 | super().__init__() 28 | assert kernel_size % 2 == 1, "Kernel size must be odd for proper padding." 29 | 30 | self.hidden_channels = hidden_channels 31 | self.kernel_size = (kernel_size,) 32 | self.dilation_rate = dilation_rate 33 | self.n_layers = n_layers 34 | self.gin_channels = gin_channels 35 | self.p_dropout = p_dropout 36 | self.n_channels_tensor = torch.IntTensor([hidden_channels]) # Static tensor 37 | 38 | self.in_layers = torch.nn.ModuleList() 39 | self.res_skip_layers = torch.nn.ModuleList() 40 | self.drop = torch.nn.Dropout(p_dropout) 41 | 42 | # Conditional layer for global conditioning 43 | if gin_channels: 44 | self.cond_layer = torch.nn.utils.parametrizations.weight_norm( 45 | torch.nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1), 46 | name="weight", 47 | ) 48 | 49 | # Precompute dilations and paddings 50 | dilations = [dilation_rate**i for i in range(n_layers)] 51 | paddings = [(kernel_size * d - d) // 2 for d in dilations] 52 | 53 | # Initialize layers 54 | for i in range(n_layers): 55 | self.in_layers.append( 56 | torch.nn.utils.parametrizations.weight_norm( 57 | torch.nn.Conv1d( 58 | hidden_channels, 59 | 2 * hidden_channels, 60 | kernel_size, 61 | dilation=dilations[i], 62 | padding=paddings[i], 63 | ), 64 | name="weight", 65 | ) 66 | ) 67 | 68 | res_skip_channels = ( 69 | hidden_channels if i == n_layers - 1 else 2 * hidden_channels 70 | ) 71 | self.res_skip_layers.append( 72 | torch.nn.utils.parametrizations.weight_norm( 73 | torch.nn.Conv1d(hidden_channels, res_skip_channels, 1), 74 | name="weight", 75 | ) 76 | ) 77 | 78 | def forward(self, x, x_mask, g=None): 79 | output = x.clone().zero_() 80 | 81 | # Apply conditional layer if global conditioning is provided 82 | g = self.cond_layer(g) if g is not None else None 83 | 84 | for i in range(self.n_layers): 85 | x_in = self.in_layers[i](x) 86 | g_l = ( 87 | g[ 88 | :, 89 | i * 2 * self.hidden_channels : (i + 1) * 2 * self.hidden_channels, 90 | :, 91 | ] 92 | if g is not None 93 | else 0 94 | ) 95 | 96 | # Activation with fused Tanh-Sigmoid 97 | acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, self.n_channels_tensor) 98 | acts = self.drop(acts) 99 | 100 | # Residual and skip connections 101 | res_skip_acts = self.res_skip_layers[i](acts) 102 | if i < self.n_layers - 1: 103 | res_acts = res_skip_acts[:, : self.hidden_channels, :] 104 | x = (x + res_acts) * x_mask 105 | output = output + res_skip_acts[:, self.hidden_channels :, :] 106 | else: 107 | output = output + res_skip_acts 108 | 109 | return output * x_mask 110 | 111 | def remove_weight_norm(self): 112 | if self.gin_channels: 113 | torch.nn.utils.remove_weight_norm(self.cond_layer) 114 | for layer in self.in_layers: 115 | torch.nn.utils.remove_weight_norm(layer) 116 | for layer in self.res_skip_layers: 117 | torch.nn.utils.remove_weight_norm(layer) 118 | -------------------------------------------------------------------------------- /rvc/lib/algorithm/normalization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class LayerNorm(torch.nn.Module): 5 | """ 6 | Layer normalization module. 7 | 8 | Args: 9 | channels (int): Number of channels. 10 | eps (float, optional): Epsilon value for numerical stability. Defaults to 1e-5. 11 | """ 12 | 13 | def __init__(self, channels: int, eps: float = 1e-5): 14 | super().__init__() 15 | self.eps = eps 16 | self.gamma = torch.nn.Parameter(torch.ones(channels)) 17 | self.beta = torch.nn.Parameter(torch.zeros(channels)) 18 | 19 | def forward(self, x): 20 | # Transpose to (batch_size, time_steps, channels) for layer_norm 21 | x = x.transpose(1, -1) 22 | x = torch.nn.functional.layer_norm( 23 | x, (x.size(-1),), self.gamma, self.beta, self.eps 24 | ) 25 | # Transpose back to (batch_size, channels, time_steps) 26 | return x.transpose(1, -1) 27 | -------------------------------------------------------------------------------- /rvc/lib/predictors/F0Extractor.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import pathlib 3 | import librosa 4 | import numpy as np 5 | import resampy 6 | import torch 7 | import torchcrepe 8 | import torchfcpe 9 | import os 10 | 11 | # from tools.anyf0.rmvpe import RMVPE 12 | from rvc.lib.predictors.RMVPE import RMVPE0Predictor 13 | from rvc.configs.config import Config 14 | 15 | config = Config() 16 | 17 | 18 | @dataclasses.dataclass 19 | class F0Extractor: 20 | wav_path: pathlib.Path 21 | sample_rate: int = 44100 22 | hop_length: int = 512 23 | f0_min: int = 50 24 | f0_max: int = 1600 25 | method: str = "rmvpe" 26 | x: np.ndarray = dataclasses.field(init=False) 27 | 28 | def __post_init__(self): 29 | self.x, self.sample_rate = librosa.load(self.wav_path, sr=self.sample_rate) 30 | 31 | @property 32 | def hop_size(self): 33 | return self.hop_length / self.sample_rate 34 | 35 | @property 36 | def wav16k(self): 37 | return resampy.resample(self.x, self.sample_rate, 16000) 38 | 39 | def extract_f0(self): 40 | f0 = None 41 | method = self.method 42 | if method == "crepe": 43 | wav16k_torch = torch.FloatTensor(self.wav16k).unsqueeze(0).to(config.device) 44 | f0 = torchcrepe.predict( 45 | wav16k_torch, 46 | sample_rate=16000, 47 | hop_length=160, 48 | batch_size=512, 49 | fmin=self.f0_min, 50 | fmax=self.f0_max, 51 | device=config.device, 52 | ) 53 | f0 = f0[0].cpu().numpy() 54 | elif method == "fcpe": 55 | audio = librosa.to_mono(self.x) 56 | audio_length = len(audio) 57 | f0_target_length = (audio_length // self.hop_length) + 1 58 | audio = ( 59 | torch.from_numpy(audio) 60 | .float() 61 | .unsqueeze(0) 62 | .unsqueeze(-1) 63 | .to(config.device) 64 | ) 65 | model = torchfcpe.spawn_bundled_infer_model(device=config.device) 66 | 67 | f0 = model.infer( 68 | audio, 69 | sr=self.sample_rate, 70 | decoder_mode="local_argmax", 71 | threshold=0.006, 72 | f0_min=self.f0_min, 73 | f0_max=self.f0_max, 74 | interp_uv=False, 75 | output_interp_target_length=f0_target_length, 76 | ) 77 | f0 = f0.squeeze().cpu().numpy() 78 | elif method == "rmvpe": 79 | model_rmvpe = RMVPE0Predictor( 80 | os.path.join("rvc", "models", "predictors", "rmvpe.pt"), 81 | device=config.device, 82 | # hop_length=80 83 | ) 84 | f0 = model_rmvpe.infer_from_audio(self.wav16k, thred=0.03) 85 | 86 | else: 87 | raise ValueError(f"Unknown method: {self.method}") 88 | return self.hz_to_cents(f0, librosa.midi_to_hz(0)) 89 | 90 | def plot_f0(self, f0): 91 | from matplotlib import pyplot as plt 92 | 93 | plt.figure(figsize=(10, 4)) 94 | plt.plot(f0) 95 | plt.title(self.method) 96 | plt.xlabel("Time (frames)") 97 | plt.ylabel("F0 (cents)") 98 | plt.show() 99 | 100 | @staticmethod 101 | def hz_to_cents(F, F_ref=55.0): 102 | F_temp = np.array(F).astype(float) 103 | F_temp[F_temp == 0] = np.nan 104 | F_cents = 1200 * np.log2(F_temp / F_ref) 105 | return F_cents 106 | -------------------------------------------------------------------------------- /rvc/lib/tools/analyzer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import librosa.display 4 | import librosa 5 | 6 | 7 | def calculate_features(y, sr): 8 | stft = np.abs(librosa.stft(y)) 9 | duration = librosa.get_duration(y=y, sr=sr) 10 | cent = librosa.feature.spectral_centroid(S=stft, sr=sr)[0] 11 | bw = librosa.feature.spectral_bandwidth(S=stft, sr=sr)[0] 12 | rolloff = librosa.feature.spectral_rolloff(S=stft, sr=sr)[0] 13 | return stft, duration, cent, bw, rolloff 14 | 15 | 16 | def plot_title(title): 17 | plt.suptitle(title, fontsize=16, fontweight="bold") 18 | 19 | 20 | def plot_spectrogram(y, sr, stft, duration, cmap="inferno"): 21 | plt.subplot(3, 1, 1) 22 | plt.imshow( 23 | librosa.amplitude_to_db(stft, ref=np.max), 24 | origin="lower", 25 | extent=[0, duration, 0, sr / 1000], 26 | aspect="auto", 27 | cmap=cmap, # Change the colormap here 28 | ) 29 | plt.colorbar(format="%+2.0f dB") 30 | plt.xlabel("Time (s)") 31 | plt.ylabel("Frequency (kHz)") 32 | plt.title("Spectrogram") 33 | 34 | 35 | def plot_waveform(y, sr, duration): 36 | plt.subplot(3, 1, 2) 37 | librosa.display.waveshow(y, sr=sr) 38 | plt.xlabel("Time (s)") 39 | plt.ylabel("Amplitude") 40 | plt.title("Waveform") 41 | 42 | 43 | def plot_features(times, cent, bw, rolloff, duration): 44 | plt.subplot(3, 1, 3) 45 | plt.plot(times, cent, label="Spectral Centroid (kHz)", color="b") 46 | plt.plot(times, bw, label="Spectral Bandwidth (kHz)", color="g") 47 | plt.plot(times, rolloff, label="Spectral Rolloff (kHz)", color="r") 48 | plt.xlabel("Time (s)") 49 | plt.title("Spectral Features") 50 | plt.legend() 51 | 52 | 53 | def analyze_audio(audio_file, save_plot_path="logs/audio_analysis.png"): 54 | y, sr = librosa.load(audio_file) 55 | stft, duration, cent, bw, rolloff = calculate_features(y, sr) 56 | 57 | plt.figure(figsize=(12, 10)) 58 | 59 | plot_title("Audio Analysis" + " - " + audio_file.split("/")[-1]) 60 | plot_spectrogram(y, sr, stft, duration) 61 | plot_waveform(y, sr, duration) 62 | plot_features(librosa.times_like(cent), cent, bw, rolloff, duration) 63 | 64 | plt.tight_layout() 65 | 66 | if save_plot_path: 67 | plt.savefig(save_plot_path, bbox_inches="tight", dpi=300) 68 | plt.close() 69 | 70 | audio_info = f"""Sample Rate: {sr}\nDuration: {( 71 | str(round(duration, 2)) + " seconds" 72 | if duration < 60 73 | else str(round(duration / 60, 2)) + " minutes" 74 | )}\nNumber of Samples: {len(y)}\nBits per Sample: {librosa.get_samplerate(audio_file)}\nChannels: {"Mono (1)" if y.ndim == 1 else "Stereo (2)"}""" 75 | 76 | return audio_info, save_plot_path 77 | -------------------------------------------------------------------------------- /rvc/lib/tools/launch_tensorboard.py: -------------------------------------------------------------------------------- 1 | import time 2 | import logging 3 | from tensorboard import program 4 | 5 | log_path = "logs" 6 | 7 | 8 | def launch_tensorboard_pipeline(): 9 | logging.getLogger("root").setLevel(logging.WARNING) 10 | logging.getLogger("tensorboard").setLevel(logging.WARNING) 11 | 12 | tb = program.TensorBoard() 13 | tb.configure(argv=[None, "--logdir", log_path]) 14 | url = tb.launch() 15 | 16 | print( 17 | f"Access the tensorboard using the following link:\n{url}?pinnedCards=%5B%7B%22plugin%22%3A%22scalars%22%2C%22tag%22%3A%22loss%2Fg%2Ftotal%22%7D%2C%7B%22plugin%22%3A%22scalars%22%2C%22tag%22%3A%22loss%2Fd%2Ftotal%22%7D%2C%7B%22plugin%22%3A%22scalars%22%2C%22tag%22%3A%22loss%2Fg%2Fkl%22%7D%2C%7B%22plugin%22%3A%22scalars%22%2C%22tag%22%3A%22loss%2Fg%2Fmel%22%7D%5D" 18 | ) 19 | 20 | while True: 21 | time.sleep(600) 22 | -------------------------------------------------------------------------------- /rvc/lib/tools/model_download.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import sys 4 | import shutil 5 | import zipfile 6 | import requests 7 | from bs4 import BeautifulSoup 8 | from urllib.parse import unquote 9 | from tqdm import tqdm 10 | 11 | now_dir = os.getcwd() 12 | sys.path.append(now_dir) 13 | 14 | from rvc.lib.utils import format_title 15 | from rvc.lib.tools import gdown 16 | 17 | 18 | file_path = os.path.join(now_dir, "logs") 19 | zips_path = os.path.join(file_path, "zips") 20 | os.makedirs(zips_path, exist_ok=True) 21 | 22 | 23 | def search_pth_index(folder): 24 | pth_paths = [ 25 | os.path.join(folder, file) 26 | for file in os.listdir(folder) 27 | if os.path.isfile(os.path.join(folder, file)) and file.endswith(".pth") 28 | ] 29 | index_paths = [ 30 | os.path.join(folder, file) 31 | for file in os.listdir(folder) 32 | if os.path.isfile(os.path.join(folder, file)) and file.endswith(".index") 33 | ] 34 | return pth_paths, index_paths 35 | 36 | 37 | def download_from_url(url): 38 | os.chdir(zips_path) 39 | 40 | try: 41 | if "drive.google.com" in url: 42 | file_id = extract_google_drive_id(url) 43 | if file_id: 44 | gdown.download( 45 | url=f"https://drive.google.com/uc?id={file_id}", 46 | quiet=False, 47 | fuzzy=True, 48 | ) 49 | elif "/blob/" in url or "/resolve/" in url: 50 | download_blob_or_resolve(url) 51 | elif "/tree/main" in url: 52 | download_from_huggingface(url) 53 | else: 54 | download_file(url) 55 | 56 | rename_downloaded_files() 57 | return "downloaded" 58 | except Exception as error: 59 | print(f"An error occurred downloading the file: {error}") 60 | return None 61 | finally: 62 | os.chdir(now_dir) 63 | 64 | 65 | def extract_google_drive_id(url): 66 | if "file/d/" in url: 67 | return url.split("file/d/")[1].split("/")[0] 68 | if "id=" in url: 69 | return url.split("id=")[1].split("&")[0] 70 | return None 71 | 72 | 73 | def download_blob_or_resolve(url): 74 | if "/blob/" in url: 75 | url = url.replace("/blob/", "/resolve/") 76 | response = requests.get(url, stream=True) 77 | if response.status_code == 200: 78 | save_response_content(response) 79 | else: 80 | raise ValueError( 81 | "Download failed with status code: " + str(response.status_code) 82 | ) 83 | 84 | 85 | def save_response_content(response): 86 | content_disposition = unquote(response.headers.get("Content-Disposition", "")) 87 | file_name = ( 88 | re.search(r'filename="([^"]+)"', content_disposition) 89 | .groups()[0] 90 | .replace(os.path.sep, "_") 91 | if content_disposition 92 | else "downloaded_file" 93 | ) 94 | 95 | total_size = int(response.headers.get("Content-Length", 0)) 96 | chunk_size = 1024 97 | 98 | with open(os.path.join(zips_path, file_name), "wb") as file, tqdm( 99 | total=total_size, unit="B", unit_scale=True, desc=file_name 100 | ) as progress_bar: 101 | for data in response.iter_content(chunk_size): 102 | file.write(data) 103 | progress_bar.update(len(data)) 104 | 105 | 106 | def download_from_huggingface(url): 107 | response = requests.get(url) 108 | soup = BeautifulSoup(response.content, "html.parser") 109 | temp_url = next( 110 | ( 111 | link["href"] 112 | for link in soup.find_all("a", href=True) 113 | if link["href"].endswith(".zip") 114 | ), 115 | None, 116 | ) 117 | if temp_url: 118 | url = temp_url.replace("blob", "resolve") 119 | if "huggingface.co" not in url: 120 | url = "https://huggingface.co" + url 121 | download_file(url) 122 | else: 123 | raise ValueError("No zip file found in Huggingface URL") 124 | 125 | 126 | def download_file(url): 127 | response = requests.get(url, stream=True) 128 | if response.status_code == 200: 129 | save_response_content(response) 130 | else: 131 | raise ValueError( 132 | "Download failed with status code: " + str(response.status_code) 133 | ) 134 | 135 | 136 | def rename_downloaded_files(): 137 | for currentPath, _, zipFiles in os.walk(zips_path): 138 | for file in zipFiles: 139 | file_name, extension = os.path.splitext(file) 140 | real_path = os.path.join(currentPath, file) 141 | os.rename(real_path, file_name.replace(os.path.sep, "_") + extension) 142 | 143 | 144 | def extract(zipfile_path, unzips_path): 145 | try: 146 | with zipfile.ZipFile(zipfile_path, "r") as zip_ref: 147 | zip_ref.extractall(unzips_path) 148 | os.remove(zipfile_path) 149 | return True 150 | except Exception as error: 151 | print(f"An error occurred extracting the zip file: {error}") 152 | return False 153 | 154 | 155 | def unzip_file(zip_path, zip_file_name): 156 | zip_file_path = os.path.join(zip_path, zip_file_name + ".zip") 157 | extract_path = os.path.join(file_path, zip_file_name) 158 | with zipfile.ZipFile(zip_file_path, "r") as zip_ref: 159 | zip_ref.extractall(extract_path) 160 | os.remove(zip_file_path) 161 | 162 | 163 | def model_download_pipeline(url: str): 164 | try: 165 | result = download_from_url(url) 166 | if result == "downloaded": 167 | return handle_extraction_process() 168 | else: 169 | return "Error" 170 | except Exception as error: 171 | print(f"An unexpected error occurred: {error}") 172 | return "Error" 173 | 174 | 175 | def handle_extraction_process(): 176 | extract_folder_path = "" 177 | for filename in os.listdir(zips_path): 178 | if filename.endswith(".zip"): 179 | zipfile_path = os.path.join(zips_path, filename) 180 | model_name = format_title(os.path.basename(zipfile_path).split(".zip")[0]) 181 | extract_folder_path = os.path.join("logs", os.path.normpath(model_name)) 182 | success = extract(zipfile_path, extract_folder_path) 183 | clean_extracted_files(extract_folder_path, model_name) 184 | 185 | if success: 186 | print(f"Model {model_name} downloaded!") 187 | else: 188 | print(f"Error downloading {model_name}") 189 | return "Error" 190 | if not extract_folder_path: 191 | print("Zip file was not found.") 192 | return "Error" 193 | return search_pth_index(extract_folder_path) 194 | 195 | 196 | def clean_extracted_files(extract_folder_path, model_name): 197 | macosx_path = os.path.join(extract_folder_path, "__MACOSX") 198 | if os.path.exists(macosx_path): 199 | shutil.rmtree(macosx_path) 200 | 201 | subfolders = [ 202 | f 203 | for f in os.listdir(extract_folder_path) 204 | if os.path.isdir(os.path.join(extract_folder_path, f)) 205 | ] 206 | if len(subfolders) == 1: 207 | subfolder_path = os.path.join(extract_folder_path, subfolders[0]) 208 | for item in os.listdir(subfolder_path): 209 | shutil.move( 210 | os.path.join(subfolder_path, item), 211 | os.path.join(extract_folder_path, item), 212 | ) 213 | os.rmdir(subfolder_path) 214 | 215 | for item in os.listdir(extract_folder_path): 216 | source_path = os.path.join(extract_folder_path, item) 217 | if ".pth" in item: 218 | new_file_name = model_name + ".pth" 219 | elif ".index" in item: 220 | new_file_name = model_name + ".index" 221 | else: 222 | continue 223 | 224 | destination_path = os.path.join(extract_folder_path, new_file_name) 225 | if not os.path.exists(destination_path): 226 | os.rename(source_path, destination_path) 227 | -------------------------------------------------------------------------------- /rvc/lib/tools/prerequisites_download.py: -------------------------------------------------------------------------------- 1 | import os 2 | from concurrent.futures import ThreadPoolExecutor 3 | from tqdm import tqdm 4 | import requests 5 | 6 | url_base = "https://huggingface.co/IAHispano/Applio/resolve/main/Resources" 7 | 8 | pretraineds_hifigan_list = [ 9 | ( 10 | "pretrained_v2/", 11 | [ 12 | "f0D32k.pth", 13 | "f0D40k.pth", 14 | "f0D48k.pth", 15 | "f0G32k.pth", 16 | "f0G40k.pth", 17 | "f0G48k.pth", 18 | ], 19 | ) 20 | ] 21 | models_list = [("predictors/", ["rmvpe.pt", "fcpe.pt"])] 22 | embedders_list = [("embedders/contentvec/", ["pytorch_model.bin", "config.json"])] 23 | executables_list = [ 24 | ("", ["ffmpeg.exe", "ffprobe.exe"]), 25 | ] 26 | 27 | folder_mapping_list = { 28 | "pretrained_v2/": "rvc/models/pretraineds/hifi-gan/", 29 | "embedders/contentvec/": "rvc/models/embedders/contentvec/", 30 | "predictors/": "rvc/models/predictors/", 31 | "formant/": "rvc/models/formant/", 32 | } 33 | 34 | 35 | def get_file_size_if_missing(file_list): 36 | """ 37 | Calculate the total size of files to be downloaded only if they do not exist locally. 38 | """ 39 | total_size = 0 40 | for remote_folder, files in file_list: 41 | local_folder = folder_mapping_list.get(remote_folder, "") 42 | for file in files: 43 | destination_path = os.path.join(local_folder, file) 44 | if not os.path.exists(destination_path): 45 | url = f"{url_base}/{remote_folder}{file}" 46 | response = requests.head(url) 47 | total_size += int(response.headers.get("content-length", 0)) 48 | return total_size 49 | 50 | 51 | def download_file(url, destination_path, global_bar): 52 | """ 53 | Download a file from the given URL to the specified destination path, 54 | updating the global progress bar as data is downloaded. 55 | """ 56 | 57 | dir_name = os.path.dirname(destination_path) 58 | if dir_name: 59 | os.makedirs(dir_name, exist_ok=True) 60 | response = requests.get(url, stream=True) 61 | block_size = 1024 62 | with open(destination_path, "wb") as file: 63 | for data in response.iter_content(block_size): 64 | file.write(data) 65 | global_bar.update(len(data)) 66 | 67 | 68 | def download_mapping_files(file_mapping_list, global_bar): 69 | """ 70 | Download all files in the provided file mapping list using a thread pool executor, 71 | and update the global progress bar as downloads progress. 72 | """ 73 | with ThreadPoolExecutor() as executor: 74 | futures = [] 75 | for remote_folder, file_list in file_mapping_list: 76 | local_folder = folder_mapping_list.get(remote_folder, "") 77 | for file in file_list: 78 | destination_path = os.path.join(local_folder, file) 79 | if not os.path.exists(destination_path): 80 | url = f"{url_base}/{remote_folder}{file}" 81 | futures.append( 82 | executor.submit( 83 | download_file, url, destination_path, global_bar 84 | ) 85 | ) 86 | for future in futures: 87 | future.result() 88 | 89 | 90 | def split_pretraineds(pretrained_list): 91 | f0_list = [] 92 | non_f0_list = [] 93 | for folder, files in pretrained_list: 94 | f0_files = [f for f in files if f.startswith("f0")] 95 | non_f0_files = [f for f in files if not f.startswith("f0")] 96 | if f0_files: 97 | f0_list.append((folder, f0_files)) 98 | if non_f0_files: 99 | non_f0_list.append((folder, non_f0_files)) 100 | return f0_list, non_f0_list 101 | 102 | 103 | pretraineds_hifigan_list, _ = split_pretraineds(pretraineds_hifigan_list) 104 | 105 | 106 | def calculate_total_size( 107 | pretraineds_hifigan, 108 | models, 109 | exe, 110 | ): 111 | """ 112 | Calculate the total size of all files to be downloaded based on selected categories. 113 | """ 114 | total_size = 0 115 | if models: 116 | total_size += get_file_size_if_missing(models_list) 117 | total_size += get_file_size_if_missing(embedders_list) 118 | if exe and os.name == "nt": 119 | total_size += get_file_size_if_missing(executables_list) 120 | total_size += get_file_size_if_missing(pretraineds_hifigan) 121 | return total_size 122 | 123 | 124 | def prequisites_download_pipeline( 125 | pretraineds_hifigan, 126 | models, 127 | exe, 128 | ): 129 | """ 130 | Manage the download pipeline for different categories of files. 131 | """ 132 | total_size = calculate_total_size( 133 | pretraineds_hifigan_list if pretraineds_hifigan else [], 134 | models, 135 | exe, 136 | ) 137 | 138 | if total_size > 0: 139 | with tqdm( 140 | total=total_size, unit="iB", unit_scale=True, desc="Downloading all files" 141 | ) as global_bar: 142 | if models: 143 | download_mapping_files(models_list, global_bar) 144 | download_mapping_files(embedders_list, global_bar) 145 | if exe: 146 | if os.name == "nt": 147 | download_mapping_files(executables_list, global_bar) 148 | else: 149 | print("No executables needed") 150 | if pretraineds_hifigan: 151 | download_mapping_files(pretraineds_hifigan_list, global_bar) 152 | else: 153 | pass 154 | -------------------------------------------------------------------------------- /rvc/lib/tools/pretrained_selector.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def pretrained_selector(vocoder, sample_rate): 5 | base_path = os.path.join("rvc", "models", "pretraineds", f"{vocoder.lower()}") 6 | 7 | path_g = os.path.join(base_path, f"f0G{str(sample_rate)[:2]}k.pth") 8 | path_d = os.path.join(base_path, f"f0D{str(sample_rate)[:2]}k.pth") 9 | 10 | if os.path.exists(path_g) and os.path.exists(path_d): 11 | return path_g, path_d 12 | else: 13 | return "", "" 14 | -------------------------------------------------------------------------------- /rvc/lib/tools/split_audio.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import librosa 3 | 4 | 5 | def process_audio(audio, sr=16000, silence_thresh=-60, min_silence_len=250): 6 | """ 7 | Splits an audio signal into segments using a fixed frame size and hop size. 8 | 9 | Parameters: 10 | - audio (np.ndarray): The audio signal to split. 11 | - sr (int): The sample rate of the input audio (default is 16000). 12 | - silence_thresh (int): Silence threshold (default =-60dB) 13 | - min_silence_len (int): Minimum silence duration (default 250ms). 14 | 15 | Returns: 16 | - list of np.ndarray: A list of audio segments. 17 | - np.ndarray: The intervals where the audio was split. 18 | """ 19 | frame_length = int(min_silence_len / 1000 * sr) 20 | hop_length = frame_length // 2 21 | intervals = librosa.effects.split( 22 | audio, top_db=-silence_thresh, frame_length=frame_length, hop_length=hop_length 23 | ) 24 | audio_segments = [audio[start:end] for start, end in intervals] 25 | 26 | return audio_segments, intervals 27 | 28 | 29 | def merge_audio(audio_segments_org, audio_segments_new, intervals, sr_orig, sr_new): 30 | """ 31 | Merges audio segments back into a single audio signal, filling gaps with silence. 32 | Assumes audio segments are already at sr_new. 33 | 34 | Parameters: 35 | - audio_segments_org (list of np.ndarray): The non-silent audio segments (at sr_orig). 36 | - audio_segments_new (list of np.ndarray): The non-silent audio segments (at sr_new). 37 | - intervals (np.ndarray): The intervals used for splitting the original audio. 38 | - sr_orig (int): The sample rate of the original audio 39 | - sr_new (int): The sample rate of the model 40 | Returns: 41 | - np.ndarray: The merged audio signal with silent gaps restored. 42 | """ 43 | merged_audio = np.array([], dtype=audio_segments_new[0].dtype) 44 | sr_ratio = sr_new / sr_orig 45 | 46 | for i, (start, end) in enumerate(intervals): 47 | 48 | start_new = int(start * sr_ratio) 49 | end_new = int(end * sr_ratio) 50 | 51 | original_duration = len(audio_segments_org[i]) / sr_orig 52 | new_duration = len(audio_segments_new[i]) / sr_new 53 | duration_diff = new_duration - original_duration 54 | 55 | silence_samples = int(abs(duration_diff) * sr_new) 56 | silence_compensation = np.zeros( 57 | silence_samples, dtype=audio_segments_new[0].dtype 58 | ) 59 | 60 | if i == 0 and start_new > 0: 61 | initial_silence = np.zeros(start_new, dtype=audio_segments_new[0].dtype) 62 | merged_audio = np.concatenate((merged_audio, initial_silence)) 63 | 64 | if duration_diff > 0: 65 | merged_audio = np.concatenate((merged_audio, silence_compensation)) 66 | 67 | merged_audio = np.concatenate((merged_audio, audio_segments_new[i])) 68 | 69 | if duration_diff < 0: 70 | merged_audio = np.concatenate((merged_audio, silence_compensation)) 71 | 72 | if i < len(intervals) - 1: 73 | next_start_new = int(intervals[i + 1][0] * sr_ratio) 74 | silence_duration = next_start_new - end_new 75 | if silence_duration > 0: 76 | silence = np.zeros(silence_duration, dtype=audio_segments_new[0].dtype) 77 | merged_audio = np.concatenate((merged_audio, silence)) 78 | 79 | return merged_audio 80 | -------------------------------------------------------------------------------- /rvc/lib/tools/tts.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import asyncio 3 | import edge_tts 4 | import os 5 | 6 | 7 | async def main(): 8 | # Parse command line arguments 9 | tts_file = str(sys.argv[1]) 10 | text = str(sys.argv[2]) 11 | voice = str(sys.argv[3]) 12 | rate = int(sys.argv[4]) 13 | output_file = str(sys.argv[5]) 14 | 15 | rates = f"+{rate}%" if rate >= 0 else f"{rate}%" 16 | if tts_file and os.path.exists(tts_file): 17 | text = "" 18 | try: 19 | with open(tts_file, "r", encoding="utf-8") as file: 20 | text = file.read() 21 | except UnicodeDecodeError: 22 | with open(tts_file, "r") as file: 23 | text = file.read() 24 | await edge_tts.Communicate(text, voice, rate=rates).save(output_file) 25 | # print(f"TTS with {voice} completed. Output TTS file: '{output_file}'") 26 | 27 | 28 | if __name__ == "__main__": 29 | asyncio.run(main()) 30 | -------------------------------------------------------------------------------- /rvc/lib/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import soxr 4 | import librosa 5 | import soundfile as sf 6 | import numpy as np 7 | import re 8 | import unicodedata 9 | import wget 10 | from torch import nn 11 | 12 | import logging 13 | from transformers import HubertModel 14 | import warnings 15 | 16 | # Remove this to see warnings about transformers models 17 | warnings.filterwarnings("ignore") 18 | 19 | logging.getLogger("fairseq").setLevel(logging.ERROR) 20 | logging.getLogger("faiss.loader").setLevel(logging.ERROR) 21 | logging.getLogger("transformers").setLevel(logging.ERROR) 22 | logging.getLogger("torch").setLevel(logging.ERROR) 23 | 24 | now_dir = os.getcwd() 25 | sys.path.append(now_dir) 26 | 27 | base_path = os.path.join(now_dir, "rvc", "models", "formant", "stftpitchshift") 28 | stft = base_path + ".exe" if sys.platform == "win32" else base_path 29 | 30 | 31 | class HubertModelWithFinalProj(HubertModel): 32 | def __init__(self, config): 33 | super().__init__(config) 34 | self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size) 35 | 36 | 37 | def load_audio(file, sample_rate): 38 | try: 39 | file = file.strip(" ").strip('"').strip("\n").strip('"').strip(" ") 40 | audio, sr = sf.read(file) 41 | if len(audio.shape) > 1: 42 | audio = librosa.to_mono(audio.T) 43 | if sr != sample_rate: 44 | audio = librosa.resample( 45 | audio, orig_sr=sr, target_sr=sample_rate, res_type="soxr_vhq" 46 | ) 47 | except Exception as error: 48 | raise RuntimeError(f"An error occurred loading the audio: {error}") 49 | 50 | return audio.flatten() 51 | 52 | 53 | def load_audio_infer( 54 | file, 55 | sample_rate, 56 | **kwargs, 57 | ): 58 | formant_shifting = kwargs.get("formant_shifting", False) 59 | try: 60 | file = file.strip(" ").strip('"').strip("\n").strip('"').strip(" ") 61 | if not os.path.isfile(file): 62 | raise FileNotFoundError(f"File not found: {file}") 63 | audio, sr = sf.read(file) 64 | if len(audio.shape) > 1: 65 | audio = librosa.to_mono(audio.T) 66 | if sr != sample_rate: 67 | audio = librosa.resample( 68 | audio, orig_sr=sr, target_sr=sample_rate, res_type="soxr_vhq" 69 | ) 70 | if formant_shifting: 71 | formant_qfrency = kwargs.get("formant_qfrency", 0.8) 72 | formant_timbre = kwargs.get("formant_timbre", 0.8) 73 | 74 | from stftpitchshift import StftPitchShift 75 | 76 | pitchshifter = StftPitchShift(1024, 32, sample_rate) 77 | audio = pitchshifter.shiftpitch( 78 | audio, 79 | factors=1, 80 | quefrency=formant_qfrency * 1e-3, 81 | distortion=formant_timbre, 82 | ) 83 | except Exception as error: 84 | raise RuntimeError(f"An error occurred loading the audio: {error}") 85 | return np.array(audio).flatten() 86 | 87 | 88 | def format_title(title): 89 | formatted_title = unicodedata.normalize("NFC", title) 90 | formatted_title = re.sub(r"[\u2500-\u257F]+", "", formatted_title) 91 | formatted_title = re.sub(r"[^\w\s.-]", "", formatted_title, flags=re.UNICODE) 92 | formatted_title = re.sub(r"\s+", "_", formatted_title) 93 | return formatted_title 94 | 95 | 96 | def load_embedding(embedder_model, custom_embedder=None): 97 | embedder_root = os.path.join(now_dir, "rvc", "models", "embedders") 98 | embedding_list = { 99 | "contentvec": os.path.join(embedder_root, "contentvec"), 100 | "chinese-hubert-base": os.path.join(embedder_root, "chinese_hubert_base"), 101 | "japanese-hubert-base": os.path.join(embedder_root, "japanese_hubert_base"), 102 | "korean-hubert-base": os.path.join(embedder_root, "korean_hubert_base"), 103 | } 104 | 105 | online_embedders = { 106 | "contentvec": "https://huggingface.co/IAHispano/Applio/resolve/main/Resources/embedders/contentvec/pytorch_model.bin", 107 | "chinese-hubert-base": "https://huggingface.co/IAHispano/Applio/resolve/main/Resources/embedders/chinese_hubert_base/pytorch_model.bin", 108 | "japanese-hubert-base": "https://huggingface.co/IAHispano/Applio/resolve/main/Resources/embedders/japanese_hubert_base/pytorch_model.bin", 109 | "korean-hubert-base": "https://huggingface.co/IAHispano/Applio/resolve/main/Resources/embedders/korean_hubert_base/pytorch_model.bin", 110 | } 111 | 112 | config_files = { 113 | "contentvec": "https://huggingface.co/IAHispano/Applio/resolve/main/Resources/embedders/contentvec/config.json", 114 | "chinese-hubert-base": "https://huggingface.co/IAHispano/Applio/resolve/main/Resources/embedders/chinese_hubert_base/config.json", 115 | "japanese-hubert-base": "https://huggingface.co/IAHispano/Applio/resolve/main/Resources/embedders/japanese_hubert_base/config.json", 116 | "korean-hubert-base": "https://huggingface.co/IAHispano/Applio/resolve/main/Resources/embedders/korean_hubert_base/config.json", 117 | } 118 | 119 | if embedder_model == "custom": 120 | if os.path.exists(custom_embedder): 121 | model_path = custom_embedder 122 | else: 123 | print(f"Custom embedder not found: {custom_embedder}, using contentvec") 124 | model_path = embedding_list["contentvec"] 125 | else: 126 | model_path = embedding_list[embedder_model] 127 | bin_file = os.path.join(model_path, "pytorch_model.bin") 128 | json_file = os.path.join(model_path, "config.json") 129 | os.makedirs(model_path, exist_ok=True) 130 | if not os.path.exists(bin_file): 131 | url = online_embedders[embedder_model] 132 | print(f"Downloading {url} to {model_path}...") 133 | wget.download(url, out=bin_file) 134 | if not os.path.exists(json_file): 135 | url = config_files[embedder_model] 136 | print(f"Downloading {url} to {model_path}...") 137 | wget.download(url, out=json_file) 138 | 139 | models = HubertModelWithFinalProj.from_pretrained(model_path) 140 | return models 141 | -------------------------------------------------------------------------------- /rvc/lib/zluda.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | if torch.cuda.is_available() and torch.cuda.get_device_name().endswith("[ZLUDA]"): 4 | 5 | class STFT: 6 | def __init__(self): 7 | self.device = "cuda" 8 | self.fourier_bases = {} # Cache for Fourier bases 9 | 10 | def _get_fourier_basis(self, n_fft): 11 | # Check if the basis for this n_fft is already cached 12 | if n_fft in self.fourier_bases: 13 | return self.fourier_bases[n_fft] 14 | fourier_basis = torch.fft.fft(torch.eye(n_fft, device="cpu")).to( 15 | self.device 16 | ) 17 | # stack separated real and imaginary components and convert to torch tensor 18 | cutoff = n_fft // 2 + 1 19 | fourier_basis = torch.cat( 20 | [fourier_basis.real[:cutoff], fourier_basis.imag[:cutoff]], dim=0 21 | ) 22 | # cache the tensor and return 23 | self.fourier_bases[n_fft] = fourier_basis 24 | return fourier_basis 25 | 26 | def transform(self, input, n_fft, hop_length, window): 27 | # fetch cached Fourier basis 28 | fourier_basis = self._get_fourier_basis(n_fft) 29 | # apply hann window to Fourier basis 30 | fourier_basis = fourier_basis * window 31 | # pad input to center with reflect 32 | pad_amount = n_fft // 2 33 | input = torch.nn.functional.pad( 34 | input, (pad_amount, pad_amount), mode="reflect" 35 | ) 36 | # separate input into n_fft-sized frames 37 | input_frames = input.unfold(1, n_fft, hop_length).permute(0, 2, 1) 38 | # apply fft to each frame 39 | fourier_transform = torch.matmul(fourier_basis, input_frames) 40 | cutoff = n_fft // 2 + 1 41 | return torch.complex( 42 | fourier_transform[:, :cutoff, :], fourier_transform[:, cutoff:, :] 43 | ) 44 | 45 | stft = STFT() 46 | _torch_stft = torch.stft 47 | 48 | def z_stft(input: torch.Tensor, window: torch.Tensor, *args, **kwargs): 49 | # only optimizing a specific call from rvc.train.mel_processing.MultiScaleMelSpectrogramLoss 50 | if ( 51 | kwargs.get("win_length") == None 52 | and kwargs.get("center") == None 53 | and kwargs.get("return_complex") == True 54 | ): 55 | # use GPU accelerated calculation 56 | return stft.transform( 57 | input, kwargs.get("n_fft"), kwargs.get("hop_length"), window 58 | ) 59 | else: 60 | # simply do the operation on CPU 61 | return _torch_stft( 62 | input=input.cpu(), window=window.cpu(), *args, **kwargs 63 | ).to(input.device) 64 | 65 | def z_jit(f, *_, **__): 66 | f.graph = torch._C.Graph() 67 | return f 68 | 69 | # hijacks 70 | torch.stft = z_stft 71 | torch.jit.script = z_jit 72 | # disabling unsupported cudnn 73 | torch.backends.cudnn.enabled = False 74 | torch.backends.cuda.enable_flash_sdp(False) 75 | torch.backends.cuda.enable_math_sdp(True) 76 | torch.backends.cuda.enable_mem_efficient_sdp(False) 77 | -------------------------------------------------------------------------------- /rvc/models/embedders/.gitkeep: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /rvc/models/embedders/embedders_custom/.gitkeep: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /rvc/models/formant/.gitkeep: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /rvc/models/predictors/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blaisewf/rvc-cli/a732f87a54780e26e4a0e3c91b2d93f848eef6d2/rvc/models/predictors/.gitkeep -------------------------------------------------------------------------------- /rvc/models/pretraineds/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blaisewf/rvc-cli/a732f87a54780e26e4a0e3c91b2d93f848eef6d2/rvc/models/pretraineds/.gitkeep -------------------------------------------------------------------------------- /rvc/models/pretraineds/custom/.gitkeep: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /rvc/models/pretraineds/hifi-gan/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blaisewf/rvc-cli/a732f87a54780e26e4a0e3c91b2d93f848eef6d2/rvc/models/pretraineds/hifi-gan/.gitkeep -------------------------------------------------------------------------------- /rvc/train/extract/preparing_files.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from random import shuffle 4 | from rvc.configs.config import Config 5 | import json 6 | 7 | config = Config() 8 | current_directory = os.getcwd() 9 | 10 | 11 | def generate_config(sample_rate: int, model_path: str): 12 | config_path = os.path.join("rvc", "configs", f"{sample_rate}.json") 13 | config_save_path = os.path.join(model_path, "config.json") 14 | if not os.path.exists(config_save_path): 15 | shutil.copyfile(config_path, config_save_path) 16 | 17 | 18 | def generate_filelist(model_path: str, sample_rate: int, include_mutes: int = 2): 19 | gt_wavs_dir = os.path.join(model_path, "sliced_audios") 20 | feature_dir = os.path.join(model_path, f"extracted") 21 | 22 | f0_dir, f0nsf_dir = None, None 23 | f0_dir = os.path.join(model_path, "f0") 24 | f0nsf_dir = os.path.join(model_path, "f0_voiced") 25 | 26 | gt_wavs_files = set(name.split(".")[0] for name in os.listdir(gt_wavs_dir)) 27 | feature_files = set(name.split(".")[0] for name in os.listdir(feature_dir)) 28 | 29 | f0_files = set(name.split(".")[0] for name in os.listdir(f0_dir)) 30 | f0nsf_files = set(name.split(".")[0] for name in os.listdir(f0nsf_dir)) 31 | names = gt_wavs_files & feature_files & f0_files & f0nsf_files 32 | 33 | options = [] 34 | mute_base_path = os.path.join(current_directory, "logs", "mute") 35 | sids = [] 36 | for name in names: 37 | sid = name.split("_")[0] 38 | if sid not in sids: 39 | sids.append(sid) 40 | options.append( 41 | f"{os.path.join(gt_wavs_dir, name)}.wav|{os.path.join(feature_dir, name)}.npy|{os.path.join(f0_dir, name)}.wav.npy|{os.path.join(f0nsf_dir, name)}.wav.npy|{sid}" 42 | ) 43 | 44 | if include_mutes > 0: 45 | mute_audio_path = os.path.join( 46 | mute_base_path, "sliced_audios", f"mute{sample_rate}.wav" 47 | ) 48 | mute_feature_path = os.path.join(mute_base_path, f"extracted", "mute.npy") 49 | mute_f0_path = os.path.join(mute_base_path, "f0", "mute.wav.npy") 50 | mute_f0nsf_path = os.path.join(mute_base_path, "f0_voiced", "mute.wav.npy") 51 | 52 | # adding x files per sid 53 | for sid in sids * include_mutes: 54 | options.append( 55 | f"{mute_audio_path}|{mute_feature_path}|{mute_f0_path}|{mute_f0nsf_path}|{sid}" 56 | ) 57 | 58 | file_path = os.path.join(model_path, "model_info.json") 59 | if os.path.exists(file_path): 60 | with open(file_path, "r") as f: 61 | data = json.load(f) 62 | else: 63 | data = {} 64 | data.update( 65 | { 66 | "speakers_id": len(sids), 67 | } 68 | ) 69 | with open(file_path, "w") as f: 70 | json.dump(data, f, indent=4) 71 | 72 | shuffle(options) 73 | 74 | with open(os.path.join(model_path, "filelist.txt"), "w") as f: 75 | f.write("\n".join(options)) 76 | -------------------------------------------------------------------------------- /rvc/train/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def feature_loss(fmap_r, fmap_g): 5 | """ 6 | Compute the feature loss between reference and generated feature maps. 7 | 8 | Args: 9 | fmap_r (list of torch.Tensor): List of reference feature maps. 10 | fmap_g (list of torch.Tensor): List of generated feature maps. 11 | """ 12 | return 2 * sum( 13 | torch.mean(torch.abs(rl - gl)) 14 | for dr, dg in zip(fmap_r, fmap_g) 15 | for rl, gl in zip(dr, dg) 16 | ) 17 | 18 | 19 | def discriminator_loss(disc_real_outputs, disc_generated_outputs): 20 | """ 21 | Compute the discriminator loss for real and generated outputs. 22 | 23 | Args: 24 | disc_real_outputs (list of torch.Tensor): List of discriminator outputs for real samples. 25 | disc_generated_outputs (list of torch.Tensor): List of discriminator outputs for generated samples. 26 | """ 27 | loss = 0 28 | r_losses = [] 29 | g_losses = [] 30 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs): 31 | r_loss = torch.mean((1 - dr.float()) ** 2) 32 | g_loss = torch.mean(dg.float() ** 2) 33 | 34 | # r_losses.append(r_loss.item()) 35 | # g_losses.append(g_loss.item()) 36 | loss += r_loss + g_loss 37 | 38 | return loss, r_losses, g_losses 39 | 40 | 41 | def generator_loss(disc_outputs): 42 | """ 43 | Compute the generator loss based on discriminator outputs. 44 | 45 | Args: 46 | disc_outputs (list of torch.Tensor): List of discriminator outputs for generated samples. 47 | """ 48 | loss = 0 49 | gen_losses = [] 50 | for dg in disc_outputs: 51 | l = torch.mean((1 - dg.float()) ** 2) 52 | # gen_losses.append(l.item()) 53 | loss += l 54 | 55 | return loss, gen_losses 56 | 57 | 58 | def discriminator_loss_scaled(disc_real, disc_fake, scale=1.0): 59 | loss = 0 60 | for i, (d_real, d_fake) in enumerate(zip(disc_real, disc_fake)): 61 | real_loss = torch.mean((1 - d_real) ** 2) 62 | fake_loss = torch.mean(d_fake**2) 63 | _loss = real_loss + fake_loss 64 | loss += _loss if i < len(disc_real) / 2 else scale * _loss 65 | return loss, None, None 66 | 67 | 68 | def generator_loss_scaled(disc_outputs, scale=1.0): 69 | loss = 0 70 | for i, d_fake in enumerate(disc_outputs): 71 | d_fake = d_fake.float() 72 | _loss = torch.mean((1 - d_fake) ** 2) 73 | loss += _loss if i < len(disc_outputs) / 2 else scale * _loss 74 | return loss, None, None 75 | 76 | 77 | def discriminator_loss_scaled(disc_real, disc_fake, scale=1.0): 78 | """ 79 | Compute the scaled discriminator loss for real and generated outputs. 80 | 81 | Args: 82 | disc_real (list of torch.Tensor): List of discriminator outputs for real samples. 83 | disc_fake (list of torch.Tensor): List of discriminator outputs for generated samples. 84 | scale (float, optional): Scaling factor applied to losses beyond the midpoint. Default is 1.0. 85 | """ 86 | midpoint = len(disc_real) // 2 87 | losses = [] 88 | for i, (d_real, d_fake) in enumerate(zip(disc_real, disc_fake)): 89 | real_loss = (1 - d_real).pow(2).mean() 90 | fake_loss = d_fake.pow(2).mean() 91 | total_loss = real_loss + fake_loss 92 | if i >= midpoint: 93 | total_loss *= scale 94 | losses.append(total_loss) 95 | loss = sum(losses) 96 | return loss, None, None 97 | 98 | 99 | def generator_loss_scaled(disc_outputs, scale=1.0): 100 | """ 101 | Compute the scaled generator loss based on discriminator outputs. 102 | 103 | Args: 104 | disc_outputs (list of torch.Tensor): List of discriminator outputs for generated samples. 105 | scale (float, optional): Scaling factor applied to losses beyond the midpoint. Default is 1.0. 106 | """ 107 | midpoint = len(disc_outputs) // 2 108 | losses = [] 109 | for i, d_fake in enumerate(disc_outputs): 110 | loss_value = (1 - d_fake).pow(2).mean() 111 | if i >= midpoint: 112 | loss_value *= scale 113 | losses.append(loss_value) 114 | loss = sum(losses) 115 | return loss, None, None 116 | 117 | 118 | def kl_loss(z_p, logs_q, m_p, logs_p, z_mask): 119 | """ 120 | Compute the Kullback-Leibler divergence loss. 121 | 122 | Args: 123 | z_p (torch.Tensor): Latent variable z_p [b, h, t_t]. 124 | logs_q (torch.Tensor): Log variance of q [b, h, t_t]. 125 | m_p (torch.Tensor): Mean of p [b, h, t_t]. 126 | logs_p (torch.Tensor): Log variance of p [b, h, t_t]. 127 | z_mask (torch.Tensor): Mask for the latent variables [b, h, t_t]. 128 | """ 129 | kl = logs_p - logs_q - 0.5 + 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2 * logs_p) 130 | kl = (kl * z_mask).sum() 131 | loss = kl / z_mask.sum() 132 | return loss 133 | -------------------------------------------------------------------------------- /rvc/train/mel_processing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data 3 | from librosa.filters import mel as librosa_mel_fn 4 | 5 | 6 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 7 | """ 8 | Dynamic range compression using log10. 9 | 10 | Args: 11 | x (torch.Tensor): Input tensor. 12 | C (float, optional): Scaling factor. Defaults to 1. 13 | clip_val (float, optional): Minimum value for clamping. Defaults to 1e-5. 14 | """ 15 | return torch.log(torch.clamp(x, min=clip_val) * C) 16 | 17 | 18 | def dynamic_range_decompression_torch(x, C=1): 19 | """ 20 | Dynamic range decompression using exp. 21 | 22 | Args: 23 | x (torch.Tensor): Input tensor. 24 | C (float, optional): Scaling factor. Defaults to 1. 25 | """ 26 | return torch.exp(x) / C 27 | 28 | 29 | def spectral_normalize_torch(magnitudes): 30 | """ 31 | Spectral normalization using dynamic range compression. 32 | 33 | Args: 34 | magnitudes (torch.Tensor): Magnitude spectrogram. 35 | """ 36 | return dynamic_range_compression_torch(magnitudes) 37 | 38 | 39 | def spectral_de_normalize_torch(magnitudes): 40 | """ 41 | Spectral de-normalization using dynamic range decompression. 42 | 43 | Args: 44 | magnitudes (torch.Tensor): Normalized spectrogram. 45 | """ 46 | return dynamic_range_decompression_torch(magnitudes) 47 | 48 | 49 | mel_basis = {} 50 | hann_window = {} 51 | 52 | 53 | def spectrogram_torch(y, n_fft, hop_size, win_size, center=False): 54 | """ 55 | Compute the spectrogram of a signal using STFT. 56 | 57 | Args: 58 | y (torch.Tensor): Input signal. 59 | n_fft (int): FFT window size. 60 | hop_size (int): Hop size between frames. 61 | win_size (int): Window size. 62 | center (bool, optional): Whether to center the window. Defaults to False. 63 | """ 64 | global hann_window 65 | dtype_device = str(y.dtype) + "_" + str(y.device) 66 | wnsize_dtype_device = str(win_size) + "_" + dtype_device 67 | if wnsize_dtype_device not in hann_window: 68 | hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to( 69 | dtype=y.dtype, device=y.device 70 | ) 71 | 72 | y = torch.nn.functional.pad( 73 | y.unsqueeze(1), 74 | (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), 75 | mode="reflect", 76 | ) 77 | y = y.squeeze(1) 78 | 79 | spec = torch.stft( 80 | y, 81 | n_fft=n_fft, 82 | hop_length=hop_size, 83 | win_length=win_size, 84 | window=hann_window[wnsize_dtype_device], 85 | center=center, 86 | pad_mode="reflect", 87 | normalized=False, 88 | onesided=True, 89 | return_complex=True, 90 | ) 91 | 92 | spec = torch.sqrt(spec.real.pow(2) + spec.imag.pow(2) + 1e-6) 93 | 94 | return spec 95 | 96 | 97 | def spec_to_mel_torch(spec, n_fft, num_mels, sample_rate, fmin, fmax): 98 | """ 99 | Convert a spectrogram to a mel-spectrogram. 100 | 101 | Args: 102 | spec (torch.Tensor): Magnitude spectrogram. 103 | n_fft (int): FFT window size. 104 | num_mels (int): Number of mel frequency bins. 105 | sample_rate (int): Sampling rate of the audio signal. 106 | fmin (float): Minimum frequency. 107 | fmax (float): Maximum frequency. 108 | """ 109 | global mel_basis 110 | dtype_device = str(spec.dtype) + "_" + str(spec.device) 111 | fmax_dtype_device = str(fmax) + "_" + dtype_device 112 | if fmax_dtype_device not in mel_basis: 113 | mel = librosa_mel_fn( 114 | sr=sample_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax 115 | ) 116 | mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to( 117 | dtype=spec.dtype, device=spec.device 118 | ) 119 | 120 | melspec = torch.matmul(mel_basis[fmax_dtype_device], spec) 121 | melspec = spectral_normalize_torch(melspec) 122 | return melspec 123 | 124 | 125 | def mel_spectrogram_torch( 126 | y, n_fft, num_mels, sample_rate, hop_size, win_size, fmin, fmax, center=False 127 | ): 128 | """ 129 | Compute the mel-spectrogram of a signal. 130 | 131 | Args: 132 | y (torch.Tensor): Input signal. 133 | n_fft (int): FFT window size. 134 | num_mels (int): Number of mel frequency bins. 135 | sample_rate (int): Sampling rate of the audio signal. 136 | hop_size (int): Hop size between frames. 137 | win_size (int): Window size. 138 | fmin (float): Minimum frequency. 139 | fmax (float): Maximum frequency. 140 | center (bool, optional): Whether to center the window. Defaults to False. 141 | """ 142 | spec = spectrogram_torch(y, n_fft, hop_size, win_size, center) 143 | 144 | melspec = spec_to_mel_torch(spec, n_fft, num_mels, sample_rate, fmin, fmax) 145 | 146 | return melspec 147 | 148 | 149 | def compute_window_length(n_mels: int, sample_rate: int): 150 | f_min = 0 151 | f_max = sample_rate / 2 152 | window_length_seconds = 8 * n_mels / (f_max - f_min) 153 | window_length = int(window_length_seconds * sample_rate) 154 | return 2 ** (window_length.bit_length() - 1) 155 | 156 | 157 | class MultiScaleMelSpectrogramLoss(torch.nn.Module): 158 | 159 | def __init__( 160 | self, 161 | sample_rate: int = 24000, 162 | n_mels: list[int] = [5, 10, 20, 40, 80, 160, 320, 480], 163 | loss_fn=torch.nn.L1Loss(), 164 | ): 165 | super().__init__() 166 | self.sample_rate = sample_rate 167 | self.loss_fn = loss_fn 168 | self.log_base = torch.log(torch.tensor(10.0)) 169 | self.stft_params: list[tuple] = [] 170 | self.hann_window: dict[int, torch.Tensor] = {} 171 | self.mel_banks: dict[int, torch.Tensor] = {} 172 | 173 | self.stft_params = [ 174 | (mel, compute_window_length(mel, sample_rate), self.sample_rate // 100) 175 | for mel in n_mels 176 | ] 177 | 178 | def mel_spectrogram( 179 | self, 180 | wav: torch.Tensor, 181 | n_mels: int, 182 | window_length: int, 183 | hop_length: int, 184 | ): 185 | # IDs for caching 186 | dtype_device = str(wav.dtype) + "_" + str(wav.device) 187 | win_dtype_device = str(window_length) + "_" + dtype_device 188 | mel_dtype_device = str(n_mels) + "_" + dtype_device 189 | # caching hann window 190 | if win_dtype_device not in self.hann_window: 191 | self.hann_window[win_dtype_device] = torch.hann_window( 192 | window_length, device=wav.device, dtype=torch.float32 193 | ) 194 | 195 | wav = wav.squeeze(1) # -> torch(B, T) 196 | 197 | stft = torch.stft( 198 | wav.float(), 199 | n_fft=window_length, 200 | hop_length=hop_length, 201 | window=self.hann_window[win_dtype_device], 202 | return_complex=True, 203 | ) # -> torch (B, window_length // 2 + 1, (T - window_length)/hop_length + 1) 204 | 205 | magnitude = torch.sqrt(stft.real.pow(2) + stft.imag.pow(2) + 1e-6) 206 | 207 | # caching mel filter 208 | if mel_dtype_device not in self.mel_banks: 209 | self.mel_banks[mel_dtype_device] = torch.from_numpy( 210 | librosa_mel_fn( 211 | sr=self.sample_rate, 212 | n_mels=n_mels, 213 | n_fft=window_length, 214 | fmin=0, 215 | fmax=None, 216 | ) 217 | ).to(device=wav.device, dtype=torch.float32) 218 | 219 | mel_spectrogram = torch.matmul( 220 | self.mel_banks[mel_dtype_device], magnitude 221 | ) # torch(B, n_mels, stft.frames) 222 | return mel_spectrogram 223 | 224 | def forward( 225 | self, real: torch.Tensor, fake: torch.Tensor 226 | ): # real: torch(B, 1, T) , fake: torch(B, 1, T) 227 | loss = 0.0 228 | for p in self.stft_params: 229 | real_mels = self.mel_spectrogram(real, *p) 230 | fake_mels = self.mel_spectrogram(fake, *p) 231 | real_logmels = torch.log(real_mels.clamp(min=1e-5)) / self.log_base 232 | fake_logmels = torch.log(fake_mels.clamp(min=1e-5)) / self.log_base 233 | loss += self.loss_fn(real_logmels, fake_logmels) 234 | return loss 235 | -------------------------------------------------------------------------------- /rvc/train/process/change_info.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | 5 | def change_info(path, info, name): 6 | try: 7 | ckpt = torch.load(path, map_location="cpu", weights_only=True) 8 | ckpt["info"] = info 9 | 10 | if not name: 11 | name = os.path.splitext(os.path.basename(path))[0] 12 | 13 | target_dir = os.path.join("logs", name) 14 | os.makedirs(target_dir, exist_ok=True) 15 | 16 | torch.save(ckpt, os.path.join(target_dir, f"{name}.pth")) 17 | 18 | return "Success." 19 | 20 | except Exception as error: 21 | print(f"An error occurred while changing the info: {error}") 22 | return f"Error: {error}" 23 | -------------------------------------------------------------------------------- /rvc/train/process/extract_index.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from multiprocessing import cpu_count 4 | 5 | import faiss 6 | import numpy as np 7 | from sklearn.cluster import MiniBatchKMeans 8 | 9 | # Parse command line arguments 10 | exp_dir = str(sys.argv[1]) 11 | index_algorithm = str(sys.argv[2]) 12 | 13 | try: 14 | feature_dir = os.path.join(exp_dir, f"extracted") 15 | model_name = os.path.basename(exp_dir) 16 | 17 | if not os.path.exists(feature_dir): 18 | print( 19 | f"Feature to generate index file not found at {feature_dir}. Did you run preprocessing and feature extraction steps?" 20 | ) 21 | sys.exit(1) 22 | 23 | index_filename_added = f"{model_name}.index" 24 | index_filepath_added = os.path.join(exp_dir, index_filename_added) 25 | 26 | if os.path.exists(index_filepath_added): 27 | pass 28 | else: 29 | npys = [] 30 | listdir_res = sorted(os.listdir(feature_dir)) 31 | 32 | for name in listdir_res: 33 | file_path = os.path.join(feature_dir, name) 34 | phone = np.load(file_path) 35 | npys.append(phone) 36 | 37 | big_npy = np.concatenate(npys, axis=0) 38 | 39 | big_npy_idx = np.arange(big_npy.shape[0]) 40 | np.random.shuffle(big_npy_idx) 41 | big_npy = big_npy[big_npy_idx] 42 | 43 | if big_npy.shape[0] > 2e5 and ( 44 | index_algorithm == "Auto" or index_algorithm == "KMeans" 45 | ): 46 | big_npy = ( 47 | MiniBatchKMeans( 48 | n_clusters=10000, 49 | verbose=True, 50 | batch_size=256 * cpu_count(), 51 | compute_labels=False, 52 | init="random", 53 | ) 54 | .fit(big_npy) 55 | .cluster_centers_ 56 | ) 57 | 58 | n_ivf = min(int(16 * np.sqrt(big_npy.shape[0])), big_npy.shape[0] // 39) 59 | 60 | # index_added 61 | index_added = faiss.index_factory(768, f"IVF{n_ivf},Flat") 62 | index_ivf_added = faiss.extract_index_ivf(index_added) 63 | index_ivf_added.nprobe = 1 64 | index_added.train(big_npy) 65 | 66 | batch_size_add = 8192 67 | for i in range(0, big_npy.shape[0], batch_size_add): 68 | index_added.add(big_npy[i : i + batch_size_add]) 69 | 70 | faiss.write_index(index_added, index_filepath_added) 71 | print(f"Saved index file '{index_filepath_added}'") 72 | 73 | except Exception as error: 74 | print(f"An error occurred extracting the index: {error}") 75 | print( 76 | "If you are running this code in a virtual environment, make sure you have enough GPU available to generate the Index file." 77 | ) 78 | -------------------------------------------------------------------------------- /rvc/train/process/extract_model.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import hashlib 3 | import json 4 | import os 5 | import sys 6 | from collections import OrderedDict 7 | 8 | import torch 9 | 10 | now_dir = os.getcwd() 11 | sys.path.append(now_dir) 12 | 13 | 14 | def replace_keys_in_dict(d, old_key_part, new_key_part): 15 | if isinstance(d, OrderedDict): 16 | updated_dict = OrderedDict() 17 | else: 18 | updated_dict = {} 19 | for key, value in d.items(): 20 | new_key = key.replace(old_key_part, new_key_part) 21 | if isinstance(value, dict): 22 | value = replace_keys_in_dict(value, old_key_part, new_key_part) 23 | updated_dict[new_key] = value 24 | return updated_dict 25 | 26 | 27 | def extract_model( 28 | ckpt, 29 | sr, 30 | name, 31 | model_path, 32 | epoch, 33 | step, 34 | hps, 35 | overtrain_info, 36 | vocoder, 37 | pitch_guidance=True, 38 | version="v2", 39 | ): 40 | try: 41 | model_dir = os.path.dirname(model_path) 42 | os.makedirs(model_dir, exist_ok=True) 43 | 44 | if os.path.exists(os.path.join(model_dir, "model_info.json")): 45 | with open(os.path.join(model_dir, "model_info.json"), "r") as f: 46 | data = json.load(f) 47 | dataset_length = data.get("total_dataset_duration", None) 48 | embedder_model = data.get("embedder_model", None) 49 | speakers_id = data.get("speakers_id", 1) 50 | else: 51 | dataset_length = None 52 | 53 | with open(os.path.join(now_dir, "assets", "config.json"), "r") as f: 54 | data = json.load(f) 55 | model_author = data.get("model_author", None) 56 | 57 | opt = OrderedDict( 58 | weight={ 59 | key: value.half() for key, value in ckpt.items() if "enc_q" not in key 60 | } 61 | ) 62 | opt["config"] = [ 63 | hps.data.filter_length // 2 + 1, 64 | 32, 65 | hps.model.inter_channels, 66 | hps.model.hidden_channels, 67 | hps.model.filter_channels, 68 | hps.model.n_heads, 69 | hps.model.n_layers, 70 | hps.model.kernel_size, 71 | hps.model.p_dropout, 72 | hps.model.resblock, 73 | hps.model.resblock_kernel_sizes, 74 | hps.model.resblock_dilation_sizes, 75 | hps.model.upsample_rates, 76 | hps.model.upsample_initial_channel, 77 | hps.model.upsample_kernel_sizes, 78 | hps.model.spk_embed_dim, 79 | hps.model.gin_channels, 80 | hps.data.sample_rate, 81 | ] 82 | 83 | opt["epoch"] = epoch 84 | opt["step"] = step 85 | opt["sr"] = sr 86 | opt["f0"] = pitch_guidance 87 | opt["version"] = version 88 | opt["creation_date"] = datetime.datetime.now().isoformat() 89 | 90 | hash_input = f"{name}-{epoch}-{step}-{sr}-{version}-{opt['config']}" 91 | opt["model_hash"] = hashlib.sha256(hash_input.encode()).hexdigest() 92 | opt["overtrain_info"] = overtrain_info 93 | opt["dataset_length"] = dataset_length 94 | opt["model_name"] = name 95 | opt["author"] = model_author 96 | opt["embedder_model"] = embedder_model 97 | opt["speakers_id"] = speakers_id 98 | opt["vocoder"] = vocoder 99 | 100 | torch.save( 101 | replace_keys_in_dict( 102 | replace_keys_in_dict( 103 | opt, ".parametrizations.weight.original1", ".weight_v" 104 | ), 105 | ".parametrizations.weight.original0", 106 | ".weight_g", 107 | ), 108 | model_path, 109 | ) 110 | 111 | print(f"Saved model '{model_path}' (epoch {epoch} and step {step})") 112 | 113 | except Exception as error: 114 | print(f"An error occurred extracting the model: {error}") 115 | -------------------------------------------------------------------------------- /rvc/train/process/model_blender.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from collections import OrderedDict 4 | 5 | 6 | def extract(ckpt): 7 | a = ckpt["model"] 8 | opt = OrderedDict() 9 | opt["weight"] = {} 10 | for key in a.keys(): 11 | if "enc_q" in key: 12 | continue 13 | opt["weight"][key] = a[key] 14 | return opt 15 | 16 | 17 | def model_blender(name, path1, path2, ratio): 18 | try: 19 | message = f"Model {path1} and {path2} are merged with alpha {ratio}." 20 | ckpt1 = torch.load(path1, map_location="cpu", weights_only=True) 21 | ckpt2 = torch.load(path2, map_location="cpu", weights_only=True) 22 | 23 | if ckpt1["sr"] != ckpt2["sr"]: 24 | return "The sample rates of the two models are not the same." 25 | 26 | cfg = ckpt1["config"] 27 | cfg_f0 = ckpt1["f0"] 28 | cfg_version = ckpt1["version"] 29 | cfg_sr = ckpt1["sr"] 30 | vocoder = ckpt1.get("vocoder", "HiFi-GAN") 31 | 32 | if "model" in ckpt1: 33 | ckpt1 = extract(ckpt1) 34 | else: 35 | ckpt1 = ckpt1["weight"] 36 | if "model" in ckpt2: 37 | ckpt2 = extract(ckpt2) 38 | else: 39 | ckpt2 = ckpt2["weight"] 40 | 41 | if sorted(list(ckpt1.keys())) != sorted(list(ckpt2.keys())): 42 | return "Fail to merge the models. The model architectures are not the same." 43 | 44 | opt = OrderedDict() 45 | opt["weight"] = {} 46 | for key in ckpt1.keys(): 47 | if key == "emb_g.weight" and ckpt1[key].shape != ckpt2[key].shape: 48 | min_shape0 = min(ckpt1[key].shape[0], ckpt2[key].shape[0]) 49 | opt["weight"][key] = ( 50 | ratio * (ckpt1[key][:min_shape0].float()) 51 | + (1 - ratio) * (ckpt2[key][:min_shape0].float()) 52 | ).half() 53 | else: 54 | opt["weight"][key] = ( 55 | ratio * (ckpt1[key].float()) + (1 - ratio) * (ckpt2[key].float()) 56 | ).half() 57 | 58 | opt["config"] = cfg 59 | opt["sr"] = cfg_sr 60 | opt["f0"] = cfg_f0 61 | opt["version"] = cfg_version 62 | opt["info"] = message 63 | opt["vocoder"] = vocoder 64 | 65 | torch.save(opt, os.path.join("logs", f"{name}.pth")) 66 | print(message) 67 | return message, os.path.join("logs", f"{name}.pth") 68 | except Exception as error: 69 | print(f"An error occurred blending the models: {error}") 70 | return error 71 | -------------------------------------------------------------------------------- /rvc/train/process/model_information.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from datetime import datetime 3 | 4 | 5 | def prettify_date(date_str): 6 | if date_str is None: 7 | return "None" 8 | try: 9 | date_time_obj = datetime.strptime(date_str, "%Y-%m-%dT%H:%M:%S.%f") 10 | return date_time_obj.strftime("%Y-%m-%d %H:%M:%S") 11 | except ValueError: 12 | return "Invalid date format" 13 | 14 | 15 | def model_information(path): 16 | model_data = torch.load(path, map_location="cpu", weights_only=True) 17 | 18 | print(f"Loaded model from {path}") 19 | 20 | model_name = model_data.get("model_name", "None") 21 | epochs = model_data.get("epoch", "None") 22 | steps = model_data.get("step", "None") 23 | sr = model_data.get("sr", "None") 24 | f0 = model_data.get("f0", "None") 25 | dataset_length = model_data.get("dataset_length", "None") 26 | vocoder = model_data.get("vocoder", "None") 27 | creation_date = model_data.get("creation_date", "None") 28 | model_hash = model_data.get("model_hash", None) 29 | overtrain_info = model_data.get("overtrain_info", "None") 30 | model_author = model_data.get("author", "None") 31 | embedder_model = model_data.get("embedder_model", "None") 32 | speakers_id = model_data.get("speakers_id", 0) 33 | 34 | creation_date_str = prettify_date(creation_date) if creation_date else "None" 35 | 36 | return ( 37 | f"Model Name: {model_name}\n" 38 | f"Model Creator: {model_author}\n" 39 | f"Epochs: {epochs}\n" 40 | f"Steps: {steps}\n" 41 | f"Vocoder: {vocoder}\n" 42 | f"Sampling Rate: {sr}\n" 43 | f"Dataset Length: {dataset_length}\n" 44 | f"Creation Date: {creation_date_str}\n" 45 | f"Overtrain Info: {overtrain_info}\n" 46 | f"Embedder Model: {embedder_model}\n" 47 | f"Max Speakers ID: {speakers_id}" 48 | f"Hash: {model_hash}\n" 49 | ) 50 | -------------------------------------------------------------------------------- /uvr/__init__.py: -------------------------------------------------------------------------------- 1 | from .separator import Separator 2 | -------------------------------------------------------------------------------- /uvr/architectures/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blaisewf/rvc-cli/a732f87a54780e26e4a0e3c91b2d93f848eef6d2/uvr/architectures/__init__.py -------------------------------------------------------------------------------- /uvr/uvr_lib_v5/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blaisewf/rvc-cli/a732f87a54780e26e4a0e3c91b2d93f848eef6d2/uvr/uvr_lib_v5/__init__.py -------------------------------------------------------------------------------- /uvr/uvr_lib_v5/attend.py: -------------------------------------------------------------------------------- 1 | from functools import wraps 2 | from packaging import version 3 | from collections import namedtuple 4 | 5 | import torch 6 | from torch import nn, einsum 7 | import torch.nn.functional as F 8 | 9 | from einops import rearrange, reduce 10 | 11 | # constants 12 | 13 | FlashAttentionConfig = namedtuple( 14 | "FlashAttentionConfig", ["enable_flash", "enable_math", "enable_mem_efficient"] 15 | ) 16 | 17 | # helpers 18 | 19 | 20 | def exists(val): 21 | return val is not None 22 | 23 | 24 | def once(fn): 25 | called = False 26 | 27 | @wraps(fn) 28 | def inner(x): 29 | nonlocal called 30 | if called: 31 | return 32 | called = True 33 | return fn(x) 34 | 35 | return inner 36 | 37 | 38 | print_once = once(print) 39 | 40 | # main class 41 | 42 | 43 | class Attend(nn.Module): 44 | def __init__(self, dropout=0.0, flash=False): 45 | super().__init__() 46 | self.dropout = dropout 47 | self.attn_dropout = nn.Dropout(dropout) 48 | 49 | self.flash = flash 50 | assert not ( 51 | flash and version.parse(torch.__version__) < version.parse("2.0.0") 52 | ), "in order to use flash attention, you must be using pytorch 2.0 or above" 53 | 54 | # determine efficient attention configs for cuda and cpu 55 | 56 | self.cpu_config = FlashAttentionConfig(True, True, True) 57 | self.cuda_config = None 58 | 59 | if not torch.cuda.is_available() or not flash: 60 | return 61 | 62 | device_properties = torch.cuda.get_device_properties(torch.device("cuda")) 63 | 64 | if device_properties.major == 8 and device_properties.minor == 0: 65 | print_once( 66 | "A100 GPU detected, using flash attention if input tensor is on cuda" 67 | ) 68 | self.cuda_config = FlashAttentionConfig(True, False, False) 69 | else: 70 | self.cuda_config = FlashAttentionConfig(False, True, True) 71 | 72 | def flash_attn(self, q, k, v): 73 | _, heads, q_len, _, k_len, is_cuda, device = ( 74 | *q.shape, 75 | k.shape[-2], 76 | q.is_cuda, 77 | q.device, 78 | ) 79 | 80 | # Check if there is a compatible device for flash attention 81 | 82 | config = self.cuda_config if is_cuda else self.cpu_config 83 | 84 | # pytorch 2.0 flash attn: q, k, v, mask, dropout, softmax_scale 85 | 86 | with torch.backends.cuda.sdp_kernel(**config._asdict()): 87 | out = F.scaled_dot_product_attention( 88 | q, k, v, dropout_p=self.dropout if self.training else 0.0 89 | ) 90 | 91 | return out 92 | 93 | def forward(self, q, k, v): 94 | """ 95 | einstein notation 96 | b - batch 97 | h - heads 98 | n, i, j - sequence length (base sequence length, source, target) 99 | d - feature dimension 100 | """ 101 | 102 | q_len, k_len, device = q.shape[-2], k.shape[-2], q.device 103 | 104 | scale = q.shape[-1] ** -0.5 105 | 106 | if self.flash: 107 | return self.flash_attn(q, k, v) 108 | 109 | # similarity 110 | 111 | sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale 112 | 113 | # attention 114 | 115 | attn = sim.softmax(dim=-1) 116 | attn = self.attn_dropout(attn) 117 | 118 | # aggregate values 119 | 120 | out = einsum(f"b h i j, b h j d -> b h i d", attn, v) 121 | 122 | return out 123 | -------------------------------------------------------------------------------- /uvr/uvr_lib_v5/demucs/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /uvr/uvr_lib_v5/demucs/model_v2.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | import julius 10 | from torch import nn 11 | from .tasnet_v2 import ConvTasNet 12 | 13 | from .utils import capture_init, center_trim 14 | 15 | 16 | class BLSTM(nn.Module): 17 | def __init__(self, dim, layers=1): 18 | super().__init__() 19 | self.lstm = nn.LSTM( 20 | bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim 21 | ) 22 | self.linear = nn.Linear(2 * dim, dim) 23 | 24 | def forward(self, x): 25 | x = x.permute(2, 0, 1) 26 | x = self.lstm(x)[0] 27 | x = self.linear(x) 28 | x = x.permute(1, 2, 0) 29 | return x 30 | 31 | 32 | def rescale_conv(conv, reference): 33 | std = conv.weight.std().detach() 34 | scale = (std / reference) ** 0.5 35 | conv.weight.data /= scale 36 | if conv.bias is not None: 37 | conv.bias.data /= scale 38 | 39 | 40 | def rescale_module(module, reference): 41 | for sub in module.modules(): 42 | if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d)): 43 | rescale_conv(sub, reference) 44 | 45 | 46 | def auto_load_demucs_model_v2(sources, demucs_model_name): 47 | 48 | if "48" in demucs_model_name: 49 | channels = 48 50 | elif "unittest" in demucs_model_name: 51 | channels = 4 52 | else: 53 | channels = 64 54 | 55 | if "tasnet" in demucs_model_name: 56 | init_demucs_model = ConvTasNet(sources, X=10) 57 | else: 58 | init_demucs_model = Demucs(sources, channels=channels) 59 | 60 | return init_demucs_model 61 | 62 | 63 | class Demucs(nn.Module): 64 | @capture_init 65 | def __init__( 66 | self, 67 | sources, 68 | audio_channels=2, 69 | channels=64, 70 | depth=6, 71 | rewrite=True, 72 | glu=True, 73 | rescale=0.1, 74 | resample=True, 75 | kernel_size=8, 76 | stride=4, 77 | growth=2.0, 78 | lstm_layers=2, 79 | context=3, 80 | normalize=False, 81 | samplerate=44100, 82 | segment_length=4 * 10 * 44100, 83 | ): 84 | """ 85 | Args: 86 | sources (list[str]): list of source names 87 | audio_channels (int): stereo or mono 88 | channels (int): first convolution channels 89 | depth (int): number of encoder/decoder layers 90 | rewrite (bool): add 1x1 convolution to each encoder layer 91 | and a convolution to each decoder layer. 92 | For the decoder layer, `context` gives the kernel size. 93 | glu (bool): use glu instead of ReLU 94 | resample_input (bool): upsample x2 the input and downsample /2 the output. 95 | rescale (int): rescale initial weights of convolutions 96 | to get their standard deviation closer to `rescale` 97 | kernel_size (int): kernel size for convolutions 98 | stride (int): stride for convolutions 99 | growth (float): multiply (resp divide) number of channels by that 100 | for each layer of the encoder (resp decoder) 101 | lstm_layers (int): number of lstm layers, 0 = no lstm 102 | context (int): kernel size of the convolution in the 103 | decoder before the transposed convolution. If > 1, 104 | will provide some context from neighboring time 105 | steps. 106 | samplerate (int): stored as meta information for easing 107 | future evaluations of the model. 108 | segment_length (int): stored as meta information for easing 109 | future evaluations of the model. Length of the segments on which 110 | the model was trained. 111 | """ 112 | 113 | super().__init__() 114 | self.audio_channels = audio_channels 115 | self.sources = sources 116 | self.kernel_size = kernel_size 117 | self.context = context 118 | self.stride = stride 119 | self.depth = depth 120 | self.resample = resample 121 | self.channels = channels 122 | self.normalize = normalize 123 | self.samplerate = samplerate 124 | self.segment_length = segment_length 125 | 126 | self.encoder = nn.ModuleList() 127 | self.decoder = nn.ModuleList() 128 | 129 | if glu: 130 | activation = nn.GLU(dim=1) 131 | ch_scale = 2 132 | else: 133 | activation = nn.ReLU() 134 | ch_scale = 1 135 | in_channels = audio_channels 136 | for index in range(depth): 137 | encode = [] 138 | encode += [nn.Conv1d(in_channels, channels, kernel_size, stride), nn.ReLU()] 139 | if rewrite: 140 | encode += [nn.Conv1d(channels, ch_scale * channels, 1), activation] 141 | self.encoder.append(nn.Sequential(*encode)) 142 | 143 | decode = [] 144 | if index > 0: 145 | out_channels = in_channels 146 | else: 147 | out_channels = len(self.sources) * audio_channels 148 | if rewrite: 149 | decode += [ 150 | nn.Conv1d(channels, ch_scale * channels, context), 151 | activation, 152 | ] 153 | decode += [nn.ConvTranspose1d(channels, out_channels, kernel_size, stride)] 154 | if index > 0: 155 | decode.append(nn.ReLU()) 156 | self.decoder.insert(0, nn.Sequential(*decode)) 157 | in_channels = channels 158 | channels = int(growth * channels) 159 | 160 | channels = in_channels 161 | 162 | if lstm_layers: 163 | self.lstm = BLSTM(channels, lstm_layers) 164 | else: 165 | self.lstm = None 166 | 167 | if rescale: 168 | rescale_module(self, reference=rescale) 169 | 170 | def valid_length(self, length): 171 | """ 172 | Return the nearest valid length to use with the model so that 173 | there is no time steps left over in a convolutions, e.g. for all 174 | layers, size of the input - kernel_size % stride = 0. 175 | 176 | If the mixture has a valid length, the estimated sources 177 | will have exactly the same length when context = 1. If context > 1, 178 | the two signals can be center trimmed to match. 179 | 180 | For training, extracts should have a valid length.For evaluation 181 | on full tracks we recommend passing `pad = True` to :method:`forward`. 182 | """ 183 | if self.resample: 184 | length *= 2 185 | for _ in range(self.depth): 186 | length = math.ceil((length - self.kernel_size) / self.stride) + 1 187 | length = max(1, length) 188 | length += self.context - 1 189 | for _ in range(self.depth): 190 | length = (length - 1) * self.stride + self.kernel_size 191 | 192 | if self.resample: 193 | length = math.ceil(length / 2) 194 | return int(length) 195 | 196 | def forward(self, mix): 197 | x = mix 198 | 199 | if self.normalize: 200 | mono = mix.mean(dim=1, keepdim=True) 201 | mean = mono.mean(dim=-1, keepdim=True) 202 | std = mono.std(dim=-1, keepdim=True) 203 | else: 204 | mean = 0 205 | std = 1 206 | 207 | x = (x - mean) / (1e-5 + std) 208 | 209 | if self.resample: 210 | x = julius.resample_frac(x, 1, 2) 211 | 212 | saved = [] 213 | for encode in self.encoder: 214 | x = encode(x) 215 | saved.append(x) 216 | if self.lstm: 217 | x = self.lstm(x) 218 | for decode in self.decoder: 219 | skip = center_trim(saved.pop(-1), x) 220 | x = x + skip 221 | x = decode(x) 222 | 223 | if self.resample: 224 | x = julius.resample_frac(x, 2, 1) 225 | x = x * std + mean 226 | x = x.view(x.size(0), len(self.sources), self.audio_channels, x.size(-1)) 227 | return x 228 | -------------------------------------------------------------------------------- /uvr/uvr_lib_v5/demucs/pretrained.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """Loading pretrained models. 7 | """ 8 | 9 | import logging 10 | from pathlib import Path 11 | import typing as tp 12 | 13 | # from dora.log import fatal 14 | 15 | import logging 16 | 17 | from diffq import DiffQuantizer 18 | import torch.hub 19 | 20 | from .model import Demucs 21 | from .tasnet_v2 import ConvTasNet 22 | from .utils import set_state 23 | 24 | from .hdemucs import HDemucs 25 | from .repo import ( 26 | RemoteRepo, 27 | LocalRepo, 28 | ModelOnlyRepo, 29 | BagOnlyRepo, 30 | AnyModelRepo, 31 | ModelLoadingError, 32 | ) # noqa 33 | 34 | logger = logging.getLogger(__name__) 35 | ROOT_URL = "https://dl.fbaipublicfiles.com/demucs/mdx_final/" 36 | REMOTE_ROOT = Path(__file__).parent / "remote" 37 | 38 | SOURCES = ["drums", "bass", "other", "vocals"] 39 | 40 | 41 | def demucs_unittest(): 42 | model = HDemucs(channels=4, sources=SOURCES) 43 | return model 44 | 45 | 46 | def add_model_flags(parser): 47 | group = parser.add_mutually_exclusive_group(required=False) 48 | group.add_argument("-s", "--sig", help="Locally trained XP signature.") 49 | group.add_argument( 50 | "-n", 51 | "--name", 52 | default="mdx_extra_q", 53 | help="Pretrained model name or signature. Default is mdx_extra_q.", 54 | ) 55 | parser.add_argument( 56 | "--repo", 57 | type=Path, 58 | help="Folder containing all pre-trained models for use with -n.", 59 | ) 60 | 61 | 62 | def _parse_remote_files(remote_file_list) -> tp.Dict[str, str]: 63 | root: str = "" 64 | models: tp.Dict[str, str] = {} 65 | for line in remote_file_list.read_text().split("\n"): 66 | line = line.strip() 67 | if line.startswith("#"): 68 | continue 69 | elif line.startswith("root:"): 70 | root = line.split(":", 1)[1].strip() 71 | else: 72 | sig = line.split("-", 1)[0] 73 | assert sig not in models 74 | models[sig] = ROOT_URL + root + line 75 | return models 76 | 77 | 78 | def get_model(name: str, repo: tp.Optional[Path] = None): 79 | """`name` must be a bag of models name or a pretrained signature 80 | from the remote AWS model repo or the specified local repo if `repo` is not None. 81 | """ 82 | if name == "demucs_unittest": 83 | return demucs_unittest() 84 | model_repo: ModelOnlyRepo 85 | if repo is None: 86 | models = _parse_remote_files(REMOTE_ROOT / "files.txt") 87 | model_repo = RemoteRepo(models) 88 | bag_repo = BagOnlyRepo(REMOTE_ROOT, model_repo) 89 | else: 90 | if not repo.is_dir(): 91 | fatal(f"{repo} must exist and be a directory.") 92 | model_repo = LocalRepo(repo) 93 | bag_repo = BagOnlyRepo(repo, model_repo) 94 | any_repo = AnyModelRepo(model_repo, bag_repo) 95 | model = any_repo.get_model(name) 96 | model.eval() 97 | return model 98 | 99 | 100 | def get_model_from_args(args): 101 | """ 102 | Load local model package or pre-trained model. 103 | """ 104 | return get_model(name=args.name, repo=args.repo) 105 | 106 | 107 | logger = logging.getLogger(__name__) 108 | ROOT = "https://dl.fbaipublicfiles.com/demucs/v3.0/" 109 | 110 | PRETRAINED_MODELS = { 111 | "demucs": "e07c671f", 112 | "demucs48_hq": "28a1282c", 113 | "demucs_extra": "3646af93", 114 | "demucs_quantized": "07afea75", 115 | "tasnet": "beb46fac", 116 | "tasnet_extra": "df3777b2", 117 | "demucs_unittest": "09ebc15f", 118 | } 119 | 120 | SOURCES = ["drums", "bass", "other", "vocals"] 121 | 122 | 123 | def get_url(name): 124 | sig = PRETRAINED_MODELS[name] 125 | return ROOT + name + "-" + sig[:8] + ".th" 126 | 127 | 128 | def is_pretrained(name): 129 | return name in PRETRAINED_MODELS 130 | 131 | 132 | def load_pretrained(name): 133 | if name == "demucs": 134 | return demucs(pretrained=True) 135 | elif name == "demucs48_hq": 136 | return demucs(pretrained=True, hq=True, channels=48) 137 | elif name == "demucs_extra": 138 | return demucs(pretrained=True, extra=True) 139 | elif name == "demucs_quantized": 140 | return demucs(pretrained=True, quantized=True) 141 | elif name == "demucs_unittest": 142 | return demucs_unittest(pretrained=True) 143 | elif name == "tasnet": 144 | return tasnet(pretrained=True) 145 | elif name == "tasnet_extra": 146 | return tasnet(pretrained=True, extra=True) 147 | else: 148 | raise ValueError(f"Invalid pretrained name {name}") 149 | 150 | 151 | def _load_state(name, model, quantizer=None): 152 | url = get_url(name) 153 | state = torch.hub.load_state_dict_from_url(url, map_location="cpu", check_hash=True) 154 | set_state(model, quantizer, state) 155 | if quantizer: 156 | quantizer.detach() 157 | 158 | 159 | def demucs_unittest(pretrained=True): 160 | model = Demucs(channels=4, sources=SOURCES) 161 | if pretrained: 162 | _load_state("demucs_unittest", model) 163 | return model 164 | 165 | 166 | def demucs(pretrained=True, extra=False, quantized=False, hq=False, channels=64): 167 | if not pretrained and (extra or quantized or hq): 168 | raise ValueError("if extra or quantized is True, pretrained must be True.") 169 | model = Demucs(sources=SOURCES, channels=channels) 170 | if pretrained: 171 | name = "demucs" 172 | if channels != 64: 173 | name += str(channels) 174 | quantizer = None 175 | if sum([extra, quantized, hq]) > 1: 176 | raise ValueError("Only one of extra, quantized, hq, can be True.") 177 | if quantized: 178 | quantizer = DiffQuantizer(model, group_size=8, min_size=1) 179 | name += "_quantized" 180 | if extra: 181 | name += "_extra" 182 | if hq: 183 | name += "_hq" 184 | _load_state(name, model, quantizer) 185 | return model 186 | 187 | 188 | def tasnet(pretrained=True, extra=False): 189 | if not pretrained and extra: 190 | raise ValueError("if extra is True, pretrained must be True.") 191 | model = ConvTasNet(X=10, sources=SOURCES) 192 | if pretrained: 193 | name = "tasnet" 194 | if extra: 195 | name = "tasnet_extra" 196 | _load_state(name, model) 197 | return model 198 | -------------------------------------------------------------------------------- /uvr/uvr_lib_v5/demucs/repo.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """Represents a model repository, including pre-trained models and bags of models. 7 | A repo can either be the main remote repository stored in AWS, or a local repository 8 | with your own models. 9 | """ 10 | 11 | from hashlib import sha256 12 | from pathlib import Path 13 | import typing as tp 14 | 15 | import torch 16 | import yaml 17 | 18 | from .apply import BagOfModels, Model 19 | from .states import load_model 20 | 21 | 22 | AnyModel = tp.Union[Model, BagOfModels] 23 | 24 | 25 | class ModelLoadingError(RuntimeError): 26 | pass 27 | 28 | 29 | def check_checksum(path: Path, checksum: str): 30 | sha = sha256() 31 | with open(path, "rb") as file: 32 | while True: 33 | buf = file.read(2**20) 34 | if not buf: 35 | break 36 | sha.update(buf) 37 | actual_checksum = sha.hexdigest()[: len(checksum)] 38 | if actual_checksum != checksum: 39 | raise ModelLoadingError( 40 | f"Invalid checksum for file {path}, " 41 | f"expected {checksum} but got {actual_checksum}" 42 | ) 43 | 44 | 45 | class ModelOnlyRepo: 46 | """Base class for all model only repos.""" 47 | 48 | def has_model(self, sig: str) -> bool: 49 | raise NotImplementedError() 50 | 51 | def get_model(self, sig: str) -> Model: 52 | raise NotImplementedError() 53 | 54 | 55 | class RemoteRepo(ModelOnlyRepo): 56 | def __init__(self, models: tp.Dict[str, str]): 57 | self._models = models 58 | 59 | def has_model(self, sig: str) -> bool: 60 | return sig in self._models 61 | 62 | def get_model(self, sig: str) -> Model: 63 | try: 64 | url = self._models[sig] 65 | except KeyError: 66 | raise ModelLoadingError( 67 | f"Could not find a pre-trained model with signature {sig}." 68 | ) 69 | pkg = torch.hub.load_state_dict_from_url( 70 | url, map_location="cpu", check_hash=True 71 | ) 72 | return load_model(pkg) 73 | 74 | 75 | class LocalRepo(ModelOnlyRepo): 76 | def __init__(self, root: Path): 77 | self.root = root 78 | self.scan() 79 | 80 | def scan(self): 81 | self._models = {} 82 | self._checksums = {} 83 | for file in self.root.iterdir(): 84 | if file.suffix == ".th": 85 | if "-" in file.stem: 86 | xp_sig, checksum = file.stem.split("-") 87 | self._checksums[xp_sig] = checksum 88 | else: 89 | xp_sig = file.stem 90 | if xp_sig in self._models: 91 | print("Whats xp? ", xp_sig) 92 | raise ModelLoadingError( 93 | f"Duplicate pre-trained model exist for signature {xp_sig}. " 94 | "Please delete all but one." 95 | ) 96 | self._models[xp_sig] = file 97 | 98 | def has_model(self, sig: str) -> bool: 99 | return sig in self._models 100 | 101 | def get_model(self, sig: str) -> Model: 102 | try: 103 | file = self._models[sig] 104 | except KeyError: 105 | raise ModelLoadingError( 106 | f"Could not find pre-trained model with signature {sig}." 107 | ) 108 | if sig in self._checksums: 109 | check_checksum(file, self._checksums[sig]) 110 | return load_model(file) 111 | 112 | 113 | class BagOnlyRepo: 114 | """Handles only YAML files containing bag of models, leaving the actual 115 | model loading to some Repo. 116 | """ 117 | 118 | def __init__(self, root: Path, model_repo: ModelOnlyRepo): 119 | self.root = root 120 | self.model_repo = model_repo 121 | self.scan() 122 | 123 | def scan(self): 124 | self._bags = {} 125 | for file in self.root.iterdir(): 126 | if file.suffix == ".yaml": 127 | self._bags[file.stem] = file 128 | 129 | def has_model(self, name: str) -> bool: 130 | return name in self._bags 131 | 132 | def get_model(self, name: str) -> BagOfModels: 133 | try: 134 | yaml_file = self._bags[name] 135 | except KeyError: 136 | raise ModelLoadingError( 137 | f"{name} is neither a single pre-trained model or " "a bag of models." 138 | ) 139 | bag = yaml.safe_load(open(yaml_file)) 140 | signatures = bag["models"] 141 | models = [self.model_repo.get_model(sig) for sig in signatures] 142 | weights = bag.get("weights") 143 | segment = bag.get("segment") 144 | return BagOfModels(models, weights, segment) 145 | 146 | 147 | class AnyModelRepo: 148 | def __init__(self, model_repo: ModelOnlyRepo, bag_repo: BagOnlyRepo): 149 | self.model_repo = model_repo 150 | self.bag_repo = bag_repo 151 | 152 | def has_model(self, name_or_sig: str) -> bool: 153 | return self.model_repo.has_model(name_or_sig) or self.bag_repo.has_model( 154 | name_or_sig 155 | ) 156 | 157 | def get_model(self, name_or_sig: str) -> AnyModel: 158 | # print('name_or_sig: ', name_or_sig) 159 | if self.model_repo.has_model(name_or_sig): 160 | return self.model_repo.get_model(name_or_sig) 161 | else: 162 | return self.bag_repo.get_model(name_or_sig) 163 | -------------------------------------------------------------------------------- /uvr/uvr_lib_v5/demucs/spec.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """Conveniance wrapper to perform STFT and iSTFT""" 7 | 8 | import torch as th 9 | 10 | 11 | def spectro(x, n_fft=512, hop_length=None, pad=0): 12 | *other, length = x.shape 13 | x = x.reshape(-1, length) 14 | 15 | device_type = x.device.type 16 | is_other_gpu = not device_type in ["cuda", "cpu"] 17 | 18 | if is_other_gpu: 19 | x = x.cpu() 20 | z = th.stft( 21 | x, 22 | n_fft * (1 + pad), 23 | hop_length or n_fft // 4, 24 | window=th.hann_window(n_fft).to(x), 25 | win_length=n_fft, 26 | normalized=True, 27 | center=True, 28 | return_complex=True, 29 | pad_mode="reflect", 30 | ) 31 | _, freqs, frame = z.shape 32 | return z.view(*other, freqs, frame) 33 | 34 | 35 | def ispectro(z, hop_length=None, length=None, pad=0): 36 | *other, freqs, frames = z.shape 37 | n_fft = 2 * freqs - 2 38 | z = z.view(-1, freqs, frames) 39 | win_length = n_fft // (1 + pad) 40 | 41 | device_type = z.device.type 42 | is_other_gpu = not device_type in ["cuda", "cpu"] 43 | 44 | if is_other_gpu: 45 | z = z.cpu() 46 | x = th.istft( 47 | z, 48 | n_fft, 49 | hop_length, 50 | window=th.hann_window(win_length).to(z.real), 51 | win_length=win_length, 52 | normalized=True, 53 | length=length, 54 | center=True, 55 | ) 56 | _, length = x.shape 57 | return x.view(*other, length) 58 | -------------------------------------------------------------------------------- /uvr/uvr_lib_v5/demucs/states.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """ 7 | Utilities to save and load models. 8 | """ 9 | from contextlib import contextmanager 10 | 11 | import functools 12 | import hashlib 13 | import inspect 14 | import io 15 | from pathlib import Path 16 | import warnings 17 | 18 | from diffq import DiffQuantizer, UniformQuantizer, restore_quantized_state 19 | import torch 20 | 21 | 22 | def get_quantizer(model, args, optimizer=None): 23 | """Return the quantizer given the XP quantization args.""" 24 | quantizer = None 25 | if args.diffq: 26 | quantizer = DiffQuantizer( 27 | model, min_size=args.min_size, group_size=args.group_size 28 | ) 29 | if optimizer is not None: 30 | quantizer.setup_optimizer(optimizer) 31 | elif args.qat: 32 | quantizer = UniformQuantizer(model, bits=args.qat, min_size=args.min_size) 33 | return quantizer 34 | 35 | 36 | def load_model(path_or_package, strict=False): 37 | """Load a model from the given serialized model, either given as a dict (already loaded) 38 | or a path to a file on disk.""" 39 | if isinstance(path_or_package, dict): 40 | package = path_or_package 41 | elif isinstance(path_or_package, (str, Path)): 42 | with warnings.catch_warnings(): 43 | warnings.simplefilter("ignore") 44 | path = path_or_package 45 | package = torch.load(path, "cpu") 46 | else: 47 | raise ValueError(f"Invalid type for {path_or_package}.") 48 | 49 | klass = package["klass"] 50 | args = package["args"] 51 | kwargs = package["kwargs"] 52 | 53 | if strict: 54 | model = klass(*args, **kwargs) 55 | else: 56 | sig = inspect.signature(klass) 57 | for key in list(kwargs): 58 | if key not in sig.parameters: 59 | warnings.warn("Dropping inexistant parameter " + key) 60 | del kwargs[key] 61 | model = klass(*args, **kwargs) 62 | 63 | state = package["state"] 64 | 65 | set_state(model, state) 66 | return model 67 | 68 | 69 | def get_state(model, quantizer, half=False): 70 | """Get the state from a model, potentially with quantization applied. 71 | If `half` is True, model are stored as half precision, which shouldn't impact performance 72 | but half the state size.""" 73 | if quantizer is None: 74 | dtype = torch.half if half else None 75 | state = { 76 | k: p.data.to(device="cpu", dtype=dtype) 77 | for k, p in model.state_dict().items() 78 | } 79 | else: 80 | state = quantizer.get_quantized_state() 81 | state["__quantized"] = True 82 | return state 83 | 84 | 85 | def set_state(model, state, quantizer=None): 86 | """Set the state on a given model.""" 87 | if state.get("__quantized"): 88 | if quantizer is not None: 89 | quantizer.restore_quantized_state(model, state["quantized"]) 90 | else: 91 | restore_quantized_state(model, state) 92 | else: 93 | model.load_state_dict(state) 94 | return state 95 | 96 | 97 | def save_with_checksum(content, path): 98 | """Save the given value on disk, along with a sha256 hash. 99 | Should be used with the output of either `serialize_model` or `get_state`.""" 100 | buf = io.BytesIO() 101 | torch.save(content, buf) 102 | sig = hashlib.sha256(buf.getvalue()).hexdigest()[:8] 103 | 104 | path = path.parent / (path.stem + "-" + sig + path.suffix) 105 | path.write_bytes(buf.getvalue()) 106 | 107 | 108 | def copy_state(state): 109 | return {k: v.cpu().clone() for k, v in state.items()} 110 | 111 | 112 | @contextmanager 113 | def swap_state(model, state): 114 | """ 115 | Context manager that swaps the state of a model, e.g: 116 | 117 | # model is in old state 118 | with swap_state(model, new_state): 119 | # model in new state 120 | # model back to old state 121 | """ 122 | old_state = copy_state(model.state_dict()) 123 | model.load_state_dict(state, strict=False) 124 | try: 125 | yield 126 | finally: 127 | model.load_state_dict(old_state) 128 | 129 | 130 | def capture_init(init): 131 | @functools.wraps(init) 132 | def __init__(self, *args, **kwargs): 133 | self._init_args_kwargs = (args, kwargs) 134 | init(self, *args, **kwargs) 135 | 136 | return __init__ 137 | -------------------------------------------------------------------------------- /uvr/uvr_lib_v5/mdxnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .modules import TFC_TDF 4 | from pytorch_lightning import LightningModule 5 | 6 | dim_s = 4 7 | 8 | 9 | class AbstractMDXNet(LightningModule): 10 | def __init__( 11 | self, 12 | target_name, 13 | lr, 14 | optimizer, 15 | dim_c, 16 | dim_f, 17 | dim_t, 18 | n_fft, 19 | hop_length, 20 | overlap, 21 | ): 22 | super().__init__() 23 | self.target_name = target_name 24 | self.lr = lr 25 | self.optimizer = optimizer 26 | self.dim_c = dim_c 27 | self.dim_f = dim_f 28 | self.dim_t = dim_t 29 | self.n_fft = n_fft 30 | self.n_bins = n_fft // 2 + 1 31 | self.hop_length = hop_length 32 | self.window = nn.Parameter( 33 | torch.hann_window(window_length=self.n_fft, periodic=True), 34 | requires_grad=False, 35 | ) 36 | self.freq_pad = nn.Parameter( 37 | torch.zeros([1, dim_c, self.n_bins - self.dim_f, self.dim_t]), 38 | requires_grad=False, 39 | ) 40 | 41 | def get_optimizer(self): 42 | if self.optimizer == "rmsprop": 43 | return torch.optim.RMSprop(self.parameters(), self.lr) 44 | 45 | if self.optimizer == "adamw": 46 | return torch.optim.AdamW(self.parameters(), self.lr) 47 | 48 | 49 | class ConvTDFNet(AbstractMDXNet): 50 | def __init__( 51 | self, 52 | target_name, 53 | lr, 54 | optimizer, 55 | dim_c, 56 | dim_f, 57 | dim_t, 58 | n_fft, 59 | hop_length, 60 | num_blocks, 61 | l, 62 | g, 63 | k, 64 | bn, 65 | bias, 66 | overlap, 67 | ): 68 | 69 | super(ConvTDFNet, self).__init__( 70 | target_name, lr, optimizer, dim_c, dim_f, dim_t, n_fft, hop_length, overlap 71 | ) 72 | # self.save_hyperparameters() 73 | 74 | self.num_blocks = num_blocks 75 | self.l = l 76 | self.g = g 77 | self.k = k 78 | self.bn = bn 79 | self.bias = bias 80 | 81 | if optimizer == "rmsprop": 82 | norm = nn.BatchNorm2d 83 | 84 | if optimizer == "adamw": 85 | norm = lambda input: nn.GroupNorm(2, input) 86 | 87 | self.n = num_blocks // 2 88 | scale = (2, 2) 89 | 90 | self.first_conv = nn.Sequential( 91 | nn.Conv2d(in_channels=self.dim_c, out_channels=g, kernel_size=(1, 1)), 92 | norm(g), 93 | nn.ReLU(), 94 | ) 95 | 96 | f = self.dim_f 97 | c = g 98 | self.encoding_blocks = nn.ModuleList() 99 | self.ds = nn.ModuleList() 100 | for i in range(self.n): 101 | self.encoding_blocks.append(TFC_TDF(c, l, f, k, bn, bias=bias, norm=norm)) 102 | self.ds.append( 103 | nn.Sequential( 104 | nn.Conv2d( 105 | in_channels=c, 106 | out_channels=c + g, 107 | kernel_size=scale, 108 | stride=scale, 109 | ), 110 | norm(c + g), 111 | nn.ReLU(), 112 | ) 113 | ) 114 | f = f // 2 115 | c += g 116 | 117 | self.bottleneck_block = TFC_TDF(c, l, f, k, bn, bias=bias, norm=norm) 118 | 119 | self.decoding_blocks = nn.ModuleList() 120 | self.us = nn.ModuleList() 121 | for i in range(self.n): 122 | self.us.append( 123 | nn.Sequential( 124 | nn.ConvTranspose2d( 125 | in_channels=c, 126 | out_channels=c - g, 127 | kernel_size=scale, 128 | stride=scale, 129 | ), 130 | norm(c - g), 131 | nn.ReLU(), 132 | ) 133 | ) 134 | f = f * 2 135 | c -= g 136 | 137 | self.decoding_blocks.append(TFC_TDF(c, l, f, k, bn, bias=bias, norm=norm)) 138 | 139 | self.final_conv = nn.Sequential( 140 | nn.Conv2d(in_channels=c, out_channels=self.dim_c, kernel_size=(1, 1)), 141 | ) 142 | 143 | def forward(self, x): 144 | 145 | x = self.first_conv(x) 146 | 147 | x = x.transpose(-1, -2) 148 | 149 | ds_outputs = [] 150 | for i in range(self.n): 151 | x = self.encoding_blocks[i](x) 152 | ds_outputs.append(x) 153 | x = self.ds[i](x) 154 | 155 | x = self.bottleneck_block(x) 156 | 157 | for i in range(self.n): 158 | x = self.us[i](x) 159 | x *= ds_outputs[-i - 1] 160 | x = self.decoding_blocks[i](x) 161 | 162 | x = x.transpose(-1, -2) 163 | 164 | x = self.final_conv(x) 165 | 166 | return x 167 | 168 | 169 | class Mixer(nn.Module): 170 | def __init__(self, device, mixer_path): 171 | 172 | super(Mixer, self).__init__() 173 | 174 | self.linear = nn.Linear((dim_s + 1) * 2, dim_s * 2, bias=False) 175 | 176 | self.load_state_dict(torch.load(mixer_path, map_location=device)) 177 | 178 | def forward(self, x): 179 | x = x.reshape(1, (dim_s + 1) * 2, -1).transpose(-1, -2) 180 | x = self.linear(x) 181 | return x.transpose(-1, -2).reshape(dim_s, 2, -1) 182 | -------------------------------------------------------------------------------- /uvr/uvr_lib_v5/mixer.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blaisewf/rvc-cli/a732f87a54780e26e4a0e3c91b2d93f848eef6d2/uvr/uvr_lib_v5/mixer.ckpt -------------------------------------------------------------------------------- /uvr/uvr_lib_v5/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class TFC(nn.Module): 6 | def __init__(self, c, l, k, norm): 7 | super(TFC, self).__init__() 8 | 9 | self.H = nn.ModuleList() 10 | for i in range(l): 11 | self.H.append( 12 | nn.Sequential( 13 | nn.Conv2d( 14 | in_channels=c, 15 | out_channels=c, 16 | kernel_size=k, 17 | stride=1, 18 | padding=k // 2, 19 | ), 20 | norm(c), 21 | nn.ReLU(), 22 | ) 23 | ) 24 | 25 | def forward(self, x): 26 | for h in self.H: 27 | x = h(x) 28 | return x 29 | 30 | 31 | class DenseTFC(nn.Module): 32 | def __init__(self, c, l, k, norm): 33 | super(DenseTFC, self).__init__() 34 | 35 | self.conv = nn.ModuleList() 36 | for i in range(l): 37 | self.conv.append( 38 | nn.Sequential( 39 | nn.Conv2d( 40 | in_channels=c, 41 | out_channels=c, 42 | kernel_size=k, 43 | stride=1, 44 | padding=k // 2, 45 | ), 46 | norm(c), 47 | nn.ReLU(), 48 | ) 49 | ) 50 | 51 | def forward(self, x): 52 | for layer in self.conv[:-1]: 53 | x = torch.cat([layer(x), x], 1) 54 | return self.conv[-1](x) 55 | 56 | 57 | class TFC_TDF(nn.Module): 58 | def __init__(self, c, l, f, k, bn, dense=False, bias=True, norm=nn.BatchNorm2d): 59 | 60 | super(TFC_TDF, self).__init__() 61 | 62 | self.use_tdf = bn is not None 63 | 64 | self.tfc = DenseTFC(c, l, k, norm) if dense else TFC(c, l, k, norm) 65 | 66 | if self.use_tdf: 67 | if bn == 0: 68 | self.tdf = nn.Sequential(nn.Linear(f, f, bias=bias), norm(c), nn.ReLU()) 69 | else: 70 | self.tdf = nn.Sequential( 71 | nn.Linear(f, f // bn, bias=bias), 72 | norm(c), 73 | nn.ReLU(), 74 | nn.Linear(f // bn, f, bias=bias), 75 | norm(c), 76 | nn.ReLU(), 77 | ) 78 | 79 | def forward(self, x): 80 | x = self.tfc(x) 81 | return x + self.tdf(x) if self.use_tdf else x 82 | -------------------------------------------------------------------------------- /uvr/uvr_lib_v5/pyrb.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import tempfile 4 | import six 5 | import numpy as np 6 | import soundfile as sf 7 | import sys 8 | 9 | if getattr(sys, "frozen", False): 10 | BASE_PATH_RUB = sys._MEIPASS 11 | else: 12 | BASE_PATH_RUB = os.path.dirname(os.path.abspath(__file__)) 13 | 14 | __all__ = ["time_stretch", "pitch_shift"] 15 | 16 | __RUBBERBAND_UTIL = os.path.join(BASE_PATH_RUB, "rubberband") 17 | 18 | if six.PY2: 19 | DEVNULL = open(os.devnull, "w") 20 | else: 21 | DEVNULL = subprocess.DEVNULL 22 | 23 | 24 | def __rubberband(y, sr, **kwargs): 25 | 26 | assert sr > 0 27 | 28 | # Get the input and output tempfile 29 | fd, infile = tempfile.mkstemp(suffix=".wav") 30 | os.close(fd) 31 | fd, outfile = tempfile.mkstemp(suffix=".wav") 32 | os.close(fd) 33 | 34 | # dump the audio 35 | sf.write(infile, y, sr) 36 | 37 | try: 38 | # Execute rubberband 39 | arguments = [__RUBBERBAND_UTIL, "-q"] 40 | 41 | for key, value in six.iteritems(kwargs): 42 | arguments.append(str(key)) 43 | arguments.append(str(value)) 44 | 45 | arguments.extend([infile, outfile]) 46 | 47 | subprocess.check_call(arguments, stdout=DEVNULL, stderr=DEVNULL) 48 | 49 | # Load the processed audio. 50 | y_out, _ = sf.read(outfile, always_2d=True) 51 | 52 | # make sure that output dimensions matches input 53 | if y.ndim == 1: 54 | y_out = np.squeeze(y_out) 55 | 56 | except OSError as exc: 57 | six.raise_from( 58 | RuntimeError( 59 | "Failed to execute rubberband. " 60 | "Please verify that rubberband-cli " 61 | "is installed." 62 | ), 63 | exc, 64 | ) 65 | 66 | finally: 67 | # Remove temp files 68 | os.unlink(infile) 69 | os.unlink(outfile) 70 | 71 | return y_out 72 | 73 | 74 | def time_stretch(y, sr, rate, rbargs=None): 75 | if rate <= 0: 76 | raise ValueError("rate must be strictly positive") 77 | 78 | if rate == 1.0: 79 | return y 80 | 81 | if rbargs is None: 82 | rbargs = dict() 83 | 84 | rbargs.setdefault("--tempo", rate) 85 | 86 | return __rubberband(y, sr, **rbargs) 87 | 88 | 89 | def pitch_shift(y, sr, n_steps, rbargs=None): 90 | 91 | if n_steps == 0: 92 | return y 93 | 94 | if rbargs is None: 95 | rbargs = dict() 96 | 97 | rbargs.setdefault("--pitch", n_steps) 98 | 99 | return __rubberband(y, sr, **rbargs) 100 | -------------------------------------------------------------------------------- /uvr/uvr_lib_v5/results.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | Matchering - Audio Matching and Mastering Python Library 5 | Copyright (C) 2016-2022 Sergree 6 | 7 | This program is free software: you can redistribute it and/or modify 8 | it under the terms of the GNU General Public License as published by 9 | the Free Software Foundation, either version 3 of the License, or 10 | (at your option) any later version. 11 | 12 | This program is distributed in the hope that it will be useful, 13 | but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | GNU General Public License for more details. 16 | 17 | You should have received a copy of the GNU General Public License 18 | along with this program. If not, see . 19 | """ 20 | 21 | import os 22 | import soundfile as sf 23 | 24 | 25 | class Result: 26 | def __init__( 27 | self, file: str, subtype: str, use_limiter: bool = True, normalize: bool = True 28 | ): 29 | _, file_ext = os.path.splitext(file) 30 | file_ext = file_ext[1:].upper() 31 | if not sf.check_format(file_ext): 32 | raise TypeError(f"{file_ext} format is not supported") 33 | if not sf.check_format(file_ext, subtype): 34 | raise TypeError(f"{file_ext} format does not have {subtype} subtype") 35 | self.file = file 36 | self.subtype = subtype 37 | self.use_limiter = use_limiter 38 | self.normalize = normalize 39 | 40 | 41 | def pcm16(file: str) -> Result: 42 | return Result(file, "PCM_16") 43 | 44 | 45 | def pcm24(file: str) -> Result: 46 | return Result(file, "FLOAT") 47 | 48 | 49 | def save_audiofile(file: str, wav_set="PCM_16") -> Result: 50 | return Result(file, wav_set) 51 | -------------------------------------------------------------------------------- /uvr/uvr_lib_v5/stft.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class STFT: 5 | """ 6 | This class performs the Short-Time Fourier Transform (STFT) and its inverse (ISTFT). 7 | These functions are essential for converting the audio between the time domain and the frequency domain, 8 | which is a crucial aspect of audio processing in neural networks. 9 | """ 10 | 11 | def __init__(self, logger, n_fft, hop_length, dim_f, device): 12 | self.logger = logger 13 | self.n_fft = n_fft 14 | self.hop_length = hop_length 15 | self.dim_f = dim_f 16 | self.device = device 17 | # Create a Hann window tensor for use in the STFT. 18 | self.hann_window = torch.hann_window(window_length=self.n_fft, periodic=True) 19 | 20 | def __call__(self, input_tensor): 21 | # Determine if the input tensor's device is not a standard computing device (i.e., not CPU or CUDA). 22 | is_non_standard_device = not input_tensor.device.type in ["cuda", "cpu"] 23 | 24 | # If on a non-standard device, temporarily move the tensor to CPU for processing. 25 | if is_non_standard_device: 26 | input_tensor = input_tensor.cpu() 27 | 28 | # Transfer the pre-defined window tensor to the same device as the input tensor. 29 | stft_window = self.hann_window.to(input_tensor.device) 30 | 31 | # Extract batch dimensions (all dimensions except the last two which are channel and time). 32 | batch_dimensions = input_tensor.shape[:-2] 33 | 34 | # Extract channel and time dimensions (last two dimensions of the tensor). 35 | channel_dim, time_dim = input_tensor.shape[-2:] 36 | 37 | # Reshape the tensor to merge batch and channel dimensions for STFT processing. 38 | reshaped_tensor = input_tensor.reshape([-1, time_dim]) 39 | 40 | # Perform the Short-Time Fourier Transform (STFT) on the reshaped tensor. 41 | stft_output = torch.stft( 42 | reshaped_tensor, 43 | n_fft=self.n_fft, 44 | hop_length=self.hop_length, 45 | window=stft_window, 46 | center=True, 47 | return_complex=False, 48 | ) 49 | 50 | # Rearrange the dimensions of the STFT output to bring the frequency dimension forward. 51 | permuted_stft_output = stft_output.permute([0, 3, 1, 2]) 52 | 53 | # Reshape the output to restore the original batch and channel dimensions, while keeping the newly formed frequency and time dimensions. 54 | final_output = permuted_stft_output.reshape( 55 | [*batch_dimensions, channel_dim, 2, -1, permuted_stft_output.shape[-1]] 56 | ).reshape( 57 | [*batch_dimensions, channel_dim * 2, -1, permuted_stft_output.shape[-1]] 58 | ) 59 | 60 | # If the original tensor was on a non-standard device, move the processed tensor back to that device. 61 | if is_non_standard_device: 62 | final_output = final_output.to(self.device) 63 | 64 | # Return the transformed tensor, sliced to retain only the required frequency dimension (`dim_f`). 65 | return final_output[..., : self.dim_f, :] 66 | 67 | def pad_frequency_dimension( 68 | self, 69 | input_tensor, 70 | batch_dimensions, 71 | channel_dim, 72 | freq_dim, 73 | time_dim, 74 | num_freq_bins, 75 | ): 76 | """ 77 | Adds zero padding to the frequency dimension of the input tensor. 78 | """ 79 | # Create a padding tensor for the frequency dimension 80 | freq_padding = torch.zeros( 81 | [*batch_dimensions, channel_dim, num_freq_bins - freq_dim, time_dim] 82 | ).to(input_tensor.device) 83 | 84 | # Concatenate the padding to the input tensor along the frequency dimension. 85 | padded_tensor = torch.cat([input_tensor, freq_padding], -2) 86 | 87 | return padded_tensor 88 | 89 | def calculate_inverse_dimensions(self, input_tensor): 90 | # Extract batch dimensions and frequency-time dimensions. 91 | batch_dimensions = input_tensor.shape[:-3] 92 | channel_dim, freq_dim, time_dim = input_tensor.shape[-3:] 93 | 94 | # Calculate the number of frequency bins for the inverse STFT. 95 | num_freq_bins = self.n_fft // 2 + 1 96 | 97 | return batch_dimensions, channel_dim, freq_dim, time_dim, num_freq_bins 98 | 99 | def prepare_for_istft( 100 | self, padded_tensor, batch_dimensions, channel_dim, num_freq_bins, time_dim 101 | ): 102 | """ 103 | Prepares the tensor for Inverse Short-Time Fourier Transform (ISTFT) by reshaping 104 | and creating a complex tensor from the real and imaginary parts. 105 | """ 106 | # Reshape the tensor to separate real and imaginary parts and prepare for ISTFT. 107 | reshaped_tensor = padded_tensor.reshape( 108 | [*batch_dimensions, channel_dim // 2, 2, num_freq_bins, time_dim] 109 | ) 110 | 111 | # Flatten batch dimensions and rearrange for ISTFT. 112 | flattened_tensor = reshaped_tensor.reshape([-1, 2, num_freq_bins, time_dim]) 113 | 114 | # Rearrange the dimensions of the tensor to bring the frequency dimension forward. 115 | permuted_tensor = flattened_tensor.permute([0, 2, 3, 1]) 116 | 117 | # Combine real and imaginary parts into a complex tensor. 118 | complex_tensor = permuted_tensor[..., 0] + permuted_tensor[..., 1] * 1.0j 119 | 120 | return complex_tensor 121 | 122 | def inverse(self, input_tensor): 123 | # Determine if the input tensor's device is not a standard computing device (i.e., not CPU or CUDA). 124 | is_non_standard_device = not input_tensor.device.type in ["cuda", "cpu"] 125 | 126 | # If on a non-standard device, temporarily move the tensor to CPU for processing. 127 | if is_non_standard_device: 128 | input_tensor = input_tensor.cpu() 129 | 130 | # Transfer the pre-defined Hann window tensor to the same device as the input tensor. 131 | stft_window = self.hann_window.to(input_tensor.device) 132 | 133 | batch_dimensions, channel_dim, freq_dim, time_dim, num_freq_bins = ( 134 | self.calculate_inverse_dimensions(input_tensor) 135 | ) 136 | 137 | padded_tensor = self.pad_frequency_dimension( 138 | input_tensor, 139 | batch_dimensions, 140 | channel_dim, 141 | freq_dim, 142 | time_dim, 143 | num_freq_bins, 144 | ) 145 | 146 | complex_tensor = self.prepare_for_istft( 147 | padded_tensor, batch_dimensions, channel_dim, num_freq_bins, time_dim 148 | ) 149 | 150 | # Perform the Inverse Short-Time Fourier Transform (ISTFT). 151 | istft_result = torch.istft( 152 | complex_tensor, 153 | n_fft=self.n_fft, 154 | hop_length=self.hop_length, 155 | window=stft_window, 156 | center=True, 157 | ) 158 | 159 | # Reshape ISTFT result to restore original batch and channel dimensions. 160 | final_output = istft_result.reshape([*batch_dimensions, 2, -1]) 161 | 162 | # If the original tensor was on a non-standard device, move the processed tensor back to that device. 163 | if is_non_standard_device: 164 | final_output = final_output.to(self.device) 165 | 166 | return final_output 167 | -------------------------------------------------------------------------------- /uvr/uvr_lib_v5/vr_network/__init__.py: -------------------------------------------------------------------------------- 1 | # VR init. 2 | -------------------------------------------------------------------------------- /uvr/uvr_lib_v5/vr_network/layers_new.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | from uvr.uvr_lib_v5 import spec_utils 6 | 7 | 8 | class Conv2DBNActiv(nn.Module): 9 | """ 10 | Conv2DBNActiv Class: 11 | This class implements a convolutional layer followed by batch normalization and an activation function. 12 | It is a fundamental building block for constructing neural networks, especially useful in image and audio processing tasks. 13 | The class encapsulates the pattern of applying a convolution, normalizing the output, and then applying a non-linear activation. 14 | """ 15 | 16 | def __init__(self, nin, nout, ksize=3, stride=1, pad=1, dilation=1, activ=nn.ReLU): 17 | super(Conv2DBNActiv, self).__init__() 18 | 19 | # Sequential model combining Conv2D, BatchNorm, and activation function into a single module 20 | self.conv = nn.Sequential( 21 | nn.Conv2d( 22 | nin, 23 | nout, 24 | kernel_size=ksize, 25 | stride=stride, 26 | padding=pad, 27 | dilation=dilation, 28 | bias=False, 29 | ), 30 | nn.BatchNorm2d(nout), 31 | activ(), 32 | ) 33 | 34 | def __call__(self, input_tensor): 35 | # Forward pass through the sequential model 36 | return self.conv(input_tensor) 37 | 38 | 39 | class Encoder(nn.Module): 40 | """ 41 | Encoder Class: 42 | This class defines an encoder module typically used in autoencoder architectures. 43 | It consists of two convolutional layers, each followed by batch normalization and an activation function. 44 | """ 45 | 46 | def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.LeakyReLU): 47 | super(Encoder, self).__init__() 48 | 49 | # First convolutional layer of the encoder 50 | self.conv1 = Conv2DBNActiv(nin, nout, ksize, stride, pad, activ=activ) 51 | # Second convolutional layer of the encoder 52 | self.conv2 = Conv2DBNActiv(nout, nout, ksize, 1, pad, activ=activ) 53 | 54 | def __call__(self, input_tensor): 55 | # Applying the first and then the second convolutional layers 56 | hidden = self.conv1(input_tensor) 57 | hidden = self.conv2(hidden) 58 | 59 | return hidden 60 | 61 | 62 | class Decoder(nn.Module): 63 | """ 64 | Decoder Class: 65 | This class defines a decoder module, which is the counterpart of the Encoder class in autoencoder architectures. 66 | It applies a convolutional layer followed by batch normalization and an activation function, with an optional dropout layer for regularization. 67 | """ 68 | 69 | def __init__( 70 | self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.ReLU, dropout=False 71 | ): 72 | super(Decoder, self).__init__() 73 | # Convolutional layer with optional dropout for regularization 74 | self.conv1 = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ) 75 | # self.conv2 = Conv2DBNActiv(nout, nout, ksize, 1, pad, activ=activ) 76 | self.dropout = nn.Dropout2d(0.1) if dropout else None 77 | 78 | def __call__(self, input_tensor, skip=None): 79 | # Forward pass through the convolutional layer and optional dropout 80 | input_tensor = F.interpolate( 81 | input_tensor, scale_factor=2, mode="bilinear", align_corners=True 82 | ) 83 | 84 | if skip is not None: 85 | skip = spec_utils.crop_center(skip, input_tensor) 86 | input_tensor = torch.cat([input_tensor, skip], dim=1) 87 | 88 | hidden = self.conv1(input_tensor) 89 | # hidden = self.conv2(hidden) 90 | 91 | if self.dropout is not None: 92 | hidden = self.dropout(hidden) 93 | 94 | return hidden 95 | 96 | 97 | class ASPPModule(nn.Module): 98 | """ 99 | ASPPModule Class: 100 | This class implements the Atrous Spatial Pyramid Pooling (ASPP) module, which is useful for semantic image segmentation tasks. 101 | It captures multi-scale contextual information by applying convolutions at multiple dilation rates. 102 | """ 103 | 104 | def __init__(self, nin, nout, dilations=(4, 8, 12), activ=nn.ReLU, dropout=False): 105 | super(ASPPModule, self).__init__() 106 | 107 | # Global context convolution captures the overall context 108 | self.conv1 = nn.Sequential( 109 | nn.AdaptiveAvgPool2d((1, None)), 110 | Conv2DBNActiv(nin, nout, 1, 1, 0, activ=activ), 111 | ) 112 | self.conv2 = Conv2DBNActiv(nin, nout, 1, 1, 0, activ=activ) 113 | self.conv3 = Conv2DBNActiv( 114 | nin, nout, 3, 1, dilations[0], dilations[0], activ=activ 115 | ) 116 | self.conv4 = Conv2DBNActiv( 117 | nin, nout, 3, 1, dilations[1], dilations[1], activ=activ 118 | ) 119 | self.conv5 = Conv2DBNActiv( 120 | nin, nout, 3, 1, dilations[2], dilations[2], activ=activ 121 | ) 122 | self.bottleneck = Conv2DBNActiv(nout * 5, nout, 1, 1, 0, activ=activ) 123 | self.dropout = nn.Dropout2d(0.1) if dropout else None 124 | 125 | def forward(self, input_tensor): 126 | _, _, h, w = input_tensor.size() 127 | 128 | # Upsample global context to match input size and combine with local and multi-scale features 129 | feat1 = F.interpolate( 130 | self.conv1(input_tensor), size=(h, w), mode="bilinear", align_corners=True 131 | ) 132 | feat2 = self.conv2(input_tensor) 133 | feat3 = self.conv3(input_tensor) 134 | feat4 = self.conv4(input_tensor) 135 | feat5 = self.conv5(input_tensor) 136 | out = torch.cat((feat1, feat2, feat3, feat4, feat5), dim=1) 137 | out = self.bottleneck(out) 138 | 139 | if self.dropout is not None: 140 | out = self.dropout(out) 141 | 142 | return out 143 | 144 | 145 | class LSTMModule(nn.Module): 146 | """ 147 | LSTMModule Class: 148 | This class defines a module that combines convolutional feature extraction with a bidirectional LSTM for sequence modeling. 149 | It is useful for tasks that require understanding temporal dynamics in data, such as speech and audio processing. 150 | """ 151 | 152 | def __init__(self, nin_conv, nin_lstm, nout_lstm): 153 | super(LSTMModule, self).__init__() 154 | # Convolutional layer for initial feature extraction 155 | self.conv = Conv2DBNActiv(nin_conv, 1, 1, 1, 0) 156 | 157 | # Bidirectional LSTM for capturing temporal dynamics 158 | self.lstm = nn.LSTM( 159 | input_size=nin_lstm, hidden_size=nout_lstm // 2, bidirectional=True 160 | ) 161 | 162 | # Dense layer for output dimensionality matching 163 | self.dense = nn.Sequential( 164 | nn.Linear(nout_lstm, nin_lstm), nn.BatchNorm1d(nin_lstm), nn.ReLU() 165 | ) 166 | 167 | def forward(self, input_tensor): 168 | N, _, nbins, nframes = input_tensor.size() 169 | 170 | # Extract features and prepare for LSTM 171 | hidden = self.conv(input_tensor)[:, 0] # N, nbins, nframes 172 | hidden = hidden.permute(2, 0, 1) # nframes, N, nbins 173 | hidden, _ = self.lstm(hidden) 174 | 175 | # Apply dense layer and reshape to match expected output format 176 | hidden = self.dense(hidden.reshape(-1, hidden.size()[-1])) # nframes * N, nbins 177 | hidden = hidden.reshape(nframes, N, 1, nbins) 178 | hidden = hidden.permute(1, 2, 3, 0) 179 | 180 | return hidden 181 | -------------------------------------------------------------------------------- /uvr/uvr_lib_v5/vr_network/model_param_init.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | default_param = {} 4 | default_param["bins"] = -1 5 | default_param["unstable_bins"] = -1 # training only 6 | default_param["stable_bins"] = -1 # training only 7 | default_param["sr"] = 44100 8 | default_param["pre_filter_start"] = -1 9 | default_param["pre_filter_stop"] = -1 10 | default_param["band"] = {} 11 | 12 | N_BINS = "n_bins" 13 | 14 | 15 | def int_keys(d): 16 | """ 17 | Converts string keys that represent integers into actual integer keys in a list. 18 | 19 | This function is particularly useful when dealing with JSON data that may represent 20 | integer keys as strings due to the nature of JSON encoding. By converting these keys 21 | back to integers, it ensures that the data can be used in a manner consistent with 22 | its original representation, especially in contexts where the distinction between 23 | string and integer keys is important. 24 | 25 | Args: 26 | input_list (list of tuples): A list of (key, value) pairs where keys are strings 27 | that may represent integers. 28 | 29 | Returns: 30 | dict: A dictionary with keys converted to integers where applicable. 31 | """ 32 | # Initialize an empty dictionary to hold the converted key-value pairs. 33 | result_dict = {} 34 | # Iterate through each key-value pair in the input list. 35 | for key, value in d: 36 | # Check if the key is a digit (i.e., represents an integer). 37 | if key.isdigit(): 38 | # Convert the key from a string to an integer. 39 | key = int(key) 40 | result_dict[key] = value 41 | return result_dict 42 | 43 | 44 | class ModelParameters(object): 45 | """ 46 | A class to manage model parameters, including loading from a configuration file. 47 | 48 | Attributes: 49 | param (dict): Dictionary holding all parameters for the model. 50 | """ 51 | 52 | def __init__(self, config_path=""): 53 | """ 54 | Initializes the ModelParameters object by loading parameters from a JSON configuration file. 55 | 56 | Args: 57 | config_path (str): Path to the JSON configuration file. 58 | """ 59 | 60 | # Load parameters from the given configuration file path. 61 | with open(config_path, "r") as f: 62 | self.param = json.loads(f.read(), object_pairs_hook=int_keys) 63 | 64 | # Ensure certain parameters are set to False if not specified in the configuration. 65 | for k in [ 66 | "mid_side", 67 | "mid_side_b", 68 | "mid_side_b2", 69 | "stereo_w", 70 | "stereo_n", 71 | "reverse", 72 | ]: 73 | if not k in self.param: 74 | self.param[k] = False 75 | 76 | # If 'n_bins' is specified in the parameters, it's used as the value for 'bins'. 77 | if N_BINS in self.param: 78 | self.param["bins"] = self.param[N_BINS] 79 | -------------------------------------------------------------------------------- /uvr/uvr_lib_v5/vr_network/modelparams/1band_sr16000_hl512.json: -------------------------------------------------------------------------------- 1 | { 2 | "bins": 1024, 3 | "unstable_bins": 0, 4 | "reduction_bins": 0, 5 | "band": { 6 | "1": { 7 | "sr": 16000, 8 | "hl": 512, 9 | "n_fft": 2048, 10 | "crop_start": 0, 11 | "crop_stop": 1024, 12 | "hpf_start": -1, 13 | "res_type": "sinc_best" 14 | } 15 | }, 16 | "sr": 16000, 17 | "pre_filter_start": 1023, 18 | "pre_filter_stop": 1024 19 | } -------------------------------------------------------------------------------- /uvr/uvr_lib_v5/vr_network/modelparams/1band_sr32000_hl512.json: -------------------------------------------------------------------------------- 1 | { 2 | "bins": 1024, 3 | "unstable_bins": 0, 4 | "reduction_bins": 0, 5 | "band": { 6 | "1": { 7 | "sr": 32000, 8 | "hl": 512, 9 | "n_fft": 2048, 10 | "crop_start": 0, 11 | "crop_stop": 1024, 12 | "hpf_start": -1, 13 | "res_type": "kaiser_fast" 14 | } 15 | }, 16 | "sr": 32000, 17 | "pre_filter_start": 1000, 18 | "pre_filter_stop": 1021 19 | } -------------------------------------------------------------------------------- /uvr/uvr_lib_v5/vr_network/modelparams/1band_sr33075_hl384.json: -------------------------------------------------------------------------------- 1 | { 2 | "bins": 1024, 3 | "unstable_bins": 0, 4 | "reduction_bins": 0, 5 | "band": { 6 | "1": { 7 | "sr": 33075, 8 | "hl": 384, 9 | "n_fft": 2048, 10 | "crop_start": 0, 11 | "crop_stop": 1024, 12 | "hpf_start": -1, 13 | "res_type": "sinc_best" 14 | } 15 | }, 16 | "sr": 33075, 17 | "pre_filter_start": 1000, 18 | "pre_filter_stop": 1021 19 | } -------------------------------------------------------------------------------- /uvr/uvr_lib_v5/vr_network/modelparams/1band_sr44100_hl1024.json: -------------------------------------------------------------------------------- 1 | { 2 | "bins": 1024, 3 | "unstable_bins": 0, 4 | "reduction_bins": 0, 5 | "band": { 6 | "1": { 7 | "sr": 44100, 8 | "hl": 1024, 9 | "n_fft": 2048, 10 | "crop_start": 0, 11 | "crop_stop": 1024, 12 | "hpf_start": -1, 13 | "res_type": "sinc_best" 14 | } 15 | }, 16 | "sr": 44100, 17 | "pre_filter_start": 1023, 18 | "pre_filter_stop": 1024 19 | } -------------------------------------------------------------------------------- /uvr/uvr_lib_v5/vr_network/modelparams/1band_sr44100_hl256.json: -------------------------------------------------------------------------------- 1 | { 2 | "bins": 256, 3 | "unstable_bins": 0, 4 | "reduction_bins": 0, 5 | "band": { 6 | "1": { 7 | "sr": 44100, 8 | "hl": 256, 9 | "n_fft": 512, 10 | "crop_start": 0, 11 | "crop_stop": 256, 12 | "hpf_start": -1, 13 | "res_type": "sinc_best" 14 | } 15 | }, 16 | "sr": 44100, 17 | "pre_filter_start": 256, 18 | "pre_filter_stop": 256 19 | } -------------------------------------------------------------------------------- /uvr/uvr_lib_v5/vr_network/modelparams/1band_sr44100_hl512.json: -------------------------------------------------------------------------------- 1 | { 2 | "bins": 1024, 3 | "unstable_bins": 0, 4 | "reduction_bins": 0, 5 | "band": { 6 | "1": { 7 | "sr": 44100, 8 | "hl": 512, 9 | "n_fft": 2048, 10 | "crop_start": 0, 11 | "crop_stop": 1024, 12 | "hpf_start": -1, 13 | "res_type": "sinc_best" 14 | } 15 | }, 16 | "sr": 44100, 17 | "pre_filter_start": 1023, 18 | "pre_filter_stop": 1024 19 | } -------------------------------------------------------------------------------- /uvr/uvr_lib_v5/vr_network/modelparams/1band_sr44100_hl512_cut.json: -------------------------------------------------------------------------------- 1 | { 2 | "bins": 1024, 3 | "unstable_bins": 0, 4 | "reduction_bins": 0, 5 | "band": { 6 | "1": { 7 | "sr": 44100, 8 | "hl": 512, 9 | "n_fft": 2048, 10 | "crop_start": 0, 11 | "crop_stop": 700, 12 | "hpf_start": -1, 13 | "res_type": "sinc_best" 14 | } 15 | }, 16 | "sr": 44100, 17 | "pre_filter_start": 1023, 18 | "pre_filter_stop": 700 19 | } -------------------------------------------------------------------------------- /uvr/uvr_lib_v5/vr_network/modelparams/1band_sr44100_hl512_nf1024.json: -------------------------------------------------------------------------------- 1 | { 2 | "bins": 512, 3 | "unstable_bins": 0, 4 | "reduction_bins": 0, 5 | "band": { 6 | "1": { 7 | "sr": 44100, 8 | "hl": 512, 9 | "n_fft": 1024, 10 | "crop_start": 0, 11 | "crop_stop": 512, 12 | "hpf_start": -1, 13 | "res_type": "sinc_best" 14 | } 15 | }, 16 | "sr": 44100, 17 | "pre_filter_start": 511, 18 | "pre_filter_stop": 512 19 | } -------------------------------------------------------------------------------- /uvr/uvr_lib_v5/vr_network/modelparams/2band_32000.json: -------------------------------------------------------------------------------- 1 | { 2 | "bins": 768, 3 | "unstable_bins": 7, 4 | "reduction_bins": 705, 5 | "band": { 6 | "1": { 7 | "sr": 6000, 8 | "hl": 66, 9 | "n_fft": 512, 10 | "crop_start": 0, 11 | "crop_stop": 240, 12 | "lpf_start": 60, 13 | "lpf_stop": 118, 14 | "res_type": "sinc_fastest" 15 | }, 16 | "2": { 17 | "sr": 32000, 18 | "hl": 352, 19 | "n_fft": 1024, 20 | "crop_start": 22, 21 | "crop_stop": 505, 22 | "hpf_start": 44, 23 | "hpf_stop": 23, 24 | "res_type": "sinc_medium" 25 | } 26 | }, 27 | "sr": 32000, 28 | "pre_filter_start": 710, 29 | "pre_filter_stop": 731 30 | } 31 | -------------------------------------------------------------------------------- /uvr/uvr_lib_v5/vr_network/modelparams/2band_44100_lofi.json: -------------------------------------------------------------------------------- 1 | { 2 | "bins": 512, 3 | "unstable_bins": 7, 4 | "reduction_bins": 510, 5 | "band": { 6 | "1": { 7 | "sr": 11025, 8 | "hl": 160, 9 | "n_fft": 768, 10 | "crop_start": 0, 11 | "crop_stop": 192, 12 | "lpf_start": 41, 13 | "lpf_stop": 139, 14 | "res_type": "sinc_fastest" 15 | }, 16 | "2": { 17 | "sr": 44100, 18 | "hl": 640, 19 | "n_fft": 1024, 20 | "crop_start": 10, 21 | "crop_stop": 320, 22 | "hpf_start": 47, 23 | "hpf_stop": 15, 24 | "res_type": "sinc_medium" 25 | } 26 | }, 27 | "sr": 44100, 28 | "pre_filter_start": 510, 29 | "pre_filter_stop": 512 30 | } 31 | -------------------------------------------------------------------------------- /uvr/uvr_lib_v5/vr_network/modelparams/2band_48000.json: -------------------------------------------------------------------------------- 1 | { 2 | "bins": 768, 3 | "unstable_bins": 7, 4 | "reduction_bins": 705, 5 | "band": { 6 | "1": { 7 | "sr": 6000, 8 | "hl": 66, 9 | "n_fft": 512, 10 | "crop_start": 0, 11 | "crop_stop": 240, 12 | "lpf_start": 60, 13 | "lpf_stop": 240, 14 | "res_type": "sinc_fastest" 15 | }, 16 | "2": { 17 | "sr": 48000, 18 | "hl": 528, 19 | "n_fft": 1536, 20 | "crop_start": 22, 21 | "crop_stop": 505, 22 | "hpf_start": 82, 23 | "hpf_stop": 22, 24 | "res_type": "sinc_medium" 25 | } 26 | }, 27 | "sr": 48000, 28 | "pre_filter_start": 710, 29 | "pre_filter_stop": 731 30 | } -------------------------------------------------------------------------------- /uvr/uvr_lib_v5/vr_network/modelparams/3band_44100.json: -------------------------------------------------------------------------------- 1 | { 2 | "bins": 768, 3 | "unstable_bins": 5, 4 | "reduction_bins": 733, 5 | "band": { 6 | "1": { 7 | "sr": 11025, 8 | "hl": 128, 9 | "n_fft": 768, 10 | "crop_start": 0, 11 | "crop_stop": 278, 12 | "lpf_start": 28, 13 | "lpf_stop": 140, 14 | "res_type": "polyphase" 15 | }, 16 | "2": { 17 | "sr": 22050, 18 | "hl": 256, 19 | "n_fft": 768, 20 | "crop_start": 14, 21 | "crop_stop": 322, 22 | "hpf_start": 70, 23 | "hpf_stop": 14, 24 | "lpf_start": 283, 25 | "lpf_stop": 314, 26 | "res_type": "polyphase" 27 | }, 28 | "3": { 29 | "sr": 44100, 30 | "hl": 512, 31 | "n_fft": 768, 32 | "crop_start": 131, 33 | "crop_stop": 313, 34 | "hpf_start": 154, 35 | "hpf_stop": 141, 36 | "res_type": "sinc_medium" 37 | } 38 | }, 39 | "sr": 44100, 40 | "pre_filter_start": 757, 41 | "pre_filter_stop": 768 42 | } 43 | -------------------------------------------------------------------------------- /uvr/uvr_lib_v5/vr_network/modelparams/3band_44100_mid.json: -------------------------------------------------------------------------------- 1 | { 2 | "mid_side": true, 3 | "bins": 768, 4 | "unstable_bins": 5, 5 | "reduction_bins": 733, 6 | "band": { 7 | "1": { 8 | "sr": 11025, 9 | "hl": 128, 10 | "n_fft": 768, 11 | "crop_start": 0, 12 | "crop_stop": 278, 13 | "lpf_start": 28, 14 | "lpf_stop": 140, 15 | "res_type": "polyphase" 16 | }, 17 | "2": { 18 | "sr": 22050, 19 | "hl": 256, 20 | "n_fft": 768, 21 | "crop_start": 14, 22 | "crop_stop": 322, 23 | "hpf_start": 70, 24 | "hpf_stop": 14, 25 | "lpf_start": 283, 26 | "lpf_stop": 314, 27 | "res_type": "polyphase" 28 | }, 29 | "3": { 30 | "sr": 44100, 31 | "hl": 512, 32 | "n_fft": 768, 33 | "crop_start": 131, 34 | "crop_stop": 313, 35 | "hpf_start": 154, 36 | "hpf_stop": 141, 37 | "res_type": "sinc_medium" 38 | } 39 | }, 40 | "sr": 44100, 41 | "pre_filter_start": 757, 42 | "pre_filter_stop": 768 43 | } 44 | -------------------------------------------------------------------------------- /uvr/uvr_lib_v5/vr_network/modelparams/3band_44100_msb2.json: -------------------------------------------------------------------------------- 1 | { 2 | "mid_side_b2": true, 3 | "bins": 640, 4 | "unstable_bins": 7, 5 | "reduction_bins": 565, 6 | "band": { 7 | "1": { 8 | "sr": 11025, 9 | "hl": 108, 10 | "n_fft": 1024, 11 | "crop_start": 0, 12 | "crop_stop": 187, 13 | "lpf_start": 92, 14 | "lpf_stop": 186, 15 | "res_type": "polyphase" 16 | }, 17 | "2": { 18 | "sr": 22050, 19 | "hl": 216, 20 | "n_fft": 768, 21 | "crop_start": 0, 22 | "crop_stop": 212, 23 | "hpf_start": 68, 24 | "hpf_stop": 34, 25 | "lpf_start": 174, 26 | "lpf_stop": 209, 27 | "res_type": "polyphase" 28 | }, 29 | "3": { 30 | "sr": 44100, 31 | "hl": 432, 32 | "n_fft": 640, 33 | "crop_start": 66, 34 | "crop_stop": 307, 35 | "hpf_start": 86, 36 | "hpf_stop": 72, 37 | "res_type": "kaiser_fast" 38 | } 39 | }, 40 | "sr": 44100, 41 | "pre_filter_start": 639, 42 | "pre_filter_stop": 640 43 | } 44 | -------------------------------------------------------------------------------- /uvr/uvr_lib_v5/vr_network/modelparams/4band_44100.json: -------------------------------------------------------------------------------- 1 | { 2 | "bins": 768, 3 | "unstable_bins": 7, 4 | "reduction_bins": 668, 5 | "band": { 6 | "1": { 7 | "sr": 11025, 8 | "hl": 128, 9 | "n_fft": 1024, 10 | "crop_start": 0, 11 | "crop_stop": 186, 12 | "lpf_start": 37, 13 | "lpf_stop": 73, 14 | "res_type": "polyphase" 15 | }, 16 | "2": { 17 | "sr": 11025, 18 | "hl": 128, 19 | "n_fft": 512, 20 | "crop_start": 4, 21 | "crop_stop": 185, 22 | "hpf_start": 36, 23 | "hpf_stop": 18, 24 | "lpf_start": 93, 25 | "lpf_stop": 185, 26 | "res_type": "polyphase" 27 | }, 28 | "3": { 29 | "sr": 22050, 30 | "hl": 256, 31 | "n_fft": 512, 32 | "crop_start": 46, 33 | "crop_stop": 186, 34 | "hpf_start": 93, 35 | "hpf_stop": 46, 36 | "lpf_start": 164, 37 | "lpf_stop": 186, 38 | "res_type": "polyphase" 39 | }, 40 | "4": { 41 | "sr": 44100, 42 | "hl": 512, 43 | "n_fft": 768, 44 | "crop_start": 121, 45 | "crop_stop": 382, 46 | "hpf_start": 138, 47 | "hpf_stop": 123, 48 | "res_type": "sinc_medium" 49 | } 50 | }, 51 | "sr": 44100, 52 | "pre_filter_start": 740, 53 | "pre_filter_stop": 768 54 | } 55 | -------------------------------------------------------------------------------- /uvr/uvr_lib_v5/vr_network/modelparams/4band_44100_mid.json: -------------------------------------------------------------------------------- 1 | { 2 | "bins": 768, 3 | "unstable_bins": 7, 4 | "mid_side": true, 5 | "reduction_bins": 668, 6 | "band": { 7 | "1": { 8 | "sr": 11025, 9 | "hl": 128, 10 | "n_fft": 1024, 11 | "crop_start": 0, 12 | "crop_stop": 186, 13 | "lpf_start": 37, 14 | "lpf_stop": 73, 15 | "res_type": "polyphase" 16 | }, 17 | "2": { 18 | "sr": 11025, 19 | "hl": 128, 20 | "n_fft": 512, 21 | "crop_start": 4, 22 | "crop_stop": 185, 23 | "hpf_start": 36, 24 | "hpf_stop": 18, 25 | "lpf_start": 93, 26 | "lpf_stop": 185, 27 | "res_type": "polyphase" 28 | }, 29 | "3": { 30 | "sr": 22050, 31 | "hl": 256, 32 | "n_fft": 512, 33 | "crop_start": 46, 34 | "crop_stop": 186, 35 | "hpf_start": 93, 36 | "hpf_stop": 46, 37 | "lpf_start": 164, 38 | "lpf_stop": 186, 39 | "res_type": "polyphase" 40 | }, 41 | "4": { 42 | "sr": 44100, 43 | "hl": 512, 44 | "n_fft": 768, 45 | "crop_start": 121, 46 | "crop_stop": 382, 47 | "hpf_start": 138, 48 | "hpf_stop": 123, 49 | "res_type": "sinc_medium" 50 | } 51 | }, 52 | "sr": 44100, 53 | "pre_filter_start": 740, 54 | "pre_filter_stop": 768 55 | } 56 | -------------------------------------------------------------------------------- /uvr/uvr_lib_v5/vr_network/modelparams/4band_44100_msb.json: -------------------------------------------------------------------------------- 1 | { 2 | "mid_side_b": true, 3 | "bins": 768, 4 | "unstable_bins": 7, 5 | "reduction_bins": 668, 6 | "band": { 7 | "1": { 8 | "sr": 11025, 9 | "hl": 128, 10 | "n_fft": 1024, 11 | "crop_start": 0, 12 | "crop_stop": 186, 13 | "lpf_start": 37, 14 | "lpf_stop": 73, 15 | "res_type": "polyphase" 16 | }, 17 | "2": { 18 | "sr": 11025, 19 | "hl": 128, 20 | "n_fft": 512, 21 | "crop_start": 4, 22 | "crop_stop": 185, 23 | "hpf_start": 36, 24 | "hpf_stop": 18, 25 | "lpf_start": 93, 26 | "lpf_stop": 185, 27 | "res_type": "polyphase" 28 | }, 29 | "3": { 30 | "sr": 22050, 31 | "hl": 256, 32 | "n_fft": 512, 33 | "crop_start": 46, 34 | "crop_stop": 186, 35 | "hpf_start": 93, 36 | "hpf_stop": 46, 37 | "lpf_start": 164, 38 | "lpf_stop": 186, 39 | "res_type": "polyphase" 40 | }, 41 | "4": { 42 | "sr": 44100, 43 | "hl": 512, 44 | "n_fft": 768, 45 | "crop_start": 121, 46 | "crop_stop": 382, 47 | "hpf_start": 138, 48 | "hpf_stop": 123, 49 | "res_type": "sinc_medium" 50 | } 51 | }, 52 | "sr": 44100, 53 | "pre_filter_start": 740, 54 | "pre_filter_stop": 768 55 | } -------------------------------------------------------------------------------- /uvr/uvr_lib_v5/vr_network/modelparams/4band_44100_msb2.json: -------------------------------------------------------------------------------- 1 | { 2 | "mid_side_b": true, 3 | "bins": 768, 4 | "unstable_bins": 7, 5 | "reduction_bins": 668, 6 | "band": { 7 | "1": { 8 | "sr": 11025, 9 | "hl": 128, 10 | "n_fft": 1024, 11 | "crop_start": 0, 12 | "crop_stop": 186, 13 | "lpf_start": 37, 14 | "lpf_stop": 73, 15 | "res_type": "polyphase" 16 | }, 17 | "2": { 18 | "sr": 11025, 19 | "hl": 128, 20 | "n_fft": 512, 21 | "crop_start": 4, 22 | "crop_stop": 185, 23 | "hpf_start": 36, 24 | "hpf_stop": 18, 25 | "lpf_start": 93, 26 | "lpf_stop": 185, 27 | "res_type": "polyphase" 28 | }, 29 | "3": { 30 | "sr": 22050, 31 | "hl": 256, 32 | "n_fft": 512, 33 | "crop_start": 46, 34 | "crop_stop": 186, 35 | "hpf_start": 93, 36 | "hpf_stop": 46, 37 | "lpf_start": 164, 38 | "lpf_stop": 186, 39 | "res_type": "polyphase" 40 | }, 41 | "4": { 42 | "sr": 44100, 43 | "hl": 512, 44 | "n_fft": 768, 45 | "crop_start": 121, 46 | "crop_stop": 382, 47 | "hpf_start": 138, 48 | "hpf_stop": 123, 49 | "res_type": "sinc_medium" 50 | } 51 | }, 52 | "sr": 44100, 53 | "pre_filter_start": 740, 54 | "pre_filter_stop": 768 55 | } -------------------------------------------------------------------------------- /uvr/uvr_lib_v5/vr_network/modelparams/4band_44100_reverse.json: -------------------------------------------------------------------------------- 1 | { 2 | "reverse": true, 3 | "bins": 768, 4 | "unstable_bins": 7, 5 | "reduction_bins": 668, 6 | "band": { 7 | "1": { 8 | "sr": 11025, 9 | "hl": 128, 10 | "n_fft": 1024, 11 | "crop_start": 0, 12 | "crop_stop": 186, 13 | "lpf_start": 37, 14 | "lpf_stop": 73, 15 | "res_type": "polyphase" 16 | }, 17 | "2": { 18 | "sr": 11025, 19 | "hl": 128, 20 | "n_fft": 512, 21 | "crop_start": 4, 22 | "crop_stop": 185, 23 | "hpf_start": 36, 24 | "hpf_stop": 18, 25 | "lpf_start": 93, 26 | "lpf_stop": 185, 27 | "res_type": "polyphase" 28 | }, 29 | "3": { 30 | "sr": 22050, 31 | "hl": 256, 32 | "n_fft": 512, 33 | "crop_start": 46, 34 | "crop_stop": 186, 35 | "hpf_start": 93, 36 | "hpf_stop": 46, 37 | "lpf_start": 164, 38 | "lpf_stop": 186, 39 | "res_type": "polyphase" 40 | }, 41 | "4": { 42 | "sr": 44100, 43 | "hl": 512, 44 | "n_fft": 768, 45 | "crop_start": 121, 46 | "crop_stop": 382, 47 | "hpf_start": 138, 48 | "hpf_stop": 123, 49 | "res_type": "sinc_medium" 50 | } 51 | }, 52 | "sr": 44100, 53 | "pre_filter_start": 740, 54 | "pre_filter_stop": 768 55 | } -------------------------------------------------------------------------------- /uvr/uvr_lib_v5/vr_network/modelparams/4band_44100_sw.json: -------------------------------------------------------------------------------- 1 | { 2 | "stereo_w": true, 3 | "bins": 768, 4 | "unstable_bins": 7, 5 | "reduction_bins": 668, 6 | "band": { 7 | "1": { 8 | "sr": 11025, 9 | "hl": 128, 10 | "n_fft": 1024, 11 | "crop_start": 0, 12 | "crop_stop": 186, 13 | "lpf_start": 37, 14 | "lpf_stop": 73, 15 | "res_type": "polyphase" 16 | }, 17 | "2": { 18 | "sr": 11025, 19 | "hl": 128, 20 | "n_fft": 512, 21 | "crop_start": 4, 22 | "crop_stop": 185, 23 | "hpf_start": 36, 24 | "hpf_stop": 18, 25 | "lpf_start": 93, 26 | "lpf_stop": 185, 27 | "res_type": "polyphase" 28 | }, 29 | "3": { 30 | "sr": 22050, 31 | "hl": 256, 32 | "n_fft": 512, 33 | "crop_start": 46, 34 | "crop_stop": 186, 35 | "hpf_start": 93, 36 | "hpf_stop": 46, 37 | "lpf_start": 164, 38 | "lpf_stop": 186, 39 | "res_type": "polyphase" 40 | }, 41 | "4": { 42 | "sr": 44100, 43 | "hl": 512, 44 | "n_fft": 768, 45 | "crop_start": 121, 46 | "crop_stop": 382, 47 | "hpf_start": 138, 48 | "hpf_stop": 123, 49 | "res_type": "sinc_medium" 50 | } 51 | }, 52 | "sr": 44100, 53 | "pre_filter_start": 740, 54 | "pre_filter_stop": 768 55 | } -------------------------------------------------------------------------------- /uvr/uvr_lib_v5/vr_network/modelparams/4band_v2.json: -------------------------------------------------------------------------------- 1 | { 2 | "bins": 672, 3 | "unstable_bins": 8, 4 | "reduction_bins": 637, 5 | "band": { 6 | "1": { 7 | "sr": 7350, 8 | "hl": 80, 9 | "n_fft": 640, 10 | "crop_start": 0, 11 | "crop_stop": 85, 12 | "lpf_start": 25, 13 | "lpf_stop": 53, 14 | "res_type": "polyphase" 15 | }, 16 | "2": { 17 | "sr": 7350, 18 | "hl": 80, 19 | "n_fft": 320, 20 | "crop_start": 4, 21 | "crop_stop": 87, 22 | "hpf_start": 25, 23 | "hpf_stop": 12, 24 | "lpf_start": 31, 25 | "lpf_stop": 62, 26 | "res_type": "polyphase" 27 | }, 28 | "3": { 29 | "sr": 14700, 30 | "hl": 160, 31 | "n_fft": 512, 32 | "crop_start": 17, 33 | "crop_stop": 216, 34 | "hpf_start": 48, 35 | "hpf_stop": 24, 36 | "lpf_start": 139, 37 | "lpf_stop": 210, 38 | "res_type": "polyphase" 39 | }, 40 | "4": { 41 | "sr": 44100, 42 | "hl": 480, 43 | "n_fft": 960, 44 | "crop_start": 78, 45 | "crop_stop": 383, 46 | "hpf_start": 130, 47 | "hpf_stop": 86, 48 | "res_type": "kaiser_fast" 49 | } 50 | }, 51 | "sr": 44100, 52 | "pre_filter_start": 668, 53 | "pre_filter_stop": 672 54 | } -------------------------------------------------------------------------------- /uvr/uvr_lib_v5/vr_network/modelparams/4band_v2_sn.json: -------------------------------------------------------------------------------- 1 | { 2 | "bins": 672, 3 | "unstable_bins": 8, 4 | "reduction_bins": 637, 5 | "band": { 6 | "1": { 7 | "sr": 7350, 8 | "hl": 80, 9 | "n_fft": 640, 10 | "crop_start": 0, 11 | "crop_stop": 85, 12 | "lpf_start": 25, 13 | "lpf_stop": 53, 14 | "res_type": "polyphase" 15 | }, 16 | "2": { 17 | "sr": 7350, 18 | "hl": 80, 19 | "n_fft": 320, 20 | "crop_start": 4, 21 | "crop_stop": 87, 22 | "hpf_start": 25, 23 | "hpf_stop": 12, 24 | "lpf_start": 31, 25 | "lpf_stop": 62, 26 | "res_type": "polyphase" 27 | }, 28 | "3": { 29 | "sr": 14700, 30 | "hl": 160, 31 | "n_fft": 512, 32 | "crop_start": 17, 33 | "crop_stop": 216, 34 | "hpf_start": 48, 35 | "hpf_stop": 24, 36 | "lpf_start": 139, 37 | "lpf_stop": 210, 38 | "res_type": "polyphase" 39 | }, 40 | "4": { 41 | "sr": 44100, 42 | "hl": 480, 43 | "n_fft": 960, 44 | "crop_start": 78, 45 | "crop_stop": 383, 46 | "hpf_start": 130, 47 | "hpf_stop": 86, 48 | "convert_channels": "stereo_n", 49 | "res_type": "kaiser_fast" 50 | } 51 | }, 52 | "sr": 44100, 53 | "pre_filter_start": 668, 54 | "pre_filter_stop": 672 55 | } -------------------------------------------------------------------------------- /uvr/uvr_lib_v5/vr_network/modelparams/4band_v3.json: -------------------------------------------------------------------------------- 1 | { 2 | "bins": 672, 3 | "unstable_bins": 8, 4 | "reduction_bins": 530, 5 | "band": { 6 | "1": { 7 | "sr": 7350, 8 | "hl": 80, 9 | "n_fft": 640, 10 | "crop_start": 0, 11 | "crop_stop": 85, 12 | "lpf_start": 25, 13 | "lpf_stop": 53, 14 | "res_type": "polyphase" 15 | }, 16 | "2": { 17 | "sr": 7350, 18 | "hl": 80, 19 | "n_fft": 320, 20 | "crop_start": 4, 21 | "crop_stop": 87, 22 | "hpf_start": 25, 23 | "hpf_stop": 12, 24 | "lpf_start": 31, 25 | "lpf_stop": 62, 26 | "res_type": "polyphase" 27 | }, 28 | "3": { 29 | "sr": 14700, 30 | "hl": 160, 31 | "n_fft": 512, 32 | "crop_start": 17, 33 | "crop_stop": 216, 34 | "hpf_start": 48, 35 | "hpf_stop": 24, 36 | "lpf_start": 139, 37 | "lpf_stop": 210, 38 | "res_type": "polyphase" 39 | }, 40 | "4": { 41 | "sr": 44100, 42 | "hl": 480, 43 | "n_fft": 960, 44 | "crop_start": 78, 45 | "crop_stop": 383, 46 | "hpf_start": 130, 47 | "hpf_stop": 86, 48 | "res_type": "kaiser_fast" 49 | } 50 | }, 51 | "sr": 44100, 52 | "pre_filter_start": 668, 53 | "pre_filter_stop": 672 54 | } -------------------------------------------------------------------------------- /uvr/uvr_lib_v5/vr_network/modelparams/4band_v3_sn.json: -------------------------------------------------------------------------------- 1 | { 2 | "n_bins": 672, 3 | "unstable_bins": 8, 4 | "stable_bins": 530, 5 | "band": { 6 | "1": { 7 | "sr": 7350, 8 | "hl": 80, 9 | "n_fft": 640, 10 | "crop_start": 0, 11 | "crop_stop": 85, 12 | "lpf_start": 25, 13 | "lpf_stop": 53, 14 | "res_type": "polyphase" 15 | }, 16 | "2": { 17 | "sr": 7350, 18 | "hl": 80, 19 | "n_fft": 320, 20 | "crop_start": 4, 21 | "crop_stop": 87, 22 | "hpf_start": 25, 23 | "hpf_stop": 12, 24 | "lpf_start": 31, 25 | "lpf_stop": 62, 26 | "res_type": "polyphase" 27 | }, 28 | "3": { 29 | "sr": 14700, 30 | "hl": 160, 31 | "n_fft": 512, 32 | "crop_start": 17, 33 | "crop_stop": 216, 34 | "hpf_start": 48, 35 | "hpf_stop": 24, 36 | "lpf_start": 139, 37 | "lpf_stop": 210, 38 | "res_type": "polyphase" 39 | }, 40 | "4": { 41 | "sr": 44100, 42 | "hl": 480, 43 | "n_fft": 960, 44 | "crop_start": 78, 45 | "crop_stop": 383, 46 | "hpf_start": 130, 47 | "hpf_stop": 86, 48 | "convert_channels": "stereo_n", 49 | "res_type": "kaiser_fast" 50 | } 51 | }, 52 | "sr": 44100, 53 | "pre_filter_start": 668, 54 | "pre_filter_stop": 672 55 | } -------------------------------------------------------------------------------- /uvr/uvr_lib_v5/vr_network/modelparams/ensemble.json: -------------------------------------------------------------------------------- 1 | { 2 | "mid_side_b2": true, 3 | "bins": 1280, 4 | "unstable_bins": 7, 5 | "reduction_bins": 565, 6 | "band": { 7 | "1": { 8 | "sr": 11025, 9 | "hl": 108, 10 | "n_fft": 2048, 11 | "crop_start": 0, 12 | "crop_stop": 374, 13 | "lpf_start": 92, 14 | "lpf_stop": 186, 15 | "res_type": "polyphase" 16 | }, 17 | "2": { 18 | "sr": 22050, 19 | "hl": 216, 20 | "n_fft": 1536, 21 | "crop_start": 0, 22 | "crop_stop": 424, 23 | "hpf_start": 68, 24 | "hpf_stop": 34, 25 | "lpf_start": 348, 26 | "lpf_stop": 418, 27 | "res_type": "polyphase" 28 | }, 29 | "3": { 30 | "sr": 44100, 31 | "hl": 432, 32 | "n_fft": 1280, 33 | "crop_start": 132, 34 | "crop_stop": 614, 35 | "hpf_start": 172, 36 | "hpf_stop": 144, 37 | "res_type": "polyphase" 38 | } 39 | }, 40 | "sr": 44100, 41 | "pre_filter_start": 1280, 42 | "pre_filter_stop": 1280 43 | } -------------------------------------------------------------------------------- /uvr/uvr_lib_v5/vr_network/nets_new.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from . import layers_new as layers 5 | 6 | 7 | class BaseNet(nn.Module): 8 | """ 9 | BaseNet Class: 10 | This class defines the base network architecture for vocal removal. It includes a series of encoders for feature extraction, 11 | an ASPP module for capturing multi-scale context, and a series of decoders for reconstructing the output. Additionally, 12 | it incorporates an LSTM module for capturing temporal dependencies. 13 | """ 14 | 15 | def __init__( 16 | self, nin, nout, nin_lstm, nout_lstm, dilations=((4, 2), (8, 4), (12, 6)) 17 | ): 18 | super(BaseNet, self).__init__() 19 | # Initialize the encoder layers with increasing output channels for hierarchical feature extraction. 20 | self.enc1 = layers.Conv2DBNActiv(nin, nout, 3, 1, 1) 21 | self.enc2 = layers.Encoder(nout, nout * 2, 3, 2, 1) 22 | self.enc3 = layers.Encoder(nout * 2, nout * 4, 3, 2, 1) 23 | self.enc4 = layers.Encoder(nout * 4, nout * 6, 3, 2, 1) 24 | self.enc5 = layers.Encoder(nout * 6, nout * 8, 3, 2, 1) 25 | 26 | # ASPP module for capturing multi-scale features with different dilation rates. 27 | self.aspp = layers.ASPPModule(nout * 8, nout * 8, dilations, dropout=True) 28 | 29 | # Decoder layers for upscaling and merging features from different levels of the encoder and ASPP module. 30 | self.dec4 = layers.Decoder(nout * (6 + 8), nout * 6, 3, 1, 1) 31 | self.dec3 = layers.Decoder(nout * (4 + 6), nout * 4, 3, 1, 1) 32 | self.dec2 = layers.Decoder(nout * (2 + 4), nout * 2, 3, 1, 1) 33 | 34 | # LSTM module for capturing temporal dependencies in the sequence of features. 35 | self.lstm_dec2 = layers.LSTMModule(nout * 2, nin_lstm, nout_lstm) 36 | self.dec1 = layers.Decoder(nout * (1 + 2) + 1, nout * 1, 3, 1, 1) 37 | 38 | def __call__(self, input_tensor): 39 | # Sequentially pass the input through the encoder layers. 40 | encoded1 = self.enc1(input_tensor) 41 | encoded2 = self.enc2(encoded1) 42 | encoded3 = self.enc3(encoded2) 43 | encoded4 = self.enc4(encoded3) 44 | encoded5 = self.enc5(encoded4) 45 | 46 | # Pass the deepest encoder output through the ASPP module. 47 | bottleneck = self.aspp(encoded5) 48 | 49 | # Sequentially upscale and merge the features using the decoder layers. 50 | bottleneck = self.dec4(bottleneck, encoded4) 51 | bottleneck = self.dec3(bottleneck, encoded3) 52 | bottleneck = self.dec2(bottleneck, encoded2) 53 | # Concatenate the LSTM module output for temporal feature enhancement. 54 | bottleneck = torch.cat([bottleneck, self.lstm_dec2(bottleneck)], dim=1) 55 | bottleneck = self.dec1(bottleneck, encoded1) 56 | 57 | return bottleneck 58 | 59 | 60 | class CascadedNet(nn.Module): 61 | """ 62 | CascadedNet Class: 63 | This class defines a cascaded network architecture that processes input in multiple stages, each stage focusing on different frequency bands. 64 | It utilizes the BaseNet for processing, and combines outputs from different stages to produce the final mask for vocal removal. 65 | """ 66 | 67 | def __init__(self, n_fft, nn_arch_size=51000, nout=32, nout_lstm=128): 68 | super(CascadedNet, self).__init__() 69 | # Calculate frequency bins based on FFT size. 70 | self.max_bin = n_fft // 2 71 | self.output_bin = n_fft // 2 + 1 72 | self.nin_lstm = self.max_bin // 2 73 | self.offset = 64 74 | # Adjust output channels based on the architecture size. 75 | nout = 64 if nn_arch_size == 218409 else nout 76 | 77 | # print(nout, nout_lstm, n_fft) 78 | 79 | # Initialize the network stages, each focusing on different frequency bands and progressively refining the output. 80 | self.stg1_low_band_net = nn.Sequential( 81 | BaseNet(2, nout // 2, self.nin_lstm // 2, nout_lstm), 82 | layers.Conv2DBNActiv(nout // 2, nout // 4, 1, 1, 0), 83 | ) 84 | self.stg1_high_band_net = BaseNet( 85 | 2, nout // 4, self.nin_lstm // 2, nout_lstm // 2 86 | ) 87 | 88 | self.stg2_low_band_net = nn.Sequential( 89 | BaseNet(nout // 4 + 2, nout, self.nin_lstm // 2, nout_lstm), 90 | layers.Conv2DBNActiv(nout, nout // 2, 1, 1, 0), 91 | ) 92 | self.stg2_high_band_net = BaseNet( 93 | nout // 4 + 2, nout // 2, self.nin_lstm // 2, nout_lstm // 2 94 | ) 95 | 96 | self.stg3_full_band_net = BaseNet( 97 | 3 * nout // 4 + 2, nout, self.nin_lstm, nout_lstm 98 | ) 99 | 100 | # Output layer for generating the final mask. 101 | self.out = nn.Conv2d(nout, 2, 1, bias=False) 102 | # Auxiliary output layer for intermediate supervision during training. 103 | self.aux_out = nn.Conv2d(3 * nout // 4, 2, 1, bias=False) 104 | 105 | def forward(self, input_tensor): 106 | # Preprocess input tensor to match the maximum frequency bin. 107 | input_tensor = input_tensor[:, :, : self.max_bin] 108 | 109 | # Split the input into low and high frequency bands. 110 | bandw = input_tensor.size()[2] // 2 111 | l1_in = input_tensor[:, :, :bandw] 112 | h1_in = input_tensor[:, :, bandw:] 113 | 114 | # Process each band through the first stage networks. 115 | l1 = self.stg1_low_band_net(l1_in) 116 | h1 = self.stg1_high_band_net(h1_in) 117 | 118 | # Combine the outputs for auxiliary supervision. 119 | aux1 = torch.cat([l1, h1], dim=2) 120 | 121 | # Prepare inputs for the second stage by concatenating the original and processed bands. 122 | l2_in = torch.cat([l1_in, l1], dim=1) 123 | h2_in = torch.cat([h1_in, h1], dim=1) 124 | 125 | # Process through the second stage networks. 126 | l2 = self.stg2_low_band_net(l2_in) 127 | h2 = self.stg2_high_band_net(h2_in) 128 | 129 | # Combine the outputs for auxiliary supervision. 130 | aux2 = torch.cat([l2, h2], dim=2) 131 | 132 | # Prepare input for the third stage by concatenating all previous outputs with the original input. 133 | f3_in = torch.cat([input_tensor, aux1, aux2], dim=1) 134 | 135 | # Process through the third stage network. 136 | f3 = self.stg3_full_band_net(f3_in) 137 | 138 | # Apply the output layer to generate the final mask and apply sigmoid for normalization. 139 | mask = torch.sigmoid(self.out(f3)) 140 | 141 | # Pad the mask to match the output frequency bin size. 142 | mask = F.pad( 143 | input=mask, 144 | pad=(0, 0, 0, self.output_bin - mask.size()[2]), 145 | mode="replicate", 146 | ) 147 | 148 | # During training, generate and pad the auxiliary output for additional supervision. 149 | if self.training: 150 | aux = torch.cat([aux1, aux2], dim=1) 151 | aux = torch.sigmoid(self.aux_out(aux)) 152 | aux = F.pad( 153 | input=aux, 154 | pad=(0, 0, 0, self.output_bin - aux.size()[2]), 155 | mode="replicate", 156 | ) 157 | return mask, aux 158 | else: 159 | return mask 160 | 161 | # Method for predicting the mask given an input tensor. 162 | def predict_mask(self, input_tensor): 163 | mask = self.forward(input_tensor) 164 | 165 | # If an offset is specified, crop the mask to remove edge artifacts. 166 | if self.offset > 0: 167 | mask = mask[:, :, :, self.offset : -self.offset] 168 | assert mask.size()[3] > 0 169 | 170 | return mask 171 | 172 | # Method for applying the predicted mask to the input tensor to obtain the predicted magnitude. 173 | def predict(self, input_tensor): 174 | mask = self.forward(input_tensor) 175 | pred_mag = input_tensor * mask 176 | 177 | # If an offset is specified, crop the predicted magnitude to remove edge artifacts. 178 | if self.offset > 0: 179 | pred_mag = pred_mag[:, :, :, self.offset : -self.offset] 180 | assert pred_mag.size()[3] > 0 181 | 182 | return pred_mag 183 | --------------------------------------------------------------------------------