├── .devcontainer └── devcontainer.json ├── .dockerignore ├── .eslintrc.json ├── .github └── workflows │ └── unittest-backend.yml ├── .gitignore ├── .media ├── architecture.drawio.xml ├── architecture.png └── hero.png ├── .pylintrc ├── Dockerfile ├── LICENSE ├── README.md ├── assets ├── icon.ico ├── icon.png └── word_lists │ ├── english_abbreviations.txt │ ├── english_initialisms.txt │ ├── german_abbreviations.txt │ ├── german_initialisms.txt │ ├── russian_abbreviations.txt │ ├── russian_initialisms.txt │ ├── spanish_abbreviations.txt │ └── spanish_initialisms.txt ├── backend ├── environment.yml ├── global_test.py ├── setup.py └── voice_smith │ ├── __init__.py │ ├── acoustic_training.py │ ├── cleaning_run.py │ ├── config │ ├── configs.py │ ├── file_extensions.py │ ├── globals.py │ ├── langs.py │ └── symbols.py │ ├── create_splits.py │ ├── finetune_acoustic.py │ ├── g2p │ ├── LICENSE │ ├── dp │ │ ├── __init__.py │ │ ├── configs │ │ │ ├── __init__.py │ │ │ ├── autoreg_config.yaml │ │ │ ├── forward_config.yaml │ │ │ └── logging.yaml │ │ ├── model │ │ │ ├── __init__.py │ │ │ ├── model.py │ │ │ ├── predictor.py │ │ │ └── utils.py │ │ ├── phonemizer.py │ │ ├── preprocess.py │ │ ├── preprocessing │ │ │ ├── __init__.py │ │ │ ├── text.py │ │ │ └── utils.py │ │ ├── result.py │ │ ├── train.py │ │ ├── training │ │ │ ├── __init__.py │ │ │ ├── dataset.py │ │ │ ├── decorators.py │ │ │ ├── evaluation.py │ │ │ ├── losses.py │ │ │ ├── metrics.py │ │ │ └── trainer.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── infer.py │ │ │ ├── io.py │ │ │ ├── logging.py │ │ │ └── model.py │ ├── lang2transctibe.txt │ ├── parse_dictionary.py │ ├── phones.txt │ ├── run_prediction.py │ ├── run_training.py │ └── text_symbols.txt │ ├── handlers │ └── v2_0_0.py │ ├── inference.py │ ├── model │ ├── acoustic_model.py │ ├── attention.py │ ├── ecapa_tdnn.py │ ├── layers.py │ ├── natural_speech.py │ ├── position_encoding.py │ ├── reference_encoder.py │ └── univnet.py │ ├── preprocessing │ ├── align.py │ ├── copy_files.py │ ├── extract_data.py │ ├── g2p.py │ ├── gen_speaker_embeddings.py │ ├── generate_vocab.py │ ├── get_txt_from_files.py │ ├── ground_truth_alignment.py │ ├── merge_lexika.py │ ├── sample_splitting.py │ ├── text_normalization.py │ ├── transcribe.py │ └── vad.py │ ├── pretrain_acoustic.py │ ├── sample_splitting_run.py │ ├── scripts │ ├── crawl_wiktionary.py │ ├── get_statistics.py │ ├── prep_acoustic_for_fine.py │ └── speaker_encoder_to_torch.py │ ├── server.py │ ├── sql.py │ ├── text_normalization_run.py │ ├── training_run.py │ ├── utils │ ├── audio.py │ ├── audio_test.py │ ├── currencies.py │ ├── dataset.py │ ├── ds_utils.py │ ├── exceptions.py │ ├── export.py │ ├── loggers.py │ ├── loss.py │ ├── metrics.py │ ├── mfa.py │ ├── model.py │ ├── number_normalization.py │ ├── number_normalization_test.py │ ├── optimizer.py │ ├── punctuation.py │ ├── runs.py │ ├── shell.py │ ├── soft_dtw.py │ ├── sql_logger.py │ ├── ssml_parser.py │ ├── text.py │ ├── tokenization.py │ ├── tools.py │ └── wandb_logger.py │ └── vocoder_training.py ├── package.json ├── src ├── App.test.js ├── App.tsx ├── app │ └── store.ts ├── channels.ts ├── components │ ├── audio_player │ │ ├── AudioBottomBar.tsx │ │ └── AudioPlayer.tsx │ ├── breadcrumb │ │ └── BreadcrumbItem.tsx │ ├── cards │ │ ├── DocumentationCard.tsx │ │ └── RunCard.tsx │ ├── charts │ │ ├── LineChart.tsx │ │ └── PieChart.tsx │ ├── help │ │ ├── HelpButton.tsx │ │ └── HelpIcon.tsx │ ├── image │ │ └── Image.tsx │ ├── inputs │ │ ├── AcousticModelTypeInput.tsx │ │ ├── AlignmentBatchSizeInput.tsx │ │ ├── BatchSizeInput.tsx │ │ ├── DatasetInput.tsx │ │ ├── DeviceInput.tsx │ │ ├── GradientAccumulationInput.tsx │ │ ├── GradientAccumulationStepsInput.tsx │ │ ├── LanguageSelect.tsx │ │ ├── LearningRateInput.tsx │ │ ├── MaximumWorkersInput.tsx │ │ ├── NameInput.tsx │ │ ├── RunValidationEveryInput.tsx │ │ ├── SkipOnErrorInput.tsx │ │ ├── TrainOnlySpeakerEmbedsUntilInput.tsx │ │ └── TrainingStepsInput.tsx │ ├── log_printer │ │ ├── LogPrinter.tsx │ │ └── Terminal.tsx │ ├── modals │ │ └── NoCloseModal.tsx │ ├── run_management │ │ └── RunManager.tsx │ ├── runs │ │ ├── ProcessingSteps.tsx │ │ ├── RunConfiguration.tsx │ │ └── RunConfigurationForm.tsx │ └── usage_stats │ │ └── UsageStatsRow.tsx ├── config.ts ├── electron │ ├── electron.ts │ ├── handles │ │ ├── cleaningRuns.ts │ │ ├── datasets.ts │ │ ├── docker.ts │ │ ├── files.ts │ │ ├── install.ts │ │ ├── models.ts │ │ ├── preprocessingRuns.ts │ │ ├── sampleSplittingRuns.ts │ │ ├── settings.ts │ │ ├── synthesis.ts │ │ ├── textNormalizationRuns.ts │ │ └── trainingRuns.ts │ └── utils │ │ ├── db.ts │ │ ├── docker.ts │ │ ├── files.ts │ │ ├── globals.ts │ │ └── processes.ts ├── features │ ├── appInfoSlice.ts │ ├── importSettings.ts │ ├── navigationSettingsSlice.ts │ ├── runManagerSlice.ts │ └── usageStatsSlice.ts ├── fonts │ └── atmospheric.ttf ├── global.css ├── index.html ├── interfaces.ts ├── pages │ ├── datasets │ │ ├── Dataset.tsx │ │ ├── DatasetSelection.tsx │ │ ├── Datasets.tsx │ │ ├── ImportSettingsDialog.tsx │ │ └── Speaker.tsx │ ├── documentation │ │ └── Introduction.tsx │ ├── main_loading │ │ ├── InstallerOptions.tsx │ │ └── MainLoading.tsx │ ├── models │ │ ├── Models.tsx │ │ └── Synthesize.tsx │ ├── preprocessing_runs │ │ ├── PreprocessingRunSelection.tsx │ │ ├── PreprocessingRuns.tsx │ │ ├── dataset_cleaning │ │ │ ├── ApplyChanges.tsx │ │ │ ├── ChooseSamples.tsx │ │ │ ├── Configuration.tsx │ │ │ ├── DatasetCleaning.tsx │ │ │ └── Preprocessing.tsx │ │ ├── sample_splitting │ │ │ ├── ApplyChanges.tsx │ │ │ ├── ChooseSamples.tsx │ │ │ ├── Configuration.tsx │ │ │ ├── Preprocessing.tsx │ │ │ └── SampleSplitting.tsx │ │ └── text_normalization │ │ │ ├── ChooseSamples.tsx │ │ │ ├── Configuration.tsx │ │ │ ├── Preprocessing.tsx │ │ │ └── TextNormalization.tsx │ ├── run_queue │ │ └── RunQueue.tsx │ ├── settings │ │ └── Settings.tsx │ └── training_runs │ │ ├── AcousticModelFinetuning.tsx │ │ ├── AcousticStatistics.tsx │ │ ├── AudioStatistic.css │ │ ├── AudioStatistic.tsx │ │ ├── Configuration.tsx │ │ ├── CreateModel.tsx │ │ ├── GroundTruthAlignment.tsx │ │ ├── ImageStatistic.css │ │ ├── ImageStatistic.tsx │ │ ├── Preprocessing.tsx │ │ ├── RunSelection.tsx │ │ ├── SaveModel.tsx │ │ ├── TrainingRuns.tsx │ │ ├── VocoderFineTuning.tsx │ │ └── VocoderStatistics.tsx ├── react-app-env.d.ts ├── renderer.tsx ├── reportWebVitals.js ├── routes.ts ├── setupTests.js └── utils.tsx ├── tsconfig.json ├── webpack.main.config.js ├── webpack.plugins.js ├── webpack.renderer.config.js ├── webpack.rules.js └── yarn.lock /.devcontainer/devcontainer.json: -------------------------------------------------------------------------------- 1 | // For format details, see https://aka.ms/devcontainer.json. For config options, see the README at: 2 | // https://github.com/microsoft/vscode-dev-containers/tree/v0.241.1/containers/docker-existing-dockerfile 3 | { 4 | "name": "Existing Dockerfile", 5 | 6 | // Sets the run context to one level up instead of the .devcontainer folder. 7 | "context": "..", 8 | 9 | // Update the 'dockerFile' property if you aren't using the standard 'Dockerfile' filename. 10 | "dockerFile": "../Dockerfile" 11 | 12 | // Use 'forwardPorts' to make a list of ports inside the container available locally. 13 | // "forwardPorts": [], 14 | 15 | // Uncomment the next line to run commands after the container is created - for example installing curl. 16 | // "postCreateCommand": "apt-get update && apt-get install -y curl", 17 | 18 | // Uncomment when using a ptrace-based debugger like C++, Go, and Rust 19 | // "runArgs": [ "--cap-add=SYS_PTRACE", "--security-opt", "seccomp=unconfined" ], 20 | 21 | // Uncomment to use the Docker CLI from inside the container. See https://aka.ms/vscode-remote/samples/docker-from-docker. 22 | // "mounts": [ "source=/var/run/docker.sock,target=/var/run/docker.sock,type=bind" ], 23 | 24 | // Uncomment to connect as a non-root user if you've added one. See https://aka.ms/vscode-remote/containers/non-root. 25 | // "remoteUser": "vscode" 26 | } 27 | -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | /backend/ 2 | /storage/ 3 | /node_modules/ 4 | /out/ 5 | /src/ 6 | /pretrained_models/ -------------------------------------------------------------------------------- /.eslintrc.json: -------------------------------------------------------------------------------- 1 | { 2 | "env": { 3 | "browser": true, 4 | "es6": true, 5 | "node": true 6 | }, 7 | "extends": [ 8 | "eslint:recommended", 9 | "plugin:@typescript-eslint/eslint-recommended", 10 | "plugin:@typescript-eslint/recommended", 11 | "plugin:import/recommended", 12 | "plugin:import/electron", 13 | "plugin:import/typescript" 14 | ], 15 | "parser": "@typescript-eslint/parser" 16 | } 17 | -------------------------------------------------------------------------------- /.github/workflows/unittest-backend.yml: -------------------------------------------------------------------------------- 1 | name: Unit-test backend 2 | 3 | on: 4 | push: 5 | branches: 6 | - '*' 7 | 8 | pull_request: 9 | branches: 10 | - '*' 11 | 12 | schedule: 13 | - cron: '0 0 * * 0' # “At 00:00 on Sunday.” 14 | 15 | env: 16 | CONDA_ENV_NAME: voice_smith 17 | CACHE_NUMBER: 0 # increase to reset cache manually 18 | 19 | jobs: 20 | build: 21 | strategy: 22 | matrix: 23 | include: 24 | - 25 | os: ubuntu-latest 26 | label: linux-64 27 | prefix: /usr/share/miniconda3/envs/voice_smith 28 | 29 | name: ${{ matrix.label }} 30 | runs-on: ${{ matrix.os }} 31 | steps: 32 | - 33 | name: Checkout 34 | uses: actions/checkout@v2 35 | 36 | - name: Setup Mambaforge 37 | uses: conda-incubator/setup-miniconda@v2 38 | with: 39 | miniforge-variant: Mambaforge 40 | miniforge-version: latest 41 | activate-environment: voice_smith 42 | use-mamba: true 43 | 44 | - name: Set cache date 45 | run: echo "DATE=$(date +'%Y%m%d')" >> $GITHUB_ENV 46 | 47 | - uses: actions/cache@v2 48 | with: 49 | path: ${{ matrix.prefix }} 50 | key: ${{ matrix.label }}-conda-${{ hashFiles('./backend/environment.yml') }}-${{ env.DATE }}-${{ env.CACHE_NUMBER }} 51 | id: cache 52 | 53 | - name: Update environment 54 | run: mamba env update -n $CONDA_ENV_NAME -f ./backend/environment.yml 55 | if: steps.cache.outputs.cache-hit != 'true' 56 | 57 | - name: Linting using Pylint 58 | run: pylint ./backend 59 | continue-on-error: true 60 | 61 | - name: Run pytest in backend 62 | shell: bash -l {0} 63 | run: conda run -n $CONDA_ENV_NAME python -m pytest ./backend 64 | 65 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Logs 2 | logs 3 | *.log 4 | npm-debug.log* 5 | yarn-debug.log* 6 | yarn-error.log* 7 | lerna-debug.log* 8 | 9 | # Diagnostic reports (https://nodejs.org/api/report.html) 10 | report.[0-9]*.[0-9]*.[0-9]*.[0-9]*.json 11 | 12 | # Runtime data 13 | pids 14 | *.pid 15 | *.seed 16 | *.pid.lock 17 | 18 | # Directory for instrumented libs generated by jscoverage/JSCover 19 | lib-cov 20 | 21 | # Coverage directory used by tools like istanbul 22 | coverage 23 | *.lcov 24 | 25 | # nyc test coverage 26 | .nyc_output 27 | 28 | # Grunt intermediate storage (https://gruntjs.com/creating-plugins#storing-task-files) 29 | .grunt 30 | 31 | # Bower dependency directory (https://bower.io/) 32 | bower_components 33 | 34 | # node-waf configuration 35 | .lock-wscript 36 | 37 | # Compiled binary addons (https://nodejs.org/api/addons.html) 38 | build/Release 39 | 40 | # Dependency directories 41 | node_modules/ 42 | jspm_packages/ 43 | 44 | # TypeScript v1 declaration files 45 | typings/ 46 | 47 | # TypeScript cache 48 | *.tsbuildinfo 49 | 50 | # Optional npm cache directory 51 | .npm 52 | 53 | # Optional eslint cache 54 | .eslintcache 55 | 56 | # Microbundle cache 57 | .rpt2_cache/ 58 | .rts2_cache_cjs/ 59 | .rts2_cache_es/ 60 | .rts2_cache_umd/ 61 | 62 | # Optional REPL history 63 | .node_repl_history 64 | 65 | # Output of 'npm pack' 66 | *.tgz 67 | 68 | # Yarn Integrity file 69 | .yarn-integrity 70 | 71 | # dotenv environment variables file 72 | .env 73 | .env.test 74 | 75 | # parcel-bundler cache (https://parceljs.org/) 76 | .cache 77 | 78 | # Next.js build output 79 | .next 80 | 81 | # Nuxt.js build / generate output 82 | .nuxt 83 | dist 84 | 85 | # Gatsby files 86 | .cache/ 87 | # Comment in the public line in if your project uses Gatsby and *not* Next.js 88 | # https://nextjs.org/blog/next-9-1#public-directory-support 89 | # public 90 | 91 | # vuepress build output 92 | .vuepress/dist 93 | 94 | # Serverless directories 95 | .serverless/ 96 | 97 | # FuseBox cache 98 | .fusebox/ 99 | 100 | # pytest cache 101 | .pytest_cache/ 102 | 103 | # DynamoDB Local files 104 | .dynamodb/ 105 | 106 | # TernJS port file 107 | .tern-port 108 | 109 | # Python Auto-Generated 110 | __pycache__/ 111 | 112 | # Build directories 113 | /build/ 114 | /backend_dist/ 115 | /dist/ 116 | /entry.build/ 117 | /entry.dist/ 118 | /assets/*.pt 119 | *.egg-info/ 120 | 121 | # Enviroment 122 | /env/ 123 | 124 | # User data 125 | /voice_smith.db 126 | wandb/ 127 | /out/ 128 | /.webpack/ 129 | /backup/ 130 | /data/ 131 | /assets/tiny_bert/ 132 | /pretrained_models/ 133 | /new_storage_path_plus/ 134 | /backend/voice_smith/g2p/checkpoints*/ 135 | /backend/voice_smith/g2p/datasets*/ 136 | /backend/voice_smith/g2p/dictionaries*/ 137 | /assets/g2p/ -------------------------------------------------------------------------------- /.media/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dunky11/voicesmith/35b00c8619f5272b4ec2e087a7cc871c4f71aadf/.media/architecture.png -------------------------------------------------------------------------------- /.media/hero.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dunky11/voicesmith/35b00c8619f5272b4ec2e087a7cc871c4f71aadf/.media/hero.png -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM continuumio/miniconda3 2 | RUN apt-get update 3 | RUN apt-get install build-essential -y 4 | RUN useradd -ms /bin/bash voice_smith 5 | WORKDIR /home/voice_smith 6 | COPY ./assets /home/voice_smith/assets 7 | USER voice_smith 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VoiceSmith [Work in Progress] 2 | 3 | VoiceSmith makes it possible to train and infer on both single and multispeaker models without any coding experience. It fine-tunes a pretty solid text to speech pipeline based on a modified version of [DelightfulTTS](https://arxiv.org/abs/2110.12612) and [UnivNet](https://arxiv.org/abs/2106.07889) on your dataset. Both models were pretrained on a proprietary 5000 speaker dataset. It also provides some tools for dataset preprocessing like automatic text normalization. 4 | 5 | If you want to play around with a model trained on a highly emotional emotional 60 speaker dataset using an earlier version of this software [click here](https://colab.research.google.com/drive/1zh6w_TpEAyr_UIojiLmt4ZdYLWeap9mn#scrollTo=vQCA50dao0Mt). 6 | 7 | 8 | 9 | ## Requirements 10 | 11 | #### Hardware 12 | * OS: Windows (only CPU supported currently) or any Linux based operating system. If you want to run this on macOS you have to follow the steps in build from source in order to create the installer. This is untested since I don't currently own a Mac. 13 | * Graphics: NVIDIA GPU with [CUDA support](https://developer.nvidia.com/cuda-gpus) is highly recommended, you can train on CPU otherwise but it will take days if not weeks. 14 | * RAM: 8GB of RAM, you can try with less but it may not work. 15 | 16 | #### Software 17 | * Docker, you can [download it here](https://docs.docker.com/get-docker/). If you are on Linux, it is advised to install [Docker Engine](https://docs.docker.com/engine/install/ubuntu/#install-using-the-repository) instead of Docker Desktop, as Docker Desktop likes to make the world complicated. 18 | 19 | ## How to install 20 | 21 | 1. Download the latest installer from the [releases page](https://github.com/dunky11/voicesmith/releases). 22 | 2. Double click to run the installer. 23 | 24 | ## How to develop 25 | 26 | 1. Make sure you have the latest version of [Node.js](https://nodejs.org/) installed 27 | 2. Clone the repository 28 | 29 | ``` 30 | git clone https://github.com/dunky11/voicesmith 31 | ``` 32 | 3. Install dependencies, this can take a minute 33 | 34 | ``` 35 | cd voicesmith 36 | npm install 37 | ``` 38 | 4. [Click here](https://drive.google.com/drive/folders/15VQgRxGO_Z_RUNMyuJreg9O5Ckcit2vh?usp=sharing), select the folder with the latest version, download all files and place them inside the repositories assets folder. 39 | 40 | 5. Start the project 41 | 42 | ``` 43 | npm start 44 | ``` 45 | 46 | ## Build from source 47 | 48 | 1. Follow steps 1 - 4 from above. 49 | 2. Run make, this will create a folder named out/make with an installer inside. The installer will be different based on your operating system. 50 | 51 | ``` 52 | npm make 53 | ``` 54 | ## Architecture 55 | 56 | VoiceSmith currently uses a two-stage modified DelightfulTTS and UnivNet pipeline. 57 | 58 | 59 | 60 | 61 | ## Contribute 62 | 63 | Show your support by ⭐ the project. Pull requests are always welcome. 64 | 65 | ## License 66 | 67 | This project is licensed under the Apache-2.0 license - see the [LICENSE.md](https://github.com/dunky11/voicesmith/blob/master/LICENSE) file for details. 68 | -------------------------------------------------------------------------------- /assets/icon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dunky11/voicesmith/35b00c8619f5272b4ec2e087a7cc871c4f71aadf/assets/icon.ico -------------------------------------------------------------------------------- /assets/icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dunky11/voicesmith/35b00c8619f5272b4ec2e087a7cc871c4f71aadf/assets/icon.png -------------------------------------------------------------------------------- /assets/word_lists/german_initialisms.txt: -------------------------------------------------------------------------------- 1 | ABC 2 | ABS 3 | ACE 4 | ADAC 5 | ADHS 6 | ADS 7 | AG 8 | AGB 9 | AMS 10 | AN 11 | ARD 12 | AfD 13 | BA 14 | BASF 15 | BDM 16 | BGA 17 | BGB 18 | BGE 19 | BGH 20 | BGS 21 | BH 22 | BKA 23 | BLZ 24 | BMVIT 25 | BMW 26 | BND 27 | BNE 28 | BR 29 | BRV 30 | BVG 31 | BWL 32 | BtMG 33 | CBD 34 | CCC 35 | CDU 36 | CFK 37 | CSD 38 | CSEL 39 | CSU 40 | CeBIT 41 | DB 42 | DBVS 43 | DDR 44 | DFB 45 | DGB 46 | DI 47 | DKP 48 | DLR 49 | DM 50 | DNS 51 | DPA 52 | DRK 53 | DTB 54 | DWB 55 | DaF 56 | DaZ 57 | EDU 58 | EDV 59 | EG 60 | EJRM 61 | EL 62 | ELGA 63 | EM 64 | EPO 65 | ESM 66 | ESP 67 | ETW 68 | EU 69 | EVP 70 | EWG 71 | EWSA 72 | EWSD 73 | EZB 74 | FC 75 | FCKW 76 | FDP 77 | FH 78 | FKK 79 | FOC 80 | FWF 81 | GAU 82 | GDI 83 | GEZ 84 | GG 85 | GH 86 | GPS 87 | GSG 88 | GSZ 89 | GUS 90 | GmbH 91 | HB 92 | HGW 93 | HJ 94 | HNO 95 | HQ 96 | HWB 97 | IBN 98 | IBS 99 | IGBCE 100 | IHS 101 | IO 102 | IOK 103 | IPA 104 | IPR 105 | JKU 106 | KAGB 107 | KBC 108 | KFZ 109 | KGV 110 | KI 111 | KMU 112 | KP 113 | LDVH 114 | LSD 115 | MAN 116 | MEZ 117 | MG 118 | MP 119 | MPI 120 | MfS 121 | NDH 122 | NDW 123 | NHN 124 | NLW 125 | NN 126 | NPD 127 | Nm 128 | OGH 129 | OHG 130 | OLG 131 | ONS 132 | OOS 133 | OPCW 134 | ORB 135 | OSZE 136 | OVCW 137 | OWL 138 | PLK 139 | PM 140 | PTBS 141 | PZE 142 | PdA 143 | Pz 144 | PzB 145 | RAF 146 | RIF 147 | RKL 148 | RPzB 149 | RWE 150 | RZ 151 | SED 152 | SGDI 153 | SGP 154 | SP 155 | SPD 156 | SPDL 157 | SPS 158 | SUV 159 | SVG 160 | SVP 161 | StVO 162 | TTS 163 | TU 164 | UKW 165 | UPN 166 | USSR 167 | UVP 168 | VAE 169 | VDA 170 | VDP 171 | VHS 172 | VN 173 | VO 174 | VPG 175 | VSV 176 | VWL 177 | VfB 178 | WBF 179 | WDR 180 | WKN 181 | a.a.O. 182 | glp 183 | m.E. 184 | mfG 185 | mnl. 186 | n.V. 187 | v.l.n.r. 188 | ÖAAB 189 | ÖIF 190 | -------------------------------------------------------------------------------- /assets/word_lists/russian_initialisms.txt: -------------------------------------------------------------------------------- 1 | АБК 2 | АДД 3 | АЗС 4 | АК-47 5 | АКПП 6 | АНБ 7 | АРГК 8 | АТР 9 | АУЕ 10 | Б.М. 11 | БМП 12 | БНО 13 | БСБ 14 | БТР 15 | БЧБ 16 | Би-би-си 17 | ВВП 18 | ВВС 19 | ВДНХ 20 | ВМФ 21 | ВНЖ 22 | ВНП 23 | ВРК 24 | ВСУ 25 | ВТО 26 | ВЦИК 27 | ВЧК 28 | ГД 29 | ГКЧП 30 | ГОЭЛРО 31 | ГТО 32 | ДЗ 33 | ДНК 34 | ДНР 35 | ДТП 36 | ЕАО 37 | ЕАС 38 | ЕС 39 | ЕСПЧ 40 | ЖКХ 41 | ЗРК 42 | ЗС 43 | И.О. 44 | ИК 45 | ИП 46 | КВР 47 | КГБ 48 | ККК 49 | КНДР 50 | КНР 51 | КПК 52 | КПП 53 | КПСС 54 | ЛГБТ 55 | ЛНР 56 | ЛПЗП 57 | ЛУ 58 | МВД 59 | МВФ 60 | МГУ 61 | МК 62 | МКС 63 | МОБ 64 | МТКК 65 | МФА 66 | МЧС 67 | Мб 68 | НДС 69 | НКВД 70 | НКГБ 71 | НКО 72 | НКР 73 | НЛО 74 | НПО 75 | ОАЭ 76 | ОБЖ 77 | ОБСЕ 78 | ОДКБ 79 | ОЗХО 80 | ОКБ 81 | ООО 82 | ОПК 83 | ОРВИ 84 | ПВО 85 | ПГМ 86 | ПЖиВ 87 | ПМЖ 88 | ПМР 89 | ПО 90 | ПРО 91 | ПТУ 92 | РБ 93 | РЖД 94 | РКК 95 | РНК 96 | РНР 97 | РПЦ 98 | РСФСР 99 | РФ 100 | СК 101 | СКФО 102 | СНВ 103 | СНГ 104 | СНК 105 | СНХ 106 | СП 107 | СПб 108 | СССР 109 | СФСР 110 | США 111 | ТВ 112 | ТОФСЭ 113 | ТРКИ 114 | ТТС 115 | УЗИ 116 | Ф.И.О. 117 | ФБК 118 | ФБР 119 | ФИО 120 | ФМС 121 | ФНС 122 | ФСБ 123 | ФСИН 124 | ФТС 125 | ХВ 126 | ЦБК 127 | ЦК 128 | ЦРУ 129 | ЧВК 130 | ЧК 131 | ЭВМ 132 | ЮВА 133 | б.г. 134 | б.м. 135 | б.у. 136 | б/г 137 | б/м 138 | б/у 139 | бн/о 140 | д.-в.-н. 141 | д.-в.-нем. 142 | дз 143 | драп 144 | и.о. 145 | м.б. 146 | м.г. 147 | мл 148 | п.г. 149 | п.т.ч. 150 | п/г 151 | п/м 152 | п/пр. 153 | пгт 154 | с.г. 155 | с.м. 156 | с.ч. 157 | с.ш. 158 | с/г 159 | с/м 160 | с/ч 161 | т.г. 162 | т.е. 163 | т.к. 164 | т.м. 165 | т.н. 166 | т/г 167 | т/м 168 | хз 169 | эсер 170 | -------------------------------------------------------------------------------- /assets/word_lists/spanish_abbreviations.txt: -------------------------------------------------------------------------------- 1 | & 2 | 1.º 3 | 2.º 4 | 2a 5 | 2o 6 | 2ª 7 | 2º 8 | 3ª 9 | 3º 10 | 4º 11 | 5ª 12 | 5º 13 | AEUMC 14 | AG 15 | ARA 16 | ARC 17 | Ags 18 | Ags. 19 | Aguas 20 | Atte. 21 | BCN 22 | C's 23 | CAN 24 | CDMX 25 | CE 26 | CH 27 | CI 28 | CL 29 | CM 30 | CS 31 | Calz. 32 | Cdad. 33 | Crta. 34 | Cs 35 | D. 36 | DCM 37 | DU 38 | Dr. 39 | Dra. 40 | Drs. 41 | FAB 42 | FAC 43 | Fdez 44 | G. 45 | GR 46 | GT 47 | Glez 48 | Gral. 49 | Gzlez 50 | HG 51 | IVA 52 | Juº 53 | LATAM 54 | M.ª 55 | MC 56 | ML 57 | MX 58 | Ma. 59 | Mª 60 | NE 61 | NL 62 | NLE 63 | NO 64 | OA 65 | ONT 66 | PGR 67 | PMA 68 | PNN 69 | PU 70 | Play 71 | QE 72 | RFA 73 | Rdguez. 74 | SE 75 | SI 76 | SL 77 | SN 78 | Sin. 79 | Sr. 80 | Sra. 81 | Srta. 82 | Sta. 83 | Stgo. 84 | TA 85 | TB 86 | TCM 87 | TL 88 | TQM 89 | TUA 90 | Ud. 91 | Uds. 92 | VC 93 | Val. 94 | Vd. 95 | Vds. 96 | Vmd. 97 | Vzla. 98 | YC 99 | ZA 100 | Zac 101 | aC 102 | ab. 103 | adj. 104 | admón. 105 | adv. 106 | ant. 107 | aprox 108 | astrol. 109 | aum. 110 | av 111 | avda 112 | avda. 113 | ayto. 114 | bpd 115 | c/ 116 | cdla 117 | cdo. 118 | clás. 119 | coloq. 120 | coop 121 | cs 122 | cta 123 | ctvo. 124 | ctvs. 125 | d.C. 126 | dcha. 127 | der. 128 | dim. 129 | dls 130 | dpto 131 | drcha 132 | dud. 133 | endocrino 134 | etc. 135 | ext. 136 | faj. 137 | fol 138 | fr. 139 | frs. 140 | gral 141 | gral. 142 | gén. 143 | hb 144 | hisp. 145 | hno. 146 | inf. 147 | int. 148 | it. 149 | izda 150 | izq 151 | izq. 152 | lat. 153 | mts. 154 | n. 155 | n.º 156 | no 157 | núm. 158 | p. 159 | pasada 160 | post. 161 | pp. 162 | pq 163 | pta 164 | pto 165 | pulg 166 | pza 167 | pág 168 | qn 169 | qq 170 | q̃ 171 | rep. 172 | s/n 173 | servo 174 | señá 175 | sida 176 | slds 177 | sma. 178 | sup. 179 | tamb 180 | tb 181 | tmb 182 | vdd 183 | vds 184 | velc 185 | vuecencia 186 | wn 187 | xq 188 | ár. 189 | -------------------------------------------------------------------------------- /assets/word_lists/spanish_initialisms.txt: -------------------------------------------------------------------------------- 1 | A.T. 2 | ACS 3 | ACV 4 | ADE 5 | ADN 6 | AEC 7 | AL 8 | ALDF 9 | ALV 10 | ANHQV 11 | ANSV 12 | ARN 13 | AT 14 | AUE 15 | AUH 16 | AXJ 17 | BCB 18 | BCE 19 | BID 20 | BM 21 | BOE 22 | BOPE 23 | BS 24 | BTT 25 | CA 26 | CAE 27 | CAISS 28 | CCD 29 | CCNCC 30 | CDS 31 | CEI 32 | CEPAL 33 | CES 34 | CF 35 | CFC 36 | CGPJ 37 | CHE 38 | CI 39 | CJNG 40 | CNI 41 | CNMC 42 | CNT 43 | CSD 44 | CyL 45 | D.F. 46 | DANA 47 | DCV 48 | DEP 49 | DF 50 | DGT 51 | DLE 52 | DNI 53 | DPN 54 | DRAE 55 | DT 56 | EAU 57 | ECG 58 | EDAR 59 | EGDE 60 | ELN 61 | EPD 62 | EPOC 63 | ERC 64 | ETS 65 | ETT 66 | EU 67 | EUA 68 | FA 69 | FCF 70 | FMI 71 | FMLN 72 | FRA 73 | FSLN 74 | GC 75 | GH 76 | GNL 77 | HBP 78 | HDA 79 | HDB 80 | HDLGP 81 | HDP 82 | IA 83 | IBEX 84 | ICEX 85 | ICFT 86 | IDH 87 | IES 88 | IFE 89 | IGN 90 | IMAO 91 | IMC 92 | IME 93 | IMSS 94 | INAH 95 | INSS 96 | IRA 97 | IRPF 98 | ISRS 99 | ITS 100 | ITU 101 | IU 102 | JCE 103 | JLB 104 | JRG 105 | LATAM 106 | LQSA 107 | LSE 108 | MDP 109 | MDQ 110 | MIR 111 | MPR 112 | MYHYV 113 | NIE 114 | NMC 115 | NNA 116 | NOM 117 | NPI 118 | OEA 119 | OGM 120 | OMC 121 | OMI 122 | OMM 123 | OMS 124 | OMT 125 | PA 126 | PAN 127 | PBC 128 | PCC 129 | PCR 130 | PCUS 131 | PFM 132 | PIN 133 | PNL 134 | PNV 135 | PP 136 | PPE 137 | PPK 138 | PPP 139 | PRI 140 | PRM 141 | PSUV 142 | PUC 143 | PVP 144 | QR 145 | RAAN 146 | RAAS 147 | RAE 148 | RAP 149 | RCN 150 | RCP 151 | RDSI 152 | RPC 153 | RU 154 | S.A. 155 | S.T.D. 156 | SAG 157 | SAI 158 | SARM 159 | SCA 160 | SD 161 | SHCP 162 | SMI 163 | SRL 164 | TAC 165 | TAPO 166 | TC 167 | TCA 168 | TDAH 169 | TEDH 170 | TEP 171 | THC 172 | TIC 173 | TKM 174 | TLC 175 | TLCAN 176 | TMA 177 | TOC 178 | TU 179 | TV3 180 | TVE 181 | alv 182 | bdd 183 | bpd 184 | dana 185 | e.d. 186 | mcd 187 | mcm 188 | mdd 189 | mde 190 | mdp 191 | msnm 192 | -------------------------------------------------------------------------------- /backend/environment.yml: -------------------------------------------------------------------------------- 1 | name: voice_smith 2 | channels: 3 | - conda-forge 4 | - defaults 5 | - pytorch 6 | - anaconda 7 | - photosynthesis-team 8 | 9 | dependencies: 10 | - python=3.8 11 | - flask=2.1.2 12 | - flask-cors=3.0.10 13 | - pytorch::torchvision=0.12.0 14 | - pytorch::torchaudio=0.11.0 15 | - pytorch::pytorch=1.11.0 16 | - pytorch::torchtext==0.12.0 17 | - conda-forge::cudatoolkit=11.3.1 18 | - psutil=5.9.0 19 | - einops=0.4.1 20 | - fire=0.4.0 21 | - waitress=2.1.1 22 | - pysoundfile 23 | - conda-forge::montreal-forced-aligner=2.0.5 24 | - numpy=1.22.4 25 | - anaconda::cython=0.29.28 26 | - conda-forge::torchmetrics=0.9.1 27 | - conda-forge::spacy=3.3.1 28 | - photosynthesis-team::piq=0.7.0 29 | - conda-forge::monkeytype=22.2.0 30 | - conda-forge::librosa=0.9.2 31 | - pytest 32 | - pip 33 | - pip: 34 | - -e . 35 | - tgt==1.4.1 36 | - nemo-toolkit[all]==1.8.2 37 | - pythainlp==3.0.8 38 | - git+https://github.com/savoirfairelinux/num2words.git # Package is not up to date in both pip and conda 39 | -------------------------------------------------------------------------------- /backend/global_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | import glob 4 | 5 | 6 | def test_should_import_modules(): 7 | in_dir = str(Path(os.path.dirname(__file__)) / "voice_smith") 8 | for module in glob.iglob(f"{in_dir}/**/*.py"): 9 | if module == "__init__.py" or module[-3:] != ".py": 10 | continue 11 | __import__(module[:-3].replace("/", "."), locals(), globals()) 12 | 13 | -------------------------------------------------------------------------------- /backend/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name="voice_smith", 5 | version="1.0.0", 6 | description="VoiceSmith makes training TTS models easy. It also provides utilities for dataset preprocessing.", 7 | author="dunky11", 8 | packages=["voice_smith"], 9 | ) 10 | -------------------------------------------------------------------------------- /backend/voice_smith/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dunky11/voicesmith/35b00c8619f5272b4ec2e087a7cc871c4f71aadf/backend/voice_smith/__init__.py -------------------------------------------------------------------------------- /backend/voice_smith/config/file_extensions.py: -------------------------------------------------------------------------------- 1 | SUPPORTED_AUDIO_EXTENSIONS = [ 2 | ".wav", 3 | ".flac", 4 | ".ogg", 5 | ".sph", 6 | ] 7 | SUPPORTED_TEXT_EXTENSIONS = [".txt"] 8 | -------------------------------------------------------------------------------- /backend/voice_smith/config/globals.py: -------------------------------------------------------------------------------- 1 | DB_PATH = "/home/voice_smith/db/voice_smith.db" 2 | ASSETS_PATH = "/home/voice_smith/assets" 3 | USER_DATA_PATH = "/home/data" 4 | TRAINING_RUNS_PATH = "/home/voice_smith/data/training_runs" 5 | TEXT_NORMALIZATION_RUNS_PATH = "/home/voice_smith/data/text_normalization_runs" 6 | SAMPLE_SPLITTING_RUNS_PATH = "/home/voice_smith/data/sample_splitting_runs" 7 | CLEANING_RUNS_PATH = "/home/voice_smith/data/cleaning_runs" 8 | MODELS_PATH = "/home/voice_smith/data/models" 9 | DATASETS_PATH = "/home/voice_smith/data/datasets" 10 | AUDIO_SYNTH_PATH = "/home/voice_smith/data/audio_synth" 11 | ENVIRONMENT_NAME = "voice_smith" 12 | -------------------------------------------------------------------------------- /backend/voice_smith/config/langs.py: -------------------------------------------------------------------------------- 1 | SUPPORTED_LANGUAGES = [ 2 | "bg", 3 | "cs", 4 | "de", 5 | "en", 6 | "es", 7 | "fr", 8 | "ha", 9 | "hr", 10 | "ko", 11 | "pl", 12 | "pt", 13 | "ru", 14 | "sv", 15 | "sw", 16 | "th", 17 | "tr", 18 | "uk", 19 | "vi", 20 | "zh", 21 | ] 22 | 23 | # Mappings from symbol to numeric ID and vice versa: 24 | lang2id = {s: i for i, s in enumerate(SUPPORTED_LANGUAGES)} 25 | id2lang = {i: s for i, s in enumerate(SUPPORTED_LANGUAGES)} 26 | -------------------------------------------------------------------------------- /backend/voice_smith/config/symbols.py: -------------------------------------------------------------------------------- 1 | symbols = [str(el) for el in range(256)] 2 | symbol2id = {s: i for i, s in enumerate(symbols)} 3 | id2symbol = {i: s for i, s in enumerate(symbols)} 4 | -------------------------------------------------------------------------------- /backend/voice_smith/create_splits.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import random 3 | from voice_smith import TRAINING_RUNS_PATH 4 | 5 | VAL_SIZE = 2000 6 | 7 | paths = [ 8 | TRAINING_RUNS_PATH / "pretraining_two_stage_ac" / "data" / "train.txt", 9 | TRAINING_RUNS_PATH / "pretraining_two_stage_ac" / "data" / "val.txt", 10 | ] 11 | 12 | lines = [] 13 | 14 | for path in paths: 15 | with open(path, "r") as f: 16 | for line in f: 17 | lines.append(line) 18 | 19 | random.shuffle(lines) 20 | 21 | train_lines, val_lines = lines[VAL_SIZE:], lines[:VAL_SIZE] 22 | 23 | with open( 24 | TRAINING_RUNS_PATH 25 | / "pretraining_two_stage_ac" 26 | / "data" 27 | / "train_finetuning.txt", 28 | "w", 29 | ) as f: 30 | for line in train_lines: 31 | f.write(line) 32 | 33 | with open( 34 | TRAINING_RUNS_PATH 35 | / "pretraining_two_stage_ac" 36 | / "data" 37 | / "val_finetuning.txt", 38 | "w", 39 | ) as f: 40 | for line in val_lines: 41 | f.write(line) 42 | -------------------------------------------------------------------------------- /backend/voice_smith/finetune_acoustic.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import fire 3 | from pathlib import Path 4 | from voice_smith.acoustic_training import train_acoustic 5 | 6 | if __name__ == "__main__": 7 | from voice_smith.config.preprocess_config import preprocess_config 8 | from voice_smith.config.acoustic_fine_tuning_config import ( 9 | acoustic_fine_tuning_config, 10 | ) 11 | from voice_smith.config.acoustic_model_config import acoustic_model_config 12 | from voice_smith.utils.wandb_logger import WandBLogger 13 | import wandb 14 | 15 | def pass_args(training_run_name, checkpoint_path=None): 16 | training_run_name = str(training_run_name) 17 | if checkpoint_path == None: 18 | reset = True 19 | checkpoint_path = str(Path(".") / "assets" / "acoustic_pretrained.pt") 20 | else: 21 | checkpoint_path = str(checkpoint_path) 22 | reset = False 23 | device = ( 24 | torch.device("cuda") if torch.cuda.is_available else torch.device("cpu") 25 | ) 26 | logger = WandBLogger(training_run_name) 27 | wandb.config.update( 28 | { 29 | "preprocess_config": preprocess_config, 30 | "model_config": acoustic_model_config, 31 | "training_config": acoustic_fine_tuning_config, 32 | }, 33 | allow_val_change=True, 34 | ) 35 | train_acoustic( 36 | training_run_name=training_run_name, 37 | preprocess_config=preprocess_config, 38 | model_config=acoustic_model_config, 39 | train_config=acoustic_fine_tuning_config, 40 | checkpoint_path=checkpoint_path, 41 | logger=logger, 42 | device=device, 43 | fine_tuning=True, 44 | reset=reset, 45 | ) 46 | 47 | fire.Fire(pass_args) 48 | -------------------------------------------------------------------------------- /backend/voice_smith/g2p/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Axel Springer News Media & Tech GmbH & Co. KG - Ideas Engineering 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /backend/voice_smith/g2p/dp/__init__.py: -------------------------------------------------------------------------------- 1 | from voice_smith.g2p.dp.result import Prediction, PhonemizerResult -------------------------------------------------------------------------------- /backend/voice_smith/g2p/dp/configs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dunky11/voicesmith/35b00c8619f5272b4ec2e087a7cc871c4f71aadf/backend/voice_smith/g2p/dp/configs/__init__.py -------------------------------------------------------------------------------- /backend/voice_smith/g2p/dp/configs/forward_config.yaml: -------------------------------------------------------------------------------- 1 | paths: 2 | checkpoint_dir: checkpoints # Directory to store model checkpoints and tensorboard, will be created if not existing. 3 | data_dir: datasets # Directory to store processed data, will be created if not existing. 4 | 5 | preprocessing: 6 | languages: ["en_us"] # All languages in the dataset. 7 | 8 | # Text (grapheme) and phoneme symbols, either provide a string or list of strings. 9 | # Symbols in the dataset will be filtered according to these lists! 10 | text_symbols: "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZäöüÄÖÜß'" 11 | phoneme_symbols: 12 | [ 13 | "a", 14 | "b", 15 | "d", 16 | "e", 17 | "f", 18 | "g", 19 | "h", 20 | "i", 21 | "j", 22 | "k", 23 | "l", 24 | "m", 25 | "n", 26 | "o", 27 | "p", 28 | "r", 29 | "s", 30 | "t", 31 | "u", 32 | "v", 33 | "w", 34 | "x", 35 | "y", 36 | "z", 37 | "æ", 38 | "ç", 39 | "ð", 40 | "ø", 41 | "ŋ", 42 | "œ", 43 | "ɐ", 44 | "ɑ", 45 | "ɔ", 46 | "ə", 47 | "ɛ", 48 | "ɝ", 49 | "ɹ", 50 | "ɡ", 51 | "ɪ", 52 | "ʁ", 53 | "ʃ", 54 | "ʊ", 55 | "ʌ", 56 | "ʏ", 57 | "ʒ", 58 | "ʔ", 59 | "ˈ", 60 | "ˌ", 61 | "ː", 62 | "̃", 63 | "̍", 64 | "̥", 65 | "̩", 66 | "̯", 67 | "͡", 68 | "θ", 69 | ] 70 | 71 | char_repeats: 72 | 3 # Number of grapheme character repeats to allow for mapping to longer phoneme sequences. 73 | # Set to 1 for autoreg_transformer. 74 | lowercase: true # Whether to lowercase the grapheme input. 75 | n_val: 5000 # Default number of validation data points if no explicit validation data is provided. 76 | 77 | model: 78 | type: 79 | "autoreg_transformer" # Whether to use a forward transformer or autoregressive transformer model. 80 | # Choices: ['transformer', 'autoreg_transformer'] 81 | d_model: 384 82 | d_fft: 1536 83 | layers: 5 84 | dropout: 0.1 85 | heads: 8 86 | 87 | training: 88 | # Hyperparams for learning rate and scheduler. 89 | # The scheduler is reducing the lr on plateau of phoneme error rate (tested every n_generate_steps). 90 | 91 | learning_rate: 0.0001 # Learning rate of Adam. 92 | warmup_steps: 10000 # Linear increase of the lr from zero to the given lr within the given number of steps. 93 | scheduler_plateau_factor: 0.5 # Factor to multiply learning rate on plateau. 94 | scheduler_plateau_patience: 100 # Number of text generations with no improvement to tolerate. 95 | batch_size: 192 # Training batch size. 96 | batch_size_val: 192 # Validation batch size. 97 | epochs: 500 # Number of epochs to train. 98 | generate_steps: 99 | 1000 # Interval of training steps to generate sample outputs. Also, at this step the phoneme and word 100 | # error rates are calculated for the scheduler. 101 | validate_steps: 102 | 1000 # Interval of training steps to validate the model 103 | # (for the autoregressive model this is teacher-forced). 104 | checkpoint_steps: 2000 # Interval of training steps to save the model. 105 | n_generate_samples: 10 # Number of result samples to show on tensorboard. 106 | store_phoneme_dict_in_model: 107 | true # Whether to store the raw phoneme dict in the model. 108 | # It will be loaded by the phonemizer object. 109 | -------------------------------------------------------------------------------- /backend/voice_smith/g2p/dp/configs/logging.yaml: -------------------------------------------------------------------------------- 1 | version: 1 2 | 3 | formatters: 4 | extended: 5 | format: "%(asctime)s.%(msecs)01d %(levelname)s %(module)s: %(message).1066s" 6 | 7 | handlers: 8 | console: 9 | class: logging.StreamHandler 10 | formatter: extended 11 | 12 | root: 13 | handlers: [console] 14 | level: DEBUG 15 | -------------------------------------------------------------------------------- /backend/voice_smith/g2p/dp/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dunky11/voicesmith/35b00c8619f5272b4ec2e087a7cc871c4f71aadf/backend/voice_smith/g2p/dp/model/__init__.py -------------------------------------------------------------------------------- /backend/voice_smith/g2p/dp/model/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Tuple 3 | 4 | import torch 5 | from torch.nn.utils.rnn import pad_sequence 6 | 7 | 8 | class PositionalEncoding(torch.nn.Module): 9 | 10 | def __init__(self, d_model: int, dropout=0.1, max_len=5000) -> None: 11 | """ 12 | Initializes positional encoding. 13 | 14 | Args: 15 | d_model (int): Dimension of model. 16 | dropout (float): Dropout after positional encoding. 17 | max_len: Max length of precalculated position sequence. 18 | """ 19 | 20 | super().__init__() 21 | self.dropout = torch.nn.Dropout(p=dropout) 22 | self.scale = torch.nn.Parameter(torch.ones(1)) 23 | 24 | pe = torch.zeros(max_len, d_model) 25 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 26 | div_term = torch.exp(torch.arange( 27 | 0, d_model, 2).float() * (-math.log(10000.0) / d_model)) 28 | pe[:, 0::2] = torch.sin(position * div_term) 29 | pe[:, 1::2] = torch.cos(position * div_term) 30 | pe = pe.unsqueeze(0).transpose(0, 1) 31 | self.register_buffer('pe', pe) 32 | 33 | def forward(self, x: torch.Tensor) -> torch.Tensor: # shape: [T, N] 34 | x = x + self.scale * self.pe[:x.size(0), :] 35 | return self.dropout(x) 36 | 37 | 38 | def get_dedup_tokens(logits_batch: torch.Tensor) \ 39 | -> Tuple[torch.Tensor, torch.Tensor]: 40 | """Converts a batch of logits into the batch most probable tokens and their probabilities. 41 | 42 | Args: 43 | logits_batch (Tensor): Batch of logits (N x T x V). 44 | 45 | Returns: 46 | Tuple: Deduplicated tokens. The first element is a tensor (token indices) and the second element 47 | is a tensor (token probabilities) 48 | 49 | """ 50 | 51 | logits_batch = logits_batch.softmax(-1) 52 | out_tokens, out_probs = [], [] 53 | for i in range(logits_batch.size(0)): 54 | logits = logits_batch[i] 55 | max_logits, max_indices = torch.max(logits, dim=-1) 56 | max_logits = max_logits[max_indices!=0] 57 | max_indices = max_indices[max_indices!=0] 58 | cons_tokens, counts = torch.unique_consecutive( 59 | max_indices, return_counts=True) 60 | out_probs_i = torch.zeros(len(counts), device=logits.device) 61 | ind = 0 62 | for i, c in enumerate(counts): 63 | max_logit = max_logits[ind:ind + c].max() 64 | out_probs_i[i] = max_logit 65 | ind = ind + c 66 | out_tokens.append(cons_tokens) 67 | out_probs.append(out_probs_i) 68 | 69 | out_tokens = pad_sequence(out_tokens, batch_first=True, padding_value=0.).long() 70 | out_probs = pad_sequence(out_probs, batch_first=True, padding_value=0.) 71 | 72 | return out_tokens, out_probs 73 | 74 | 75 | def _generate_square_subsequent_mask(sz: int) -> torch.Tensor: 76 | mask = torch.triu(torch.ones(sz, sz), 1) 77 | mask = mask.masked_fill(mask == 1, float('-inf')) 78 | return mask 79 | 80 | 81 | def _make_len_mask(inp: torch.Tensor) -> torch.Tensor: 82 | return (inp == 0).transpose(0, 1) 83 | 84 | 85 | def _get_len_util_stop(sequence: torch.Tensor, end_index: int) -> int: 86 | for i, val in enumerate(sequence): 87 | if val == end_index: 88 | return i + 1 89 | return len(sequence) 90 | 91 | 92 | def _trim_util_stop(sequence: torch.Tensor, end_index: int) -> torch.Tensor: 93 | seq_len = _get_len_util_stop(sequence, end_index) 94 | return sequence[:seq_len] 95 | -------------------------------------------------------------------------------- /backend/voice_smith/g2p/dp/preprocessing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dunky11/voicesmith/35b00c8619f5272b4ec2e087a7cc871c4f71aadf/backend/voice_smith/g2p/dp/preprocessing/__init__.py -------------------------------------------------------------------------------- /backend/voice_smith/g2p/dp/preprocessing/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import List, Union, Any 3 | 4 | 5 | def _product(probs: Union[None, List[float]]) -> float: 6 | if probs is None or len(probs) == 0: 7 | return 0. 8 | if 0 in probs: 9 | return 0 10 | prob = math.exp(sum([math.log(p) for p in probs])) 11 | return prob 12 | 13 | 14 | def _batchify(input: List[Any], batch_size: int) -> List[List[Any]]: 15 | l = len(input) 16 | output = [] 17 | for i in range(0, l, batch_size): 18 | batch = input[i:min(i + batch_size, l)] 19 | output.append(batch) 20 | return output 21 | 22 | -------------------------------------------------------------------------------- /backend/voice_smith/g2p/dp/result.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict 2 | 3 | 4 | class Prediction: 5 | """ 6 | Container for single word prediction result. 7 | """ 8 | 9 | def __init__( 10 | self, 11 | word: str, 12 | phonemes: str, 13 | phonemes_list: List[str], 14 | phoneme_tokens: List[str], 15 | confidence: float, 16 | token_probs: List[float], 17 | ) -> None: 18 | """ 19 | Initializes a Prediction object. 20 | 21 | Args: 22 | word (str): Original word to predict. 23 | phonemes (str): Predicted phonemes (without special tokens). 24 | phonemes_list (List[str]): Predicted phoneme tokens (withou special tokens). 25 | phoneme_tokens (List[str]): Predicted phoneme tokens (including special tokens). 26 | confidence (float): Total confidence of result. 27 | token_probs (List[float]): Probability of each phoneme token. 28 | """ 29 | 30 | self.word = word 31 | self.phonemes = phonemes 32 | self.phonemes_list = phonemes_list 33 | self.phoneme_tokens = phoneme_tokens 34 | self.confidence = confidence 35 | self.token_probs = token_probs 36 | 37 | 38 | class PhonemizerResult: 39 | """ 40 | Container for phonemizer output. 41 | """ 42 | 43 | def __init__( 44 | self, 45 | text: List[str], 46 | split_text: List[List[str]], 47 | split_phonemes: List[List[str]], 48 | predictions: Dict[str, Prediction], 49 | ) -> None: 50 | """ 51 | Initializes a PhonemizerResult object. 52 | 53 | Args: 54 | text (List[str]): List of input texts. 55 | split_text (List[List[str]]): List of texts, where each text is split into words and special chars. 56 | split_phonemes (List[List[str]]): List of phonemes corresponding to split_text. 57 | predictions (Dict[str, Prediction]): Dictionary with entries word to Tuple (phoneme, probability). 58 | """ 59 | 60 | self.text = text 61 | self.split_text = split_text 62 | self.split_phonemes = split_phonemes 63 | self.predictions = predictions 64 | -------------------------------------------------------------------------------- /backend/voice_smith/g2p/dp/train.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Dict, List 3 | from voice_smith.g2p.dp.model.model import load_checkpoint, ModelType, create_model 4 | from voice_smith.g2p.dp.preprocessing.text import Preprocessor 5 | from voice_smith.g2p.dp.training.trainer import Trainer 6 | from voice_smith.g2p.dp.utils.logging import get_logger 7 | from voice_smith.utils.model import get_param_num 8 | 9 | logger = get_logger(__name__) 10 | 11 | 12 | def train( 13 | config: str, 14 | name: str, 15 | lang_to_word_to_gold: Dict[str, Dict[str, List[List[str]]]], 16 | checkpoint_file: str = None, 17 | ) -> None: 18 | """ 19 | Runs training of a transformer model. 20 | 21 | Args: 22 | config_file (str): Path to the config.yaml that stores all necessary parameters. 23 | checkpoint_file (str, optional): Path to a model checkpoint to resume training for (e.g. latest_model.pt) 24 | 25 | Returns: 26 | None: The model checkpoints are stored in a folder provided by the config. 27 | 28 | """ 29 | 30 | if checkpoint_file is not None: 31 | logger.info(f"Restoring model from checkpoint: {checkpoint_file}") 32 | model, checkpoint = load_checkpoint(checkpoint_file) 33 | model.train() 34 | step = checkpoint["step"] 35 | logger.info(f"Loaded model with step: {step}") 36 | for key, val in config["training"].items(): 37 | val_orig = checkpoint["config"]["training"][key] 38 | if val_orig != val: 39 | logger.info(f"Overwriting training param: {key} {val_orig} --> {val}") 40 | checkpoint["config"]["training"][key] = val 41 | config = checkpoint["config"] 42 | model_type = config["model"]["type"] 43 | model_type = ModelType(model_type) 44 | else: 45 | logger.info("Initializing new model from config...") 46 | preprocessor = Preprocessor.from_config(config) 47 | model_type = config["model"]["type"] 48 | model_type = ModelType(model_type) 49 | model = create_model(model_type, config=config) 50 | checkpoint = { 51 | "config": config, 52 | } 53 | 54 | print(f"Total number of parameters: {get_param_num(model)}") 55 | 56 | if "preprocessor" in checkpoint.keys(): 57 | del checkpoint["preprocessor"] 58 | 59 | checkpoint_dir = Path(config["paths"]["checkpoint_dir"]) 60 | logger.info(f"Checkpoints will be stored at {checkpoint_dir.absolute()}") 61 | loss_type = "cross_entropy" if model_type.is_autoregressive() else "ctc" 62 | trainer = Trainer( 63 | checkpoint_dir=checkpoint_dir, loss_type=loss_type, name=name, config=config 64 | ) 65 | trainer.train( 66 | model=model, 67 | checkpoint=checkpoint, 68 | store_phoneme_dict_in_model=config["training"]["store_phoneme_dict_in_model"], 69 | lang_to_word_to_gold=lang_to_word_to_gold, 70 | ) 71 | -------------------------------------------------------------------------------- /backend/voice_smith/g2p/dp/training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dunky11/voicesmith/35b00c8619f5272b4ec2e087a7cc871c4f71aadf/backend/voice_smith/g2p/dp/training/__init__.py -------------------------------------------------------------------------------- /backend/voice_smith/g2p/dp/training/decorators.py: -------------------------------------------------------------------------------- 1 | import traceback 2 | 3 | 4 | def ignore_exception(f): 5 | 6 | def apply_func(*args, **kwargs): 7 | try: 8 | result = f(*args, **kwargs) 9 | return result 10 | except Exception: 11 | print(f'Catched exception in {f}:') 12 | traceback.print_exc() 13 | return None 14 | return apply_func -------------------------------------------------------------------------------- /backend/voice_smith/g2p/dp/training/evaluation.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Dict, Any 2 | from voice_smith.g2p.dp.training.metrics import phoneme_error, word_error 3 | 4 | 5 | def evaluate_samples(lang_samples: Dict[str, List[Tuple[List[str], List[str], List[str]]]], lang_to_word_to_gold) -> Dict[str, Any]: 6 | """Calculates word and phoneme error rates per language and their mean across languages 7 | Args: 8 | lang_samples (Dict): Data to evaluate. Contains languages as keys and list of result samples as values. 9 | Prediction samples is given as a List of Tuples, where each Tuple is a tokenized representation of 10 | (text, result, target). 11 | Returns: 12 | Dict: Evaluation result carrying word and phoneme error rates per language. 13 | """ 14 | 15 | evaluation_result = dict() 16 | lang_phon_err, lang_phon_count, lang_word_err = dict(), dict(), dict() 17 | languages = sorted(lang_samples.keys()) 18 | for lang in languages: 19 | for word, generated, target in lang_samples[lang]: 20 | word = ''.join(word) 21 | phon_err, phon_count = phoneme_error(generated, target) 22 | word_err = word_error(generated, target) 23 | phon_err_dict = lang_phon_err.setdefault(lang, dict()) 24 | phon_count_dict = lang_phon_count.setdefault(lang, dict()) 25 | word_err_dict = lang_word_err.setdefault(lang, dict()) 26 | best_phon_err, best_phon_count = phon_err_dict.get(word, None), phon_count_dict.get(word, None) 27 | if best_phon_err is None or phon_err / phon_count < best_phon_err / best_phon_count: 28 | phon_err_dict[word] = phon_err 29 | phon_count_dict[word] = phon_count 30 | word_err_dict[word] = word_err 31 | 32 | phon_errors, phon_counts, word_errors, word_counts = [], [], [], [] 33 | for lang in languages: 34 | phon_err = sum(lang_phon_err[lang].values()) 35 | phon_errors.append(phon_err) 36 | phon_count = sum(lang_phon_count[lang].values()) 37 | phon_counts.append(phon_count) 38 | word_err = sum(lang_word_err[lang].values()) 39 | word_errors.append(word_err) 40 | word_count = len(lang_word_err[lang]) 41 | word_counts.append(word_count) 42 | per = phon_err / phon_count 43 | wer = word_err / word_count 44 | evaluation_result.setdefault(lang, {}).update({'per': per}) 45 | evaluation_result.setdefault(lang, {}).update({'wer': wer}) 46 | mean_per = sum(phon_errors) / sum(phon_counts) 47 | mean_wer = sum(word_errors) / sum(word_counts) 48 | evaluation_result['mean_per'] = mean_per 49 | evaluation_result['mean_wer'] = mean_wer 50 | 51 | return evaluation_result -------------------------------------------------------------------------------- /backend/voice_smith/g2p/dp/training/losses.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import torch 4 | 5 | 6 | class CrossEntropyLoss(torch.nn.Module): 7 | """ """ 8 | 9 | def __init__(self) -> None: 10 | super().__init__() 11 | self.criterion = torch.nn.CrossEntropyLoss(ignore_index=0) 12 | 13 | def forward(self, 14 | pred: torch.Tensor, 15 | batch: Dict[str, torch.Tensor]) -> torch.Tensor: 16 | """Forward pass of the CrossEntropyLoss module on a batch. 17 | 18 | Args: 19 | pred: Batch of model predictions. 20 | batch: Dictionary of a training data batch, containing 'phonemes': target phonemes. 21 | pred: torch.Tensor: 22 | batch: Dict[str: 23 | torch.Tensor]: 24 | 25 | Returns: 26 | Loss as tensor. 27 | 28 | """ 29 | 30 | phonemes = batch['phonemes'] 31 | loss = self.criterion(pred.transpose(1, 2), phonemes[:, 1:]) 32 | return loss 33 | 34 | 35 | class CTCLoss(torch.nn.Module): 36 | """ """ 37 | 38 | def __init__(self): 39 | super().__init__() 40 | self.criterion = torch.nn.CTCLoss() 41 | 42 | def forward(self, 43 | pred: torch.Tensor, 44 | batch: Dict[str, torch.Tensor]) -> torch.Tensor: 45 | """Forward pass of the CTCLoss module on a batch. 46 | 47 | Args: 48 | pred: Batch of model predictions. 49 | batch: Dictionary of a training data batch, containing 'phonemes': target phonemes, 50 | 'text_len': input text lengths, 'phonemes_len': target phoneme lengths 51 | pred: torch.Tensor: 52 | batch: Dict[str: 53 | torch.Tensor]: 54 | 55 | Returns: 56 | Loss as tensor. 57 | 58 | """ 59 | 60 | pred = pred.transpose(0, 1).log_softmax(2) 61 | phonemes = batch['phonemes'] 62 | text_len = batch['text_len'] 63 | phon_len = batch['phonemes_len'] 64 | loss = self.criterion(pred, phonemes, text_len, phon_len) 65 | return loss 66 | -------------------------------------------------------------------------------- /backend/voice_smith/g2p/dp/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dunky11/voicesmith/35b00c8619f5272b4ec2e087a7cc871c4f71aadf/backend/voice_smith/g2p/dp/utils/__init__.py -------------------------------------------------------------------------------- /backend/voice_smith/g2p/dp/utils/infer.py: -------------------------------------------------------------------------------- 1 | from voice_smith.g2p.dp.phonemizer import Phonemizer 2 | from typing import List 3 | 4 | 5 | def batched_predict( 6 | model: Phonemizer, texts: List[str], langs: List[str], batch_size=32 7 | ): 8 | assert len(texts) == len(langs) 9 | return model.phonemise_list(texts, langs=langs, batch_size=batch_size) 10 | -------------------------------------------------------------------------------- /backend/voice_smith/g2p/dp/utils/io.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from pathlib import Path 3 | from typing import Dict, List, Any, Union 4 | 5 | import torch 6 | import yaml 7 | 8 | 9 | def read_config(path: str) -> Dict[str, Any]: 10 | """ 11 | Reads the config dictionary from the yaml file. 12 | 13 | Args: 14 | path (str): Path to the .yaml file. 15 | 16 | Returns: 17 | Dict[str, Any]: Configuration. 18 | 19 | """ 20 | 21 | with open(path, 'r', encoding='utf-8') as stream: 22 | config = yaml.load(stream, Loader=yaml.FullLoader) 23 | return config 24 | 25 | 26 | def save_config(config: Dict[str, Any], path: str) -> None: 27 | """ 28 | Saves the config as a yaml file. 29 | 30 | Args: 31 | config (Dict[str, Any]): Configuration. 32 | path (str): Path to save the dictionary to (.yaml). 33 | """ 34 | 35 | with open(path, 'w+', encoding='utf-8') as stream: 36 | yaml.dump(config, stream, default_flow_style=False) 37 | 38 | 39 | def get_files(path: str, extension: str = '.wav') -> List[Path]: 40 | """ 41 | Recursively retrieves all files with a given extension from a folder. 42 | 43 | Args: 44 | path (str): Path to the folder to retrieve files from. 45 | extension (str): Extension of files to be retrieved (Default value = '.wav'). 46 | 47 | Returns: 48 | List[Path]: List of paths to the found files. 49 | """ 50 | 51 | return list(Path(path).expanduser().resolve().rglob(f'*{extension}')) 52 | 53 | 54 | def pickle_binary(data: object, file: Union[str, Path]) -> None: 55 | """ 56 | Pickles a given object to a binary file. 57 | 58 | Args: 59 | data (object): Object to be pickled. 60 | file (Union[str, Path]): Path to destination file (use the .pkl extension). 61 | """ 62 | 63 | with open(str(file), 'wb') as f: 64 | pickle.dump(data, f) 65 | 66 | 67 | def unpickle_binary(file: Union[str, Path]) -> object: 68 | """ 69 | Unpickles a given binary file to an object 70 | 71 | Args: 72 | file (nion[str, Path]): Path to the file. 73 | 74 | Returns: 75 | object: Unpickled object. 76 | 77 | """ 78 | 79 | with open(str(file), 'rb') as f: 80 | return pickle.load(f) 81 | 82 | 83 | def to_device(batch: Dict[str, torch.Tensor], device: torch.device) -> Dict[str, torch.Tensor]: 84 | """ 85 | Sends a batch of data to the given torch devicee (cpu or cuda). 86 | 87 | Args: 88 | batch (Dict[str, torch.Tensor]): Batch to be send to the device. 89 | device (torch.device): Device (either torch.device('cpu') or torch.device('cuda'). 90 | 91 | Returns: 92 | Dict[str, torch.Tensor]: The batch at the given device. 93 | 94 | """ 95 | 96 | return {key: val.to(device) for key, val in batch.items()} -------------------------------------------------------------------------------- /backend/voice_smith/g2p/dp/utils/logging.py: -------------------------------------------------------------------------------- 1 | import logging.config 2 | import logging.config 3 | import os 4 | from logging import Logger, getLogger 5 | from pathlib import Path 6 | 7 | import voice_smith.g2p.dp.configs 8 | from voice_smith.g2p.dp.utils.io import read_config 9 | 10 | main_dir = os.path.dirname(os.path.abspath(voice_smith.g2p.dp.configs.__file__)) 11 | config_file_path = Path(main_dir) / 'logging.yaml' 12 | config = read_config(config_file_path) 13 | logging.config.dictConfig(config) 14 | 15 | 16 | def get_logger(name: str) -> Logger: 17 | """ 18 | Creates a logger object for a given name. 19 | 20 | Args: 21 | name (str): Name of the logger. 22 | 23 | Returns: 24 | Logger: Logger object with given name. 25 | """ 26 | 27 | logger = getLogger(name) 28 | return logger -------------------------------------------------------------------------------- /backend/voice_smith/g2p/dp/utils/model.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import torch 3 | from voice_smith.g2p.dp.phonemizer import Phonemizer 4 | 5 | def get_g2p(assets_path: str, device: torch.device) -> Phonemizer: 6 | checkpoint_path = Path(assets_path) / "g2p" / "en" / "g2p.pt" 7 | phonemizer = Phonemizer.from_checkpoint(str(checkpoint_path), device=device) 8 | return phonemizer 9 | -------------------------------------------------------------------------------- /backend/voice_smith/g2p/lang2transctibe.txt: -------------------------------------------------------------------------------- 1 | bg: https://huggingface.co/anuragshas/wav2vec2-large-xls-r-300m-bg 2 | dataset: Common Voice Bulgarian 8 3 | Test WER: 21.195 4 | 5 | cs: https://huggingface.co/comodoro/wav2vec2-xls-r-300m-cs-250 6 | dataset: Common Voice Czech 8 7 | Test WER: 0.1475 8 | 9 | de: 10 | dataset: -------------------------------------------------------------------------------- /backend/voice_smith/g2p/parse_dictionary.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from tqdm import tqdm 3 | from typing import List, Tuple, Dict 4 | 5 | 6 | 7 | def parse_dictionary( 8 | dictionary_path: str, name: str 9 | ) -> Tuple[List[Tuple[str, str, List[str]]], List[str], List[str], Dict[str, List[List[str]]]]: 10 | word_to_gold: Dict[str, List[List[str]]] = {} 11 | out = [] 12 | all_phones = [] 13 | words_preprocessed = {} 14 | with open(str(dictionary_path), "r", encoding="utf-8") as f: 15 | for line in tqdm(f.readlines()): 16 | line = line.strip().split() 17 | word = line[0] 18 | phones = line[1 + 4 :] 19 | word = word.lower() 20 | words_preprocessed[word] = 0 21 | out.append((name, word, phones)) 22 | all_phones.extend(phones) 23 | if word in word_to_gold: 24 | word_to_gold[word].append(phones) 25 | else: 26 | word_to_gold[word] = [phones] 27 | unique_phones = list(set(all_phones)) 28 | text_symbols = list(set("".join(words_preprocessed))) 29 | return out, unique_phones, text_symbols, word_to_gold 30 | -------------------------------------------------------------------------------- /backend/voice_smith/g2p/phones.txt: -------------------------------------------------------------------------------- 1 | ['a˦˨', 'ɛ˩˩˦', 'tʃː', 't̪s̪', 'oː˥˧', 'vʲː', 'o', 'l', 'lː', 'øː˧˩', 'ʑ', 'rʲ', 'tɕ', 'ɯ˧', 'aː˥˩', 'm̩', 'iː˦˨', 'ua˧', 'hː', 'ɪː', 'ɔ˥˩', 'ɚ', 'd̪z̪', 'vː', 'ɰ', 'tsʲ', 'ɛ˨˩', 'ʂː', 'zʲ', 'tʃ', 'i˨˦', 'uː˧', 'y', 'ɛː˩˩˦', 'oː˩', 'ia˨˩', 'b', 'o˩˩˦', 'j̃', 'ɛː˧˩', 'ɯa˧', 'ɤː˧', 'ɵ˧˩', 'eː˧', 'ɔʏ', 'iː', 'bʲ', 'n̪', 'uː˩˩˦', 'ũ', 'iː˦˥', 'aj', 'ʒʲ', 'fʲː', 'ɐ', 'ua˦˥', 'sʲː', 'ɪ', 'd̪ː', 'eː˦˥', 'aː˨˦', 'dʐ', 'ð', 'v', 'e˩˩˦', 'bʲː', 'oː˨˩', 'i˧', 'e˦˥', 'ɔː˧', 'u˨˦', 'ɛː˧', 'pʲː', 'ɐ̃', 'θ', 'ua˥˩', 'u', 'ɔ˥˧', 'ɤ˧', 'ɫ̩', 'a˥˧', 'dʒ', 'ʂ', 'dʲː', 'd̪', 'ia˦˥', 'n̩', 'ʉː˥˩', 'dʑ', 'ɔ˩', 'ʉ', 'ɛ˧', 'e˨˦', 'z̪', 't', 'm', 'eː˧˩', 'ɯ', 'ʒ', 'rː', 'dzʲː', 'r̩˦˨', 'ɛː˥˩', 't̪ʰ', 'uː˨˩', 'ʉː', 'tʰ', 'o˨˩', 'ɔː˩˩˦', 'ʐ', 'xː', 'ʎː', 'ɟʝ', 'yː˧˩', 'd̪z̪ː', 'ɛ˥˩', 'z', 'ia˩˩˦', 'ə', 'i˦˨', 's̪', 'ɛ˩', 'ɑ̃', 't̪', 'ɕː', 'yː', 'u˥˩', 'ʃʲː', 'ŋ', 'z̪ː', 'eː', 'ɔː˨˩', 'ʏ˩', 'œ˥˩', 'f', 'ɵ', 'ɝ', 'ɛ˥˧', 'ia˥˩', 'ɯ˦˥', 'ɔ˧˩', 'ẽ', 'β', 'ɡ', 'ɔ˨˩', 'dʲ', 'ʊ˥˩', 'uː˨˦', 'eː˩', 'uː˧˩', 'yː˥˧', 'ɔ̃', 'oː˨˦', 'uː˥˩', 'ɤ', 'x', 'ʋʲ', 'ɤː˥˩', 'a', 'ɧ', 'ɲː', 'ɯa˩˩˦', 'eː˦˨', 'w̃', 'aː˩˩˦', 'ʊ', 'ʏ', 'tɕʰ', 'øː˩', 'ɯː˥˩', 'ʈ', 'ɟː', 'a˩', 'ɭ', 'ɾʲː', 'ɯː˩˩˦', 'n̪ː', 'ɤː˩˩˦', 'ɤ˦˥', 'uː˦˨', 'dzʲ', 'eː˩˩˦', 'ɤ˩˩˦', 'ʏ˥˩', 'ɔ', 'œ˧˩', 'ɡː', 'ɑː', 'a˦˥', 'ew', 'ĩ', 'fː', 'ia˧', 'tɕː', 'kʰ', 'bː', 'i˥˩', 'ɯa˨˩', 'ɔ˩˩˦', 'ɵ˥˧', 'eː˨˩', 'oː˦˨', 'l̩', 'ʏ˧˩', 'øː', 'ɪ˥˧', 'ɯː˦˥', 'o˦˥', 'ɔ˧', 'i˨˩', 'j', 'õ', 'ɯː˨˩', 'aː', 'tsʲː', 'ɵ˥˩', 'p', 'oː˧', 'pʰ', 'ʉː˥˧', 'tʂː', 'aː˦˥', 'h', 'ɨ', 'ɯ˨˩', 'ɖ', 't̚', 'a˥˩', 'ɤ˨˩', 'ɯa˥˩', 'ɛː˨˩', 'ʒʲː', 'ɾ', 'ɑː˩', 'ʋː', 'i˩˩˦', 'ɤ˥˩', 'ɤː˦˥', 'u˦˥', 'ɑ', 'ɲ', 'ʃː', 'cʰ', 'æ', 'ɯ˥˩', 'ʝ', 'ʏ˥˧', 'yː˥˩', 'ow', 'ʊ˩', 'eː˥˩', 'u˧', 'k', 'yː˩', 'ɛː', 'oː˩˩˦', 'ɫ', 'eː˨˦', 'ɯː˧', 'ɕ', 'ɒ', 'w', 'oː', 'a˨˦', 'aː˨˩', 'ʑː', 'ɑː˥˧', 'iː˨˦', 'ʋ', 'ʋʲː', 'r̩', 'a˨˩', 'ɔj', 'u˨˩', 't̪ː', 'a˩˩˦', 'p̚', 's', 'ʎ', 'ʐː', 'ɒː', 'e', 'ɣ', 'dʒː', 'ç', 'o˨˦', 'tʂ', 'oː˧˩', 'ɱ', 'œ', 'pʲ', 'ɛː˦˥', 'ɑː˥˩', 's̪ː', 'fʲ', 'a˧˩', 'uː˦˥', 'a˧', 'ɦ', 'ɪ˥˩', 'zʲː', 'ɫː', 'r̝', 'i˦˥', 'r̩ː˦˨', 'ɛ̃', 'uː˥˧', 'ɔː˦˥', 'ua˨˩', 'e˦˨', 'o˥˩', 'ɯː', 'øː˥˩', 'e˨˩', 'sʲ', 'øː˥˧', 'ɦː', 'ɛ', 't̪s̪ː', 'ɑː˧˩', 'pː', 'c', 'iː˥˩', 'jː', 'd', 'n', 'ʊ˧˩', 'vʲ', 'ʃʲ', 'ɪ˧˩', 'ɯa˦˥', 'pf', 'o˧', 'ts', 'e˥˩', 'ʈʰ', 'iː˥˧', 'u˩˩˦', 'tʲ', 'ua˩˩˦', 'r̩˨˦', 'ʁ', 'iː˩˩˦', 'mː', 'ɛ˦˥', 'ɥ', 'ɟ', 'ɵ˩', 'ɔː˥˩', 'ʃ', 'mʲ', 'tʃʲː', 'ɹ', 'ø', 'cː', 'ʔ', 'iː˧', 'aː˧', 'aː˦˨', 'ɳ', 'tʲː', 'r', 'ɛː˥˧', 'aw', 'oː˦˥', 'ɛ˧˩', 'rʲː', 'kː', 'tʃʲ', 'eː˥˧', 'mʲː', 'oː˥˩', 'i', 'ɯ˩˩˦', 'uː', 'iː˧˩', 'r̩ː˨˦', 'o˦˨', 'dʐː', 'ɪ˩', 'k̚', 'e˧', 'ɾː', 'ɔ˦˥', 'iː˨˩', 'ʉː˧˩', 'çː', 'ɾʲ', 'ej', 'ɤː˨˩', 'u˦˨'] -------------------------------------------------------------------------------- /backend/voice_smith/g2p/run_prediction.py: -------------------------------------------------------------------------------- 1 | from voice_smith.g2p.dp.phonemizer import Phonemizer 2 | from voice_smith.utils.model import get_param_num 3 | 4 | if __name__ == "__main__": 5 | 6 | checkpoint_path = "checkpoints/best_model_no_optim.pt" 7 | phonemizer = Phonemizer.from_checkpoint(checkpoint_path) 8 | print(get_param_num(phonemizer.predictor.model)) 9 | text = "young" 10 | 11 | result = phonemizer.phonemise_list([text], lang="en_us") 12 | 13 | print(result.phonemes) 14 | for text, pred in result.predictions.items(): 15 | tokens, probs = pred.phoneme_tokens, pred.token_probs 16 | for o, p in zip(tokens, probs): 17 | print(f"{o} {p}") 18 | tokens = "".join(tokens) 19 | print(f"{text} | {tokens} | {pred.confidence}") 20 | 21 | -------------------------------------------------------------------------------- /backend/voice_smith/g2p/run_training.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import random 3 | import argparse 4 | from voice_smith.g2p.dp.preprocess import preprocess 5 | from voice_smith.g2p.dp.train import train 6 | from voice_smith.g2p.dp.utils.io import read_config 7 | from voice_smith.g2p.parse_dictionary import parse_dictionary 8 | import json 9 | 10 | 11 | perform_benchmark = False 12 | 13 | name = "G2P Byte MFA training 6x6 transformer (384, 1536), [15 MFA langs]" 14 | 15 | if perform_benchmark: 16 | SPLIT_SIZE = 12753 17 | else: 18 | SPLIT_SIZE = 10000 19 | 20 | if __name__ == "__main__": 21 | if perform_benchmark: 22 | print("Benchmarking on CMUDict ...") 23 | else: 24 | print("Training model for production ...") 25 | 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument("--checkpoint", type=str, default=None) 28 | args = parser.parse_args() 29 | 30 | data, phones, text_symbols = [], [], [] 31 | lang_to_word_to_gold = {} 32 | config = read_config(Path(".") / "dp" / "configs" / "autoreg_config.yaml") 33 | if args.checkpoint is None: 34 | for dictionary_path in [ 35 | Path(".") / "dictionaries" / "bg" / "bulgarian_mfa.dict", 36 | Path(".") / "dictionaries" / "cs" / "czech_mfa.dict", 37 | Path(".") / "dictionaries" / "de" / "german_mfa.dict", 38 | Path(".") / "dictionaries" / "en" / "english_us_mfa.dict", 39 | Path(".") / "dictionaries" / "es" / "spanish_mfa.dict", 40 | Path(".") / "dictionaries" / "fr" / "french_mfa.dict", 41 | Path(".") / "dictionaries" / "hr" / "croatian_mfa.dict", 42 | Path(".") / "dictionaries" / "pl" / "polish_mfa.dict", 43 | Path(".") / "dictionaries" / "pt" / "portuguese_portugal_mfa.dict", 44 | Path(".") / "dictionaries" / "ru" / "russian_mfa.dict", 45 | Path(".") / "dictionaries" / "sv" / "swedish_mfa.dict", 46 | Path(".") / "dictionaries" / "th" / "thai_mfa.dict", 47 | Path(".") / "dictionaries" / "tr" / "turkish_mfa.dict", 48 | Path(".") / "dictionaries" / "uk" / "ukrainian_mfa.dict", 49 | ]: 50 | lang = dictionary_path.parent.name 51 | d, p, t, word_to_gold = parse_dictionary(dictionary_path, lang) 52 | data.extend(d) 53 | phones.extend(p) 54 | text_symbols.extend(t) 55 | lang_to_word_to_gold[lang] = word_to_gold 56 | 57 | phones = list(set(phones)) 58 | text_symbols = list(str(el) for el in range(256)) + [""] 59 | 60 | config["preprocessing"]["phoneme_symbols"] = phones 61 | config["preprocessing"]["text_symbols"] = text_symbols 62 | if not perform_benchmark: 63 | random.shuffle(data) 64 | train_data, val_data = data[SPLIT_SIZE:], data[:SPLIT_SIZE] 65 | 66 | preprocess( 67 | config=config, 68 | train_data=train_data, 69 | val_data=val_data, 70 | deduplicate_train_data=False, 71 | ) 72 | 73 | with open( 74 | Path(".") / "datasets" / "lang_to_word_to_gold.json", "w", encoding="utf-8" 75 | ) as f: 76 | f.write(json.dumps(lang_to_word_to_gold)) 77 | 78 | else: 79 | with open( 80 | Path(".") / "datasets" / "lang_to_word_to_gold.json", "r", encoding="utf-8" 81 | ) as f: 82 | lang_to_word_to_gold = json.load(f) 83 | 84 | train( 85 | config=config, 86 | checkpoint_file=args.checkpoint, 87 | name=name, 88 | lang_to_word_to_gold=lang_to_word_to_gold, 89 | ) 90 | -------------------------------------------------------------------------------- /backend/voice_smith/g2p/text_symbols.txt: -------------------------------------------------------------------------------- 1 | ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '20', '21', '22', '23', '24', '25', '26', '27', '28', '29', '30', '31', '32', '33', '34', '35', '36', '37', '38', '39', '40', '41', '42', '43', '44', '45', '46', '47', '48', '49', '50', '51', '52', '53', '54', '55', '56', '57', '58', '59', '60', '61', '62', '63', '64', '65', '66', '67', '68', '69', '70', '71', '72', '73', '74', '75', '76', '77', '78', '79', '80', '81', '82', '83', '84', '85', '86', '87', '88', '89', '90', '91', '92', '93', '94', '95', '96', '97', '98', '99', '100', '101', '102', '103', '104', '105', '106', '107', '108', '109', '110', '111', '112', '113', '114', '115', '116', '117', '118', '119', '120', '121', '122', '123', '124', '125', '126', '127', '128', '129', '130', '131', '132', '133', '134', '135', '136', '137', '138', '139', '140', '141', '142', '143', '144', '145', '146', '147', '148', '149', '150', '151', '152', '153', '154', '155', '156', '157', '158', '159', '160', '161', '162', '163', '164', '165', '166', '167', '168', '169', '170', '171', '172', '173', '174', '175', '176', '177', '178', '179', '180', '181', '182', '183', '184', '185', '186', '187', '188', '189', '190', '191', '192', '193', '194', '195', '196', '197', '198', '199', '200', '201', '202', '203', '204', '205', '206', '207', '208', '209', '210', '211', '212', '213', '214', '215', '216', '217', '218', '219', '220', '221', '222', '223', '224', '225', '226', '227', '228', '229', '230', '231', '232', '233', '234', '235', '236', '237', '238', '239', '240', '241', '242', '243', '244', '245', '246', '247', '248', '249', '250', '251', '252', '253', '254', '255', ''] -------------------------------------------------------------------------------- /backend/voice_smith/model/position_encoding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | 4 | 5 | def positional_encoding( 6 | d_model: int, length: int, device: torch.device 7 | ) -> torch.Tensor: 8 | pe = torch.zeros(length, d_model, device=device) 9 | position = torch.arange(0, length, dtype=torch.float, device=device).unsqueeze(1) 10 | div_term = torch.exp( 11 | torch.arange(0, d_model, 2, device=device).float() 12 | * -(math.log(10000.0) / d_model) 13 | ) 14 | pe[:, 0::2] = torch.sin(position * div_term) 15 | pe[:, 1::2] = torch.cos(position * div_term) 16 | pe = pe.unsqueeze(0) 17 | return pe 18 | -------------------------------------------------------------------------------- /backend/voice_smith/preprocessing/copy_files.py: -------------------------------------------------------------------------------- 1 | from joblib import Parallel, delayed 2 | import torch 3 | from typing import List, Callable, Literal, Union 4 | from pathlib import Path 5 | from voice_smith.utils.audio import safe_load, save_audio 6 | from voice_smith.utils.tools import iter_logger 7 | 8 | 9 | def copy_sample( 10 | sample_id: int, 11 | audio_src: str, 12 | text: str, 13 | out_dir: str, 14 | skip_on_error: bool, 15 | lowercase: bool, 16 | name_by: Union[Literal["id"], Literal["name"]], 17 | ) -> None: 18 | if not Path(audio_src).exists(): 19 | raise Exception(f"File {audio_src} does not exist ...") 20 | out_path = Path(out_dir) 21 | out_path.mkdir(exist_ok=True, parents=True) 22 | audio_out_path = ( 23 | out_path / f"{sample_id if name_by == 'id' else Path(audio_src).stem}.flac" 24 | ) 25 | txt_out_path = ( 26 | out_path / f"{sample_id if name_by == 'id' else Path(audio_src).stem}.txt" 27 | ) 28 | if txt_out_path.exists(): 29 | return 30 | try: 31 | audio, sr = safe_load(audio_src, sr=None) 32 | save_audio(str(audio_out_path), torch.FloatTensor(audio), sr) 33 | with open(txt_out_path, "w", encoding="utf-8") as f: 34 | f.write(text.lower() if lowercase else text) 35 | except Exception as e: 36 | if skip_on_error: 37 | if audio_out_path.exists(): 38 | audio_out_path.unlink() 39 | print(e) 40 | return 41 | else: 42 | raise e 43 | 44 | 45 | def copy_files( 46 | data_path: str, 47 | sample_ids: List[int], 48 | texts: List[str], 49 | audio_paths: List[str], 50 | names: List[str], 51 | langs: List[str], 52 | workers: int, 53 | skip_on_error: bool, 54 | name_by: Union[Literal["id"], Literal["name"]], 55 | progress_cb: Callable[[float], None], 56 | lowercase: bool = True, 57 | log_every: int = 200, 58 | ) -> None: 59 | assert len(sample_ids) == len(audio_paths) == len(names) == len(texts) == len(langs) 60 | 61 | def callback(index: int): 62 | if index % log_every == 0: 63 | progress = index / len(sample_ids) 64 | progress_cb(progress) 65 | 66 | print("Copying files ...") 67 | Parallel(n_jobs=workers)( 68 | delayed(copy_sample)( 69 | sample_id, 70 | audio_path, 71 | text, 72 | Path(data_path) / "raw_data" / lang / name, 73 | skip_on_error, 74 | lowercase, 75 | name_by, 76 | ) 77 | for sample_id, audio_path, text, name, lang in iter_logger( 78 | zip(sample_ids, audio_paths, texts, names, langs), 79 | cb=callback, 80 | total=len(sample_ids), 81 | ) 82 | ) 83 | progress_cb(1.0) 84 | -------------------------------------------------------------------------------- /backend/voice_smith/preprocessing/g2p.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import List 3 | from voice_smith.g2p.dp.utils.model import get_g2p 4 | from voice_smith.g2p.dp.utils.infer import batched_predict 5 | 6 | 7 | def grapheme_to_phonemes( 8 | texts: List[str], langs: List[str], assets_path: str, device: torch.device 9 | ) -> List[List[str]]: 10 | assert len(texts) == len(langs) 11 | model = get_g2p(assets_path=assets_path, device=device) 12 | phonemes_list = batched_predict(model, texts) 13 | return phonemes_list 14 | 15 | 16 | if __name__ == "__main__": 17 | texts = ["This", "is", "a", "test", "hehehehee", "!"] 18 | assets_path = "/home/media/main_volume/datasets/voice-smith/assets" 19 | phones = grapheme_to_phonemes( 20 | texts=texts, assets_path=assets_path, device=torch.device("cuda") 21 | ) 22 | print(phones) 23 | -------------------------------------------------------------------------------- /backend/voice_smith/preprocessing/generate_vocab.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import List, Dict 3 | from voice_smith.config.configs import PreprocessLangType 4 | from voice_smith.utils.shell import run_conda_in_shell 5 | from voice_smith.preprocessing.g2p import batched_predict, get_g2p 6 | from voice_smith.utils.tokenization import WordTokenizer 7 | from voice_smith.utils.mfa import lang_to_mfa_g2p 8 | 9 | 10 | def generate_vocab( 11 | texts: List[str], 12 | lang: str, 13 | assets_path: str, 14 | language_type: PreprocessLangType, 15 | device: torch.device, 16 | ) -> Dict[str, List[str]]: 17 | tokenizer = WordTokenizer(lang=lang, remove_punct=False) 18 | words_to_tokenize = set() 19 | for text in texts: 20 | for word in tokenizer.tokenize(text): 21 | words_to_tokenize.add(word) 22 | words_to_tokenize = list(words_to_tokenize) 23 | g2p = get_g2p(assets_path=assets_path, device=device) 24 | predicted_phones = batched_predict( 25 | model=g2p, 26 | texts=words_to_tokenize, 27 | langs=[lang for _ in range(len(words_to_tokenize))], 28 | ) 29 | vocab = {word: phones for word, phones in zip(words_to_tokenize, predicted_phones)} 30 | return vocab 31 | 32 | 33 | def generate_vocab_mfa( 34 | lexicon_path: str, 35 | n_workers: int, 36 | lang: str, 37 | corpus_path: str, 38 | environment_name: str, 39 | language_type: PreprocessLangType, 40 | ): 41 | cmd = f"mfa g2p --clean -j {n_workers} {lang_to_mfa_g2p(lang, language_type)} {corpus_path} {lexicon_path}" 42 | run_conda_in_shell(cmd, environment_name, stderr_to_stdout=True) 43 | -------------------------------------------------------------------------------- /backend/voice_smith/preprocessing/get_txt_from_files.py: -------------------------------------------------------------------------------- 1 | from joblib import Parallel, delayed 2 | import multiprocessing as mp 3 | import shutil 4 | from pathlib import Path 5 | from voice_smith.utils.audio import safe_load 6 | from typing import List, Callable, Optional, Union 7 | from voice_smith.utils.tools import iter_logger 8 | 9 | 10 | def get_txt_from_file(src: str) -> Union[str, None]: 11 | if not Path(src).exists(): 12 | print(f"Text file {src} doesn't exist, skipping ...") 13 | return None 14 | with open(src, "r", encoding="utf-8") as f: 15 | text = f.read() 16 | return text 17 | 18 | 19 | def get_txt_from_files( 20 | db_id: int, 21 | table_name: str, 22 | txt_paths: List[str], 23 | get_logger: Optional[Callable], 24 | log_every: int = 200, 25 | ) -> List[str]: 26 | def callback(index: int): 27 | if index % log_every == 0: 28 | logger = get_logger() 29 | progress = index / len(txt_paths) 30 | logger.query( 31 | f"UPDATE {table_name} SET get_txt_progress=? WHERE id=?", 32 | [progress, db_id], 33 | ) 34 | 35 | print("Fetching text from files ...") 36 | texts = Parallel(n_jobs=max(1, mp.cpu_count() - 1))( 37 | delayed(get_txt_from_file)(file_path) 38 | for file_path in iter_logger(txt_paths, cb=callback) 39 | ) 40 | logger = get_logger() 41 | logger.query( 42 | f"UPDATE {table_name} SET get_txt_progress=? WHERE id=?", 43 | [1.0, db_id], 44 | ) 45 | return texts 46 | -------------------------------------------------------------------------------- /backend/voice_smith/preprocessing/merge_lexika.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | 4 | def lex_to_dict(path: str): 5 | with open(path, "r", encoding="utf-8") as f: 6 | lines = f.readlines() 7 | dic = {} 8 | for line in lines: 9 | if line.strip() == "": 10 | continue 11 | split = line.split() 12 | key = split[0].strip().lower() 13 | value = " ".join(split[1:]) 14 | # Some dicts contain lower case phones, which is incorrect 15 | value = value.upper() 16 | dic[key] = value.strip() 17 | return dic 18 | 19 | 20 | def merge_lexica(base_lexica_path: str, lang: str, assets_path: str, out_path: str): 21 | dic_final = {} 22 | for lex in list((Path(assets_path) / "lexica" / lang).iterdir()) + [ 23 | base_lexica_path 24 | ]: 25 | dic = lex_to_dict(str(lex)) 26 | for key in dic.keys(): 27 | if not key in dic_final: 28 | dic_final[key] = dic[key] 29 | 30 | with open(out_path, "w", encoding="utf-8") as f: 31 | for key in dic_final.keys(): 32 | f.write(key + " " + dic_final[key] + "\n") 33 | 34 | -------------------------------------------------------------------------------- /backend/voice_smith/preprocessing/transcribe.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | from typing import Dict, List, Callable 4 | from transformers import pipeline 5 | from voice_smith.utils.tools import iter_logger 6 | 7 | lang2model: Dict[str, str] = { 8 | "bg": "anuragshas/wav2vec2-large-xls-r-300m-bg", 9 | "cs": "comodoro/wav2vec2-xls-r-300m-cs-250", 10 | "de": "jonatasgrosman/wav2vec2-large-xlsr-53-german", 11 | "en": "facebook/wav2vec2-base-960h", 12 | "es": "jonatasgrosman/wav2vec2-large-xlsr-53-spanish", 13 | "fr": "jonatasgrosman/wav2vec2-large-xlsr-53-french", 14 | "hr": "classla/wav2vec2-xls-r-parlaspeech-hr", 15 | "pl": "jonatasgrosman/wav2vec2-large-xlsr-53-polish", 16 | "pt": "jonatasgrosman/wav2vec2-large-xlsr-53-portuguese", 17 | "ru": "jonatasgrosman/wav2vec2-large-xlsr-53-russian", 18 | "sv": "viktor-enzell/wav2vec2-large-voxrex-swedish-4gram", 19 | "th": "airesearch/wav2vec2-large-xlsr-53-th", 20 | "tr": "mpoyraz/wav2vec2-xls-r-300m-cv6-turkish", 21 | "uk": "Yehor/wav2vec2-xls-r-300m-uk-with-small-lm", 22 | } 23 | 24 | 25 | class STTDataset(Dataset): 26 | def __init__(self, files): 27 | self.files = files 28 | 29 | def __len__(self): 30 | return len(self.files) 31 | 32 | def __getitem__(self, idx): 33 | return self.files[idx] 34 | 35 | 36 | def transcribe( 37 | audio_files: List[str], 38 | lang: str, 39 | device: torch.device, 40 | progress_cb: Callable[[float], None], 41 | callback_every: int = 25, 42 | batch_size: int = 10, 43 | ) -> List[str]: 44 | 45 | dataset = STTDataset(audio_files) 46 | 47 | pipe = pipeline( 48 | model=lang2model[lang], 49 | device=0 if "cuda" in device.type else -1, 50 | chunk_length_s=10, 51 | stride_length_s=(4, 2), 52 | framework="pt", 53 | ) 54 | 55 | def cb(progress): 56 | progress_cb(progress / len(audio_files)) 57 | 58 | transcriptions = [] 59 | 60 | with torch.no_grad(): 61 | for transcription in iter_logger( 62 | pipe(dataset, batch_size=batch_size), 63 | total=len(audio_files), 64 | print_every=callback_every, 65 | callback_every=callback_every, 66 | cb=cb, 67 | ): 68 | transcriptions.append(transcription["text"]) 69 | 70 | return transcriptions 71 | 72 | -------------------------------------------------------------------------------- /backend/voice_smith/preprocessing/vad.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pathlib import Path 3 | from typing import List 4 | from tqdm import tqdm 5 | from joblib import Parallel, delayed 6 | from voice_smith.utils.audio import safe_load, resample, save_audio 7 | 8 | 9 | def remove_silence( 10 | in_paths: List[str], 11 | out_paths: List[str], 12 | before_silence_sec: float = 0.0, 13 | after_silence_sec: float = 0.0, 14 | vad_sr=16000, 15 | ) -> None: 16 | assert len(in_paths) == len(out_paths) 17 | model, utils = torch.hub.load( 18 | repo_or_dir="snakers4/silero-vad", 19 | model="silero_vad", 20 | force_reload=False, 21 | onnx=False, 22 | ) 23 | 24 | (get_speech_timestamps, _, _, _, _) = utils 25 | for in_path, out_path in zip(in_paths, out_paths): 26 | wav, sr_orig = safe_load(in_path, sr=None) 27 | wav_16k = resample(wav, orig_sr=sr_orig, target_sr=vad_sr) 28 | tstamps = get_speech_timestamps( 29 | torch.from_numpy(wav_16k), model, sampling_rate=vad_sr, return_seconds=True 30 | ) 31 | if len(tstamps) > 0: 32 | start = max(tstamps[0]["start"] - before_silence_sec, 0) 33 | end = tstamps[-1]["end"] + after_silence_sec 34 | else: 35 | start = 0 36 | end = -1 37 | wav = wav[int(start * sr_orig) : int(end * sr_orig if end != -1 else end)] 38 | Path(out_path).parent.mkdir(exist_ok=True, parents=True) 39 | save_audio(out_path, torch.from_numpy(wav), sr_orig) 40 | 41 | 42 | def batched_remove_silence( 43 | in_paths: List[str], 44 | out_paths: List[str], 45 | workers: int, 46 | before_silence_sec: float = 0.0, 47 | after_silence_sec: float = 0.0, 48 | vad_sr=16000, 49 | chunk_size=500, 50 | ): 51 | Parallel(n_jobs=workers)( 52 | delayed(remove_silence)( 53 | in_paths[chunk_start_idx : chunk_start_idx + chunk_size], 54 | out_paths[chunk_start_idx : chunk_start_idx + chunk_size], 55 | before_silence_sec, 56 | after_silence_sec, 57 | vad_sr, 58 | ) 59 | for chunk_start_idx in tqdm(range(0, len(in_paths), chunk_size)) 60 | ) 61 | -------------------------------------------------------------------------------- /backend/voice_smith/pretrain_acoustic.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import wandb 3 | from voice_smith.acoustic_training import train_acoustic 4 | from voice_smith.config.configs import ( 5 | PreprocessingConfig, 6 | AcousticPretrainingConfig, 7 | AcousticENModelConfig, 8 | ) 9 | from voice_smith.utils.wandb_logger import WandBLogger 10 | import argparse 11 | from voice_smith.config.globals import TRAINING_RUNS_PATH, ASSETS_PATH 12 | 13 | if __name__ == "__main__": 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument("--run_id", type=int, required=True) 16 | parser.add_argument("--checkpoint", type=str, default=None) 17 | args = parser.parse_args() 18 | 19 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 20 | logger = WandBLogger("DelightfulTTS 120M parameters with UnsupDurAligner") 21 | p_config = PreprocessingConfig(language="english_only") 22 | m_config = AcousticENModelConfig() 23 | t_config = AcousticPretrainingConfig() 24 | wandb.config.update( 25 | { 26 | "preprocess_config": p_config, 27 | "model_config": m_config, 28 | "training_config": t_config, 29 | }, 30 | allow_val_change=True, 31 | ) 32 | train_acoustic( 33 | db_id=args.run_id, 34 | training_run_name=str(args.run_id), 35 | preprocess_config=p_config, 36 | model_config=m_config, 37 | train_config=t_config, 38 | logger=logger, 39 | device=device, 40 | reset=False, 41 | checkpoint_acoustic=args.checkpoint, 42 | fine_tuning=False, 43 | overwrite_saves=True, 44 | assets_path=ASSETS_PATH, 45 | training_runs_path=TRAINING_RUNS_PATH, 46 | ) 47 | 48 | -------------------------------------------------------------------------------- /backend/voice_smith/scripts/get_statistics.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import librosa 3 | from joblib import Parallel, delayed 4 | from tqdm import tqdm 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | from typing import List 8 | from voice_smith.utils.tokenization import WordTokenizer 9 | 10 | def get_duration(audio_path): 11 | audio, sr = librosa.load(audio_path, sr=None) 12 | duration = audio.shape[0] / sr 13 | return duration 14 | 15 | def get_text(text_path): 16 | with open(text_path, "r", encoding="utf-8") as f: 17 | text = f.read().strip() 18 | return text 19 | 20 | if __name__ == "__main__": 21 | 22 | tokenizer = WordTokenizer(lang="en", remove_punct=False) 23 | 24 | words = set() 25 | 26 | with open("en.dict", "r", encoding="utf-8") as f: 27 | for line in f.readlines(): 28 | words.add(line.split("\t")[0]) 29 | 30 | for speaker_path in (Path(".") / "in").iterdir(): 31 | speaker_name = speaker_path.name 32 | durations = Parallel(n_jobs=12)( 33 | delayed(get_duration)(audio_path) 34 | for audio_path in tqdm(speaker_path.glob("*.flac")) 35 | ) 36 | texts = Parallel(n_jobs=12)( 37 | delayed(get_text)(text_path) 38 | for text_path in tqdm(speaker_path.glob("*.txt")) 39 | ) 40 | 41 | words_found = 0 42 | words_total = 0 43 | for text in texts: 44 | for token in tokenizer.tokenize(text): 45 | if token.lower() in words: 46 | words_found += 1 47 | words_total += 1 48 | 49 | 50 | np.random.seed(42) 51 | x = np.random.normal(size=1000) 52 | 53 | plt.hist(x, density=True, bins=20) # density=False would make counts 54 | plt.xlabel('duration') 55 | plt.savefig(f"{speaker_name}_durations.png") 56 | 57 | print("-" * 20) 58 | print(f"Speaker: {speaker_name}") 59 | print(f"Number of files: {len(durations)}") 60 | print(f"Mean duration: {sum(durations) / len(durations)}") 61 | print(f"Words in dict: {(words_found / words_total) * 100}%") 62 | print("-" * 20) -------------------------------------------------------------------------------- /backend/voice_smith/scripts/prep_acoustic_for_fine.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | import torch 4 | from voice_smith.utils.model import get_acoustic_models 5 | from voice_smith.config.globals import ASSETS_PATH 6 | from voice_smith.config.configs import PreprocessingConfig, AcousticFinetuningConfig, AcousticModelConfig 7 | 8 | def prep_acoustic_for_fine(checkpoint: str, data_path: str): 9 | device = torch.device("cpu") 10 | gen, optim, step = get_acoustic_models( 11 | data_path=data_path, 12 | checkpoint_acoustic=checkpoint, 13 | train_config=AcousticFinetuningConfig(), 14 | preprocess_config=PreprocessingConfig(), 15 | model_config=AcousticModelConfig(), 16 | fine_tuning=True, 17 | device=device, 18 | reset=True, 19 | assets_path=ASSETS_PATH, 20 | ) 21 | torch.save({"gen": gen.state_dict(), "steps": 0}, Path(ASSETS_PATH) / "acoustic_pretrained.pt") 22 | 23 | 24 | if __name__ == "__main__": 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument('--checkpoint', type=str, required=True) 27 | parser.add_argument('--data_path', type=str, required=True) 28 | args = parser.parse_args() 29 | prep_acoustic_for_fine(checkpoint=args.checkpoint, data_path=args.data_path) 30 | -------------------------------------------------------------------------------- /backend/voice_smith/scripts/speaker_encoder_to_torch.py: -------------------------------------------------------------------------------- 1 | from speechbrain.pretrained import EncoderClassifier 2 | from torch.jit._trace import trace_module 3 | from pathlib import Path 4 | import torch 5 | from voice_smith.config.globals import ASSETS_PATH 6 | 7 | if __name__ == "__main__": 8 | classifier = EncoderClassifier.from_hparams(source=str(Path(ASSETS_PATH) / "ecapa_tdnn")) 9 | classifier.eval() 10 | classifier.device = torch.device("cpu") 11 | classifier.mods.embedding_model.cpu() 12 | input = torch.randn((6, 49440)) 13 | relative_lens = torch.ones((6,)) 14 | output_pre = classifier.encode_batch(input, relative_lens) 15 | classifier_torch = trace_module( 16 | classifier, {"encode_batch": (input, relative_lens)} 17 | ) 18 | classifier_torch.save(Path(ASSETS_PATH) / "ecapa_tdnn.pt") 19 | output_post = classifier_torch.encode_batch(input, relative_lens) 20 | assert torch.allclose(output_pre, output_post) 21 | -------------------------------------------------------------------------------- /backend/voice_smith/sql.py: -------------------------------------------------------------------------------- 1 | import sqlite3 2 | import pathlib 3 | from pathlib import Path 4 | import os 5 | 6 | 7 | def get_con(db_path: str) -> sqlite3.Connection: 8 | con = sqlite3.connect(db_path) 9 | return con 10 | 11 | 12 | def save_current_pid(con: sqlite3.Connection, cur: sqlite3.Cursor): 13 | cur.execute("UPDATE settings SET pid=?", (os.getpid(), )) 14 | con.commit() 15 | -------------------------------------------------------------------------------- /backend/voice_smith/utils/audio_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from voice_smith.utils.audio import save_audio, safe_load, resample 3 | from voice_smith.config.file_extensions import SUPPORTED_AUDIO_EXTENSIONS 4 | 5 | _SAMPLING_RATES_TO_TEST = [ 6 | 800, 7 | 2000, 8 | 1999, 9 | 22050, 10 | 22001, 11 | 22057, 12 | 44100, 13 | 48000, 14 | 79999, 15 | 80000, 16 | ] 17 | 18 | 19 | def test_should_save_audio(tmp_path): 20 | for n_samples in _SAMPLING_RATES_TO_TEST: 21 | for sr in _SAMPLING_RATES_TO_TEST: 22 | audio = torch.randn((n_samples,)) 23 | for audio_extension in SUPPORTED_AUDIO_EXTENSIONS: 24 | audio_path = tmp_path / f"audio{audio_extension}" 25 | save_audio(file_path=audio_path, audio=audio, sr=sr) 26 | assert ( 27 | audio_path.exists() 28 | ), f"Failed to write audio for extension {audio_extension}" 29 | 30 | 31 | def test_should_load_audio(tmp_path): 32 | for n_samples in _SAMPLING_RATES_TO_TEST: 33 | for sr_in in _SAMPLING_RATES_TO_TEST: 34 | audio_shape = (n_samples,) 35 | audio_in = torch.randn(audio_shape) 36 | for audio_extension in SUPPORTED_AUDIO_EXTENSIONS: 37 | audio_path = tmp_path / f"audio{audio_extension}" 38 | save_audio(file_path=audio_path, audio=audio_in, sr=sr_in) 39 | audio_out, sr_out = safe_load(path=str(audio_path), sr=None) 40 | assert ( 41 | sr_in == sr_out 42 | ), f"Invalid sampling rate loaded for extension {audio_extension}. sr_in: {sr_in}, sr_out: {sr_out}" 43 | assert ( 44 | audio_in.shape == audio_out.shape 45 | ), f"Invalid shape loaded for extension {audio_extension}. shape_in: {audio_in.shape}, shape_out: {audio_out.shape}" 46 | 47 | 48 | def test_should_resample(): 49 | for n_samples in [1000, 22049, 22050, 22051, 44100]: 50 | for orig_sr in _SAMPLING_RATES_TO_TEST: 51 | for target_sr in _SAMPLING_RATES_TO_TEST: 52 | audio_shape = (n_samples,) 53 | audio_in = torch.randn(audio_shape) 54 | resample(wav=audio_in.numpy(), orig_sr=orig_sr, target_sr=target_sr) 55 | -------------------------------------------------------------------------------- /backend/voice_smith/utils/currencies.py: -------------------------------------------------------------------------------- 1 | # FROM https://gist.githubusercontent.com/ksafranski/2973986/raw/5fda5e87189b066e11c1bf80bbfbecb556cf2cc1/Common-Currency.json 2 | iso_4217_to_symbols = { 3 | "AED": ["د.إ", "AED"], 4 | "AFN": ["؋", "Af"], 5 | "ALL": ["ALL", "Lek"], 6 | "AMD": ["դր.", "AMD"], 7 | "ANG": ["ƒ", "NAf", "NAƒ"], 8 | "AOA": ["Kz"], 9 | "ARS": ["$", "AR$"], 10 | "AUD": ["$", "AU$"], 11 | "AZN": ["man.", "ман."], 12 | "BAM": ["KM"], 13 | "BDT": ["৳", "Tk"], 14 | "BGN": ["лв.", "BGN"], 15 | "BHD": ["BD", "৳"], 16 | "BIF": ["FBu"], 17 | "BND": ["BN$", "$"], 18 | "BOB": ["Bs"], 19 | "BRL": ["R$"], 20 | "BWP": ["BWP", "P"], 21 | "BYN": ["Br", "руб."], 22 | "BZD": ["$", "BZ$"], 23 | "CAD": ["$", "CA$"], 24 | "CDF": ["FrCD", "CDF"], 25 | "CHF": ["CHF"], 26 | "CLP": ["CL$", "$"], 27 | "CNY": ["CN¥"], 28 | "COP": ["CO$", "$"], 29 | "CRC": ["₡"], 30 | "CVE": ["CV$"], 31 | "CZK": ["Kč"], 32 | "DEM": ["DM", "DEM"], 33 | "DJF": ["Fdj"], 34 | "DKK": ["Dkr", "kr"], 35 | "DOP": ["RD$"], 36 | "DZD": ["دج", "DA"], 37 | "EEK": ["Ekr", "kr"], 38 | "EGP": ["EGP", "ج.م", "E£"], 39 | "ERN": ["Nfk"], 40 | "ESP": ["Pta", "Pt", "Pts", "Ptas"], 41 | "ETB": ["Br"], 42 | "EUR": ["€"], 43 | "FRF": ["FRF", "₣"], 44 | "GBP": ["£"], 45 | "GEL": ["GEL"], 46 | "GHS": ["GH₵"], 47 | "GNF": ["FG"], 48 | "GTQ": ["Q", "GTQ"], 49 | "HKD": ["$", "HK$"], 50 | "HNL": ["HNL", "L"], 51 | "HRK": ["kn"], 52 | "HUF": ["Ft"], 53 | "IDR": ["Rp"], 54 | "ILS": ["₪"], 55 | "INR": ["টকা", "Rs"], 56 | "IQD": ["IQD", "ع.د"], 57 | "IRR": ["IRR", "﷼"], 58 | "ISK": ["Ikr", "kr"], 59 | "JMD": ["J$", "$"], 60 | "JOD": ["د.ا", "JD"], 61 | "JPY": ["¥", "¥"], 62 | "KES": ["Ksh"], 63 | "KHR": ["៛", "KHR"], 64 | "KMF": ["CF", "FC"], 65 | "KPW": ["₩"], 66 | "KRW": ["₩"], 67 | "KWD": ["KD", "د.ك"], 68 | "KZT": ["KZT", "тңг."], 69 | "LBP": ["L.L.", "ل.ل"], 70 | "LKR": ["SLRs", "SL Re"], 71 | "LTL": ["Lt"], 72 | "LVL": ["Ls"], 73 | "LYD": ["LD", "ل.د"], 74 | "MAD": ["MAD"], 75 | "MDL": ["MDL"], 76 | "MGA": ["MGA"], 77 | "MKD": ["MKD"], 78 | "MMK": ["K", "MMK"], 79 | "MOP": ["MOP$"], 80 | "MUR": ["MURs"], 81 | "MXN": ["MX$", "$"], 82 | "MYR": ["RM"], 83 | "MZN": ["MTn"], 84 | "NAD": ["N$"], 85 | "NGN": ["₦"], 86 | "NIO": ["C$"], 87 | "NOK": ["Nkr", "kr"], 88 | "NPR": ["NPRs", "नेरू"], 89 | "NZD": ["$", "NZ$"], 90 | "OMR": ["ر.ع.", "OMR"], 91 | "PAB": ["B/."], 92 | "PEN": ["S/."], 93 | "PHP": ["₱"], 94 | "PKR": ["₨", "PKRs"], 95 | "PLN": ["zł"], 96 | "PYG": ["₲"], 97 | "QAR": ["QR", "ر.ق"], 98 | "RON": ["RON"], 99 | "RSD": ["дин.", "din."], 100 | "RUB": ["RUB", "₽."], 101 | "RWF": ["FR", "RWF"], 102 | "SAR": ["SR"], 103 | "SDG": ["SDG"], 104 | "SEK": ["Skr", "kr"], 105 | "SGD": ["$", "S$"], 106 | "SOS": ["Ssh"], 107 | "SYP": ["LS", "£S", "SY£"], 108 | "THB": ["฿"], 109 | "TND": ["د.ت", "DT"], 110 | "TOP": ["T$"], 111 | "TRY": ["TL"], 112 | "TTD": ["$", "TT$"], 113 | "TWD": ["NT$"], 114 | "TZS": ["TSh"], 115 | "UAH": ["₴"], 116 | "UGX": ["USh"], 117 | "USD": ["$"], 118 | "UYU": ["$", "$U"], 119 | "UZS": ["UZS"], 120 | "VEF": ["Bs.F."], 121 | "VES": ["Bs.D", "Bs."], 122 | "VND": ["₫"], 123 | "XAF": ["FCFA"], 124 | "XOF": ["CFA"], 125 | "YER": ["YR", "﷼"], 126 | "ZAR": ["R"], 127 | "ZMK": ["ZK"], 128 | "ZWL": ["ZWL$"], 129 | } 130 | 131 | -------------------------------------------------------------------------------- /backend/voice_smith/utils/ds_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import List, Any, Tuple 3 | import torch 4 | 5 | def unison_shuffled_copies(a: List[Any], b: List[Any]) -> Tuple[List[Any], List[Any]]: 6 | assert len(a) == len(b) 7 | p = np.random.permutation(len(a)) 8 | a_out = [a[idx] for idx in p] 9 | b_out = [b[idx] for idx in p] 10 | return a_out, b_out 11 | 12 | def stratified_train_test_split(x: List[Any], y: List[Any], train_size: float) -> Tuple[List[Any], List[Any], List[Any], List[Any]]: 13 | label2samples = {} 14 | for x, y in zip(x, y): 15 | if y in label2samples: 16 | label2samples[y].append(x) 17 | else: 18 | label2samples[y] = [x] 19 | train_x_out, train_y_out = [], [] 20 | val_x_out, val_y_out = [], [] 21 | for label, samples in label2samples.items(): 22 | split_at = int(np.round(len(samples) * train_size)) 23 | x_split_train, x_split_val = samples[:split_at], samples[split_at:] 24 | y_split_train, y_split_val = [label] * len(x_split_train), [label] * len(x_split_val) 25 | train_x_out.extend(x_split_train) 26 | train_y_out.extend(y_split_train) 27 | val_x_out.extend(x_split_val) 28 | val_y_out.extend(y_split_val) 29 | train_x_out, train_y_out = unison_shuffled_copies(train_x_out, train_y_out) 30 | val_x_out, val_y_out = unison_shuffled_copies(val_x_out, val_y_out) 31 | return train_x_out, val_x_out, train_y_out, val_y_out 32 | 33 | class OnlineScaler(): 34 | """ Online mean and variance computation, see 35 | http://www.cs.yale.edu/publications/techreports/tr222.pdf 36 | equation 1.5a and 1.5b 37 | """ 38 | t_1_m = None 39 | s_1_m = None 40 | m = 0 41 | 42 | def partial_fit(self, x: torch.Tensor) -> None: 43 | assert(len(x.shape) > 1), "First dimension to partial_fit must be batch size" 44 | if self.m == 0: 45 | self.t_1_m = x[0] 46 | self.s_1_m = 0.0 47 | if x.shape[0] > 1: 48 | self.m += 1 49 | self.partial_fit(x[1:]) 50 | else: 51 | n = x.shape[0] 52 | x_sum = x.sum(0) 53 | self.s_1_m = self.s_1_m + x.var(0) + (self.m / (n * (self.m + n))) * ((n / self.m) * self.t_1_m - x_sum) ** 2 54 | self.t_1_m = self.t_1_m + x_sum 55 | self.m += n 56 | 57 | def get_mean_std(self) -> Tuple[torch.Tensor, torch.Tensor]: 58 | return self.t_1_m / self.m, torch.sqrt(self.s_1_m) 59 | 60 | if __name__ == "__main__": 61 | def count_labels(xs): 62 | label2count = {} 63 | for x in xs: 64 | if x in label2count: 65 | label2count[x] += 1 66 | else: 67 | label2count[x] = 1 68 | return label2count 69 | 70 | samples = (["0_sample"] * 800000) + (["1_sample"] * 180000) + (["2_sample"] * 20000) + (["3_sample"] * 10000) 71 | labels = (["0_label"] * 800000) + (["1_label"] * 180000) + (["2_label"] * 20000) + (["3_label "] * 10000) 72 | train_x_out, val_x_out, train_y_out, val_y_out = stratified_train_test_split(samples, labels, train_size=0.9) 73 | 74 | print(count_labels(train_x_out)) 75 | print(count_labels(train_y_out)) 76 | print(count_labels(val_x_out)) 77 | print(count_labels(val_y_out)) 78 | 79 | a = torch.randn((100000, 10)) 80 | scaler = OnlineScaler() 81 | scaler.partial_fit(a) 82 | 83 | print(scaler_sklearn._mean, scaler_sklearn._std) 84 | 85 | print(a.mean(0), a.var(0)) 86 | print(scaler.get_mean_var()) -------------------------------------------------------------------------------- /backend/voice_smith/utils/exceptions.py: -------------------------------------------------------------------------------- 1 | class InvalidLangException(Exception): 2 | """ Raised when a language was passed as an argument which is 3 | not supported. 4 | """ 5 | 6 | pass 7 | -------------------------------------------------------------------------------- /backend/voice_smith/utils/export.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Tuple, Union 3 | from torch.jit._script import script, ScriptModule 4 | from torch.jit._trace import trace 5 | from voice_smith.utils.model import get_acoustic_models, get_vocoder 6 | from voice_smith.model.univnet import TracedGenerator 7 | from voice_smith.config.configs import ( 8 | AcousticPretrainingConfig, 9 | AcousticFinetuningConfig, 10 | PreprocessingConfig, 11 | AcousticModelConfigType, 12 | VocoderPretrainingConfig, 13 | VocoderFinetuningConfig, 14 | VocoderModelConfig, 15 | ) 16 | 17 | 18 | def acoustic_to_torchscript( 19 | checkpoint_acoustic: str, 20 | data_path: str, 21 | train_config: Union[AcousticPretrainingConfig, AcousticFinetuningConfig], 22 | preprocess_config: PreprocessingConfig, 23 | model_config: AcousticModelConfigType, 24 | assets_path: str, 25 | ) -> Tuple[ScriptModule, ScriptModule]: 26 | device = torch.device("cpu") 27 | acoustic, _, _ = get_acoustic_models( 28 | checkpoint_acoustic=checkpoint_acoustic, 29 | data_path=data_path, 30 | train_config=train_config, 31 | preprocess_config=preprocess_config, 32 | model_config=model_config, 33 | fine_tuning=False, 34 | device=device, 35 | reset=False, 36 | assets_path=assets_path, 37 | ) 38 | acoustic.prepare_for_export() 39 | acoustic.eval() 40 | acoustic_torch = script(acoustic,) 41 | return acoustic_torch 42 | 43 | 44 | def vocoder_to_torchscript( 45 | ckpt_path: str, 46 | data_path: str, 47 | train_config: Union[VocoderPretrainingConfig, VocoderFinetuningConfig], 48 | preprocess_config: PreprocessingConfig, 49 | model_config: VocoderModelConfig, 50 | ) -> ScriptModule: 51 | device = torch.device("cpu") 52 | vocoder, _, _, _, _, _, _ = get_vocoder( 53 | checkpoint=ckpt_path, 54 | train_config=train_config, 55 | reset=False, 56 | device=device, 57 | preprocess_config=preprocess_config, 58 | model_config=model_config, 59 | ) 60 | vocoder.eval(True) 61 | mels = torch.randn((2, preprocess_config.stft.n_mel_channels, 50)) 62 | vocoder = TracedGenerator(vocoder, example_inputs=(mels,)) 63 | mel_lens = torch.tensor([mels.shape[2]], dtype=torch.int64) 64 | vocoder_torch = script(vocoder, example_inputs=[(mels, mel_lens)]) 65 | return vocoder_torch 66 | -------------------------------------------------------------------------------- /backend/voice_smith/utils/loggers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys 3 | from typing import Union, Tuple 4 | import matplotlib 5 | 6 | 7 | class DualLogger(object): 8 | # https://stackoverflow.com/questions/14906764/how-to-redirect-stdout-to-both-file-and-console-with-scripting 9 | # Writes to both stdout and terminal 10 | def __init__(self, location: str): 11 | self.terminal = sys.stdout 12 | # "w" is not working as mode 13 | print(location) 14 | self.log = open(location, "a", encoding="utf-8") 15 | 16 | def write(self, message): 17 | self.terminal.write(message) 18 | self.log.write(message) 19 | 20 | def flush(self): 21 | # this flush method is needed for python 3 compatibility. 22 | # this handles the flush command by doing nothing. 23 | # you might want to specify some extra behavior here. 24 | pass 25 | 26 | 27 | def set_stream_location(location: str, log_console: bool) -> None: 28 | if log_console: 29 | pass 30 | else: 31 | sys.stdout = open(location, "w", encoding="utf-8") 32 | sys.stderr = sys.stdout 33 | 34 | 35 | class Logger: 36 | def __init__(self): 37 | self.cm = matplotlib.cm.get_cmap("plasma") 38 | 39 | def map_image_color(self, image: np.ndarray) -> None: 40 | normed_data = (image - np.min(image)) / (np.max(image) - np.min(image)) 41 | mapped_data = self.cm(normed_data) 42 | return mapped_data 43 | 44 | def log_image(self, name: str, image: np.ndarray, step: int) -> None: 45 | raise NotImplementedError 46 | 47 | def log_graph(self, name: str, value: float, step: int) -> None: 48 | raise NotImplementedError 49 | 50 | def log_audio(self, name: str, audio: np.ndarray, step: int, sr: int) -> None: 51 | raise NotImplementedError 52 | 53 | def query(self, query: str, args: Tuple[Union[int, str, float], ...]) -> None: 54 | raise NotImplementedError 55 | -------------------------------------------------------------------------------- /backend/voice_smith/utils/metrics.py: -------------------------------------------------------------------------------- 1 | from fastdtw import fastdtw 2 | import numpy as np 3 | from scipy.spatial.distance import euclidean 4 | from typing import Tuple 5 | import torch 6 | 7 | # from torchmetrics.functional.audio.stoi import short_time_objective_intelligibility 8 | # from torchmetrics.functional.audio.pesq import perceptual_evaluation_speech_quality 9 | from voice_smith.utils.audio import resample 10 | 11 | 12 | def mcd(a: np.ndarray, b: np.ndarray) -> float: 13 | """Computes the mel cepstrum 14 | 15 | :param a: np.ndarray of shape (timesteps_1, n_mels) 16 | :param b: np.ndarray of shape (timesteps_2, n_mels) 17 | :return: (np.ndarray, np.ndarray), which are the aligned versions of a and b, 18 | both of shape (max(timesteps_1, timesteps_2), n_mels) 19 | """ 20 | K = 10 / np.log(10) * np.sqrt(2) 21 | return K * np.mean(np.sqrt(np.sum((a - b) ** 2, axis=1))) 22 | 23 | 24 | def dtw_align(a: np.ndarray, b: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: 25 | """Uses dynamic time warping to align two 2-dimensional numpy arrays. 26 | Returns the aligned versions of both a and b. 27 | 28 | :param a: np.ndarray of shape (timesteps_1, n_mels) 29 | :param b: np.ndarray of shape (timesteps_2, n_mels) 30 | :return: (np.ndarray, np.ndarray), which are the aligned versions of a and b, 31 | both of shape (max(timesteps_1, timesteps_2), n_mels) 32 | """ 33 | _, warp_path = fastdtw(a, b, dist=euclidean) 34 | a_aligned = np.zeros((len(warp_path), max(a.shape[1], b.shape[1]))) 35 | b_aligned = np.zeros((len(warp_path), max(a.shape[1], b.shape[1]))) 36 | for i, (a_index, b_index) in enumerate(warp_path): 37 | a_aligned[i] = a[min(a_index, a.shape[0])] 38 | b_aligned[i] = b[min(b_index, b.shape[0])] 39 | return a_aligned, b_aligned 40 | 41 | 42 | def mcd_dtw(a: np.ndarray, b: np.ndarray) -> float: 43 | a_aligned, b_aligned = dtw_align(a, b) 44 | distortion = mcd(a_aligned, b_aligned) 45 | return distortion 46 | 47 | 48 | def calc_estoi(audio_real, audio_fake, sampling_rate): 49 | return torch.tensor([0.0], device=audio_fake.device, dtype=torch.float32) 50 | """return torch.mean( 51 | short_time_objective_intelligibility( 52 | audio_fake, audio_real, sampling_rate 53 | ).float() 54 | )""" 55 | 56 | 57 | def calc_pesq(audio_real_16k, audio_fake_16k): 58 | return torch.tensor([0.0], device=audio_fake_16k.device, dtype=torch.float32) 59 | """return torch.mean( 60 | perceptual_evaluation_speech_quality( 61 | audio_fake_16k, audio_real_16k, 16000, "wb" 62 | ) 63 | )""" 64 | 65 | 66 | def calc_rmse(audio_real, audio_fake, stft): 67 | spec_real = stft.linear_spectrogram(audio_real.squeeze(1)) 68 | spec_fake = stft.linear_spectrogram(audio_fake.squeeze(1)) 69 | mse = torch.nn.functional.mse_loss(spec_fake, spec_real) 70 | rmse = torch.sqrt(mse) 71 | return rmse 72 | -------------------------------------------------------------------------------- /backend/voice_smith/utils/mfa.py: -------------------------------------------------------------------------------- 1 | from voice_smith.config.configs import PreprocessLangType 2 | 3 | 4 | def lang_to_mfa_acoustic(lang: str, language_type: PreprocessLangType): 5 | if lang == "bg": 6 | return "bulgarian_mfa" 7 | elif lang == "cs": 8 | return "czech_mfa" 9 | elif lang == "de": 10 | return "german_mfa" 11 | elif lang == "en": 12 | if language_type == "english_only": 13 | return "english_us_arpa" 14 | else: 15 | return "english_mfa" 16 | elif lang == "es": 17 | return "spanish_mfa" 18 | elif lang == "fr": 19 | return "french_mfa" 20 | elif lang == "hr": 21 | return "croatian_mfa" 22 | elif lang == "pl": 23 | return "polish_mfa" 24 | elif lang == "pt": 25 | return "portuguese_mfa" 26 | elif lang == "ru": 27 | return "russian_mfa" 28 | elif lang == "sv": 29 | return "swedish_mfa" 30 | elif lang == "th": 31 | return "thai_mfa" 32 | elif lang == "tr": 33 | return "turkish_mfa" 34 | elif lang == "uk": 35 | return "ukrainian_mfa" 36 | raise Exception( 37 | f"No case selected in switch-statement - language '{lang}' is not supported ..." 38 | ) 39 | 40 | 41 | def lang_to_mfa_g2p(lang, language_type: PreprocessLangType): 42 | if lang == "bg": 43 | return "bulgarian_mfa" 44 | elif lang == "cs": 45 | return "czech_mfa" 46 | elif lang == "de": 47 | return "german_mfa" 48 | elif lang == "en": 49 | if language_type == "english_only": 50 | return "english_us_arpa" 51 | else: 52 | return "english_mfa" 53 | elif lang == "es": 54 | return "spanish_mfa" 55 | elif lang == "fr": 56 | return "french_mfa" 57 | elif lang == "hr": 58 | return "croatian_mfa" 59 | elif lang == "pl": 60 | return "polish_mfa" 61 | elif lang == "pt": 62 | return "portuguese_mfa" 63 | elif lang == "ru": 64 | return "russian_mfa" 65 | elif lang == "sv": 66 | return "swedish_mfa" 67 | elif lang == "th": 68 | return "thai_mfa" 69 | elif lang == "tr": 70 | return "turkish_mfa" 71 | elif lang == "uk": 72 | return "ukrainian_mfa" 73 | raise Exception( 74 | f"No case selected in switch-statement - language '{lang}' is not supported ..." 75 | ) 76 | -------------------------------------------------------------------------------- /backend/voice_smith/utils/number_normalization_test.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from voice_smith.utils.number_normalization import ( 3 | get_number_normalizer, 4 | NumberNormLangType, 5 | ) 6 | 7 | _LANGUAGES_TO_CHECK: List[NumberNormLangType] = [ 8 | "cz", 9 | "de", 10 | "en", 11 | "es", 12 | "fr", 13 | "pl", 14 | "pt", 15 | "ru", 16 | "sv", 17 | "th", 18 | "tr", 19 | "uk", 20 | ] 21 | 22 | _NUMBER_INPUTS = [ 23 | "100.00", 24 | "22", 25 | "10000.0", 26 | "9.00100", 27 | "00.00", 28 | "987654210.123456789", 29 | "00.921", 30 | "100", 31 | ] 32 | 33 | _NO_NUMBER_INPUTS = [ 34 | "Parent00.1Safa", 35 | "0000Mother", 36 | "Mike Tyson", 37 | "?-!;%21+", 38 | ] 39 | 40 | 41 | def test_should_get_number_normalizers(): 42 | for lang in _LANGUAGES_TO_CHECK: 43 | get_number_normalizer(lang) 44 | 45 | 46 | def test_should_normalize_number(): 47 | for lang in _LANGUAGES_TO_CHECK: 48 | for number in _NUMBER_INPUTS: 49 | normalizer = get_number_normalizer(lang) 50 | output = normalizer.normalize(None, number, None) 51 | assert ( 52 | output.has_normalized 53 | ), f"Failed check for language {lang}, input: {number}, output: {output}" 54 | assert ( 55 | output.word != number 56 | ), f"Failed for language {lang}, input: {number}, output: {output}" 57 | assert ( 58 | not output.collapsed_prev 59 | ), f"Failed check for language {lang}, input: {number}, output: {output}" 60 | assert ( 61 | not output.collapsed_next 62 | ), f"Failed check for language {lang}, input: {number}, output: {output}" 63 | 64 | 65 | def test_should_not_normalize_number(): 66 | for lang in _LANGUAGES_TO_CHECK: 67 | for number in _NO_NUMBER_INPUTS: 68 | normalizer = get_number_normalizer(lang) 69 | output = normalizer.normalize(None, number, None) 70 | assert ( 71 | not output.has_normalized 72 | ), f"Failed check for language {lang}, input: {number}, output: {output}" 73 | assert ( 74 | not output.collapsed_prev 75 | ), f"Failed check for language {lang}, input: {number}, output: {output}" 76 | assert ( 77 | not output.collapsed_next 78 | ), f"Failed check for language {lang}, input: {number}, output: {output}" 79 | 80 | -------------------------------------------------------------------------------- /backend/voice_smith/utils/optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from typing import Iterable, Dict, Any 4 | from voice_smith.config.configs import ( 5 | AcousticFinetuningConfig, 6 | AcousticPretrainingConfig, 7 | AcousticModelConfigType, 8 | ) 9 | 10 | 11 | class ScheduledOptimPretraining: 12 | def __init__( 13 | self, 14 | parameters: Iterable, 15 | train_config: AcousticPretrainingConfig, 16 | model_config: AcousticModelConfigType, 17 | current_step: int, 18 | ): 19 | self._optimizer = torch.optim.Adam( 20 | parameters, 21 | betas=train_config.optimizer_config.betas, 22 | eps=train_config.optimizer_config.eps, 23 | ) 24 | self.n_warmup_steps = train_config.optimizer_config.warm_up_step 25 | self.anneal_steps = train_config.optimizer_config.anneal_steps 26 | self.anneal_rate = train_config.optimizer_config.anneal_rate 27 | self.current_step = current_step 28 | self.init_lr = model_config.encoder.n_hidden ** -0.5 29 | 30 | def step_and_update_lr(self, step: int) -> None: 31 | self._update_learning_rate(step) 32 | self._optimizer.step() 33 | 34 | def zero_grad(self) -> None: 35 | self._optimizer.zero_grad() 36 | 37 | def load_state_dict(self, state_dict: Dict[str, Any]) -> None: 38 | self._optimizer.load_state_dict(state_dict) 39 | 40 | def _get_lr_scale(self) -> float: 41 | lr_scale = np.min( 42 | [ 43 | np.power(1 if self.current_step == 0 else self.current_step, -0.5), 44 | np.power(self.n_warmup_steps, -1.5) * self.current_step, 45 | ] 46 | ) 47 | for s in self.anneal_steps: 48 | if self.current_step > s: 49 | lr_scale = lr_scale * self.anneal_rate 50 | return lr_scale 51 | 52 | def _update_learning_rate(self, step: int) -> None: 53 | """Learning rate scheduling per step""" 54 | self.current_step = step 55 | lr = self.init_lr * self._get_lr_scale() 56 | for param_group in self._optimizer.param_groups: 57 | param_group["lr"] = lr 58 | 59 | 60 | class ScheduledOptimFinetuning: 61 | def __init__( 62 | self, 63 | parameters: Iterable, 64 | train_config: AcousticFinetuningConfig, 65 | current_step: int, 66 | ): 67 | self._optimizer = torch.optim.AdamW( 68 | parameters, 69 | betas=train_config.optimizer_config.betas, 70 | eps=train_config.optimizer_config.eps, 71 | ) 72 | self.current_step = current_step 73 | self.init_lr = train_config.optimizer_config.learning_rate 74 | self.lr_decay = train_config.optimizer_config.lr_decay 75 | 76 | def step_and_update_lr(self, step: int) -> None: 77 | self._update_learning_rate(step) 78 | self._optimizer.step() 79 | 80 | def zero_grad(self) -> None: 81 | self._optimizer.zero_grad() 82 | 83 | def load_state_dict(self, state_dict: Dict[str, Any]) -> None: 84 | self._optimizer.load_state_dict(state_dict) 85 | 86 | def _get_lr_scale(self) -> float: 87 | lr_scale = self.lr_decay ** self.current_step 88 | return lr_scale 89 | 90 | def _update_learning_rate(self, step: int) -> None: 91 | """Learning rate scheduling per step""" 92 | self.current_step = step 93 | lr = self.init_lr * self._get_lr_scale() 94 | for param_group in self._optimizer.param_groups: 95 | param_group["lr"] = lr 96 | -------------------------------------------------------------------------------- /backend/voice_smith/utils/punctuation.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | _PUNCTUATION_CHARS_SET = set( 4 | [ 5 | "!", 6 | ".", 7 | "?", 8 | "։", 9 | "؟", 10 | "۔", 11 | "܀", 12 | "܁", 13 | "܂", 14 | "߹", 15 | "।", 16 | "॥", 17 | "၊", 18 | "။", 19 | "።", 20 | "፧", 21 | "፨", 22 | "᙮", 23 | "᜵", 24 | "᜶", 25 | "᠃", 26 | "᠉", 27 | "᥄", 28 | "᥅", 29 | "᪨", 30 | "᪩", 31 | "᪪", 32 | "᪫", 33 | "᭚", 34 | "᭛", 35 | "᭞", 36 | "᭟", 37 | "᰻", 38 | "᰼", 39 | "᱾", 40 | "᱿", 41 | "‼", 42 | "‽", 43 | "⁇", 44 | "⁈", 45 | "⁉", 46 | "⸮", 47 | "⸼", 48 | "꓿", 49 | "꘎", 50 | "꘏", 51 | "꛳", 52 | "꛷", 53 | "꡶", 54 | "꡷", 55 | "꣎", 56 | "꣏", 57 | "꤯", 58 | "꧈", 59 | "꧉", 60 | "꩝", 61 | "꩞", 62 | "꩟", 63 | "꫰", 64 | "꫱", 65 | "꯫", 66 | "﹒", 67 | "﹖", 68 | "﹗", 69 | "!", 70 | ".", 71 | "?", 72 | "𐩖", 73 | "𐩗", 74 | "𑁇", 75 | "𑁈", 76 | "𑂾", 77 | "𑂿", 78 | "𑃀", 79 | "𑃁", 80 | "𑅁", 81 | "𑅂", 82 | "𑅃", 83 | "𑇅", 84 | "𑇆", 85 | "𑇍", 86 | "𑇞", 87 | "𑇟", 88 | "𑈸", 89 | "𑈹", 90 | "𑈻", 91 | "𑈼", 92 | "𑊩", 93 | "𑑋", 94 | "𑑌", 95 | "𑗂", 96 | "𑗃", 97 | "𑗉", 98 | "𑗊", 99 | "𑗋", 100 | "𑗌", 101 | "𑗍", 102 | "𑗎", 103 | "𑗏", 104 | "𑗐", 105 | "𑗑", 106 | "𑗒", 107 | "𑗓", 108 | "𑗔", 109 | "𑗕", 110 | "𑗖", 111 | "𑗗", 112 | "𑙁", 113 | "𑙂", 114 | "𑜼", 115 | "𑜽", 116 | "𑜾", 117 | "𑩂", 118 | "𑩃", 119 | "𑪛", 120 | "𑪜", 121 | "𑱁", 122 | "𑱂", 123 | "𖩮", 124 | "𖩯", 125 | "𖫵", 126 | "𖬷", 127 | "𖬸", 128 | "𖭄", 129 | "𛲟", 130 | "𝪈", 131 | "。", 132 | "。", 133 | "…", 134 | "..", 135 | "...", 136 | "....", 137 | ".....", 138 | "??", 139 | "???", 140 | "????", 141 | "?????", 142 | "!!", 143 | "!!!", 144 | "!!!!", 145 | "!!!!!", 146 | ] 147 | ) 148 | 149 | 150 | def get_punct(lang: str) -> Set[str]: 151 | return _PUNCTUATION_CHARS_SET 152 | -------------------------------------------------------------------------------- /backend/voice_smith/utils/runs.py: -------------------------------------------------------------------------------- 1 | import sqlite3 2 | from typing import Optional, Callable, List, Tuple, Any, Dict 3 | 4 | 5 | class StageRunner: 6 | def __init__( 7 | self, 8 | cur: sqlite3.Cursor, 9 | con: sqlite3.Connection, 10 | get_stage_name: Callable[[Any], str], 11 | stages: List[Tuple[str, Callable[[Any], bool]]], 12 | before_stage: Optional[Callable[[Any], None]] = None, 13 | after_stage: Optional[Callable[[Any], None]] = None, 14 | before_run: Optional[Callable[[Any], None]] = None, 15 | after_run: Optional[Callable[[Any], None]] = None, 16 | ): 17 | self.cur = cur 18 | self.con = con 19 | self.get_stage_name = get_stage_name 20 | self.stages = stages 21 | self.before_stage = before_stage 22 | self.after_stage = after_stage 23 | self.before_run = before_run 24 | self.after_run = after_run 25 | 26 | def run(self, **kwargs): 27 | finished = False 28 | if self.before_run is not None: 29 | self.before_run(**kwargs) 30 | while not finished: 31 | stage_name = self.get_stage_name(**kwargs) 32 | for n, stage in self.stages: 33 | if n == stage_name: 34 | if self.before_stage is not None: 35 | self.before_stage(**{**kwargs, "stage_name": stage_name}) 36 | finished = stage(**kwargs) 37 | if self.after_stage is not None: 38 | self.after_stage(**{**kwargs, "stage_name": stage_name}) 39 | break 40 | if self.after_run is not None: 41 | self.after_run(**kwargs) 42 | 43 | -------------------------------------------------------------------------------- /backend/voice_smith/utils/shell.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import os 3 | import time 4 | import signal 5 | import atexit 6 | import sys 7 | 8 | 9 | def run_conda_in_shell( 10 | cmd: str, environment_name: str, stderr_to_stdout: bool, sleep_time: float = 0.25 11 | ) -> bool: 12 | with subprocess.Popen( 13 | f"conda run -n {environment_name} --no-capture-output {cmd}", 14 | universal_newlines=True, 15 | env={**os.environ, "PYTHONNOUSERSITE": "True"}, 16 | shell=True, 17 | stdout=subprocess.PIPE, 18 | stderr=subprocess.STDOUT if stderr_to_stdout else None, 19 | preexec_fn=os.setsid, 20 | ) as process: 21 | handler_orig = signal.getsignal(signal.SIGTERM) 22 | 23 | def sigterm_handler(sig, frame): 24 | os.killpg(os.getpgid(process.pid), signal.SIGKILL) 25 | sys.exit() 26 | 27 | signal.signal(signal.SIGTERM, sigterm_handler) 28 | 29 | while True: 30 | output = process.stdout.readline() 31 | if output != "": 32 | print(output, flush=True) 33 | return_code = process.poll() 34 | if return_code is not None: 35 | signal.signal(signal.SIGTERM, handler_orig) 36 | return return_code == 0 37 | time.sleep(sleep_time) 38 | -------------------------------------------------------------------------------- /backend/voice_smith/utils/sql_logger.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Union 2 | from pathlib import Path 3 | import numpy as np 4 | from PIL import Image 5 | import torch 6 | from voice_smith.utils.audio import save_audio 7 | from voice_smith.utils.loggers import Logger 8 | 9 | 10 | class SQLLogger(Logger): 11 | def __init__(self, training_run_id: int, con, cursor, out_dir: str, stage: str): 12 | super().__init__() 13 | self.training_run_id = training_run_id 14 | self.con = con 15 | self.cur = cursor 16 | self.out_dir = Path(out_dir) 17 | self.stage = stage 18 | 19 | def log_image(self, name: str, image: np.ndarray, step: int) -> None: 20 | image = self.map_image_color(image) 21 | out_dir = self.out_dir / "image_logs" / name 22 | out_dir.mkdir(exist_ok=True, parents=True) 23 | pil_img = Image.fromarray((image * 255).astype(np.uint8)) 24 | pil_img.save(str(out_dir / f"{step}.png")) 25 | self.cur.execute( 26 | "INSERT INTO image_statistic(name, step, stage, training_run_id) VALUES(?, ?, ?, ?)", 27 | [name, step, self.stage, self.training_run_id], 28 | ) 29 | self.con.commit() 30 | 31 | def log_graph(self, name: str, value: float, step: int): 32 | self.cur.execute( 33 | "INSERT INTO graph_statistic(name, step, stage, value, training_run_id) VALUES(?, ?, ?, ?, ?)", 34 | [name, step, self.stage, value, self.training_run_id], 35 | ) 36 | self.con.commit() 37 | 38 | def log_audio(self, name: str, audio: np.ndarray, step: int, sr: int): 39 | out_dir = self.out_dir / "audio_logs" / name 40 | out_dir.mkdir(exist_ok=True, parents=True) 41 | save_audio(str(out_dir / f"{step}.flac"), torch.FloatTensor(audio), sr) 42 | self.cur.execute( 43 | "INSERT INTO audio_statistic(name, step, stage, training_run_id) VALUES(?, ?, ?, ?)", 44 | [name, step, self.stage, self.training_run_id], 45 | ) 46 | self.con.commit() 47 | 48 | def query(self, query: str, args: Tuple[Union[int, str, float], ...]) -> None: 49 | self.cur.execute(query, args) 50 | self.con.commit() 51 | 52 | -------------------------------------------------------------------------------- /backend/voice_smith/utils/text.py: -------------------------------------------------------------------------------- 1 | def strip_cont_whitespaces(string: str) -> str: 2 | new_string = "" 3 | last_whitespace = False 4 | for char in string: 5 | if char == " " and last_whitespace: 6 | continue 7 | new_string += char 8 | last_whitespace = char == " " 9 | return new_string 10 | -------------------------------------------------------------------------------- /backend/voice_smith/utils/wandb_logger.py: -------------------------------------------------------------------------------- 1 | import wandb 2 | import numpy as np 3 | from typing import Tuple, Union, Dict, Any 4 | from voice_smith.utils.loggers import Logger 5 | 6 | 7 | class WandBLogger(Logger): 8 | def __init__(self, training_run_name, config: Union[Dict[str, Any], None] = None): 9 | super().__init__() 10 | wandb.init(id=training_run_name, resume=None, config=config) 11 | 12 | def log_image(self, name: str, image: np.ndarray, step: int): 13 | image = self.map_image_color(image) 14 | wandb.log({name: wandb.Image(image)}, step=step) 15 | 16 | def log_graph(self, name: str, value: float, step: int): 17 | wandb.log({name: value}, step=step) 18 | 19 | def log_audio(self, name: str, audio: np.ndarray, step: int, sr: int): 20 | wandb.log({name: wandb.Audio(audio, sample_rate=sr)}, step=step) 21 | 22 | def query(self, query: str, args: Tuple[Union[int, str, float], ...]) -> None: 23 | pass 24 | -------------------------------------------------------------------------------- /src/App.test.js: -------------------------------------------------------------------------------- 1 | import React from "react"; 2 | import {render, screen} from "@testing-library/react"; 3 | import App from "./App"; 4 | 5 | test("renders learn react link", () => { 6 | render(); 7 | const linkElement = screen.getByText(/learn react/i); 8 | expect(linkElement).toBeInTheDocument(); 9 | }); 10 | -------------------------------------------------------------------------------- /src/app/store.ts: -------------------------------------------------------------------------------- 1 | import { configureStore } from "@reduxjs/toolkit"; 2 | import appInfo from "../features/appInfoSlice"; 3 | import runManager from "../features/runManagerSlice"; 4 | import useStats from "../features/usageStatsSlice"; 5 | import importSettings from "../features/importSettings"; 6 | import navigationSettings from "../features/navigationSettingsSlice"; 7 | 8 | export const store = configureStore({ 9 | reducer: { 10 | appInfo, 11 | runManager, 12 | useStats, 13 | importSettings, 14 | navigationSettings, 15 | }, 16 | }); 17 | 18 | export type RootState = ReturnType; 19 | export type AppDispatch = typeof store.dispatch; 20 | -------------------------------------------------------------------------------- /src/components/audio_player/AudioPlayer.tsx: -------------------------------------------------------------------------------- 1 | import React, { ReactElement, useEffect, useRef } from "react"; 2 | // @ts-ignore 3 | import WaveSurfer from "wavesurfer.js"; 4 | import { createUseStyles } from "react-jss"; 5 | import { WaveSurverInterface } from "../../interfaces"; 6 | import { GET_AUDIO_DATA_URL_CHANNEL } from "../../channels"; 7 | const { ipcRenderer } = window.require("electron"); 8 | 9 | const useStyles = createUseStyles({ 10 | waveWrapper: { 11 | display: "block", 12 | width: "100%", 13 | }, 14 | }); 15 | 16 | export default function AudioPlayer({ 17 | id, 18 | path, 19 | onPlayStateChange, 20 | isPlaying, 21 | height, 22 | }: { 23 | id: string; 24 | path: string; 25 | onPlayStateChange: (state: boolean) => void; 26 | isPlaying: boolean; 27 | height: number; 28 | }): ReactElement { 29 | const classes = useStyles(); 30 | const isMounted = useRef(false); 31 | const wavesurfer = useRef(null); 32 | const initWaveSurfer = () => { 33 | wavesurfer.current = WaveSurfer.create({ 34 | container: `#${id}`, 35 | waveColor: "grey", 36 | progressColor: "#1890ff", 37 | cursorColor: "#1890ff", 38 | barWidth: 1, 39 | cursorWidth: 1, 40 | height: height, 41 | barGap: 1, 42 | }); 43 | if (wavesurfer.current != null) { 44 | wavesurfer.current.on("finish", stopAudio); 45 | } 46 | }; 47 | 48 | const getDataUrl = () => { 49 | if (path === null) { 50 | return; 51 | } 52 | ipcRenderer 53 | .invoke(GET_AUDIO_DATA_URL_CHANNEL.IN, path) 54 | .then((dataUrl: string) => { 55 | if (wavesurfer.current === null || !isMounted.current) { 56 | return; 57 | } 58 | wavesurfer.current.load(dataUrl); 59 | }); 60 | }; 61 | 62 | const stopAudio = () => { 63 | if (wavesurfer.current === null) { 64 | return; 65 | } 66 | wavesurfer.current.stop(); 67 | onPlayStateChange(false); 68 | }; 69 | 70 | useEffect(() => { 71 | getDataUrl(); 72 | }, [path]); 73 | 74 | useEffect(() => { 75 | if (wavesurfer.current === null) { 76 | return; 77 | } 78 | if (!isPlaying) { 79 | wavesurfer.current.pause(); 80 | } else { 81 | wavesurfer.current.play(); 82 | } 83 | }, [isPlaying]); 84 | 85 | useEffect(() => { 86 | isMounted.current = true; 87 | initWaveSurfer(); 88 | getDataUrl(); 89 | return () => { 90 | isMounted.current = false; 91 | }; 92 | }, []); 93 | 94 | return ; 95 | } 96 | 97 | AudioPlayer.defaultProps = { 98 | height: 80, 99 | }; 100 | -------------------------------------------------------------------------------- /src/components/breadcrumb/BreadcrumbItem.tsx: -------------------------------------------------------------------------------- 1 | import React, { ReactElement } from "react"; 2 | import { Breadcrumb, Typography } from "antd"; 3 | import { Link } from "react-router-dom"; 4 | import { useSelector } from "react-redux"; 5 | import { RootState } from "../../app/store"; 6 | 7 | export default function BreadcrumbItem({ 8 | to, 9 | children, 10 | }: { 11 | to: string | null; 12 | children: string; 13 | }): ReactElement { 14 | const isDisabled = useSelector( 15 | (state: RootState) => state.navigationSettings.isDisabled 16 | ); 17 | return ( 18 | 19 | {to === null ? ( 20 | {children} 21 | ) : isDisabled ? ( 22 | {children} 23 | ) : ( 24 | {children} 25 | )} 26 | 27 | ); 28 | } 29 | 30 | BreadcrumbItem.defaultProps = { 31 | onClick: null, 32 | disabled: false, 33 | to: null, 34 | }; 35 | -------------------------------------------------------------------------------- /src/components/cards/DocumentationCard.tsx: -------------------------------------------------------------------------------- 1 | import React, { ReactElement } from "react"; 2 | import { Card } from "antd"; 3 | 4 | export default function DocumentationCard({ 5 | title, 6 | children, 7 | }: { 8 | title: string; 9 | children: ReactElement; 10 | }): ReactElement { 11 | return {children}; 12 | } 13 | -------------------------------------------------------------------------------- /src/components/cards/RunCard.tsx: -------------------------------------------------------------------------------- 1 | import React, { ReactElement, ReactNode } from "react"; 2 | import { Card } from "antd"; 3 | import HelpIcon from "../help/HelpIcon"; 4 | 5 | export default function RunCard({ 6 | title, 7 | buttons, 8 | children, 9 | disableFullHeight, 10 | docsUrl, 11 | }: { 12 | title: string | null; 13 | buttons: ReactNode[]; 14 | children: ReactNode; 15 | disableFullHeight: boolean; 16 | docsUrl: string | null; 17 | }): ReactElement { 18 | return ( 19 | 23 | {title} 24 | {docsUrl && ( 25 | 26 | )} 27 | 28 | ) 29 | } 30 | style={{ 31 | height: disableFullHeight ? null : "100%", 32 | display: "flex", 33 | flexDirection: "column", 34 | justifyContent: "space-between", 35 | }} 36 | bodyStyle={{ paddingTop: title === null ? 8 : null }} 37 | actions={ 38 | buttons.length === null 39 | ? null 40 | : [ 41 | 49 | {buttons.map((ButtonNode, index) => ( 50 | 56 | {ButtonNode} 57 | 58 | ))} 59 | , 60 | ] 61 | } 62 | > 63 | {children} 64 | 65 | ); 66 | } 67 | 68 | RunCard.defaultProps = { 69 | title: null, 70 | buttons: [], 71 | disableFullHeight: false, 72 | docsUrl: null, 73 | }; 74 | -------------------------------------------------------------------------------- /src/components/charts/PieChart.tsx: -------------------------------------------------------------------------------- 1 | import React, { ReactElement } from "react"; 2 | import { Card } from "antd"; 3 | import { createUseStyles } from "react-jss"; 4 | import { RadialChart } from "react-vis"; 5 | import { CHART_BG_COLORS } from "../../config"; 6 | 7 | const useStyles = createUseStyles({ 8 | card: { 9 | width: "100%", 10 | }, 11 | cardInner: { 12 | display: "flex", 13 | justifyContent: "center", 14 | alignItems: "center", 15 | }, 16 | }); 17 | 18 | export default function PieChart({ 19 | labels, 20 | data, 21 | title, 22 | chartHeight, 23 | chartWidth, 24 | }: { 25 | labels: string[]; 26 | data: number[]; 27 | title: string; 28 | chartHeight: number; 29 | chartWidth: number; 30 | }): ReactElement { 31 | const classes = useStyles(); 32 | const radialData = data.map((el: number, index: number) => ({ 33 | angle: el, 34 | label: labels[index], 35 | color: CHART_BG_COLORS[index], 36 | })); 37 | return ( 38 | 39 | 40 | 46 | 47 | 48 | ); 49 | } 50 | -------------------------------------------------------------------------------- /src/components/help/HelpButton.tsx: -------------------------------------------------------------------------------- 1 | import React, { ReactElement } from "react"; 2 | import { Button } from "antd"; 3 | import { QuestionOutlined } from "@ant-design/icons"; 4 | import { documentationUrl } from "../../config"; 5 | const { shell } = window.require("electron"); 6 | 7 | export default function HelpButton({ 8 | children, 9 | docsUrl, 10 | ...rest 11 | }: { 12 | children: ReactElement | string; 13 | docsUrl: string; 14 | [x: string]: any; 15 | }): ReactElement { 16 | return ( 17 | 18 | } 25 | onClick={() => { 26 | shell.openExternal(`${documentationUrl}${docsUrl}`); 27 | }} 28 | > 29 | {children} 30 | 31 | 32 | ); 33 | } 34 | -------------------------------------------------------------------------------- /src/components/help/HelpIcon.tsx: -------------------------------------------------------------------------------- 1 | import React, { ReactElement } from "react"; 2 | import { Button } from "antd"; 3 | import { QuestionOutlined } from "@ant-design/icons"; 4 | import { documentationUrl } from "../../config"; 5 | const { shell } = window.require("electron"); 6 | 7 | export default function HelpIcon({ 8 | docsUrl, 9 | ...props 10 | }: { 11 | docsUrl: string; 12 | [x: string]: any; 13 | }): ReactElement { 14 | return ( 15 | { 19 | shell.openExternal(`${documentationUrl}${docsUrl}`); 20 | }} 21 | shape="circle" 22 | icon={} 23 | /> 24 | ); 25 | } 26 | -------------------------------------------------------------------------------- /src/components/image/Image.tsx: -------------------------------------------------------------------------------- 1 | import React, { useEffect, useState, useRef, ReactElement } from "react"; 2 | import { GET_IMAGE_DATA_URL_CHANNEL } from "../../channels"; 3 | const { ipcRenderer } = window.require("electron"); 4 | 5 | export default function Image({ 6 | path, 7 | style, 8 | }: { 9 | path: string; 10 | style: {}; 11 | }): ReactElement { 12 | const isMounted = useRef(false); 13 | const [dataUrl, setDataUrl] = useState(""); 14 | 15 | const getDataUrl = () => { 16 | if (path === "") { 17 | return; 18 | } 19 | ipcRenderer 20 | .invoke(GET_IMAGE_DATA_URL_CHANNEL.IN, path) 21 | .then((dataUrl: string) => { 22 | if (!isMounted.current) { 23 | return; 24 | } 25 | setDataUrl(dataUrl); 26 | }); 27 | }; 28 | 29 | useEffect(() => { 30 | getDataUrl(); 31 | }, [path]); 32 | 33 | useEffect(() => { 34 | isMounted.current = true; 35 | return () => { 36 | isMounted.current = false; 37 | }; 38 | }, []); 39 | 40 | return dataUrl ? : <>>; 41 | } 42 | 43 | Image.defaultProps = { 44 | style: {}, 45 | }; 46 | -------------------------------------------------------------------------------- /src/components/inputs/AcousticModelTypeInput.tsx: -------------------------------------------------------------------------------- 1 | import React, { ReactElement } from "react"; 2 | import { Form, Select, Typography } from "antd"; 3 | import HelpIcon from "../help/HelpIcon"; 4 | import { Rule } from "antd/lib/form"; 5 | 6 | export default function AcousticModelTypeInput({ 7 | disabled, 8 | docsUrl, 9 | rules, 10 | name 11 | }: { 12 | disabled: boolean; 13 | docsUrl: string | null; 14 | rules: Rule[] | null; 15 | name: string; 16 | }): ReactElement { 17 | return ( 18 | 21 | Acoustic Model Type 22 | {docsUrl && } 23 | 24 | } 25 | name={name} 26 | rules={rules} 27 | > 28 | 29 | English Only 30 | 31 | Multilingual 32 | 33 | 34 | 35 | ); 36 | } 37 | 38 | AcousticModelTypeInput.defaultProps = { 39 | disabled: false, 40 | docsUrl: null, 41 | rules: null, 42 | }; 43 | -------------------------------------------------------------------------------- /src/components/inputs/AlignmentBatchSizeInput.tsx: -------------------------------------------------------------------------------- 1 | import React, { ReactElement } from "react"; 2 | import { Form, InputNumber, Typography } from "antd"; 3 | import HelpIcon from "../help/HelpIcon"; 4 | 5 | export default function AlignmentBatchSizeInput({ 6 | disabled, 7 | docsUrl, 8 | }: { 9 | disabled: boolean; 10 | docsUrl: string | null; 11 | }): ReactElement { 12 | return ( 13 | 16 | Forced Aligner Batch Size 17 | {docsUrl && } 18 | 19 | } 20 | name="forcedAlignmentBatchSize" 21 | > 22 | 28 | 29 | ); 30 | } 31 | 32 | AlignmentBatchSizeInput.defaultProps = { 33 | disabled: false, 34 | docsUrl: null, 35 | }; 36 | -------------------------------------------------------------------------------- /src/components/inputs/BatchSizeInput.tsx: -------------------------------------------------------------------------------- 1 | import React, { ReactElement } from "react"; 2 | import { Form, InputNumber, Typography } from "antd"; 3 | import HelpIcon from "../help/HelpIcon"; 4 | 5 | export default function BatchSizeInput({ 6 | disabled, 7 | name, 8 | docsUrl, 9 | }: { 10 | disabled: boolean; 11 | name: string; 12 | docsUrl: string | null; 13 | }): ReactElement { 14 | return ( 15 | 18 | Batch Size 19 | {docsUrl && } 20 | 21 | } 22 | name={name} 23 | > 24 | 25 | 26 | ); 27 | } 28 | 29 | BatchSizeInput.defaultProps = { 30 | disabled: false, 31 | docsUrl: null, 32 | }; 33 | -------------------------------------------------------------------------------- /src/components/inputs/DatasetInput.tsx: -------------------------------------------------------------------------------- 1 | import React, { useEffect, useState, useRef, ReactElement } from "react"; 2 | import { Form, Select, Typography } from "antd"; 3 | import HelpIcon from "../help/HelpIcon"; 4 | import { DatasetInterface } from "../../interfaces"; 5 | import { FETCH_DATASET_CANDIDATES_CHANNEL } from "../../channels"; 6 | 7 | const { ipcRenderer } = window.require("electron"); 8 | 9 | export default function DatasetInput({ 10 | disabled, 11 | docsUrl, 12 | }: { 13 | disabled: boolean; 14 | docsUrl: string | null; 15 | }): ReactElement { 16 | const isMounted = useRef(false); 17 | const [datasets, setDatasets] = useState([]); 18 | 19 | const fetchDatasets = () => { 20 | ipcRenderer 21 | .invoke(FETCH_DATASET_CANDIDATES_CHANNEL.IN) 22 | .then((datasets: DatasetInterface[]) => { 23 | console.log(datasets); 24 | if (!isMounted.current) { 25 | return; 26 | } 27 | setDatasets(datasets); 28 | }); 29 | }; 30 | 31 | useEffect(() => { 32 | isMounted.current = true; 33 | fetchDatasets(); 34 | return () => { 35 | isMounted.current = false; 36 | }; 37 | }, []); 38 | 39 | return ( 40 | ({ 43 | validator(_, value: string) { 44 | if (value === null) { 45 | return Promise.reject(new Error("Please select a dataset")); 46 | } 47 | return Promise.resolve(); 48 | }, 49 | }), 50 | ]} 51 | label={ 52 | 53 | Dataset 54 | {docsUrl && } 55 | 56 | } 57 | name="datasetID" 58 | > 59 | 60 | {datasets.map((dataset: DatasetInterface) => ( 61 | 66 | {dataset.name} 67 | 68 | ))} 69 | 70 | 71 | ); 72 | } 73 | 74 | DatasetInput.defaultProps = { 75 | disabled: false, 76 | docsUrl: null, 77 | }; 78 | -------------------------------------------------------------------------------- /src/components/inputs/DeviceInput.tsx: -------------------------------------------------------------------------------- 1 | import React, { useEffect, useState, useRef, ReactElement } from "react"; 2 | import { Form, Select, Alert, Typography } from "antd"; 3 | import { SERVER_URL } from "../../config"; 4 | import HelpIcon from "../help/HelpIcon"; 5 | const { shell } = window.require("electron"); 6 | 7 | export default function DeviceInput({ 8 | disabled, 9 | docsUrl, 10 | }: { 11 | disabled: boolean; 12 | docsUrl: string | null; 13 | }): ReactElement { 14 | const [cudaIsAvailable, setCudaIsAvailable] = useState(false); 15 | const [hasFetchedCuda, setHasFetchedCuda] = useState(false); 16 | const isMounted = useRef(false); 17 | 18 | const fetchIsCudaAvailable = () => { 19 | const ajax = new XMLHttpRequest(); 20 | ajax.open("GET", `${SERVER_URL}/is-cuda-available`); 21 | ajax.onload = () => { 22 | if (!isMounted.current) { 23 | return; 24 | } 25 | const response = JSON.parse(ajax.responseText); 26 | setCudaIsAvailable(response.available); 27 | setHasFetchedCuda(true); 28 | }; 29 | ajax.send(); 30 | }; 31 | 32 | useEffect(() => { 33 | isMounted.current = true; 34 | fetchIsCudaAvailable(); 35 | return () => { 36 | isMounted.current = false; 37 | }; 38 | }, []); 39 | 40 | return ( 41 | <> 42 | 45 | Device 46 | {docsUrl && ( 47 | 48 | )} 49 | 50 | } 51 | name="device" 52 | > 53 | 54 | CPU 55 | 56 | GPU 57 | 58 | 59 | 60 | {hasFetchedCuda && !cudaIsAvailable && ( 61 | 65 | No CUDA supported GPU was detected. While you can train on CPU, 66 | training on GPU is highly recommended since training on CPU will 67 | most likely take days. If you want to train on GPU{" "} 68 | { 70 | shell.openExternal("https://developer.nvidia.com/cuda-gpus"); 71 | }} 72 | > 73 | please make sure it has CUDA support 74 | {" "} 75 | and it's driver is up to date. Afterwards restart the app. 76 | 77 | } 78 | type="warning" 79 | /> 80 | )} 81 | > 82 | ); 83 | } 84 | 85 | DeviceInput.defaultProps = { 86 | disabled: false, 87 | docsUrl: null, 88 | }; 89 | -------------------------------------------------------------------------------- /src/components/inputs/GradientAccumulationInput.tsx: -------------------------------------------------------------------------------- 1 | import React, { ReactElement } from "react"; 2 | import { Form, InputNumber, Typography } from "antd"; 3 | import HelpIcon from "../help/HelpIcon"; 4 | 5 | export default function GradientAccumulationStepsInput({ 6 | disabled, 7 | name, 8 | docsUrl, 9 | }: { 10 | disabled: boolean; 11 | name: string; 12 | docsUrl: string | null; 13 | }): ReactElement { 14 | return ( 15 | 18 | Gradient Accumulation Steps 19 | {docsUrl && } 20 | 21 | } 22 | name={name} 23 | > 24 | 25 | 26 | ); 27 | } 28 | 29 | GradientAccumulationStepsInput.defaultProps = { 30 | disabled: false, 31 | docsUrl: null, 32 | }; 33 | -------------------------------------------------------------------------------- /src/components/inputs/GradientAccumulationStepsInput.tsx: -------------------------------------------------------------------------------- 1 | import React, { ReactElement } from "react"; 2 | import { Form, InputNumber, Typography } from "antd"; 3 | import HelpIcon from "../help/HelpIcon"; 4 | 5 | export default function GradientAccumulationStepsInput({ 6 | disabled, 7 | name, 8 | docsUrl, 9 | }: { 10 | disabled: boolean; 11 | name: string; 12 | docsUrl: string | null; 13 | }): ReactElement { 14 | return ( 15 | 18 | Gradient Accumulation Steps 19 | {docsUrl && } 20 | 21 | } 22 | name={name} 23 | > 24 | 25 | 26 | ); 27 | } 28 | 29 | GradientAccumulationStepsInput.defaultProps = { 30 | disabled: false, 31 | docsUrl: null, 32 | }; 33 | -------------------------------------------------------------------------------- /src/components/inputs/LanguageSelect.tsx: -------------------------------------------------------------------------------- 1 | import React, { ReactElement } from "react"; 2 | import { Select } from "antd"; 3 | import { LANGUAGES } from "../../config"; 4 | import { SpeakerInterface } from "../../interfaces"; 5 | 6 | export default function LanguageSelect({ 7 | value, 8 | className, 9 | onChange, 10 | disabled, 11 | }: { 12 | value: SpeakerInterface["language"] | null; 13 | className: string | null; 14 | onChange: ((lang: SpeakerInterface["language"]) => void) | null; 15 | disabled: boolean; 16 | }): ReactElement { 17 | if (value === null && onChange !== null) { 18 | throw new Error( 19 | `Invalid props received: value is null and onChange is not null, they both have to be null or both have to be not null ...` 20 | ); 21 | } else if (value !== null && onChange === null) { 22 | throw new Error( 23 | `Invalid props received: value is not null and onChange is null, they both have to be null or both have to be not null ...` 24 | ); 25 | } 26 | return ( 27 | 33 | {LANGUAGES.map((el) => ( 34 | 35 | {el.name} 36 | 37 | ))} 38 | 39 | ); 40 | } 41 | 42 | LanguageSelect.defaultProps = { 43 | className: null, 44 | disabled: false, 45 | value: null, 46 | onChange: null, 47 | }; 48 | -------------------------------------------------------------------------------- /src/components/inputs/LearningRateInput.tsx: -------------------------------------------------------------------------------- 1 | import React, { ReactElement } from "react"; 2 | import { Form, InputNumber, Typography } from "antd"; 3 | import HelpIcon from "../help/HelpIcon"; 4 | 5 | export default function LearningRateInput({ 6 | disabled, 7 | name, 8 | docsUrl, 9 | }: { 10 | disabled: boolean; 11 | name: string; 12 | docsUrl: string | null; 13 | }): ReactElement { 14 | return ( 15 | 18 | Learning Rate 19 | {docsUrl && } 20 | 21 | } 22 | name={name} 23 | rules={[ 24 | () => ({ 25 | validator(_, value) { 26 | if (value === 0) { 27 | return Promise.reject( 28 | new Error("Learning rate must be greater than zero") 29 | ); 30 | } 31 | return Promise.resolve(); 32 | }, 33 | }), 34 | ]} 35 | > 36 | 37 | 38 | ); 39 | } 40 | 41 | LearningRateInput.defaultProps = { 42 | disabled: false, 43 | docsUrl: null, 44 | }; 45 | -------------------------------------------------------------------------------- /src/components/inputs/MaximumWorkersInput.tsx: -------------------------------------------------------------------------------- 1 | import React, { ReactElement } from "react"; 2 | import { Form, Select, Typography } from "antd"; 3 | import HelpIcon from "../help/HelpIcon"; 4 | 5 | export default function MaximumWorkersInput({ 6 | disabled, 7 | docsUrl, 8 | }: { 9 | disabled: boolean; 10 | docsUrl: string | null; 11 | }): ReactElement { 12 | return ( 13 | 16 | Maximum Worker Count 17 | {docsUrl && } 18 | 19 | } 20 | name="maximumWorkers" 21 | > 22 | 23 | Auto 24 | {Array.from(Array(64 + 1).keys()) 25 | .slice(1) 26 | .map((el) => ( 27 | 28 | {el} 29 | 30 | ))} 31 | 32 | 33 | ); 34 | } 35 | 36 | MaximumWorkersInput.defaultProps = { 37 | disabled: false, 38 | docsUrl: null, 39 | }; 40 | -------------------------------------------------------------------------------- /src/components/inputs/NameInput.tsx: -------------------------------------------------------------------------------- 1 | import React, { useEffect, useState, useRef, ReactElement } from "react"; 2 | import { Form, Input } from "antd"; 3 | 4 | export default function DeviceInput({ 5 | disabled, 6 | fetchNames, 7 | }: { 8 | disabled: boolean; 9 | fetchNames: () => Promise; 10 | }): ReactElement { 11 | const [names, setNames] = useState([]); 12 | const isMounted = useRef(false); 13 | 14 | const fetchNamesInUse = async () => { 15 | const names = await fetchNames(); 16 | if (isMounted.current) { 17 | setNames(names); 18 | } 19 | }; 20 | 21 | useEffect(() => { 22 | isMounted.current = true; 23 | fetchNamesInUse(); 24 | return () => { 25 | isMounted.current = false; 26 | }; 27 | }, []); 28 | 29 | return ( 30 | ({ 35 | validator(_, value: string) { 36 | if (value.trim() === "") { 37 | return Promise.reject(new Error("Please enter a name")); 38 | } 39 | if (names.includes(value)) { 40 | return Promise.reject(new Error("This name is already in use")); 41 | } 42 | return Promise.resolve(); 43 | }, 44 | }), 45 | ]} 46 | > 47 | 48 | 49 | ); 50 | } 51 | 52 | DeviceInput.defaultProps = { 53 | disabled: false, 54 | }; 55 | -------------------------------------------------------------------------------- /src/components/inputs/RunValidationEveryInput.tsx: -------------------------------------------------------------------------------- 1 | import React, { ReactElement } from "react"; 2 | import { Form, InputNumber, Typography } from "antd"; 3 | import HelpIcon from "../help/HelpIcon"; 4 | 5 | export default function RunValidationEveryInput({ 6 | disabled, 7 | name, 8 | docsUrl, 9 | }: { 10 | disabled: boolean; 11 | name: string; 12 | docsUrl: string | null; 13 | }): ReactElement { 14 | return ( 15 | 18 | Run Validation Every 19 | {docsUrl && } 20 | 21 | } 22 | name={name} 23 | > 24 | 25 | 26 | ); 27 | } 28 | 29 | RunValidationEveryInput.defaultProps = { 30 | disabled: false, 31 | docsUrl: null, 32 | }; 33 | -------------------------------------------------------------------------------- /src/components/inputs/SkipOnErrorInput.tsx: -------------------------------------------------------------------------------- 1 | import React, { ReactElement } from "react"; 2 | import { Form, Select, Typography } from "antd"; 3 | import HelpIcon from "../help/HelpIcon"; 4 | 5 | export default function SkipOnErrorInput({ 6 | disabled, 7 | docsUrl, 8 | }: { 9 | disabled: boolean; 10 | docsUrl: string | null; 11 | }): ReactElement { 12 | return ( 13 | 16 | On Error Ignore Sample 17 | {docsUrl && } 18 | 19 | } 20 | name="skipOnError" 21 | > 22 | 23 | Yes 24 | No 25 | 26 | 27 | ); 28 | } 29 | 30 | SkipOnErrorInput.defaultProps = { 31 | disabled: false, 32 | docsUrl: null, 33 | }; 34 | -------------------------------------------------------------------------------- /src/components/inputs/TrainOnlySpeakerEmbedsUntilInput.tsx: -------------------------------------------------------------------------------- 1 | import React, { ReactElement } from "react"; 2 | import { Form, InputNumber, Typography } from "antd"; 3 | import HelpIcon from "../help/HelpIcon"; 4 | 5 | export default function TrainOnlySpeakerEmbedsUntilInput({ 6 | disabled, 7 | name, 8 | rules, 9 | docsUrl, 10 | }: { 11 | disabled: boolean; 12 | name: string; 13 | rules: any[]; 14 | docsUrl: string | null; 15 | }): ReactElement { 16 | return ( 17 | 21 | Train Only Speaker Embeds Until 22 | {docsUrl && } 23 | 24 | } 25 | name={name} 26 | > 27 | 28 | 29 | ); 30 | } 31 | 32 | TrainOnlySpeakerEmbedsUntilInput.defaultProps = { 33 | disabled: false, 34 | rules: [], 35 | docsUrl: null, 36 | }; 37 | -------------------------------------------------------------------------------- /src/components/inputs/TrainingStepsInput.tsx: -------------------------------------------------------------------------------- 1 | import React, { ReactElement } from "react"; 2 | import { Form, InputNumber, Typography } from "antd"; 3 | import HelpIcon from "../help/HelpIcon"; 4 | 5 | export default function TrainingStepsInput({ 6 | disabled, 7 | name, 8 | docsUrl, 9 | }: { 10 | disabled: boolean; 11 | name: string; 12 | docsUrl: string | null; 13 | }): ReactElement { 14 | return ( 15 | 18 | Training Steps 19 | {docsUrl && } 20 | 21 | } 22 | name={name} 23 | > 24 | 25 | 26 | ); 27 | } 28 | 29 | TrainingStepsInput.defaultProps = { 30 | disabled: false, 31 | docsUrl: null, 32 | }; 33 | -------------------------------------------------------------------------------- /src/components/log_printer/LogPrinter.tsx: -------------------------------------------------------------------------------- 1 | import React, { useState, useEffect, useRef, ReactElement } from "react"; 2 | import { FETCH_LOGFILE_CHANNEL } from "../../channels"; 3 | import { POLL_LOGFILE_INTERVALL } from "../../config"; 4 | import { useInterval } from "../../utils"; 5 | import Terminal from "./Terminal"; 6 | const { ipcRenderer } = window.require("electron"); 7 | 8 | export default function LogPrinter({ 9 | name, 10 | logFileName, 11 | type, 12 | }: { 13 | name: string | null; 14 | logFileName: string; 15 | type: 16 | | "trainingRun" 17 | | "model" 18 | | "cleaningRun" 19 | | "textNormalizationRun" 20 | | "sampleSplittingRun"; 21 | }): ReactElement { 22 | const [logLines, setLogLines] = useState([]); 23 | const isMounted = useRef(false); 24 | 25 | const pollLog = () => { 26 | if (name === null) { 27 | return; 28 | } 29 | ipcRenderer 30 | .invoke(FETCH_LOGFILE_CHANNEL.IN, name, logFileName, type) 31 | .then((lines: string[]) => { 32 | if (!isMounted.current) { 33 | return; 34 | } 35 | if (lines.length !== logLines.length) { 36 | setLogLines(lines); 37 | } 38 | }); 39 | }; 40 | 41 | useInterval(pollLog, POLL_LOGFILE_INTERVALL); 42 | 43 | useEffect(() => { 44 | isMounted.current = true; 45 | return () => { 46 | isMounted.current = false; 47 | }; 48 | }, []); 49 | 50 | return ( 51 | ({ type: "message", message: el }))} 53 | /> 54 | ); 55 | } 56 | -------------------------------------------------------------------------------- /src/components/log_printer/Terminal.tsx: -------------------------------------------------------------------------------- 1 | import React, { ReactElement } from "react"; 2 | import { Typography } from "antd"; 3 | import classNames from "classnames"; 4 | import { createUseStyles } from "react-jss"; 5 | import { TerminalMessage } from "../../interfaces"; 6 | 7 | const useStyles = createUseStyles({ 8 | wrapper: { 9 | width: "100%", 10 | maxHeight: 600, 11 | overflowY: "auto", 12 | padding: 16, 13 | backgroundColor: "#272727", 14 | borderRadius: 2, 15 | }, 16 | text: { fontFamily: "monospace", fontSize: 12 }, 17 | message: { 18 | color: "#9CD9F0", 19 | }, 20 | errorMessage: { 21 | color: "#E09690", 22 | }, 23 | }); 24 | 25 | export default function Terminal({ 26 | messages, 27 | maxLines, 28 | }: { 29 | messages: TerminalMessage[]; 30 | maxLines: number; 31 | }): ReactElement { 32 | const classes = useStyles(); 33 | const startIndex = Math.max(messages.length - maxLines, 0); 34 | 35 | return ( 36 | 37 | {messages.slice(startIndex).map((el, index) => ( 38 | 45 | {el.message} 46 | 47 | ))} 48 | 49 | ); 50 | } 51 | 52 | Terminal.defaultProps = { maxLines: 1000 }; 53 | -------------------------------------------------------------------------------- /src/components/modals/NoCloseModal.tsx: -------------------------------------------------------------------------------- 1 | import React, { ReactElement } from "react"; 2 | import { Modal } from "antd"; 3 | import { createUseStyles } from "react-jss"; 4 | 5 | const useStyles = createUseStyles({ 6 | footerWrapper: { 7 | display: "flex", 8 | justifyContent: "flex-end", 9 | }, 10 | }); 11 | 12 | const getFooter = (buttons: React.ReactNode[]) => { 13 | const classes = useStyles(); 14 | return ( 15 | 16 | {buttons.map((button, index) => ( 17 | 21 | {button} 22 | 23 | ))} 24 | 25 | ); 26 | }; 27 | 28 | export default function NoCloseModal({ 29 | visible, 30 | title, 31 | children, 32 | buttons, 33 | }: { 34 | visible: boolean; 35 | title: string; 36 | children: React.ReactNode; 37 | buttons: React.ReactNode[] | null; 38 | }): ReactElement { 39 | return ( 40 | >} 47 | footer={buttons === null ? null : getFooter(buttons)} 48 | > 49 | {children} 50 | 51 | ); 52 | } 53 | 54 | NoCloseModal.defaultProps = { 55 | buttons: null, 56 | }; 57 | -------------------------------------------------------------------------------- /src/components/runs/RunConfiguration.tsx: -------------------------------------------------------------------------------- 1 | import React, { ReactElement } from "react"; 2 | import { Button } from "antd"; 3 | import RunConfigurationForm from "../../components/runs/RunConfigurationForm"; 4 | import RunCard from "../../components/cards/RunCard"; 5 | 6 | export default function RunConfiguration({ 7 | title, 8 | forms, 9 | hasStarted, 10 | isDisabled, 11 | onBack, 12 | onDefaults, 13 | onSave, 14 | onNext, 15 | formRef, 16 | initialValues, 17 | onFinish, 18 | fetchNames, 19 | docsUrl, 20 | }: { 21 | title: string; 22 | forms: ReactElement; 23 | hasStarted: boolean; 24 | isDisabled: boolean; 25 | onBack: () => void; 26 | onDefaults: () => void; 27 | onSave: () => void; 28 | onNext: () => void; 29 | formRef: any; 30 | initialValues: { [key: string]: any }; 31 | onFinish: (values: any) => void; 32 | fetchNames: () => Promise; 33 | docsUrl: string | null; 34 | }): ReactElement { 35 | return ( 36 | Back, 41 | 42 | Reset to Default 43 | , 44 | 45 | Save 46 | , 47 | 48 | {hasStarted ? "Save and Next" : "Save and Start Run"} 49 | , 50 | ]} 51 | > 52 | 59 | {forms} 60 | 61 | 62 | ); 63 | } 64 | 65 | RunConfiguration.defaultProps = { 66 | docsUrl: null, 67 | }; 68 | -------------------------------------------------------------------------------- /src/components/runs/RunConfigurationForm.tsx: -------------------------------------------------------------------------------- 1 | import React, { ReactElement } from "react"; 2 | import { Form } from "antd"; 3 | import NameInput from "../inputs/NameInput"; 4 | 5 | export default function RunConfigurationForm({ 6 | formRef, 7 | initialValues, 8 | onFinish, 9 | fetchNames, 10 | isDisabled, 11 | children, 12 | }: { 13 | // TODO find correct type 14 | formRef: any; 15 | initialValues: { [key: string]: any }; 16 | onFinish: (values: any) => void; 17 | fetchNames: () => Promise; 18 | isDisabled: boolean; 19 | children: ReactElement; 20 | }): ReactElement { 21 | return ( 22 | { 25 | formRef.current = node; 26 | }} 27 | initialValues={initialValues} 28 | onFinish={onFinish} 29 | > 30 | 31 | {children} 32 | 33 | ); 34 | } 35 | -------------------------------------------------------------------------------- /src/config.ts: -------------------------------------------------------------------------------- 1 | import { 2 | TrainingRunConfigInterface, 3 | SpeakerInterface, 4 | CleaningRunConfigInterface, 5 | TextNormalizationRunConfigInterface, 6 | SampleSplittingRunConfigInterface, 7 | } from "./interfaces"; 8 | export const SERVER_URL = "http://localhost:12118"; 9 | export const POLL_LOGFILE_INTERVALL = 1000; 10 | export const POLL_NULL_INTERVALL = 50; 11 | export const CHART_BG_COLORS = [ 12 | "rgb(255, 99, 132)", 13 | "rgb(54, 162, 235)", 14 | "rgb(255, 205, 86)", 15 | ]; 16 | export const CHART_BG_COLORS_FADED = [ 17 | "rgba(255, 99, 132, 0.5)", 18 | "rgba(54, 162, 235, 0.5)", 19 | "rgba(255, 205, 86, 0.5)", 20 | ]; 21 | export const TEXT_EXTENSIONS = ["txt"]; 22 | export const AUDIO_EXTENSIONS = ["wav", "flac"]; 23 | export const STATISTIC_HEIGHT = 200; 24 | export const DOCKER_IMAGE_NAME = "voicesmith/voicesmith:v0.2.2"; 25 | export const DOCKER_CONTAINER_NAME = "voice_smith"; 26 | export const CONDA_ENV_NAME = "voice_smith"; 27 | export const LANGUAGES: { 28 | name: string; 29 | iso6391: SpeakerInterface["language"]; 30 | }[] = [ 31 | { 32 | name: "Bulgarian", 33 | iso6391: "bg", 34 | }, 35 | { 36 | name: "Czech", 37 | iso6391: "cs", 38 | }, 39 | { 40 | name: "German", 41 | iso6391: "de", 42 | }, 43 | { 44 | name: "English", 45 | iso6391: "en", 46 | }, 47 | { 48 | name: "Spanish", 49 | iso6391: "es", 50 | }, 51 | { 52 | name: "French", 53 | iso6391: "fr", 54 | }, 55 | { 56 | name: "Croatian", 57 | iso6391: "hr", 58 | }, 59 | { 60 | name: "Polish", 61 | iso6391: "pl", 62 | }, 63 | { 64 | name: "Portuguese", 65 | iso6391: "pt", 66 | }, 67 | { 68 | name: "Russian", 69 | iso6391: "ru", 70 | }, 71 | { 72 | name: "Swedish", 73 | iso6391: "sv", 74 | }, 75 | { 76 | name: "Thai", 77 | iso6391: "th", 78 | }, 79 | { 80 | name: "Turkish", 81 | iso6391: "tr", 82 | }, 83 | { 84 | name: "Ukrainian", 85 | iso6391: "uk", 86 | }, 87 | ]; 88 | 89 | export const trainingRunInitialValues: TrainingRunConfigInterface = { 90 | name: "", 91 | maximumWorkers: -1, 92 | validationSize: 5.0, 93 | minSeconds: 0.5, 94 | maxSeconds: 10, 95 | useAudioNormalization: true, 96 | acousticLearningRate: 0.0002, 97 | acousticTrainingIterations: 30000, 98 | acousticValidateEvery: 2000, 99 | acousticBatchSize: 5, 100 | acousticGradAccumSteps: 3, 101 | vocoderLearningRate: 0.0002, 102 | vocoderTrainingIterations: 20000, 103 | vocoderValidateEvery: 2000, 104 | vocoderBatchSize: 5, 105 | vocoderGradAccumSteps: 3, 106 | device: "CPU", 107 | onlyTrainSpeakerEmbUntil: 5000, 108 | datasetID: null, 109 | datasetName: null, 110 | skipOnError: true, 111 | forcedAlignmentBatchSize: 200000, 112 | acousticModelType: "multilingual" 113 | }; 114 | 115 | export const cleaningRunInitialValues: CleaningRunConfigInterface = { 116 | name: "", 117 | datasetID: null, 118 | datasetName: null, 119 | skipOnError: true, 120 | device: "CPU", 121 | maximumWorkers: -1, 122 | }; 123 | 124 | export const textNormalizationRunInitialValues: TextNormalizationRunConfigInterface = 125 | { 126 | name: "", 127 | datasetID: null, 128 | datasetName: null, 129 | }; 130 | 131 | export const sampleSplittingRunInitialValues: SampleSplittingRunConfigInterface = 132 | { 133 | name: "", 134 | maximumWorkers: -1, 135 | datasetID: null, 136 | datasetName: null, 137 | device: "CPU", 138 | skipOnError: true, 139 | forcedAlignmentBatchSize: 200000, 140 | }; 141 | 142 | export const defaultPageOptions = { 143 | defaultPageSize: 100, 144 | pageSizeOptions: [50, 100, 250, 1000], 145 | }; 146 | 147 | export const documentationUrl = "https://docs.voicesmith.io"; 148 | -------------------------------------------------------------------------------- /src/electron/handles/docker.ts: -------------------------------------------------------------------------------- 1 | import { ipcMain } from "electron"; 2 | import { START_BACKEND_CHANNEL } from "../../channels"; 3 | import { startContainer } from "../utils/docker"; 4 | 5 | ipcMain.handle(START_BACKEND_CHANNEL.IN, async () => { 6 | await startContainer(null, null); 7 | }); 8 | -------------------------------------------------------------------------------- /src/electron/handles/install.ts: -------------------------------------------------------------------------------- 1 | import { ipcMain, IpcMainEvent } from "electron"; 2 | import { getInstalledPath } from "../utils/globals"; 3 | import { 4 | INSTALL_BACKEND_CHANNEL, 5 | FINISH_INSTALL_CHANNEL, 6 | FETCH_HAS_DOCKER_CHANNEL, 7 | FETCH_NEEDS_INSTALL_CHANNEL, 8 | } from "../../channels"; 9 | import { exists } from "../utils/files"; 10 | import fsNative from "fs"; 11 | import { 12 | InstallBackendReplyInterface, 13 | InstallerOptionsInterface, 14 | } from "../../interfaces"; 15 | import { 16 | getHasDocker, 17 | resetDocker, 18 | createContainer, 19 | installEnvironment, 20 | stopContainer, 21 | } from "../utils/docker"; 22 | const fsPromises = fsNative.promises; 23 | 24 | ipcMain.handle(FETCH_HAS_DOCKER_CHANNEL.IN, async () => { 25 | return await getHasDocker(); 26 | }); 27 | 28 | ipcMain.handle(FETCH_NEEDS_INSTALL_CHANNEL.IN, async () => { 29 | return !(await exists(getInstalledPath())); 30 | }); 31 | 32 | ipcMain.handle(FINISH_INSTALL_CHANNEL.IN, async () => { 33 | const installedPath = getInstalledPath(); 34 | if (!(await exists(installedPath))) { 35 | await fsPromises.writeFile(installedPath, ""); 36 | } 37 | }); 38 | 39 | ipcMain.on( 40 | INSTALL_BACKEND_CHANNEL.IN, 41 | async (event: IpcMainEvent, installerOptions: InstallerOptionsInterface) => { 42 | await resetDocker(); 43 | const onData = (data: string) => { 44 | const reply: InstallBackendReplyInterface = { 45 | type: "message", 46 | message: data, 47 | }; 48 | event.reply(INSTALL_BACKEND_CHANNEL.REPLY, reply); 49 | }; 50 | const onError = (data: string) => { 51 | const reply: InstallBackendReplyInterface = { 52 | type: "error", 53 | message: data, 54 | }; 55 | event.reply(INSTALL_BACKEND_CHANNEL.REPLY, reply); 56 | }; 57 | await createContainer(onData, onError, installerOptions.device === "GPU"); 58 | await installEnvironment(onData, onError); 59 | const reply: InstallBackendReplyInterface = { 60 | type: "finished", 61 | message: "", 62 | success: true, 63 | }; 64 | await stopContainer(); 65 | event.reply(INSTALL_BACKEND_CHANNEL.REPLY, reply); 66 | } 67 | ); 68 | -------------------------------------------------------------------------------- /src/electron/handles/models.ts: -------------------------------------------------------------------------------- 1 | import { ipcMain, IpcMainInvokeEvent } from "electron"; 2 | import { DB } from "../utils/db"; 3 | import { FETCH_MODELS_CHANNEL, REMOVE_MODEL_CHANNEL } from "../../channels"; 4 | 5 | ipcMain.handle(FETCH_MODELS_CHANNEL.IN, async () => { 6 | const speakersStmt = DB.getInstance().prepare( 7 | "SELECT name, speaker_id AS speakerID FROM model_speaker WHERE model_id=@ID" 8 | ); 9 | const models = DB.getInstance() 10 | .prepare( 11 | "SELECT ID, name, type, description, created_at AS createdAt FROM model" 12 | ) 13 | .all(); 14 | for (const model of models) { 15 | model.speakers = speakersStmt.all({ ID: model.ID }); 16 | } 17 | return models; 18 | }); 19 | 20 | ipcMain.handle( 21 | REMOVE_MODEL_CHANNEL.IN, 22 | (event: IpcMainInvokeEvent, modelID: number) => { 23 | DB.getInstance() 24 | .prepare("DELETE FROM model WHERE ID=@modelID") 25 | .run({ modelID }); 26 | } 27 | ); 28 | -------------------------------------------------------------------------------- /src/electron/handles/settings.ts: -------------------------------------------------------------------------------- 1 | import { app, ipcMain, IpcMainEvent } from "electron"; 2 | import fsNative from "fs"; 3 | const fsPromises = fsNative.promises; 4 | import { 5 | GET_APP_INFO_CHANNEL, 6 | SAVE_SETTINGS_CHANNEL, 7 | FETCH_SETTINGS_CHANNEL, 8 | } from "../../channels"; 9 | import { SettingsInterface, AppInfoInterface } from "../../interfaces"; 10 | import { UserDataPath } from "../utils/globals"; 11 | import { DB } from "../utils/db"; 12 | import { copyDir, safeRmDir } from "../utils/files"; 13 | 14 | ipcMain.handle(GET_APP_INFO_CHANNEL.IN, (event: IpcMainEvent) => { 15 | const info: AppInfoInterface = { 16 | version: app.getVersion(), 17 | platform: process.platform, 18 | }; 19 | return info; 20 | }); 21 | 22 | ipcMain.on( 23 | SAVE_SETTINGS_CHANNEL.IN, 24 | async (event: IpcMainEvent, settings: SettingsInterface) => { 25 | const from = UserDataPath().getPath(); 26 | const updatePaths = from !== settings.dataPath; 27 | if (updatePaths) { 28 | await copyDir(from, settings.dataPath); 29 | } 30 | 31 | DB.getInstance() 32 | .prepare("UPDATE settings SET data_path=@dataPath WHERE ID=1") 33 | .run(settings); 34 | if (updatePaths) { 35 | await safeRmDir(from); 36 | UserDataPath().setPath(settings.dataPath); 37 | } 38 | event.reply(SAVE_SETTINGS_CHANNEL.REPLY, { type: "finished" }); 39 | } 40 | ); 41 | 42 | ipcMain.handle(FETCH_SETTINGS_CHANNEL.IN, () => { 43 | let settings = DB.getInstance() 44 | .prepare("SELECT data_path AS dataPath FROM settings WHERE ID=1") 45 | .get(); 46 | settings = { 47 | ...settings, 48 | dataPath: 49 | settings.dataPath === null ? UserDataPath().getPath() : settings.dataPath, 50 | }; 51 | return settings; 52 | }); 53 | -------------------------------------------------------------------------------- /src/electron/handles/synthesis.ts: -------------------------------------------------------------------------------- 1 | import { ipcMain, IpcMainInvokeEvent } from "electron"; 2 | import path from "path"; 3 | import { getAudioSynthDir } from "../utils/globals"; 4 | import { safeUnlink } from "../utils/files"; 5 | import { DB } from "../utils/db"; 6 | import { 7 | REMOVE_AUDIOS_SYNTH_CHANNEL, 8 | FETCH_AUDIOS_SYNTH_CHANNEL, 9 | } from "../../channels"; 10 | 11 | ipcMain.handle( 12 | FETCH_AUDIOS_SYNTH_CHANNEL.IN, 13 | async (event: IpcMainInvokeEvent) => { 14 | const audios = DB.getInstance() 15 | .prepare( 16 | ` 17 | SELECT 18 | ID, 19 | file_name AS fileName, 20 | text, 21 | speaker_name AS speakerName, 22 | model_name AS modelName, 23 | created_at AS createdAt, 24 | sampling_rate as samplingRate, 25 | dur_secs AS durSecs 26 | FROM audio_synth 27 | ORDER BY created_at DESC 28 | ` 29 | ) 30 | .all() 31 | .map((audio: any) => { 32 | audio.filePath = path.join(getAudioSynthDir(), audio.fileName); 33 | delete audio.fileName; 34 | return audio; 35 | }); 36 | return audios; 37 | } 38 | ); 39 | 40 | ipcMain.handle( 41 | REMOVE_AUDIOS_SYNTH_CHANNEL.IN, 42 | async (event: IpcMainInvokeEvent, audios: any[]) => { 43 | for (const audio of audios) { 44 | await safeUnlink(audio.filePath); 45 | } 46 | const stmt = DB.getInstance().prepare( 47 | "DELETE FROM audio_synth WHERE ID=@ID" 48 | ); 49 | const deleteMany = DB.getInstance().transaction((els: any) => { 50 | for (const el of els) stmt.run(el); 51 | }); 52 | deleteMany( 53 | audios.map((audio) => ({ 54 | ID: audio.ID, 55 | })) 56 | ); 57 | } 58 | ); 59 | -------------------------------------------------------------------------------- /src/electron/utils/files.ts: -------------------------------------------------------------------------------- 1 | import fs from "fs-extra"; 2 | import fsNative from "fs"; 3 | const fsPromises = fsNative.promises; 4 | 5 | export const exists = (file: string) => { 6 | return fsPromises 7 | .access(file, fs.constants.F_OK) 8 | .then(() => true) 9 | .catch(() => false); 10 | }; 11 | 12 | export const copyDir = (src: string, dest: string) => { 13 | return new Promise((resolve, reject) => { 14 | fs.copy(src, dest) 15 | .then(() => { 16 | resolve(null); 17 | }) 18 | .catch((err) => { 19 | reject(err); 20 | }); 21 | }); 22 | }; 23 | 24 | export const safeUnlink = async (path: string) => { 25 | try { 26 | await fsPromises.unlink(path); 27 | } catch (err) { 28 | if (err.code === "ENOENT") { 29 | return; 30 | } 31 | throw err; 32 | } 33 | }; 34 | 35 | export const safeMkdir = async (path: string) => { 36 | await fsPromises.mkdir(path, { recursive: true }); 37 | }; 38 | 39 | export const safeRmDir = async (path: string) => { 40 | try { 41 | await fsPromises.rm(path, { recursive: true, force: true }); 42 | } catch (err) { 43 | if (err.code === "ENOENT") { 44 | return; 45 | } 46 | throw err; 47 | } 48 | }; 49 | -------------------------------------------------------------------------------- /src/electron/utils/globals.ts: -------------------------------------------------------------------------------- 1 | import { app } from "electron"; 2 | import path from "path"; 3 | import isDev from "electron-is-dev"; 4 | import { DB } from "./db"; 5 | 6 | export const UserDataPath = function () { 7 | let dataPath: string | null = null; 8 | return { 9 | getPath: function () { 10 | if (dataPath === null) { 11 | const dataPathDB = DB.getInstance() 12 | .prepare("SELECT data_path AS dataPath FROM settings") 13 | .get().dataPath; 14 | if (dataPathDB === null) { 15 | dataPath = path.join(app.getPath("userData"), "data"); 16 | } else { 17 | dataPath = dataPathDB; 18 | } 19 | } 20 | return dataPath; 21 | }, 22 | setPath: function (path: string) { 23 | dataPath = path; 24 | }, 25 | }; 26 | }; 27 | 28 | const joinUserData = (pathToJoin: string) => () => { 29 | const userDataPath = UserDataPath().getPath(); 30 | return path.join(userDataPath, pathToJoin); 31 | }; 32 | 33 | export const BASE_PATH = app.getAppPath(); 34 | export const PORT = 12118; 35 | export const getModelsDir = joinUserData("models"); 36 | export const getTrainingRunsDir = joinUserData("training_runs"); 37 | export const getAudioSynthDir = joinUserData("audio_synth"); 38 | export const getDatasetsDir = joinUserData("datasets"); 39 | export const getCleaningRunsDir = joinUserData("cleaning_runs"); 40 | export const getTextNormalizationRunsDir = joinUserData( 41 | "text_normalization_runs" 42 | ); 43 | export const getSampleSplittingRunsDir = joinUserData("sample_splitting_runs"); 44 | export const getInstalledPath = joinUserData("INSTALLED"); 45 | export const PY_DIST_FOLDER = "backend_dist"; 46 | export const PY_FOLDER = "voice_smith"; 47 | export const RESSOURCES_PATH = isDev ? BASE_PATH : process.resourcesPath; 48 | export const CONDA_PATH = path.join(RESSOURCES_PATH, "backend"); 49 | export const ASSETS_PATH = path.join(RESSOURCES_PATH, "assets"); 50 | export const BACKEND_PATH = path.join(CONDA_PATH, "voice_smith"); 51 | export const DB_PATH = path.join( 52 | app.getPath("userData"), 53 | "db", 54 | "voice_smith.db" 55 | ); 56 | -------------------------------------------------------------------------------- /src/electron/utils/processes.ts: -------------------------------------------------------------------------------- 1 | import { IpcMainEvent } from "electron"; 2 | import { exec, ChildProcess } from "child_process"; 3 | import { stopContainer, spawnCondaCmd } from "./docker"; 4 | import { PORT } from "./globals"; 5 | import { DB } from "./db"; 6 | import { CONTINUE_TRAINING_RUN_CHANNEL } from "../../channels"; 7 | import { DOCKER_CONTAINER_NAME } from "../../config"; 8 | 9 | let serverProc: ChildProcess = null; 10 | let pyProc: ChildProcess = null; 11 | 12 | const killLastRun = () => { 13 | const pid = DB.getInstance().prepare("SELECT pid FROM settings").get().pid; 14 | if (pid !== null) { 15 | exec(`docker exec ${DOCKER_CONTAINER_NAME} kill -15 ${pid}`); 16 | } 17 | }; 18 | 19 | export const startRun = ( 20 | event: IpcMainEvent, 21 | scriptName: string, 22 | args: string[], 23 | logErr: boolean 24 | ): void => { 25 | pyProc = spawnCondaCmd( 26 | ["python", scriptName, ...args], 27 | null, 28 | logErr 29 | ? (data: string) => { 30 | event.reply(CONTINUE_TRAINING_RUN_CHANNEL.REPLY, { 31 | type: "error", 32 | errorMessage: data, 33 | }); 34 | } 35 | : null, 36 | (code: number) => { 37 | console.log("FINISHED RUN CALLED"); 38 | event.reply(CONTINUE_TRAINING_RUN_CHANNEL.REPLY, { 39 | type: "finishedRun", 40 | }); 41 | } 42 | ); 43 | 44 | event.reply(CONTINUE_TRAINING_RUN_CHANNEL.REPLY, { 45 | type: "startedRun", 46 | }); 47 | }; 48 | 49 | export const killServerProc = (): void => { 50 | if (serverProc === null) { 51 | return; 52 | } 53 | serverProc.kill("SIGKILL"); 54 | serverProc = null; 55 | }; 56 | 57 | export const killPyProc = (): void => { 58 | if (pyProc === null) { 59 | return; 60 | } 61 | pyProc.kill(); 62 | killLastRun(); 63 | pyProc = null; 64 | }; 65 | 66 | export const createServerProc = (): void => { 67 | // Make sure database object is created 68 | DB.getInstance(); 69 | serverProc = spawnCondaCmd( 70 | ["python", "./backend/voice_smith/server.py", "--port", "80"], 71 | null, 72 | (data: string) => { 73 | console.log(`stderr in server process: ${data}`); 74 | }, 75 | (code: number) => { 76 | if (code !== 0) { 77 | console.log( 78 | `Non zero exit code in server process, status code ${code}` 79 | ); 80 | } 81 | } 82 | ); 83 | 84 | if (serverProc != null) { 85 | console.log(`child process success on port ${PORT}`); 86 | } 87 | }; 88 | 89 | export const exit = (): void => { 90 | stopContainer(); 91 | }; 92 | -------------------------------------------------------------------------------- /src/features/appInfoSlice.ts: -------------------------------------------------------------------------------- 1 | import { createSlice, PayloadAction } from "@reduxjs/toolkit"; 2 | import { stat } from "original-fs"; 3 | import { AppInfoInterface } from "../interfaces"; 4 | 5 | const initialState: AppInfoInterface = { 6 | version: null, 7 | platform: null, 8 | }; 9 | 10 | export const appInfoSlice = createSlice({ 11 | name: "appInfo", 12 | initialState, 13 | reducers: { 14 | editAppInfo: (state, action: PayloadAction) => { 15 | state.platform = action.payload.platform; 16 | state.version = action.payload.version; 17 | }, 18 | }, 19 | }); 20 | 21 | export const { editAppInfo } = appInfoSlice.actions; 22 | export default appInfoSlice.reducer; 23 | -------------------------------------------------------------------------------- /src/features/importSettings.ts: -------------------------------------------------------------------------------- 1 | import { createSlice, PayloadAction } from "@reduxjs/toolkit"; 2 | import { ImportSettingsInterface } from "../interfaces"; 3 | 4 | const initialState: ImportSettingsInterface = { 5 | language: "en", 6 | }; 7 | 8 | export const importSettingsSlice = createSlice({ 9 | name: "importSettings", 10 | initialState, 11 | reducers: { 12 | editImportSettings: ( 13 | state, 14 | action: PayloadAction 15 | ) => { 16 | state.language = action.payload.language; 17 | }, 18 | }, 19 | }); 20 | 21 | export const { editImportSettings } = importSettingsSlice.actions; 22 | export default importSettingsSlice.reducer; 23 | -------------------------------------------------------------------------------- /src/features/navigationSettingsSlice.ts: -------------------------------------------------------------------------------- 1 | import { createSlice, PayloadAction } from "@reduxjs/toolkit"; 2 | import { NavigationSettingsInterface } from "../interfaces"; 3 | 4 | const initialState: NavigationSettingsInterface = { 5 | isDisabled: false, 6 | }; 7 | 8 | export const navigationSettingsSlice = createSlice({ 9 | name: "navigationSettings", 10 | initialState, 11 | reducers: { 12 | setNavIsDisabled: (state, action: PayloadAction) => { 13 | state.isDisabled = action.payload; 14 | }, 15 | }, 16 | }); 17 | 18 | export const { setNavIsDisabled } = navigationSettingsSlice.actions; 19 | export default navigationSettingsSlice.reducer; 20 | -------------------------------------------------------------------------------- /src/features/runManagerSlice.ts: -------------------------------------------------------------------------------- 1 | import { createSlice, PayloadAction } from "@reduxjs/toolkit"; 2 | import { notification } from "antd"; 3 | import { RunManagerInterface, RunInterface } from "../interfaces"; 4 | 5 | const initialState: RunManagerInterface = { 6 | isRunning: false, 7 | queue: [], 8 | }; 9 | 10 | export const runManagerSlice = createSlice({ 11 | name: "runManager", 12 | initialState, 13 | reducers: { 14 | setIsRunning: (state, action: PayloadAction) => { 15 | state.isRunning = action.payload; 16 | }, 17 | editQueue: (state, action: PayloadAction) => { 18 | console.log("EDIT QUEUE CALLED"); 19 | state.queue = action.payload; 20 | }, 21 | addToQueue: (state, action: PayloadAction) => { 22 | state.queue = [...state.queue, action.payload]; 23 | if (state.queue.length === 1) { 24 | state.isRunning = true; 25 | } else { 26 | notification["success"]({ 27 | message: "Your run has been added to the queue", 28 | placement: "top", 29 | }); 30 | } 31 | }, 32 | popFromQueue: (state) => { 33 | const newQueue = [...state.queue]; 34 | newQueue.shift(); 35 | state.queue = newQueue; 36 | }, 37 | }, 38 | }); 39 | 40 | export const { setIsRunning, addToQueue, popFromQueue, editQueue } = 41 | runManagerSlice.actions; 42 | export default runManagerSlice.reducer; 43 | -------------------------------------------------------------------------------- /src/features/usageStatsSlice.ts: -------------------------------------------------------------------------------- 1 | import { createSlice, PayloadAction } from "@reduxjs/toolkit"; 2 | import { UsageStatsInterface } from "../interfaces"; 3 | 4 | const initialState: UsageStatsInterface[] = []; 5 | const USAGE_STATS_MAX_LENGTH = 100; 6 | 7 | export const usageStatsSlice = createSlice({ 8 | name: "appInfo", 9 | initialState, 10 | reducers: { 11 | addStats: (state, action: PayloadAction) => { 12 | if (state.length >= USAGE_STATS_MAX_LENGTH) { 13 | state.shift(); 14 | } 15 | state.push(action.payload); 16 | }, 17 | }, 18 | }); 19 | 20 | export const { addStats } = usageStatsSlice.actions; 21 | export default usageStatsSlice.reducer; 22 | -------------------------------------------------------------------------------- /src/fonts/atmospheric.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dunky11/voicesmith/35b00c8619f5272b4ec2e087a7cc871c4f71aadf/src/fonts/atmospheric.ttf -------------------------------------------------------------------------------- /src/global.css: -------------------------------------------------------------------------------- 1 | @import "~react-vis/dist/style"; 2 | 3 | body { 4 | min-height: 100vh; 5 | min-width: 1600px; 6 | overflow-y: scroll; 7 | } 8 | 9 | /** Remove Chromes default outline on focus */ 10 | :focus { 11 | outline:none; 12 | } 13 | 14 | @font-face { 15 | font-family: 'atmospheric'; 16 | src: url('./fonts/atmospheric.ttf'); 17 | } 18 | -------------------------------------------------------------------------------- /src/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | VoiceSmith 6 | 7 | 8 | You need to enable JavaScript to run this app. 9 | 10 | 20 | 21 | 22 | -------------------------------------------------------------------------------- /src/pages/datasets/Datasets.tsx: -------------------------------------------------------------------------------- 1 | import React, { useEffect, useState, useRef, ReactElement } from "react"; 2 | import { Switch, Route, useHistory } from "react-router-dom"; 3 | import Dataset from "./Dataset"; 4 | import DatasetSelection from "./DatasetSelection"; 5 | import { DATASETS_ROUTE } from "../../routes"; 6 | 7 | export default function Datasets(): ReactElement { 8 | const isMounted = useRef(false); 9 | const history = useHistory(); 10 | const [selectedDatasetID, setSelectedDatasetID] = useState( 11 | null 12 | ); 13 | 14 | const passSelectedSpeakerID = (ID: number | null) => { 15 | if (ID === selectedDatasetID && ID !== null) { 16 | history.push(DATASETS_ROUTE.EDIT.ROUTE); 17 | } else { 18 | setSelectedDatasetID(ID); 19 | } 20 | }; 21 | 22 | useEffect(() => { 23 | if (selectedDatasetID === null) { 24 | return; 25 | } 26 | history.push(DATASETS_ROUTE.EDIT.ROUTE); 27 | }, [selectedDatasetID]); 28 | 29 | useEffect(() => { 30 | isMounted.current = true; 31 | return () => { 32 | isMounted.current = false; 33 | }; 34 | }, []); 35 | 36 | return ( 37 | 38 | ( 40 | 43 | )} 44 | path={DATASETS_ROUTE.SELECTION.ROUTE} 45 | > 46 | } 48 | path={DATASETS_ROUTE.EDIT.ROUTE} 49 | > 50 | 51 | ); 52 | } 53 | -------------------------------------------------------------------------------- /src/pages/datasets/ImportSettingsDialog.tsx: -------------------------------------------------------------------------------- 1 | import React, { ReactElement } from "react"; 2 | import { Modal, Form } from "antd"; 3 | import LanguageSelect from "../../components/inputs/LanguageSelect"; 4 | import { useDispatch, useSelector } from "react-redux"; 5 | import { editImportSettings } from "../../features/importSettings"; 6 | import { RootState } from "../../app/store"; 7 | 8 | export default function ImportSettingsDialog({ 9 | open, 10 | onClose, 11 | onOk, 12 | }: { 13 | open: boolean; 14 | onClose: () => void; 15 | onOk: () => void; 16 | }): ReactElement { 17 | const dispatch = useDispatch(); 18 | const importSettings = useSelector( 19 | (state: RootState) => state.importSettings 20 | ); 21 | 22 | return ( 23 | 30 | 31 | 32 | { 35 | dispatch( 36 | editImportSettings({ ...importSettings, language: lang }) 37 | ); 38 | }} 39 | /> 40 | 41 | 42 | 43 | ); 44 | } 45 | -------------------------------------------------------------------------------- /src/pages/documentation/Introduction.tsx: -------------------------------------------------------------------------------- 1 | import { Typography } from "antd"; 2 | import React, { ReactElement } from "react"; 3 | import DocumentationCard from "../../components/cards/DocumentationCard"; 4 | 5 | export default function Introduction(): ReactElement { 6 | return ( 7 | 8 | Introduction 9 | 10 | ); 11 | } 12 | -------------------------------------------------------------------------------- /src/pages/preprocessing_runs/dataset_cleaning/ApplyChanges.tsx: -------------------------------------------------------------------------------- 1 | import React, { ReactElement } from "react"; 2 | import { Tabs, Steps, Button, Card } from "antd"; 3 | import { LoadingOutlined } from "@ant-design/icons"; 4 | import { useDispatch, useSelector } from "react-redux"; 5 | import { RootState } from "../../../app/store"; 6 | import UsageStatsRow from "../../../components/usage_stats/UsageStatsRow"; 7 | import LogPrinter from "../../../components/log_printer/LogPrinter"; 8 | import { CleaningRunInterface, RunInterface } from "../../../interfaces"; 9 | import { 10 | getProgressTitle, 11 | getStageIsRunning, 12 | getWouldContinueRun, 13 | } from "../../../utils"; 14 | import RunCard from "../../../components/cards/RunCard"; 15 | import { setIsRunning, addToQueue } from "../../../features/runManagerSlice"; 16 | 17 | export default function ApplyChanges({ 18 | onStepChange, 19 | run, 20 | }: { 21 | onStepChange: (current: number) => void; 22 | run: CleaningRunInterface | null; 23 | }): ReactElement { 24 | const dispatch = useDispatch(); 25 | const running: RunInterface = useSelector((state: RootState) => { 26 | if (!state.runManager.isRunning || state.runManager.queue.length === 0) { 27 | return null; 28 | } 29 | return state.runManager.queue[0]; 30 | }); 31 | const stageIsRunning = getStageIsRunning( 32 | ["apply_changes"], 33 | run.stage, 34 | running, 35 | "cleaningRun", 36 | run.ID 37 | ); 38 | 39 | const wouldContinueRun = getWouldContinueRun( 40 | ["apply_changes"], 41 | run.stage, 42 | running, 43 | "cleaningRun", 44 | run.ID 45 | ); 46 | 47 | const onNextClick = () => { 48 | if (stageIsRunning) { 49 | dispatch(setIsRunning(false)); 50 | } else if (run.stage !== "finished") { 51 | dispatch(addToQueue({ ID: run.ID, type: "cleaningRun", name: run.name })); 52 | } 53 | }; 54 | 55 | const onBackClick = () => { 56 | onStepChange(2); 57 | }; 58 | 59 | const getNextButtonText = () => { 60 | if (stageIsRunning) { 61 | return "Pause Run"; 62 | } 63 | if (wouldContinueRun) { 64 | return "Continue Run"; 65 | } 66 | return ""; 67 | }; 68 | 69 | const current = 0; 70 | 71 | // TODO progress for apply changes 72 | return ( 73 | Back, 77 | run.stage !== "finished" && ( 78 | 79 | {getNextButtonText()} 80 | 81 | ), 82 | ]} 83 | > 84 | 85 | 86 | 87 | 88 | 89 | 98 | ) : undefined 99 | } 100 | /> 101 | 102 | 103 | 104 | 105 | 110 | 111 | 112 | 113 | ); 114 | } 115 | -------------------------------------------------------------------------------- /src/pages/preprocessing_runs/dataset_cleaning/Preprocessing.tsx: -------------------------------------------------------------------------------- 1 | import React, { ReactElement } from "react"; 2 | import PreprocessingSteps from "../../../components/runs/ProcessingSteps"; 3 | import { CleaningRunInterface } from "../../../interfaces"; 4 | 5 | export default function Configuration({ 6 | onStepChange, 7 | run, 8 | }: { 9 | onStepChange: (current: number) => void; 10 | run: CleaningRunInterface; 11 | }): ReactElement { 12 | const onBack = () => { 13 | onStepChange(0); 14 | }; 15 | 16 | const onNext = () => { 17 | onStepChange(2); 18 | }; 19 | 20 | return ( 21 | 45 | ); 46 | } 47 | -------------------------------------------------------------------------------- /src/pages/preprocessing_runs/sample_splitting/ApplyChanges.tsx: -------------------------------------------------------------------------------- 1 | import React, { ReactElement } from "react"; 2 | import ProcessingSteps from "../../../components/runs/ProcessingSteps"; 3 | import { SampleSplittingRunInterface } from "../../../interfaces"; 4 | 5 | export default function ApplyChanges({ 6 | onStepChange, 7 | run, 8 | }: { 9 | onStepChange: (current: number) => void; 10 | run: SampleSplittingRunInterface | null; 11 | }): ReactElement { 12 | const onBack = () => { 13 | onStepChange(2); 14 | }; 15 | 16 | return ( 17 | 43 | ); 44 | } 45 | -------------------------------------------------------------------------------- /src/pages/preprocessing_runs/sample_splitting/Preprocessing.tsx: -------------------------------------------------------------------------------- 1 | import React, { ReactElement } from "react"; 2 | import PreprocessingSteps from "../../../components/runs/ProcessingSteps"; 3 | import { SampleSplittingRunInterface } from "../../../interfaces"; 4 | 5 | export default function Preprocessing({ 6 | onStepChange, 7 | run, 8 | }: { 9 | onStepChange: (current: number) => void; 10 | run: SampleSplittingRunInterface; 11 | }): ReactElement { 12 | const onBack = () => { 13 | onStepChange(0); 14 | }; 15 | 16 | const onNext = () => { 17 | onStepChange(2); 18 | }; 19 | 20 | return ( 21 | 58 | ); 59 | } 60 | -------------------------------------------------------------------------------- /src/pages/preprocessing_runs/text_normalization/Preprocessing.tsx: -------------------------------------------------------------------------------- 1 | import React, { ReactElement } from "react"; 2 | import { TextNormalizationRunInterface } from "../../../interfaces"; 3 | import PreprocessingSteps from "../../../components/runs/ProcessingSteps"; 4 | 5 | export default function Preprocessing({ 6 | onStepChange, 7 | run, 8 | }: { 9 | onStepChange: (current: number) => void; 10 | run: TextNormalizationRunInterface; 11 | }): ReactElement { 12 | const onBack = () => { 13 | onStepChange(0); 14 | }; 15 | 16 | const onNext = () => { 17 | onStepChange(2); 18 | }; 19 | 20 | return ( 21 | 40 | ); 41 | } 42 | -------------------------------------------------------------------------------- /src/pages/training_runs/AcousticModelFinetuning.tsx: -------------------------------------------------------------------------------- 1 | import React, { useEffect, useState, useRef, ReactElement } from "react"; 2 | import { Tabs, Card, Button } from "antd"; 3 | import { useDispatch, useSelector } from "react-redux"; 4 | import { RootState } from "../../app/store"; 5 | import AcousticStatistics from "./AcousticStatistics"; 6 | import UsageStatsRow from "../../components/usage_stats/UsageStatsRow"; 7 | import LogPrinter from "../../components/log_printer/LogPrinter"; 8 | import { RunInterface, TrainingRunInterface } from "../../interfaces"; 9 | import { getStageIsRunning, getWouldContinueRun } from "../../utils"; 10 | import RunCard from "../../components/cards/RunCard"; 11 | import { setIsRunning, addToQueue } from "../../features/runManagerSlice"; 12 | 13 | export default function AcousticModelFinetuning({ 14 | onStepChange, 15 | run, 16 | }: { 17 | onStepChange: (step: number) => void; 18 | run: TrainingRunInterface; 19 | }): ReactElement { 20 | const dispatch = useDispatch(); 21 | const running: RunInterface = useSelector((state: RootState) => { 22 | if (!state.runManager.isRunning || state.runManager.queue.length === 0) { 23 | return null; 24 | } 25 | return state.runManager.queue[0]; 26 | }); 27 | const isMounted = useRef(false); 28 | 29 | const stageIsRunning = getStageIsRunning( 30 | ["acoustic_fine_tuning"], 31 | run.stage, 32 | running, 33 | "trainingRun", 34 | run.ID 35 | ); 36 | const wouldContinueRun = getWouldContinueRun( 37 | ["acoustic_fine_tuning"], 38 | run.stage, 39 | running, 40 | "trainingRun", 41 | run.ID 42 | ); 43 | 44 | const onBackClick = () => { 45 | onStepChange(2); 46 | }; 47 | 48 | const onNextClick = () => { 49 | if (stageIsRunning) { 50 | dispatch(setIsRunning(false)); 51 | } else if (wouldContinueRun) { 52 | dispatch( 53 | addToQueue({ 54 | ID: run.ID, 55 | type: "trainingRun", 56 | name: run.name, 57 | }) 58 | ); 59 | } else { 60 | onStepChange(4); 61 | } 62 | }; 63 | 64 | const getNextButtonText = () => { 65 | if (stageIsRunning) { 66 | return "Pause Training"; 67 | } 68 | if (wouldContinueRun) { 69 | return "Continue Training"; 70 | } 71 | return "Next"; 72 | }; 73 | 74 | useEffect(() => { 75 | isMounted.current = true; 76 | return () => { 77 | isMounted.current = false; 78 | }; 79 | }, []); 80 | 81 | return ( 82 | Back, 87 | 94 | {getNextButtonText()} 95 | , 96 | ]} 97 | > 98 | 99 | 100 | 101 | 106 | 107 | 108 | 113 | 114 | 115 | 116 | ); 117 | } 118 | -------------------------------------------------------------------------------- /src/pages/training_runs/AudioStatistic.css: -------------------------------------------------------------------------------- 1 | .audio-slider { 2 | margin-top: 16px; 3 | } 4 | 5 | .audio-slider .ant-slider-mark { 6 | display: none; 7 | } -------------------------------------------------------------------------------- /src/pages/training_runs/GroundTruthAlignment.tsx: -------------------------------------------------------------------------------- 1 | import React, { ReactElement } from "react"; 2 | import { Tabs, Card, Button, Steps } from "antd"; 3 | import { LoadingOutlined } from "@ant-design/icons"; 4 | import { useDispatch, useSelector } from "react-redux"; 5 | import { RootState } from "../../app/store"; 6 | import RunCard from "../../components/cards/RunCard"; 7 | import LogPrinter from "../../components/log_printer/LogPrinter"; 8 | import UsageStatsRow from "../../components/usage_stats/UsageStatsRow"; 9 | import { getStageIsRunning, getWouldContinueRun } from "../../utils"; 10 | import { RunInterface, TrainingRunInterface } from "../../interfaces"; 11 | import { setIsRunning, addToQueue } from "../../features/runManagerSlice"; 12 | 13 | export default function GroundTruthAlignment({ 14 | onStepChange, 15 | run, 16 | }: { 17 | onStepChange: (step: number) => void; 18 | run: TrainingRunInterface; 19 | }): ReactElement { 20 | const dispatch = useDispatch(); 21 | const running: RunInterface = useSelector((state: RootState) => { 22 | if (!state.runManager.isRunning || state.runManager.queue.length === 0) { 23 | return null; 24 | } 25 | return state.runManager.queue[0]; 26 | }); 27 | const stageIsRunning = getStageIsRunning( 28 | ["ground_truth_alignment"], 29 | run.stage, 30 | running, 31 | "trainingRun", 32 | run.ID 33 | ); 34 | const wouldContinueRun = getWouldContinueRun( 35 | ["ground_truth_alignment"], 36 | run.stage, 37 | running, 38 | "trainingRun", 39 | run.ID 40 | ); 41 | 42 | const onBackClick = () => { 43 | onStepChange(3); 44 | }; 45 | 46 | const onNextClick = () => { 47 | if (stageIsRunning) { 48 | dispatch(setIsRunning(false)); 49 | } else if (wouldContinueRun) { 50 | dispatch( 51 | addToQueue({ 52 | ID: run.ID, 53 | type: "trainingRun", 54 | name: run.name, 55 | }) 56 | ); 57 | } else { 58 | onStepChange(5); 59 | } 60 | }; 61 | 62 | const getNextButtonText = () => { 63 | if (stageIsRunning) { 64 | return "Pause Training"; 65 | } 66 | if (wouldContinueRun) { 67 | return "Continue Training"; 68 | } 69 | return "Next"; 70 | }; 71 | 72 | return ( 73 | 76 | Back 77 | , 78 | 79 | {getNextButtonText()} 80 | , 81 | ]} 82 | title="Generate Ground Truth Alignments" 83 | docsUrl="/usage/training#generate-ground-truth-alignments" 84 | > 85 | 86 | 87 | 88 | 89 | 90 | : undefined} 94 | /> 95 | 96 | 97 | 98 | 99 | 104 | 105 | 106 | 107 | ); 108 | } 109 | -------------------------------------------------------------------------------- /src/pages/training_runs/ImageStatistic.css: -------------------------------------------------------------------------------- 1 | .image-slider { 2 | margin-top: 16px; 3 | } 4 | 5 | .image-slider .ant-slider-mark { 6 | display: none; 7 | } -------------------------------------------------------------------------------- /src/pages/training_runs/ImageStatistic.tsx: -------------------------------------------------------------------------------- 1 | import React, { useState, useEffect, useRef, ReactElement } from "react"; 2 | import { Slider, Card, Empty, Form } from "antd"; 3 | import { STATISTIC_HEIGHT } from "../../config"; 4 | import "./ImageStatistic.css"; 5 | import Image from "../../components/image/Image"; 6 | import { FormInstance } from "rc-field-form"; 7 | 8 | export default function ImageStatistic({ 9 | name, 10 | steps, 11 | paths, 12 | }: { 13 | name: string; 14 | steps: number[]; 15 | paths: string[]; 16 | }): ReactElement { 17 | const [selectedPath, setSelectedPath] = useState(""); 18 | const formRef = useRef(); 19 | 20 | const onStepChange = (selectedStep: number) => { 21 | let i = 0; 22 | for (const step of steps) { 23 | if (step === selectedStep) { 24 | setSelectedPath(paths[i]); 25 | return; 26 | } 27 | i += 1; 28 | } 29 | }; 30 | 31 | const statisticToMarks = () => { 32 | const obj: { [step: number]: string } = {}; 33 | for (const step of steps) { 34 | obj[step] = String(step); 35 | } 36 | return obj; 37 | }; 38 | 39 | useEffect(() => { 40 | if (paths.length > 0 && selectedPath === "") { 41 | setSelectedPath(paths[paths.length - 1]); 42 | formRef.current?.setFieldsValue({ step: steps[steps.length - 1] }); 43 | } 44 | }, [paths]); 45 | 46 | return ( 47 | 48 | {selectedPath === "" ? ( 49 | 53 | ) : ( 54 | 55 | 63 | 64 | )} 65 | { 67 | formRef.current = node; 68 | }} 69 | onValuesChange={(_, values) => { 70 | onStepChange(values.step); 71 | }} 72 | > 73 | 74 | 82 | 83 | 84 | 85 | ); 86 | } 87 | -------------------------------------------------------------------------------- /src/pages/training_runs/Preprocessing.tsx: -------------------------------------------------------------------------------- 1 | import React, { ReactElement } from "react"; 2 | import PreprocessingSteps from "../../components/runs/ProcessingSteps"; 3 | import { TrainingRunInterface } from "../../interfaces"; 4 | 5 | export default function Preprocessing({ 6 | onStepChange, 7 | run, 8 | }: { 9 | onStepChange: (step: number) => void; 10 | run: TrainingRunInterface; 11 | }): ReactElement { 12 | const onBack = () => { 13 | onStepChange(1); 14 | }; 15 | 16 | const onNext = () => { 17 | onStepChange(3); 18 | }; 19 | 20 | return ( 21 | 59 | ); 60 | } 61 | -------------------------------------------------------------------------------- /src/pages/training_runs/SaveModel.tsx: -------------------------------------------------------------------------------- 1 | import React, { ReactElement } from "react"; 2 | import PreprocessingSteps from "../../components/runs/ProcessingSteps"; 3 | import { TrainingRunInterface } from "../../interfaces"; 4 | 5 | export default function SaveModel({ 6 | onStepChange, 7 | run, 8 | }: { 9 | onStepChange: (step: number) => void; 10 | run: TrainingRunInterface; 11 | }): ReactElement { 12 | const onBack = () => { 13 | onStepChange(5); 14 | }; 15 | 16 | return ( 17 | 42 | ); 43 | } 44 | -------------------------------------------------------------------------------- /src/pages/training_runs/TrainingRuns.tsx: -------------------------------------------------------------------------------- 1 | import React, { useState, useEffect, useRef, ReactElement } from "react"; 2 | import { Route, Switch, useHistory } from "react-router-dom"; 3 | import { TRAINING_RUNS_ROUTE } from "../../routes"; 4 | import { REMOVE_TRAINING_RUN_CHANNEL } from "../../channels"; 5 | import { TrainingRunInterface } from "../../interfaces"; 6 | import CreateModel from "./CreateModel"; 7 | import RunSelection from "./RunSelection"; 8 | const { ipcRenderer } = window.require("electron"); 9 | 10 | export default function TrainingRuns(): ReactElement { 11 | const history = useHistory(); 12 | const trainingRunToRm = useRef(null); 13 | const [selectedTrainingRun, setSelectedTrainingRun] = 14 | useState(null); 15 | 16 | const selectTrainingRun = (run: TrainingRunInterface) => { 17 | trainingRunToRm.current = null; 18 | setSelectedTrainingRun(run); 19 | history.push(TRAINING_RUNS_ROUTE.CREATE_MODEL.ROUTE); 20 | }; 21 | 22 | const removeTrainingRun = (run: TrainingRunInterface) => { 23 | if (selectedTrainingRun !== null && run.ID === selectedTrainingRun.ID) { 24 | trainingRunToRm.current = run; 25 | setSelectedTrainingRun(null); 26 | } else { 27 | ipcRenderer.invoke(REMOVE_TRAINING_RUN_CHANNEL.IN, run.ID); 28 | } 29 | }; 30 | 31 | useEffect(() => { 32 | if (trainingRunToRm.current !== null) { 33 | ipcRenderer.invoke( 34 | REMOVE_TRAINING_RUN_CHANNEL.IN, 35 | trainingRunToRm.current.ID 36 | ); 37 | trainingRunToRm.current = null; 38 | } 39 | }, [selectedTrainingRun]); 40 | 41 | return ( 42 | 43 | 45 | selectedTrainingRun && ( 46 | 47 | ) 48 | } 49 | path={TRAINING_RUNS_ROUTE.CREATE_MODEL.ROUTE} 50 | > 51 | ( 53 | 57 | )} 58 | path={TRAINING_RUNS_ROUTE.RUN_SELECTION.ROUTE} 59 | > 60 | 61 | ); 62 | } 63 | -------------------------------------------------------------------------------- /src/pages/training_runs/VocoderFineTuning.tsx: -------------------------------------------------------------------------------- 1 | import React, { useEffect, useRef, ReactElement } from "react"; 2 | import { Tabs, Button } from "antd"; 3 | import { useDispatch, useSelector } from "react-redux"; 4 | import { RootState } from "../../app/store"; 5 | import VocoderStatistics from "./VocoderStatistics"; 6 | import LogPrinter from "../../components/log_printer/LogPrinter"; 7 | import { RunInterface, TrainingRunInterface } from "../../interfaces"; 8 | import { getStageIsRunning, getWouldContinueRun } from "../../utils"; 9 | import RunCard from "../../components/cards/RunCard"; 10 | import UsageStatsRow from "../../components/usage_stats/UsageStatsRow"; 11 | import { setIsRunning, addToQueue } from "../../features/runManagerSlice"; 12 | 13 | export default function VocoderFineTuning({ 14 | onStepChange, 15 | run, 16 | }: { 17 | onStepChange: (step: number) => void; 18 | run: TrainingRunInterface; 19 | }): ReactElement { 20 | const isMounted = useRef(false); 21 | const dispatch = useDispatch(); 22 | const running: RunInterface = useSelector((state: RootState) => { 23 | if (!state.runManager.isRunning || state.runManager.queue.length === 0) { 24 | return null; 25 | } 26 | return state.runManager.queue[0]; 27 | }); 28 | const stageIsRunning = getStageIsRunning( 29 | ["vocoder_fine_tuning"], 30 | run.stage, 31 | running, 32 | "trainingRun", 33 | run.ID 34 | ); 35 | const wouldContinueRun = getWouldContinueRun( 36 | ["vocoder_fine_tuning"], 37 | run.stage, 38 | running, 39 | "trainingRun", 40 | run.ID 41 | ); 42 | 43 | const onBackClick = () => { 44 | onStepChange(4); 45 | }; 46 | 47 | const onNextClick = () => { 48 | if (stageIsRunning) { 49 | dispatch(setIsRunning(false)); 50 | } else if (wouldContinueRun) { 51 | dispatch( 52 | addToQueue({ 53 | ID: run.ID, 54 | type: "trainingRun", 55 | name: run.name, 56 | }) 57 | ); 58 | } else { 59 | onStepChange(6); 60 | } 61 | }; 62 | 63 | const getNextButtonText = () => { 64 | if (stageIsRunning) { 65 | return "Pause Training"; 66 | } 67 | if (wouldContinueRun) { 68 | return "Continue Training"; 69 | } 70 | return "Next"; 71 | }; 72 | 73 | useEffect(() => { 74 | isMounted.current = true; 75 | return () => { 76 | isMounted.current = false; 77 | }; 78 | }, []); 79 | 80 | return ( 81 | Back, 84 | 85 | {getNextButtonText()} 86 | , 87 | ]} 88 | title="Vocoder Fine-Tuning" 89 | docsUrl="/usage/training#vocoder-fine-tuning" 90 | > 91 | 92 | 93 | 94 | 99 | 100 | 101 | 106 | 107 | 108 | 109 | ); 110 | } 111 | -------------------------------------------------------------------------------- /src/react-app-env.d.ts: -------------------------------------------------------------------------------- 1 | /// 2 | -------------------------------------------------------------------------------- /src/renderer.tsx: -------------------------------------------------------------------------------- 1 | import "antd/dist/antd.variable.min.css"; 2 | import React from "react"; 3 | import ReactDOM from "react-dom"; 4 | import { HashRouter } from "react-router-dom"; 5 | import { Provider } from "react-redux"; 6 | import { store } from "./app/store"; 7 | import App from "./App"; 8 | import reportWebVitals from "./reportWebVitals"; 9 | import "./global.css"; 10 | 11 | import { ConfigProvider } from "antd"; 12 | 13 | ConfigProvider.config({ 14 | theme: { 15 | primaryColor: "#2f54eb", 16 | }, 17 | }); 18 | 19 | ReactDOM.render( 20 | 21 | 22 | 23 | 24 | 25 | 26 | , 27 | document.getElementById("root") 28 | ); 29 | 30 | // If you want to start measuring performance in your app, pass a function 31 | // to log results (for example: reportWebVitals(console.log)) 32 | // or send to an analytics endpoint. Learn more: https://bit.ly/CRA-vitals 33 | reportWebVitals(); 34 | -------------------------------------------------------------------------------- /src/reportWebVitals.js: -------------------------------------------------------------------------------- 1 | const reportWebVitals = (onPerfEntry) => { 2 | if (onPerfEntry && onPerfEntry instanceof Function) { 3 | import("web-vitals").then(({getCLS, getFID, getFCP, getLCP, getTTFB}) => { 4 | getCLS(onPerfEntry); 5 | getFID(onPerfEntry); 6 | getFCP(onPerfEntry); 7 | getLCP(onPerfEntry); 8 | getTTFB(onPerfEntry); 9 | }); 10 | } 11 | }; 12 | 13 | export default reportWebVitals; 14 | -------------------------------------------------------------------------------- /src/routes.ts: -------------------------------------------------------------------------------- 1 | export const DATASETS_ROUTE = { 2 | ROUTE: "/datasets", 3 | EDIT: { 4 | ROUTE: "/datasets/edit", 5 | }, 6 | SELECTION: { 7 | ROUTE: "/datasets/dataset-selection", 8 | }, 9 | }; 10 | 11 | export const MODELS_ROUTE = { 12 | ROUTE: "/models", 13 | SELECTION: { ROUTE: "/models/selection" }, 14 | SYNTHESIZE: { ROUTE: "/models/synthesize" }, 15 | }; 16 | 17 | export const TRAINING_RUNS_ROUTE = { 18 | ROUTE: "/training-runs", 19 | RUN_SELECTION: { 20 | ROUTE: "/training-runs/run-selection", 21 | }, 22 | CREATE_MODEL: { 23 | ROUTE: "/training-runs/create-model", 24 | CONFIGURATION: { 25 | ROUTE: "/training-runs/create-model/configuration", 26 | }, 27 | DATA_PREPROCESSING: { 28 | ROUTE: "/training-runs/create-model/data-preprocessing", 29 | }, 30 | ACOUSTIC_TRAINING: { 31 | ROUTE: "/training-runs/create-model/acoustic-training", 32 | }, 33 | GENERATE_GTA: { 34 | ROUTE: "/training-runs/create-model/generate-gta", 35 | }, 36 | VOCODER_TRAINING: { 37 | ROUTE: "/training-runs/create-model/vocoder-training", 38 | }, 39 | SAVE_MODEL: { 40 | ROUTE: "/training-runs/create-model/save-gta", 41 | }, 42 | }, 43 | }; 44 | 45 | export const PREPROCESSING_RUNS_ROUTE = { 46 | ROUTE: "/preprocessing-runs", 47 | TEXT_NORMALIZATION: { 48 | ROUTE: "/preprocessing-runs/text-normalization", 49 | CONFIGURATION: { 50 | ROUTE: "/preprocessing-runs/text-normalization/configuration", 51 | }, 52 | RUNNING: { 53 | ROUTE: "/preprocessing-runs/text-normalization/running", 54 | }, 55 | CHOOSE_SAMPLES: { 56 | ROUTE: "/preprocessing-runs/text-normalization/choose-samples", 57 | }, 58 | }, 59 | DATASET_CLEANING: { 60 | ROUTE: "/preprocessing-runs/dataset-cleaning", 61 | CONFIGURATION: { 62 | ROUTE: "/preprocessing-runs/dataset-cleaning/configuration", 63 | }, 64 | RUNNING: { 65 | ROUTE: "/preprocessing-runs/dataset-cleaning/running", 66 | }, 67 | CHOOSE_SAMPLES: { 68 | ROUTE: "/preprocessing-runs/dataset-cleaning/choose-samples", 69 | }, 70 | APPLY_CHANGES: { 71 | ROUTE: "/preprocessing-runs/dataset-cleaning/apply-changes", 72 | }, 73 | }, 74 | RUN_SELECTION: { ROUTE: "/preprocessing-runs/run-selection" }, 75 | SAMPLE_SPLITTING: { 76 | ROUTE: "/preprocessing-runs/sample-splitting", 77 | CONFIGURATION: { 78 | ROUTE: "/preprocessing-runs/sample-splitting/configuration", 79 | }, 80 | RUNNING: { 81 | ROUTE: "/preprocessing-runs/sample-splitting/running", 82 | }, 83 | CHOOSE_SAMPLES: { 84 | ROUTE: "/preprocessing-runs/sample-splitting/choose-samples", 85 | }, 86 | APPLY_CHANGES: { 87 | ROUTE: "/preprocessing-runs/sample-splitting/apply-changes", 88 | }, 89 | }, 90 | }; 91 | 92 | export const SETTINGS_ROUTE = { 93 | ROUTE: "/settings", 94 | }; 95 | 96 | export const RUN_QUEUE_ROUTE = { 97 | ROUTE: "/run-queue", 98 | }; 99 | 100 | export const DOCUMENTATION_ROUTE = { 101 | INTODUCTION: { ROUTE: "/introduction" }, 102 | DATASETS: { ROUTE: "/" }, 103 | }; 104 | -------------------------------------------------------------------------------- /src/setupTests.js: -------------------------------------------------------------------------------- 1 | // jest-dom adds custom jest matchers for asserting on DOM nodes. 2 | // allows you to do things like: 3 | // expect(element).toHaveTextContent(/react/i) 4 | // learn more: https://github.com/testing-library/jest-dom 5 | import "@testing-library/jest-dom"; 6 | 7 | -------------------------------------------------------------------------------- /tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "jsx": "react", 4 | "allowJs": true, 5 | "module": "commonjs", 6 | "skipLibCheck": true, 7 | "esModuleInterop": true, 8 | "noImplicitAny": true, 9 | "sourceMap": true, 10 | "baseUrl": ".", 11 | "outDir": "dist", 12 | "moduleResolution": "node", 13 | "resolveJsonModule": true, 14 | "paths": { 15 | "*": ["node_modules/*"] 16 | } 17 | }, 18 | "include": ["src/**/*"] 19 | } 20 | -------------------------------------------------------------------------------- /webpack.main.config.js: -------------------------------------------------------------------------------- 1 | module.exports = { 2 | /** 3 | * This is the main entry point for your application, it's the first file 4 | * that runs in the main process. 5 | */ 6 | entry: './src/electron/electron.ts', 7 | // Put your normal webpack config below here 8 | module: { 9 | rules: require('./webpack.rules'), 10 | }, 11 | resolve: { 12 | extensions: ['.js', '.ts', '.jsx', '.tsx', '.css', '.json'], 13 | fallback: { "assert": false } 14 | }, 15 | }; -------------------------------------------------------------------------------- /webpack.plugins.js: -------------------------------------------------------------------------------- 1 | const ForkTsCheckerWebpackPlugin = require('fork-ts-checker-webpack-plugin'); 2 | 3 | module.exports = [ 4 | new ForkTsCheckerWebpackPlugin() 5 | ]; 6 | -------------------------------------------------------------------------------- /webpack.renderer.config.js: -------------------------------------------------------------------------------- 1 | const rules = require('./webpack.rules'); 2 | const plugins = require('./webpack.plugins'); 3 | 4 | rules.push({ 5 | test: /\.css$/, 6 | use: [{ loader: 'style-loader' }, { loader: 'css-loader' }], 7 | }); 8 | 9 | module.exports = { 10 | module: { 11 | rules, 12 | }, 13 | plugins: plugins, 14 | resolve: { 15 | extensions: ['.js', '.ts', '.jsx', '.tsx', '.css'] 16 | }, 17 | }; 18 | -------------------------------------------------------------------------------- /webpack.rules.js: -------------------------------------------------------------------------------- 1 | module.exports = [ 2 | // Add support for native node modules 3 | { 4 | // We're specifying native_modules in the test because the asset relocator loader generates a 5 | // "fake" .node file which is really a cjs file. 6 | test: /native_modules\/.+\.node$/, 7 | use: 'node-loader', 8 | }, 9 | { 10 | test: /\.(m?js|node)$/, 11 | parser: { amd: false }, 12 | use: { 13 | loader: '@vercel/webpack-asset-relocator-loader', 14 | options: { 15 | outputAssetBase: 'native_modules', 16 | }, 17 | }, 18 | }, 19 | { 20 | test: /\.tsx?$/, 21 | exclude: /(node_modules|\.webpack)/, 22 | use: { 23 | loader: 'ts-loader', 24 | options: { 25 | transpileOnly: true 26 | } 27 | } 28 | }, 29 | ]; 30 | --------------------------------------------------------------------------------