├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── _note.ipynb ├── audio ├── __init__.py ├── audio_processing.py ├── stft.py └── tools.py ├── config └── kss │ ├── model.yaml │ ├── preprocess.yaml │ └── train.yaml ├── conformer ├── LICENSE ├── README.md ├── __init__.py ├── conformer │ ├── __init__.py │ ├── activation.py │ ├── attention.py │ ├── convolution.py │ ├── embedding.py │ ├── encoder.py │ ├── feed_forward.py │ ├── model.py │ └── modules.py ├── docs │ ├── .DS_Store │ ├── Makefile │ ├── Model.html │ ├── Modules.html │ ├── Submodules.html │ ├── _sources │ │ ├── Model.rst.txt │ │ ├── Modules.rst.txt │ │ ├── Submodules.rst.txt │ │ └── index.rst.txt │ ├── _static │ │ ├── basic.css │ │ ├── css │ │ │ ├── badge_only.css │ │ │ ├── fonts │ │ │ │ ├── Roboto-Slab-Bold.woff │ │ │ │ ├── Roboto-Slab-Bold.woff2 │ │ │ │ ├── Roboto-Slab-Regular.woff │ │ │ │ ├── Roboto-Slab-Regular.woff2 │ │ │ │ ├── fontawesome-webfont.eot │ │ │ │ ├── fontawesome-webfont.svg │ │ │ │ ├── fontawesome-webfont.ttf │ │ │ │ ├── fontawesome-webfont.woff │ │ │ │ ├── fontawesome-webfont.woff2 │ │ │ │ ├── lato-bold-italic.woff │ │ │ │ ├── lato-bold-italic.woff2 │ │ │ │ ├── lato-bold.woff │ │ │ │ ├── lato-bold.woff2 │ │ │ │ ├── lato-normal-italic.woff │ │ │ │ ├── lato-normal-italic.woff2 │ │ │ │ ├── lato-normal.woff │ │ │ │ └── lato-normal.woff2 │ │ │ └── theme.css │ │ ├── doctools.js │ │ ├── documentation_options.js │ │ ├── file.png │ │ ├── jquery-3.5.1.js │ │ ├── jquery.js │ │ ├── js │ │ │ ├── badge_only.js │ │ │ ├── html5shiv-printshiv.min.js │ │ │ ├── html5shiv.min.js │ │ │ └── theme.js │ │ ├── language_data.js │ │ ├── minus.png │ │ ├── plus.png │ │ ├── pygments.css │ │ ├── searchtools.js │ │ ├── underscore-1.3.1.js │ │ └── underscore.js │ ├── genindex.html │ ├── index.html │ ├── make.bat │ ├── objects.inv │ ├── search.html │ ├── searchindex.js │ └── source │ │ ├── Model.rst │ │ ├── Modules.rst │ │ ├── Submodules.rst │ │ ├── conf.py │ │ └── index.rst └── setup.py ├── data_utils.py ├── evaluate.py ├── mel_processing.py ├── model ├── __init__.py ├── cvaejets.py ├── layers.py ├── loss.py └── modules.py ├── preprocess.py ├── preprocessed_data └── kss │ ├── speakers.json │ └── stats.json ├── requirements.txt ├── samples ├── CVAEJETS-sample-0.75.wav ├── CVAEJETS-sample-1.00-vc.wav ├── CVAEJETS-sample-1.00.wav ├── CVAEJETS-sample-1.50.wav ├── CVAEJETS-tensorboard-losses1.png ├── CVAEJETS-tensorboard-losses2.png └── CVAEJETS-tensorboard-stats.png ├── text └── __init__.py ├── train.py └── utils ├── model.py ├── pitch_utils.py ├── stft_loss.py └── tools.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:1.8.1-cuda11.1-cudnn8-devel 2 | 3 | COPY ./requirements.txt /tmp/requirements.txt 4 | ARG DEBIAN_FRONTEND=noninteractive 5 | ENV TZ=Asia/Seoul 6 | # RUN apt-key adv --keyserver keyserver.ubuntu.com --recv-keys 7 | RUN apt-get update && \ 8 | apt-get -y -qq update && \ 9 | apt-get install -y apt-utils && \ 10 | apt-get install -y curl && \ 11 | apt-get install -y tzdata && \ 12 | apt-get install -y python3-ipykernel && \ 13 | apt-get install -y espeak && \ 14 | apt-get install -y openssh-client && \ 15 | apt-get install -y curl && \ 16 | apt-get install -y git && \ 17 | apt-get install -y vim && \ 18 | apt-get install -y libsndfile-dev && \ 19 | apt-get install -y gcc && \ 20 | apt-get install -y ffmpeg && \ 21 | apt-get install -y locales && \ 22 | apt-get install -y language-pack-ko && \ 23 | apt-get install -y default-jre && \ 24 | apt-get install -y screen && \ 25 | apt-get install -y zip && \ 26 | apt-get install -y unzip && \ 27 | apt-get install -y sshfs 28 | 29 | RUN pip install --upgrade pip 30 | RUN /bin/bash -c "bash <(curl -s https://raw.githubusercontent.com/konlpy/konlpy/master/scripts/mecab.sh) &&\ 31 | pip install -r /tmp/requirements.txt --ignore-installed" 32 | RUN pip install --upgrade tensorboard 33 | RUN pip install jupyter -U 34 | 35 | VOLUME /home/work 36 | WORKDIR /home/work 37 | 38 | CMD ["/bin/bash"] 39 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 choihk 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Introduction 2 | 1. FastSpeech2, HiFi-GAN, VITS, Conformer 오픈 소스를 활용하여 JETS(End-To-End)를 간단 구현하고 한국어 데이터셋(KSS)을 사용해 빠르게 학습합니다. 3 | 2. Adversarial Training에서 Discriminator는 VITS에서 사용한 모듈을 그대로 사용합니다. 4 | 3. 효과적인 Alignment Learning을 위해 Text Sequence 내부 blank token을 추가합니다. 5 | 4. 본 레포지토리에서 HiFi-GAN에서 제안하는 l1 reconstructure loss(only log mel magnitude)를 그대로 사용하면 adversarial loss에서 issue가 발생합니다. 따라서 log stft magnitude와 l1 norm이 같이 계산되는 stft loss로 대체했습니다. 6 | 5. 확장성을 위하여 기존 FastSpeech2 구조에서 Decoder 대신 VITS의 Normalizing Flows(CouplingLayer)를 사용하였습니다. 따라서 Posterior Encoder도 같이 사용됩니다. (Quality 향상, Voice Conversion 목적) 7 | 6. 기존 Posterior Encoder는 Linear Spectrogram을 입력값으로 사용하지만, 본 레포지토리에서는 Mel Spectrogram을 사용합니다. 8 | 7. 기존 오픈소스는 MFA기반 preprocessing을 진행한 상태에서 학습을 진행하지만 본 레포지토리에서는 alignment learning 기반 학습을 진행하고 preprocessing으로 인해 발생할 수 있는 디스크 용량 문제를 방지하기 위해 data_utils.py로부터 학습 데이터가 feeding됩니다. 9 | 8. conda 환경으로 진행해도 무방하지만 본 레포지토리에서는 docker 환경만 제공합니다. 기본적으로 ubuntu에 docker, nvidia-docker가 설치되었다고 가정합니다. 10 | 9. GPU, CUDA 종류에 따라 Dockerfile 상단 torch image 수정이 필요할 수도 있습니다. 11 | 10. preprocessing 단계에서는 학습에 필요한 transcript와 stats 정도만 추출하는 과정만 포함되어 있습니다. 12 | 11. 그 외의 다른 preprocessing 과정은 필요하지 않습니다. 13 | 12. 직전 레포지토리 [VAEJETS](https://github.com/choiHkk/VAEJETS) 보다 powerful하고 training time이 감소되었습니다. 14 | 13. End-To-End & Adversarial training 기반이기 때문에 우수한 품질의 오디오를 생성하기 위해선 많은 학습을 필요로 합니다. 15 | 16 | ## Dataset 17 | 1. download dataset - https://www.kaggle.com/datasets/bryanpark/korean-single-speaker-speech-dataset 18 | 2. `unzip /path/to/the/kss.zip -d /path/to/the/kss` 19 | 3. `mkdir /path/to/the/CVAEJETS/data/dataset` 20 | 4. `mv /path/to/the/kss.zip /path/to/the/CVAEJETS/data/dataset` 21 | 22 | ## Docker build 23 | 1. `cd /path/to/the/CVAEJETS` 24 | 2. `docker build --tag CVAEJETS:latest .` 25 | 26 | ## Training 27 | 1. `nvidia-docker run -it --name 'CVAEJETS' -v /path/to/CVAEJETS:/home/work/CVAEJETS --ipc=host --privileged CVAEJETS:latest` 28 | 2. `cd /home/work/CVAEJETS` 29 | 5. `ln -s /home/work/CVAEJETS/data/dataset/kss` 30 | 6. `python preprocess.py ./config/kss/preprocess.yaml` 31 | 7. `python train.py -p ./config/kss/preprocess.yaml -m ./config/kss/model.yaml -t ./config/kss/train.yaml` 32 | 8. `python train.py --restore_step -p ./config/kss/preprocess.yaml -m ./config/kss/model.yaml -t ./config/kss/train.yaml` 33 | 9. arguments 34 | * -p : preprocess config path 35 | * -m : model config path 36 | * -t : train config path 37 | 10. (OPTIONAL) `tensorboard --logdir=outdir/logdir` 38 | 39 | ## Tensorboard losses 40 | ![CVAEJETS-tensorboard-losses1](https://user-images.githubusercontent.com/69423543/185771913-20621fca-c0fb-4e41-93f4-905e2ffaa13e.png) 41 | ![CVAEJETS-tensorboard-losses2](https://user-images.githubusercontent.com/69423543/185771915-65a16463-91c5-4030-ad46-379cc420de1a.png) 42 | 43 | 44 | ## Tensorboard Stats 45 | ![CVAEJETS-tensorboard-stats](https://user-images.githubusercontent.com/69423543/185771918-f14d33d2-e2f3-4bfe-bd66-a7ac9523edac.png) 46 | 47 | 48 | ## Reference 49 | 1. [VAENAR-TTS: Variational Auto-Encoder based Non-AutoRegressive Text-to-Speech Synthesis](https://arxiv.org/abs/2107.03298) 50 | 2. [JETS: Jointly Training FastSpeech2 and HiFi-GAN for End to End Text to Speech](https://arxiv.org/abs/2203.16852) 51 | 3. [Comprehensive-Transformer-TTS](https://github.com/keonlee9420/Comprehensive-Transformer-TTS) 52 | 4. [Comprehensive-E2E-TTS](https://github.com/keonlee9420/Comprehensive-E2E-TTS) 53 | 5. [Conformer](https://github.com/sooftware/conformer) - [paper](https://arxiv.org/abs/2005.08100) 54 | 6. [FastSpeech2](https://github.com/ming024/FastSpeech2) 55 | 7. [HiFi-GAN](https://github.com/jik876/hifi-gan) 56 | 8. [VAEJETS](https://github.com/choiHkk/VAEJETS) 57 | 9. [VITS](https://github.com/jaywalnut310/vits) 58 | -------------------------------------------------------------------------------- /audio/__init__.py: -------------------------------------------------------------------------------- 1 | import audio.tools 2 | import audio.stft 3 | import audio.audio_processing 4 | -------------------------------------------------------------------------------- /audio/audio_processing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import librosa.util as librosa_util 4 | from scipy.signal import get_window 5 | 6 | 7 | def window_sumsquare( 8 | window, 9 | n_frames, 10 | hop_length, 11 | win_length, 12 | n_fft, 13 | dtype=np.float32, 14 | norm=None, 15 | ): 16 | """ 17 | # from librosa 0.6 18 | Compute the sum-square envelope of a window function at a given hop length. 19 | 20 | This is used to estimate modulation effects induced by windowing 21 | observations in short-time fourier transforms. 22 | 23 | Parameters 24 | ---------- 25 | window : string, tuple, number, callable, or list-like 26 | Window specification, as in `get_window` 27 | 28 | n_frames : int > 0 29 | The number of analysis frames 30 | 31 | hop_length : int > 0 32 | The number of samples to advance between frames 33 | 34 | win_length : [optional] 35 | The length of the window function. By default, this matches `n_fft`. 36 | 37 | n_fft : int > 0 38 | The length of each analysis frame. 39 | 40 | dtype : np.dtype 41 | The data type of the output 42 | 43 | Returns 44 | ------- 45 | wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))` 46 | The sum-squared envelope of the window function 47 | """ 48 | if win_length is None: 49 | win_length = n_fft 50 | 51 | n = n_fft + hop_length * (n_frames - 1) 52 | x = np.zeros(n, dtype=dtype) 53 | 54 | # Compute the squared window at the desired length 55 | win_sq = get_window(window, win_length, fftbins=True) 56 | win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2 57 | win_sq = librosa_util.pad_center(win_sq, n_fft) 58 | 59 | # Fill the envelope 60 | for i in range(n_frames): 61 | sample = i * hop_length 62 | x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))] 63 | return x 64 | 65 | 66 | def griffin_lim(magnitudes, stft_fn, n_iters=30): 67 | """ 68 | PARAMS 69 | ------ 70 | magnitudes: spectrogram magnitudes 71 | stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods 72 | """ 73 | 74 | angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size()))) 75 | angles = angles.astype(np.float32) 76 | angles = torch.autograd.Variable(torch.from_numpy(angles)) 77 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1) 78 | 79 | for i in range(n_iters): 80 | _, angles = stft_fn.transform(signal) 81 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1) 82 | return signal 83 | 84 | 85 | def dynamic_range_compression(x, C=1, clip_val=1e-5): 86 | """ 87 | PARAMS 88 | ------ 89 | C: compression factor 90 | """ 91 | return torch.log(torch.clamp(x, min=clip_val) * C) 92 | 93 | 94 | def dynamic_range_decompression(x, C=1): 95 | """ 96 | PARAMS 97 | ------ 98 | C: compression factor used to compress 99 | """ 100 | return torch.exp(x) / C 101 | -------------------------------------------------------------------------------- /audio/stft.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.filterwarnings('ignore') 3 | import torch 4 | import torch.nn.functional as F 5 | import numpy as np 6 | from scipy.signal import get_window 7 | from librosa.util import pad_center, tiny 8 | from librosa.filters import mel as librosa_mel_fn 9 | 10 | from audio.audio_processing import ( 11 | dynamic_range_compression, 12 | dynamic_range_decompression, 13 | window_sumsquare, 14 | ) 15 | 16 | 17 | class STFT(torch.nn.Module): 18 | """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" 19 | 20 | def __init__(self, filter_length, hop_length, win_length, window="hann"): 21 | super(STFT, self).__init__() 22 | self.filter_length = filter_length 23 | self.hop_length = hop_length 24 | self.win_length = win_length 25 | self.window = window 26 | self.forward_transform = None 27 | scale = self.filter_length / self.hop_length 28 | fourier_basis = np.fft.fft(np.eye(self.filter_length)) 29 | 30 | cutoff = int((self.filter_length / 2 + 1)) 31 | fourier_basis = np.vstack( 32 | [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])] 33 | ) 34 | 35 | forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) 36 | inverse_basis = torch.FloatTensor( 37 | np.linalg.pinv(scale * fourier_basis).T[:, None, :] 38 | ) 39 | 40 | if window is not None: 41 | assert filter_length >= win_length 42 | # get window and zero center pad it to filter_length 43 | fft_window = get_window(window, win_length, fftbins=True) 44 | fft_window = pad_center(fft_window, filter_length) 45 | fft_window = torch.from_numpy(fft_window).float() 46 | 47 | # window the bases 48 | forward_basis *= fft_window 49 | inverse_basis *= fft_window 50 | 51 | self.register_buffer("forward_basis", forward_basis.float()) 52 | self.register_buffer("inverse_basis", inverse_basis.float()) 53 | 54 | def transform(self, input_data): 55 | num_batches = input_data.size(0) 56 | num_samples = input_data.size(1) 57 | 58 | self.num_samples = num_samples 59 | 60 | # similar to librosa, reflect-pad the input 61 | input_data = input_data.view(num_batches, 1, num_samples) 62 | input_data = F.pad( 63 | input_data.unsqueeze(1), 64 | (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0), 65 | mode="reflect", 66 | ) 67 | input_data = input_data.squeeze(1) 68 | 69 | forward_transform = F.conv1d( 70 | input_data,#.cuda(), 71 | torch.autograd.Variable(self.forward_basis, requires_grad=False),#.cuda(), 72 | stride=self.hop_length, 73 | padding=0, 74 | ).cpu() 75 | 76 | cutoff = int((self.filter_length / 2) + 1) 77 | real_part = forward_transform[:, :cutoff, :] 78 | imag_part = forward_transform[:, cutoff:, :] 79 | 80 | magnitude = torch.sqrt(real_part ** 2 + imag_part ** 2) 81 | phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data)) 82 | 83 | return magnitude, phase 84 | 85 | def inverse(self, magnitude, phase): 86 | recombine_magnitude_phase = torch.cat( 87 | [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1 88 | ) 89 | 90 | inverse_transform = F.conv_transpose1d( 91 | recombine_magnitude_phase, 92 | torch.autograd.Variable(self.inverse_basis, requires_grad=False), 93 | stride=self.hop_length, 94 | padding=0, 95 | ) 96 | 97 | if self.window is not None: 98 | window_sum = window_sumsquare( 99 | self.window, 100 | magnitude.size(-1), 101 | hop_length=self.hop_length, 102 | win_length=self.win_length, 103 | n_fft=self.filter_length, 104 | dtype=np.float32, 105 | ) 106 | # remove modulation effects 107 | approx_nonzero_indices = torch.from_numpy( 108 | np.where(window_sum > tiny(window_sum))[0] 109 | ) 110 | window_sum = torch.autograd.Variable( 111 | torch.from_numpy(window_sum), requires_grad=False 112 | ) 113 | window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum 114 | inverse_transform[:, :, approx_nonzero_indices] /= window_sum[ 115 | approx_nonzero_indices 116 | ] 117 | 118 | # scale by hop ratio 119 | inverse_transform *= float(self.filter_length) / self.hop_length 120 | 121 | inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :] 122 | inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :] 123 | 124 | return inverse_transform 125 | 126 | def forward(self, input_data): 127 | self.magnitude, self.phase = self.transform(input_data) 128 | reconstruction = self.inverse(self.magnitude, self.phase) 129 | return reconstruction 130 | 131 | 132 | class TacotronSTFT(torch.nn.Module): 133 | def __init__( 134 | self, 135 | filter_length, 136 | hop_length, 137 | win_length, 138 | n_mel_channels, 139 | sampling_rate, 140 | mel_fmin, 141 | mel_fmax 142 | ): 143 | super(TacotronSTFT, self).__init__() 144 | self.n_mel_channels = n_mel_channels 145 | self.sampling_rate = sampling_rate 146 | self.stft_fn = STFT(filter_length, hop_length, win_length) 147 | mel_basis = librosa_mel_fn( 148 | sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax 149 | ) 150 | mel_basis = torch.from_numpy(mel_basis).float() 151 | self.register_buffer("mel_basis", mel_basis) 152 | 153 | def spectral_normalize(self, magnitudes): 154 | output = dynamic_range_compression(magnitudes) 155 | return output 156 | 157 | def spectral_de_normalize(self, magnitudes): 158 | output = dynamic_range_decompression(magnitudes) 159 | return output 160 | 161 | def mel_spectrogram(self, y): 162 | """Computes mel-spectrograms from a batch of waves 163 | PARAMS 164 | ------ 165 | y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1] 166 | 167 | RETURNS 168 | ------- 169 | mel_output: torch.FloatTensor of shape (B, n_mel_channels, T) 170 | """ 171 | assert torch.min(y.data) >= -1 172 | assert torch.max(y.data) <= 1 173 | 174 | magnitudes, phases = self.stft_fn.transform(y.float()) 175 | magnitudes = magnitudes.data 176 | mel_output = torch.matmul( 177 | self.mel_basis.to(dtype=y.dtype, device=y.device), 178 | magnitudes.to(dtype=y.dtype, device=y.device) 179 | ) 180 | mel_output = self.spectral_normalize(mel_output) 181 | energy = torch.norm(magnitudes, dim=1) 182 | 183 | return mel_output, energy 184 | -------------------------------------------------------------------------------- /audio/tools.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from scipy.io.wavfile import write 4 | 5 | from audio.audio_processing import griffin_lim 6 | 7 | 8 | def get_mel_from_wav(audio, _stft): 9 | audio = torch.clip(torch.FloatTensor(audio).unsqueeze(0), -1, 1) 10 | audio = torch.autograd.Variable(audio, requires_grad=False) 11 | melspec, energy = _stft.mel_spectrogram(audio) 12 | melspec = torch.squeeze(melspec, 0).numpy().astype(np.float32) 13 | energy = torch.squeeze(energy, 0).numpy().astype(np.float32) 14 | 15 | return melspec, energy 16 | 17 | 18 | def inv_mel_spec(mel, out_filename, _stft, griffin_iters=60): 19 | mel = torch.stack([mel]) 20 | mel_decompress = _stft.spectral_de_normalize(mel) 21 | mel_decompress = mel_decompress.transpose(1, 2).data.cpu() 22 | spec_from_mel_scaling = 1000 23 | spec_from_mel = torch.mm(mel_decompress[0], _stft.mel_basis) 24 | spec_from_mel = spec_from_mel.transpose(0, 1).unsqueeze(0) 25 | spec_from_mel = spec_from_mel * spec_from_mel_scaling 26 | 27 | audio = griffin_lim( 28 | torch.autograd.Variable(spec_from_mel[:, :, :-1]), _stft._stft_fn, griffin_iters 29 | ) 30 | 31 | audio = audio.squeeze() 32 | audio = audio.cpu().numpy() 33 | audio_path = out_filename 34 | write(audio_path, _stft.sampling_rate, audio) 35 | -------------------------------------------------------------------------------- /config/kss/model.yaml: -------------------------------------------------------------------------------- 1 | speaker_encoder: 2 | speaker_encoder_hidden: 256 3 | 4 | transformer: 5 | encoder_layer: 4 6 | encoder_head: 2 7 | encoder_hidden: 192 8 | feed_forward_expansion_factor: 4 9 | conv_expansion_factor: 2 10 | input_dropout_p: 0.1 11 | feed_forward_dropout_p: 0.1 12 | attention_dropout_p: 0.1 13 | conv_dropout_p: 0.1 14 | conv_kernel_size: 31 15 | half_step_residual: True 16 | 17 | posterior_encoder: 18 | posterior_encoder_kernel_size: 5 19 | posterior_encoder_dilation_rate: 1 20 | posterior_encoder_n_layers: 16 21 | 22 | residual_coupling_block: 23 | residual_coupling_block_kernel_size: 5 24 | residual_coupling_block_dilation_rate: 1 25 | residual_coupling_block_n_layers: 4 26 | residual_coupling_block_n_flows: 4 27 | 28 | variance_predictor: 29 | filter_size: 256 30 | kernel_size: 3 31 | dropout: 0.5 32 | cwt_hidden_size: 128 33 | cwt_std_scale: 0.8 34 | cwt_out_dims: 11 35 | cwt_stats_out_dims: 2 36 | 37 | variance_embedding: 38 | energy_quantization: "linear" # support 'linear' or 'log', 'log' is allowed only if the energy values are not normalized during preprocessing 39 | n_bins: 256 40 | 41 | generator: 42 | resblock: "1" 43 | segment_size: 8192 44 | generator_hidden: 192 45 | upsample_rates: [8,8,2,2] 46 | upsample_kernel_sizes: [16,16,4,4] 47 | upsample_initial_channel: 512 48 | resblock_kernel_sizes: [3,7,11] 49 | resblock_dilation_sizes: [[1,3,5], [1,3,5], [1,3,5]] 50 | 51 | temperature: 0.0005 52 | 53 | max_seq_len: 1000 54 | -------------------------------------------------------------------------------- /config/kss/preprocess.yaml: -------------------------------------------------------------------------------- 1 | dataset: "kss" 2 | 3 | path: 4 | corpus_path: "/home/work/CVAEJETS/kss" 5 | raw_path: "/home/work/CVAEJETS/kss/transcript.v.1.2.txt" 6 | preprocessed_path: "./preprocessed_data/kss" 7 | training_files: "./preprocessed_data/kss/train.txt" 8 | validation_files: "./preprocessed_data/kss/val.txt" 9 | 10 | preprocessing: 11 | val_size: 512 12 | duration: 13 | beta_binomial_scaling_factor: 1 14 | text: 15 | text_cleaners: ["korean_cleaners"] 16 | language: "ko" 17 | use_intersperse: True 18 | audio: 19 | trim_top_db: 35 20 | trim_frame_length: 6000 21 | trim_hop_length: 200 22 | sampling_rate: 22050 23 | max_wav_value: 32768.0 24 | stft: 25 | filter_length: 1024 26 | hop_length: 256 27 | win_length: 1024 28 | mel: 29 | n_mel_channels: 80 30 | mel_fmin: 0 31 | mel_fmax: 11025 32 | pitch: 33 | feature: "frame_level" 34 | normalization: True 35 | cwt_scales: -1 36 | pitch_norm: "log" 37 | pitch_norm_eps: 0.000000001 38 | use_uv: True 39 | energy: 40 | feature: "frame_level" 41 | normalization: True 42 | -------------------------------------------------------------------------------- /config/kss/train.yaml: -------------------------------------------------------------------------------- 1 | path: 2 | ckpt_path: "./output/ckpt/kss" 3 | log_path: "./output/log/kss" 4 | result_path: "./output/result/kss" 5 | optimizer: 6 | batch_size: 16 7 | betas: [0.8, 0.99] 8 | eps: 0.000000001 9 | learning_rate: 0.0002 10 | lr_decay: 0.999 11 | step: 12 | total_step: 300000 13 | log_step: 1 14 | val_step: 1000 15 | save_step: 1000 16 | duration: 17 | binarization_start_steps: 6000 18 | binarization_loss_enable_steps: 18000 19 | binarization_loss_warmup_steps: 10000 20 | loss: 21 | stft: 22 | c_stft: 20 23 | fft_sizes: [1024, 2048, 512] 24 | hop_sizes: [128, 256, 64] 25 | win_lengths: [1024, 2048, 512] -------------------------------------------------------------------------------- /conformer/README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 | 4 |

