├── .DS_Store ├── .gitignore ├── HebrewToEnglish.py ├── README.md ├── Tacotron_Synthesis_Notebook_contest_notebook.ipynb ├── data_preprocess.py ├── filelists ├── train_list.txt └── validation_list.txt ├── inference.py ├── requirements.txt └── train.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxmelichov/Text-To-speech/c368377446e8e3367362814af006b7ede3d21919/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | venv 2 | *.pt 3 | data/* 4 | outdir 5 | checkpoints/logs 6 | inference_results/* 7 | not_needed_nomore 8 | glow.py 9 | __pycache__ 10 | /*/__pycache__ -------------------------------------------------------------------------------- /HebrewToEnglish.py: -------------------------------------------------------------------------------- 1 | import re 2 | from hebrew import Hebrew 3 | 4 | ##deals with . and , in a normal string 5 | def break_to_letter_and_rebuild(string): 6 | lst=[] 7 | i=0 8 | flag=False 9 | while i2 and is_number_range(parts[1]+'-'+parts[2]): 405 | return parts[0],(str(parts[1]),str(parts[2])) , "is_number_range" 406 | 407 | 408 | 409 | def clean_number(word): 410 | 411 | ## '4,000,' / '10,000' / '1000,' 412 | if is_number_with_comma(word): 413 | return NumberToHebrew(int(clean_number_with_comma(word))) 414 | 415 | ## '2.9' / '3.4' 416 | elif is_number_with_decimal(word): 417 | list_heb=[] 418 | part1,part2=clean_decimal(word) 419 | list_heb+=NumberToHebrew(part1) 420 | list_heb+=NumberToHebrew('.') 421 | list_heb+=NumberToHebrew(part2) 422 | return list_heb 423 | 424 | ## '4.5%' / '9.25%.' / '26.5%,' 425 | elif is_percentage(word): 426 | list_heb = [] 427 | part1, part2 = clean_decimal(clean_percentage(word)) 428 | 429 | if part2!=None: 430 | list_heb += NumberToHebrew(part1) 431 | list_heb += NumberToHebrew('.') 432 | list_heb += NumberToHebrew(part2) 433 | list_heb += NumberToHebrew('%') 434 | return list_heb 435 | else: 436 | list_heb += NumberToHebrew(part1) 437 | list_heb += NumberToHebrew('%') 438 | return list_heb 439 | 440 | ## '5-6' / '1971-1972,' / '2003-2005.' 441 | elif is_number_range(word): 442 | list_heb = [] 443 | part1, part2 = clean_number_range(word) 444 | list_heb += NumberToHebrew(part1) 445 | list_heb.append("עַד") 446 | list_heb += NumberToHebrew(part2) 447 | return list_heb 448 | 449 | ## 'בְּ-100,000' / בְּ-99.99% / הַ-1,100 / מִ-0.7% / לְ-1.9 / כְּ-22,000. 450 | elif is_pattern_number_with_heb(word): 451 | heb_letter,num,func=(clean_pattern_number_with_heb(word)) 452 | #arr_attr= (clean_pattern_number_with_heb(word)) 453 | list_heb = [] 454 | list_heb.append(heb_letter) 455 | 456 | if func=="is_number_with_comma": 457 | list_heb+=NumberToHebrew(int(num)) 458 | return list_heb 459 | 460 | elif func=="is_number_with_decimal": 461 | part1, part2 = num 462 | list_heb += NumberToHebrew(part1) 463 | list_heb += NumberToHebrew('.') 464 | list_heb += NumberToHebrew(part2) 465 | return list_heb 466 | 467 | elif func=="is_percentage": 468 | part1, part2 = clean_decimal(num) 469 | 470 | if part2 != None: 471 | list_heb += NumberToHebrew(part1) 472 | list_heb += NumberToHebrew('.') 473 | list_heb += NumberToHebrew(part2) 474 | list_heb += NumberToHebrew('%') 475 | return list_heb 476 | else: 477 | list_heb += NumberToHebrew(part1) 478 | list_heb += NumberToHebrew('%') 479 | return list_heb 480 | 481 | elif func == "is_number_range": 482 | part1, part2 = num 483 | list_heb += NumberToHebrew(int(part1)) 484 | list_heb.append("עַד") 485 | list_heb += NumberToHebrew(int(part2)) 486 | return list_heb 487 | 488 | ####################################################### 489 | 490 | ##takes a letter in hebrew and returns the sound in english 491 | def HebrewLetterToEnglishSound(obj,tzuptzik,last_letter=False): 492 | obj = Hebrew(obj).string 493 | # map the nikud symbols to their corresponding phenoms 494 | nikud_map = {"ָ": "a", "ַ": "a", "ֶ": "e", "ֵ": "e", "ִ": "i", "ְ": "", "ֹ": "o", "ֻ": "oo", 'ּ': "", 'ֲ': 'a'} 495 | 496 | 497 | beged_kefet_shin_sin = { 498 | ############ B 499 | "בּ": "b", 500 | "בְּ": "b", 501 | "בִּ": "bi", 502 | "בֹּ": "bo", 503 | "בֵּ": "be", 504 | "בֶּ": "be", 505 | "בַּ": "ba", 506 | "בָּ": "ba", 507 | "בֻּ": "boo", 508 | ############ G 509 | "גּ": "g", 510 | "גְּ": "g", 511 | "גִּ": "gi", 512 | "גֹּ": "go", 513 | "גֵּ": "ge", 514 | "גֶּ": "ge", 515 | "גַּ": "ga", 516 | "גָּ": "ga", 517 | "גֻּ": "goo", 518 | ########### D 519 | "דּ": "d", 520 | "דְּ": "d", 521 | "דִּ": "di", 522 | "דֹּ": "do", 523 | "דֵּ": "de", 524 | "דֶּ": "de", 525 | "דַּ": "da", 526 | "דָּ": "da", 527 | "דֻּ": "doo", 528 | ########### K 529 | "כּ": "k", 530 | "כְּ": "k", 531 | "כִּ": "ki", 532 | "כֹּ": "ko", 533 | "כֵּ": "ke", 534 | "כֶּ": "ke", 535 | "כַּ": "ka", 536 | "כָּ": "ka", 537 | "כֻּ": "koo", 538 | ############ P 539 | "פּ": "p", 540 | "פְּ": "p", 541 | "פִּ": "pi", 542 | "פֹּ": "po", 543 | "פֵּ": "pe", 544 | "פֶּ": "pe", 545 | "פַּ": "pa", 546 | "פָּ": "pa", 547 | "פֻּ": "poo", 548 | ############ T 549 | "תּ": "t", 550 | "תְּ": "t", 551 | "תִּ": "ti", 552 | "תֹּ": "to", 553 | "תֵּ": "te", 554 | "תֶּ": "te", 555 | "תַּ": "ta", 556 | "תָּ": "ta", 557 | "תֻּ": "too", 558 | ############ S 559 | "שׂ": "s", 560 | "שְׂ": "s", 561 | "שִׂ": "si", 562 | "שֹׂ": "so", 563 | "שֵׂ": "se", 564 | "שֶׂ": "se", 565 | "שַׂ": "sa", 566 | "שָׂ": "sa", 567 | "שֻׂ": "soo", 568 | ########### SH 569 | "שׁ": "sh", 570 | "שְׁ": "sh", 571 | "שִׁ": "shi", 572 | "שֹׁ": "sho", 573 | "שֵׁ": "she", 574 | "שֶׁ": "she", 575 | "שַׁ": "sha", 576 | "שָׁ": "sha", 577 | "שֻׁ": "shoo", 578 | } 579 | 580 | vav = { 581 | "וֵּו": "ve", 582 | "וּ": "oo", 583 | "וּו": "oo", 584 | "וֹ": "o", 585 | "וֹו": "oo", 586 | "וְ": "ve", 587 | "וֱו": "ve", 588 | "וִ": "vi", 589 | "וִו": "vi", 590 | "וַ": "va", 591 | "וַו": "va", 592 | "וֶ": "ve", 593 | "וֶו": "ve", 594 | "וָ": "va", 595 | "וָו": "va", 596 | "וֻ": "oo", 597 | "וֻו": "oo" 598 | } 599 | 600 | 601 | letters_map = { 602 | "א": "", 603 | "ב": "v", 604 | "ג": "g", 605 | "ד": "d", 606 | "ה": "hh", 607 | "ו": "v", 608 | "ז": "z", 609 | "ח": "h", 610 | "ט": "t", 611 | "י": "y", 612 | "כ": "h", 613 | "ל": "l", 614 | "מ": "m", 615 | "נ": "n", 616 | "ס": "s", 617 | "ע": "", 618 | "פ": "f", 619 | "צ": "ts", 620 | "ק": "k", 621 | "ר": "r", 622 | "ש": "sh", 623 | "ת": "t", 624 | "ן": "n", 625 | "ם": "m", 626 | "ף": "f", 627 | "ץ": "ts", 628 | "ך": "h", 629 | } 630 | 631 | patah_ganav={ 632 | "חַ": "ah", 633 | "חָ": "ah", 634 | "הַ": "hha", 635 | "הָ": "hha", 636 | "עַ": "a", 637 | "עָ": "a", 638 | 639 | } 640 | 641 | tzuptzik_letters={ 642 | ##G 643 | "ג": "j", 644 | "גְ": "j", 645 | "גִ": "ji", 646 | "גֹ": "jo", 647 | "גֵ": "je", 648 | "גֶ": "je", 649 | "גַ": "ja", 650 | "גָ": "ja", 651 | "גֻ": "joo", 652 | "גּ": "j", 653 | "גְּ": "j", 654 | "גִּ": "ji", 655 | "גֹּ": "jo", 656 | "גֵּ": "je", 657 | "גֶּ": "je", 658 | "גַּ": "ja", 659 | "גָּ": "ja", 660 | "גֻּ": "joo", 661 | 662 | ##ch 663 | "צ": "ch", 664 | "צְ": "ch", 665 | "צִ": "chi", 666 | "צֹ": "cho", 667 | "צֵ": "che", 668 | "צֶ": "che", 669 | "צַ": "cha", 670 | "צָ": "cha", 671 | "צֻ": "choo", 672 | 673 | ##ch 674 | "ץ": "ch", 675 | "ץְ": "ch", 676 | "ץִ": "chi", 677 | "ץֹ": "cho", 678 | "ץֵ": "che", 679 | "ץֶ": "che", 680 | "ץַ": "cha", 681 | "ץָ": "cha", 682 | "ץֻ": "choo", 683 | 684 | ##Z 685 | "ז": "zh", 686 | "זְ": "zh", 687 | "זִ": "zhi", 688 | "זֹ": "zho", 689 | "זֵ": "zhe", 690 | "זֶ": "zhe", 691 | "זַ": "zha", 692 | "זָ": "zha", 693 | "זֻ": "zhoo", 694 | } 695 | 696 | if last_letter: 697 | if obj in patah_ganav: 698 | return patah_ganav[obj] 699 | 700 | if tzuptzik==True: 701 | if obj in tzuptzik_letters: 702 | return tzuptzik_letters[obj] 703 | 704 | if obj in beged_kefet_shin_sin: 705 | return beged_kefet_shin_sin[obj] 706 | elif obj in vav: 707 | return vav[obj] 708 | else: 709 | lst = break_to_list(obj) 710 | string = "" 711 | for item in lst: 712 | if item in letters_map: 713 | string += letters_map[item] 714 | if item in nikud_map: 715 | string += nikud_map[item] 716 | 717 | return string 718 | 719 | 720 | ##takes hebrew word and turns it into the sound in english 721 | def HebrewWordToEnglishSound(word,index): 722 | new_sentence="" 723 | hs = Hebrew(word) 724 | hs = Hebrew(list(hs.graphemes)).string 725 | for i, letter in enumerate(hs): 726 | 727 | tzuptzik = False 728 | if i < len(hs) - 1: 729 | if hs[i + 1] == '\'': 730 | tzuptzik = True 731 | 732 | tav = HebrewLetterToEnglishSound(letter, tzuptzik, i == len(hs) - 1) 733 | new_sentence += tav 734 | 735 | ##clean list: 736 | try: 737 | if new_sentence[-1] == 'y' and new_sentence[-2] == 'y': 738 | new_sentence = new_sentence.replace("yy", "y") 739 | except: 740 | pass 741 | return new_sentence 742 | 743 | ##takes hebrew sentence and turns it into english sounds 744 | def HebrewToEnglish(sentence,index=0): 745 | words = sentence.split() 746 | new_sentence = "" 747 | 748 | for word in words: 749 | ##if number not in string 750 | if not has_number(word): 751 | 752 | ##breaks the word to letters and ',' and '.' 753 | broken_word=break_to_letter_and_rebuild(word) 754 | 755 | for brk_word in broken_word: 756 | 757 | ##tries to add silence 758 | if brk_word=='.' or brk_word==',' or brk_word==';': 759 | new_sentence += "q"+" " 760 | 761 | else: 762 | ret_sentence=HebrewWordToEnglishSound(brk_word,index) 763 | new_sentence+=ret_sentence+" " 764 | 765 | ##if there is a number: 766 | else: 767 | try: 768 | before_num,num,after_num=split_number_and_string(word) 769 | 770 | if has_number(after_num) or has_number(before_num): 771 | list_of_numbers=clean_number(word) 772 | for number in list_of_numbers: 773 | ret_sentence = HebrewWordToEnglishSound(number, index) 774 | new_sentence += ret_sentence + " " 775 | 776 | else: 777 | ret_sentence = HebrewWordToEnglishSound(before_num, index) 778 | new_sentence += ret_sentence+" " 779 | 780 | num = [s for s in word if s.isdigit()] 781 | num="".join(num) 782 | num=int(num) 783 | list_of_numbers=NumberToHebrew(num) 784 | for number in list_of_numbers: 785 | ret_sentence=HebrewWordToEnglishSound(number,index) 786 | new_sentence += ret_sentence + " " 787 | 788 | ret_sentence = HebrewWordToEnglishSound(after_num, index) 789 | new_sentence += ret_sentence + " " 790 | 791 | 792 | 793 | except: 794 | print("error from split_number_and_string in line:", index,"with word: ",word) 795 | 796 | 797 | 798 | return new_sentence 799 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Text-To-Speech (Robo-Shaul) 2 | 3 | Welcome to the Robo-Shaul repository! This project enables you to train your own Robo-Shaul or use pre-trained models to convert Hebrew text into speech using the Tacotron 2 TTS framework. 4 | 5 | Robo-Shaul was originally developed for a competition, where the winning model was trained for only 5k steps. After the competition, a more advanced model was trained for 90k steps using improved methodologies and a wider range of training data, resulting in significantly better performance. 6 | 7 | --- 8 | 9 | ## 🚀 Quick Start 10 | 11 | ### Prerequisites 12 | 13 | - Python 3.10 14 | 15 | ### Installation 16 | 17 | 1. **Clone the repository:** 18 | ```bash 19 | git clone https://github.com/maxmelichov/Text-To-speech.git 20 | cd Text-To-speech 21 | ``` 22 | 23 | 2. **Set up a virtual environment:** 24 | ```bash 25 | python3.10 -m venv venv 26 | source venv/bin/activate # Linux/Mac 27 | # or 28 | activate.bat # Windows 29 | ``` 30 | 31 | 3. **Install dependencies:** 32 | ```bash 33 | pip install -r requirements.txt 34 | ``` 35 | 36 | 4. **Clone required submodules and dependencies:** 37 | ```bash 38 | git clone https://github.com/maxmelichov/tacotron2.git 39 | git submodule init 40 | git submodule update 41 | git clone https://github.com/maxmelichov/waveglow.git 42 | cp waveglow/glow.py ./ 43 | ``` 44 | 45 | --- 46 | 47 | ## 📁 Project Structure 48 | 49 | The main directories used in this project are: 50 | 51 | ``` 52 | Text-To-speech/ 53 | ├── data/ # Place the SASPEECH dataset here 54 | ├── checkpoints/ # Stores Tacotron2 model checkpoints (*.pt files) 55 | ├── waveglow_weights/ # Stores WaveGlow model checkpoint (*.pt file) 56 | ├── tacotron2/ # Tacotron2 source code (cloned as submodule) 57 | ├── waveglow/ # WaveGlow source code (cloned as submodule) 58 | ├── ... 59 | ``` 60 | 61 | - **data/**: Put your downloaded and preprocessed dataset here. 62 | - **checkpoints/**: Save and load Tacotron2 model weights (e.g., `checkpoint_90000.pt`). 63 | - **waveglow_weights/**: Place the WaveGlow model checkpoint file (e.g., `waveglow_256channels.pt`). 64 | 65 | --- 66 | 67 | ## 📦 Download Pre-trained Models 68 | 69 | - **WaveGlow model:** [Download](https://drive.usercontent.google.com/download?id=19CVIL0TL_yyW-qC4jJ2vPht5cxc6VQpO&export=download&authuser=0) 70 | - **Model with 90K steps:** [Download](https://drive.google.com/uc?id=13B_NfAw8y-A9pg-xLcP5kQ_7dbObGc8S&export=download) 71 | - **Model with 5K steps:** [Download](https://drive.google.com/u/0/uc?id=1iE3VgeQsyZcIgAXYmwhk-FzWktwrT2Wo&export=download) 72 | 73 | --- 74 | 75 | ## 📚 Dataset 76 | 77 | - Download the SASPEECH dataset from [OpenSLR](https://openslr.org/134). 78 | 79 | --- 80 | 81 | ## 🛠️ Usage 82 | 83 | 1. **Preprocess the data:** 84 | ```bash 85 | python data_preprocess.py 86 | ``` 87 | After running the script, ensure you generate a `.txt` file in the same format as the examples in the `filelists` directory: 88 | 89 | ``` 90 | path/to/audio.wav|transcript in Hebrew that using English letters 91 | ``` 92 | 93 | 2. **Train the model:** 94 | ```bash 95 | python train.py 96 | ``` 97 | 98 | 3. **Generate speech (inference):** 99 | ```bash 100 | python inference.py 101 | ``` 102 | 103 | --- 104 | 105 | ## 💡 Demos & Resources 106 | 107 | - **Live Demo:** [Project Site](http://www.roboshaul.com/) 108 | - **Demo Page:** [here](https://maxmelichov.github.io/) 109 | - **Quick Start Notebook:** [Notebook](https://github.com/maxmelichov/Text-To-speech/blob/main/Tacotron_Synthesis_Notebook_contest_notebook.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1heUHKqCUwXGX_NRZUeN5J9UdB9UVV32m#scrollTo=IbrwoO0A1D0b) 110 | - **Project Podcast:** [חיות כיס episode](https://open.spotify.com/episode/7eM8KcpUGMxOk6X5WQYdh5?si=3xf0TNzwRTSHaCo8jIozOg) 111 | - **Training & Synthesis Videos:** [Part 1](https://www.youtube.com/watch?v=b1fzyM0VhhI) | [Part 2](https://www.youtube.com/watch?v=gVqSEIr2PD4&t=284s) 112 | 113 | --- 114 | 115 | ## 📝 Model Details 116 | 117 | - The system uses the SASPEECH dataset, a collection of unedited recordings from Shaul Amsterdamski for the 'Hayot Kis' podcast. 118 | - The TTS system is based on Nvidia's Tacotron 2, customized for Hebrew. 119 | 120 | **Note:** The model expects diacritized Hebrew (עברית מנוקדת). For diacritization, we recommend [Nakdimon](https://nakdimon.org) ([GitHub](https://github.com/elazarg/nakdimon)). 121 | 122 | 123 | --- 124 | 125 | ## 👥 Contact 126 | 127 | | Maxim Melichov | Tony Hasson | 128 | | -------------- | ----------- | 129 | | [LinkedIn](https://www.linkedin.com/in/max-melichov/) | [LinkedIn](https://www.linkedin.com/in/tony-hasson-a14402205/) | 130 | 131 | --- 132 | 133 | Feel free to reach out with questions or suggestions! 134 | -------------------------------------------------------------------------------- /Tacotron_Synthesis_Notebook_contest_notebook.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "metadata": {}, 7 | "source": [ 8 | "# MAKE SURE YOU RUN THE CELL IN ORDER!" 9 | ] 10 | }, 11 | { 12 | "attachments": {}, 13 | "cell_type": "markdown", 14 | "metadata": { 15 | "id": "ToPD6zvQuXCE" 16 | }, 17 | "source": [ 18 | "# Check GPU\n", 19 | "\n", 20 | "First, if you want, you can check to see which GPU is currently being used. The best GPUs are P100, V100, and T4. If you get a K80 or P4, you can restart the runtime and try again if you'd like, but all GPUs will (probably) work with this notebook." 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": null, 26 | "metadata": { 27 | "id": "nXzIEAZRuR0Q" 28 | }, 29 | "outputs": [], 30 | "source": [ 31 | "!nvidia-smi -L" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "metadata": { 38 | "cellView": "form", 39 | "id": "NatDNxccuUmw" 40 | }, 41 | "outputs": [], 42 | "source": [ 43 | "#@title Install Tacotron and Waveglow (click to see code)\n", 44 | "\n", 45 | "import os\n", 46 | "from os.path import exists, join, basename, splitext\n", 47 | "!pip install gdown\n", 48 | "!pip install Hebrew\n", 49 | "git_repo_url = 'https://github.com/maxmelichov/tacotron2.git'\n", 50 | "project_name = splitext(basename(git_repo_url))[0]\n", 51 | "if not exists(project_name):\n", 52 | " # clone and install\n", 53 | " !git clone -q --recursive {git_repo_url}\n", 54 | " !cd {project_name}/waveglow && git checkout 2fd4e63\n", 55 | " !pip install -q librosa unidecode\n", 56 | " !pip install Hebrew\n", 57 | " \n", 58 | "import sys\n", 59 | "sys.path.append(join(project_name, 'waveglow/'))\n", 60 | "sys.path.append(project_name)\n", 61 | "import time\n", 62 | "import matplotlib\n", 63 | "import matplotlib.pylab as plt\n", 64 | "import gdown\n", 65 | "from hebrew import Hebrew\n", 66 | "from hebrew.chars import HebrewChar, ALEPH\n", 67 | "from hebrew import GematriaTypes\n", 68 | "d = 'https://drive.google.com/uc?id='" 69 | ] 70 | }, 71 | { 72 | "cell_type": "markdown", 73 | "metadata": { 74 | "id": "TaPJp2QNut8U" 75 | }, 76 | "source": [ 77 | "The pre-trained model" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": 4, 83 | "metadata": { 84 | "cellView": "form", 85 | "id": "wgu8WXVkuqtD" 86 | }, 87 | "outputs": [ 88 | { 89 | "ename": "FileURLRetrievalError", 90 | "evalue": "Failed to retrieve file url:\n\n\tCannot retrieve the public link of the file. You may need to change\n\tthe permission to 'Anyone with the link', or have had many accesses.\n\tCheck FAQ in https://github.com/wkentaro/gdown?tab=readme-ov-file#faq.\n\nYou may still be able to access the file from the browser:\n\n\thttps://drive.google.com/uc?id=1ByWFnQsqIZuJpz7PgBgCo4ED63eFyKKp&export=download\n\nbut Gdown can't. Please check connections and permissions.", 91 | "output_type": "error", 92 | "traceback": [ 93 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 94 | "\u001b[0;31mFileURLRetrievalError\u001b[0m Traceback (most recent call last)", 95 | "File \u001b[0;32m~/Documents/Projects/Text-To-speech/venv/lib/python3.10/site-packages/gdown/download.py:267\u001b[0m, in \u001b[0;36mdownload\u001b[0;34m(url, output, quiet, proxy, speed, use_cookies, verify, id, fuzzy, resume, format, user_agent, log_messages)\u001b[0m\n\u001b[1;32m 266\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 267\u001b[0m url \u001b[38;5;241m=\u001b[39m \u001b[43mget_url_from_gdrive_confirmation\u001b[49m\u001b[43m(\u001b[49m\u001b[43mres\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtext\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 268\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m FileURLRetrievalError \u001b[38;5;28;01mas\u001b[39;00m e:\n", 96 | "File \u001b[0;32m~/Documents/Projects/Text-To-speech/venv/lib/python3.10/site-packages/gdown/download.py:55\u001b[0m, in \u001b[0;36mget_url_from_gdrive_confirmation\u001b[0;34m(contents)\u001b[0m\n\u001b[1;32m 54\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m url:\n\u001b[0;32m---> 55\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m FileURLRetrievalError(\n\u001b[1;32m 56\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCannot retrieve the public link of the file. \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 57\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mYou may need to change the permission to \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 58\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mAnyone with the link\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m, or have had many accesses. \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 59\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCheck FAQ in https://github.com/wkentaro/gdown?tab=readme-ov-file#faq.\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 60\u001b[0m )\n\u001b[1;32m 61\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m url\n", 97 | "\u001b[0;31mFileURLRetrievalError\u001b[0m: Cannot retrieve the public link of the file. You may need to change the permission to 'Anyone with the link', or have had many accesses. Check FAQ in https://github.com/wkentaro/gdown?tab=readme-ov-file#faq.", 98 | "\nDuring handling of the above exception, another exception occurred:\n", 99 | "\u001b[0;31mFileURLRetrievalError\u001b[0m Traceback (most recent call last)", 100 | "Cell \u001b[0;32mIn[4], line 9\u001b[0m\n\u001b[1;32m 6\u001b[0m tacotron2_pretrained_model \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mMLPTTS\u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[1;32m 7\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m exists(tacotron2_pretrained_model) \u001b[38;5;129;01mor\u001b[39;00m force_download_TT2:\n\u001b[1;32m 8\u001b[0m \u001b[38;5;66;03m# ↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓ PUT MODEL HERE\u001b[39;00m\n\u001b[0;32m----> 9\u001b[0m \u001b[43mgdown\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdownload\u001b[49m\u001b[43m(\u001b[49m\u001b[43md\u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[38;5;124;43mr\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43m1ByWFnQsqIZuJpz7PgBgCo4ED63eFyKKp&export=download\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtacotron2_pretrained_model\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mquiet\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[1;32m 10\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mTacotron2 Model Downloaded\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 11\u001b[0m \u001b[38;5;66;03m# ↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑ PUT MODEL HERE\u001b[39;00m\n", 101 | "File \u001b[0;32m~/Documents/Projects/Text-To-speech/venv/lib/python3.10/site-packages/gdown/download.py:278\u001b[0m, in \u001b[0;36mdownload\u001b[0;34m(url, output, quiet, proxy, speed, use_cookies, verify, id, fuzzy, resume, format, user_agent, log_messages)\u001b[0m\n\u001b[1;32m 268\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m FileURLRetrievalError \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 269\u001b[0m message \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 270\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mFailed to retrieve file url:\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 271\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mYou may still be able to access the file from the browser:\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 276\u001b[0m url_origin,\n\u001b[1;32m 277\u001b[0m )\n\u001b[0;32m--> 278\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m FileURLRetrievalError(message)\n\u001b[1;32m 280\u001b[0m filename_from_url \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 281\u001b[0m last_modified_time \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", 102 | "\u001b[0;31mFileURLRetrievalError\u001b[0m: Failed to retrieve file url:\n\n\tCannot retrieve the public link of the file. You may need to change\n\tthe permission to 'Anyone with the link', or have had many accesses.\n\tCheck FAQ in https://github.com/wkentaro/gdown?tab=readme-ov-file#faq.\n\nYou may still be able to access the file from the browser:\n\n\thttps://drive.google.com/uc?id=1ByWFnQsqIZuJpz7PgBgCo4ED63eFyKKp&export=download\n\nbut Gdown can't. Please check connections and permissions." 103 | ] 104 | } 105 | ], 106 | "source": [ 107 | "#@title Install Tacotron pretrained model\n", 108 | "from os.path import exists, join, basename, splitext\n", 109 | "import gdown\n", 110 | "d = 'https://drive.google.com/uc?id='\n", 111 | "force_download_TT2 = True\n", 112 | "tacotron2_pretrained_model = 'MLPTTS'\n", 113 | "if not exists(tacotron2_pretrained_model) or force_download_TT2:\n", 114 | " # ↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓ PUT MODEL HERE\n", 115 | " gdown.download(d+r'1ByWFnQsqIZuJpz7PgBgCo4ED63eFyKKp&export=download', tacotron2_pretrained_model, quiet=False)\n", 116 | " print(\"Tacotron2 Model Downloaded\")\n", 117 | " # ↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑ PUT MODEL HERE\n", 118 | "\n", 119 | "\n", 120 | "waveglow_pretrained_model = 'waveglow.pt'\n", 121 | "if not exists(waveglow_pretrained_model):\n", 122 | " gdown.download(d+r'1rpK8CzAAirq9sWZhe9nlfvxMF1dRgFbF&export=download', waveglow_pretrained_model, quiet=False); print(\"WaveGlow Model Downloaded\")#1okuUstGoBe_qZ4qUEF8CcwEugHP7GM_b&export" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": null, 128 | "metadata": { 129 | "id": "GKkUcNRsu9sr" 130 | }, 131 | "outputs": [], 132 | "source": [ 133 | "!python ./tacotron2/waveglow/convert_model.py /content/waveglow.pt /content/waveglow.pt" 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": null, 139 | "metadata": { 140 | "id": "uvqhtbjru_l-" 141 | }, 142 | "outputs": [], 143 | "source": [ 144 | "import locale\n", 145 | "locale.getpreferredencoding = lambda: \"UTF-8\"\n", 146 | "locale.getpreferredencoding()" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": null, 152 | "metadata": { 153 | "cellView": "form", 154 | "id": "GAhZ-94-vC70" 155 | }, 156 | "outputs": [], 157 | "source": [ 158 | "#@title Initialize Tacotron and Waveglow and Hebrew to English fucntion\n", 159 | "%matplotlib inline\n", 160 | "import IPython.display as ipd\n", 161 | "import numpy as np\n", 162 | "import torch\n", 163 | "\n", 164 | "from hparams import create_hparams\n", 165 | "from model import Tacotron2\n", 166 | "from layers import TacotronSTFT\n", 167 | "from audio_processing import griffin_lim\n", 168 | "from text import text_to_sequence\n", 169 | "from denoiser import Denoiser\n", 170 | "\n", 171 | "graph_width = 900\n", 172 | "graph_height = 360\n", 173 | "def plot_data(data, figsize=(int(graph_width/100), int(graph_height/100))):\n", 174 | " %matplotlib inline\n", 175 | " fig, axes = plt.subplots(1, len(data), figsize=figsize)\n", 176 | " for i in range(len(data)):\n", 177 | " axes[i].imshow(data[i], aspect='auto', origin='upper', \n", 178 | " interpolation='none', cmap='inferno')\n", 179 | " fig.canvas.draw()\n", 180 | " plt.show()\n", 181 | "\n", 182 | "\n", 183 | "import re\n", 184 | " \n", 185 | " \n", 186 | "##breaks down number to list\n", 187 | "def breakdown(number):\n", 188 | " digits = []\n", 189 | " for place, value in zip([100000000,10000000, 1000000,100000, 10000,1000, 100, 10,1], [100000000,10000000, 1000000,100000, 10000,1000, 100, 10,1]):\n", 190 | " digit = number // place * value\n", 191 | " digits.append(digit)\n", 192 | " number -= digit\n", 193 | " return digits\n", 194 | " \n", 195 | "##auxilary function for NumberToHebrew , helps break down arrays of 3's\n", 196 | "def build_three_num_heb(list_num,num_dict_below_20,num_dict_eq_above_20,last):\n", 197 | " \n", 198 | " flag_zero=1\n", 199 | " list_heb=[]\n", 200 | " for i, num in enumerate(list_num):\n", 201 | " if num == 0:\n", 202 | " continue\n", 203 | " else:\n", 204 | " flag_zero=0\n", 205 | " if i < 1:\n", 206 | " list_heb.append(num_dict_eq_above_20[num])\n", 207 | " \n", 208 | " elif i == 1:\n", 209 | " if list_num[0] != 0:\n", 210 | " if list_num[1] + list_num[2] < 20:\n", 211 | " list_heb.append(\"וְ\" + num_dict_below_20[list_num[1] + list_num[2]])\n", 212 | " break\n", 213 | " else:\n", 214 | " list_heb.append(num_dict_eq_above_20[num])\n", 215 | " else:\n", 216 | " if list_num[1] + list_num[2] < 20:\n", 217 | " list_heb.append(num_dict_below_20[list_num[1] + list_num[2]])\n", 218 | " break\n", 219 | " else:\n", 220 | " list_heb.append(num_dict_eq_above_20[num])\n", 221 | " \n", 222 | " elif i == 2:\n", 223 | " if list_num[0] != 0 or list_num[1] != 0 or last:\n", 224 | " list_heb.append(\"וְ\" + num_dict_below_20[num])\n", 225 | " else:\n", 226 | " list_heb.append(num_dict_below_20[num])\n", 227 | " \n", 228 | " return list_heb,flag_zero\n", 229 | " \n", 230 | "##gets number and turns it into hebrew with nikud\n", 231 | "def NumberToHebrew(number):\n", 232 | " \n", 233 | " if number==0:\n", 234 | " return [\"אֶפֶס\"]\n", 235 | " \n", 236 | " signs_dict = {\n", 237 | " '%': 'אָחוּז',\n", 238 | " ',': 'פְּסִיק',\n", 239 | " '.': 'נְקֻדָּה'\n", 240 | " \n", 241 | " }\n", 242 | " \n", 243 | " num_dict_below_20={\n", 244 | " 1: 'אֶחָד',\n", 245 | " 2: 'שְׁנַיִם',\n", 246 | " 3: 'שְׁלֹשָׁה',\n", 247 | " 4: 'אַרְבָּעָה',\n", 248 | " 5: 'חֲמִשָּׁה',\n", 249 | " 6: 'שִׁשָּׁה',\n", 250 | " 7: 'שִׁבְעָה',\n", 251 | " 8: 'שְׁמוֹנָה',\n", 252 | " 9: 'תִּשְׁעָה',\n", 253 | " 10: 'עֶשֶׂר',\n", 254 | " 11: 'אַחַד עָשָׂר',\n", 255 | " 12: 'שְׁנֵים עָשָׂר',\n", 256 | " 13: 'שְׁלֹשָׁה עָשָׂר',\n", 257 | " 14: 'אַרְבָּעָה עָשָׂר',\n", 258 | " 15: 'חֲמִשָּׁה עָשָׂר',\n", 259 | " 16: 'שִׁשָּׁה עָשָׂר',\n", 260 | " 17: \"שִׁבְעָה עֶשְׂרֵה\",\n", 261 | " 18: \"שְׁמוֹנָה עֶשְׂרֵה\",\n", 262 | " 19: \"תִּשְׁעָה עֶשְׂרֵה\"\n", 263 | " }\n", 264 | " \n", 265 | " num_dict_eq_above_20={\n", 266 | " 20: \"עֶשְׂרִים\",\n", 267 | " 30: \"שְׁלשִׁים\",\n", 268 | " 40: \"אַרְבָּעִים\",\n", 269 | " 50: \"חֲמִשִּׁים\",\n", 270 | " 60: \"שִׁשִּׁים\",\n", 271 | " 70: \"שִׁבְעִים\",\n", 272 | " 80: \"שְׁמוֹנִים\",\n", 273 | " 90: \"תִּשְׁעִים\",\n", 274 | " 100: \"מֵאָה\",\n", 275 | " 200: \"מָאתַיִם\",\n", 276 | " 300: \"שְׁלֹשׁ מֵאוֹת\",\n", 277 | " 400: \"אַרְבָּעִ מֵאוֹת\",\n", 278 | " 500: \"חֲמִשֶּׁ מֵאוֹת\",\n", 279 | " 600: \"שֵׁשׁ מֵאוֹת\",\n", 280 | " 700: \"שִׁבְעַ מֵאוֹת\",\n", 281 | " 800: \"שְׁמוֹנֶ מֵאוֹת\",\n", 282 | " 900: \"תִּשְׁעַ מֵאוֹת\",\n", 283 | " 1000: \"אֶלֶף\",\n", 284 | " 2000: \"אֲלַפַּיִם\",\n", 285 | " 3000: \"שְׁלֹשֶׁת אֲלָפִים\",\n", 286 | " 4000: \"אַרְבַּעַת אֲלָפִים\",\n", 287 | " 5000: \"חֲמֵשׁ אֲלָפִים\",\n", 288 | " 6000: \"שֵׁשׁ אֲלָפִים\",\n", 289 | " 7000: \"שִׁבְעָה אֲלָפִים\",\n", 290 | " 8000: \"שְׁמוֹנָה אֲלָפִים\",\n", 291 | " 9000: \"תִּשְׁעָה אֲלָפִים\"\n", 292 | " }\n", 293 | " \n", 294 | " if number in signs_dict:\n", 295 | " return [signs_dict[number]]\n", 296 | " \n", 297 | " if number<10000:\n", 298 | " \n", 299 | " list_heb=[]\n", 300 | " list_num=breakdown(number)\n", 301 | " list_num=list_num[5:]\n", 302 | " for i,num in enumerate(list_num):\n", 303 | " if num==0:\n", 304 | " continue\n", 305 | " else:\n", 306 | " if i<2:\n", 307 | " list_heb.append(num_dict_eq_above_20[num])\n", 308 | " \n", 309 | " elif i==2:\n", 310 | " if list_num[0]!=0 or list_num[1]!=0:\n", 311 | " if list_num[2]+list_num[3]<20:\n", 312 | " list_heb.append(\"וְ\"+num_dict_below_20[list_num[2]+list_num[3]])\n", 313 | " break\n", 314 | " else:\n", 315 | " list_heb.append(num_dict_eq_above_20[num])\n", 316 | " else:\n", 317 | " if list_num[2] + list_num[3] < 20:\n", 318 | " list_heb.append(num_dict_below_20[list_num[2] + list_num[3]])\n", 319 | " break\n", 320 | " else:\n", 321 | " list_heb.append(num_dict_eq_above_20[num])\n", 322 | " \n", 323 | " elif i==3:\n", 324 | " if list_num[0]!=0 or list_num[1]!=0 or list_num[2]!=0:\n", 325 | " list_heb.append(\"וְ\" + num_dict_below_20[num])\n", 326 | " else:\n", 327 | " list_heb.append(num_dict_below_20[num])\n", 328 | " \n", 329 | " return list_heb\n", 330 | " \n", 331 | " else:\n", 332 | " \n", 333 | " list_heb = []\n", 334 | " list_num = breakdown(number)\n", 335 | " s1,s2,s3=list_num[:3],list_num[3:6],list_num[6:]\n", 336 | " \n", 337 | " \n", 338 | " \n", 339 | " ##take care of millions\n", 340 | " \n", 341 | " # set them up for transcript\n", 342 | " for i in range(len(s1)):\n", 343 | " s1[i]=s1[i]/1000000\n", 344 | " \n", 345 | " ret_list,flag_zero=build_three_num_heb(s1,num_dict_below_20,num_dict_eq_above_20,False)\n", 346 | " if not flag_zero:\n", 347 | " for item in ret_list:\n", 348 | " list_heb.append(item)\n", 349 | " list_heb.append(\"מִילְיוֹן\")\n", 350 | " \n", 351 | " ##take care of thousands\n", 352 | " \n", 353 | " # set them up for transcript\n", 354 | " for i in range(len(s2)):\n", 355 | " s2[i] = s2[i] / 1000\n", 356 | " \n", 357 | " ret_list, flag_zero = build_three_num_heb(s2, num_dict_below_20, num_dict_eq_above_20,False)\n", 358 | " if not flag_zero:\n", 359 | " for item in ret_list:\n", 360 | " list_heb.append(item)\n", 361 | " list_heb.append(\"אֶלֶף\")\n", 362 | " \n", 363 | " ##take care of hundred and leftovers\n", 364 | " ret_list, flag_zero = build_three_num_heb(s3, num_dict_below_20, num_dict_eq_above_20,True)\n", 365 | " if not flag_zero:\n", 366 | " for item in ret_list:\n", 367 | " list_heb.append(item)\n", 368 | " \n", 369 | " return list_heb\n", 370 | " \n", 371 | " \n", 372 | " \n", 373 | "##attempts to split string to number and string\n", 374 | "def split_number_and_string(input_string):\n", 375 | " # Use regular expression to find any number within the string\n", 376 | " match = re.search(r'\\d+', input_string)\n", 377 | " \n", 378 | " if match:\n", 379 | " # Extract the number from the string\n", 380 | " number = match.group()\n", 381 | " \n", 382 | " # Split the string into two parts: before and after the number\n", 383 | " index = match.start()\n", 384 | " string_before_number = input_string[:index]\n", 385 | " string_after_number = input_string[index + len(number):]\n", 386 | " \n", 387 | " return string_before_number, number, string_after_number\n", 388 | " else:\n", 389 | " # If no number is found, return None\n", 390 | " return None\n", 391 | " \n", 392 | "#######################################################################Auxilary functions\n", 393 | " \n", 394 | " \n", 395 | " \n", 396 | "##check if string has number in it\n", 397 | "def has_number(input_string):\n", 398 | " # Use regular expression to search for any digits within the string\n", 399 | " return bool(re.search(r'\\d', input_string))\n", 400 | " \n", 401 | "##breaks text to list\n", 402 | "def break_to_list(text):\n", 403 | " \"\"\"\n", 404 | " This function receives a string and returns a list of strings with each word from the input text.\n", 405 | " \"\"\"\n", 406 | " lst = []\n", 407 | " for tav in text:\n", 408 | " lst.append(tav)\n", 409 | " return lst\n", 410 | " \n", 411 | "####################################\n", 412 | " \n", 413 | " \n", 414 | "########################## relevant for fixing text with numbers:\n", 415 | " \n", 416 | "def is_number_with_comma(string):\n", 417 | " # Remove trailing comma, if any\n", 418 | " if string.endswith(',') or string.endswith('.'):\n", 419 | " string = string[:-1]\n", 420 | " \n", 421 | " # Check if string matches pattern\n", 422 | " if ',' in string:\n", 423 | " parts = string.split(',')\n", 424 | " if len(parts) != 2:\n", 425 | " return False\n", 426 | " if not all(part.isdigit() for part in parts):\n", 427 | " return False\n", 428 | " elif not string.isdigit():\n", 429 | " return False\n", 430 | " \n", 431 | " return True\n", 432 | " \n", 433 | "def clean_number_with_comma(string):\n", 434 | " # Remove trailing comma, if any\n", 435 | " if string.endswith(',') or string.endswith('.'):\n", 436 | " string = string[:-1]\n", 437 | " \n", 438 | " # Remove commas from string\n", 439 | " string = string.replace(',', '')\n", 440 | " \n", 441 | " # Convert string to integer and return\n", 442 | " return int(string)\n", 443 | " \n", 444 | " \n", 445 | "def is_number_with_decimal(string):\n", 446 | " if ',' in string:\n", 447 | " if string[-1] != ',':\n", 448 | " return False\n", 449 | " string = string[:-1]\n", 450 | " if '.' not in string:\n", 451 | " return False\n", 452 | " try:\n", 453 | " float(string)\n", 454 | " except ValueError:\n", 455 | " return False\n", 456 | " return True\n", 457 | " \n", 458 | "def clean_decimal(string):\n", 459 | " if ',' in string:\n", 460 | " if string[-1] != ',':\n", 461 | " return None\n", 462 | " string = string[:-1]\n", 463 | " \n", 464 | " parts = string.split('.')\n", 465 | " try:\n", 466 | " return int(parts[0]),int(parts[1])\n", 467 | " except:\n", 468 | " return int(parts[0]),None\n", 469 | " \n", 470 | " \n", 471 | "def is_percentage(string):\n", 472 | " \n", 473 | " if string[-1] == ',' or string[-1] == '.':\n", 474 | " string = string[:-1]\n", 475 | " \n", 476 | " if not string.endswith('%'):\n", 477 | " return False\n", 478 | " string = string[:-1]\n", 479 | " if string.endswith(','):\n", 480 | " string = string[:-1]\n", 481 | " try:\n", 482 | " float(string)\n", 483 | " except ValueError:\n", 484 | " return False\n", 485 | " return True\n", 486 | " \n", 487 | " \n", 488 | "def clean_percentage(string):\n", 489 | " if string[-1] == ',' or string[-1] == '.':\n", 490 | " string = string[:-1]\n", 491 | " \n", 492 | " if string.endswith('%'):\n", 493 | " string = string[:-1]\n", 494 | " else:\n", 495 | " return None\n", 496 | " if ',' in string:\n", 497 | " if string[-1] != ',':\n", 498 | " return None\n", 499 | " string = string[:-1]\n", 500 | " try:\n", 501 | " number = float(string)\n", 502 | " except ValueError:\n", 503 | " return None\n", 504 | " return str(number).rstrip('0').rstrip('.')\n", 505 | " \n", 506 | "def is_number_range(string):\n", 507 | " \n", 508 | " \n", 509 | " if '-' not in string:\n", 510 | " return False\n", 511 | " \n", 512 | " if string[-1] == ',' or string[-1] == '.':\n", 513 | " string = string[:-1]\n", 514 | " \n", 515 | " \n", 516 | " parts = string.split('-')\n", 517 | " if len(parts) != 2:\n", 518 | " return False\n", 519 | " for part in parts:\n", 520 | " if not part.isdigit():\n", 521 | " return False\n", 522 | " return True\n", 523 | " \n", 524 | "def clean_number_range(string):\n", 525 | " if string[-1] == ',' or string[-1] == '.':\n", 526 | " string = string[:-1]\n", 527 | " \n", 528 | " parts = string.split('-')\n", 529 | " return (int(parts[0]), int(parts[1]))\n", 530 | " \n", 531 | " \n", 532 | "def is_pattern_number_with_heb(string):\n", 533 | " if '-' not in string:\n", 534 | " return False\n", 535 | " \n", 536 | " if string[-1] == ',' or string[-1] == '.':\n", 537 | " string = string[:-1]\n", 538 | " \n", 539 | " parts = string.split('-')\n", 540 | " \n", 541 | " if not parts[1].isdigit() and not is_number_range(parts[1]) and not is_number_with_comma(parts[1]) and not is_number_with_decimal(parts[1]) and not is_percentage(parts[1]):\n", 542 | " return False\n", 543 | " \n", 544 | " return True\n", 545 | " \n", 546 | "def clean_pattern_number_with_heb(string):\n", 547 | " if string[-1] == ',' or string[-1] == '.':\n", 548 | " string = string[:-1]\n", 549 | " \n", 550 | " parts = string.split('-')\n", 551 | " if len(parts)<=2:\n", 552 | " ## '4,000,' / '10,000' / '1000,'\n", 553 | " if is_number_with_comma(parts[1]):\n", 554 | " return parts[0],str(clean_number_with_comma(parts[1])) , \"is_number_with_comma\"\n", 555 | " \n", 556 | " ## '2.9' / '3.4'\n", 557 | " elif is_number_with_decimal(parts[1]):\n", 558 | " return parts[0],clean_decimal(parts[1]) , \"is_number_with_decimal\"\n", 559 | " \n", 560 | " ## '4.5%' / '9.25%.' / '26.5%,'\n", 561 | " elif is_percentage(parts[1]):\n", 562 | " return parts[0],str(clean_percentage(parts[1])) , \"is_percentage\"\n", 563 | " \n", 564 | " ## '5-6' / '1971-1972,' / '2003-2005.'\n", 565 | " if len(parts)>2 and is_number_range(parts[1]+'-'+parts[2]):\n", 566 | " return parts[0],(str(parts[1]),str(parts[2])) , \"is_number_range\"\n", 567 | " \n", 568 | " \n", 569 | " \n", 570 | "def clean_number(word):\n", 571 | " \n", 572 | " ## '4,000,' / '10,000' / '1000,'\n", 573 | " if is_number_with_comma(word):\n", 574 | " return NumberToHebrew(int(clean_number_with_comma(word)))\n", 575 | " \n", 576 | " ## '2.9' / '3.4'\n", 577 | " elif is_number_with_decimal(word):\n", 578 | " list_heb=[]\n", 579 | " part1,part2=clean_decimal(word)\n", 580 | " list_heb+=NumberToHebrew(part1)\n", 581 | " list_heb+=NumberToHebrew('.')\n", 582 | " list_heb+=NumberToHebrew(part2)\n", 583 | " return list_heb\n", 584 | " \n", 585 | " ## '4.5%' / '9.25%.' / '26.5%,'\n", 586 | " elif is_percentage(word):\n", 587 | " list_heb = []\n", 588 | " part1, part2 = clean_decimal(clean_percentage(word))\n", 589 | " \n", 590 | " if part2!=None:\n", 591 | " list_heb += NumberToHebrew(part1)\n", 592 | " list_heb += NumberToHebrew('.')\n", 593 | " list_heb += NumberToHebrew(part2)\n", 594 | " list_heb += NumberToHebrew('%')\n", 595 | " return list_heb\n", 596 | " else:\n", 597 | " list_heb += NumberToHebrew(part1)\n", 598 | " list_heb += NumberToHebrew('%')\n", 599 | " return list_heb\n", 600 | " \n", 601 | " ## '5-6' / '1971-1972,' / '2003-2005.'\n", 602 | " elif is_number_range(word):\n", 603 | " list_heb = []\n", 604 | " part1, part2 = clean_number_range(word)\n", 605 | " list_heb += NumberToHebrew(part1)\n", 606 | " list_heb.append(\"עַד\")\n", 607 | " list_heb += NumberToHebrew(part2)\n", 608 | " return list_heb\n", 609 | " \n", 610 | " ## 'בְּ-100,000' / בְּ-99.99% / הַ-1,100 / מִ-0.7% / לְ-1.9 / כְּ-22,000.\n", 611 | " elif is_pattern_number_with_heb(word):\n", 612 | " heb_letter,num,func=(clean_pattern_number_with_heb(word))\n", 613 | " #arr_attr= (clean_pattern_number_with_heb(word))\n", 614 | " list_heb = []\n", 615 | " list_heb.append(heb_letter)\n", 616 | " \n", 617 | " if func==\"is_number_with_comma\":\n", 618 | " list_heb+=NumberToHebrew(int(num))\n", 619 | " return list_heb\n", 620 | " \n", 621 | " elif func==\"is_number_with_decimal\":\n", 622 | " part1, part2 = num\n", 623 | " list_heb += NumberToHebrew(part1)\n", 624 | " list_heb += NumberToHebrew('.')\n", 625 | " list_heb += NumberToHebrew(part2)\n", 626 | " return list_heb\n", 627 | " \n", 628 | " elif func==\"is_percentage\":\n", 629 | " part1, part2 = clean_decimal(num)\n", 630 | " \n", 631 | " if part2 != None:\n", 632 | " list_heb += NumberToHebrew(part1)\n", 633 | " list_heb += NumberToHebrew('.')\n", 634 | " list_heb += NumberToHebrew(part2)\n", 635 | " list_heb += NumberToHebrew('%')\n", 636 | " return list_heb\n", 637 | " else:\n", 638 | " list_heb += NumberToHebrew(part1)\n", 639 | " list_heb += NumberToHebrew('%')\n", 640 | " return list_heb\n", 641 | " \n", 642 | " elif func == \"is_number_range\":\n", 643 | " part1, part2 = num\n", 644 | " list_heb += NumberToHebrew(int(part1))\n", 645 | " list_heb.append(\"עַד\")\n", 646 | " list_heb += NumberToHebrew(int(part2))\n", 647 | " return list_heb\n", 648 | " \n", 649 | "#######################################################\n", 650 | " \n", 651 | "##takes a letter in hebrew and returns the sound in english\n", 652 | "def English(obj,tzuptzik,last_letter=False):\n", 653 | " obj = Hebrew(obj).string\n", 654 | " # map the nikud symbols to their corresponding phenoms\n", 655 | " nikud_map = {\"ָ\": \"a\", \"ַ\": \"a\", \"ֶ\": \"e\", \"ֵ\": \"e\", \"ִ\": \"i\", \"ְ\": \"\", \"ֹ\": \"o\", \"ֻ\": \"oo\", 'ּ': \"\", 'ֲ': 'a'}\n", 656 | " # fixme need create more hebrew phenoms\n", 657 | " \n", 658 | " beged_kefet_shin_sin = {\n", 659 | " ############ B\n", 660 | " \"בּ\": \"b\",\n", 661 | " \"בְּ\": \"b\",\n", 662 | " \"בִּ\": \"bi\",\n", 663 | " \"בֹּ\": \"bo\",\n", 664 | " \"בֵּ\": \"be\",\n", 665 | " \"בֶּ\": \"be\",\n", 666 | " \"בַּ\": \"ba\",\n", 667 | " \"בָּ\": \"ba\",\n", 668 | " \"בֻּ\": \"boo\",\n", 669 | " ############ G\n", 670 | " \"גּ\": \"g\",\n", 671 | " \"גְּ\": \"g\",\n", 672 | " \"גִּ\": \"gi\",\n", 673 | " \"גֹּ\": \"go\",\n", 674 | " \"גֵּ\": \"ge\",\n", 675 | " \"גֶּ\": \"ge\",\n", 676 | " \"גַּ\": \"ga\",\n", 677 | " \"גָּ\": \"ga\",\n", 678 | " \"גֻּ\": \"goo\",\n", 679 | " ########### D\n", 680 | " \"דּ\": \"d\",\n", 681 | " \"דְּ\": \"d\",\n", 682 | " \"דִּ\": \"di\",\n", 683 | " \"דֹּ\": \"do\",\n", 684 | " \"דֵּ\": \"de\",\n", 685 | " \"דֶּ\": \"de\",\n", 686 | " \"דַּ\": \"da\",\n", 687 | " \"דָּ\": \"da\",\n", 688 | " \"דֻּ\": \"doo\",\n", 689 | " ########### K\n", 690 | " \"כּ\": \"k\",\n", 691 | " \"כְּ\": \"k\",\n", 692 | " \"כִּ\": \"ki\",\n", 693 | " \"כֹּ\": \"ko\",\n", 694 | " \"כֵּ\": \"ke\",\n", 695 | " \"כֶּ\": \"ke\",\n", 696 | " \"כַּ\": \"ka\",\n", 697 | " \"כָּ\": \"ka\",\n", 698 | " \"כֻּ\": \"koo\",\n", 699 | " ############ P\n", 700 | " \"פּ\": \"p\",\n", 701 | " \"פְּ\": \"p\",\n", 702 | " \"פִּ\": \"pi\",\n", 703 | " \"פֹּ\": \"po\",\n", 704 | " \"פֵּ\": \"pe\",\n", 705 | " \"פֶּ\": \"pe\",\n", 706 | " \"פַּ\": \"pa\",\n", 707 | " \"פָּ\": \"pa\",\n", 708 | " \"פֻּ\": \"poo\",\n", 709 | " ############ T\n", 710 | " \"תּ\": \"t\",\n", 711 | " \"תְּ\": \"t\",\n", 712 | " \"תִּ\": \"ti\",\n", 713 | " \"תֹּ\": \"to\",\n", 714 | " \"תֵּ\": \"te\",\n", 715 | " \"תֶּ\": \"te\",\n", 716 | " \"תַּ\": \"ta\",\n", 717 | " \"תָּ\": \"ta\",\n", 718 | " \"תֻּ\": \"too\",\n", 719 | " ############ S\n", 720 | " \"שׂ\": \"s\",\n", 721 | " \"שְׂ\": \"s\",\n", 722 | " \"שִׂ\": \"si\",\n", 723 | " \"שֹׂ\": \"so\",\n", 724 | " \"שֵׂ\": \"se\",\n", 725 | " \"שֶׂ\": \"se\",\n", 726 | " \"שַׂ\": \"sa\",\n", 727 | " \"שָׂ\": \"sa\",\n", 728 | " \"שֻׂ\": \"soo\",\n", 729 | " ########### SH\n", 730 | " \"שׁ\": \"sh\",\n", 731 | " \"שְׁ\": \"sh\",\n", 732 | " \"שִׁ\": \"shi\",\n", 733 | " \"שֹׁ\": \"sho\",\n", 734 | " \"שֵׁ\": \"she\",\n", 735 | " \"שֶׁ\": \"she\",\n", 736 | " \"שַׁ\": \"sha\",\n", 737 | " \"שָׁ\": \"sha\",\n", 738 | " \"שֻׁ\": \"shoo\",\n", 739 | " }\n", 740 | " \n", 741 | " vav = {\n", 742 | " \"וֵּו\": \"ve\",\n", 743 | " \"וּ\": \"oo\",\n", 744 | " \"וּו\": \"oo\",\n", 745 | " \"וֹ\": \"o\",\n", 746 | " \"וֹו\": \"oo\",\n", 747 | " \"וְ\": \"ve\",\n", 748 | " \"וֱו\": \"ve\",\n", 749 | " \"וִ\": \"vi\",\n", 750 | " \"וִו\": \"vi\",\n", 751 | " \"וַ\": \"va\",\n", 752 | " \"וַו\": \"va\",\n", 753 | " \"וֶ\": \"ve\",\n", 754 | " \"וֶו\": \"ve\",\n", 755 | " \"וָ\": \"va\",\n", 756 | " \"וָו\": \"va\",\n", 757 | " \"וֻ\": \"oo\",\n", 758 | " \"וֻו\": \"oo\"\n", 759 | " }\n", 760 | " \n", 761 | " \n", 762 | " letters_map = {\n", 763 | " \"א\": \"\",\n", 764 | " \"ב\": \"v\",\n", 765 | " \"ג\": \"g\",\n", 766 | " \"ד\": \"d\",\n", 767 | " \"ה\": \"hh\",\n", 768 | " \"ו\": \"v\",\n", 769 | " \"ז\": \"z\",\n", 770 | " \"ח\": \"h\",\n", 771 | " \"ט\": \"t\",\n", 772 | " \"י\": \"y\",\n", 773 | " \"כ\": \"h\",\n", 774 | " \"ל\": \"l\",\n", 775 | " \"מ\": \"m\",\n", 776 | " \"נ\": \"n\",\n", 777 | " \"ס\": \"s\",\n", 778 | " \"ע\": \"\",\n", 779 | " \"פ\": \"f\",\n", 780 | " \"צ\": \"ts\",\n", 781 | " \"ק\": \"k\",\n", 782 | " \"ר\": \"r\",\n", 783 | " \"ש\": \"sh\",\n", 784 | " \"ת\": \"t\",\n", 785 | " \"ן\": \"n\",\n", 786 | " \"ם\": \"m\",\n", 787 | " \"ף\": \"f\",\n", 788 | " \"ץ\": \"ts\",\n", 789 | " \"ך\": \"h\",\n", 790 | " }\n", 791 | " \n", 792 | " patah_ganav={\n", 793 | " \"חַ\": \"ah\",\n", 794 | " \"חָ\": \"ah\",\n", 795 | " \"הַ\": \"hha\",\n", 796 | " \"הָ\": \"hha\",\n", 797 | " \"עַ\": \"a\",\n", 798 | " \"עָ\": \"a\",\n", 799 | " \n", 800 | " }\n", 801 | " \n", 802 | " tzuptzik_letters={\n", 803 | " ##G\n", 804 | " \"ג\": \"j\",\n", 805 | " \"גְ\": \"j\",\n", 806 | " \"גִ\": \"ji\",\n", 807 | " \"גֹ\": \"jo\",\n", 808 | " \"גֵ\": \"je\",\n", 809 | " \"גֶ\": \"je\",\n", 810 | " \"גַ\": \"ja\",\n", 811 | " \"גָ\": \"ja\",\n", 812 | " \"גֻ\": \"joo\",\n", 813 | " \"גּ\": \"j\",\n", 814 | " \"גְּ\": \"j\",\n", 815 | " \"גִּ\": \"ji\",\n", 816 | " \"גֹּ\": \"jo\",\n", 817 | " \"גֵּ\": \"je\",\n", 818 | " \"גֶּ\": \"je\",\n", 819 | " \"גַּ\": \"ja\",\n", 820 | " \"גָּ\": \"ja\",\n", 821 | " \"גֻּ\": \"joo\",\n", 822 | " \n", 823 | " ##ch\n", 824 | " \"צ\": \"ch\",\n", 825 | " \"צְ\": \"ch\",\n", 826 | " \"צִ\": \"chi\",\n", 827 | " \"צֹ\": \"cho\",\n", 828 | " \"צֵ\": \"che\",\n", 829 | " \"צֶ\": \"che\",\n", 830 | " \"צַ\": \"cha\",\n", 831 | " \"צָ\": \"cha\",\n", 832 | " \"צֻ\": \"choo\",\n", 833 | " \n", 834 | " ##Z\n", 835 | " \"ז\": \"zh\",\n", 836 | " \"זְ\": \"zh\",\n", 837 | " \"זִ\": \"zhi\",\n", 838 | " \"זֹ\": \"zho\",\n", 839 | " \"זֵ\": \"zhe\",\n", 840 | " \"זֶ\": \"zhe\",\n", 841 | " \"זַ\": \"zha\",\n", 842 | " \"זָ\": \"zha\",\n", 843 | " \"זֻ\": \"zhoo\",\n", 844 | " }\n", 845 | " \n", 846 | " if last_letter:\n", 847 | " if obj in patah_ganav:\n", 848 | " return patah_ganav[obj]\n", 849 | " \n", 850 | " if tzuptzik==True:\n", 851 | " if obj in tzuptzik_letters:\n", 852 | " return tzuptzik_letters[obj]\n", 853 | " \n", 854 | " if obj in beged_kefet_shin_sin:\n", 855 | " return beged_kefet_shin_sin[obj]\n", 856 | " elif obj in vav:\n", 857 | " return vav[obj]\n", 858 | " else:\n", 859 | " lst = break_to_list(obj)\n", 860 | " string = \"\"\n", 861 | " for item in lst:\n", 862 | " if item in letters_map:\n", 863 | " string += letters_map[item]\n", 864 | " if item in nikud_map:\n", 865 | " string += nikud_map[item]\n", 866 | " \n", 867 | " return string\n", 868 | " \n", 869 | " \n", 870 | "##takes hebrew word and turns it into the sound in english\n", 871 | "def HebWordToEng(word,index):\n", 872 | " new_sentence=\"\"\n", 873 | " hs = Hebrew(word)\n", 874 | " hs = Hebrew(list(hs.graphemes)).string\n", 875 | " for i, letter in enumerate(hs):\n", 876 | " \n", 877 | " tzuptzik = False\n", 878 | " if i < len(hs) - 1:\n", 879 | " if hs[i + 1] == '\\'':\n", 880 | " tzuptzik = True\n", 881 | " \n", 882 | " tav = English(letter, tzuptzik, i == len(hs) - 1)\n", 883 | " new_sentence += tav\n", 884 | " \n", 885 | " ##clean list:\n", 886 | " try:\n", 887 | " if new_sentence[-1] == 'y' and new_sentence[-2] == 'y':\n", 888 | " new_sentence = new_sentence.replace(\"yy\", \"y\")\n", 889 | " except:\n", 890 | " pass\n", 891 | " return new_sentence\n", 892 | " \n", 893 | "##takes hebrew sentence and turns it into english sounds\n", 894 | "def ARPA(sentence,index=0):\n", 895 | " words = sentence.split()\n", 896 | " new_sentence = \"\"\n", 897 | " \n", 898 | " for word in words:\n", 899 | " ##if number not in string\n", 900 | " if not has_number(word):\n", 901 | " ret_sentence=HebWordToEng(word,index)\n", 902 | " new_sentence+=ret_sentence+\" \"\n", 903 | " \n", 904 | " #if there is a number:\n", 905 | " else:\n", 906 | " try:\n", 907 | " before_num,num,after_num=split_number_and_string(word)\n", 908 | " \n", 909 | " if has_number(after_num) or has_number(before_num):\n", 910 | " ##raise(Exception)\n", 911 | " ##fixme - need to break down number in case of different scenarios (check discord)\n", 912 | " list_of_numbers=clean_number(word)\n", 913 | " for number in list_of_numbers:\n", 914 | " ret_sentence = HebWordToEng(number, index)\n", 915 | " new_sentence += ret_sentence + \" \"\n", 916 | " \n", 917 | " else:\n", 918 | " ret_sentence = HebWordToEng(before_num, index)\n", 919 | " new_sentence += ret_sentence+\" \"\n", 920 | " \n", 921 | " num = [s for s in word if s.isdigit()]\n", 922 | " num=\"\".join(num)\n", 923 | " num=int(num)\n", 924 | " list_of_numbers=NumberToHebrew(num)\n", 925 | " for number in list_of_numbers:\n", 926 | " ret_sentence=HebWordToEng(number,index)\n", 927 | " new_sentence += ret_sentence + \" \"\n", 928 | " \n", 929 | " ret_sentence = HebWordToEng(after_num, index)\n", 930 | " new_sentence += ret_sentence + \" \"\n", 931 | " \n", 932 | " \n", 933 | " \n", 934 | " except:\n", 935 | " print(\"error from split_number_and_string in line:\", index,\"with word: \",word)\n", 936 | " \n", 937 | " \n", 938 | " # new_sentence1 = ' '.join(new_sentence.split())\n", 939 | " return new_sentence\n", 940 | " \n", 941 | "\n", 942 | " \n", 943 | " \n", 944 | "\n", 945 | "\n", 946 | "\n", 947 | "torch.set_grad_enabled(False)\n", 948 | "\n", 949 | "# initialize Tacotron2 with the pretrained model\n", 950 | "hparams = create_hparams()" 951 | ] 952 | }, 953 | { 954 | "cell_type": "code", 955 | "execution_count": null, 956 | "metadata": { 957 | "id": "FCyv8MmOvla7" 958 | }, 959 | "outputs": [], 960 | "source": [ 961 | "# Load Tacotron2 (run this cell every time you change the model)\n", 962 | "hparams.sampling_rate = 22050 # Don't change this\n", 963 | "hparams.max_decoder_steps = 1000 # How long the audio will be before it cuts off (1000 is about 11 seconds)\n", 964 | "hparams.gate_threshold = 0.1 # Model must be 90% sure the clip is over before ending generation (the higher this number is, the more likely that the AI will keep generating until it reaches the Max Decoder Steps)\n", 965 | "model = Tacotron2(hparams)\n", 966 | "model.load_state_dict(torch.load(tacotron2_pretrained_model)['state_dict'])\n", 967 | "_ = model.cuda().eval()" 968 | ] 969 | }, 970 | { 971 | "cell_type": "code", 972 | "execution_count": null, 973 | "metadata": { 974 | "id": "QnW74wETvoIS" 975 | }, 976 | "outputs": [], 977 | "source": [ 978 | "# Load WaveGlow\n", 979 | "waveglow = torch.load(waveglow_pretrained_model)['model']\n", 980 | "waveglow.cuda().eval()\n", 981 | "for k in waveglow.convinv:\n", 982 | " k.float()\n", 983 | "denoiser = Denoiser(waveglow)" 984 | ] 985 | }, 986 | { 987 | "cell_type": "code", 988 | "execution_count": null, 989 | "metadata": {}, 990 | "outputs": [], 991 | "source": [ 992 | "#@Input the hebrew text with nikud\n", 993 | "text = \"בָּנַיי\" #@param {type:\"string\"}" 994 | ] 995 | }, 996 | { 997 | "cell_type": "code", 998 | "execution_count": null, 999 | "metadata": { 1000 | "id": "B5Zb4w8bvog5" 1001 | }, 1002 | "outputs": [], 1003 | "source": [ 1004 | "\n", 1005 | "sigma = 0.8\n", 1006 | "denoise_strength = 0.1\n", 1007 | "# try to switch raw data to True maybe the results will be better\n", 1008 | "raw_input = False # disables automatic ARPAbet conversion, useful for inputting your own ARPAbet pronounciations or just for testing\n", 1009 | "\n", 1010 | "for i in text.split(\"\\n\"):\n", 1011 | " if len(i) < 1: continue;\n", 1012 | " print(i)\n", 1013 | " if raw_input:\n", 1014 | " if i[-1] != \";\": i=i+\";\" \n", 1015 | " else: i = ARPA(i)\n", 1016 | " print(i)\n", 1017 | " with torch.no_grad(): # save VRAM by not including gradients\n", 1018 | " sequence = np.array(text_to_sequence(i, ['english_cleaners']))[None, :]\n", 1019 | " sequence = torch.autograd.Variable(torch.from_numpy(sequence)).cuda().long()\n", 1020 | " mel_outputs, mel_outputs_postnet, _, alignments = model.inference(sequence)\n", 1021 | " plot_data((mel_outputs_postnet.float().data.cpu().numpy()[0],alignments.float().data.cpu().numpy()[0].T))\n", 1022 | " audio = waveglow.infer(mel_outputs_postnet, sigma=sigma); print(\"\"); ipd.display(ipd.Audio(audio[0].data.cpu().numpy(), rate=hparams.sampling_rate))" 1023 | ] 1024 | } 1025 | ], 1026 | "metadata": { 1027 | "colab": { 1028 | "provenance": [] 1029 | }, 1030 | "kernelspec": { 1031 | "display_name": "venv", 1032 | "language": "python", 1033 | "name": "python3" 1034 | }, 1035 | "language_info": { 1036 | "codemirror_mode": { 1037 | "name": "ipython", 1038 | "version": 3 1039 | }, 1040 | "file_extension": ".py", 1041 | "mimetype": "text/x-python", 1042 | "name": "python", 1043 | "nbconvert_exporter": "python", 1044 | "pygments_lexer": "ipython3", 1045 | "version": "3.10.17" 1046 | } 1047 | }, 1048 | "nbformat": 4, 1049 | "nbformat_minor": 0 1050 | } 1051 | -------------------------------------------------------------------------------- /data_preprocess.py: -------------------------------------------------------------------------------- 1 | import HebrewToEnglish 2 | import pandas as pd 3 | from sklearn.model_selection import train_test_split 4 | from tqdm import tqdm 5 | import argparse 6 | import os 7 | 8 | #### THIS IS AN EXAMPLE OF HOW TO CONVERT THE ROBOSHAUL DATASET TO A LIST FOR THE MODEL #### 9 | 10 | def process_dataset(csv_path, audio_path): 11 | """ 12 | Process Hebrew text-to-speech dataset by converting Hebrew to English and creating train/val splits. 13 | 14 | Args: 15 | csv_path (str): Path to the CSV metadata file 16 | audio_path (str): Full path to the directory containing audio files 17 | """ 18 | # Load the dataset 19 | df = pd.read_csv(csv_path, sep='|', encoding='utf-8') 20 | 21 | # Convert the 'text' column to a list 22 | text_list = df['transcript'].tolist() 23 | 24 | # Convert Hebrew text to English transliteration 25 | hebrew_to_english = HebrewToEnglish.HebrewToEnglish 26 | english_text_list = [hebrew_to_english(text) for text in tqdm(text_list, desc="Converting Hebrew to English")] 27 | 28 | list_of_path_and_text = [] 29 | for path, text in tqdm(zip(df['file_id'], english_text_list), desc="Creating path-text pairs", total=len(df)): 30 | # Create full path to audio file 31 | full_audio_path = os.path.join(audio_path, f"{path}.wav") 32 | list_of_path_and_text.append((full_audio_path, text)) 33 | 34 | # Split the data into train and validation sets 35 | train_texts, val_texts = train_test_split(english_text_list, test_size=0.2, random_state=42) 36 | 37 | # Create filelists directory if it doesn't exist 38 | os.makedirs('filelists', exist_ok=True) 39 | 40 | # Save the training set 41 | with open('filelists/train_list.txt', 'w', encoding='utf-8') as f: 42 | for path, text in tqdm(zip(df['file_id'][:len(train_texts)], train_texts), desc="Saving training set"): 43 | full_audio_path = os.path.join(audio_path, f"{path}.wav") 44 | f.write(f"{full_audio_path}|{text}\n") 45 | 46 | # Save the validation set 47 | with open('filelists/validation_list.txt', 'w', encoding='utf-8') as f: 48 | for path, text in tqdm(zip(df['file_id'][len(train_texts):], val_texts), desc="Saving validation set"): 49 | full_audio_path = os.path.join(audio_path, f"{path}.wav") 50 | f.write(f"{full_audio_path}|{text}\n") 51 | 52 | def main(): 53 | parser = argparse.ArgumentParser(description='Process Hebrew TTS dataset') 54 | parser.add_argument('--csv_path', type=str, default='data/saspeech_gold_standard/metadata_full.csv', 55 | help='Path to the CSV metadata file') 56 | parser.add_argument('--audio_path', type=str, default='data/saspeech_gold_standard/wavs', 57 | help='Full path to the directory containing audio files') 58 | 59 | args = parser.parse_args() 60 | 61 | process_dataset(args.csv_path, args.audio_path) 62 | 63 | if __name__ == "__main__": 64 | main() -------------------------------------------------------------------------------- /filelists/train_list.txt: -------------------------------------------------------------------------------- 1 | data/saspeech_gold_standard/wavs/gold_000_line_000.npy|hhinehh hhatshoovahh shelo meaz q -------------------------------------------------------------------------------- /filelists/validation_list.txt: -------------------------------------------------------------------------------- 1 | data/saspeech_gold_standard/wavs/gold_015_line_049.npy|beyn hhayeter q lahats mod kaved shebank hhapoaliym q shebatkoofahh hhazot hheviyn shealoviych hhistabeh pohh meal hhatsavavr q -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | from pathlib import Path 5 | from typing import Optional, Tuple 6 | 7 | import numpy as np 8 | import torch 9 | import matplotlib.pyplot as plt 10 | from scipy.io.wavfile import write 11 | 12 | from tacotron2.hparams import create_hparams 13 | from tacotron2.model import Tacotron2 14 | from tacotron2.text import text_to_sequence 15 | from waveglow.denoiser import Denoiser 16 | import HebrewToEnglish 17 | 18 | class TTSInference: 19 | """Text-to-Speech inference class using Tacotron2 and WaveGlow.""" 20 | 21 | def __init__(self, tacotron2_model_path: str, waveglow_model_path: str): 22 | self.tacotron2_model_path = tacotron2_model_path 23 | self.waveglow_model_path = waveglow_model_path 24 | self.model = None 25 | self.waveglow = None 26 | self.denoiser = None 27 | self.hparams = None 28 | 29 | def load_models(self) -> bool: 30 | """Load Tacotron2 and WaveGlow models.""" 31 | try: 32 | # Validate model files exist 33 | if not os.path.exists(self.tacotron2_model_path): 34 | print(f"ERROR: Tacotron2 model not found: {self.tacotron2_model_path}") 35 | return False 36 | if not os.path.exists(self.waveglow_model_path): 37 | print(f"ERROR: WaveGlow model not found: {self.waveglow_model_path}") 38 | return False 39 | 40 | print("Loading Tacotron2 model...") 41 | self.hparams = create_hparams() 42 | self.hparams.sampling_rate = 22050 43 | self.hparams.max_decoder_steps = 1000 44 | self.hparams.gate_threshold = 0.1 45 | 46 | self.model = Tacotron2(self.hparams) 47 | checkpoint = torch.load(self.tacotron2_model_path, map_location='cpu', weights_only=False) 48 | self.model.load_state_dict(checkpoint['state_dict']) 49 | self.model.eval() 50 | 51 | print("Loading WaveGlow model...") 52 | waveglow_checkpoint = torch.load(self.waveglow_model_path, map_location='cpu', weights_only=False) 53 | self.waveglow = waveglow_checkpoint['model'] 54 | self.waveglow.eval() 55 | 56 | for k in self.waveglow.convinv: 57 | k.float() 58 | self.denoiser = Denoiser(self.waveglow) 59 | 60 | print("Models loaded successfully!") 61 | return True 62 | 63 | except Exception as e: 64 | print(f"ERROR: Error loading models: {str(e)}") 65 | return False 66 | 67 | def synthesize_speech(self, text: str, output_dir: str = "output", 68 | sigma: float = 0.8, raw_input: bool = False) -> list: 69 | """Synthesize speech from text and save audio files.""" 70 | if not self.model or not self.waveglow: 71 | print("ERROR: Models not loaded. Call load_models() first.") 72 | return [] 73 | 74 | # Create output directory 75 | Path(output_dir).mkdir(exist_ok=True) 76 | generated_files = [] 77 | 78 | ARPA = HebrewToEnglish.HebrewToEnglish 79 | 80 | for i, line in enumerate(text.split("\n")): 81 | line = line.strip() 82 | if len(line) < 1: 83 | continue 84 | 85 | print(f"Processing line {i+1}: {line}") 86 | 87 | try: 88 | # Process text 89 | if raw_input: 90 | if not line.endswith(";"): 91 | line = line + ";" 92 | else: 93 | line = ARPA(line) 94 | 95 | print(f"Processed text: {line}") 96 | 97 | # Generate audio 98 | with torch.no_grad(): 99 | sequence = np.array(text_to_sequence(line, ['english_cleaners']))[None, :] 100 | sequence = torch.autograd.Variable(torch.from_numpy(sequence)).long() 101 | 102 | mel_outputs, mel_outputs_postnet, _, alignments = self.model.inference(sequence) 103 | audio = self.waveglow.infer(mel_outputs_postnet, sigma=sigma) 104 | 105 | # Save audio 106 | audio_numpy = audio[0].data.cpu().numpy() 107 | filename = f"output_audio_{i+1}_{hash(line) & 0x7FFFFFFF}.wav" 108 | filepath = os.path.join(output_dir, filename) 109 | 110 | write(filepath, self.hparams.sampling_rate, audio_numpy) 111 | generated_files.append(filepath) 112 | print(f"Audio saved: {filepath}") 113 | 114 | except Exception as e: 115 | print(f"ERROR: Error processing line '{line}': {str(e)}") 116 | continue 117 | 118 | return generated_files 119 | 120 | def main(): 121 | parser = argparse.ArgumentParser(description="Hebrew Text-to-Speech Inference") 122 | parser.add_argument("--text", type=str, default="בָּנַיי", 123 | help="Text to synthesize (default: בָּנַיי)") 124 | parser.add_argument("--text-file", type=str, 125 | help="Path to text file to synthesize") 126 | parser.add_argument("--tacotron2-model", type=str, 127 | default="checkpoints/shaul_gold_only_with_special.pt", 128 | help="Path to Tacotron2 model") 129 | parser.add_argument("--waveglow-model", type=str, 130 | default="waveglow_weights/waveglow_256channels_universal_v4.pt", 131 | help="Path to WaveGlow model") 132 | parser.add_argument("--output-dir", type=str, default="inference_results", 133 | help="Output directory for audio files") 134 | parser.add_argument("--sigma", type=float, default=0.8, 135 | help="WaveGlow sigma parameter") 136 | parser.add_argument("--raw-input", action="store_true", 137 | help="Use raw input without Hebrew to English conversion") 138 | 139 | args = parser.parse_args() 140 | 141 | # Get input text 142 | if args.text_file: 143 | try: 144 | with open(args.text_file, 'r', encoding='utf-8') as f: 145 | input_text = f.read().strip() 146 | except Exception as e: 147 | print(f"ERROR: Error reading text file: {str(e)}") 148 | sys.exit(1) 149 | else: 150 | input_text = args.text 151 | 152 | if not input_text: 153 | print("ERROR: No input text provided") 154 | sys.exit(1) 155 | 156 | # Initialize TTS inference 157 | tts = TTSInference(args.tacotron2_model, args.waveglow_model) 158 | 159 | # Load models 160 | if not tts.load_models(): 161 | print("ERROR: Failed to load models") 162 | sys.exit(1) 163 | 164 | # Synthesize speech 165 | print(f"Synthesizing speech for text: {input_text}") 166 | generated_files = tts.synthesize_speech( 167 | input_text, 168 | output_dir=args.output_dir, 169 | sigma=args.sigma, 170 | raw_input=args.raw_input 171 | ) 172 | 173 | if generated_files: 174 | print(f"Successfully generated {len(generated_files)} audio files:") 175 | for file in generated_files: 176 | print(f" - {file}") 177 | else: 178 | print("WARNING: No audio files were generated") 179 | 180 | if __name__ == "__main__": 181 | main() -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # CUDA 12.4 support 2 | torch==2.5.1 3 | torchaudio==2.5.1 4 | torchvision==0.20.1 5 | --extra-index-url https://download.pytorch.org/whl/cu124 6 | 7 | # Additional dependencies 8 | numpy 9 | tqdm 10 | matplotlib 11 | requests 12 | librosa 13 | inflect 14 | tensorboardX 15 | Hebrew 16 | gdown 17 | unidecode 18 | pandas -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import argparse 4 | import math 5 | from numpy import finfo 6 | from typing import Optional, Tuple, Any 7 | 8 | import torch 9 | import torch.distributed as dist 10 | from torch.utils.data.distributed import DistributedSampler 11 | from torch.utils.data import DataLoader 12 | import numpy as np 13 | 14 | from tacotron2.model import Tacotron2 15 | from tacotron2.data_utils import TextMelLoader, TextMelCollate 16 | from tacotron2.loss_function import Tacotron2Loss 17 | from tacotron2.logger import Tacotron2Logger 18 | from tacotron2.hparams import create_hparams 19 | from tacotron2.distributed import apply_gradient_allreduce 20 | import tacotron2.layers as layers 21 | from tacotron2.utils import load_wav_to_torch, load_filepaths_and_text 22 | from tacotron2.text import text_to_sequence 23 | 24 | import matplotlib 25 | matplotlib.use('Agg') # Use non-interactive backend 26 | import matplotlib.pyplot as plt 27 | import glob 28 | import requests 29 | import subprocess 30 | from math import e 31 | from tqdm import tqdm 32 | from distutils.dir_util import copy_tree 33 | import torchaudio 34 | 35 | 36 | 37 | 38 | def list_available_checkpoints(checkpoint_dir: str = "checkpoints") -> list: 39 | """List available checkpoint files in the directory.""" 40 | if not os.path.exists(checkpoint_dir): 41 | return [] 42 | 43 | checkpoints = [] 44 | for file in os.listdir(checkpoint_dir): 45 | if file.endswith('.pt') or file.endswith('.pth') or 'checkpoint' in file.lower(): 46 | checkpoints.append(os.path.join(checkpoint_dir, file)) 47 | 48 | return sorted(checkpoints) 49 | 50 | 51 | def select_checkpoint(available_checkpoints: list, preferred_checkpoint: str = None) -> str: 52 | """Select a checkpoint from available options.""" 53 | if not available_checkpoints: 54 | return None 55 | 56 | if preferred_checkpoint: 57 | # Look for exact match first 58 | for checkpoint in available_checkpoints: 59 | if preferred_checkpoint in checkpoint: 60 | print(f"Found matching checkpoint: {checkpoint}") 61 | return checkpoint 62 | 63 | print(f"WARNING: Preferred checkpoint '{preferred_checkpoint}' not found") 64 | 65 | # If no preference or not found, list available options 66 | print("Available checkpoints:") 67 | for i, checkpoint in enumerate(available_checkpoints): 68 | print(f" {i+1}: {os.path.basename(checkpoint)}") 69 | 70 | # Return the latest (last in sorted list) by default 71 | selected = available_checkpoints[-1] 72 | print(f"Using latest checkpoint: {os.path.basename(selected)}") 73 | return selected 74 | 75 | def create_mels(hparams) -> None: 76 | """Generate mel spectrograms from audio files.""" 77 | print("Generating Mels") 78 | 79 | stft = layers.TacotronSTFT( 80 | hparams.filter_length, hparams.hop_length, hparams.win_length, 81 | hparams.n_mel_channels, hparams.sampling_rate, hparams.mel_fmin, 82 | hparams.mel_fmax) 83 | 84 | def save_mel(filename: str) -> None: 85 | audio, sampling_rate = load_wav_to_torch(filename) 86 | 87 | 88 | if sampling_rate != stft.sampling_rate: 89 | audio = torchaudio.transforms.Resample( 90 | orig_freq=sampling_rate, new_freq=stft.sampling_rate)(audio) 91 | 92 | audio_norm = audio / hparams.max_wav_value 93 | audio_norm = audio_norm.unsqueeze(0) 94 | audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False) 95 | melspec = stft.mel_spectrogram(audio_norm) 96 | melspec = torch.squeeze(melspec, 0).cpu().numpy() 97 | np.save(filename.replace('.wav', '.npy'), melspec) 98 | 99 | # Collect wav files from training and validation lists 100 | wav_files = set() 101 | 102 | 103 | # Get wav files from training list 104 | training_files = load_filepaths_and_text(hparams.training_files) 105 | for file_info in training_files: 106 | audio_path = file_info[0] 107 | # Convert .npy path back to .wav for mel generation 108 | if audio_path.endswith('.npy'): 109 | wav_path = audio_path.replace('.npy', '.wav') 110 | wav_files.add(wav_path) 111 | elif audio_path.endswith('.wav'): 112 | wav_files.add(audio_path) 113 | 114 | # Get wav files from validation list 115 | validation_files = load_filepaths_and_text(hparams.validation_files) 116 | for file_info in validation_files: 117 | audio_path = file_info[0] 118 | # Convert .npy path back to .wav for mel generation 119 | if audio_path.endswith('.npy'): 120 | wav_path = audio_path.replace('.npy', '.wav') 121 | wav_files.add(wav_path) 122 | elif audio_path.endswith('.wav'): 123 | wav_files.add(audio_path) 124 | 125 | 126 | wav_files = list(wav_files) 127 | 128 | if not wav_files: 129 | print("WARNING: No .wav files found in training/validation lists") 130 | return 131 | 132 | print(f"Found {len(wav_files)} unique .wav files to process") 133 | for wav_file in tqdm(wav_files, desc="Creating mel spectrograms"): 134 | save_mel(wav_file) 135 | 136 | 137 | def reduce_tensor(tensor: torch.Tensor, n_gpus: int) -> torch.Tensor: 138 | """Reduce tensor across GPUs.""" 139 | rt = tensor.clone() 140 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 141 | rt /= n_gpus 142 | return rt 143 | 144 | 145 | def init_distributed(hparams, n_gpus: int, rank: int, group_name: Optional[str]) -> None: 146 | """Initialize distributed training.""" 147 | assert torch.cuda.is_available(), "Distributed mode requires CUDA." 148 | print("Initializing Distributed") 149 | 150 | torch.cuda.set_device(rank % torch.cuda.device_count()) 151 | 152 | dist.init_process_group( 153 | backend=hparams.dist_backend, 154 | init_method=hparams.dist_url, 155 | world_size=n_gpus, 156 | rank=rank, 157 | group_name=group_name 158 | ) 159 | 160 | print("Done initializing distributed") 161 | 162 | 163 | def prepare_dataloaders(hparams) -> Tuple[DataLoader, Any, Any]: 164 | """Prepare training and validation data loaders.""" 165 | trainset = TextMelLoader(hparams.training_files, hparams) 166 | valset = TextMelLoader(hparams.validation_files, hparams) 167 | collate_fn = TextMelCollate(hparams.n_frames_per_step) 168 | 169 | if hparams.distributed_run: 170 | train_sampler = DistributedSampler(trainset) 171 | shuffle = False 172 | else: 173 | train_sampler = None 174 | shuffle = True 175 | 176 | train_loader = DataLoader( 177 | trainset, 178 | num_workers=1, 179 | shuffle=shuffle, 180 | sampler=train_sampler, 181 | batch_size=hparams.batch_size, 182 | pin_memory=False, 183 | drop_last=True, 184 | collate_fn=collate_fn 185 | ) 186 | return train_loader, valset, collate_fn 187 | 188 | 189 | def prepare_directories_and_logger(output_directory: str, log_directory: str, rank: int): 190 | """Prepare output directories and logger.""" 191 | if rank == 0: 192 | os.makedirs(output_directory, exist_ok=True) 193 | os.chmod(output_directory, 0o775) 194 | tacotron_logger = Tacotron2Logger(os.path.join(output_directory, log_directory)) 195 | else: 196 | tacotron_logger = None 197 | return tacotron_logger 198 | 199 | 200 | def load_model(hparams): 201 | """Load and initialize the Tacotron2 model.""" 202 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 203 | model = Tacotron2(hparams).to(device) 204 | if hparams.fp16_run: 205 | model.decoder.attention_layer.score_mask_value = finfo('float16').min 206 | 207 | if hparams.distributed_run: 208 | model = apply_gradient_allreduce(model) 209 | 210 | return model 211 | 212 | 213 | def warm_start_model(checkpoint_path: str, model, ignore_layers: list): 214 | """Load model weights from checkpoint for warm start.""" 215 | if not os.path.isfile(checkpoint_path): 216 | raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}") 217 | 218 | print(f"Warm starting model from checkpoint '{checkpoint_path}'") 219 | checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') 220 | model_dict = checkpoint_dict['state_dict'] 221 | 222 | if ignore_layers: 223 | model_dict = {k: v for k, v in model_dict.items() if k not in ignore_layers} 224 | dummy_dict = model.state_dict() 225 | dummy_dict.update(model_dict) 226 | model_dict = dummy_dict 227 | 228 | model.load_state_dict(model_dict) 229 | return model 230 | 231 | 232 | def load_checkpoint(checkpoint_path: str, model, optimizer) -> Tuple[Any, Any, float, int]: 233 | """Load model, optimizer state and training progress from checkpoint.""" 234 | print(f"Loading checkpoint '{checkpoint_path}'") 235 | checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') 236 | model.load_state_dict(checkpoint_dict['state_dict']) 237 | optimizer.load_state_dict(checkpoint_dict['optimizer']) 238 | learning_rate = checkpoint_dict['learning_rate'] 239 | iteration = checkpoint_dict['iteration'] 240 | print(f"Loaded checkpoint from iteration {iteration}") 241 | return model, optimizer, learning_rate, iteration 242 | 243 | 244 | def save_checkpoint(model, optimizer, learning_rate: float, iteration: int, filepath: str) -> None: 245 | """Save model checkpoint.""" 246 | print(f"Saving model and optimizer state at iteration {iteration} to {filepath}") 247 | try: 248 | checkpoint_data = { 249 | 'iteration': iteration, 250 | 'state_dict': model.state_dict(), 251 | 'optimizer': optimizer.state_dict(), 252 | 'learning_rate': learning_rate 253 | } 254 | torch.save(checkpoint_data, filepath) 255 | print("Model saved successfully") 256 | except KeyboardInterrupt: 257 | print("Interrupt received while saving, completing save operation...") 258 | torch.save(checkpoint_data, filepath) 259 | print("Model saved successfully after interrupt") 260 | 261 | 262 | def plot_alignment(alignment: np.ndarray, info: Optional[str] = None, 263 | width: int = 1000, height: int = 600) -> None: 264 | """Plot alignment matrix and save to file.""" 265 | fig, ax = plt.subplots(figsize=(width/100, height/100)) 266 | im = ax.imshow(alignment, cmap='inferno', aspect='auto', origin='lower', 267 | interpolation='none') 268 | ax.autoscale(enable=True, axis="y", tight=True) 269 | fig.colorbar(im, ax=ax) 270 | xlabel = 'Decoder timestep' 271 | if info is not None: 272 | xlabel += '\n\n' + info 273 | plt.xlabel(xlabel) 274 | plt.ylabel('Encoder timestep') 275 | plt.tight_layout() 276 | plt.savefig('alignment.png', dpi=100, bbox_inches='tight') 277 | plt.close() 278 | 279 | 280 | def validate(model, criterion, valset, iteration: int, batch_size: int, n_gpus: int, 281 | collate_fn, tacotron_logger, distributed_run: bool, rank: int, 282 | epoch: int, start_epoch: float, learning_rate: float, hparams) -> None: 283 | """Perform validation.""" 284 | model.eval() 285 | with torch.no_grad(): 286 | val_sampler = DistributedSampler(valset) if distributed_run else None 287 | val_loader = DataLoader( 288 | valset, sampler=val_sampler, num_workers=1, 289 | shuffle=False, batch_size=batch_size, 290 | pin_memory=False, collate_fn=collate_fn 291 | ) 292 | 293 | val_loss = 0.0 294 | for i, batch in enumerate(val_loader): 295 | x, y = model.parse_batch(batch) 296 | y_pred = model(x) 297 | loss = criterion(y_pred, y) 298 | if distributed_run: 299 | reduced_val_loss = reduce_tensor(loss.data, n_gpus).item() 300 | else: 301 | reduced_val_loss = loss.item() 302 | val_loss += reduced_val_loss 303 | val_loss = val_loss / (i + 1) 304 | 305 | model.train() 306 | if rank == 0: 307 | elapsed_time = (time.perf_counter() - start_epoch) / 60 308 | print(f"Epoch: {epoch} Validation loss {iteration}: {val_loss:.9f} Time: {elapsed_time:.1f}m LR: {learning_rate:.6f}") 309 | tacotron_logger.log_validation(val_loss, model, y, y_pred, iteration) 310 | 311 | if hparams.show_alignments: 312 | _, mel_outputs, gate_outputs, alignments = y_pred 313 | idx = torch.randint(0, alignments.size(0), (1,)).item() 314 | plot_alignment(alignments[idx].data.cpu().numpy().T) 315 | 316 | 317 | def train(output_directory: str, log_directory: str, checkpoint_path: Optional[str], 318 | warm_start: bool, n_gpus: int, rank: int, group_name: Optional[str], 319 | hparams, log_directory2: Optional[str], checkpoint_folder_path: Optional[str], 320 | preferred_checkpoint: Optional[str]) -> None: 321 | """Main training function.""" 322 | # Set device 323 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 324 | print(f"Using device: {device}") 325 | 326 | if hparams.distributed_run: 327 | init_distributed(hparams, n_gpus, rank, group_name) 328 | 329 | torch.manual_seed(hparams.seed) 330 | if torch.cuda.is_available(): 331 | torch.cuda.manual_seed(hparams.seed) 332 | 333 | model = load_model(hparams) 334 | learning_rate = hparams.learning_rate 335 | optimizer = torch.optim.Adam( 336 | model.parameters(), 337 | lr=learning_rate, 338 | weight_decay=hparams.weight_decay 339 | ) 340 | 341 | if hparams.fp16_run: 342 | try: 343 | from apex import amp 344 | model, optimizer = amp.initialize(model, optimizer, opt_level='O2') 345 | except ImportError: 346 | print("ERROR: apex not available, disabling fp16_run") 347 | hparams.fp16_run = False 348 | 349 | if hparams.distributed_run: 350 | model = apply_gradient_allreduce(model) 351 | 352 | criterion = Tacotron2Loss() 353 | tacotron_logger = prepare_directories_and_logger(output_directory, log_directory, rank) 354 | train_loader, valset, collate_fn = prepare_dataloaders(hparams) 355 | 356 | # Load checkpoint or pretrained model 357 | iteration = 0 358 | epoch_offset = 0 359 | 360 | if checkpoint_path and os.path.isfile(checkpoint_path): 361 | if warm_start: 362 | model = warm_start_model(checkpoint_path, model, hparams.ignore_layers) 363 | else: 364 | model, optimizer, _learning_rate, iteration = load_checkpoint( 365 | checkpoint_path, model, optimizer) 366 | if hparams.use_saved_learning_rate: 367 | learning_rate = _learning_rate 368 | iteration += 1 369 | epoch_offset = max(0, int(iteration / len(train_loader))) 370 | elif checkpoint_folder_path: 371 | print(f"Loading checkpoints from local folder: {checkpoint_folder_path}") 372 | available_checkpoints = load_local_checkpoint_folder(checkpoint_folder_path) 373 | 374 | if available_checkpoints: 375 | selected_checkpoint = select_checkpoint(available_checkpoints, preferred_checkpoint) 376 | if selected_checkpoint: 377 | if warm_start: 378 | model = warm_start_model(selected_checkpoint, model, hparams.ignore_layers) 379 | else: 380 | model, optimizer, _learning_rate, iteration = load_checkpoint( 381 | selected_checkpoint, model, optimizer) 382 | if hparams.use_saved_learning_rate: 383 | learning_rate = _learning_rate 384 | iteration += 1 385 | epoch_offset = max(0, int(iteration / len(train_loader))) 386 | else: 387 | print("WARNING: No checkpoints found in specified folder") 388 | else: 389 | print("No checkpoint specified, starting from scratch") 390 | 391 | model.train() 392 | 393 | # Main training loop 394 | for epoch in tqdm(range(epoch_offset, hparams.epochs), desc="Epochs"): 395 | print(f"Starting Epoch: {epoch} Iteration: {iteration}") 396 | start_epoch = time.perf_counter() 397 | 398 | for i, batch in tqdm(enumerate(train_loader), total=len(train_loader), desc="Batches"): 399 | start = time.perf_counter() 400 | 401 | # Learning rate schedule 402 | if iteration < hparams.decay_start: 403 | learning_rate = hparams.A_ 404 | else: 405 | iteration_adjusted = iteration - hparams.decay_start 406 | learning_rate = (hparams.A_ * (e ** (-iteration_adjusted / hparams.B_))) + hparams.C_ 407 | 408 | learning_rate = max(hparams.min_learning_rate, learning_rate) 409 | 410 | for param_group in optimizer.param_groups: 411 | param_group['lr'] = learning_rate 412 | 413 | model.zero_grad() 414 | x, y = model.parse_batch(batch) 415 | y_pred = model(x) 416 | loss = criterion(y_pred, y) 417 | 418 | if hparams.distributed_run: 419 | reduced_loss = reduce_tensor(loss.data, n_gpus).item() 420 | else: 421 | reduced_loss = loss.item() 422 | 423 | if hparams.fp16_run: 424 | from apex import amp 425 | with amp.scale_loss(loss, optimizer) as scaled_loss: 426 | scaled_loss.backward() 427 | grad_norm = torch.nn.utils.clip_grad_norm_( 428 | amp.master_params(optimizer), hparams.grad_clip_thresh) 429 | is_overflow = math.isnan(grad_norm) 430 | else: 431 | loss.backward() 432 | grad_norm = torch.nn.utils.clip_grad_norm_( 433 | model.parameters(), hparams.grad_clip_thresh) 434 | is_overflow = False 435 | 436 | optimizer.step() 437 | 438 | if not is_overflow and rank == 0: 439 | duration = time.perf_counter() - start 440 | tacotron_logger.log_training( 441 | reduced_loss, grad_norm, learning_rate, duration, iteration) 442 | 443 | iteration += 1 444 | 445 | validate(model, criterion, valset, iteration, hparams.batch_size, n_gpus, 446 | collate_fn, tacotron_logger, hparams.distributed_run, rank, epoch, 447 | start_epoch, learning_rate, hparams) 448 | 449 | save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path) 450 | 451 | if log_directory2: 452 | try: 453 | copy_tree(log_directory, log_directory2) 454 | except Exception as e: 455 | print(f"WARNING: Failed to copy logs: {e}") 456 | 457 | 458 | def check_dataset(hparams) -> None: 459 | """Check dataset files for common issues.""" 460 | def check_file_array(filelist_arr: list, dataset_type: str) -> None: 461 | print(f"Checking {dataset_type} Files") 462 | for i, file_info in enumerate(filelist_arr): 463 | if len(file_info) > 2: 464 | print(f"WARNING: {file_info} has multiple '|', this may cause issues") 465 | 466 | file_path, text = file_info[0], file_info[1] 467 | 468 | # Check file extensions 469 | if hparams.load_mel_from_disk and '.wav' in file_path: 470 | print(f"WARNING: {file_path} is .wav but expecting .npy") 471 | elif not hparams.load_mel_from_disk and '.npy' in file_path: 472 | print(f"WARNING: {file_path} is .npy but expecting .wav") 473 | 474 | # Check file existence 475 | if not os.path.exists(file_path): 476 | print(f"WARNING: {file_path} does not exist") 477 | 478 | # Check text content 479 | if len(text) < 3: 480 | print(f"INFO: {file_path} has very little text: '{text}'") 481 | 482 | if text.strip() and text.strip()[-1] not in "!?,.;:": 483 | print(f"INFO: {file_path} has no ending punctuation") 484 | 485 | # Check training files 486 | try: 487 | training_files = load_filepaths_and_text(hparams.training_files) 488 | check_file_array(training_files, "Training") 489 | except Exception as e: 490 | print(f"ERROR: Error checking training files: {e}") 491 | 492 | # Check validation files 493 | try: 494 | validation_files = load_filepaths_and_text(hparams.validation_files) 495 | check_file_array(validation_files, "Validation") 496 | except Exception as e: 497 | print(f"ERROR: Error checking validation files: {e}") 498 | 499 | print("Finished dataset check") 500 | 501 | 502 | def update_filelists_for_npy() -> None: 503 | """Update filelist files to use .npy extensions instead of .wav.""" 504 | filelist_files = glob.glob("filelists/*.txt") 505 | 506 | for filepath in filelist_files: 507 | try: 508 | with open(filepath, 'r', encoding='utf-8') as file: 509 | content = file.read() 510 | 511 | # Replace .wav| with .npy| 512 | modified_content = content.replace('.wav|', '.npy|') 513 | 514 | with open(filepath, 'w', encoding='utf-8') as file: 515 | file.write(modified_content) 516 | 517 | print(f"Updated {filepath} to use .npy extensions") 518 | except Exception as e: 519 | print(f"ERROR: Error updating {filepath}: {e}") 520 | 521 | 522 | def load_local_checkpoint_folder(folder_path: str) -> list: 523 | """Load checkpoint files from a local folder.""" 524 | if not os.path.exists(folder_path): 525 | print(f"WARNING: Checkpoint folder '{folder_path}' does not exist") 526 | return [] 527 | 528 | checkpoints = [] 529 | for file in os.listdir(folder_path): 530 | if file.endswith('.pt') or file.endswith('.pth') or 'checkpoint' in file.lower(): 531 | checkpoints.append(os.path.join(folder_path, file)) 532 | 533 | if checkpoints: 534 | print(f"Found {len(checkpoints)} checkpoint(s) in '{folder_path}'") 535 | else: 536 | print(f"WARNING: No checkpoint files found in '{folder_path}'") 537 | 538 | return sorted(checkpoints) 539 | 540 | def main(args: argparse.Namespace) -> None: 541 | """Main function.""" 542 | # Create hyperparameters first 543 | hparams = create_hparams() 544 | 545 | # Configure hyperparameters 546 | model_filename = 'current_model.pt' # Add .pt extension 547 | hparams.training_files = args.training_files 548 | hparams.validation_files = args.validation_files 549 | hparams.p_attention_dropout = 0.1 550 | hparams.p_decoder_dropout = 0.1 551 | hparams.decay_start = 15000 552 | hparams.A_ = 5e-4 553 | hparams.B_ = 8000 554 | hparams.C_ = 0 555 | hparams.min_learning_rate = 1e-5 556 | hparams.show_alignments = True 557 | hparams.batch_size = 32 558 | hparams.load_mel_from_disk = True 559 | hparams.ignore_layers = [] 560 | hparams.epochs = 10000 561 | 562 | # Generate mels if requested (before updating filelists) 563 | if args.generate_mels: 564 | create_mels(hparams) 565 | 566 | # Update filelists to use .npy after mel generation 567 | update_filelists_for_npy() 568 | 569 | # Set CUDA backend settings 570 | torch.backends.cudnn.enabled = hparams.cudnn_enabled 571 | torch.backends.cudnn.benchmark = hparams.cudnn_benchmark 572 | 573 | # Set checkpoint path 574 | checkpoint_path = args.checkpoint_path or os.path.join(args.output_directory, model_filename) 575 | 576 | # Check dataset 577 | check_dataset(hparams) 578 | 579 | # Start training 580 | train( 581 | output_directory=args.output_directory, 582 | log_directory=args.log_directory, 583 | checkpoint_path=checkpoint_path, 584 | warm_start=args.warm_start, 585 | n_gpus=args.n_gpus, 586 | rank=args.rank, 587 | group_name=args.group_name, 588 | hparams=hparams, 589 | log_directory2=args.log_directory2, 590 | checkpoint_folder_path=args.checkpoint_folder_path, 591 | preferred_checkpoint=args.preferred_checkpoint 592 | ) 593 | 594 | 595 | if __name__ == "__main__": 596 | parser = argparse.ArgumentParser(description='Tacotron2 Training') 597 | parser.add_argument('--warm_start', action='store_true', default=False, 598 | help='Warm start model from checkpoint') 599 | parser.add_argument('--n_gpus', type=int, default=1, 600 | help='Number of GPUs to use for training') 601 | parser.add_argument('--rank', type=int, default=0, 602 | help='Rank of the current GPU') 603 | parser.add_argument('--num_workers', type=int, default=1, 604 | help='Number of data loading workers') 605 | parser.add_argument('--group_name', type=str, default=None, 606 | help='Name of the distributed group') 607 | parser.add_argument('--output_directory', type=str, default='checkpoints', 608 | help='Directory to save checkpoints') 609 | parser.add_argument('--log_directory', type=str, default='logs', 610 | help='Directory to save tensorboard logs') 611 | parser.add_argument('--log_directory2', type=str, default=None, 612 | help='Directory to copy tensorboard logs after each epoch') 613 | parser.add_argument('--checkpoint_path', type=str, default=None, 614 | help='Path to the checkpoint file') 615 | parser.add_argument('--checkpoint_folder_path', type=str, default=None, 616 | help='Local folder path containing checkpoints') 617 | parser.add_argument('--preferred_checkpoint', type=str, default=None, 618 | help='Preferred checkpoint filename to load (if empty, uses latest)') 619 | parser.add_argument('--training_files', type=str, default='filelists/train_list.txt', 620 | help='Path to training files list') 621 | parser.add_argument('--validation_files', type=str, default='filelists/validation_list.txt', 622 | help='Path to validation files list') 623 | parser.add_argument('--generate_mels', action='store_true', default=True, 624 | help='Generate mel spectrograms from audio files') 625 | 626 | args = parser.parse_args() 627 | main(args) --------------------------------------------------------------------------------