5 | 6 | **PyTorch implementation of Conformer: Convolution-augmented Transformer for Speech Recognition.** 7 | 8 | 9 |
10 | 11 | *** 12 | 13 |

14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | Transformer models are good at capturing content-based global interactions, while CNNs exploit local features effectively. Conformer combine convolution neural networks and transformers to model both local and global dependencies of an audio sequence in a parameter-efficient way. Conformer significantly outperforms the previous Transformer and CNN based models achieving state-of-the-art accuracies. 32 | 33 | 34 | 35 | This repository contains only model code, but you can train with conformer at [openspeech](https://github.com/openspeech-team/openspeech) 36 | 37 | ## Installation 38 | This project recommends Python 3.7 or higher. 39 | We recommend creating a new virtual environment for this project (using virtual env or conda). 40 | 41 | ### Prerequisites 42 | * Numpy: `pip install numpy` (Refer [here](https://github.com/numpy/numpy) for problem installing Numpy). 43 | * Pytorch: Refer to [PyTorch website](http://pytorch.org/) to install the version w.r.t. your environment. 44 | 45 | ### Install from source 46 | Currently we only support installation from source code using setuptools. Checkout the source code and run the 47 | following commands: 48 | 49 | ``` 50 | pip install -e . 51 | ``` 52 | 53 | ## Usage 54 | 55 | ```python 56 | import torch 57 | import torch.nn as nn 58 | from conformer import Conformer 59 | 60 | batch_size, sequence_length, dim = 3, 12345, 80 61 | 62 | cuda = torch.cuda.is_available() 63 | device = torch.device('cuda' if cuda else 'cpu') 64 | 65 | criterion = nn.CTCLoss().to(device) 66 | 67 | inputs = torch.rand(batch_size, sequence_length, dim).to(device) 68 | input_lengths = torch.LongTensor([12345, 12300, 12000]) 69 | targets = torch.LongTensor([[1, 3, 3, 3, 3, 3, 4, 5, 6, 2], 70 | [1, 3, 3, 3, 3, 3, 4, 5, 2, 0], 71 | [1, 3, 3, 3, 3, 3, 4, 2, 0, 0]]).to(device) 72 | target_lengths = torch.LongTensor([9, 8, 7]) 73 | 74 | model = Conformer(num_classes=10, 75 | input_dim=dim, 76 | encoder_dim=32, 77 | num_encoder_layers=3).to(device) 78 | 79 | # Forward propagate 80 | outputs, output_lengths = model(inputs, input_lengths) 81 | 82 | # Calculate CTC Loss 83 | loss = criterion(outputs.transpose(0, 1), targets, output_lengths, target_lengths) 84 | ``` 85 | 86 | ## Troubleshoots and Contributing 87 | If you have any questions, bug reports, and feature requests, please [open an issue](https://github.com/sooftware/conformer/issues) on github or 88 | contacts sh951011@gmail.com please. 89 | 90 | I appreciate any kind of feedback or contribution. Feel free to proceed with small issues like bug fixes, documentation improvement. For major contributions and new features, please discuss with the collaborators in corresponding issues. 91 | 92 | ## Code Style 93 | I follow [PEP-8](https://www.python.org/dev/peps/pep-0008/) for code style. Especially the style of docstrings is important to generate documentation. 94 | 95 | ## Reference 96 | - [Conformer: Convolution-augmented Transformer for Speech Recognition](https://arxiv.org/pdf/2005.08100.pdf) 97 | - [Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context](https://arxiv.org/abs/1901.02860) 98 | - [kimiyoung/transformer-xl](https://github.com/kimiyoung/transformer-xl) 99 | - [espnet/espnet](https://github.com/espnet/espnet) 100 | 101 | ## Author 102 | 103 | * Soohwan Kim [@sooftware](https://github.com/sooftware) 104 | * Contacts: sh951011@gmail.com 105 | -------------------------------------------------------------------------------- /conformer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, Soohwan Kim. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .conformer.model import Conformer 16 | -------------------------------------------------------------------------------- /conformer/conformer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, Soohwan Kim. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # from .model import Conformer 16 | -------------------------------------------------------------------------------- /conformer/conformer/activation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, Soohwan Kim. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch.nn as nn 16 | from torch import Tensor 17 | 18 | 19 | class Swish(nn.Module): 20 | """ 21 | Swish is a smooth, non-monotonic function that consistently matches or outperforms ReLU on deep networks applied 22 | to a variety of challenging domains such as Image classification and Machine translation. 23 | """ 24 | def __init__(self): 25 | super(Swish, self).__init__() 26 | 27 | def forward(self, inputs: Tensor) -> Tensor: 28 | return inputs * inputs.sigmoid() 29 | 30 | 31 | class GLU(nn.Module): 32 | """ 33 | The gating mechanism is called Gated Linear Units (GLU), which was first introduced for natural language processing 34 | in the paper “Language Modeling with Gated Convolutional Networks” 35 | """ 36 | def __init__(self, dim: int) -> None: 37 | super(GLU, self).__init__() 38 | self.dim = dim 39 | 40 | def forward(self, inputs: Tensor) -> Tensor: 41 | outputs, gate = inputs.chunk(2, dim=self.dim) 42 | return outputs * gate.sigmoid() 43 | -------------------------------------------------------------------------------- /conformer/conformer/attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, Soohwan Kim. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import math 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | from torch import Tensor 20 | from typing import Optional 21 | 22 | from .embedding import PositionalEncoding 23 | from .modules import Linear 24 | 25 | 26 | class RelativeMultiHeadAttention(nn.Module): 27 | """ 28 | Multi-head attention with relative positional encoding. 29 | This concept was proposed in the "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" 30 | 31 | Args: 32 | d_model (int): The dimension of model 33 | num_heads (int): The number of attention heads. 34 | dropout_p (float): probability of dropout 35 | 36 | Inputs: query, key, value, pos_embedding, mask 37 | - **query** (batch, time, dim): Tensor containing query vector 38 | - **key** (batch, time, dim): Tensor containing key vector 39 | - **value** (batch, time, dim): Tensor containing value vector 40 | - **pos_embedding** (batch, time, dim): Positional embedding tensor 41 | - **mask** (batch, 1, time2) or (batch, time1, time2): Tensor containing indices to be masked 42 | 43 | Returns: 44 | - **outputs**: Tensor produces by relative multi head attention module. 45 | """ 46 | def __init__( 47 | self, 48 | d_model: int = 512, 49 | num_heads: int = 16, 50 | dropout_p: float = 0.1, 51 | ): 52 | super(RelativeMultiHeadAttention, self).__init__() 53 | assert d_model % num_heads == 0, "d_model % num_heads should be zero." 54 | self.d_model = d_model 55 | self.d_head = int(d_model / num_heads) 56 | self.num_heads = num_heads 57 | self.sqrt_dim = math.sqrt(d_model) 58 | 59 | self.query_proj = Linear(d_model, d_model) 60 | self.key_proj = Linear(d_model, d_model) 61 | self.value_proj = Linear(d_model, d_model) 62 | self.pos_proj = Linear(d_model, d_model, bias=False) 63 | 64 | self.dropout = nn.Dropout(p=dropout_p) 65 | self.u_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head)) 66 | self.v_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head)) 67 | torch.nn.init.xavier_uniform_(self.u_bias) 68 | torch.nn.init.xavier_uniform_(self.v_bias) 69 | 70 | self.out_proj = Linear(d_model, d_model) 71 | 72 | def forward( 73 | self, 74 | query: Tensor, 75 | key: Tensor, 76 | value: Tensor, 77 | pos_embedding: Tensor, 78 | mask: Optional[Tensor] = None, 79 | ) -> Tensor: 80 | batch_size = value.size(0) 81 | 82 | query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head) 83 | key = self.key_proj(key).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3) 84 | value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3) 85 | pos_embedding = self.pos_proj(pos_embedding).view(batch_size, -1, self.num_heads, self.d_head) 86 | 87 | content_score = torch.matmul((query + self.u_bias).transpose(1, 2), key.transpose(2, 3)) 88 | pos_score = torch.matmul((query + self.v_bias).transpose(1, 2), pos_embedding.permute(0, 2, 3, 1)) 89 | pos_score = self._relative_shift(pos_score) 90 | 91 | score = (content_score + pos_score) / self.sqrt_dim 92 | 93 | if mask is not None: 94 | mask = mask.unsqueeze(1) 95 | score.masked_fill_(mask, -1e9) 96 | 97 | attn = F.softmax(score, -1) 98 | attn = self.dropout(attn) 99 | 100 | context = torch.matmul(attn, value).transpose(1, 2) 101 | context = context.contiguous().view(batch_size, -1, self.d_model) 102 | 103 | return self.out_proj(context) 104 | 105 | def _relative_shift(self, pos_score: Tensor) -> Tensor: 106 | batch_size, num_heads, seq_length1, seq_length2 = pos_score.size() 107 | zeros = pos_score.new_zeros(batch_size, num_heads, seq_length1, 1) 108 | padded_pos_score = torch.cat([zeros, pos_score], dim=-1) 109 | 110 | padded_pos_score = padded_pos_score.view(batch_size, num_heads, seq_length2 + 1, seq_length1) 111 | pos_score = padded_pos_score[:, :, 1:].view_as(pos_score) 112 | 113 | return pos_score 114 | 115 | 116 | class MultiHeadedSelfAttentionModule(nn.Module): 117 | """ 118 | Conformer employ multi-headed self-attention (MHSA) while integrating an important technique from Transformer-XL, 119 | the relative sinusoidal positional encoding scheme. The relative positional encoding allows the self-attention 120 | module to generalize better on different input length and the resulting encoder is more robust to the variance of 121 | the utterance length. Conformer use prenorm residual units with dropout which helps training 122 | and regularizing deeper models. 123 | 124 | Args: 125 | d_model (int): The dimension of model 126 | num_heads (int): The number of attention heads. 127 | dropout_p (float): probability of dropout 128 | 129 | Inputs: inputs, mask 130 | - **inputs** (batch, time, dim): Tensor containing input vector 131 | - **mask** (batch, 1, time2) or (batch, time1, time2): Tensor containing indices to be masked 132 | 133 | Returns: 134 | - **outputs** (batch, time, dim): Tensor produces by relative multi headed self attention module. 135 | """ 136 | def __init__(self, d_model: int, num_heads: int, dropout_p: float = 0.1): 137 | super(MultiHeadedSelfAttentionModule, self).__init__() 138 | self.positional_encoding = PositionalEncoding(d_model) 139 | self.layer_norm = nn.LayerNorm(d_model) 140 | self.attention = RelativeMultiHeadAttention(d_model, num_heads, dropout_p) 141 | self.dropout = nn.Dropout(p=dropout_p) 142 | 143 | def forward(self, inputs: Tensor, mask: Optional[Tensor] = None): 144 | batch_size, seq_length, _ = inputs.size() 145 | pos_embedding = self.positional_encoding(seq_length) 146 | pos_embedding = pos_embedding.repeat(batch_size, 1, 1) 147 | 148 | inputs = self.layer_norm(inputs) 149 | outputs = self.attention(inputs, inputs, inputs, pos_embedding=pos_embedding, mask=mask) 150 | 151 | return self.dropout(outputs) 152 | -------------------------------------------------------------------------------- /conformer/conformer/convolution.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, Soohwan Kim. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import torch.nn as nn 17 | from torch import Tensor 18 | from typing import Tuple 19 | 20 | from .activation import Swish, GLU 21 | from .modules import Transpose 22 | 23 | 24 | class DepthwiseConv1d(nn.Module): 25 | """ 26 | When groups == in_channels and out_channels == K * in_channels, where K is a positive integer, 27 | this operation is termed in literature as depthwise convolution. 28 | 29 | Args: 30 | in_channels (int): Number of channels in the input 31 | out_channels (int): Number of channels produced by the convolution 32 | kernel_size (int or tuple): Size of the convolving kernel 33 | stride (int, optional): Stride of the convolution. Default: 1 34 | padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 35 | bias (bool, optional): If True, adds a learnable bias to the output. Default: True 36 | 37 | Inputs: inputs 38 | - **inputs** (batch, in_channels, time): Tensor containing input vector 39 | 40 | Returns: outputs 41 | - **outputs** (batch, out_channels, time): Tensor produces by depthwise 1-D convolution. 42 | """ 43 | def __init__( 44 | self, 45 | in_channels: int, 46 | out_channels: int, 47 | kernel_size: int, 48 | stride: int = 1, 49 | padding: int = 0, 50 | bias: bool = False, 51 | ) -> None: 52 | super(DepthwiseConv1d, self).__init__() 53 | assert out_channels % in_channels == 0, "out_channels should be constant multiple of in_channels" 54 | self.conv = nn.Conv1d( 55 | in_channels=in_channels, 56 | out_channels=out_channels, 57 | kernel_size=kernel_size, 58 | groups=in_channels, 59 | stride=stride, 60 | padding=padding, 61 | bias=bias, 62 | ) 63 | 64 | def forward(self, inputs: Tensor) -> Tensor: 65 | return self.conv(inputs) 66 | 67 | 68 | class PointwiseConv1d(nn.Module): 69 | """ 70 | When kernel size == 1 conv1d, this operation is termed in literature as pointwise convolution. 71 | This operation often used to match dimensions. 72 | 73 | Args: 74 | in_channels (int): Number of channels in the input 75 | out_channels (int): Number of channels produced by the convolution 76 | stride (int, optional): Stride of the convolution. Default: 1 77 | padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 78 | bias (bool, optional): If True, adds a learnable bias to the output. Default: True 79 | 80 | Inputs: inputs 81 | - **inputs** (batch, in_channels, time): Tensor containing input vector 82 | 83 | Returns: outputs 84 | - **outputs** (batch, out_channels, time): Tensor produces by pointwise 1-D convolution. 85 | """ 86 | def __init__( 87 | self, 88 | in_channels: int, 89 | out_channels: int, 90 | stride: int = 1, 91 | padding: int = 0, 92 | bias: bool = True, 93 | ) -> None: 94 | super(PointwiseConv1d, self).__init__() 95 | self.conv = nn.Conv1d( 96 | in_channels=in_channels, 97 | out_channels=out_channels, 98 | kernel_size=1, 99 | stride=stride, 100 | padding=padding, 101 | bias=bias, 102 | ) 103 | 104 | def forward(self, inputs: Tensor) -> Tensor: 105 | return self.conv(inputs) 106 | 107 | 108 | class ConformerConvModule(nn.Module): 109 | """ 110 | Conformer convolution module starts with a pointwise convolution and a gated linear unit (GLU). 111 | This is followed by a single 1-D depthwise convolution layer. Batchnorm is deployed just after the convolution 112 | to aid training deep models. 113 | 114 | Args: 115 | in_channels (int): Number of channels in the input 116 | kernel_size (int or tuple, optional): Size of the convolving kernel Default: 31 117 | dropout_p (float, optional): probability of dropout 118 | 119 | Inputs: inputs 120 | inputs (batch, time, dim): Tensor contains input sequences 121 | 122 | Outputs: outputs 123 | outputs (batch, time, dim): Tensor produces by conformer convolution module. 124 | """ 125 | def __init__( 126 | self, 127 | in_channels: int, 128 | kernel_size: int = 31, 129 | expansion_factor: int = 2, 130 | dropout_p: float = 0.1, 131 | ) -> None: 132 | super(ConformerConvModule, self).__init__() 133 | assert (kernel_size - 1) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding" 134 | assert expansion_factor == 2, "Currently, Only Supports expansion_factor 2" 135 | 136 | self.sequential = nn.Sequential( 137 | nn.LayerNorm(in_channels), 138 | Transpose(shape=(1, 2)), 139 | PointwiseConv1d(in_channels, in_channels * expansion_factor, stride=1, padding=0, bias=True), 140 | GLU(dim=1), 141 | DepthwiseConv1d(in_channels, in_channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2), 142 | nn.BatchNorm1d(in_channels), 143 | Swish(), 144 | PointwiseConv1d(in_channels, in_channels, stride=1, padding=0, bias=True), 145 | nn.Dropout(p=dropout_p), 146 | ) 147 | 148 | def forward(self, inputs: Tensor) -> Tensor: 149 | return self.sequential(inputs).transpose(1, 2) 150 | 151 | 152 | class Conv2dSubampling(nn.Module): 153 | """ 154 | Convolutional 2D subsampling (to 1/4 length) 155 | 156 | Args: 157 | in_channels (int): Number of channels in the input image 158 | out_channels (int): Number of channels produced by the convolution 159 | 160 | Inputs: inputs 161 | - **inputs** (batch, time, dim): Tensor containing sequence of inputs 162 | 163 | Returns: outputs, output_lengths 164 | - **outputs** (batch, time, dim): Tensor produced by the convolution 165 | - **output_lengths** (batch): list of sequence output lengths 166 | """ 167 | def __init__(self, in_channels: int, out_channels: int) -> None: 168 | super(Conv2dSubampling, self).__init__() 169 | self.sequential = nn.Sequential( 170 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2), 171 | nn.ReLU(), 172 | nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=2), 173 | nn.ReLU(), 174 | ) 175 | 176 | def forward(self, inputs: Tensor, input_lengths: Tensor) -> Tuple[Tensor, Tensor]: 177 | outputs = self.sequential(inputs.unsqueeze(1)) 178 | batch_size, channels, subsampled_lengths, sumsampled_dim = outputs.size() 179 | 180 | outputs = outputs.permute(0, 2, 1, 3) 181 | outputs = outputs.contiguous().view(batch_size, subsampled_lengths, channels * sumsampled_dim) 182 | 183 | output_lengths = input_lengths >> 2 184 | output_lengths -= 1 185 | 186 | return outputs, output_lengths 187 | -------------------------------------------------------------------------------- /conformer/conformer/embedding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, Soohwan Kim. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import math 16 | import torch 17 | import torch.nn as nn 18 | from torch import Tensor 19 | 20 | 21 | class PositionalEncoding(nn.Module): 22 | """ 23 | Positional Encoding proposed in "Attention Is All You Need". 24 | Since transformer contains no recurrence and no convolution, in order for the model to make 25 | use of the order of the sequence, we must add some positional information. 26 | 27 | "Attention Is All You Need" use sine and cosine functions of different frequencies: 28 | PE_(pos, 2i) = sin(pos / power(10000, 2i / d_model)) 29 | PE_(pos, 2i+1) = cos(pos / power(10000, 2i / d_model)) 30 | """ 31 | def __init__(self, d_model: int = 512, max_len: int = 10000) -> None: 32 | super(PositionalEncoding, self).__init__() 33 | pe = torch.zeros(max_len, d_model, requires_grad=False) 34 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 35 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)) 36 | pe[:, 0::2] = torch.sin(position * div_term) 37 | pe[:, 1::2] = torch.cos(position * div_term) 38 | pe = pe.unsqueeze(0) 39 | self.register_buffer('pe', pe) 40 | 41 | def forward(self, length: int) -> Tensor: 42 | return self.pe[:, :length] -------------------------------------------------------------------------------- /conformer/conformer/encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, Soohwan Kim. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import torch.nn as nn 17 | from torch import Tensor 18 | from typing import Tuple 19 | 20 | from .feed_forward import FeedForwardModule 21 | from .attention import MultiHeadedSelfAttentionModule 22 | from .convolution import ( 23 | ConformerConvModule, 24 | Conv2dSubampling, 25 | ) 26 | from .modules import ( 27 | ResidualConnectionModule, 28 | Linear, 29 | ) 30 | 31 | 32 | class ConformerBlock(nn.Module): 33 | """ 34 | Conformer block contains two Feed Forward modules sandwiching the Multi-Headed Self-Attention module 35 | and the Convolution module. This sandwich structure is inspired by Macaron-Net, which proposes replacing 36 | the original feed-forward layer in the Transformer block into two half-step feed-forward layers, 37 | one before the attention layer and one after. 38 | 39 | Args: 40 | encoder_dim (int, optional): Dimension of conformer encoder 41 | num_attention_heads (int, optional): Number of attention heads 42 | feed_forward_expansion_factor (int, optional): Expansion factor of feed forward module 43 | conv_expansion_factor (int, optional): Expansion factor of conformer convolution module 44 | feed_forward_dropout_p (float, optional): Probability of feed forward module dropout 45 | attention_dropout_p (float, optional): Probability of attention module dropout 46 | conv_dropout_p (float, optional): Probability of conformer convolution module dropout 47 | conv_kernel_size (int or tuple, optional): Size of the convolving kernel 48 | half_step_residual (bool): Flag indication whether to use half step residual or not 49 | 50 | Inputs: inputs 51 | - **inputs** (batch, time, dim): Tensor containing input vector 52 | 53 | Returns: outputs 54 | - **outputs** (batch, time, dim): Tensor produces by conformer block. 55 | """ 56 | def __init__( 57 | self, 58 | encoder_dim: int = 512, 59 | num_attention_heads: int = 8, 60 | feed_forward_expansion_factor: int = 4, 61 | conv_expansion_factor: int = 2, 62 | feed_forward_dropout_p: float = 0.1, 63 | attention_dropout_p: float = 0.1, 64 | conv_dropout_p: float = 0.1, 65 | conv_kernel_size: int = 31, 66 | half_step_residual: bool = True, 67 | ): 68 | super(ConformerBlock, self).__init__() 69 | if half_step_residual: 70 | self.feed_forward_residual_factor = 0.5 71 | else: 72 | self.feed_forward_residual_factor = 1 73 | 74 | self.sequential = nn.Sequential( 75 | ResidualConnectionModule( 76 | module=FeedForwardModule( 77 | encoder_dim=encoder_dim, 78 | expansion_factor=feed_forward_expansion_factor, 79 | dropout_p=feed_forward_dropout_p, 80 | ), 81 | module_factor=self.feed_forward_residual_factor, 82 | ), 83 | ResidualConnectionModule( 84 | module=MultiHeadedSelfAttentionModule( 85 | d_model=encoder_dim, 86 | num_heads=num_attention_heads, 87 | dropout_p=attention_dropout_p, 88 | ), 89 | ), 90 | ResidualConnectionModule( 91 | module=ConformerConvModule( 92 | in_channels=encoder_dim, 93 | kernel_size=conv_kernel_size, 94 | expansion_factor=conv_expansion_factor, 95 | dropout_p=conv_dropout_p, 96 | ), 97 | ), 98 | ResidualConnectionModule( 99 | module=FeedForwardModule( 100 | encoder_dim=encoder_dim, 101 | expansion_factor=feed_forward_expansion_factor, 102 | dropout_p=feed_forward_dropout_p, 103 | ), 104 | module_factor=self.feed_forward_residual_factor, 105 | ), 106 | nn.LayerNorm(encoder_dim), 107 | ) 108 | 109 | def forward(self, inputs: Tensor) -> Tensor: 110 | return self.sequential(inputs) 111 | 112 | 113 | class ConformerEncoder(nn.Module): 114 | """ 115 | Conformer encoder first processes the input with a convolution subsampling layer and then 116 | with a number of conformer blocks. 117 | 118 | Args: 119 | input_dim (int, optional): Dimension of input vector 120 | encoder_dim (int, optional): Dimension of conformer encoder 121 | num_layers (int, optional): Number of conformer blocks 122 | num_attention_heads (int, optional): Number of attention heads 123 | feed_forward_expansion_factor (int, optional): Expansion factor of feed forward module 124 | conv_expansion_factor (int, optional): Expansion factor of conformer convolution module 125 | feed_forward_dropout_p (float, optional): Probability of feed forward module dropout 126 | attention_dropout_p (float, optional): Probability of attention module dropout 127 | conv_dropout_p (float, optional): Probability of conformer convolution module dropout 128 | conv_kernel_size (int or tuple, optional): Size of the convolving kernel 129 | half_step_residual (bool): Flag indication whether to use half step residual or not 130 | 131 | Inputs: inputs, input_lengths 132 | - **inputs** (batch, time, dim): Tensor containing input vector 133 | - **input_lengths** (batch): list of sequence input lengths 134 | 135 | Returns: outputs, output_lengths 136 | - **outputs** (batch, out_channels, time): Tensor produces by conformer encoder. 137 | - **output_lengths** (batch): list of sequence output lengths 138 | """ 139 | def __init__( 140 | self, 141 | input_dim: int = 80, 142 | encoder_dim: int = 512, 143 | num_layers: int = 17, 144 | num_attention_heads: int = 8, 145 | feed_forward_expansion_factor: int = 4, 146 | conv_expansion_factor: int = 2, 147 | input_dropout_p: float = 0.1, 148 | feed_forward_dropout_p: float = 0.1, 149 | attention_dropout_p: float = 0.1, 150 | conv_dropout_p: float = 0.1, 151 | conv_kernel_size: int = 31, 152 | half_step_residual: bool = True, 153 | ): 154 | super(ConformerEncoder, self).__init__() 155 | # self.conv_subsample = Conv2dSubampling(in_channels=1, out_channels=encoder_dim) 156 | # self.input_projection = nn.Sequential( 157 | # Linear(encoder_dim * (((input_dim - 1) // 2 - 1) // 2), encoder_dim), 158 | # nn.Dropout(p=input_dropout_p), 159 | # ) 160 | self.input_projection = nn.Sequential( 161 | Linear(input_dim, encoder_dim), 162 | nn.Dropout(p=input_dropout_p), 163 | ) 164 | self.layers = nn.ModuleList([ConformerBlock( 165 | encoder_dim=encoder_dim, 166 | num_attention_heads=num_attention_heads, 167 | feed_forward_expansion_factor=feed_forward_expansion_factor, 168 | conv_expansion_factor=conv_expansion_factor, 169 | feed_forward_dropout_p=feed_forward_dropout_p, 170 | attention_dropout_p=attention_dropout_p, 171 | conv_dropout_p=conv_dropout_p, 172 | conv_kernel_size=conv_kernel_size, 173 | half_step_residual=half_step_residual, 174 | ) for _ in range(num_layers)]) 175 | 176 | def count_parameters(self) -> int: 177 | """ Count parameters of encoder """ 178 | return sum([p.numel for p in self.parameters()]) 179 | 180 | def update_dropout(self, dropout_p: float) -> None: 181 | """ Update dropout probability of encoder """ 182 | for name, child in self.named_children(): 183 | if isinstance(child, nn.Dropout): 184 | child.p = dropout_p 185 | 186 | def forward(self, inputs: Tensor, input_lengths: Tensor) -> Tuple[Tensor, Tensor]: 187 | """ 188 | Forward propagate a `inputs` for encoder training. 189 | 190 | Args: 191 | inputs (torch.FloatTensor): A input sequence passed to encoder. Typically for inputs this will be a padded 192 | `FloatTensor` of size ``(batch, seq_length, dimension)``. 193 | input_lengths (torch.LongTensor): The length of input tensor. ``(batch)`` 194 | 195 | Returns: 196 | (Tensor, Tensor) 197 | 198 | * outputs (torch.FloatTensor): A output sequence of encoder. `FloatTensor` of size 199 | ``(batch, seq_length, dimension)`` 200 | * output_lengths (torch.LongTensor): The length of output tensor. ``(batch)`` 201 | """ 202 | # outputs, output_lengths = self.conv_subsample(inputs, input_lengths) 203 | outputs = self.input_projection(inputs) 204 | 205 | for layer in self.layers: 206 | outputs = layer(outputs) 207 | 208 | return outputs, input_lengths 209 | 210 | # return outputs, output_lengths 211 | -------------------------------------------------------------------------------- /conformer/conformer/feed_forward.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, Soohwan Kim. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import torch.nn as nn 17 | from torch import Tensor 18 | 19 | from .activation import Swish 20 | from .modules import Linear 21 | 22 | 23 | class FeedForwardModule(nn.Module): 24 | """ 25 | Conformer Feed Forward Module follow pre-norm residual units and apply layer normalization within the residual unit 26 | and on the input before the first linear layer. This module also apply Swish activation and dropout, which helps 27 | regularizing the network. 28 | 29 | Args: 30 | encoder_dim (int): Dimension of conformer encoder 31 | expansion_factor (int): Expansion factor of feed forward module. 32 | dropout_p (float): Ratio of dropout 33 | 34 | Inputs: inputs 35 | - **inputs** (batch, time, dim): Tensor contains input sequences 36 | 37 | Outputs: outputs 38 | - **outputs** (batch, time, dim): Tensor produces by feed forward module. 39 | """ 40 | def __init__( 41 | self, 42 | encoder_dim: int = 512, 43 | expansion_factor: int = 4, 44 | dropout_p: float = 0.1, 45 | ) -> None: 46 | super(FeedForwardModule, self).__init__() 47 | self.sequential = nn.Sequential( 48 | nn.LayerNorm(encoder_dim), 49 | Linear(encoder_dim, encoder_dim * expansion_factor, bias=True), 50 | Swish(), 51 | nn.Dropout(p=dropout_p), 52 | Linear(encoder_dim * expansion_factor, encoder_dim, bias=True), 53 | nn.Dropout(p=dropout_p), 54 | ) 55 | 56 | def forward(self, inputs: Tensor) -> Tensor: 57 | return self.sequential(inputs) 58 | -------------------------------------------------------------------------------- /conformer/conformer/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, Soohwan Kim. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import torch.nn as nn 17 | from torch import Tensor 18 | from typing import Tuple 19 | 20 | from .encoder import ConformerEncoder 21 | from .modules import Linear 22 | 23 | 24 | class Conformer(nn.Module): 25 | """ 26 | Conformer: Convolution-augmented Transformer for Speech Recognition 27 | The paper used a one-lstm Transducer decoder, currently still only implemented 28 | the conformer encoder shown in the paper. 29 | 30 | Args: 31 | num_classes (int): Number of classification classes 32 | input_dim (int, optional): Dimension of input vector 33 | encoder_dim (int, optional): Dimension of conformer encoder 34 | num_encoder_layers (int, optional): Number of conformer blocks 35 | num_attention_heads (int, optional): Number of attention heads 36 | feed_forward_expansion_factor (int, optional): Expansion factor of feed forward module 37 | conv_expansion_factor (int, optional): Expansion factor of conformer convolution module 38 | feed_forward_dropout_p (float, optional): Probability of feed forward module dropout 39 | attention_dropout_p (float, optional): Probability of attention module dropout 40 | conv_dropout_p (float, optional): Probability of conformer convolution module dropout 41 | conv_kernel_size (int or tuple, optional): Size of the convolving kernel 42 | half_step_residual (bool): Flag indication whether to use half step residual or not 43 | 44 | Inputs: inputs 45 | - **inputs** (batch, time, dim): Tensor containing input vector 46 | - **input_lengths** (batch): list of sequence input lengths 47 | 48 | Returns: outputs, output_lengths 49 | - **outputs** (batch, out_channels, time): Tensor produces by conformer. 50 | - **output_lengths** (batch): list of sequence output lengths 51 | """ 52 | def __init__( 53 | self, 54 | model_config: dict 55 | # num_classes: int, 56 | # input_dim: int = 80, 57 | # encoder_dim: int = 512, 58 | # num_encoder_layers: int = 17, 59 | # num_attention_heads: int = 8, 60 | # feed_forward_expansion_factor: int = 4, 61 | # conv_expansion_factor: int = 2, 62 | # input_dropout_p: float = 0.1, 63 | # feed_forward_dropout_p: float = 0.1, 64 | # attention_dropout_p: float = 0.1, 65 | # conv_dropout_p: float = 0.1, 66 | # conv_kernel_size: int = 31, 67 | # half_step_residual: bool = True, 68 | ) -> None: 69 | super(Conformer, self).__init__() 70 | 71 | input_dim = model_config["transformer"]["encoder_hidden"] 72 | encoder_dim = model_config["transformer"]["encoder_hidden"] 73 | num_encoder_layers = model_config["transformer"]["encoder_layer"] 74 | num_attention_heads = model_config["transformer"]["encoder_head"] 75 | feed_forward_expansion_factor = model_config["transformer"]["feed_forward_expansion_factor"] 76 | conv_expansion_factor = model_config["transformer"]["conv_expansion_factor"] 77 | input_dropout_p = model_config["transformer"]["input_dropout_p"] 78 | feed_forward_dropout_p = model_config["transformer"]["feed_forward_dropout_p"] 79 | attention_dropout_p = model_config["transformer"]["attention_dropout_p"] 80 | conv_dropout_p = model_config["transformer"]["conv_dropout_p"] 81 | conv_kernel_size = model_config["transformer"]["conv_kernel_size"] 82 | half_step_residual = model_config["transformer"]["half_step_residual"] 83 | 84 | self.encoder = ConformerEncoder( 85 | input_dim=input_dim, 86 | encoder_dim=encoder_dim, 87 | num_layers=num_encoder_layers, 88 | num_attention_heads=num_attention_heads, 89 | feed_forward_expansion_factor=feed_forward_expansion_factor, 90 | conv_expansion_factor=conv_expansion_factor, 91 | input_dropout_p=input_dropout_p, 92 | feed_forward_dropout_p=feed_forward_dropout_p, 93 | attention_dropout_p=attention_dropout_p, 94 | conv_dropout_p=conv_dropout_p, 95 | conv_kernel_size=conv_kernel_size, 96 | half_step_residual=half_step_residual, 97 | ) 98 | # self.fc = Linear(encoder_dim, num_classes, bias=False) 99 | 100 | def count_parameters(self) -> int: 101 | """ Count parameters of encoder """ 102 | return self.encoder.count_parameters() 103 | 104 | def update_dropout(self, dropout_p) -> None: 105 | """ Update dropout probability of model """ 106 | self.encoder.update_dropout(dropout_p) 107 | 108 | def forward(self, inputs: Tensor, input_lengths: Tensor) -> Tuple[Tensor, Tensor]: 109 | """ 110 | Forward propagate a `inputs` and `targets` pair for training. 111 | 112 | Args: 113 | inputs (torch.FloatTensor): A input sequence passed to encoder. Typically for inputs this will be a padded 114 | `FloatTensor` of size ``(batch, seq_length, dimension)``. 115 | input_lengths (torch.LongTensor): The length of input tensor. ``(batch)`` 116 | 117 | Returns: 118 | * predictions (torch.FloatTensor): Result of model predictions. 119 | """ 120 | encoder_outputs, encoder_output_lengths = self.encoder(inputs, input_lengths) 121 | return encoder_outputs, encoder_output_lengths 122 | 123 | # outputs = self.fc(encoder_outputs) 124 | # outputs = nn.functional.log_softmax(outputs, dim=-1) 125 | # return outputs, encoder_output_lengths 126 | -------------------------------------------------------------------------------- /conformer/conformer/modules.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, Soohwan Kim. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import torch.nn as nn 17 | import torch.nn.init as init 18 | from torch import Tensor 19 | 20 | 21 | class ResidualConnectionModule(nn.Module): 22 | """ 23 | Residual Connection Module. 24 | outputs = (module(inputs) x module_factor + inputs x input_factor) 25 | """ 26 | def __init__(self, module: nn.Module, module_factor: float = 1.0, input_factor: float = 1.0): 27 | super(ResidualConnectionModule, self).__init__() 28 | self.module = module 29 | self.module_factor = module_factor 30 | self.input_factor = input_factor 31 | 32 | def forward(self, inputs: Tensor) -> Tensor: 33 | return (self.module(inputs) * self.module_factor) + (inputs * self.input_factor) 34 | 35 | 36 | class Linear(nn.Module): 37 | """ 38 | Wrapper class of torch.nn.Linear 39 | Weight initialize by xavier initialization and bias initialize to zeros. 40 | """ 41 | def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None: 42 | super(Linear, self).__init__() 43 | self.linear = nn.Linear(in_features, out_features, bias=bias) 44 | init.xavier_uniform_(self.linear.weight) 45 | if bias: 46 | init.zeros_(self.linear.bias) 47 | 48 | def forward(self, x: Tensor) -> Tensor: 49 | return self.linear(x) 50 | 51 | 52 | class View(nn.Module): 53 | """ Wrapper class of torch.view() for Sequential module. """ 54 | def __init__(self, shape: tuple, contiguous: bool = False): 55 | super(View, self).__init__() 56 | self.shape = shape 57 | self.contiguous = contiguous 58 | 59 | def forward(self, x: Tensor) -> Tensor: 60 | if self.contiguous: 61 | x = x.contiguous() 62 | 63 | return x.view(*self.shape) 64 | 65 | 66 | class Transpose(nn.Module): 67 | """ Wrapper class of torch.transpose() for Sequential module. """ 68 | def __init__(self, shape: tuple): 69 | super(Transpose, self).__init__() 70 | self.shape = shape 71 | 72 | def forward(self, x: Tensor) -> Tensor: 73 | return x.transpose(*self.shape) 74 | -------------------------------------------------------------------------------- /conformer/docs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/choiHkk/CVAEJETS/df71e7d484926c5ff0eac1a4749dbb2cafe81a6d/conformer/docs/.DS_Store -------------------------------------------------------------------------------- /conformer/docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /conformer/docs/Model.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | Model — conformer latest documentation 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 |

50 | 51 | 106 | 107 |
108 | 109 | 110 | 116 | 117 | 118 |
119 | 120 |
121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 |
141 | 142 |
    143 | 144 |
  • »
  • 145 | 146 |
  • Model
  • 147 | 148 | 149 |
  • 150 | 151 | 152 | View page source 153 | 154 | 155 |
  • 156 | 157 |
158 | 159 | 160 |
161 |
162 |
163 |
164 | 165 |
166 |

Model

167 |
168 |

Conformer

169 |
170 |
171 |

Encoder

172 |
173 |
174 |

Decoder

175 |
176 |
177 | 178 | 179 |
180 | 181 |
182 |
183 | 187 | 188 |
189 | 190 |
191 |

192 | © Copyright 2021, Soohwan Kim. 193 | 194 |

195 |
196 | 197 | 198 | 199 | Built with Sphinx using a 200 | 201 | theme 202 | 203 | provided by Read the Docs. 204 | 205 |
206 |
207 |
208 | 209 |
210 | 211 |
212 | 213 | 214 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | -------------------------------------------------------------------------------- /conformer/docs/Modules.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | Conformer Modules — conformer latest documentation 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 |
50 | 51 | 106 | 107 |
108 | 109 | 110 | 116 | 117 | 118 |
119 | 120 |
121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 |
141 | 142 |
    143 | 144 |
  • »
  • 145 | 146 |
  • Conformer Modules
  • 147 | 148 | 149 |
  • 150 | 151 | 152 | View page source 153 | 154 | 155 |
  • 156 | 157 |
158 | 159 | 160 |
161 |
162 |
163 |
164 | 165 |
166 |

Conformer Modules

167 |
168 |

Attention

169 |
170 |
171 |

Convolution

172 |
173 |
174 |

Feed Forward

175 |
176 |
177 | 178 | 179 |
180 | 181 |
182 |
183 | 187 | 188 |
189 | 190 |
191 |

192 | © Copyright 2021, Soohwan Kim. 193 | 194 |

195 |
196 | 197 | 198 | 199 | Built with Sphinx using a 200 | 201 | theme 202 | 203 | provided by Read the Docs. 204 | 205 |
206 |
207 |
208 | 209 |
210 | 211 |
212 | 213 | 214 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | -------------------------------------------------------------------------------- /conformer/docs/Submodules.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | Submodules — conformer latest documentation 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 |
49 | 50 | 105 | 106 |
107 | 108 | 109 | 115 | 116 | 117 |
118 | 119 |
120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 |
140 | 141 |
    142 | 143 |
  • »
  • 144 | 145 |
  • Submodules
  • 146 | 147 | 148 |
  • 149 | 150 | 151 | View page source 152 | 153 | 154 |
  • 155 | 156 |
157 | 158 | 159 |
160 |
161 |
162 |
163 | 164 |
165 |

Submodules

166 |
167 |

Activation

168 |
169 |
170 |

Modules

171 |
172 |
173 |

Embedding

174 |
175 |
176 | 177 | 178 |
179 | 180 |
181 |
182 | 185 | 186 |
187 | 188 |
189 |

190 | © Copyright 2021, Soohwan Kim. 191 | 192 |

193 |
194 | 195 | 196 | 197 | Built with Sphinx using a 198 | 199 | theme 200 | 201 | provided by Read the Docs. 202 | 203 |
204 |
205 |
206 | 207 |
208 | 209 |
210 | 211 | 212 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | -------------------------------------------------------------------------------- /conformer/docs/_sources/Model.rst.txt: -------------------------------------------------------------------------------- 1 | 2 | Model 3 | ===================================================== 4 | 5 | Conformer 6 | -------------------------------------------- 7 | 8 | .. automodule:: conformer.model 9 | :members: 10 | 11 | Encoder 12 | -------------------------------------------- 13 | 14 | .. automodule:: conformer.encoder 15 | :members: 16 | 17 | Decoder 18 | -------------------------------------------- 19 | 20 | .. automodule:: conformer.decoder 21 | :members: 22 | -------------------------------------------------------------------------------- /conformer/docs/_sources/Modules.rst.txt: -------------------------------------------------------------------------------- 1 | 2 | Conformer Modules 3 | ===================================================== 4 | 5 | Attention 6 | -------------------------------------------- 7 | 8 | .. automodule:: conformer.attention 9 | :members: 10 | 11 | Convolution 12 | -------------------------------------------- 13 | 14 | .. automodule:: conformer.convolution 15 | :members: 16 | 17 | Feed Forward 18 | -------------------------------------------- 19 | 20 | .. automodule:: conformer.feed_forward 21 | :members: -------------------------------------------------------------------------------- /conformer/docs/_sources/Submodules.rst.txt: -------------------------------------------------------------------------------- 1 | 2 | Submodules 3 | ===================================================== 4 | 5 | Activation 6 | -------------------------------------------- 7 | 8 | .. automodule:: conformer.activation 9 | :members: 10 | 11 | Modules 12 | -------------------------------------------- 13 | 14 | .. automodule:: conformer.modules 15 | :members: 16 | 17 | Embedding 18 | -------------------------------------------- 19 | 20 | .. automodule:: conformer.embedding 21 | :members: -------------------------------------------------------------------------------- /conformer/docs/_sources/index.rst.txt: -------------------------------------------------------------------------------- 1 | .. conformer documentation master file, created by 2 | sphinx-quickstart on Sun Jan 24 01:16:16 2021. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to Conformer's documentation! 7 | ===================================== 8 | 9 | .. toctree:: 10 | :maxdepth: 1 11 | :caption: PACKAGE 12 | 13 | Model 14 | Modules 15 | Submodules 16 | 17 | 18 | 19 | Indices and tables 20 | ================== 21 | 22 | * :ref:`genindex` 23 | * :ref:`modindex` 24 | * :ref:`search` 25 | -------------------------------------------------------------------------------- /conformer/docs/_static/css/badge_only.css: -------------------------------------------------------------------------------- 1 | .fa:before{-webkit-font-smoothing:antialiased}.clearfix{*zoom:1}.clearfix:after,.clearfix:before{display:table;content:""}.clearfix:after{clear:both}@font-face{font-family:FontAwesome;font-style:normal;font-weight:400;src:url(fonts/fontawesome-webfont.eot?674f50d287a8c48dc19ba404d20fe713?#iefix) format("embedded-opentype"),url(fonts/fontawesome-webfont.woff2?af7ae505a9eed503f8b8e6982036873e) format("woff2"),url(fonts/fontawesome-webfont.woff?fee66e712a8a08eef5805a46892932ad) format("woff"),url(fonts/fontawesome-webfont.ttf?b06871f281fee6b241d60582ae9369b9) format("truetype"),url(fonts/fontawesome-webfont.svg?912ec66d7572ff821749319396470bde#FontAwesome) format("svg")}.fa:before{font-family:FontAwesome;font-style:normal;font-weight:400;line-height:1}.fa:before,a .fa{text-decoration:inherit}.fa:before,a .fa,li .fa{display:inline-block}li .fa-large:before{width:1.875em}ul.fas{list-style-type:none;margin-left:2em;text-indent:-.8em}ul.fas li .fa{width:.8em}ul.fas li .fa-large:before{vertical-align:baseline}.fa-book:before,.icon-book:before{content:"\f02d"}.fa-caret-down:before,.icon-caret-down:before{content:"\f0d7"}.fa-caret-up:before,.icon-caret-up:before{content:"\f0d8"}.fa-caret-left:before,.icon-caret-left:before{content:"\f0d9"}.fa-caret-right:before,.icon-caret-right:before{content:"\f0da"}.rst-versions{position:fixed;bottom:0;left:0;width:300px;color:#fcfcfc;background:#1f1d1d;font-family:Lato,proxima-nova,Helvetica Neue,Arial,sans-serif;z-index:400}.rst-versions a{color:#2980b9;text-decoration:none}.rst-versions .rst-badge-small{display:none}.rst-versions .rst-current-version{padding:12px;background-color:#272525;display:block;text-align:right;font-size:90%;cursor:pointer;color:#27ae60}.rst-versions .rst-current-version:after{clear:both;content:"";display:block}.rst-versions .rst-current-version .fa{color:#fcfcfc}.rst-versions .rst-current-version .fa-book,.rst-versions .rst-current-version .icon-book{float:left}.rst-versions .rst-current-version.rst-out-of-date{background-color:#e74c3c;color:#fff}.rst-versions .rst-current-version.rst-active-old-version{background-color:#f1c40f;color:#000}.rst-versions.shift-up{height:auto;max-height:100%;overflow-y:scroll}.rst-versions.shift-up .rst-other-versions{display:block}.rst-versions .rst-other-versions{font-size:90%;padding:12px;color:grey;display:none}.rst-versions .rst-other-versions hr{display:block;height:1px;border:0;margin:20px 0;padding:0;border-top:1px solid #413d3d}.rst-versions .rst-other-versions dd{display:inline-block;margin:0}.rst-versions .rst-other-versions dd a{display:inline-block;padding:6px;color:#fcfcfc}.rst-versions.rst-badge{width:auto;bottom:20px;right:20px;left:auto;border:none;max-width:300px;max-height:90%}.rst-versions.rst-badge .fa-book,.rst-versions.rst-badge .icon-book{float:none;line-height:30px}.rst-versions.rst-badge.shift-up .rst-current-version{text-align:right}.rst-versions.rst-badge.shift-up .rst-current-version .fa-book,.rst-versions.rst-badge.shift-up .rst-current-version .icon-book{float:left}.rst-versions.rst-badge>.rst-current-version{width:auto;height:30px;line-height:30px;padding:0 6px;display:block;text-align:center}@media screen and (max-width:768px){.rst-versions{width:85%;display:none}.rst-versions.shift{display:block}} -------------------------------------------------------------------------------- /conformer/docs/_static/css/fonts/Roboto-Slab-Bold.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/choiHkk/CVAEJETS/df71e7d484926c5ff0eac1a4749dbb2cafe81a6d/conformer/docs/_static/css/fonts/Roboto-Slab-Bold.woff -------------------------------------------------------------------------------- /conformer/docs/_static/css/fonts/Roboto-Slab-Bold.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/choiHkk/CVAEJETS/df71e7d484926c5ff0eac1a4749dbb2cafe81a6d/conformer/docs/_static/css/fonts/Roboto-Slab-Bold.woff2 -------------------------------------------------------------------------------- /conformer/docs/_static/css/fonts/Roboto-Slab-Regular.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/choiHkk/CVAEJETS/df71e7d484926c5ff0eac1a4749dbb2cafe81a6d/conformer/docs/_static/css/fonts/Roboto-Slab-Regular.woff -------------------------------------------------------------------------------- /conformer/docs/_static/css/fonts/Roboto-Slab-Regular.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/choiHkk/CVAEJETS/df71e7d484926c5ff0eac1a4749dbb2cafe81a6d/conformer/docs/_static/css/fonts/Roboto-Slab-Regular.woff2 -------------------------------------------------------------------------------- /conformer/docs/_static/css/fonts/fontawesome-webfont.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/choiHkk/CVAEJETS/df71e7d484926c5ff0eac1a4749dbb2cafe81a6d/conformer/docs/_static/css/fonts/fontawesome-webfont.eot -------------------------------------------------------------------------------- /conformer/docs/_static/css/fonts/fontawesome-webfont.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/choiHkk/CVAEJETS/df71e7d484926c5ff0eac1a4749dbb2cafe81a6d/conformer/docs/_static/css/fonts/fontawesome-webfont.ttf -------------------------------------------------------------------------------- /conformer/docs/_static/css/fonts/fontawesome-webfont.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/choiHkk/CVAEJETS/df71e7d484926c5ff0eac1a4749dbb2cafe81a6d/conformer/docs/_static/css/fonts/fontawesome-webfont.woff -------------------------------------------------------------------------------- /conformer/docs/_static/css/fonts/fontawesome-webfont.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/choiHkk/CVAEJETS/df71e7d484926c5ff0eac1a4749dbb2cafe81a6d/conformer/docs/_static/css/fonts/fontawesome-webfont.woff2 -------------------------------------------------------------------------------- /conformer/docs/_static/css/fonts/lato-bold-italic.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/choiHkk/CVAEJETS/df71e7d484926c5ff0eac1a4749dbb2cafe81a6d/conformer/docs/_static/css/fonts/lato-bold-italic.woff -------------------------------------------------------------------------------- /conformer/docs/_static/css/fonts/lato-bold-italic.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/choiHkk/CVAEJETS/df71e7d484926c5ff0eac1a4749dbb2cafe81a6d/conformer/docs/_static/css/fonts/lato-bold-italic.woff2 -------------------------------------------------------------------------------- /conformer/docs/_static/css/fonts/lato-bold.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/choiHkk/CVAEJETS/df71e7d484926c5ff0eac1a4749dbb2cafe81a6d/conformer/docs/_static/css/fonts/lato-bold.woff -------------------------------------------------------------------------------- /conformer/docs/_static/css/fonts/lato-bold.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/choiHkk/CVAEJETS/df71e7d484926c5ff0eac1a4749dbb2cafe81a6d/conformer/docs/_static/css/fonts/lato-bold.woff2 -------------------------------------------------------------------------------- /conformer/docs/_static/css/fonts/lato-normal-italic.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/choiHkk/CVAEJETS/df71e7d484926c5ff0eac1a4749dbb2cafe81a6d/conformer/docs/_static/css/fonts/lato-normal-italic.woff -------------------------------------------------------------------------------- /conformer/docs/_static/css/fonts/lato-normal-italic.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/choiHkk/CVAEJETS/df71e7d484926c5ff0eac1a4749dbb2cafe81a6d/conformer/docs/_static/css/fonts/lato-normal-italic.woff2 -------------------------------------------------------------------------------- /conformer/docs/_static/css/fonts/lato-normal.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/choiHkk/CVAEJETS/df71e7d484926c5ff0eac1a4749dbb2cafe81a6d/conformer/docs/_static/css/fonts/lato-normal.woff -------------------------------------------------------------------------------- /conformer/docs/_static/css/fonts/lato-normal.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/choiHkk/CVAEJETS/df71e7d484926c5ff0eac1a4749dbb2cafe81a6d/conformer/docs/_static/css/fonts/lato-normal.woff2 -------------------------------------------------------------------------------- /conformer/docs/_static/doctools.js: -------------------------------------------------------------------------------- 1 | /* 2 | * doctools.js 3 | * ~~~~~~~~~~~ 4 | * 5 | * Sphinx JavaScript utilities for all documentation. 6 | * 7 | * :copyright: Copyright 2007-2020 by the Sphinx team, see AUTHORS. 8 | * :license: BSD, see LICENSE for details. 9 | * 10 | */ 11 | 12 | /** 13 | * select a different prefix for underscore 14 | */ 15 | $u = _.noConflict(); 16 | 17 | /** 18 | * make the code below compatible with browsers without 19 | * an installed firebug like debugger 20 | if (!window.console || !console.firebug) { 21 | var names = ["log", "debug", "info", "warn", "error", "assert", "dir", 22 | "dirxml", "group", "groupEnd", "time", "timeEnd", "count", "trace", 23 | "profile", "profileEnd"]; 24 | window.console = {}; 25 | for (var i = 0; i < names.length; ++i) 26 | window.console[names[i]] = function() {}; 27 | } 28 | */ 29 | 30 | /** 31 | * small helper function to urldecode strings 32 | */ 33 | jQuery.urldecode = function(x) { 34 | return decodeURIComponent(x).replace(/\+/g, ' '); 35 | }; 36 | 37 | /** 38 | * small helper function to urlencode strings 39 | */ 40 | jQuery.urlencode = encodeURIComponent; 41 | 42 | /** 43 | * This function returns the parsed url parameters of the 44 | * current request. Multiple values per key are supported, 45 | * it will always return arrays of strings for the value parts. 46 | */ 47 | jQuery.getQueryParameters = function(s) { 48 | if (typeof s === 'undefined') 49 | s = document.location.search; 50 | var parts = s.substr(s.indexOf('?') + 1).split('&'); 51 | var result = {}; 52 | for (var i = 0; i < parts.length; i++) { 53 | var tmp = parts[i].split('=', 2); 54 | var key = jQuery.urldecode(tmp[0]); 55 | var value = jQuery.urldecode(tmp[1]); 56 | if (key in result) 57 | result[key].push(value); 58 | else 59 | result[key] = [value]; 60 | } 61 | return result; 62 | }; 63 | 64 | /** 65 | * highlight a given string on a jquery object by wrapping it in 66 | * span elements with the given class name. 67 | */ 68 | jQuery.fn.highlightText = function(text, className) { 69 | function highlight(node, addItems) { 70 | if (node.nodeType === 3) { 71 | var val = node.nodeValue; 72 | var pos = val.toLowerCase().indexOf(text); 73 | if (pos >= 0 && 74 | !jQuery(node.parentNode).hasClass(className) && 75 | !jQuery(node.parentNode).hasClass("nohighlight")) { 76 | var span; 77 | var isInSVG = jQuery(node).closest("body, svg, foreignObject").is("svg"); 78 | if (isInSVG) { 79 | span = document.createElementNS("http://www.w3.org/2000/svg", "tspan"); 80 | } else { 81 | span = document.createElement("span"); 82 | span.className = className; 83 | } 84 | span.appendChild(document.createTextNode(val.substr(pos, text.length))); 85 | node.parentNode.insertBefore(span, node.parentNode.insertBefore( 86 | document.createTextNode(val.substr(pos + text.length)), 87 | node.nextSibling)); 88 | node.nodeValue = val.substr(0, pos); 89 | if (isInSVG) { 90 | var rect = document.createElementNS("http://www.w3.org/2000/svg", "rect"); 91 | var bbox = node.parentElement.getBBox(); 92 | rect.x.baseVal.value = bbox.x; 93 | rect.y.baseVal.value = bbox.y; 94 | rect.width.baseVal.value = bbox.width; 95 | rect.height.baseVal.value = bbox.height; 96 | rect.setAttribute('class', className); 97 | addItems.push({ 98 | "parent": node.parentNode, 99 | "target": rect}); 100 | } 101 | } 102 | } 103 | else if (!jQuery(node).is("button, select, textarea")) { 104 | jQuery.each(node.childNodes, function() { 105 | highlight(this, addItems); 106 | }); 107 | } 108 | } 109 | var addItems = []; 110 | var result = this.each(function() { 111 | highlight(this, addItems); 112 | }); 113 | for (var i = 0; i < addItems.length; ++i) { 114 | jQuery(addItems[i].parent).before(addItems[i].target); 115 | } 116 | return result; 117 | }; 118 | 119 | /* 120 | * backward compatibility for jQuery.browser 121 | * This will be supported until firefox bug is fixed. 122 | */ 123 | if (!jQuery.browser) { 124 | jQuery.uaMatch = function(ua) { 125 | ua = ua.toLowerCase(); 126 | 127 | var match = /(chrome)[ \/]([\w.]+)/.exec(ua) || 128 | /(webkit)[ \/]([\w.]+)/.exec(ua) || 129 | /(opera)(?:.*version|)[ \/]([\w.]+)/.exec(ua) || 130 | /(msie) ([\w.]+)/.exec(ua) || 131 | ua.indexOf("compatible") < 0 && /(mozilla)(?:.*? rv:([\w.]+)|)/.exec(ua) || 132 | []; 133 | 134 | return { 135 | browser: match[ 1 ] || "", 136 | version: match[ 2 ] || "0" 137 | }; 138 | }; 139 | jQuery.browser = {}; 140 | jQuery.browser[jQuery.uaMatch(navigator.userAgent).browser] = true; 141 | } 142 | 143 | /** 144 | * Small JavaScript module for the documentation. 145 | */ 146 | var Documentation = { 147 | 148 | init : function() { 149 | this.fixFirefoxAnchorBug(); 150 | this.highlightSearchWords(); 151 | this.initIndexTable(); 152 | if (DOCUMENTATION_OPTIONS.NAVIGATION_WITH_KEYS) { 153 | this.initOnKeyListeners(); 154 | } 155 | }, 156 | 157 | /** 158 | * i18n support 159 | */ 160 | TRANSLATIONS : {}, 161 | PLURAL_EXPR : function(n) { return n === 1 ? 0 : 1; }, 162 | LOCALE : 'unknown', 163 | 164 | // gettext and ngettext don't access this so that the functions 165 | // can safely bound to a different name (_ = Documentation.gettext) 166 | gettext : function(string) { 167 | var translated = Documentation.TRANSLATIONS[string]; 168 | if (typeof translated === 'undefined') 169 | return string; 170 | return (typeof translated === 'string') ? translated : translated[0]; 171 | }, 172 | 173 | ngettext : function(singular, plural, n) { 174 | var translated = Documentation.TRANSLATIONS[singular]; 175 | if (typeof translated === 'undefined') 176 | return (n == 1) ? singular : plural; 177 | return translated[Documentation.PLURALEXPR(n)]; 178 | }, 179 | 180 | addTranslations : function(catalog) { 181 | for (var key in catalog.messages) 182 | this.TRANSLATIONS[key] = catalog.messages[key]; 183 | this.PLURAL_EXPR = new Function('n', 'return +(' + catalog.plural_expr + ')'); 184 | this.LOCALE = catalog.locale; 185 | }, 186 | 187 | /** 188 | * add context elements like header anchor links 189 | */ 190 | addContextElements : function() { 191 | $('div[id] > :header:first').each(function() { 192 | $('\u00B6'). 193 | attr('href', '#' + this.id). 194 | attr('title', _('Permalink to this headline')). 195 | appendTo(this); 196 | }); 197 | $('dt[id]').each(function() { 198 | $('\u00B6'). 199 | attr('href', '#' + this.id). 200 | attr('title', _('Permalink to this definition')). 201 | appendTo(this); 202 | }); 203 | }, 204 | 205 | /** 206 | * workaround a firefox stupidity 207 | * see: https://bugzilla.mozilla.org/show_bug.cgi?id=645075 208 | */ 209 | fixFirefoxAnchorBug : function() { 210 | if (document.location.hash && $.browser.mozilla) 211 | window.setTimeout(function() { 212 | document.location.href += ''; 213 | }, 10); 214 | }, 215 | 216 | /** 217 | * highlight the search words provided in the url in the text 218 | */ 219 | highlightSearchWords : function() { 220 | var params = $.getQueryParameters(); 221 | var terms = (params.highlight) ? params.highlight[0].split(/\s+/) : []; 222 | if (terms.length) { 223 | var body = $('div.body'); 224 | if (!body.length) { 225 | body = $('body'); 226 | } 227 | window.setTimeout(function() { 228 | $.each(terms, function() { 229 | body.highlightText(this.toLowerCase(), 'highlighted'); 230 | }); 231 | }, 10); 232 | $('') 234 | .appendTo($('#searchbox')); 235 | } 236 | }, 237 | 238 | /** 239 | * init the domain index toggle buttons 240 | */ 241 | initIndexTable : function() { 242 | var togglers = $('img.toggler').click(function() { 243 | var src = $(this).attr('src'); 244 | var idnum = $(this).attr('id').substr(7); 245 | $('tr.cg-' + idnum).toggle(); 246 | if (src.substr(-9) === 'minus.png') 247 | $(this).attr('src', src.substr(0, src.length-9) + 'plus.png'); 248 | else 249 | $(this).attr('src', src.substr(0, src.length-8) + 'minus.png'); 250 | }).css('display', ''); 251 | if (DOCUMENTATION_OPTIONS.COLLAPSE_INDEX) { 252 | togglers.click(); 253 | } 254 | }, 255 | 256 | /** 257 | * helper function to hide the search marks again 258 | */ 259 | hideSearchWords : function() { 260 | $('#searchbox .highlight-link').fadeOut(300); 261 | $('span.highlighted').removeClass('highlighted'); 262 | }, 263 | 264 | /** 265 | * make the url absolute 266 | */ 267 | makeURL : function(relativeURL) { 268 | return DOCUMENTATION_OPTIONS.URL_ROOT + '/' + relativeURL; 269 | }, 270 | 271 | /** 272 | * get the current relative url 273 | */ 274 | getCurrentURL : function() { 275 | var path = document.location.pathname; 276 | var parts = path.split(/\//); 277 | $.each(DOCUMENTATION_OPTIONS.URL_ROOT.split(/\//), function() { 278 | if (this === '..') 279 | parts.pop(); 280 | }); 281 | var url = parts.join('/'); 282 | return path.substring(url.lastIndexOf('/') + 1, path.length - 1); 283 | }, 284 | 285 | initOnKeyListeners: function() { 286 | $(document).keydown(function(event) { 287 | var activeElementType = document.activeElement.tagName; 288 | // don't navigate when in search box or textarea 289 | if (activeElementType !== 'TEXTAREA' && activeElementType !== 'INPUT' && activeElementType !== 'SELECT' 290 | && !event.altKey && !event.ctrlKey && !event.metaKey && !event.shiftKey) { 291 | switch (event.keyCode) { 292 | case 37: // left 293 | var prevHref = $('link[rel="prev"]').prop('href'); 294 | if (prevHref) { 295 | window.location.href = prevHref; 296 | return false; 297 | } 298 | case 39: // right 299 | var nextHref = $('link[rel="next"]').prop('href'); 300 | if (nextHref) { 301 | window.location.href = nextHref; 302 | return false; 303 | } 304 | } 305 | } 306 | }); 307 | } 308 | }; 309 | 310 | // quick alias for translations 311 | _ = Documentation.gettext; 312 | 313 | $(document).ready(function() { 314 | Documentation.init(); 315 | }); 316 | -------------------------------------------------------------------------------- /conformer/docs/_static/documentation_options.js: -------------------------------------------------------------------------------- 1 | var DOCUMENTATION_OPTIONS = { 2 | URL_ROOT: document.getElementById("documentation_options").getAttribute('data-url_root'), 3 | VERSION: 'latest', 4 | LANGUAGE: 'None', 5 | COLLAPSE_INDEX: false, 6 | BUILDER: 'html', 7 | FILE_SUFFIX: '.html', 8 | LINK_SUFFIX: '.html', 9 | HAS_SOURCE: true, 10 | SOURCELINK_SUFFIX: '.txt', 11 | NAVIGATION_WITH_KEYS: false 12 | }; -------------------------------------------------------------------------------- /conformer/docs/_static/file.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/choiHkk/CVAEJETS/df71e7d484926c5ff0eac1a4749dbb2cafe81a6d/conformer/docs/_static/file.png -------------------------------------------------------------------------------- /conformer/docs/_static/js/badge_only.js: -------------------------------------------------------------------------------- 1 | !function(e){var t={};function r(n){if(t[n])return t[n].exports;var o=t[n]={i:n,l:!1,exports:{}};return e[n].call(o.exports,o,o.exports,r),o.l=!0,o.exports}r.m=e,r.c=t,r.d=function(e,t,n){r.o(e,t)||Object.defineProperty(e,t,{enumerable:!0,get:n})},r.r=function(e){"undefined"!=typeof Symbol&&Symbol.toStringTag&&Object.defineProperty(e,Symbol.toStringTag,{value:"Module"}),Object.defineProperty(e,"__esModule",{value:!0})},r.t=function(e,t){if(1&t&&(e=r(e)),8&t)return e;if(4&t&&"object"==typeof e&&e&&e.__esModule)return e;var n=Object.create(null);if(r.r(n),Object.defineProperty(n,"default",{enumerable:!0,value:e}),2&t&&"string"!=typeof e)for(var o in e)r.d(n,o,function(t){return e[t]}.bind(null,o));return n},r.n=function(e){var t=e&&e.__esModule?function(){return e.default}:function(){return e};return r.d(t,"a",t),t},r.o=function(e,t){return Object.prototype.hasOwnProperty.call(e,t)},r.p="",r(r.s=4)}({4:function(e,t,r){}}); -------------------------------------------------------------------------------- /conformer/docs/_static/js/html5shiv-printshiv.min.js: -------------------------------------------------------------------------------- 1 | /** 2 | * @preserve HTML5 Shiv 3.7.3-pre | @afarkas @jdalton @jon_neal @rem | MIT/GPL2 Licensed 3 | */ 4 | !function(a,b){function c(a,b){var c=a.createElement("p"),d=a.getElementsByTagName("head")[0]||a.documentElement;return c.innerHTML="x",d.insertBefore(c.lastChild,d.firstChild)}function d(){var a=y.elements;return"string"==typeof a?a.split(" "):a}function e(a,b){var c=y.elements;"string"!=typeof c&&(c=c.join(" ")),"string"!=typeof a&&(a=a.join(" ")),y.elements=c+" "+a,j(b)}function f(a){var b=x[a[v]];return b||(b={},w++,a[v]=w,x[w]=b),b}function g(a,c,d){if(c||(c=b),q)return c.createElement(a);d||(d=f(c));var e;return e=d.cache[a]?d.cache[a].cloneNode():u.test(a)?(d.cache[a]=d.createElem(a)).cloneNode():d.createElem(a),!e.canHaveChildren||t.test(a)||e.tagUrn?e:d.frag.appendChild(e)}function h(a,c){if(a||(a=b),q)return a.createDocumentFragment();c=c||f(a);for(var e=c.frag.cloneNode(),g=0,h=d(),i=h.length;i>g;g++)e.createElement(h[g]);return e}function i(a,b){b.cache||(b.cache={},b.createElem=a.createElement,b.createFrag=a.createDocumentFragment,b.frag=b.createFrag()),a.createElement=function(c){return y.shivMethods?g(c,a,b):b.createElem(c)},a.createDocumentFragment=Function("h,f","return function(){var n=f.cloneNode(),c=n.createElement;h.shivMethods&&("+d().join().replace(/[\w\-:]+/g,function(a){return b.createElem(a),b.frag.createElement(a),'c("'+a+'")'})+");return n}")(y,b.frag)}function j(a){a||(a=b);var d=f(a);return!y.shivCSS||p||d.hasCSS||(d.hasCSS=!!c(a,"article,aside,dialog,figcaption,figure,footer,header,hgroup,main,nav,section{display:block}mark{background:#FF0;color:#000}template{display:none}")),q||i(a,d),a}function k(a){for(var b,c=a.getElementsByTagName("*"),e=c.length,f=RegExp("^(?:"+d().join("|")+")$","i"),g=[];e--;)b=c[e],f.test(b.nodeName)&&g.push(b.applyElement(l(b)));return g}function l(a){for(var b,c=a.attributes,d=c.length,e=a.ownerDocument.createElement(A+":"+a.nodeName);d--;)b=c[d],b.specified&&e.setAttribute(b.nodeName,b.nodeValue);return e.style.cssText=a.style.cssText,e}function m(a){for(var b,c=a.split("{"),e=c.length,f=RegExp("(^|[\\s,>+~])("+d().join("|")+")(?=[[\\s,>+~#.:]|$)","gi"),g="$1"+A+"\\:$2";e--;)b=c[e]=c[e].split("}"),b[b.length-1]=b[b.length-1].replace(f,g),c[e]=b.join("}");return c.join("{")}function n(a){for(var b=a.length;b--;)a[b].removeNode()}function o(a){function b(){clearTimeout(g._removeSheetTimer),d&&d.removeNode(!0),d=null}var d,e,g=f(a),h=a.namespaces,i=a.parentWindow;return!B||a.printShived?a:("undefined"==typeof h[A]&&h.add(A),i.attachEvent("onbeforeprint",function(){b();for(var f,g,h,i=a.styleSheets,j=[],l=i.length,n=Array(l);l--;)n[l]=i[l];for(;h=n.pop();)if(!h.disabled&&z.test(h.media)){try{f=h.imports,g=f.length}catch(o){g=0}for(l=0;g>l;l++)n.push(f[l]);try{j.push(h.cssText)}catch(o){}}j=m(j.reverse().join("")),e=k(a),d=c(a,j)}),i.attachEvent("onafterprint",function(){n(e),clearTimeout(g._removeSheetTimer),g._removeSheetTimer=setTimeout(b,500)}),a.printShived=!0,a)}var p,q,r="3.7.3",s=a.html5||{},t=/^<|^(?:button|map|select|textarea|object|iframe|option|optgroup)$/i,u=/^(?:a|b|code|div|fieldset|h1|h2|h3|h4|h5|h6|i|label|li|ol|p|q|span|strong|style|table|tbody|td|th|tr|ul)$/i,v="_html5shiv",w=0,x={};!function(){try{var a=b.createElement("a");a.innerHTML="",p="hidden"in a,q=1==a.childNodes.length||function(){b.createElement("a");var a=b.createDocumentFragment();return"undefined"==typeof a.cloneNode||"undefined"==typeof a.createDocumentFragment||"undefined"==typeof a.createElement}()}catch(c){p=!0,q=!0}}();var y={elements:s.elements||"abbr article aside audio bdi canvas data datalist details dialog figcaption figure footer header hgroup main mark meter nav output picture progress section summary template time video",version:r,shivCSS:s.shivCSS!==!1,supportsUnknownElements:q,shivMethods:s.shivMethods!==!1,type:"default",shivDocument:j,createElement:g,createDocumentFragment:h,addElements:e};a.html5=y,j(b);var z=/^$|\b(?:all|print)\b/,A="html5shiv",B=!q&&function(){var c=b.documentElement;return!("undefined"==typeof b.namespaces||"undefined"==typeof b.parentWindow||"undefined"==typeof c.applyElement||"undefined"==typeof c.removeNode||"undefined"==typeof a.attachEvent)}();y.type+=" print",y.shivPrint=o,o(b),"object"==typeof module&&module.exports&&(module.exports=y)}("undefined"!=typeof window?window:this,document); -------------------------------------------------------------------------------- /conformer/docs/_static/js/html5shiv.min.js: -------------------------------------------------------------------------------- 1 | /** 2 | * @preserve HTML5 Shiv 3.7.3 | @afarkas @jdalton @jon_neal @rem | MIT/GPL2 Licensed 3 | */ 4 | !function(a,b){function c(a,b){var c=a.createElement("p"),d=a.getElementsByTagName("head")[0]||a.documentElement;return c.innerHTML="x",d.insertBefore(c.lastChild,d.firstChild)}function d(){var a=t.elements;return"string"==typeof a?a.split(" "):a}function e(a,b){var c=t.elements;"string"!=typeof c&&(c=c.join(" ")),"string"!=typeof a&&(a=a.join(" ")),t.elements=c+" "+a,j(b)}function f(a){var b=s[a[q]];return b||(b={},r++,a[q]=r,s[r]=b),b}function g(a,c,d){if(c||(c=b),l)return c.createElement(a);d||(d=f(c));var e;return e=d.cache[a]?d.cache[a].cloneNode():p.test(a)?(d.cache[a]=d.createElem(a)).cloneNode():d.createElem(a),!e.canHaveChildren||o.test(a)||e.tagUrn?e:d.frag.appendChild(e)}function h(a,c){if(a||(a=b),l)return a.createDocumentFragment();c=c||f(a);for(var e=c.frag.cloneNode(),g=0,h=d(),i=h.length;i>g;g++)e.createElement(h[g]);return e}function i(a,b){b.cache||(b.cache={},b.createElem=a.createElement,b.createFrag=a.createDocumentFragment,b.frag=b.createFrag()),a.createElement=function(c){return t.shivMethods?g(c,a,b):b.createElem(c)},a.createDocumentFragment=Function("h,f","return function(){var n=f.cloneNode(),c=n.createElement;h.shivMethods&&("+d().join().replace(/[\w\-:]+/g,function(a){return b.createElem(a),b.frag.createElement(a),'c("'+a+'")'})+");return n}")(t,b.frag)}function j(a){a||(a=b);var d=f(a);return!t.shivCSS||k||d.hasCSS||(d.hasCSS=!!c(a,"article,aside,dialog,figcaption,figure,footer,header,hgroup,main,nav,section{display:block}mark{background:#FF0;color:#000}template{display:none}")),l||i(a,d),a}var k,l,m="3.7.3-pre",n=a.html5||{},o=/^<|^(?:button|map|select|textarea|object|iframe|option|optgroup)$/i,p=/^(?:a|b|code|div|fieldset|h1|h2|h3|h4|h5|h6|i|label|li|ol|p|q|span|strong|style|table|tbody|td|th|tr|ul)$/i,q="_html5shiv",r=0,s={};!function(){try{var a=b.createElement("a");a.innerHTML="",k="hidden"in a,l=1==a.childNodes.length||function(){b.createElement("a");var a=b.createDocumentFragment();return"undefined"==typeof a.cloneNode||"undefined"==typeof a.createDocumentFragment||"undefined"==typeof a.createElement}()}catch(c){k=!0,l=!0}}();var t={elements:n.elements||"abbr article aside audio bdi canvas data datalist details dialog figcaption figure footer header hgroup main mark meter nav output picture progress section summary template time video",version:m,shivCSS:n.shivCSS!==!1,supportsUnknownElements:l,shivMethods:n.shivMethods!==!1,type:"default",shivDocument:j,createElement:g,createDocumentFragment:h,addElements:e};a.html5=t,j(b),"object"==typeof module&&module.exports&&(module.exports=t)}("undefined"!=typeof window?window:this,document); -------------------------------------------------------------------------------- /conformer/docs/_static/js/theme.js: -------------------------------------------------------------------------------- 1 | !function(n){var e={};function t(i){if(e[i])return e[i].exports;var o=e[i]={i:i,l:!1,exports:{}};return n[i].call(o.exports,o,o.exports,t),o.l=!0,o.exports}t.m=n,t.c=e,t.d=function(n,e,i){t.o(n,e)||Object.defineProperty(n,e,{enumerable:!0,get:i})},t.r=function(n){"undefined"!=typeof Symbol&&Symbol.toStringTag&&Object.defineProperty(n,Symbol.toStringTag,{value:"Module"}),Object.defineProperty(n,"__esModule",{value:!0})},t.t=function(n,e){if(1&e&&(n=t(n)),8&e)return n;if(4&e&&"object"==typeof n&&n&&n.__esModule)return n;var i=Object.create(null);if(t.r(i),Object.defineProperty(i,"default",{enumerable:!0,value:n}),2&e&&"string"!=typeof n)for(var o in n)t.d(i,o,function(e){return n[e]}.bind(null,o));return i},t.n=function(n){var e=n&&n.__esModule?function(){return n.default}:function(){return n};return t.d(e,"a",e),e},t.o=function(n,e){return Object.prototype.hasOwnProperty.call(n,e)},t.p="",t(t.s=0)}([function(n,e,t){t(1),n.exports=t(3)},function(n,e,t){(function(){var e="undefined"!=typeof window?window.jQuery:t(2);n.exports.ThemeNav={navBar:null,win:null,winScroll:!1,winResize:!1,linkScroll:!1,winPosition:0,winHeight:null,docHeight:null,isRunning:!1,enable:function(n){var t=this;void 0===n&&(n=!0),t.isRunning||(t.isRunning=!0,e((function(e){t.init(e),t.reset(),t.win.on("hashchange",t.reset),n&&t.win.on("scroll",(function(){t.linkScroll||t.winScroll||(t.winScroll=!0,requestAnimationFrame((function(){t.onScroll()})))})),t.win.on("resize",(function(){t.winResize||(t.winResize=!0,requestAnimationFrame((function(){t.onResize()})))})),t.onResize()})))},enableSticky:function(){this.enable(!0)},init:function(n){n(document);var e=this;this.navBar=n("div.wy-side-scroll:first"),this.win=n(window),n(document).on("click","[data-toggle='wy-nav-top']",(function(){n("[data-toggle='wy-nav-shift']").toggleClass("shift"),n("[data-toggle='rst-versions']").toggleClass("shift")})).on("click",".wy-menu-vertical .current ul li a",(function(){var t=n(this);n("[data-toggle='wy-nav-shift']").removeClass("shift"),n("[data-toggle='rst-versions']").toggleClass("shift"),e.toggleCurrent(t),e.hashChange()})).on("click","[data-toggle='rst-current-version']",(function(){n("[data-toggle='rst-versions']").toggleClass("shift-up")})),n("table.docutils:not(.field-list,.footnote,.citation)").wrap("
"),n("table.docutils.footnote").wrap("
"),n("table.docutils.citation").wrap("
"),n(".wy-menu-vertical ul").not(".simple").siblings("a").each((function(){var t=n(this);expand=n(''),expand.on("click",(function(n){return e.toggleCurrent(t),n.stopPropagation(),!1})),t.prepend(expand)}))},reset:function(){var n=encodeURI(window.location.hash)||"#";try{var e=$(".wy-menu-vertical"),t=e.find('[href="'+n+'"]');if(0===t.length){var i=$('.document [id="'+n.substring(1)+'"]').closest("div.section");0===(t=e.find('[href="#'+i.attr("id")+'"]')).length&&(t=e.find('[href="#"]'))}t.length>0&&($(".wy-menu-vertical .current").removeClass("current"),t.addClass("current"),t.closest("li.toctree-l1").addClass("current"),t.closest("li.toctree-l1").parent().addClass("current"),t.closest("li.toctree-l1").addClass("current"),t.closest("li.toctree-l2").addClass("current"),t.closest("li.toctree-l3").addClass("current"),t.closest("li.toctree-l4").addClass("current"),t.closest("li.toctree-l5").addClass("current"),t[0].scrollIntoView())}catch(n){console.log("Error expanding nav for anchor",n)}},onScroll:function(){this.winScroll=!1;var n=this.win.scrollTop(),e=n+this.winHeight,t=this.navBar.scrollTop()+(n-this.winPosition);n<0||e>this.docHeight||(this.navBar.scrollTop(t),this.winPosition=n)},onResize:function(){this.winResize=!1,this.winHeight=this.win.height(),this.docHeight=$(document).height()},hashChange:function(){this.linkScroll=!0,this.win.one("hashchange",(function(){this.linkScroll=!1}))},toggleCurrent:function(n){var e=n.closest("li");e.siblings("li.current").removeClass("current"),e.siblings().find("li.current").removeClass("current"),e.find("> ul li.current").removeClass("current"),e.toggleClass("current")}},"undefined"!=typeof window&&(window.SphinxRtdTheme={Navigation:n.exports.ThemeNav,StickyNav:n.exports.ThemeNav}),function(){for(var n=0,e=["ms","moz","webkit","o"],t=0;t2;a== 12 | null&&(a=[]);if(y&&a.reduce===y)return e&&(c=b.bind(c,e)),f?a.reduce(c,d):a.reduce(c);j(a,function(a,b,i){f?d=c.call(e,d,a,b,i):(d=a,f=true)});if(!f)throw new TypeError("Reduce of empty array with no initial value");return d};b.reduceRight=b.foldr=function(a,c,d,e){var f=arguments.length>2;a==null&&(a=[]);if(z&&a.reduceRight===z)return e&&(c=b.bind(c,e)),f?a.reduceRight(c,d):a.reduceRight(c);var g=b.toArray(a).reverse();e&&!f&&(c=b.bind(c,e));return f?b.reduce(g,c,d,e):b.reduce(g,c)};b.find=b.detect= 13 | function(a,c,b){var e;E(a,function(a,g,h){if(c.call(b,a,g,h))return e=a,true});return e};b.filter=b.select=function(a,c,b){var e=[];if(a==null)return e;if(A&&a.filter===A)return a.filter(c,b);j(a,function(a,g,h){c.call(b,a,g,h)&&(e[e.length]=a)});return e};b.reject=function(a,c,b){var e=[];if(a==null)return e;j(a,function(a,g,h){c.call(b,a,g,h)||(e[e.length]=a)});return e};b.every=b.all=function(a,c,b){var e=true;if(a==null)return e;if(B&&a.every===B)return a.every(c,b);j(a,function(a,g,h){if(!(e= 14 | e&&c.call(b,a,g,h)))return n});return e};var E=b.some=b.any=function(a,c,d){c||(c=b.identity);var e=false;if(a==null)return e;if(C&&a.some===C)return a.some(c,d);j(a,function(a,b,h){if(e||(e=c.call(d,a,b,h)))return n});return!!e};b.include=b.contains=function(a,c){var b=false;if(a==null)return b;return p&&a.indexOf===p?a.indexOf(c)!=-1:b=E(a,function(a){return a===c})};b.invoke=function(a,c){var d=i.call(arguments,2);return b.map(a,function(a){return(b.isFunction(c)?c||a:a[c]).apply(a,d)})};b.pluck= 15 | function(a,c){return b.map(a,function(a){return a[c]})};b.max=function(a,c,d){if(!c&&b.isArray(a))return Math.max.apply(Math,a);if(!c&&b.isEmpty(a))return-Infinity;var e={computed:-Infinity};j(a,function(a,b,h){b=c?c.call(d,a,b,h):a;b>=e.computed&&(e={value:a,computed:b})});return e.value};b.min=function(a,c,d){if(!c&&b.isArray(a))return Math.min.apply(Math,a);if(!c&&b.isEmpty(a))return Infinity;var e={computed:Infinity};j(a,function(a,b,h){b=c?c.call(d,a,b,h):a;bd?1:0}),"value")};b.groupBy=function(a,c){var d={},e=b.isFunction(c)?c:function(a){return a[c]};j(a,function(a,b){var c=e(a,b);(d[c]||(d[c]=[])).push(a)});return d};b.sortedIndex=function(a, 17 | c,d){d||(d=b.identity);for(var e=0,f=a.length;e>1;d(a[g])=0})})};b.difference=function(a){var c=b.flatten(i.call(arguments,1));return b.filter(a,function(a){return!b.include(c,a)})};b.zip=function(){for(var a=i.call(arguments),c=b.max(b.pluck(a,"length")),d=Array(c),e=0;e=0;d--)b=[a[d].apply(this,b)];return b[0]}}; 24 | b.after=function(a,b){return a<=0?b():function(){if(--a<1)return b.apply(this,arguments)}};b.keys=J||function(a){if(a!==Object(a))throw new TypeError("Invalid object");var c=[],d;for(d in a)b.has(a,d)&&(c[c.length]=d);return c};b.values=function(a){return b.map(a,b.identity)};b.functions=b.methods=function(a){var c=[],d;for(d in a)b.isFunction(a[d])&&c.push(d);return c.sort()};b.extend=function(a){j(i.call(arguments,1),function(b){for(var d in b)a[d]=b[d]});return a};b.defaults=function(a){j(i.call(arguments, 25 | 1),function(b){for(var d in b)a[d]==null&&(a[d]=b[d])});return a};b.clone=function(a){return!b.isObject(a)?a:b.isArray(a)?a.slice():b.extend({},a)};b.tap=function(a,b){b(a);return a};b.isEqual=function(a,b){return q(a,b,[])};b.isEmpty=function(a){if(b.isArray(a)||b.isString(a))return a.length===0;for(var c in a)if(b.has(a,c))return false;return true};b.isElement=function(a){return!!(a&&a.nodeType==1)};b.isArray=o||function(a){return l.call(a)=="[object Array]"};b.isObject=function(a){return a===Object(a)}; 26 | b.isArguments=function(a){return l.call(a)=="[object Arguments]"};if(!b.isArguments(arguments))b.isArguments=function(a){return!(!a||!b.has(a,"callee"))};b.isFunction=function(a){return l.call(a)=="[object Function]"};b.isString=function(a){return l.call(a)=="[object String]"};b.isNumber=function(a){return l.call(a)=="[object Number]"};b.isNaN=function(a){return a!==a};b.isBoolean=function(a){return a===true||a===false||l.call(a)=="[object Boolean]"};b.isDate=function(a){return l.call(a)=="[object Date]"}; 27 | b.isRegExp=function(a){return l.call(a)=="[object RegExp]"};b.isNull=function(a){return a===null};b.isUndefined=function(a){return a===void 0};b.has=function(a,b){return I.call(a,b)};b.noConflict=function(){r._=G;return this};b.identity=function(a){return a};b.times=function(a,b,d){for(var e=0;e/g,">").replace(/"/g,""").replace(/'/g,"'").replace(/\//g,"/")};b.mixin=function(a){j(b.functions(a), 28 | function(c){K(c,b[c]=a[c])})};var L=0;b.uniqueId=function(a){var b=L++;return a?a+b:b};b.templateSettings={evaluate:/<%([\s\S]+?)%>/g,interpolate:/<%=([\s\S]+?)%>/g,escape:/<%-([\s\S]+?)%>/g};var t=/.^/,u=function(a){return a.replace(/\\\\/g,"\\").replace(/\\'/g,"'")};b.template=function(a,c){var d=b.templateSettings,d="var __p=[],print=function(){__p.push.apply(__p,arguments);};with(obj||{}){__p.push('"+a.replace(/\\/g,"\\\\").replace(/'/g,"\\'").replace(d.escape||t,function(a,b){return"',_.escape("+ 29 | u(b)+"),'"}).replace(d.interpolate||t,function(a,b){return"',"+u(b)+",'"}).replace(d.evaluate||t,function(a,b){return"');"+u(b).replace(/[\r\n\t]/g," ")+";__p.push('"}).replace(/\r/g,"\\r").replace(/\n/g,"\\n").replace(/\t/g,"\\t")+"');}return __p.join('');",e=new Function("obj","_",d);return c?e(c,b):function(a){return e.call(this,a,b)}};b.chain=function(a){return b(a).chain()};var m=function(a){this._wrapped=a};b.prototype=m.prototype;var v=function(a,c){return c?b(a).chain():a},K=function(a,c){m.prototype[a]= 30 | function(){var a=i.call(arguments);H.call(a,this._wrapped);return v(c.apply(b,a),this._chain)}};b.mixin(b);j("pop,push,reverse,shift,sort,splice,unshift".split(","),function(a){var b=k[a];m.prototype[a]=function(){var d=this._wrapped;b.apply(d,arguments);var e=d.length;(a=="shift"||a=="splice")&&e===0&&delete d[0];return v(d,this._chain)}});j(["concat","join","slice"],function(a){var b=k[a];m.prototype[a]=function(){return v(b.apply(this._wrapped,arguments),this._chain)}});m.prototype.chain=function(){this._chain= 31 | true;return this};m.prototype.value=function(){return this._wrapped}}).call(this); 32 | -------------------------------------------------------------------------------- /conformer/docs/genindex.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | Index — conformer latest documentation 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 |
48 | 49 | 99 | 100 |
101 | 102 | 103 | 109 | 110 | 111 |
112 | 113 |
114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 |
134 | 135 |
    136 | 137 |
  • »
  • 138 | 139 |
  • Index
  • 140 | 141 | 142 |
  • 143 | 144 | 145 | 146 |
  • 147 | 148 |
149 | 150 | 151 |
152 |
153 |
154 |
155 | 156 | 157 |

Index

158 | 159 |
160 | 161 |
162 | 163 | 164 |
165 | 166 |
167 |
168 | 169 |
170 | 171 |
172 |

173 | © Copyright 2021, Soohwan Kim. 174 | 175 |

176 |
177 | 178 | 179 | 180 | Built with Sphinx using a 181 | 182 | theme 183 | 184 | provided by Read the Docs. 185 | 186 |
187 |
188 |
189 | 190 |
191 | 192 |
193 | 194 | 195 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | -------------------------------------------------------------------------------- /conformer/docs/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | Welcome to Conformer’s documentation! — conformer latest documentation 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 |
49 | 50 | 100 | 101 |
102 | 103 | 104 | 110 | 111 | 112 |
113 | 114 |
115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 |
135 | 136 |
    137 | 138 |
  • »
  • 139 | 140 |
  • Welcome to Conformer’s documentation!
  • 141 | 142 | 143 |
  • 144 | 145 | 146 | View page source 147 | 148 | 149 |
  • 150 | 151 |
152 | 153 | 154 |
155 |
156 |
157 |
158 | 159 |
160 |

Welcome to Conformer’s documentation!

161 |
162 |

PACKAGE

163 | 168 |
169 |
170 |
171 |

Indices and tables

172 | 177 |
178 | 179 | 180 |
181 | 182 |
183 |
184 | 187 | 188 |
189 | 190 |
191 |

192 | © Copyright 2021, Soohwan Kim. 193 | 194 |

195 |
196 | 197 | 198 | 199 | Built with Sphinx using a 200 | 201 | theme 202 | 203 | provided by Read the Docs. 204 | 205 |
206 |
207 |
208 | 209 |
210 | 211 |
212 | 213 | 214 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | -------------------------------------------------------------------------------- /conformer/docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /conformer/docs/objects.inv: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/choiHkk/CVAEJETS/df71e7d484926c5ff0eac1a4749dbb2cafe81a6d/conformer/docs/objects.inv -------------------------------------------------------------------------------- /conformer/docs/search.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | Search — conformer latest documentation 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 |
51 | 52 | 102 | 103 |
104 | 105 | 106 | 112 | 113 | 114 |
115 | 116 |
117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 |
137 | 138 |
    139 | 140 |
  • »
  • 141 | 142 |
  • Search
  • 143 | 144 | 145 |
  • 146 | 147 |
  • 148 | 149 |
150 | 151 | 152 |
153 |
154 |
155 |
156 | 157 | 164 | 165 | 166 |
167 | 168 |
169 | 170 |
171 | 172 |
173 |
174 | 175 |
176 | 177 |
178 |

179 | © Copyright 2021, Soohwan Kim. 180 | 181 |

182 |
183 | 184 | 185 | 186 | Built with Sphinx using a 187 | 188 | theme 189 | 190 | provided by Read the Docs. 191 | 192 |
193 |
194 |
195 | 196 |
197 | 198 |
199 | 200 | 201 | 206 | 207 | 208 | 209 | 210 | 211 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | -------------------------------------------------------------------------------- /conformer/docs/searchindex.js: -------------------------------------------------------------------------------- 1 | Search.setIndex({docnames:["Model","Modules","Submodules","index"],envversion:{"sphinx.domains.c":2,"sphinx.domains.changeset":1,"sphinx.domains.citation":1,"sphinx.domains.cpp":3,"sphinx.domains.index":1,"sphinx.domains.javascript":2,"sphinx.domains.math":2,"sphinx.domains.python":2,"sphinx.domains.rst":2,"sphinx.domains.std":1,"sphinx.ext.intersphinx":1,"sphinx.ext.todo":2,"sphinx.ext.viewcode":1,sphinx:56},filenames:["Model.rst","Modules.rst","Submodules.rst","index.rst"],objects:{},objnames:{},objtypes:{},terms:{index:3,model:3,modul:3,packag:3,page:3,search:3,submodul:3},titles:["Model","Conformer Modules","Submodules","Welcome to Conformer\u2019s documentation!"],titleterms:{activ:2,attent:1,conform:[0,1,3],convolut:1,decod:0,document:3,embed:2,encod:0,feed:1,forward:1,indic:3,model:0,modul:[1,2],submodul:2,tabl:3,welcom:3}}) -------------------------------------------------------------------------------- /conformer/docs/source/Model.rst: -------------------------------------------------------------------------------- 1 | 2 | Model 3 | ===================================================== 4 | 5 | Conformer 6 | -------------------------------------------- 7 | 8 | .. automodule:: conformer.model 9 | :members: 10 | 11 | Encoder 12 | -------------------------------------------- 13 | 14 | .. automodule:: conformer.encoder 15 | :members: 16 | 17 | Decoder 18 | -------------------------------------------- 19 | 20 | .. automodule:: conformer.decoder 21 | :members: 22 | -------------------------------------------------------------------------------- /conformer/docs/source/Modules.rst: -------------------------------------------------------------------------------- 1 | 2 | Conformer Modules 3 | ===================================================== 4 | 5 | Attention 6 | -------------------------------------------- 7 | 8 | .. automodule:: conformer.attention 9 | :members: 10 | 11 | Convolution 12 | -------------------------------------------- 13 | 14 | .. automodule:: conformer.convolution 15 | :members: 16 | 17 | Feed Forward 18 | -------------------------------------------- 19 | 20 | .. automodule:: conformer.feed_forward 21 | :members: -------------------------------------------------------------------------------- /conformer/docs/source/Submodules.rst: -------------------------------------------------------------------------------- 1 | 2 | Submodules 3 | ===================================================== 4 | 5 | Activation 6 | -------------------------------------------- 7 | 8 | .. automodule:: conformer.activation 9 | :members: 10 | 11 | Modules 12 | -------------------------------------------- 13 | 14 | .. automodule:: conformer.modules 15 | :members: 16 | 17 | Embedding 18 | -------------------------------------------- 19 | 20 | .. automodule:: conformer.embedding 21 | :members: -------------------------------------------------------------------------------- /conformer/docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | 14 | import os 15 | import sys 16 | sys.path.append(os.path.abspath('.')) 17 | sys.path.append(os.path.abspath('..')) 18 | import sphinx_rtd_theme 19 | 20 | 21 | # -- Project information ----------------------------------------------------- 22 | 23 | project = 'conformer' 24 | copyright = '2021, Soohwan Kim' 25 | author = 'Soohwan Kim' 26 | 27 | # The full version, including alpha/beta/rc tags 28 | release = 'latest' 29 | 30 | 31 | # -- General configuration --------------------------------------------------- 32 | 33 | # Add any Sphinx extension module names here, as strings. They can be 34 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 35 | # ones. 36 | extensions = [ 37 | 'sphinx.ext.autosummary', 38 | 'sphinx.ext.doctest', 39 | 'sphinx.ext.intersphinx', 40 | 'sphinx.ext.todo', 41 | 'sphinx.ext.coverage', 42 | 'sphinx.ext.mathjax', 43 | 'sphinx.ext.ifconfig', 44 | 'sphinx.ext.napoleon', 45 | "sphinx_rtd_theme", 46 | 'sphinx.ext.autodoc', 47 | 'sphinx.ext.imgmath', 48 | 'sphinx.ext.ifconfig', 49 | 'sphinx.ext.viewcode', 50 | 'sphinx.ext.githubpages', 51 | 'recommonmark', 52 | ] 53 | 54 | napoleon_use_ivar = True 55 | 56 | # Add any paths that contain templates here, relative to this directory. 57 | templates_path = ['_templates'] 58 | 59 | imgmath_image_format = 'svg' 60 | imgmath_latex = 'xelatex' 61 | imgmath_latex_args = ['--no-pdf'] 62 | 63 | # Source parsers 64 | source_parsers = { 65 | #'.md': 'recommonmark.parser.CommonMarkParser' 66 | } 67 | 68 | # The suffix(es) of source filenames. 69 | # You can specify multiple suffix as a list of string: 70 | # 71 | # source_suffix = ['.rst', '.md'] 72 | source_suffix = ['.rst', '.md'] 73 | 74 | # The master toctree document. 75 | master_doc = 'index' 76 | 77 | # The language for content autogenerated by Sphinx. Refer to documentation 78 | # for a list of supported languages. 79 | # 80 | # This is also used if you do content translation via gettext catalogs. 81 | # Usually you set "language" from the command line for these cases. 82 | language = None 83 | 84 | # List of patterns, relative to source directory, that match files and 85 | # directories to ignore when looking for source files. 86 | # This pattern also affects html_static_path and html_extra_path. 87 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 88 | 89 | # The name of the Pygments (syntax highlighting) style to use. 90 | # The name of the Pygments (syntax highlighting) style to use. 91 | pygments_style = 'sphinx' 92 | 93 | # If true, `todo` and `todoList` produce output, else they produce nothing. 94 | todo_include_todos = True 95 | 96 | # -- Options for HTML output ------------------------------------------------- 97 | 98 | # The theme to use for HTML and HTML Help pages. See the documentation for 99 | # a list of builtin themes. 100 | # 101 | html_theme = 'sphinx_rtd_theme' 102 | 103 | # Theme options are theme-specific and customize the look and feel of a theme 104 | # further. For a list of options available for each theme, see the 105 | # documentation. 106 | # 107 | # html_theme_options = {} 108 | 109 | # Add any paths that contain custom static files (such as style sheets) here, 110 | # relative to this directory. They are copied after the builtin static files, 111 | # so a file named "default.css" will overwrite the builtin "default.css". 112 | html_static_path = ['_static'] 113 | 114 | # Custom sidebar templates, must be a dictionary that maps document names 115 | # to template names. 116 | # 117 | # The default sidebars (for documents that don't match any pattern) are 118 | # defined by theme itself. Builtin themes are using these templates by 119 | # default: ``['localtoc.html', 'relations.html', 'sourcelink.html', 120 | # 'searchbox.html']``. 121 | # 122 | # html_sidebars = {} 123 | 124 | 125 | # -- Options for HTMLHelp output --------------------------------------------- 126 | 127 | # Output file base name for HTML help builder. 128 | htmlhelp_basename = 'Conformer.doc' 129 | 130 | 131 | # -- Options for LaTeX output ------------------------------------------------ 132 | 133 | latex_elements = { 134 | # The paper size ('letterpaper' or 'a4paper'). 135 | # 136 | # 'papersize': 'letterpaper', 137 | 138 | # The font size ('10pt', '11pt' or '12pt'). 139 | # 140 | # 'pointsize': '10pt', 141 | 142 | # Additional stuff for the LaTeX preamble. 143 | # 144 | # 'preamble': '', 145 | 146 | # Latex figure (float) alignment 147 | # 148 | # 'figure_align': 'htbp', 149 | } 150 | 151 | # Grouping the document tree into LaTeX files. List of tuples 152 | # (source start file, target name, title, 153 | # author, documentclass [howto, manual, or own class]). 154 | latex_documents = [ 155 | (master_doc, 'Conformer.tex', 'Convolution-augmented Transformer for Speech Recognition', 156 | 'SooHwan Kim', 'sooftware'), 157 | ] 158 | 159 | 160 | # -- Options for manual page output ------------------------------------------ 161 | 162 | # One entry per manual page. List of tuples 163 | # (source start file, name, description, authors, manual section). 164 | man_pages = [ 165 | (master_doc, 'Conformer', 'Convolution-augmented Transformer for Speech Recognition', 166 | [author], 1) 167 | ] 168 | 169 | 170 | # -- Options for Texinfo output ---------------------------------------------- 171 | 172 | # Grouping the document tree into Texinfo files. List of tuples 173 | # (source start file, target name, title, author, 174 | # dir menu entry, description, category) 175 | texinfo_documents = [ 176 | (master_doc, 'Conformer', 'Convolution-augmented Transformer for Speech Recognition', 177 | author, 'Soohwan Kim', 'sooftware', 178 | 'Miscellaneous'), 179 | ] 180 | 181 | 182 | # -- Options for Epub output ------------------------------------------------- 183 | 184 | # Bibliographic Dublin Core info. 185 | epub_title = project 186 | 187 | # The unique identifier of the text. This can be a ISBN number 188 | # or the project homepage. 189 | # 190 | # epub_identifier = '' 191 | 192 | # A unique identification for the text. 193 | # 194 | # epub_uid = '' 195 | 196 | # A list of files that should not be packed into the epub file. 197 | epub_exclude_files = ['search.html'] 198 | 199 | 200 | # Example configuration for intersphinx: refer to the Python standard library. 201 | intersphinx_mapping = { 202 | 'python': ('https://docs.python.org/', None), 203 | 'numpy': ('http://docs.scipy.org/doc/numpy/', None), 204 | 'PyTorch': ('http://pytorch.org/docs/master/', None), 205 | } 206 | -------------------------------------------------------------------------------- /conformer/docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. conformer documentation master file, created by 2 | sphinx-quickstart on Sun Jan 24 01:16:16 2021. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to Conformer's documentation! 7 | ===================================== 8 | 9 | .. toctree:: 10 | :maxdepth: 1 11 | :caption: PACKAGE 12 | 13 | Model 14 | Modules 15 | Submodules 16 | 17 | 18 | 19 | Indices and tables 20 | ================== 21 | 22 | * :ref:`genindex` 23 | * :ref:`modindex` 24 | * :ref:`search` 25 | -------------------------------------------------------------------------------- /conformer/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, Soohwan Kim. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from setuptools import setup, find_packages 16 | 17 | setup( 18 | name='conformer', 19 | packages = find_packages(), 20 | version='latest', 21 | description='Convolution-augmented Transformer for Speech Recognition', 22 | author='Soohwan Kim', 23 | author_email='sh951011@gmail.com', 24 | url='https://github.com/sooftware/conformer', 25 | install_requires=[ 26 | 'torch>=1.4.0', 27 | 'numpy', 28 | ], 29 | keywords=['asr', 'speech_recognition', 'conformer', 'end-to-end'], 30 | python_requires='>=3.6' 31 | ) 32 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import torch 5 | import yaml 6 | import torch.nn as nn 7 | 8 | from utils.model import get_model 9 | from utils.tools import to_device, to_device_inference, log, plot_spectrogram_to_numpy, plot_alignment_to_numpy 10 | from model import CVAEJETSLoss 11 | from data_utils import AudioTextDataset, AudioTextCollate, DataLoader 12 | from mel_processing import mel_spectrogram_torch 13 | 14 | 15 | def evaluate(models, step, configs, device, logger=None): 16 | model, discriminator = models 17 | preprocess_config, model_config, train_config = configs 18 | hop_size = preprocess_config["preprocessing"]["stft"]["hop_length"] 19 | sampling_rate = preprocess_config["preprocessing"]["audio"]["sampling_rate"] 20 | 21 | # Get dataset 22 | dataset = AudioTextDataset( 23 | preprocess_config['path']['validation_files'], preprocess_config) 24 | 25 | batch_size = train_config["optimizer"]["batch_size"] 26 | collate_fn = AudioTextCollate() 27 | loader = DataLoader( 28 | dataset, 29 | batch_size=batch_size, 30 | shuffle=False, 31 | collate_fn=collate_fn, 32 | num_workers=8, 33 | pin_memory=True, 34 | drop_last=False 35 | ) 36 | 37 | # Get loss function 38 | Loss = CVAEJETSLoss(preprocess_config, model_config, train_config).to(device) 39 | 40 | # Evaluation 41 | loss_sums_disc = [0 for _ in range(1)] # + total 42 | loss_sums_model = [0 for _ in range(12)] # + total 43 | for batch in loader: 44 | batch = to_device(batch, device) 45 | 46 | with torch.no_grad(): 47 | output = model(*(batch[:-1]), step=step, gen=False) 48 | 49 | wav_predictions, indices = output[0], output[7] 50 | wav_targets = batch[-1][...,indices[0]*hop_size:indices[1]*hop_size] 51 | 52 | y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = discriminator(wav_targets.unsqueeze(1), wav_predictions) 53 | 54 | loss_disc, losses_disc = Loss.disc_loss_fn( 55 | disc_real_outputs=y_d_hat_r, 56 | disc_generated_outputs=y_d_hat_g) 57 | 58 | loss_model, losses_model = Loss.gen_loss_fn( 59 | inputs=batch, 60 | predictions=output, 61 | step=step, 62 | disc_outputs=y_d_hat_g, 63 | fmap_r=fmap_r, 64 | fmap_g=fmap_g) 65 | 66 | for i in range(len(losses_disc)): 67 | loss_sums_disc[i] += list(losses_disc.values())[i].item() * len(batch[0]) 68 | for i in range(len(losses_model)): 69 | loss_sums_model[i] += list(losses_model.values())[i].item() * len(batch[0]) 70 | 71 | # get scalars 72 | loss_means_disc = [loss_sum / len(dataset) for loss_sum in loss_sums_disc] 73 | loss_means_model = [loss_sum / len(dataset) for loss_sum in loss_sums_model] 74 | scalars_disc = {k:v for k,v in zip(losses_disc.keys(), loss_means_disc)} 75 | scalars_model = {k:v for k,v in zip(losses_model.keys(), loss_means_model)} 76 | 77 | message1 = f"Discriminator Validation Step {step}, " + " ".join([str(round(l.item(), 4)) for l in losses_disc.values()]).strip() 78 | message2 = f"Model Validation Step {step}, " + " ".join([str(round(l.item(), 4)) for l in losses_model.values()]).strip() 79 | message = f"{message1}\n{message2}" 80 | 81 | # synthesis one sample 82 | with torch.no_grad(): 83 | # segmented output 84 | for i in range(len(batch)-1): 85 | try: 86 | batch[i] = batch[i][:1] 87 | except: 88 | pass 89 | 90 | output = model(*(batch[:-1]), step=step) 91 | wav = output[0] 92 | mel = Loss.synthesizer_loss.get_mel(wav) 93 | wav_len = output[9][0].item() * hop_size 94 | attn_h = output[10] 95 | attn_s = output[11] 96 | 97 | # total output 98 | pairs = to_device_inference( 99 | [batch[0][:1], batch[1][:1], batch[2][:1], None], device) 100 | output_gen = model(*(pairs), gen=True) 101 | wav_gen = output_gen[0] 102 | mel_gen = Loss.synthesizer_loss.get_mel(wav_gen) 103 | wav_gen_len = output_gen[9][0].item() * hop_size 104 | 105 | image_dict = { 106 | "gen/mel": plot_spectrogram_to_numpy(mel[0].cpu().numpy()), 107 | "gen/mel_gen": plot_spectrogram_to_numpy(mel_gen[0].cpu().numpy()), 108 | "all/attn_h": plot_alignment_to_numpy(attn_h[0,0].data.cpu().numpy()), 109 | "all/attn_s": plot_alignment_to_numpy(attn_s[0,0].data.cpu().numpy()) 110 | } 111 | audio_dict = { 112 | "gen/audio": wav[0,:,:wav_len], 113 | "gen/audio_gen": wav_gen[0,:,:wav_gen_len] 114 | } 115 | scalar_dict = {} 116 | scalar_dict.update(scalars_disc) 117 | scalar_dict.update(scalars_model) 118 | if logger is not None: 119 | log(writer=logger, 120 | global_step=step, 121 | images=image_dict, 122 | audios=audio_dict, 123 | scalars=scalar_dict, 124 | audio_sampling_rate=sampling_rate) 125 | 126 | return message 127 | -------------------------------------------------------------------------------- /mel_processing.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import random 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | import torch.utils.data 8 | import numpy as np 9 | import librosa 10 | import librosa.util as librosa_util 11 | from librosa.util import normalize, pad_center, tiny 12 | from scipy.signal import get_window 13 | from scipy.io.wavfile import read 14 | from librosa.filters import mel as librosa_mel_fn 15 | 16 | MAX_WAV_VALUE = 32768.0 17 | 18 | 19 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 20 | """ 21 | PARAMS 22 | ------ 23 | C: compression factor 24 | """ 25 | return torch.log(torch.clamp(x, min=clip_val) * C) 26 | 27 | 28 | def dynamic_range_decompression_torch(x, C=1): 29 | """ 30 | PARAMS 31 | ------ 32 | C: compression factor used to compress 33 | """ 34 | return torch.exp(x) / C 35 | 36 | 37 | def spectral_normalize_torch(magnitudes): 38 | output = dynamic_range_compression_torch(magnitudes) 39 | return output 40 | 41 | 42 | def spectral_de_normalize_torch(magnitudes): 43 | output = dynamic_range_decompression_torch(magnitudes) 44 | return output 45 | 46 | 47 | mel_basis = {} 48 | hann_window = {} 49 | 50 | 51 | def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False): 52 | if torch.min(y) < -1.: 53 | print('min value is ', torch.min(y)) 54 | if torch.max(y) > 1.: 55 | print('max value is ', torch.max(y)) 56 | 57 | global hann_window 58 | dtype_device = str(y.dtype) + '_' + str(y.device) 59 | wnsize_dtype_device = str(win_size) + '_' + dtype_device 60 | if wnsize_dtype_device not in hann_window: 61 | hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) 62 | 63 | y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') 64 | y = y.squeeze(1) 65 | 66 | spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device], 67 | center=center, pad_mode='reflect', normalized=False, onesided=True) 68 | 69 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) 70 | return spec 71 | 72 | 73 | def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax): 74 | global mel_basis 75 | dtype_device = str(spec.dtype) + '_' + str(spec.device) 76 | fmax_dtype_device = str(fmax) + '_' + dtype_device 77 | if fmax_dtype_device not in mel_basis: 78 | mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) 79 | mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device) 80 | spec = torch.matmul(mel_basis[fmax_dtype_device], spec) 81 | spec = spectral_normalize_torch(spec) 82 | return spec 83 | 84 | 85 | def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): 86 | if torch.min(y) < -1.: 87 | print('min value is ', torch.min(y)) 88 | if torch.max(y) > 1.: 89 | print('max value is ', torch.max(y)) 90 | 91 | global mel_basis, hann_window 92 | dtype_device = str(y.dtype) + '_' + str(y.device) 93 | fmax_dtype_device = str(fmax) + '_' + dtype_device 94 | wnsize_dtype_device = str(win_size) + '_' + dtype_device 95 | if fmax_dtype_device not in mel_basis: 96 | mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) 97 | mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device) 98 | if wnsize_dtype_device not in hann_window: 99 | hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) 100 | 101 | y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') 102 | y = y.squeeze(1) 103 | 104 | spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device], 105 | center=center, pad_mode='reflect', normalized=False, onesided=True) 106 | 107 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) 108 | energy = torch.norm(spec, dim=1) 109 | 110 | spec = torch.matmul(mel_basis[fmax_dtype_device], spec) 111 | spec = spectral_normalize_torch(spec) 112 | 113 | return spec, energy -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .cvaejets import CVAEJETSSynthesizer, MultiPeriodDiscriminator 2 | from .loss import CVAEJETSLoss 3 | -------------------------------------------------------------------------------- /model/cvaejets.py: -------------------------------------------------------------------------------- 1 | from .modules import ( 2 | VarianceAdaptor, 3 | MultiPeriodDiscriminator, 4 | Generator, 5 | ResidualCouplingBlock, 6 | PosteriorEncoder, 7 | ) 8 | from utils.tools import get_mask_from_lengths, partial 9 | from conformer import Conformer as Encoder 10 | from text import symbols 11 | import torch.nn as nn 12 | import torch 13 | import json 14 | import os 15 | 16 | 17 | 18 | class CVAEJETSSynthesizer(nn.Module): 19 | def __init__(self, preprocess_config, model_config, train_config): 20 | super(CVAEJETSSynthesizer, self).__init__() 21 | self.preprocess_config = preprocess_config 22 | self.model_config = model_config 23 | self.train_config = train_config 24 | 25 | speaker_ids_path = os.path.join(preprocess_config["path"]["preprocessed_path"], "speakers.json") 26 | assert os.path.isfile(speaker_ids_path) 27 | with open(speaker_ids_path, "r", encoding='utf8') as f: 28 | n_speaker = len(json.load(f)) 29 | self.speaker_emb = nn.Embedding( 30 | n_speaker, 31 | self.model_config["speaker_encoder"]["speaker_encoder_hidden"], 32 | ) 33 | self.embedding = nn.Embedding( 34 | len(symbols), model_config["transformer"]["encoder_hidden"], padding_idx=0) 35 | self.encoder = Encoder(self.model_config) 36 | self.posterior_encoder = PosteriorEncoder( 37 | self.preprocess_config, self.model_config) 38 | self.variance_adaptor = VarianceAdaptor( 39 | self.preprocess_config, self.model_config, self.train_config) 40 | self.flow = ResidualCouplingBlock(self.model_config) 41 | self.generator = Generator(self.model_config) 42 | 43 | 44 | def forward( 45 | self, 46 | speakers, 47 | texts, 48 | src_lens, 49 | max_src_len, 50 | mels=None, 51 | mel_lens=None, 52 | max_mel_len=None, 53 | cwt_spec_targets=None, 54 | cwt_mean_target=None, 55 | cwt_std_target=None, 56 | uv=None, 57 | e_targets=None, 58 | attn_priors=None, 59 | p_control=1.0, 60 | e_control=1.0, 61 | d_control=1.0, 62 | step=None, 63 | gen=False, 64 | noise_scale=1.0, 65 | ): 66 | src_masks = get_mask_from_lengths(src_lens, max_src_len) 67 | mel_masks = ( 68 | get_mask_from_lengths(mel_lens, max_mel_len) 69 | if mel_lens is not None 70 | else None 71 | ) 72 | x = self.embedding(texts) 73 | x, src_lens = self.encoder(x, src_lens) 74 | g = self.speaker_emb(speakers).unsqueeze(-1) 75 | 76 | ( 77 | m_p, 78 | logs_p, 79 | p_predictions, 80 | e_predictions, 81 | log_d_predictions, 82 | d_rounded, 83 | mel_lens, 84 | mel_masks, 85 | attn_h, 86 | attn_s, 87 | attn_logprob 88 | ) = self.variance_adaptor( 89 | x, 90 | src_lens, 91 | src_masks, 92 | mels, 93 | mel_lens, 94 | mel_masks, 95 | max_mel_len, 96 | cwt_spec_targets, 97 | cwt_mean_target, 98 | cwt_std_target, 99 | uv, 100 | e_targets, 101 | attn_priors, 102 | g, 103 | p_control, 104 | e_control, 105 | d_control, 106 | step, 107 | gen, 108 | ) 109 | 110 | if not gen: 111 | z, m_q, logs_q, _ = self.posterior_encoder(mels, (~mel_masks).float().unsqueeze(1), g=g) 112 | z_p = self.flow(z, (~mel_masks).float().unsqueeze(1), g=g) 113 | z, indices = partial( 114 | y=z, 115 | segment_size=self.model_config["generator"]["segment_size"], 116 | hop_size=self.preprocess_config["preprocessing"]["stft"]["hop_length"]) 117 | else: 118 | m_q, logs_q, indices = None, None, [None, None] 119 | z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale 120 | z = self.flow(z_p, (~mel_masks).float().unsqueeze(1), g=g, reverse=True) 121 | 122 | wav = self.generator(z, g=g) 123 | 124 | return ( 125 | wav, 126 | p_predictions, 127 | e_predictions, 128 | log_d_predictions, 129 | d_rounded, 130 | src_masks, 131 | mel_masks, 132 | indices, 133 | src_lens, 134 | mel_lens, 135 | attn_h, 136 | attn_s, 137 | attn_logprob, 138 | z_p, 139 | m_p, 140 | logs_p, 141 | m_q, 142 | logs_q 143 | ) 144 | 145 | def voice_conversion(self, mels, mel_lens, max_mel_len, sid_src, sid_tgt): 146 | mel_masks = get_mask_from_lengths(mel_lens, max_mel_len) 147 | g_src = self.speaker_emb(sid_src).unsqueeze(-1) 148 | g_tgt = self.speaker_emb(sid_tgt).unsqueeze(-1) 149 | z, m_q, logs_q, y_mask = self.posterior_encoder(mels, (~mel_masks).float().unsqueeze(1), g=g_src) 150 | z_p = self.flow(z, y_mask, g=g_src) 151 | z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True) 152 | o_hat = self.generator(z_hat * y_mask, g=g_tgt) 153 | return o_hat, y_mask, (z, z_p, z_hat) 154 | -------------------------------------------------------------------------------- /model/layers.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class Linear(nn.Module): 5 | def __init__(self, in_dim, out_dim, bias=True, w_init="linear"): 6 | super(Linear, self).__init__() 7 | self.in_dim = in_dim 8 | self.out_dim = out_dim 9 | self.linear = nn.Linear(in_dim, out_dim, bias=bias) 10 | 11 | nn.init.xavier_uniform_( 12 | self.linear.weight, gain=nn.init.calculate_gain(w_init) 13 | ) 14 | 15 | def forward(self, x): 16 | return self.linear(x) 17 | 18 | 19 | class Conv(nn.Module): 20 | """ 21 | Convolution Module 22 | """ 23 | 24 | def __init__( 25 | self, 26 | in_channels, 27 | out_channels, 28 | kernel_size=1, 29 | stride=1, 30 | padding=0, 31 | dilation=1, 32 | bias=True, 33 | w_init="linear", 34 | ): 35 | """ 36 | :param in_channels: dimension of input 37 | :param out_channels: dimension of output 38 | :param kernel_size: size of kernel 39 | :param stride: size of stride 40 | :param padding: size of padding 41 | :param dilation: dilation rate 42 | :param bias: boolean. if True, bias is included. 43 | :param w_init: str. weight inits with xavier initialization. 44 | """ 45 | super(Conv, self).__init__() 46 | 47 | self.conv = nn.Conv1d( 48 | in_channels, 49 | out_channels, 50 | kernel_size=kernel_size, 51 | stride=stride, 52 | padding=padding, 53 | dilation=dilation, 54 | bias=bias, 55 | ) 56 | nn.init.xavier_uniform_( 57 | self.conv.weight, gain=nn.init.calculate_gain(w_init) 58 | ) 59 | 60 | def forward(self, x): 61 | x = x.contiguous().transpose(1, 2) 62 | x = self.conv(x) 63 | x = x.contiguous().transpose(1, 2) 64 | 65 | return x 66 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | from data_utils import StatParser 2 | import argparse 3 | import yaml 4 | 5 | 6 | if __name__ == "__main__": 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument("config", type=str, help="path to preprocess.yaml") 9 | args = parser.parse_args() 10 | 11 | config = yaml.load(open(args.config, "r"), Loader=yaml.FullLoader) 12 | preprocessor = StatParser(config, True) 13 | preprocessor() 14 | -------------------------------------------------------------------------------- /preprocessed_data/kss/speakers.json: -------------------------------------------------------------------------------- 1 | {"kss": 0} -------------------------------------------------------------------------------- /preprocessed_data/kss/stats.json: -------------------------------------------------------------------------------- 1 | {"energy": [-1.236502766609192, 7.244208812713623, 43.81615813554377, 35.41429019837838]} -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorboardX 2 | librosa 3 | matplotlib 4 | numpy 5 | scipy 6 | tqdm 7 | Unidecode 8 | inflect 9 | jamo 10 | pandas 11 | jpype1 12 | konlpy 13 | jamotools 14 | tweepy==3.10.0 15 | nltk 16 | pillow 17 | natsort 18 | easydict 19 | g2pk 20 | yaml 21 | praat-parselmouth 22 | pycwt -------------------------------------------------------------------------------- /samples/CVAEJETS-sample-0.75.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/choiHkk/CVAEJETS/df71e7d484926c5ff0eac1a4749dbb2cafe81a6d/samples/CVAEJETS-sample-0.75.wav -------------------------------------------------------------------------------- /samples/CVAEJETS-sample-1.00-vc.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/choiHkk/CVAEJETS/df71e7d484926c5ff0eac1a4749dbb2cafe81a6d/samples/CVAEJETS-sample-1.00-vc.wav -------------------------------------------------------------------------------- /samples/CVAEJETS-sample-1.00.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/choiHkk/CVAEJETS/df71e7d484926c5ff0eac1a4749dbb2cafe81a6d/samples/CVAEJETS-sample-1.00.wav -------------------------------------------------------------------------------- /samples/CVAEJETS-sample-1.50.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/choiHkk/CVAEJETS/df71e7d484926c5ff0eac1a4749dbb2cafe81a6d/samples/CVAEJETS-sample-1.50.wav -------------------------------------------------------------------------------- /samples/CVAEJETS-tensorboard-losses1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/choiHkk/CVAEJETS/df71e7d484926c5ff0eac1a4749dbb2cafe81a6d/samples/CVAEJETS-tensorboard-losses1.png -------------------------------------------------------------------------------- /samples/CVAEJETS-tensorboard-losses2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/choiHkk/CVAEJETS/df71e7d484926c5ff0eac1a4749dbb2cafe81a6d/samples/CVAEJETS-tensorboard-losses2.png -------------------------------------------------------------------------------- /samples/CVAEJETS-tensorboard-stats.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/choiHkk/CVAEJETS/df71e7d484926c5ff0eac1a4749dbb2cafe81a6d/samples/CVAEJETS-tensorboard-stats.png -------------------------------------------------------------------------------- /text/__init__.py: -------------------------------------------------------------------------------- 1 | # https://github.com/sooftware/taKotron2/blob/master/text/__init__.py 2 | import re 3 | import unicodedata 4 | from g2pk import G2p 5 | 6 | CHOSUNGS = "".join([chr(_) for _ in range(0x1100, 0x1113)]) 7 | JOONGSUNGS = "".join([chr(_) for _ in range(0x1161, 0x1176)]) 8 | JONGSUNGS = "".join([chr(_) for _ in range(0x11A8, 0x11C3)]) 9 | SPECIALS = " ?!" 10 | 11 | ALL_VOCABS = "".join([ 12 | CHOSUNGS, 13 | JOONGSUNGS, 14 | JONGSUNGS, 15 | SPECIALS 16 | ]) 17 | VOCAB_DICT = { 18 | "_": 0, 19 | "~": 1, 20 | } 21 | 22 | for idx, v in enumerate(ALL_VOCABS): 23 | VOCAB_DICT[v] = idx + 2 24 | 25 | symbols = VOCAB_DICT.keys() 26 | 27 | g2p = G2p() 28 | 29 | 30 | def normalize(text): 31 | text = unicodedata.normalize('NFKD', text) 32 | text = text.upper() 33 | regex = unicodedata.normalize('NFKD', r"[^ \u11A8-\u11FF\u1100-\u115E\u1161-\u11A7?!]") 34 | text = re.sub(regex, '', text) 35 | text = re.sub(' +', ' ', text) 36 | text = text.strip() 37 | return text 38 | 39 | 40 | def tokenize(text, encoding: bool = True): 41 | tokens = list() 42 | 43 | for t in text: 44 | if encoding: 45 | tokens.append(VOCAB_DICT[t]) 46 | else: 47 | tokens.append(t) 48 | 49 | if encoding: 50 | tokens.append(VOCAB_DICT['~']) 51 | else: 52 | tokens.append('~') 53 | 54 | return tokens 55 | 56 | 57 | def text_to_sequence(text): 58 | text = g2p(text) 59 | text = normalize(text) 60 | tokens = tokenize(text, encoding=True) 61 | return tokens -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import torch 5 | import yaml 6 | import torch.nn as nn 7 | from torch.utils.tensorboard import SummaryWriter 8 | from torch.cuda.amp import autocast, GradScaler 9 | from tqdm import tqdm 10 | 11 | from utils.model import get_model, get_param_num 12 | from utils.tools import to_device, log, clip_grad_value_, AttrDict 13 | from model import CVAEJETSLoss 14 | from data_utils import AudioTextDataset, AudioTextCollate, DataLoader, DistributedBucketSampler 15 | from evaluate import evaluate 16 | import json 17 | import random 18 | random.seed(1234) 19 | 20 | 21 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 22 | 23 | torch.autograd.set_detect_anomaly(True) 24 | def main(args, configs): 25 | print("Prepare training ...") 26 | 27 | preprocess_config, model_config, train_config = configs 28 | hop_size = preprocess_config["preprocessing"]["stft"]["hop_length"] 29 | 30 | dataset = AudioTextDataset( 31 | preprocess_config['path']['training_files'], preprocess_config) 32 | 33 | batch_size = train_config["optimizer"]["batch_size"] 34 | train_sampler = DistributedBucketSampler( 35 | dataset, 36 | batch_size, 37 | [32,300,400,500,600,700,800,900,1000], 38 | num_replicas=1, 39 | rank=0, 40 | shuffle=True) 41 | collate_fn = AudioTextCollate() 42 | loader = DataLoader( 43 | dataset, 44 | num_workers=8, 45 | shuffle=False, 46 | pin_memory=True, 47 | collate_fn=collate_fn, 48 | batch_sampler=train_sampler 49 | ) 50 | 51 | # Prepare model 52 | (model, discriminator, 53 | model_optimizer, discriminator_optimizer, 54 | scheduler_model, scheduler_discriminator, 55 | epoch) = get_model( 56 | args, configs, device, train=True) 57 | 58 | scaler = GradScaler(enabled=train_config["fp16_run"]) 59 | 60 | model = nn.DataParallel(model) 61 | discriminator = nn.DataParallel(discriminator) 62 | model_num_param = get_param_num(model) 63 | discriminator_num_param = get_param_num(discriminator) 64 | Loss = CVAEJETSLoss(preprocess_config, model_config, train_config).to(device) 65 | print("Number of JETS Parameters:", model_num_param) 66 | print("Number of Discriminator Parameters:", discriminator_num_param) 67 | 68 | # Init logger 69 | for p in train_config["path"].values(): 70 | os.makedirs(p, exist_ok=True) 71 | train_log_path = os.path.join(train_config["path"]["log_path"], "train") 72 | val_log_path = os.path.join(train_config["path"]["log_path"], "val") 73 | os.makedirs(train_log_path, exist_ok=True) 74 | os.makedirs(val_log_path, exist_ok=True) 75 | train_logger = SummaryWriter(train_log_path) 76 | val_logger = SummaryWriter(val_log_path) 77 | 78 | # Training 79 | step = args.restore_step + 1 80 | total_step = train_config["step"]["total_step"] 81 | log_step = train_config["step"]["log_step"] 82 | save_step = train_config["step"]["save_step"] 83 | val_step = train_config["step"]["val_step"] 84 | 85 | outer_bar = tqdm(total=total_step, desc="Training", position=0) 86 | outer_bar.n = args.restore_step 87 | outer_bar.update() 88 | 89 | while True: 90 | inner_bar = tqdm(total=len(loader), desc="Epoch {}".format(epoch), position=1) 91 | for batch in loader: 92 | batch = to_device(batch, device) 93 | 94 | with autocast(enabled=train_config["fp16_run"]): 95 | output = model(*(batch[:-1]), step=step, gen=False) 96 | 97 | # wav_predictions, wav_targets, indices 98 | wav_predictions, indices = output[0], output[7] 99 | wav_targets = batch[-1].unsqueeze(1)[...,indices[0]*hop_size:indices[1]*hop_size] 100 | 101 | # Discriminator 102 | y_d_hat_r, y_d_hat_g, _, _ = discriminator(wav_targets, wav_predictions.detach()) 103 | 104 | with autocast(enabled=False): 105 | loss_disc, losses_disc = Loss.disc_loss_fn( 106 | disc_real_outputs=y_d_hat_r, disc_generated_outputs=y_d_hat_g) 107 | 108 | # Discriminator Backward 109 | discriminator_optimizer.zero_grad() 110 | scaler.scale(loss_disc).backward() 111 | scaler.unscale_(discriminator_optimizer) 112 | grad_norm_discriminator = clip_grad_value_(discriminator.parameters(), None) 113 | scaler.step(discriminator_optimizer) 114 | 115 | with autocast(enabled=train_config["fp16_run"]): 116 | # Generator 117 | y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = discriminator(wav_targets, wav_predictions) 118 | 119 | with autocast(enabled=False): 120 | loss_model, losses_model = Loss.gen_loss_fn( 121 | inputs=batch, 122 | predictions=output, 123 | step=step, 124 | disc_outputs=y_d_hat_g, 125 | fmap_r=fmap_r, 126 | fmap_g=fmap_g) 127 | 128 | # Generator Backward 129 | model_optimizer.zero_grad() 130 | scaler.scale(loss_model).backward() 131 | scaler.unscale_(model_optimizer) 132 | grad_norm_model = clip_grad_value_(model.parameters(), None) 133 | scaler.step(model_optimizer) 134 | scaler.update() 135 | 136 | if step % log_step == 0: 137 | lr = model_optimizer.param_groups[0]['lr'] 138 | message1 = "Step {}/{}, ".format(step, total_step) 139 | message2 = " ".join( 140 | [str(round(l.item(), 4)) for l in losses_disc.values()] + 141 | [str(round(l.item(), 4)) for l in losses_model.values()] + 142 | [str(round(grad_norm_model, 4)), str(round(grad_norm_discriminator, 4)), str(lr)] 143 | ).strip() 144 | 145 | with open(os.path.join(train_log_path, "log.txt"), "a") as f: 146 | f.write(message1 + message2 + "\n") 147 | 148 | outer_bar.write(message1 + message2) 149 | 150 | scalars = {} 151 | scalars.update(losses_disc) 152 | scalars.update(losses_model) 153 | scalars.update( 154 | { 155 | "learning_rate": lr, 156 | "grad_norm_discriminator": grad_norm_discriminator, 157 | "grad_norm_model": grad_norm_model 158 | } 159 | ) 160 | log(writer=train_logger, 161 | global_step=step, 162 | scalars=scalars) 163 | 164 | if step % val_step == 0: 165 | model.eval() 166 | discriminator.eval() 167 | message = evaluate([model, discriminator], step, configs, device, val_logger) 168 | with open(os.path.join(val_log_path, "log.txt"), "a") as f: 169 | f.write(message + "\n") 170 | outer_bar.write(message) 171 | 172 | model.train() 173 | discriminator.train() 174 | 175 | if step % save_step == 0: 176 | torch.save( 177 | { 178 | "model": model.module.state_dict(), 179 | "discriminator": discriminator.module.state_dict(), 180 | "model_optimizer": model_optimizer.state_dict(), 181 | "discriminator_optimizer": discriminator_optimizer.state_dict(), 182 | "iteration": epoch, 183 | }, 184 | os.path.join( 185 | train_config["path"]["ckpt_path"], 186 | "{}.pth.tar".format(step), 187 | ), 188 | ) 189 | 190 | if step == total_step: 191 | quit() 192 | step += 1 193 | outer_bar.update(1) 194 | 195 | epoch += 1 196 | scheduler_model.step() 197 | scheduler_discriminator.step() 198 | inner_bar.update(1) 199 | 200 | 201 | if __name__ == "__main__": 202 | parser = argparse.ArgumentParser() 203 | parser.add_argument("--restore_step", type=int, default=0) 204 | parser.add_argument( 205 | "-p", 206 | "--preprocess_config", 207 | type=str, 208 | required=True, 209 | help="path to preprocess.yaml", 210 | ) 211 | parser.add_argument( 212 | "-m", "--model_config", type=str, required=True, help="path to model.yaml" 213 | ) 214 | parser.add_argument( 215 | "-t", "--train_config", type=str, required=True, help="path to train.yaml" 216 | ) 217 | args = parser.parse_args() 218 | 219 | # Read Config 220 | preprocess_config = yaml.load(open(args.preprocess_config, "r"), Loader=yaml.FullLoader) 221 | model_config = yaml.load(open(args.model_config, "r"), Loader=yaml.FullLoader) 222 | train_config = yaml.load(open(args.train_config, "r"), Loader=yaml.FullLoader) 223 | 224 | configs = (preprocess_config, model_config, train_config) 225 | 226 | main(args, configs) 227 | -------------------------------------------------------------------------------- /utils/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | import torch 5 | import numpy as np 6 | 7 | from model import CVAEJETSSynthesizer, MultiPeriodDiscriminator 8 | 9 | 10 | def get_model(args, configs, device, train=False): 11 | (preprocess_config, model_config, train_config) = configs 12 | 13 | model = CVAEJETSSynthesizer(preprocess_config, model_config, train_config).to(device) 14 | if args.restore_step: 15 | ckpt_path = os.path.join( 16 | train_config["path"]["ckpt_path"], 17 | "{}.pth.tar".format(args.restore_step), 18 | ) 19 | ckpt = torch.load(ckpt_path, map_location=device) 20 | model.load_state_dict(ckpt["model"]) 21 | 22 | if train: 23 | discriminator = MultiPeriodDiscriminator().to(device) 24 | 25 | model_optimizer = torch.optim.AdamW( 26 | model.parameters(), 27 | train_config["optimizer"]["learning_rate"], 28 | betas=train_config["optimizer"]["betas"], 29 | eps=train_config["optimizer"]["eps"]) 30 | discriminator_optimizer = torch.optim.AdamW( 31 | discriminator.parameters(), 32 | train_config["optimizer"]["learning_rate"], 33 | betas=train_config["optimizer"]["betas"], 34 | eps=train_config["optimizer"]["eps"]) 35 | 36 | if args.restore_step: 37 | discriminator.load_state_dict(ckpt["discriminator"]) 38 | model_optimizer.load_state_dict(ckpt["model_optimizer"]) 39 | discriminator_optimizer.load_state_dict(ckpt["discriminator_optimizer"]) 40 | iteration = ckpt['iteration'] 41 | else: 42 | iteration = 1 43 | 44 | scheduler_model = torch.optim.lr_scheduler.ExponentialLR( 45 | model_optimizer, gamma=train_config["optimizer"]["lr_decay"], last_epoch=iteration-2) 46 | scheduler_discriminator = torch.optim.lr_scheduler.ExponentialLR( 47 | discriminator_optimizer, gamma=train_config["optimizer"]["lr_decay"], last_epoch=iteration-2) 48 | 49 | model.train() 50 | discriminator.train() 51 | return model, discriminator, model_optimizer, discriminator_optimizer, scheduler_model, scheduler_discriminator, iteration 52 | 53 | model.eval() 54 | model.requires_grad_ = False 55 | return model 56 | 57 | 58 | def get_param_num(model): 59 | num_param = sum(param.numel() for param in model.parameters()) 60 | return num_param 61 | -------------------------------------------------------------------------------- /utils/pitch_utils.py: -------------------------------------------------------------------------------- 1 | ######### 2 | # world 3 | ######### 4 | import librosa 5 | import parselmouth 6 | import numpy as np 7 | import pyworld as pw 8 | import torch 9 | import torch.nn.functional as F 10 | from pycwt import wavelet 11 | from scipy.interpolate import interp1d 12 | 13 | gamma = 0 14 | mcepInput = 3 # 0 for dB, 3 for magnitude 15 | alpha = 0.45 16 | en_floor = 10 ** (-80 / 20) 17 | FFT_SIZE = 2048 18 | 19 | 20 | f0_bin = 256 21 | f0_max = 1100.0 22 | f0_min = 50.0 23 | f0_mel_min = 1127 * np.log(1 + f0_min / 700) 24 | f0_mel_max = 1127 * np.log(1 + f0_max / 700) 25 | 26 | 27 | def dur_to_mel2ph(dur, dur_padding=None, alpha=1.0): 28 | """ 29 | Example (no batch dim version): 30 | 1. dur = [2,2,3] 31 | 2. token_idx = [[1],[2],[3]], dur_cumsum = [2,4,7], dur_cumsum_prev = [0,2,4] 32 | 3. token_mask = [[1,1,0,0,0,0,0], 33 | [0,0,1,1,0,0,0], 34 | [0,0,0,0,1,1,1]] 35 | 4. token_idx * token_mask = [[1,1,0,0,0,0,0], 36 | [0,0,2,2,0,0,0], 37 | [0,0,0,0,3,3,3]] 38 | 5. (token_idx * token_mask).sum(0) = [1,1,2,2,3,3,3] 39 | :param dur: Batch of durations of each frame (B, T_txt) 40 | :param dur_padding: Batch of padding of each frame (B, T_txt) 41 | :param alpha: duration rescale coefficient 42 | :return: 43 | mel2ph (B, T_speech) 44 | """ 45 | assert alpha > 0 46 | dur = torch.round(dur.float() * alpha).long() 47 | if dur_padding is not None: 48 | dur = dur * (1 - dur_padding.long()) 49 | token_idx = torch.arange(1, dur.shape[1] + 1)[None, :, None].to(dur.device) 50 | dur_cumsum = torch.cumsum(dur, 1) 51 | dur_cumsum_prev = F.pad(dur_cumsum, [1, -1], mode="constant", value=0) 52 | 53 | pos_idx = torch.arange(dur.sum(-1).max())[None, None].to(dur.device) 54 | token_mask = (pos_idx >= dur_cumsum_prev[:, :, None]) & (pos_idx < dur_cumsum[:, :, None]) 55 | mel2ph = (token_idx * token_mask.long()).sum(1) 56 | return mel2ph 57 | 58 | 59 | def get_f0cwt(f0): 60 | uv, cont_lf0_lpf = get_cont_lf0(f0) 61 | logf0s_mean_org, logf0s_std_org = np.mean(cont_lf0_lpf), np.std(cont_lf0_lpf) 62 | logf0s_mean_std_org = np.array([logf0s_mean_org, logf0s_std_org]) 63 | cont_lf0_lpf_norm = (cont_lf0_lpf - logf0s_mean_org) / logf0s_std_org 64 | Wavelet_lf0, scales = get_lf0_cwt(cont_lf0_lpf_norm) 65 | return Wavelet_lf0, scales, logf0s_mean_std_org 66 | 67 | 68 | def f0_to_coarse(f0): 69 | is_torch = isinstance(f0, torch.Tensor) 70 | f0_mel = 1127 * (1 + f0 / 700).log() if is_torch else 1127 * np.log(1 + f0 / 700) 71 | f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * (f0_bin - 2) / (f0_mel_max - f0_mel_min) + 1 72 | 73 | f0_mel[f0_mel <= 1] = 1 74 | f0_mel[f0_mel > f0_bin - 1] = f0_bin - 1 75 | f0_coarse = (f0_mel + 0.5).long() if is_torch else np.rint(f0_mel).astype(np.int) 76 | assert f0_coarse.max() <= 255 and f0_coarse.min() >= 1, (f0_coarse.max(), f0_coarse.min()) 77 | return f0_coarse 78 | 79 | 80 | def norm_f0(f0, uv, config): 81 | is_torch = isinstance(f0, torch.Tensor) 82 | if config["pitch_norm"] == "standard": 83 | f0 = (f0 - config["f0_mean"]) / config["f0_std"] 84 | if config["pitch_norm"] == "log": 85 | eps = config["pitch_norm_eps"] 86 | f0 = torch.log2(f0 + eps) if is_torch else np.log2(f0 + eps) 87 | if uv is not None and config["use_uv"]: 88 | f0[uv > 0] = 0 89 | return f0 90 | 91 | 92 | def norm_interp_f0(f0, config): 93 | # is_torch = isinstance(f0, torch.Tensor) 94 | # if is_torch: 95 | # device = f0.device 96 | # f0 = f0.data.cpu().numpy() 97 | uv = f0 == 0 98 | f0 = norm_f0(f0, uv, config) 99 | if sum(uv) == len(f0): 100 | f0[uv] = 0 101 | elif sum(uv) > 0: 102 | f0[uv] = np.interp(np.where(uv)[0], np.where(~uv)[0], f0[~uv]) 103 | # uv = torch.FloatTensor(uv) 104 | # f0 = torch.FloatTensor(f0) 105 | # if is_torch: 106 | # f0 = f0.to(device) 107 | return f0, uv 108 | 109 | 110 | def denorm_f0(f0, uv, config, pitch_padding=None, min=None, max=None): 111 | if config["pitch_norm"] == "standard": 112 | f0 = f0 * config["f0_std"] + config["f0_mean"] 113 | if config["pitch_norm"] == "log": 114 | f0 = 2 ** f0 115 | if min is not None: 116 | f0 = f0.clamp(min=min) 117 | if max is not None: 118 | f0 = f0.clamp(max=max) 119 | if uv is not None and config["use_uv"]: 120 | f0[uv > 0] = 0 121 | if pitch_padding is not None: 122 | f0[pitch_padding] = 0 123 | return f0 124 | 125 | 126 | def get_pitch(wav_data, mel, config): 127 | """ 128 | :param wav_data: [T] 129 | :param mel: [T, 80] 130 | :param config: 131 | :return: 132 | """ 133 | sampling_rate = config["preprocessing"]["audio"]["sampling_rate"] 134 | hop_length = config["preprocessing"]["stft"]["hop_length"] 135 | time_step = hop_length / sampling_rate * 1000 136 | f0_min = 80 137 | f0_max = 750 138 | 139 | if hop_length == 128: 140 | pad_size = 4 141 | elif hop_length == 256: 142 | pad_size = 2 143 | else: 144 | assert False 145 | 146 | f0 = parselmouth.Sound(wav_data, sampling_rate).to_pitch_ac( 147 | time_step=time_step / 1000, voicing_threshold=0.6, 148 | pitch_floor=f0_min, pitch_ceiling=f0_max).selected_array["frequency"] 149 | f0 = f0[:len(mel)-8] # to avoid negative rpad 150 | lpad = pad_size * 2 151 | rpad = len(mel) - len(f0) - lpad 152 | f0 = np.pad(f0, [[lpad, rpad]], mode="constant") 153 | # mel and f0 are extracted by 2 different libraries. we should force them to have the same length. 154 | # Attention: we find that new version of some libraries could cause ``rpad'' to be a negetive value... 155 | # Just to be sure, we recommend users to set up the same environments as them in requirements_auto.txt (by Anaconda) 156 | delta_l = len(mel) - len(f0) 157 | assert np.abs(delta_l) <= 8 158 | if delta_l > 0: 159 | f0 = np.concatenate([f0, [f0[-1]] * delta_l], 0) 160 | f0 = f0[:len(mel)] 161 | 162 | # f0, t = pw.dio( 163 | # wav_data.astype(np.float64), 164 | # sampling_rate, 165 | # frame_period=hop_length / sampling_rate * 1000, 166 | # ) 167 | # f0 = pw.stonemask(wav_data.astype(np.float64), f0, t, sampling_rate) 168 | # if np.sum(f0 != 0) <= 1: 169 | # return None, None 170 | 171 | pitch_coarse = f0_to_coarse(f0) 172 | return f0, pitch_coarse 173 | 174 | 175 | def expand_f0_ph(f0, mel2ph, config): 176 | f0 = denorm_f0(f0, None, config) 177 | f0 = F.pad(f0, [1, 0]) 178 | f0 = torch.gather(f0, 1, mel2ph) # [B, T_mel] 179 | return f0 180 | 181 | 182 | ######### 183 | # cwt 184 | ######### 185 | 186 | 187 | def load_wav(wav_file, sr): 188 | wav, _ = librosa.load(wav_file, sr=sr, mono=True) 189 | return wav 190 | 191 | 192 | def convert_continuos_f0(f0): 193 | """CONVERT F0 TO CONTINUOUS F0 194 | Args: 195 | f0 (ndarray): original f0 sequence with the shape (T) 196 | Return: 197 | (ndarray): continuous f0 with the shape (T) 198 | """ 199 | # get uv information as binary 200 | f0 = np.copy(f0) 201 | uv = np.float32(f0 != 0) 202 | 203 | # get start and end of f0 204 | if (f0 == 0).all(): 205 | print("| all of the f0 values are 0.") 206 | return uv, f0 207 | start_f0 = f0[f0 != 0][0] 208 | end_f0 = f0[f0 != 0][-1] 209 | 210 | # padding start and end of f0 sequence 211 | start_idx = np.where(f0 == start_f0)[0][0] 212 | end_idx = np.where(f0 == end_f0)[0][-1] 213 | f0[:start_idx] = start_f0 214 | f0[end_idx:] = end_f0 215 | 216 | # get non-zero frame index 217 | nz_frames = np.where(f0 != 0)[0] 218 | 219 | # perform linear interpolation 220 | f = interp1d(nz_frames, f0[nz_frames]) 221 | cont_f0 = f(np.arange(0, f0.shape[0])) 222 | 223 | return uv, cont_f0 224 | 225 | 226 | def get_cont_lf0(f0, frame_period=5.0): 227 | uv, cont_f0_lpf = convert_continuos_f0(f0) 228 | # cont_f0_lpf = low_pass_filter(cont_f0_lpf, int(1.0 / (frame_period * 0.001)), cutoff=20) 229 | cont_lf0_lpf = np.log(cont_f0_lpf) 230 | return uv, cont_lf0_lpf 231 | 232 | 233 | def get_lf0_cwt(lf0): 234 | """ 235 | input: 236 | signal of shape (N) 237 | output: 238 | Wavelet_lf0 of shape(10, N), scales of shape(10) 239 | """ 240 | mother = wavelet.MexicanHat() 241 | dt = 0.005 242 | dj = 1 243 | s0 = dt * 2 244 | J = 9 245 | 246 | Wavelet_lf0, scales, _, _, _, _ = wavelet.cwt(np.squeeze(lf0), dt, dj, s0, J, mother) 247 | # Wavelet.shape => (J + 1, len(lf0)) 248 | Wavelet_lf0 = np.real(Wavelet_lf0).T 249 | return Wavelet_lf0, scales 250 | 251 | 252 | def norm_scale(Wavelet_lf0): 253 | Wavelet_lf0_norm = np.zeros((Wavelet_lf0.shape[0], Wavelet_lf0.shape[1])) 254 | mean = Wavelet_lf0.mean(0)[None, :] 255 | std = Wavelet_lf0.std(0)[None, :] 256 | Wavelet_lf0_norm = (Wavelet_lf0 - mean) / std 257 | return Wavelet_lf0_norm, mean, std 258 | 259 | 260 | def normalize_cwt_lf0(f0, mean, std): 261 | uv, cont_lf0_lpf = get_cont_lf0(f0) 262 | cont_lf0_norm = (cont_lf0_lpf - mean) / std 263 | Wavelet_lf0, scales = get_lf0_cwt(cont_lf0_norm) 264 | Wavelet_lf0_norm, _, _ = norm_scale(Wavelet_lf0) 265 | 266 | return Wavelet_lf0_norm 267 | 268 | 269 | def get_lf0_cwt_norm(f0s, mean, std): 270 | uvs = [] 271 | cont_lf0_lpfs = [] 272 | cont_lf0_lpf_norms = [] 273 | Wavelet_lf0s = [] 274 | Wavelet_lf0s_norm = [] 275 | scaless = [] 276 | 277 | means = [] 278 | stds = [] 279 | for f0 in f0s: 280 | uv, cont_lf0_lpf = get_cont_lf0(f0) 281 | cont_lf0_lpf_norm = (cont_lf0_lpf - mean) / std 282 | 283 | Wavelet_lf0, scales = get_lf0_cwt(cont_lf0_lpf_norm) # [560,10] 284 | Wavelet_lf0_norm, mean_scale, std_scale = norm_scale(Wavelet_lf0) # [560,10],[1,10],[1,10] 285 | 286 | Wavelet_lf0s_norm.append(Wavelet_lf0_norm) 287 | uvs.append(uv) 288 | cont_lf0_lpfs.append(cont_lf0_lpf) 289 | cont_lf0_lpf_norms.append(cont_lf0_lpf_norm) 290 | Wavelet_lf0s.append(Wavelet_lf0) 291 | scaless.append(scales) 292 | means.append(mean_scale) 293 | stds.append(std_scale) 294 | 295 | return Wavelet_lf0s_norm, scaless, means, stds 296 | 297 | 298 | def inverse_cwt_torch(Wavelet_lf0, scales): 299 | import torch 300 | b = ((torch.arange(0, len(scales)).float().to(Wavelet_lf0.device)[None, None, :] + 1 + 2.5) ** (-2.5)) 301 | lf0_rec = Wavelet_lf0 * b 302 | lf0_rec_sum = lf0_rec.sum(-1) 303 | lf0_rec_sum = (lf0_rec_sum - lf0_rec_sum.mean(-1, keepdim=True)) / lf0_rec_sum.std(-1, keepdim=True) 304 | return lf0_rec_sum 305 | 306 | 307 | def inverse_cwt(Wavelet_lf0, scales): 308 | b = ((np.arange(0, len(scales))[None, None, :] + 1 + 2.5) ** (-2.5)) 309 | lf0_rec = Wavelet_lf0 * b 310 | lf0_rec_sum = lf0_rec.sum(-1) 311 | lf0_rec_sum = (lf0_rec_sum - lf0_rec_sum.mean(-1, keepdims=True)) / lf0_rec_sum.std(-1, keepdims=True) 312 | return lf0_rec_sum 313 | 314 | 315 | def cwt2f0(cwt_spec, mean, std, cwt_scales): 316 | assert len(mean.shape) == 1 and len(std.shape) == 1 and len(cwt_spec.shape) == 3 317 | import torch 318 | if isinstance(cwt_spec, torch.Tensor): 319 | f0 = inverse_cwt_torch(cwt_spec, cwt_scales) 320 | f0 = f0 * std[:, None] + mean[:, None] 321 | f0 = f0.exp() # [B, T] 322 | else: 323 | f0 = inverse_cwt(cwt_spec, cwt_scales) 324 | f0 = f0 * std[:, None] + mean[:, None] 325 | f0 = np.exp(f0) # [B, T] 326 | return f0 327 | 328 | 329 | def cwt2f0_norm(cwt_spec, mean, std, mel2ph, config): 330 | f0 = cwt2f0(cwt_spec, mean, std, config["cwt_scales"]) 331 | f0 = torch.cat( 332 | [f0] + [f0[:, -1:]] * (mel2ph.shape[1] - f0.shape[1]), 1) 333 | f0_norm = norm_f0(f0, None, config) 334 | return f0_norm -------------------------------------------------------------------------------- /utils/stft_loss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright 2019 Tomoki Hayashi 4 | # MIT License (https://opensource.org/licenses/MIT) 5 | 6 | """STFT-based Loss modules.""" 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | 11 | 12 | def stft(x, fft_size, hop_size, win_length, window): 13 | """Perform STFT and convert to magnitude spectrogram. 14 | Args: 15 | x (Tensor): Input signal tensor (B, T). 16 | fft_size (int): FFT size. 17 | hop_size (int): Hop size. 18 | win_length (int): Window length. 19 | window (str): Window function type. 20 | Returns: 21 | Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1). 22 | """ 23 | x_stft = torch.stft(x, fft_size, hop_size, win_length, window.to(x.device)) 24 | real = x_stft[..., 0] 25 | imag = x_stft[..., 1] 26 | 27 | # NOTE(kan-bayashi): clamp is needed to avoid nan or inf 28 | return torch.sqrt(torch.clamp(real ** 2 + imag ** 2, min=1e-7)).transpose(2, 1) 29 | 30 | 31 | class SpectralConvergengeLoss(torch.nn.Module): 32 | """Spectral convergence loss module.""" 33 | 34 | def __init__(self): 35 | """Initilize spectral convergence loss module.""" 36 | super(SpectralConvergengeLoss, self).__init__() 37 | 38 | def forward(self, x_mag, y_mag): 39 | """Calculate forward propagation. 40 | Args: 41 | x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins). 42 | y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins). 43 | Returns: 44 | Tensor: Spectral convergence loss value. 45 | """ 46 | 47 | return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro") 48 | 49 | 50 | # if(y_mag.shape == x_mag.shape): 51 | # return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro") 52 | # else: 53 | # return torch.norm(y_mag[:,:-1,:] - x_mag, p="fro") / torch.norm(y_mag[:,:-1,:], p="fro") 54 | 55 | 56 | class LogSTFTMagnitudeLoss(torch.nn.Module): 57 | """Log STFT magnitude loss module.""" 58 | 59 | def __init__(self): 60 | """Initilize los STFT magnitude loss module.""" 61 | super(LogSTFTMagnitudeLoss, self).__init__() 62 | 63 | def forward(self, x_mag, y_mag): 64 | """Calculate forward propagation. 65 | Args: 66 | x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins). 67 | y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins). 68 | Returns: 69 | Tensor: Log STFT magnitude loss value. 70 | """ 71 | 72 | return F.l1_loss(torch.log(y_mag), torch.log(x_mag)) 73 | 74 | 75 | class STFTLoss(torch.nn.Module): 76 | """STFT loss module.""" 77 | 78 | def __init__(self, fft_size, shift_size, win_length, window="hann_window"): 79 | """Initialize STFT loss module.""" 80 | super(STFTLoss, self).__init__() 81 | self.fft_size = fft_size 82 | self.shift_size = shift_size 83 | self.win_length = win_length 84 | self.window = getattr(torch, window)(win_length) 85 | self.spectral_convergenge_loss = SpectralConvergengeLoss() 86 | self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss() 87 | 88 | def forward(self, x, y): 89 | """Calculate forward propagation. 90 | Args: 91 | x (Tensor): Predicted signal (B, T). 92 | y (Tensor): Groundtruth signal (B, T). 93 | Returns: 94 | Tensor: Spectral convergence loss value. 95 | Tensor: Log STFT magnitude loss value. 96 | """ 97 | 98 | x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window) 99 | y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, self.window) 100 | 101 | if(y_mag.shape != x_mag.shape): 102 | y_mag = y_mag[:,:-1,:] 103 | 104 | sc_loss = self.spectral_convergenge_loss(x_mag, y_mag) 105 | mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag) 106 | 107 | return sc_loss, mag_loss 108 | 109 | 110 | class MultiResolutionSTFTLoss(torch.nn.Module): 111 | """Multi resolution STFT loss module.""" 112 | 113 | def __init__(self, fft_sizes, hop_sizes, win_lengths, window="hann_window"): 114 | """Initialize Multi resolution STFT loss module. 115 | Args: 116 | fft_sizes (list): List of FFT sizes. 117 | hop_sizes (list): List of hop sizes. 118 | win_lengths (list): List of window lengths. 119 | window (str): Window function type. 120 | """ 121 | super(MultiResolutionSTFTLoss, self).__init__() 122 | assert len(fft_sizes) == len(hop_sizes) == len(win_lengths) 123 | self.stft_losses = torch.nn.ModuleList() 124 | for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths): 125 | self.stft_losses += [STFTLoss(fs, ss, wl, window)] 126 | 127 | def forward(self, x, y): 128 | """Calculate forward propagation. 129 | Args: 130 | x (Tensor): Predicted signal (B, T). 131 | y (Tensor): Groundtruth signal (B, T). 132 | Returns: 133 | Tensor: Multi resolution spectral convergence loss value. 134 | Tensor: Multi resolution log STFT magnitude loss value. 135 | """ 136 | sc_loss = 0.0 137 | mag_loss = 0.0 138 | for f in self.stft_losses: 139 | sc_l, mag_l = f(x, y) 140 | sc_loss += sc_l 141 | mag_loss += mag_l 142 | sc_loss /= len(self.stft_losses) 143 | mag_loss /= len(self.stft_losses) 144 | 145 | return sc_loss + mag_loss --------------------------------------------------------------------------------