├── tests ├── __init__.py ├── test_jsut_corpus.py ├── test_predictor.py ├── test_g2p_util.py ├── test_metrics.py └── test_modules.py ├── marine ├── bin │ ├── __init__.py │ ├── conf │ │ └── train │ │ │ ├── __init__.py │ │ │ ├── data │ │ │ ├── __init__.py │ │ │ ├── ap.yaml │ │ │ ├── binary.yaml │ │ │ └── simple.yaml │ │ │ ├── model │ │ │ ├── __init__.py │ │ │ ├── crf-mtl.yaml │ │ │ ├── att-mtl.yaml │ │ │ └── linear-att-mtl.yaml │ │ │ ├── optim │ │ │ ├── __init__.py │ │ │ └── adam_plate.yaml │ │ │ ├── train │ │ │ ├── __init__.py │ │ │ └── local.yaml │ │ │ ├── criterions │ │ │ ├── __init__.py │ │ │ ├── crf-mtl.yaml │ │ │ ├── att-mtl.yaml │ │ │ └── linear-att-mtl.yaml │ │ │ └── config.yaml │ ├── build_vocab.py │ ├── prepare_features_pyopenjtalk.py │ ├── test.py │ ├── jsut2corpus.py │ └── make_raw_corpus.py ├── data │ ├── __init__.py │ ├── feature │ │ ├── __init__.py │ │ └── feature_set.py │ ├── dictionaries │ │ └── .gitignore │ ├── dataset.py │ ├── util.py │ └── pad.py ├── modules │ ├── __init__.py │ └── attention.py ├── utils │ ├── __init__.py │ ├── g2p_util │ │ ├── __init__.py │ │ ├── boundary.py │ │ ├── accent.py │ │ ├── util.py │ │ └── g2p.py │ ├── regex.py │ ├── pretrained.py │ ├── post_process.py │ ├── metrics.py │ └── openjtalk_util.py ├── __init__.py ├── dict │ ├── accent_status │ │ └── artifact.tsv │ └── accent_phrase_boundary │ │ └── artifact.tsv ├── criterions │ ├── __init__.py │ ├── crossentopyloss.py │ └── log_likelihood.py ├── models │ ├── __init__.py │ ├── bilstm_encoder.py │ ├── base_model.py │ ├── embedding.py │ ├── linear_decoder.py │ ├── crf_decoder.py │ ├── util.py │ └── att_lstm_decoder.py ├── logger.py └── types.py ├── MANIFEST.in ├── recipe ├── common │ ├── database │ │ ├── 20220912_jsut_script_ids │ │ │ ├── val │ │ │ │ └── ids.pkl │ │ │ └── test │ │ │ │ └── ids.pkl │ │ └── 20220912_jsut_vocab_min_2 │ │ │ └── vocab.pkl │ ├── build_vocab.sh │ ├── jsut2corpus.sh │ ├── make_raw_corpus.sh │ ├── pack_corpus.sh │ ├── train.sh │ └── parse_options.sh ├── 20220912_release │ ├── conf │ │ └── train │ │ │ ├── criterions │ │ │ └── loglikehood.yaml │ │ │ ├── config.yaml │ │ │ ├── optim │ │ │ └── adam.yaml │ │ │ ├── train │ │ │ └── basic.yaml │ │ │ ├── data │ │ │ └── mora_based_seq.yaml │ │ │ └── model │ │ │ └── mtl_lstm_encoder_crf_decoder.yaml │ └── run.sh └── 20250122_marine-plus │ ├── conf │ └── train │ │ ├── criterions │ │ └── loglikehood.yaml │ │ ├── config.yaml │ │ ├── optim │ │ └── adam.yaml │ │ ├── train │ │ └── basic.yaml │ │ ├── data │ │ └── mora_based_seq.yaml │ │ └── model │ │ └── mtl_lstm_encoder_crf_decoder.yaml │ ├── visualize_wrong_mora_diff.py │ └── run.sh ├── .vscode ├── extensions.json └── settings.json ├── .editorconfig ├── .github └── workflows │ ├── deploy.yml │ └── ci.yml ├── pyproject.toml ├── .gitignore ├── setup.py └── README.md /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /marine/bin/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /marine/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /marine/modules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /marine/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /marine/bin/conf/train/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /marine/data/feature/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /marine/bin/conf/train/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /marine/bin/conf/train/model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /marine/bin/conf/train/optim/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /marine/bin/conf/train/train/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /marine/bin/conf/train/criterions/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /marine/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.0.6-post3" 2 | -------------------------------------------------------------------------------- /marine/data/dictionaries/.gitignore: -------------------------------------------------------------------------------- 1 | *.dic 2 | *.json 3 | -------------------------------------------------------------------------------- /marine/dict/accent_status/artifact.tsv: -------------------------------------------------------------------------------- 1 | 音声合成 音声_合成/オ,ン,セ,ー_ゴ@,ー,セ,ー 2 | -------------------------------------------------------------------------------- /marine/dict/accent_phrase_boundary/artifact.tsv: -------------------------------------------------------------------------------- 1 | 音声合成 音声_合成/オ,ン,セ,ー_ゴ,ー,セ,ー 2 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md LICENSE 2 | recursive-include marine *.py *.yaml *.tsv 3 | -------------------------------------------------------------------------------- /marine/criterions/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa: F401 2 | from .crossentopyloss import CrossEntropyLoss 3 | from .log_likelihood import LogLikelhood 4 | -------------------------------------------------------------------------------- /marine/utils/g2p_util/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa: F401 2 | from .g2p import mora2phon, pron2mora, pron2phon 3 | from .util import SUPPORTED_MORA, get_phoneme 4 | -------------------------------------------------------------------------------- /recipe/common/database/20220912_jsut_script_ids/val/ids.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsukumijima/marine-plus/main/recipe/common/database/20220912_jsut_script_ids/val/ids.pkl -------------------------------------------------------------------------------- /recipe/common/database/20220912_jsut_vocab_min_2/vocab.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsukumijima/marine-plus/main/recipe/common/database/20220912_jsut_vocab_min_2/vocab.pkl -------------------------------------------------------------------------------- /recipe/common/database/20220912_jsut_script_ids/test/ids.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsukumijima/marine-plus/main/recipe/common/database/20220912_jsut_script_ids/test/ids.pkl -------------------------------------------------------------------------------- /.vscode/extensions.json: -------------------------------------------------------------------------------- 1 | { 2 | "recommendations": [ 3 | "charliermarsh.ruff", 4 | "editorconfig.editorconfig", 5 | "ms-python.python", 6 | "ms-python.vscode-pylance", 7 | ] 8 | } 9 | -------------------------------------------------------------------------------- /marine/bin/conf/train/criterions/crf-mtl.yaml: -------------------------------------------------------------------------------- 1 | break: 2 | _target_: marine.criterions.LogLikelhood 3 | boundary: 4 | _target_: marine.criterions.LogLikelhood 5 | accent: 6 | _target_: marine.criterions.LogLikelhood -------------------------------------------------------------------------------- /recipe/common/build_vocab.sh: -------------------------------------------------------------------------------- 1 | # NOTE: the script is supposed to be used in recipes like: 2 | # . script.sh 3 | # Please don't try to run the shell directly. 4 | 5 | marine-build-vocab $feature_file_dir $vocab_dir -m $vocab_min_freq 6 | -------------------------------------------------------------------------------- /marine/bin/conf/train/config.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - criterions: att-mtl 3 | - optim: adam_plate 4 | - model: att-mtl 5 | - train: local 6 | - data: ap 7 | - override hydra/job_logging: colorlog 8 | - override hydra/hydra_logging: colorlog -------------------------------------------------------------------------------- /marine/models/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa: F403 2 | from .att_lstm_decoder import * 3 | from .base_model import * 4 | from .bilstm_encoder import * 5 | from .crf_decoder import * 6 | from .embedding import * 7 | from .linear_decoder import * 8 | from .util import * 9 | -------------------------------------------------------------------------------- /marine/bin/conf/train/criterions/att-mtl.yaml: -------------------------------------------------------------------------------- 1 | intonation_phrase_boundary: 2 | _target_: marine.criterions.LogLikelhood 3 | accent_phrase_boundary: 4 | _target_: marine.criterions.LogLikelhood 5 | accent_status: 6 | _target_: marine.criterions.CrossEntropyLoss 7 | -------------------------------------------------------------------------------- /marine/bin/conf/train/criterions/linear-att-mtl.yaml: -------------------------------------------------------------------------------- 1 | intonation_phrase_boundary: 2 | _target_: marine.criterions.CrossEntropyLoss 3 | accent_phrase_boundary: 4 | _target_: marine.criterions.CrossEntropyLoss 5 | accent_status: 6 | _target_: marine.criterions.CrossEntropyLoss -------------------------------------------------------------------------------- /recipe/20220912_release/conf/train/criterions/loglikehood.yaml: -------------------------------------------------------------------------------- 1 | intonation_phrase_boundary: 2 | _target_: marine.criterions.LogLikelhood 3 | accent_phrase_boundary: 4 | _target_: marine.criterions.LogLikelhood 5 | accent_status: 6 | _target_: marine.criterions.LogLikelhood 7 | -------------------------------------------------------------------------------- /recipe/20250122_marine-plus/conf/train/criterions/loglikehood.yaml: -------------------------------------------------------------------------------- 1 | intonation_phrase_boundary: 2 | _target_: marine.criterions.LogLikelhood 3 | accent_phrase_boundary: 4 | _target_: marine.criterions.LogLikelhood 5 | accent_status: 6 | _target_: marine.criterions.LogLikelhood 7 | -------------------------------------------------------------------------------- /marine/bin/conf/train/optim/adam_plate.yaml: -------------------------------------------------------------------------------- 1 | optimizer: 2 | _target_: torch.optim.Adam 3 | lr: 1e-3 4 | eps: 1e-7 5 | betas: [0.9, 0.999] 6 | 7 | scheduler: 8 | _target_: torch.optim.lr_scheduler.ReduceLROnPlateau 9 | mode: min 10 | threshold: 1e-3 11 | factor: 0.7 12 | -------------------------------------------------------------------------------- /recipe/20220912_release/conf/train/config.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - train: basic 3 | - data: mora_based_seq 4 | - model: mtl_lstm_encoder_crf_decoder 5 | - criterions: loglikehood 6 | - optim: adam 7 | - override hydra/job_logging: colorlog 8 | - override hydra/hydra_logging: colorlog 9 | -------------------------------------------------------------------------------- /recipe/20220912_release/conf/train/optim/adam.yaml: -------------------------------------------------------------------------------- 1 | optimizer: 2 | _target_: torch.optim.Adam 3 | lr: 1e-3 4 | eps: 1e-7 5 | betas: [0.9, 0.999] 6 | 7 | scheduler: 8 | _target_: torch.optim.lr_scheduler.ReduceLROnPlateau 9 | mode: min 10 | threshold: 1e-3 11 | factor: 0.7 12 | -------------------------------------------------------------------------------- /recipe/20250122_marine-plus/conf/train/config.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - train: basic 3 | - data: mora_based_seq 4 | - model: mtl_lstm_encoder_crf_decoder 5 | - criterions: loglikehood 6 | - optim: adam 7 | - override hydra/job_logging: colorlog 8 | - override hydra/hydra_logging: colorlog 9 | -------------------------------------------------------------------------------- /recipe/20250122_marine-plus/conf/train/optim/adam.yaml: -------------------------------------------------------------------------------- 1 | optimizer: 2 | _target_: torch.optim.Adam 3 | lr: 1e-3 4 | eps: 1e-7 5 | betas: [0.9, 0.999] 6 | 7 | scheduler: 8 | _target_: torch.optim.lr_scheduler.ReduceLROnPlateau 9 | mode: min 10 | threshold: 1e-3 11 | factor: 0.7 12 | -------------------------------------------------------------------------------- /recipe/common/jsut2corpus.sh: -------------------------------------------------------------------------------- 1 | # NOTE: the script is supposed to be used in recipes like: 2 | # . script.sh 3 | # Please don't try to run the shell directly. 4 | 5 | marine-jsut2corpus $jsut_script_path $raw_corpus_dir --accent_status_seq_level $accent_status_seq_level --accent_status_represent_mode $accent_status_represent_mode 6 | -------------------------------------------------------------------------------- /recipe/common/make_raw_corpus.sh: -------------------------------------------------------------------------------- 1 | # NOTE: the script is supposed to be used in recipes like: 2 | # . script.sh 3 | # Please don't try to run the shell directly. 4 | 5 | marine-make-raw-corpus $jsut_script_path $raw_corpus_dir --accent_status_seq_level $accent_status_seq_level --accent_status_represent_mode $accent_status_represent_mode 6 | -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | root = true 2 | 3 | [*] 4 | charset = utf-8 5 | end_of_line = lf 6 | insert_final_newline = true 7 | indent_size = 4 8 | indent_style = space 9 | trim_trailing_whitespace = true 10 | 11 | [*.md] 12 | indent_size = 2 13 | trim_trailing_whitespace = false 14 | 15 | [*.csv] 16 | insert_final_newline = false 17 | 18 | [*.yml] 19 | indent_size = 2 20 | -------------------------------------------------------------------------------- /marine/bin/conf/train/train/local.yaml: -------------------------------------------------------------------------------- 1 | seed: 12345 2 | verbose: 100 3 | 4 | num_epochs: 20 5 | checkpoint_interval: 5 6 | save_optimizer_state: true 7 | save_test_log: true 8 | save_vocab_path: true 9 | upload_tensorborad_event: true 10 | 11 | gpus: 1 12 | 13 | out_dir: checkpoint 14 | model_name: test 15 | test_checkpoint_filename: latest.pth 16 | 17 | tensorboard_event_path: null 18 | test_log_dir: null 19 | -------------------------------------------------------------------------------- /recipe/20220912_release/conf/train/train/basic.yaml: -------------------------------------------------------------------------------- 1 | seed: 12345 2 | verbose: 100 3 | 4 | num_epochs: 20 5 | checkpoint_interval: 5 6 | save_optimizer_state: true 7 | save_test_log: true 8 | save_vocab_path: true 9 | upload_tensorborad_event: true 10 | 11 | gpus: 1 12 | 13 | out_dir: checkpoint 14 | model_name: test 15 | test_checkpoint_filename: latest.pth 16 | 17 | tensorboard_event_path: null 18 | test_log_dir: null 19 | -------------------------------------------------------------------------------- /recipe/20250122_marine-plus/conf/train/train/basic.yaml: -------------------------------------------------------------------------------- 1 | seed: 12345 2 | verbose: 100 3 | 4 | num_epochs: 20 5 | checkpoint_interval: 5 6 | save_optimizer_state: true 7 | save_test_log: true 8 | save_vocab_path: true 9 | upload_tensorborad_event: true 10 | 11 | gpus: 1 12 | 13 | out_dir: checkpoint 14 | model_name: test 15 | test_checkpoint_filename: latest.pth 16 | 17 | tensorboard_event_path: null 18 | test_log_dir: null 19 | -------------------------------------------------------------------------------- /recipe/common/pack_corpus.sh: -------------------------------------------------------------------------------- 1 | # NOTE: the script is supposed to be used in recipes like: 2 | # . script.sh 3 | # Please don't try to run the shell directly. 4 | 5 | if [ -z ${exist_target_id_dir} ]; then 6 | cmd_args="-t $val_test_size" 7 | else 8 | cmd_args="--target_id_dir $exist_target_id_dir" 9 | fi 10 | 11 | marine-pack-corpus $raw_corpus_dir $feature_file_dir $vocab_path $feature_pack_dir -s $accent_status_seq_level -f $feature_table_key $cmd_args 12 | -------------------------------------------------------------------------------- /recipe/20220912_release/conf/train/data/mora_based_seq.yaml: -------------------------------------------------------------------------------- 1 | batch_size: 256 2 | num_workers: 8 3 | 4 | data_dir: dir/to/data 5 | 6 | feature_table_key: open-jtalk 7 | 8 | input_keys: 9 | - mora 10 | - surface 11 | - pos 12 | - c_type 13 | - c_form 14 | - accent_type 15 | - accent_con_type 16 | 17 | input_length_key: mora 18 | 19 | output_keys: # = tasks 20 | - intonation_phrase_boundary 21 | - accent_phrase_boundary 22 | - accent_status 23 | 24 | output_sizes: 25 | intonation_phrase_boundary: 3 26 | accent_phrase_boundary: 3 27 | accent_status: 3 28 | 29 | represent_mode: binary 30 | -------------------------------------------------------------------------------- /recipe/20250122_marine-plus/conf/train/data/mora_based_seq.yaml: -------------------------------------------------------------------------------- 1 | batch_size: 256 2 | num_workers: 8 3 | 4 | data_dir: dir/to/data 5 | 6 | feature_table_key: open-jtalk 7 | 8 | input_keys: 9 | - mora 10 | - surface 11 | - pos 12 | - c_type 13 | - c_form 14 | - accent_type 15 | - accent_con_type 16 | 17 | input_length_key: mora 18 | 19 | output_keys: # = tasks 20 | - intonation_phrase_boundary 21 | - accent_phrase_boundary 22 | - accent_status 23 | 24 | output_sizes: 25 | intonation_phrase_boundary: 3 26 | accent_phrase_boundary: 3 27 | accent_status: 3 28 | 29 | represent_mode: binary 30 | -------------------------------------------------------------------------------- /recipe/common/train.sh: -------------------------------------------------------------------------------- 1 | # NOTE: the script is supposed to be used in recipes like: 2 | # . script.sh 3 | # Please don't try to run the shell directly. 4 | 5 | cmd_args="--config-dir $script_dir/conf/train \ 6 | train=$train data=$data model=$model criterions=$criterions optim=$optim \ 7 | train.out_dir=$model_dir train.model_name=$tag train.save_vocab_path=false \ 8 | train.tensorboard_event_path=$tensorboard_dir train.test_log_dir=$in_domain_test_log_dir \ 9 | data.feature_table_key=$feature_table_key data.data_dir=$feature_pack_dir model.vocab_path=$vocab_path" 10 | 11 | marine-train $cmd_args 12 | -------------------------------------------------------------------------------- /marine/bin/conf/train/data/ap.yaml: -------------------------------------------------------------------------------- 1 | batch_size: 512 2 | num_workers: 4 3 | 4 | data_dir: dir/to/data 5 | 6 | feature_table_key: unidic-csj 7 | 8 | input_keys: 9 | - mora 10 | - surface 11 | - pos 12 | - word_type 13 | - c_type 14 | - c_form 15 | - accent_type 16 | - accent_con_type 17 | - accent_mod_type 18 | 19 | input_length_key: mora 20 | 21 | output_keys: # = tasks 22 | - intonation_phrase_boundary 23 | - accent_phrase_boundary 24 | - accent_status 25 | 26 | output_sizes: 27 | intonation_phrase_boundary: 3 28 | accent_phrase_boundary: 3 29 | accent_status: 21 30 | 31 | represent_mode: binary 32 | -------------------------------------------------------------------------------- /marine/bin/conf/train/data/binary.yaml: -------------------------------------------------------------------------------- 1 | batch_size: 512 2 | num_workers: 4 3 | 4 | data_dir: dir/to/data 5 | 6 | feature_table_key: unidic-csj 7 | 8 | input_keys: 9 | - mora 10 | - surface 11 | - pos 12 | - word_type 13 | - c_type 14 | - c_form 15 | - accent_type 16 | - accent_con_type 17 | - accent_mod_type 18 | 19 | input_length_key: mora 20 | 21 | output_keys: # = tasks 22 | - intonation_phrase_boundary 23 | - accent_phrase_boundary 24 | - accent_status 25 | 26 | output_sizes: 27 | intonation_phrase_boundary: 3 28 | accent_phrase_boundary: 3 29 | accent_status: 3 30 | 31 | represent_mode: binary 32 | -------------------------------------------------------------------------------- /marine/bin/conf/train/data/simple.yaml: -------------------------------------------------------------------------------- 1 | batch_size: 512 2 | num_workers: 4 3 | 4 | data_dir: dir/to/data 5 | 6 | feature_table_key: unidic-csj 7 | 8 | input_keys: 9 | - mora 10 | - surface 11 | - pos 12 | - word_type 13 | - c_type 14 | - c_form 15 | - accent_type 16 | - accent_con_type 17 | - accent_mod_type 18 | 19 | input_length_key: mora 20 | 21 | output_keys: # = tasks 22 | - intonation_phrase_boundary 23 | - accent_phrase_boundary 24 | - accent_status 25 | 26 | output_sizes: 27 | intonation_phrase_boundary: 3 28 | accent_phrase_boundary: 3 29 | accent_status: 3 30 | 31 | represent_mode: high_low 32 | -------------------------------------------------------------------------------- /marine/data/dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from torch.utils.data import Dataset 4 | 5 | 6 | class AccentDataset(Dataset[dict[str, Any]]): 7 | def __init__(self, data: dict[str, Any]): 8 | self.data = data 9 | 10 | def __len__(self): 11 | return len(self.data["labels"]) 12 | 13 | def __getitem__(self, index: int) -> dict[str, Any]: 14 | item = { 15 | "features": self.data["features"][index], 16 | "labels": self.data["labels"][index], 17 | } 18 | 19 | if "ids" in self.data.keys(): 20 | item["ids"] = self.data["ids"][index] 21 | 22 | return item 23 | -------------------------------------------------------------------------------- /marine/criterions/crossentopyloss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.modules.loss import _Loss 3 | 4 | 5 | class CrossEntropyLoss(_Loss): 6 | def __init__(self) -> None: 7 | super().__init__() 8 | self.loss_func = torch.nn.CrossEntropyLoss(reduction="sum", ignore_index=0) 9 | 10 | def forward( 11 | self, 12 | logits: torch.Tensor, 13 | labels: torch.Tensor, 14 | mask: torch.Tensor | None = None, 15 | ) -> torch.Tensor: 16 | loss = torch.zeros(1, device=logits.device) 17 | batch_size = logits.size(0) 18 | 19 | for logit, label, m in zip(logits, labels, mask): 20 | loss += self.loss_func(logit[m], label[m]) 21 | 22 | return torch.sum(loss / batch_size) 23 | -------------------------------------------------------------------------------- /recipe/20250122_marine-plus/visualize_wrong_mora_diff.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import pandas as pd 4 | import yaml 5 | 6 | from marine.utils.openjtalk_util import print_diff_hl 7 | 8 | 9 | # Load text.yaml 10 | with open(Path(__file__).parent / "data" / "text.yaml", encoding="utf-8") as f: 11 | text_data = yaml.safe_load(f) 12 | 13 | df = pd.read_csv( 14 | Path(__file__).parent / "wrong_mora_info.csv", 15 | sep="|", 16 | names=["wav", "jtalk", "anotation"], 17 | ) 18 | 19 | for i in df.iterrows(): 20 | wav_id = i[1]["wav"] 21 | print(f"\n========== {wav_id} ==========") 22 | if wav_id in text_data: 23 | print(f" Text : {text_data[wav_id]['text_level0']}") 24 | print_diff_hl(i[1]["jtalk"], i[1]["anotation"]) 25 | -------------------------------------------------------------------------------- /marine/criterions/log_likelihood.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Callable 2 | 3 | import torch 4 | from torch.nn.modules.loss import _Loss 5 | 6 | 7 | class LogLikelhood(_Loss): 8 | def __init__( 9 | self, 10 | log_likehood_func: Callable[ 11 | [torch.Tensor, torch.Tensor, torch.Tensor | None], torch.Tensor 12 | ], 13 | ) -> None: 14 | super().__init__() 15 | self.log_likehood_func = log_likehood_func 16 | 17 | def forward( 18 | self, 19 | classified: torch.Tensor, 20 | label: torch.Tensor, 21 | mask: torch.Tensor | None = None, 22 | ) -> torch.Tensor: 23 | batch_size = label.size(0) 24 | log_likelihood = self.log_likehood_func(classified, label, mask) 25 | 26 | return -log_likelihood / batch_size 27 | -------------------------------------------------------------------------------- /.github/workflows/deploy.yml: -------------------------------------------------------------------------------- 1 | name: Deploy to PyPI 2 | 3 | on: 4 | push: 5 | tags: 6 | - 'v*' 7 | workflow_dispatch: 8 | inputs: 9 | publish_pypi: 10 | type: boolean 11 | required: true 12 | description: 'Publish to PyPI (Production)' 13 | 14 | jobs: 15 | build-and-publish: 16 | runs-on: ubuntu-latest 17 | name: Build and publish distribution 18 | 19 | steps: 20 | - name: Checkout 21 | uses: actions/checkout@v4 22 | 23 | - name: Set up Python 24 | uses: actions/setup-python@v5 25 | with: 26 | python-version: '3.11' 27 | 28 | - name: Build wheel and sdist 29 | run: pipx run build 30 | 31 | - name: Publish to PyPI 32 | if: ${{ github.event.inputs.publish_pypi == 'true' || github.event_name == 'push' }} 33 | uses: pypa/gh-action-pypi-publish@release/v1 34 | with: 35 | packages-dir: dist 36 | password: ${{ secrets.PYPI_API_TOKEN }} 37 | -------------------------------------------------------------------------------- /marine/models/bilstm_encoder.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor, nn 2 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 3 | 4 | 5 | class BiLSTMEncoder(nn.Module): 6 | def __init__( 7 | self, 8 | input_size: int, 9 | hidden_size: int, 10 | num_layers: int, 11 | shared_with: str | None = None, 12 | ) -> None: 13 | super().__init__() 14 | 15 | self.lstm = nn.LSTM( 16 | input_size, 17 | hidden_size, 18 | num_layers=num_layers, 19 | batch_first=True, 20 | bidirectional=True, 21 | ) 22 | 23 | def forward(self, embeddings: Tensor, lengths: Tensor) -> Tensor: 24 | # LSTM -> B * T * Hidden-size 25 | packed = pack_padded_sequence( 26 | embeddings, lengths, batch_first=True, enforce_sorted=True 27 | ) 28 | logits, _ = self.lstm(packed) 29 | logits, _ = pad_packed_sequence(logits, batch_first=True) 30 | 31 | return logits 32 | -------------------------------------------------------------------------------- /marine/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from os.path import dirname 4 | 5 | 6 | format = "[%(asctime)s][%(name)s:%(module)s][%(levelname)s] - %(message)s" 7 | 8 | 9 | def getLogger( 10 | verbose: int = 0, filename: str | None = None, name: str = "marine" 11 | ) -> logging.Logger: 12 | handlers = [] 13 | 14 | stream_handler = logging.StreamHandler() 15 | stream_handler.setFormatter(logging.Formatter(format)) 16 | handlers.append(stream_handler) 17 | 18 | if filename is not None: 19 | os.makedirs(dirname(filename), exist_ok=True) 20 | file_handler = logging.FileHandler(filename=filename) 21 | file_handler.setLevel(logging.INFO) 22 | file_handler.setFormatter(logging.Formatter(format)) 23 | handlers.append(file_handler) 24 | 25 | logging.basicConfig(handlers=handlers) 26 | 27 | logger = logging.getLogger(name) 28 | if verbose >= 100: 29 | logger.setLevel(logging.DEBUG) 30 | elif verbose > 0: 31 | logger.setLevel(logging.INFO) 32 | else: 33 | logger.setLevel(logging.WARN) 34 | 35 | return logger 36 | -------------------------------------------------------------------------------- /marine/utils/g2p_util/boundary.py: -------------------------------------------------------------------------------- 1 | from typing import Literal, cast 2 | 3 | from .util import CONNECTABLE_MORA, HALF_PUNCTUATION 4 | 5 | 6 | def represent_syllable_boundary( 7 | index: int, moras: list[str], len_phonemes: int 8 | ) -> list[Literal[0, 1]]: 9 | """ 10 | Represent syllable boundary by 2 types 11 | - types: 12 | - 0: non-boundary 13 | - 1: syllable boundary 14 | """ 15 | 16 | if moras[index] in HALF_PUNCTUATION: 17 | return [1] 18 | 19 | # Init the boundary into normal boundary 20 | # e.g., か -> [k, a] -> [0, 1] 21 | # わ -> [wa] -> [1] 22 | boundary = cast(list[Literal[0, 1]], [0] * (len_phonemes - 1) + [1]) 23 | 24 | # remove syllable boundaray if next mora is connectable mora (i.e. ッ, ン) 25 | # e.g., [ ... "デ", "ン" ] -> [ ... [d0, e1], [N1] ] -> [ ... [d0, e0], [N1] ] 26 | if index + 1 < len(moras): 27 | if ( 28 | moras[index] not in CONNECTABLE_MORA 29 | and moras[index + 1] in CONNECTABLE_MORA 30 | ): 31 | boundary[-1] = 0 32 | else: 33 | boundary[-1] = 1 34 | 35 | return boundary 36 | -------------------------------------------------------------------------------- /marine/models/base_model.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Mapping 2 | 3 | from torch import Tensor, nn 4 | 5 | from marine.models.embedding import SimpleEmbedding 6 | 7 | 8 | class BaseModel(nn.Module): 9 | def __init__( 10 | self, 11 | embedding: SimpleEmbedding, 12 | encoders: Mapping[str, nn.Module], 13 | decoders: Mapping[str, nn.Module], 14 | ) -> None: 15 | super().__init__() 16 | self.embedding = embedding 17 | self.encoders = nn.ModuleDict(encoders) 18 | self.decoders = nn.ModuleDict(decoders) 19 | 20 | def forward( 21 | self, 22 | task: str, 23 | embedding_features: dict[str, Tensor], 24 | lengths: Tensor, 25 | mask: Tensor, 26 | prev_decoder_outputs: dict[str, Tensor] | None = None, 27 | decoder_targets: dict[str, Tensor] | None = None, 28 | ) -> Tensor | tuple[Tensor, ...]: 29 | embeddings = self.embedding(**embedding_features) 30 | encoder_outputs = self.encoders[task](embeddings, lengths) 31 | decoder_outputs = self.decoders[task]( 32 | encoder_outputs, mask, prev_decoder_outputs, decoder_targets 33 | ) 34 | 35 | return decoder_outputs 36 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["wheel", "setuptools"] 3 | 4 | [tool.taskipy.tasks] 5 | install = "if [ ! -d '.venv' ]; then python -m venv .venv; fi && .venv/bin/pip install -U -e '.[dev,pyopenjtalk]'" 6 | lint = ".venv/bin/ruff check --fix ." 7 | format = ".venv/bin/ruff format ." 8 | test = ".venv/bin/pytest" 9 | 10 | [tool.ruff] 11 | # 1行の長さを最大88文字に設定 12 | line-length = 88 13 | # インデントの幅を4スペースに設定 14 | indent-width = 4 15 | # Python 3.10 を利用する 16 | target-version = "py310" 17 | 18 | [tool.ruff.lint] 19 | # flake8, pycodestyle, pyupgrade, isort, Ruff 固有のルールを使う 20 | select = ["F", "E", "W", "UP", "I", "RUF"] 21 | ignore = [ 22 | "E501", # 1行の長さを超えている場合の警告を抑制 23 | "E731", # Do not assign a `lambda` expression, use a `def` を抑制 24 | "UP038", # 非推奨化されているルールのため 25 | "RUF001", # 全角記号など `ambiguous unicode character` も使いたいため 26 | "RUF002", # 全角記号など `ambiguous unicode character` も使いたいため 27 | "RUF003", # 全角記号など `ambiguous unicode character` も使いたいため 28 | "RUF005", # 万が一のリグレッション回避のため抑制 29 | ] 30 | 31 | [tool.ruff.lint.isort] 32 | # インポートブロックの後に2行空ける 33 | lines-after-imports = 2 34 | 35 | [tool.ruff.format] 36 | # ダブルクオートを使う 37 | quote-style = "double" 38 | # インデントにはスペースを使う 39 | indent-style = "space" 40 | -------------------------------------------------------------------------------- /marine/models/embedding.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Mapping 2 | 3 | from torch import Tensor, cat, nn 4 | 5 | from marine.data.feature.feature_set import FeatureSet 6 | 7 | 8 | class SimpleEmbedding(nn.Module): 9 | def __init__( 10 | self, 11 | embeding_sizes: Mapping[str, int], 12 | dropout: float | None, 13 | feature_set: FeatureSet, 14 | ) -> None: 15 | super().__init__() 16 | 17 | # embeddings 18 | self.embeddings = nn.ModuleDict( 19 | { 20 | key: nn.Embedding( 21 | len(feature_set.feature_to_id[key]), 22 | embeding_sizes[key], 23 | padding_idx=feature_set.feature_to_id[key][feature_set.pad_token], 24 | ) 25 | for key in feature_set.feature_keys 26 | } 27 | ) 28 | 29 | if dropout: 30 | self.dropout = nn.Dropout(dropout) 31 | else: 32 | self.dropout = None 33 | 34 | def forward(self, **kwargs: dict[str, Tensor]) -> Tensor: 35 | # Embedding -> B * T * Embedding-size 36 | embs = [self.embeddings[key](kwargs[key]) for key in self.embeddings.keys()] 37 | embs = cat(embs, dim=2) 38 | 39 | if self.dropout: 40 | embs = self.dropout(embs) 41 | 42 | return embs 43 | -------------------------------------------------------------------------------- /marine/utils/regex.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | 4 | _re_has_longvowel = None 5 | _re_hiragana = None 6 | _re_katakana = None 7 | _re_kanji = None 8 | _re_letter = None 9 | _re_symbol = None 10 | 11 | 12 | def is_hiragana(surface: str) -> bool: 13 | global _re_hiragana 14 | if _re_hiragana is None: 15 | _re_hiragana = re.compile(r"^[ぁ-ん]+$") 16 | return _re_hiragana.match(surface) is not None 17 | 18 | 19 | def is_katakana(surface: str) -> bool: 20 | global _re_katakana 21 | if _re_katakana is None: 22 | _re_katakana = re.compile(r"^[ァ-ヴヶ]+$") 23 | return _re_katakana.match(surface) is not None 24 | 25 | 26 | def is_kanji(surface: str) -> bool: 27 | global _re_kanji 28 | if _re_kanji is None: 29 | _re_kanji = re.compile(r"^[一-龠]+$") 30 | return _re_kanji.match(surface) is not None 31 | 32 | 33 | def is_letter(surface: str) -> bool: 34 | global _re_letter 35 | if _re_letter is None: 36 | _re_letter = re.compile(r"^[a-zA-Z]+$") 37 | return _re_letter.match(surface) is not None 38 | 39 | 40 | def is_symbol(surface: str) -> bool: 41 | global _re_symbol 42 | if _re_symbol is None: 43 | _re_symbol = re.compile(r"^[〆々ー,.?!]+$") 44 | return _re_symbol.match(surface) is not None 45 | 46 | 47 | def has_longvowel(text: str) -> bool: 48 | global _re_has_longvowel 49 | if _re_has_longvowel is None: 50 | _re_has_longvowel = re.compile(r"(aa|ii|uu|ee|oo)$") 51 | return _re_has_longvowel.search(text) is not None 52 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | // 保存時に Ruff による自動フォーマットを行う 3 | "[python]": { 4 | "editor.codeActionsOnSave": { 5 | "source.fixAll.ruff": "explicit", 6 | "source.organizeImports.ruff": "explicit", 7 | }, 8 | "editor.defaultFormatter": "charliermarsh.ruff", 9 | "editor.formatOnSave": true, 10 | }, 11 | // Pylance の Type Checking を有効化 12 | "python.languageServer": "Pylance", 13 | "python.analysis.typeCheckingMode": "strict", 14 | // Pylance の Type Checking のうち、いくつかのエラー報告を抑制する 15 | "python.analysis.diagnosticSeverityOverrides": { 16 | "reportAssignmentType": "warning", 17 | "reportConstantRedefinition": "none", 18 | "reportGeneralTypeIssues": "warning", 19 | "reportMissingParameterType": "warning", 20 | "reportMissingTypeStubs": "none", 21 | "reportPossiblyUnboundVariable": "warning", 22 | "reportPrivateImportUsage": "none", 23 | "reportPrivateUsage": "warning", 24 | "reportShadowedImports": "none", 25 | "reportUnnecessaryComparison": "none", 26 | "reportUnnecessaryIsInstance": "none", 27 | "reportUnknownArgumentType": "none", 28 | "reportUnknownLambdaType": "none", 29 | "reportUnknownMemberType": "none", 30 | "reportUnknownParameterType": "warning", 31 | "reportUnknownVariableType": "none", 32 | "reportUntypedFunctionDecorator": "none", 33 | "reportUnusedFunction": "none", 34 | "reportUnusedVariable": "information", 35 | }, 36 | } 37 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Python package 5 | 6 | on: 7 | push: 8 | branches: [ "main" ] 9 | pull_request: 10 | branches: [ "main" ] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | strategy: 17 | fail-fast: false 18 | matrix: 19 | include: 20 | - os: ubuntu-latest 21 | python-version: '3.10' 22 | - os: ubuntu-latest 23 | python-version: '3.11' 24 | - os: ubuntu-latest 25 | python-version: '3.12' 26 | - os: macos-latest 27 | python-version: '3.10' 28 | - os: macos-latest 29 | python-version: '3.11' 30 | - os: macos-latest 31 | python-version: '3.12' 32 | - os: windows-latest 33 | python-version: '3.10' 34 | - os: windows-latest 35 | python-version: '3.11' 36 | - os: windows-latest 37 | python-version: '3.12' 38 | 39 | steps: 40 | - uses: actions/checkout@v3 41 | - name: Set up Python ${{ matrix.python-version }} 42 | uses: actions/setup-python@v3 43 | with: 44 | python-version: ${{ matrix.python-version }} 45 | - name: Install dependencies 46 | run: | 47 | python -m pip install --upgrade pip 48 | pip install -e ".[dev]" 49 | - name: Lint with Ruff 50 | run: | 51 | ruff check . 52 | - name: Test with pytest 53 | run: | 54 | pytest 55 | -------------------------------------------------------------------------------- /marine/data/util.py: -------------------------------------------------------------------------------- 1 | from logging import getLogger 2 | from pathlib import Path 3 | from typing import Any 4 | 5 | from hydra.utils import to_absolute_path 6 | from joblib import load 7 | from omegaconf import DictConfig 8 | from torch.utils.data import DataLoader 9 | 10 | from .dataset import AccentDataset 11 | from .pad import Padsequence 12 | 13 | 14 | logger = getLogger(__name__) 15 | 16 | 17 | def load_dataset( 18 | config: DictConfig, phases: list[str] | None = None 19 | ) -> dict[str, DataLoader[dict[str, Any]]]: 20 | dataloader = {} 21 | 22 | data_dir = Path(to_absolute_path(config.data.data_dir)) 23 | 24 | if phases is None: 25 | phases = ["train", "val", "test"] 26 | elif not isinstance(phases, list): 27 | raise TypeError(f"Unvailable values: {phases}") 28 | 29 | for phase in phases: 30 | is_train = phase == "train" 31 | targets = ["features", "labels", "ids"] 32 | data = {} 33 | 34 | for target in targets: 35 | data_path = data_dir / phase / f"{target}.pkl" 36 | data[target] = load(data_path) 37 | 38 | dataset = AccentDataset(data) 39 | 40 | if logger is not None: 41 | logger.info(f"{phase} data size : {len(dataset):,}") 42 | 43 | if is_train: 44 | shuffle = True 45 | else: 46 | shuffle = False 47 | 48 | dataloader[phase] = DataLoader( 49 | dataset, 50 | batch_size=config.data.batch_size, 51 | shuffle=shuffle, 52 | collate_fn=Padsequence( 53 | input_keys=config.data.input_keys, 54 | input_length_key=config.data.input_length_key, 55 | output_keys=config.data.output_keys, 56 | num_classes=config.data.output_sizes, 57 | ), 58 | num_workers=config.data.num_workers, 59 | ) 60 | 61 | return dataloader 62 | -------------------------------------------------------------------------------- /marine/bin/conf/train/model/crf-mtl.yaml: -------------------------------------------------------------------------------- 1 | vocab_path: null 2 | 3 | # general setting for each module 4 | embedding: 5 | _target_: marine.models.SimpleEmbedding 6 | embeding_sizes: 7 | surface: 512 8 | mora: 256 9 | pos: 128 10 | word_type: 64 11 | c_type: 256 12 | c_form: 128 13 | accent_type: 64 14 | accent_con_type: 64 15 | accent_mod_type: 64 16 | dropout: 0.5 17 | 18 | encoder: 19 | param: 20 | _target_: marine.models.BiLSTMEncoder 21 | num_layers: 3 22 | hidden_size: 512 23 | # input_size: depend on setting for embedding output size 24 | shared_with: 25 | intonation_phrase_boundary: null 26 | accent_phrase_boundary: intonation_phrase_boundary 27 | accent_status: accent_phrase_boundary 28 | 29 | decoder: 30 | intonation_phrase_boundary: 31 | _target_: marine.models.CRFDecoder 32 | prev_task_embedding_label_list: null 33 | prev_task_embedding_size: null 34 | prev_task_dropout: 0.5 35 | # input_size: depend on setting for encoder output size / is hierarchical decoder 36 | # output_size: depend on setting for each label 37 | accent_phrase_boundary: 38 | _target_: marine.models.CRFDecoder 39 | prev_task_embedding_label_list: 40 | - intonation_phrase_boundary 41 | prev_task_embedding_size: 42 | intonation_phrase_boundary: 64 43 | prev_task_dropout: 0.5 44 | # input_size: depend on setting for encoder output size / is hierarchical decoder 45 | # output_size: depend on setting for each label 46 | accent_phrase_boundary: 47 | _target_: marine.models.CRFDecoder 48 | prev_task_embedding_label_list: 49 | - accent_phrase_boundary 50 | prev_task_embedding_size: 51 | accent_phrase_boundary: 64 52 | prev_task_dropout: 0.5 53 | # input_size: depend on setting for encoder output size / is hierarchical decoder 54 | # output_size: depend on setting for each label 55 | 56 | base: 57 | _target_: marine.models.BaseModel 58 | -------------------------------------------------------------------------------- /tests/test_jsut_corpus.py: -------------------------------------------------------------------------------- 1 | from logging import getLogger 2 | 3 | from marine.bin.jsut2corpus import parse_jsut_annotation 4 | 5 | 6 | logger = getLogger("test") 7 | 8 | 9 | def test_jsut_parser(): 10 | for ( 11 | jsut_annotaion, 12 | accent_status_seq_level, 13 | accent_status_represent_mode, 14 | expect, 15 | ) in [ 16 | ( 17 | "^ム]シロ#ロ[ンゲノホ]ーガ#ハ[ゲヤス]イッテ#キ[ータゾ?$", 18 | "ap", 19 | "binary", 20 | { 21 | "pron": "ムシロロンゲノホーガハゲヤスイッテキータゾ", 22 | "accent_status": "1,5,4,0", 23 | "accent_phrase_boundary": "0,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,1,0,0,0", 24 | "intonation_phrase_boundary": "0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0", 25 | }, 26 | ), 27 | ( 28 | "^ム]シロ#ロ[ンゲノホ]ーガ#ハ[ゲヤス]イッテ#キ[ータゾ?$", 29 | "mora", 30 | "binary", 31 | { 32 | "pron": "ムシロロンゲノホーガハゲヤスイッテキータゾ", 33 | "accent_status": "1,0,0,0,0,0,0,1,0,0,0,0,0,1,0,0,0,0,0,0,0", 34 | "accent_phrase_boundary": "0,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,1,0,0,0", 35 | "intonation_phrase_boundary": "0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0", 36 | }, 37 | ), 38 | ( 39 | "^ム]シロ#ロ[ンゲノホ]ーガ#ハ[ゲヤス]イッテ#キ[ータゾ?$", 40 | "mora", 41 | "high_low", 42 | { 43 | "pron": "ムシロロンゲノホーガハゲヤスイッテキータゾ", 44 | "accent_status": "1,0,0,0,1,1,1,1,0,0,0,1,1,1,0,0,0,0,1,1,1", 45 | "accent_phrase_boundary": "0,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,1,0,0,0", 46 | "intonation_phrase_boundary": "0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0", 47 | }, 48 | ), 49 | ]: 50 | result = parse_jsut_annotation( 51 | jsut_annotaion, accent_status_seq_level, accent_status_represent_mode 52 | ) 53 | assert result == expect 54 | -------------------------------------------------------------------------------- /recipe/20220912_release/conf/train/model/mtl_lstm_encoder_crf_decoder.yaml: -------------------------------------------------------------------------------- 1 | vocab_path: null 2 | 3 | # general setting for each module 4 | embedding: 5 | _target_: marine.models.SimpleEmbedding 6 | embeding_sizes: 7 | surface: 512 8 | mora: 256 9 | pos: 128 10 | word_type: 64 11 | c_type: 256 12 | c_form: 128 13 | accent_type: 64 14 | accent_con_type: 64 15 | accent_mod_type: 64 16 | dropout: 0.5 17 | 18 | encoder: 19 | param: 20 | _target_: marine.models.BiLSTMEncoder 21 | num_layers: 3 22 | hidden_size: 512 23 | # input_size: depend on setting for embedding output size 24 | shared_with: 25 | intonation_phrase_boundary: null 26 | accent_phrase_boundary: intonation_phrase_boundary 27 | accent_status: accent_phrase_boundary 28 | 29 | decoder: 30 | intonation_phrase_boundary: 31 | _target_: marine.models.CRFDecoder 32 | prev_task_embedding_label_list: null 33 | prev_task_embedding_size: null 34 | prev_task_dropout: 0.5 35 | # input_size: depend on setting for encoder output size / is hierarchical decoder 36 | # output_size: depend on setting for each label 37 | accent_phrase_boundary: 38 | _target_: marine.models.CRFDecoder 39 | prev_task_embedding_label_list: 40 | - intonation_phrase_boundary 41 | prev_task_embedding_size: 42 | intonation_phrase_boundary: 64 43 | prev_task_dropout: 0.5 44 | # input_size: depend on setting for encoder output size / is hierarchical decoder 45 | # output_size: depend on setting for each label 46 | accent_status: 47 | _target_: marine.models.CRFDecoder 48 | prev_task_embedding_label_list: 49 | - accent_phrase_boundary 50 | prev_task_embedding_size: 51 | accent_phrase_boundary: 64 52 | prev_task_dropout: 0.5 53 | # input_size: depend on setting for encoder output size / is hierarchical decoder 54 | # output_size: depend on setting for each label 55 | 56 | base: 57 | _target_: marine.models.BaseModel 58 | -------------------------------------------------------------------------------- /recipe/20250122_marine-plus/conf/train/model/mtl_lstm_encoder_crf_decoder.yaml: -------------------------------------------------------------------------------- 1 | vocab_path: null 2 | 3 | # general setting for each module 4 | embedding: 5 | _target_: marine.models.SimpleEmbedding 6 | embeding_sizes: 7 | surface: 512 8 | mora: 256 9 | pos: 128 10 | word_type: 64 11 | c_type: 256 12 | c_form: 128 13 | accent_type: 64 14 | accent_con_type: 64 15 | accent_mod_type: 64 16 | dropout: 0.5 17 | 18 | encoder: 19 | param: 20 | _target_: marine.models.BiLSTMEncoder 21 | num_layers: 3 22 | hidden_size: 512 23 | # input_size: depend on setting for embedding output size 24 | shared_with: 25 | intonation_phrase_boundary: null 26 | accent_phrase_boundary: intonation_phrase_boundary 27 | accent_status: accent_phrase_boundary 28 | 29 | decoder: 30 | intonation_phrase_boundary: 31 | _target_: marine.models.CRFDecoder 32 | prev_task_embedding_label_list: null 33 | prev_task_embedding_size: null 34 | prev_task_dropout: 0.5 35 | # input_size: depend on setting for encoder output size / is hierarchical decoder 36 | # output_size: depend on setting for each label 37 | accent_phrase_boundary: 38 | _target_: marine.models.CRFDecoder 39 | prev_task_embedding_label_list: 40 | - intonation_phrase_boundary 41 | prev_task_embedding_size: 42 | intonation_phrase_boundary: 64 43 | prev_task_dropout: 0.5 44 | # input_size: depend on setting for encoder output size / is hierarchical decoder 45 | # output_size: depend on setting for each label 46 | accent_status: 47 | _target_: marine.models.CRFDecoder 48 | prev_task_embedding_label_list: 49 | - accent_phrase_boundary 50 | prev_task_embedding_size: 51 | accent_phrase_boundary: 64 52 | prev_task_dropout: 0.5 53 | # input_size: depend on setting for encoder output size / is hierarchical decoder 54 | # output_size: depend on setting for each label 55 | 56 | base: 57 | _target_: marine.models.BaseModel 58 | -------------------------------------------------------------------------------- /marine/utils/g2p_util/accent.py: -------------------------------------------------------------------------------- 1 | from typing import Literal 2 | 3 | 4 | def set_accent_status( 5 | accent: int, 6 | ) -> tuple[Literal[0, 1], int]: 7 | high = -1 8 | end_low = -1 9 | 10 | # 1-type : H-L-L ... 11 | if accent == 1: 12 | high = 0 13 | end_low = 1 14 | 15 | # 0-type : L-H-H ... 16 | elif accent <= 0: 17 | high = 1 18 | 19 | # N-type : L-H ... H_n-L_n+1-L_n+2 .. 20 | else: 21 | high = 1 22 | end_low = accent 23 | 24 | return high, end_low 25 | 26 | 27 | def represent_accent_high_low(index: int, high: int, end_low: int) -> Literal[0, 1]: 28 | """ 29 | Represent the accent by a current status of the mora 30 | - types: 31 | - 0: Low 32 | - 1: High 33 | """ 34 | 35 | # Init accent into flat-accent(=0) 36 | accent = 0 37 | 38 | # set current accent to high 39 | if index >= high and (end_low < 0 or index < end_low): 40 | accent = 1 41 | 42 | return accent 43 | 44 | 45 | def represent_longvowel_accent_high_low( 46 | index: int, high: int, end_low: int 47 | ) -> Literal[0, 1]: 48 | # Init accent into flat accent 49 | accent = 0 50 | 51 | # make previous accent to high 52 | if index >= high and (end_low < 0 or index <= end_low): 53 | accent = 1 54 | 55 | return accent 56 | 57 | 58 | def represent_accent_binary(index: int, high: int, end_low: int) -> Literal[0, 1]: 59 | """ 60 | Represent the accent by a current status of the mora 61 | - types: 62 | - 0: Not accent nucleus 63 | - 1: Accent nucleus 64 | """ 65 | 66 | # Init accent into flat-accent(=0), and make previous accent to high if match it 67 | accent = 1 if index == end_low - 1 else 0 68 | 69 | return accent 70 | 71 | 72 | def represent_longvowel_accent_binary( 73 | index: int, high: int, end_low: int 74 | ) -> Literal[0, 1]: 75 | # Init accent into a flat accent, and make previous accent to high if match it 76 | return 1 if index == end_low else 0 77 | -------------------------------------------------------------------------------- /marine/bin/conf/train/model/att-mtl.yaml: -------------------------------------------------------------------------------- 1 | vocab_path: null 2 | 3 | # general setting for each module 4 | embedding: 5 | _target_: marine.models.SimpleEmbedding 6 | embeding_sizes: 7 | surface: 512 8 | mora: 256 9 | pos: 128 10 | word_type: 64 11 | c_type: 256 12 | c_form: 128 13 | accent_type: 64 14 | accent_con_type: 64 15 | accent_mod_type: 64 16 | dropout: 0.5 17 | 18 | encoder: 19 | param: 20 | _target_: marine.models.BiLSTMEncoder 21 | num_layers: 3 22 | hidden_size: 512 23 | # input_size: depend on setting for embedding output size 24 | shared_with: 25 | intonation_phrase_boundary: null 26 | accent_phrase_boundary: intonation_phrase_boundary 27 | accent_status: accent_phrase_boundary 28 | 29 | decoder: 30 | intonation_phrase_boundary: 31 | _target_: marine.models.CRFDecoder 32 | prev_task_embedding_label_list: null 33 | prev_task_embedding_size: null 34 | prev_task_dropout: 0.5 35 | # input_size: depend on setting for encoder output size / is hierarchical decoder 36 | # output_size: depend on setting for each label 37 | accent_phrase_boundary: 38 | _target_: marine.models.CRFDecoder 39 | prev_task_embedding_label_list: 40 | - intonation_phrase_boundary 41 | prev_task_embedding_size: 42 | intonation_phrase_boundary: 64 43 | prev_task_dropout: 0.5 44 | # input_size: depend on setting for encoder output size / is hierarchical decoder 45 | # output_size: depend on setting for each label 46 | accent_status: 47 | _target_: marine.models.AttentionBasedLSTMDecoder 48 | prev_task_embedding_label_list: 49 | - accent_phrase_boundary 50 | prev_task_embedding_size: 51 | accent_phrase_boundary: 64 52 | decoder_embedding_size: 128 53 | hidden_size: 512 54 | num_layers: 2 55 | attention_hidden_size: 256 56 | zoneout: 0.1 57 | prev_task_dropout: 0.5 58 | decoder_prev_out_dropout: 0.5 59 | # input_size: depend on setting for encoder output size / is hierarchical decoder 60 | # output_size: depend on setting for each label 61 | 62 | base: 63 | _target_: marine.models.BaseModel 64 | -------------------------------------------------------------------------------- /marine/bin/conf/train/model/linear-att-mtl.yaml: -------------------------------------------------------------------------------- 1 | vocab_path: null 2 | 3 | # general setting for each module 4 | embedding: 5 | _target_: marine.models.SimpleEmbedding 6 | embeding_sizes: 7 | surface: 512 8 | mora: 256 9 | pos: 128 10 | word_type: 64 11 | c_type: 256 12 | c_form: 128 13 | accent_type: 64 14 | accent_con_type: 64 15 | accent_mod_type: 64 16 | dropout: 0.5 17 | 18 | encoder: 19 | param: 20 | _target_: marine.models.BiLSTMEncoder 21 | num_layers: 3 22 | hidden_size: 512 23 | # input_size: depend on setting for embedding output size 24 | shared_with: 25 | intonation_phrase_boundary: null 26 | accent_phrase_boundary: intonation_phrase_boundary 27 | accent_status: accent_phrase_boundary 28 | 29 | decoder: 30 | intonation_phrase_boundary: 31 | _target_: marine.models.LinearDecoder 32 | prev_task_embedding_label_list: null 33 | prev_task_embedding_size: null 34 | prev_task_dropout: 0.5 35 | # input_size: depend on setting for encoder output size / is hierarchical decoder 36 | # output_size: depend on setting for each label 37 | accent_phrase_boundary: 38 | _target_: marine.models.LinearDecoder 39 | prev_task_embedding_label_list: 40 | - intonation_phrase_boundary 41 | prev_task_embedding_size: 42 | intonation_phrase_boundary: 64 43 | prev_task_dropout: 0.5 44 | # input_size: depend on setting for encoder output size / is hierarchical decoder 45 | # output_size: depend on setting for each label 46 | accent_status: 47 | _target_: marine.models.AttentionBasedLSTMDecoder 48 | prev_task_embedding_label_list: 49 | - accent_phrase_boundary 50 | prev_task_embedding_size: 51 | accent_phrase_boundary: 64 52 | decoder_embedding_size: 128 53 | hidden_size: 512 54 | num_layers: 2 55 | attention_hidden_size: 256 56 | zoneout: 0.1 57 | prev_task_dropout: 0.5 58 | decoder_prev_out_dropout: 0.5 59 | # input_size: depend on setting for encoder output size / is hierarchical decoder 60 | # output_size: depend on setting for each label 61 | 62 | base: 63 | _target_: marine.models.BaseModel 64 | -------------------------------------------------------------------------------- /marine/bin/build_vocab.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | from collections import defaultdict 4 | from pathlib import Path 5 | 6 | from joblib import dump 7 | from tqdm import tqdm 8 | 9 | from marine.logger import getLogger 10 | from marine.utils.util import load_json_corpus 11 | 12 | 13 | logger = None 14 | 15 | 16 | def get_parser(): 17 | parser = argparse.ArgumentParser( 18 | description="Generate vocabulary file for word-embedding", 19 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 20 | ) 21 | parser.add_argument("in_path", type=Path, help="Path or directory for feature file") 22 | parser.add_argument("out_dir", type=Path, help="Output directory") 23 | parser.add_argument( 24 | "--min_freq", 25 | "-m", 26 | type=int, 27 | default=2, 28 | help=""" 29 | Minimum word frequency in the whole corpus, 30 | which to judge whether the word include in the vocabulary 31 | """, 32 | ) 33 | parser.add_argument( 34 | "--verbose", 35 | "-v", 36 | action="store_true", 37 | help="Whether print log for debug", 38 | ) 39 | return parser 40 | 41 | 42 | def count_words(words): 43 | freqs = defaultdict(int) 44 | 45 | for word in tqdm(words, "Counting word", leave=False): 46 | freqs[word] += 1 47 | 48 | freqs = sorted(freqs.items(), key=lambda x: x[1], reverse=True) 49 | 50 | return freqs 51 | 52 | 53 | def filter_words(freqs, min_freq): 54 | return list(filter(lambda x: x[1] >= min_freq, freqs)) 55 | 56 | 57 | def save_vocab(freqs, output_dir): 58 | words = [surface for surface, _ in freqs] 59 | dump(words, output_dir / "vocab.pkl", compress=True) 60 | 61 | 62 | def entry(argv=sys.argv): 63 | global logger 64 | 65 | args = get_parser().parse_args(argv[1:]) 66 | logger = getLogger(args.verbose) 67 | logger.debug(f"Loaded parameters: {args}") 68 | 69 | input_path = args.in_path 70 | output_dir = args.out_dir 71 | 72 | if not output_dir.exists(): 73 | output_dir.mkdir(parents=True) 74 | 75 | corpus = load_json_corpus(input_path) 76 | words = [node["surface"] for script in corpus for node in script["nodes"]] 77 | logger.info(f"Loaded {len(words):,} words") 78 | 79 | freqs = count_words(words) 80 | logger.info(f"Get {len(freqs):,} unique words") 81 | 82 | freqs = filter_words(freqs, args.min_freq) 83 | logger.info(f"Filtered {len(freqs):,} words") 84 | 85 | save_vocab(freqs, output_dir) 86 | 87 | 88 | if __name__ == "__main__": 89 | sys.exit(entry()) 90 | -------------------------------------------------------------------------------- /marine/utils/pretrained.py: -------------------------------------------------------------------------------- 1 | # Acknowledgement: some of the code was adapted from ttslearn 2 | # Copyright 2021 Ryuichi Yamamoto (MIT License) 3 | 4 | import os 5 | import shutil 6 | import tarfile 7 | from os.path import join 8 | from pathlib import Path 9 | from urllib.request import urlretrieve 10 | 11 | from tqdm.auto import tqdm 12 | 13 | 14 | DEFAULT_CACHE_DIR = join(os.path.expanduser("~"), ".cache", "marine") 15 | CACHE_DIR = os.environ.get("MARINE_CACHE_DIR", DEFAULT_CACHE_DIR) 16 | 17 | DEFAULT_VERSION = "v0.0.6-post1" 18 | MODEL_BASE_URL = "https://github.com/tsukumijima/marine-plus/releases/download/" 19 | 20 | 21 | # https://github.com/tqdm/tqdm#hooks-and-callbacks 22 | class _TqdmUpTo(tqdm): # type: ignore 23 | def update_to(self, b=1, bsize=1, tsize=None): 24 | if tsize is not None: 25 | self.total = tsize 26 | return self.update(b * bsize - self.n) 27 | 28 | 29 | def retrieve_pretrained_model(version: str | None = None) -> Path: 30 | """Retrieve pretrained model from local cache or download from GitHub. 31 | Args: 32 | version (str): Version of pretrained model. 33 | Returns: 34 | str: Path to the pretrained model. 35 | Raises: 36 | ValueError: If the pretrained model is not found. 37 | Examples: 38 | >>> from marine.utils.pretrained import retrieve_pretrained_model 39 | >>> from marine.predict import Predictor 40 | >>> model_dir = retrieve_pretrained_model("v0.0.6-post1") 41 | >>> predictor = Tacotron2PWGTTS(model_dir=model_dir, device="cpu") 42 | """ 43 | 44 | if version is None: 45 | version = DEFAULT_VERSION 46 | elif not isinstance(version, str): 47 | raise TypeError(f"version must be str not {type(version)}") 48 | 49 | url = MODEL_BASE_URL + f"{version}/model.tar.gz" 50 | 51 | # NOTE: assuming that filename and extracted is the same 52 | out_dir = Path(CACHE_DIR) / version 53 | filename = Path(CACHE_DIR) / f"{version}/model.tar.gz" 54 | 55 | # re-download models 56 | if out_dir.exists() and len(list(out_dir.glob("*.pth"))) == 0: 57 | shutil.rmtree(out_dir) 58 | 59 | if not out_dir.exists(): 60 | print(f'Downloading: "{url}"') 61 | 62 | out_dir.mkdir(parents=True, exist_ok=True) 63 | 64 | with _TqdmUpTo( 65 | unit="B", 66 | unit_scale=True, 67 | unit_divisor=1024, 68 | miniters=1, 69 | desc=f"{version}/model.tar.gz", 70 | ) as t: # all optional kwargs 71 | urlretrieve(url, filename, reporthook=t.update_to) 72 | t.total = t.n 73 | with tarfile.open(filename, mode="r|gz") as f: 74 | f.extractall(path=out_dir) 75 | os.remove(filename) 76 | 77 | return out_dir 78 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Mid-outs 2 | /logs 3 | /data 4 | /outputs 5 | /tensorboard 6 | /checkpoint 7 | 8 | # Recipe related 9 | recipe*/*/data* 10 | recipe*/*/downloads 11 | recipe*/*/tensorboard 12 | recipe*/*/outputs 13 | !recipe*/*/database 14 | wrong_mora_info* 15 | 16 | # For MacOS 17 | .DS_Store 18 | 19 | # Byte-compiled / optimized / DLL files 20 | __pycache__/ 21 | *.py[cod] 22 | *$py.class 23 | 24 | # C extensions 25 | *.so 26 | 27 | # Distribution / packaging 28 | .Python 29 | build/ 30 | develop-eggs/ 31 | dist/ 32 | downloads/ 33 | eggs/ 34 | .eggs/ 35 | lib/ 36 | lib64/ 37 | parts/ 38 | sdist/ 39 | var/ 40 | wheels/ 41 | pip-wheel-metadata/ 42 | share/python-wheels/ 43 | *.egg-info/ 44 | .installed.cfg 45 | *.egg 46 | MANIFEST 47 | 48 | # PyInstaller 49 | # Usually these files are written by a python script from a template 50 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 51 | *.manifest 52 | *.spec 53 | 54 | # Installer logs 55 | pip-log.txt 56 | pip-delete-this-directory.txt 57 | 58 | # Unit test / coverage reports 59 | htmlcov/ 60 | .tox/ 61 | .nox/ 62 | .coverage 63 | .coverage.* 64 | .cache 65 | nosetests.xml 66 | coverage.xml 67 | *.cover 68 | *.py,cover 69 | .hypothesis/ 70 | .pytest_cache/ 71 | 72 | # Translations 73 | *.mo 74 | *.pot 75 | 76 | # Django stuff: 77 | *.log 78 | local_settings.py 79 | db.sqlite3 80 | db.sqlite3-journal 81 | 82 | # Flask stuff: 83 | instance/ 84 | .webassets-cache 85 | 86 | # Scrapy stuff: 87 | .scrapy 88 | 89 | # Sphinx documentation 90 | docs/_build/ 91 | 92 | # PyBuilder 93 | target/ 94 | 95 | # Jupyter Notebook 96 | .ipynb_checkpoints 97 | 98 | # IPython 99 | profile_default/ 100 | ipython_config.py 101 | 102 | # pyenv 103 | .python-version 104 | 105 | # pipenv 106 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 107 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 108 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 109 | # install all needed dependencies. 110 | #Pipfile.lock 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | -------------------------------------------------------------------------------- /recipe/20250122_marine-plus/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -ue 4 | 5 | script_dir=$(cd $(dirname ${BASH_SOURCE:-$0}); pwd) 6 | MARINE_ROOT=$script_dir/../.. 7 | COMMON_ROOT=$script_dir/../common 8 | DATABASE_DIR=$script_dir/database 9 | 10 | stage=0 11 | stop_stage=0 12 | 13 | # Setup parameter for experiment 14 | ## Parameters for dataset 15 | accent_status_seq_level="mora" 16 | accent_status_represent_mode="binary" 17 | feature_table_key="open-jtalk" 18 | ## When exist_vocab_dir given, this parameter will be ignored 19 | vocab_min_freq=2 20 | ## When exist_target_id_dir given, this parameter will be ignored 21 | val_test_size=100 22 | 23 | jsut_script_path=$script_dir/data 24 | output_dir=$script_dir/outputs 25 | tag=20250122_marine-plus 26 | 27 | exist_vocab_dir="" # 常に再構築 28 | exist_feature_dir="" # 常に再構築 29 | exist_target_id_dir="" # 常に再構築 30 | 31 | . $COMMON_ROOT/parse_options.sh || exit 1 32 | 33 | # Parepare output files 34 | output_root=$output_dir/$tag 35 | mkdir -p $output_root 36 | 37 | # Setup output directory 38 | raw_corpus_dir=$output_root/raw 39 | model_dir=$output_root/model 40 | feature_pack_dir=$output_root/feature_pack 41 | tensorboard_dir=$output_root/tensorboard 42 | test_log_dir=$output_root/log 43 | jsut_script_basename="$(basename $jsut_script_path)" 44 | in_domain_test_log_dir=$test_log_dir/${jsut_script_basename%.*} 45 | 46 | # update vocab_path 47 | if [ -z ${exist_vocab_dir} ]; then 48 | vocab_dir=$output_root/vocab 49 | else 50 | vocab_dir=$exist_vocab_dir 51 | fi 52 | 53 | vocab_path=$vocab_dir/vocab.pkl 54 | 55 | # update feature file 56 | if [ -z ${exist_feature_dir} ]; then 57 | feature_file_dir=$output_root/feature 58 | else 59 | feature_file_dir=$exist_feature_dir 60 | fi 61 | 62 | 63 | # Setup hydra config for training 64 | train=basic 65 | data=mora_based_seq 66 | model=mtl_lstm_encoder_crf_decoder 67 | criterions=loglikehood 68 | optim=adam 69 | 70 | # Setup directory for test 71 | checkpoint_dir=$model_dir/$tag 72 | 73 | 74 | if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then 75 | echo "stage 1: Convert raw corpus to json" 76 | . $COMMON_ROOT/make_raw_corpus.sh 77 | fi 78 | 79 | if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ] && [ -z ${exist_feature_dir} ]; then 80 | echo "stage 2: Extract feature" 81 | python $MARINE_ROOT/marine/bin/prepare_features_pyopenjtalk.py $raw_corpus_dir $feature_file_dir 82 | fi 83 | 84 | if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ] && [ -z ${exist_vocab_dir} ]; then 85 | echo "stage 3: Build vocabulary" 86 | . $COMMON_ROOT/build_vocab.sh 87 | fi 88 | 89 | if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then 90 | echo "stage 4: Feature generation" 91 | . $COMMON_ROOT/pack_corpus.sh 92 | fi 93 | 94 | if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then 95 | echo "stage 5: Train model and test" 96 | . $COMMON_ROOT/train.sh 97 | fi 98 | -------------------------------------------------------------------------------- /recipe/20220912_release/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -ue 4 | 5 | script_dir=$(cd $(dirname ${BASH_SOURCE:-$0}); pwd) 6 | MARINE_ROOT=$script_dir/../.. 7 | COMMON_ROOT=$script_dir/../common 8 | DATABASE_DIR=$script_dir/database 9 | 10 | stage=0 11 | stop_stage=0 12 | 13 | # Setup parameter for experiment 14 | ## Parameters for dataset 15 | accent_status_seq_level="mora" 16 | accent_status_represent_mode="binary" 17 | feature_table_key="open-jtalk" 18 | ## When exist_vocab_dir given, this parameter will be ignored 19 | vocab_min_freq=2 20 | ## When exist_target_id_dir given, this parameter will be ignored 21 | val_test_size=100 22 | 23 | jsut_script_path=$HOME/data 24 | output_dir=$script_dir/outputs 25 | tag=20220912_release 26 | 27 | exist_vocab_dir=$COMMON_ROOT/database/20220912_jsut_vocab_min_2 28 | exist_feature_dir=$COMMON_ROOT/database/20220912_auto_annotated_feature 29 | exist_target_id_dir=$COMMON_ROOT/database/20220912_jsut_script_ids 30 | 31 | . $COMMON_ROOT/parse_options.sh || exit 1 32 | 33 | # Parepare output files 34 | output_root=$output_dir/$tag 35 | mkdir -p $output_root 36 | 37 | # Setup output directory 38 | raw_corpus_dir=$output_root/raw 39 | model_dir=$output_root/model 40 | feature_pack_dir=$output_root/feature_pack 41 | tensorboard_dir=$output_root/tensorboard 42 | test_log_dir=$output_root/log 43 | jsut_script_basename="$(basename $jsut_script_path)" 44 | in_domain_test_log_dir=$test_log_dir/${jsut_script_basename%.*} 45 | 46 | # update vocab_path 47 | if [ -z ${exist_vocab_dir} ]; then 48 | vocab_dir=$output_root/vocab 49 | else 50 | vocab_dir=$exist_vocab_dir 51 | fi 52 | 53 | vocab_path=$vocab_dir/vocab.pkl 54 | 55 | # update feature file 56 | if [ -z ${exist_feature_dir} ]; then 57 | feature_file_dir=$output_root/feature 58 | else 59 | feature_file_dir=$exist_feature_dir 60 | fi 61 | 62 | 63 | # Setup hydra config for training 64 | train=basic 65 | data=mora_based_seq 66 | model=mtl_lstm_encoder_crf_decoder 67 | criterions=loglikehood 68 | optim=adam 69 | 70 | # Setup directory for test 71 | checkpoint_dir=$model_dir/$tag 72 | 73 | 74 | if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then 75 | echo "stage 1: Convert jsut to json" 76 | . $COMMON_ROOT/jsut2corpus.sh 77 | fi 78 | 79 | if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ] && [ -z ${exist_feature_dir} ]; then 80 | echo "stage 2: Extract feature" 81 | python $MARINE_ROOT/marine/bin/prepare_features_pyopenjtalk.py $raw_corpus_dir $feature_file_dir 82 | fi 83 | 84 | if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ] && [ -z ${exist_vocab_dir} ]; then 85 | echo "stage 3: Build vocabulary" 86 | . $COMMON_ROOT/build_vocab.sh 87 | fi 88 | 89 | if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then 90 | echo "stage 4: Feature generation" 91 | . $COMMON_ROOT/pack_corpus.sh 92 | fi 93 | 94 | if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then 95 | echo "stage 5: Train model and test" 96 | . $COMMON_ROOT/train.sh 97 | fi 98 | -------------------------------------------------------------------------------- /marine/models/linear_decoder.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Mapping, Sequence 2 | 3 | from torch import Tensor, cat, nn 4 | 5 | 6 | class LinearDecoder(nn.Module): 7 | def __init__( 8 | self, 9 | input_size: int, 10 | output_size: int, 11 | prev_task_embedding_label_list: Sequence[str] | None = None, 12 | prev_task_embedding_label_size: Mapping[str, int] | None = None, 13 | prev_task_embedding_size: Mapping[str, int] | None = None, 14 | prev_task_dropout: float | None = None, 15 | padding_idx: int = 0, 16 | ) -> None: 17 | super().__init__() 18 | if ( 19 | prev_task_embedding_label_size 20 | and prev_task_embedding_label_list 21 | and prev_task_embedding_size 22 | ): 23 | embeddings = {} 24 | dropouts = {} 25 | 26 | for key in prev_task_embedding_label_list: 27 | embeddings[key] = nn.Embedding( 28 | prev_task_embedding_label_size[key], 29 | prev_task_embedding_size[key], 30 | padding_idx=padding_idx, 31 | ) 32 | input_size += prev_task_embedding_size[key] 33 | 34 | if prev_task_dropout: 35 | dropouts[key] = nn.Dropout(prev_task_dropout) 36 | 37 | self.prev_task_embedding = nn.ModuleDict(embeddings) 38 | 39 | if len(dropouts) > 0: 40 | self.prev_task_dropout = nn.ModuleDict(dropouts) 41 | else: 42 | self.prev_task_dropout = None 43 | 44 | else: 45 | self.prev_task_embedding = None 46 | self.prev_task_dropout = None 47 | 48 | # NOTE: output_size must includes size for [PAD] 49 | self.linear = nn.Linear(input_size, output_size, bias=True) 50 | 51 | def forward( 52 | self, 53 | logits: Tensor, 54 | mask: Tensor, 55 | prev_decoder_outputs: dict[str, Tensor] | None = None, 56 | decoder_targets: Tensor | None = None, 57 | ) -> Tensor: 58 | if self.prev_task_embedding is not None and prev_decoder_outputs is not None: 59 | prev_decoder_output_embs = [] 60 | 61 | for key in self.prev_task_embedding.keys(): 62 | prev_decoder_output = prev_decoder_outputs[key] 63 | prev_decoder_output_emb = self.prev_task_embedding[key]( 64 | prev_decoder_output 65 | ) 66 | 67 | if self.prev_task_dropout: 68 | prev_decoder_output_emb = self.prev_task_dropout[key]( 69 | prev_decoder_output_emb 70 | ) 71 | 72 | prev_decoder_output_embs.append(prev_decoder_output_emb) 73 | 74 | logits = cat([logits] + prev_decoder_output_embs, dim=2) 75 | 76 | # Linear -> B * T * Output-size 77 | linear_logits = self.linear(logits) 78 | 79 | return linear_logits 80 | -------------------------------------------------------------------------------- /marine/modules/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | 6 | class ZoneOutCell(nn.Module): 7 | def __init__(self, cell, zoneout=0.1): 8 | super().__init__() 9 | self.cell = cell 10 | self.hidden_size = cell.hidden_size 11 | self.zoneout = zoneout 12 | 13 | def forward(self, inputs, hidden): 14 | next_hidden = self.cell(inputs, hidden) 15 | next_hidden = self._zoneout(hidden, next_hidden, self.zoneout) 16 | return next_hidden 17 | 18 | def _zoneout(self, h, next_h, prob): 19 | h_0, c_0 = h 20 | h_1, c_1 = next_h 21 | h_1 = self._apply_zoneout(h_0, h_1, prob) 22 | c_1 = self._apply_zoneout(c_0, c_1, prob) 23 | return h_1, c_1 24 | 25 | def _apply_zoneout(self, h, next_h, prob): 26 | if self.training: 27 | mask = h.new(*h.size()).bernoulli_(prob) 28 | return mask * h + (1 - mask) * next_h 29 | else: 30 | return prob * h + (1 - prob) * next_h 31 | 32 | 33 | class BahdanauAttention(nn.Module): 34 | """Bahdanau-style attention 35 | This is an attention mechanism originally used in Tacotron. 36 | Args: 37 | encoder_dim (int): dimension of encoder outputs 38 | decoder_dim (int): dimension of decoder outputs 39 | hidden_dim (int): dimension of hidden state 40 | """ 41 | 42 | def __init__(self, encoder_dim=512, decoder_dim=1024, hidden_dim=128): 43 | super().__init__() 44 | self.mlp_enc = nn.Linear(encoder_dim, hidden_dim) 45 | self.mlp_dec = nn.Linear(decoder_dim, hidden_dim, bias=False) 46 | self.w = nn.Linear(hidden_dim, 1) 47 | 48 | self.processed_memory = None 49 | 50 | def reset(self): 51 | """Reset the internal buffer""" 52 | self.processed_memory = None 53 | 54 | def forward( 55 | self, 56 | encoder_outs, 57 | decoder_state, 58 | mask=None, 59 | ): 60 | """Forward step 61 | Args: 62 | encoder_outs (torch.FloatTensor): encoder outputs 63 | src_lens (list): length of each input batch 64 | decoder_state (torch.FloatTensor): decoder hidden state 65 | mask (torch.FloatTensor): mask for padding 66 | """ 67 | 68 | if self.processed_memory is None: 69 | self.processed_memory = self.mlp_enc(encoder_outs) 70 | 71 | decoder_state = self.mlp_dec(decoder_state).unsqueeze(1) 72 | 73 | erg = self.w(torch.tanh(self.processed_memory + decoder_state)).squeeze(-1) 74 | 75 | if mask is not None: 76 | # invert mask 77 | mask = ~mask 78 | erg.masked_fill_(mask, -float("inf")) 79 | 80 | attention_weights = F.softmax(erg, dim=1) 81 | 82 | attention_context = torch.sum( 83 | encoder_outs * attention_weights.unsqueeze(-1), dim=1 84 | ) 85 | 86 | return attention_context, attention_weights 87 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import codecs 2 | import re 3 | from os.path import exists, join 4 | 5 | from setuptools import find_packages, setup 6 | 7 | 8 | def find_version(*file_paths: str) -> str: 9 | with codecs.open(join(*file_paths), "r") as fp: 10 | version_file = fp.read() 11 | version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", version_file, re.M) 12 | if version_match: 13 | return version_match.group(1) 14 | raise RuntimeError("Unable to find version string.") 15 | 16 | 17 | if exists("README.md"): 18 | with open("README.md", encoding="utf-8") as fh: 19 | LONG_DESC = LONG_DESC = fh.read() 20 | else: 21 | LONG_DESC = "" 22 | 23 | 24 | setup( 25 | name="marine-plus", 26 | version=find_version("marine", "__init__.py"), 27 | description="Marine: Multi-task learning based on Japanese accent estimation (Also supports Windows and Python 3.12)", 28 | packages=find_packages(), 29 | author="Byeongseon Park", 30 | author_email="6gsn.park@gmail.com", 31 | long_description=LONG_DESC, 32 | long_description_content_type="text/markdown", 33 | include_package_data=True, 34 | install_requires=[ 35 | "importlib_resources; python_version<'3.9'", 36 | "numpy >= 1.21.0, <2", 37 | "torch >= 1.7.0", 38 | "hydra-core >= 1.1.0", 39 | "hydra_colorlog >= 1.1.0", 40 | "tqdm", 41 | "joblib", 42 | "pykakasi >= 2.3.0", 43 | ], 44 | extras_require={ 45 | "dev": [ 46 | "torchmetrics", 47 | "scikit-learn", 48 | "docstr-coverage", 49 | "tensorboard", 50 | "matplotlib", 51 | "pytest", 52 | "pytest-cov", 53 | "docstr-coverage", 54 | "ruff", 55 | "taskipy", 56 | "click", 57 | "pandas", 58 | "httpx", 59 | ], 60 | "docs": [ 61 | "sphinx", 62 | "sphinx-autobuild", 63 | "sphinx_rtd_theme", 64 | "nbsphinx>=0.8.6", 65 | "Jinja2>=3.0.1", 66 | "pandoc", 67 | "ipython", 68 | "jupyter", 69 | ], 70 | "pyopenjtalk": ["pyopenjtalk-plus"], 71 | }, 72 | entry_points={ 73 | "console_scripts": [ 74 | "marine-make-raw-corpus = marine.bin.make_raw_corpus:entry", 75 | "marine-jsut2corpus = marine.bin.jsut2corpus:entry", 76 | "marine-build-vocab = marine.bin.build_vocab:entry", 77 | "marine-pack-corpus = marine.bin.pack_corpus:entry", 78 | "marine-train = marine.bin.train:entry", 79 | "marine-test = marine.bin.test:entry", 80 | ], 81 | }, 82 | classifiers=[ 83 | "Operating System :: Unix", 84 | "Operating System :: MacOS", 85 | "Programming Language :: Python", 86 | "Programming Language :: Python :: 3", 87 | "Programming Language :: Python :: 3.10", 88 | "Programming Language :: Python :: 3.11", 89 | "Programming Language :: Python :: 3.12", 90 | "License :: OSI Approved :: Apache Software License", 91 | "Topic :: Scientific/Engineering", 92 | "Topic :: Software Development", 93 | "Intended Audience :: Science/Research", 94 | "Intended Audience :: Developers", 95 | ], 96 | ) 97 | -------------------------------------------------------------------------------- /marine/data/pad.py: -------------------------------------------------------------------------------- 1 | from typing import Any, cast 2 | 3 | import numpy as np 4 | import torch 5 | from numpy.typing import NDArray 6 | from torch.nn.utils.rnn import pad_sequence 7 | 8 | from marine.types import BatchFeature, BatchItem, PadFeature, PadOutputs 9 | 10 | 11 | class Padsequence: 12 | def __init__( 13 | self, 14 | input_keys: list[str], 15 | input_length_key: str, 16 | output_keys: list[str], 17 | num_classes: int, 18 | is_inference: bool = False, 19 | padding_idx: int = 0, 20 | ) -> None: 21 | self.input_keys = input_keys 22 | self.input_length_key = input_length_key 23 | self.output_keys = output_keys 24 | self.num_classes = num_classes 25 | self.is_inference = is_inference 26 | self.padding_idx = padding_idx 27 | 28 | def pad_feature(self, inputs: list[BatchFeature]) -> PadFeature: 29 | padded_feature: dict[str, Any] = {} 30 | 31 | for key in self.input_keys: 32 | feature = [ 33 | torch.tensor(features[key], dtype=torch.int64) for features in inputs 34 | ] 35 | 36 | if key in self.input_length_key: 37 | padded_feature[f"{key}_length"] = torch.tensor( 38 | [len(f) for f in feature], dtype=torch.int64 39 | ) 40 | 41 | padded_x = pad_sequence( 42 | feature, 43 | batch_first=True, 44 | padding_value=self.padding_idx, 45 | ) 46 | padded_feature[key] = padded_x 47 | 48 | return cast(PadFeature, padded_feature) 49 | 50 | def __call__( 51 | self, batch: list[BatchItem] 52 | ) -> tuple[ 53 | PadFeature, PadOutputs | None, list[NDArray[np.uint8]], list[str] | None 54 | ]: 55 | # sort by length 56 | if not self.is_inference: 57 | batch = sorted( 58 | batch, 59 | key=lambda x: len(x["features"][self.input_length_key]), 60 | reverse=True, 61 | ) 62 | 63 | inputs = [x["features"] for x in batch] 64 | padded_inputs = self.pad_feature(inputs) 65 | 66 | if not self.is_inference: 67 | padded_outputs = cast( 68 | PadOutputs, 69 | { 70 | key: { 71 | "label": pad_sequence( 72 | [ 73 | # Convert 1-based label (for pad) 74 | torch.tensor(x["labels"][key] + 1, dtype=torch.long) # type: ignore 75 | for x in batch 76 | ], 77 | batch_first=True, 78 | padding_value=self.padding_idx, 79 | ), 80 | "length": torch.tensor([len(x["labels"][key]) for x in batch]), # type: ignore 81 | } 82 | for key in self.output_keys 83 | }, 84 | ) 85 | script_ids = [x["ids"] for x in batch if x["ids"] is not None] # type: ignore 86 | else: 87 | padded_outputs = None 88 | script_ids = None 89 | 90 | morph_boundary = [x["features"]["morph_boundary"] for x in batch] 91 | 92 | return padded_inputs, padded_outputs, morph_boundary, script_ids 93 | -------------------------------------------------------------------------------- /marine/bin/prepare_features_pyopenjtalk.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import sys 4 | from concurrent.futures import ProcessPoolExecutor 5 | from multiprocessing import cpu_count 6 | from pathlib import Path 7 | 8 | from tqdm import tqdm 9 | 10 | from marine.data.dictionary_downloader import download_and_apply_dictionaries 11 | from marine.logger import getLogger 12 | from marine.utils.openjtalk_util import convert_open_jtalk_node_to_feature 13 | from marine.utils.util import load_json_corpus 14 | 15 | 16 | # download and apply OpenJTalk dictionaries 17 | download_and_apply_dictionaries() 18 | 19 | logger = None 20 | 21 | 22 | def get_parser(): 23 | parser = argparse.ArgumentParser( 24 | description="Convert Special format txt format data to json file", 25 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 26 | ) 27 | parser.add_argument("in_path", type=Path, help="Input path or directory") 28 | parser.add_argument("out_dir", type=Path, help="Output directory") 29 | parser.add_argument("--n_jobs", type=int, default=8, help="Number of jobs") 30 | parser.add_argument( 31 | "--verbose", 32 | "-v", 33 | type=int, 34 | default=50, 35 | help="Logging level", 36 | ) 37 | return parser 38 | 39 | 40 | def extract_feature(script_id, text): 41 | features = {"script_id": script_id, "nodes": []} 42 | 43 | try: 44 | from pyopenjtalk import run_frontend 45 | except BaseException: 46 | raise ImportError( 47 | 'Please install pyopenjtalk by `pip install -e ".[dev,pyopenjtalk]"`' 48 | ) 49 | 50 | # drop full-context label 51 | nodes = run_frontend(text) 52 | features["nodes"] = convert_open_jtalk_node_to_feature(nodes) 53 | 54 | return features 55 | 56 | 57 | def _sort_corpus_by_script_id(corpus): 58 | return list(sorted(corpus, key=lambda x: x["script_id"])) 59 | 60 | 61 | def entry(argv=sys.argv): 62 | global logger 63 | 64 | args = get_parser().parse_args(argv[1:]) 65 | logger = getLogger(args.verbose) 66 | logger.debug(args) 67 | 68 | # Process 69 | n_jobs = min(cpu_count(), args.n_jobs) 70 | 71 | if not args.out_dir.exists(): 72 | args.out_dir.mkdir(parents=True) 73 | 74 | # Load corpus 75 | corpus = load_json_corpus(args.in_path) 76 | 77 | if n_jobs > 1: 78 | logger.info(f"Processing {len(corpus):,} scripts with {n_jobs} jobs") 79 | with ProcessPoolExecutor(n_jobs) as executor: 80 | futures = [ 81 | executor.submit( 82 | extract_feature, 83 | script["script_id"], 84 | script["surface"], 85 | ) 86 | for script in corpus 87 | ] 88 | corpus = [ 89 | future.result() 90 | for future in tqdm( 91 | futures, desc="Convert corpus to feature", leave=False 92 | ) 93 | ] 94 | else: 95 | logger.info(f"Processing {len(corpus):,} scripts in a single thread") 96 | corpus = [ 97 | extract_feature(script["script_id"], script["surface"]) for script in corpus 98 | ] 99 | 100 | # corpus = _sort_corpus_by_script_id(corpus) 101 | 102 | output_path = args.out_dir / "feature.json" 103 | 104 | with open(output_path, "w", encoding="utf-8") as file: 105 | json.dump(corpus, file, ensure_ascii=False, indent=4, separators=(",", ": ")) 106 | 107 | 108 | if __name__ == "__main__": 109 | sys.exit(entry()) 110 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # marine-plus 2 | 3 | [![PyPI](https://img.shields.io/pypi/v/marine-plus.svg)](https://pypi.python.org/pypi/marine-plus) 4 | [![Python package](https://github.com/tsukumijima/marine-plus/actions/workflows/ci.yml/badge.svg)](https://github.com/tsukumijima/marine-plus/actions/workflows/ci.yml) 5 | [![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](LICENSE) 6 | 7 | marine-plus は、主に Windows 対応や新しい Python バージョンのサポートなどコードのメンテナンスを目的とした、[marine](https://github.com/6gsn/marine) の派生ライブラリです。 8 | 9 | ## Installation 10 | 11 | 下記コマンドを実行して、ライブラリをインストールできます。 12 | 13 | ```bash 14 | pip install marine-plus 15 | ``` 16 | 17 | 下記のドキュメントは、[marine](https://github.com/6gsn/marine) 本家のドキュメントを改変なしでそのまま引き継いでいます。 18 | これらのドキュメントの内容が marine-plus にも通用するかは保証されません。 19 | 20 | ------- 21 | 22 | # **MARINE** : **M**ulti-task lea**R**n**I**ng-based Japa**N**ese accent **E**stimation 23 | 24 | [![PyPI](https://img.shields.io/pypi/v/marine.svg)](https://pypi.python.org/pypi/marine) 25 | [![Python package](https://github.com/6gsn/marine/actions/workflows/ci.yml/badge.svg)](https://github.com/6gsn/marine/actions/workflows/ci.yml) 26 | [![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](LICENSE) 27 | [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.7092054.svg)](https://doi.org/10.5281/zenodo.7092054) 28 | 29 | `marine` is a tool kit for building the Japanese accent estimation model proposed in [our paper](https://www.isca-speech.org/archive/interspeech_2022/park22b_interspeech.html) ([demo](https://6gsn.github.io/demos/mtl_accent/)). 30 | 31 | For academic use, please cite the following paper ([ISCA archive](https://www.isca-speech.org/archive/interspeech_2022/park22b_interspeech.html)). 32 | 33 | ```bibtex 34 | @inproceedings{park22b_interspeech, 35 | author={Byeongseon Park and Ryuichi Yamamoto and Kentaro Tachibana}, 36 | title={{A Unified Accent Estimation Method Based on Multi-Task Learning for Japanese Text-to-Speech}}, 37 | year=2022, 38 | booktitle={Proc. Interspeech 2022}, 39 | pages={1931--1935}, 40 | doi={10.21437/Interspeech.2022-334} 41 | } 42 | ``` 43 | 44 | ## Notice 45 | 46 | The model included in this package is trained using [JSUT corpus](https://sites.google.com/site/shinnosuketakamichi/publication/jsut), which is not the same as the dataset in [our paper](https://www.isca-speech.org/archive/interspeech_2022/park22b_interspeech.html). Therefore, the model's performance is also not equal to the performance introduced in our paper. 47 | 48 | ## Get started 49 | 50 | ### Installation 51 | 52 | ```shell 53 | $ pip install marine 54 | ``` 55 | 56 | ### For development 57 | 58 | ```shell 59 | $ pip install -e ".[dev]" 60 | ``` 61 | 62 | ### Quick demo 63 | 64 | ```python 65 | In [1]: from marine.predict import Predictor 66 | 67 | In [2]: nodes = [{"surface": "こんにちは", "pos": "感動詞:*:*:*", "pron": "コンニチワ", "c_type": "*", "c_form": "*", "accent_type": 0, "accent_con_type": "-1", "chain_flag": -1}] 68 | 69 | In [3]: predictor = Predictor() 70 | 71 | In [4]: predictor.predict([nodes]) 72 | Out[4]: 73 | {'mora': [['コ', 'ン', 'ニ', 'チ', 'ワ']], 74 | 'intonation_phrase_boundary': [[0, 0, 0, 0, 0]], 75 | 'accent_phrase_boundary': [[0, 0, 0, 0, 0]], 76 | 'accent_status': [[0, 0, 0, 0, 0]]} 77 | 78 | In [5]: predictor.predict([nodes], accent_represent_mode="high_low") 79 | Out[5]: 80 | {'mora': [['コ', 'ン', 'ニ', 'チ', 'ワ']], 81 | 'intonation_phrase_boundary': [[0, 0, 0, 0, 0]], 82 | 'accent_phrase_boundary': [[0, 0, 0, 0, 0]], 83 | 'accent_status': [[0, 1, 1, 1, 1]]} 84 | ``` 85 | 86 | ### Build model yourself 87 | 88 | Coming soon... 89 | 90 | ## LICENSE 91 | 92 | - marine: Apache 2.0 license ([LICENSE](LICENSE)) 93 | - JSUT: CC-BY-SA 4.0 license, etc. (Please check [jsut-label/LICENCE.txt](https://github.com/sarulab-speech/jsut-label/blob/master/LICENCE.txt)) 94 | -------------------------------------------------------------------------------- /marine/models/crf_decoder.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Mapping, Sequence 2 | from typing import cast 3 | 4 | from torch import BoolTensor, Tensor, cat, nn 5 | 6 | from marine.modules.crf_tagger import ConditionalRandomField 7 | 8 | 9 | def _broadcast_tags( 10 | predicted_tags: Sequence[Sequence[int]], classfied: Tensor 11 | ) -> Tensor: 12 | class_probabilities = classfied * 0.0 13 | 14 | for i, instance_tags in enumerate(predicted_tags): 15 | for j, tag_id in enumerate(instance_tags): 16 | class_probabilities[i, j, tag_id] = 1 17 | 18 | return class_probabilities 19 | 20 | 21 | class CRFDecoder(nn.Module): 22 | def __init__( 23 | self, 24 | input_size: int, 25 | output_size: int, 26 | prev_task_embedding_label_list: Sequence[str] | None = None, 27 | prev_task_embedding_label_size: Mapping[str, int] | None = None, 28 | prev_task_embedding_size: Mapping[str, int] | None = None, 29 | prev_task_dropout: float | None = None, 30 | padding_idx: int = 0, 31 | ) -> None: 32 | super().__init__() 33 | if ( 34 | prev_task_embedding_label_size 35 | and prev_task_embedding_label_list 36 | and prev_task_embedding_size 37 | ): 38 | embeddings = {} 39 | dropouts = {} 40 | for key in prev_task_embedding_label_list: 41 | embeddings[key] = nn.Embedding( 42 | prev_task_embedding_label_size[key], 43 | prev_task_embedding_size[key], 44 | padding_idx=padding_idx, 45 | ) 46 | input_size += prev_task_embedding_size[key] 47 | 48 | if prev_task_dropout: 49 | dropouts[key] = nn.Dropout(prev_task_dropout) 50 | 51 | self.prev_task_embedding = nn.ModuleDict(embeddings) 52 | 53 | if len(dropouts) > 0: 54 | self.prev_task_dropout = nn.ModuleDict(dropouts) 55 | else: 56 | self.prev_task_dropout = None 57 | 58 | else: 59 | self.prev_task_embedding = None 60 | self.prev_task_dropout = None 61 | 62 | # NOTE: output_size must includes size for [PAD] 63 | self.linear = nn.Linear(input_size, output_size) 64 | self.crf = ConditionalRandomField(output_size) 65 | 66 | def forward( 67 | self, 68 | logits: Tensor, 69 | mask: Tensor, 70 | prev_decoder_outputs: dict[str, Tensor] | None = None, 71 | decoder_targets: dict[str, Tensor] | None = None, 72 | ) -> tuple[Tensor, Tensor]: 73 | if self.prev_task_embedding is not None and prev_decoder_outputs is not None: 74 | prev_decoder_output_embs = [] 75 | for key in self.prev_task_embedding.keys(): 76 | prev_decoder_output = prev_decoder_outputs[key] 77 | prev_decoder_output_emb = self.prev_task_embedding[key]( 78 | prev_decoder_output 79 | ) 80 | 81 | if self.prev_task_dropout: 82 | prev_decoder_output_emb = self.prev_task_dropout[key]( 83 | prev_decoder_output_emb 84 | ) 85 | 86 | prev_decoder_output_embs.append(prev_decoder_output_emb) 87 | 88 | logits = cat([logits] + prev_decoder_output_embs, dim=2) 89 | 90 | # Linear -> B * T * Output-size 91 | linear_logits = self.linear(logits) 92 | 93 | # CRFs 94 | best_paths = self.crf.viterbi_tags(linear_logits, cast(BoolTensor, mask)) 95 | crf_logits = [x for x, _ in best_paths] 96 | crf_logits = _broadcast_tags(cast(list[list[int]], crf_logits), linear_logits) 97 | 98 | return linear_logits, crf_logits 99 | -------------------------------------------------------------------------------- /recipe/common/parse_options.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright 2012 Johns Hopkins University (Author: Daniel Povey); 4 | # Arnab Ghoshal, Karel Vesely 5 | 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 13 | # KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED 14 | # WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, 15 | # MERCHANTABLITY OR NON-INFRINGEMENT. 16 | # See the Apache 2 License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | 20 | # Parse command-line options. 21 | # To be sourced by another script (as in ". parse_options.sh"). 22 | # Option format is: --option-name arg 23 | # and shell variable "option_name" gets set to value "arg." 24 | # The exception is --help, which takes no arguments, but prints the 25 | # $help_message variable (if defined). 26 | 27 | 28 | ### 29 | ### The --config file options have lower priority to command line 30 | ### options, so we need to import them first... 31 | ### 32 | 33 | # Now import all the configs specified by command-line, in left-to-right order 34 | for ((argpos=1; argpos<$#; argpos++)); do 35 | if [ "${!argpos}" == "--config" ]; then 36 | argpos_plus1=$((argpos+1)) 37 | config=${!argpos_plus1} 38 | [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1 39 | . $config # source the config file. 40 | fi 41 | done 42 | 43 | 44 | ### 45 | ### Now we process the command line options 46 | ### 47 | while true; do 48 | [ -z "${1:-}" ] && break; # break if there are no arguments 49 | case "$1" in 50 | # If the enclosing script is called with --help option, print the help 51 | # message and exit. Scripts should put help messages in $help_message 52 | --help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2; 53 | else printf "$help_message\n" 1>&2 ; fi; 54 | exit 0 ;; 55 | --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'" 56 | exit 1 ;; 57 | # If the first command-line argument begins with "--" (e.g. --foo-bar), 58 | # then work out the variable name as $name, which will equal "foo_bar". 59 | --*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`; 60 | # Next we test whether the variable in question is undefned-- if so it's 61 | # an invalid option and we die. Note: $0 evaluates to the name of the 62 | # enclosing script. 63 | # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar 64 | # is undefined. We then have to wrap this test inside "eval" because 65 | # foo_bar is itself inside a variable ($name). 66 | eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; 67 | 68 | oldval="`eval echo \\$$name`"; 69 | # Work out whether we seem to be expecting a Boolean argument. 70 | if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then 71 | was_bool=true; 72 | else 73 | was_bool=false; 74 | fi 75 | 76 | # Set the variable to the right value-- the escaped quotes make it work if 77 | # the option had spaces, like --cmd "queue.pl -sync y" 78 | eval $name=\"$2\"; 79 | 80 | # Check that Boolean-valued arguments are really Boolean. 81 | if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then 82 | echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 83 | exit 1; 84 | fi 85 | shift 2; 86 | ;; 87 | *) break; 88 | esac 89 | done 90 | 91 | 92 | # Check for an empty argument to the --cmd option, which can easily occur as a 93 | # result of scripting errors. 94 | [ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1; 95 | 96 | 97 | true; # so this script returns exit code 0. -------------------------------------------------------------------------------- /marine/models/util.py: -------------------------------------------------------------------------------- 1 | from logging import getLogger 2 | from typing import Any 3 | 4 | from hydra.utils import instantiate 5 | from omegaconf import DictConfig 6 | from torch import nn 7 | from torch.optim import Optimizer 8 | 9 | from marine.data.feature.feature_set import FeatureSet 10 | 11 | 12 | logger = getLogger(__name__) 13 | 14 | 15 | def init_model( 16 | tasks: list[str], 17 | config: DictConfig, 18 | feature_set: FeatureSet, 19 | device: str, 20 | is_train: bool = False, 21 | ) -> nn.Module | tuple[nn.Module, dict[str, nn.Module], Optimizer, Any]: 22 | if is_train: 23 | criterions = {} 24 | optimizer = None 25 | scheduler = None 26 | 27 | # setting shared layers 28 | # Embedding 29 | embedding_kwargs = {"feature_set": feature_set} 30 | embedding = instantiate(config.model.embedding, **embedding_kwargs) 31 | 32 | # Encoder 33 | encooder_input_size = sum( 34 | [config.model.embedding.embeding_sizes[key] for key in config.data.input_keys] 35 | ) 36 | 37 | encoder_output_size = config.model.encoder.param.hidden_size * 2 38 | encoder_kwargs = {"input_size": encooder_input_size} 39 | 40 | encoders = {} 41 | for task in tasks: 42 | if ( 43 | config.model.encoder.shared_with[task] 44 | and config.model.encoder.shared_with[task] in encoders.keys() 45 | ): 46 | logger.info( 47 | f"{task} has shared encoder with {config.model.encoder.shared_with[task]}" 48 | ) 49 | encoders[task] = encoders[config.model.encoder.shared_with[task]] 50 | else: 51 | encoder = instantiate(config.model.encoder.param, **encoder_kwargs) 52 | encoders[task] = encoder 53 | 54 | # Decoder 55 | decoder_input_sizes = [encoder_output_size] * len(tasks) 56 | 57 | decoder_kwargs = { 58 | task: { 59 | "input_size": decoder_input_size, 60 | "output_size": config.data.output_sizes[task], 61 | } 62 | for task, decoder_input_size in zip(tasks, decoder_input_sizes) 63 | } 64 | 65 | # init decoders with criterion 66 | decoders = {} 67 | for task in tasks: 68 | # init embedding for 69 | if config.model.decoder[task]["prev_task_embedding_label_list"]: 70 | assert config.model.decoder[task]["prev_task_embedding_label_list"] == list( 71 | config.model.decoder[task]["prev_task_embedding_size"].keys() 72 | ), "Not matched embedding setting for previous tasks: {} != {}".format( 73 | config.model.decoder[task]["prev_task_embedding_label_list"], 74 | config.model.decoder[task]["prev_task_embedding_size"].keys(), 75 | ) 76 | 77 | decoder_kwargs[task]["prev_task_embedding_label_size"] = { 78 | task_label: decoder_kwargs[task_label]["output_size"] 79 | for task_label in config.model.decoder[task][ 80 | "prev_task_embedding_label_list" 81 | ] 82 | if task_label in tasks 83 | } 84 | 85 | decoders[task] = instantiate(config.model.decoder[task], **decoder_kwargs[task]) 86 | 87 | if is_train: 88 | if config.criterions[task]._target_ == "marine.criterions.LogLikelhood": 89 | criterion_kwargs = {"log_likehood_func": decoders[task].crf} 90 | else: 91 | criterion_kwargs = {} 92 | 93 | criterion = instantiate(config.criterions[task], **criterion_kwargs).to( 94 | device 95 | ) 96 | criterions[task] = criterion 97 | 98 | # Init base model 99 | model_kwarwgs = {"embedding": embedding, "encoders": encoders, "decoders": decoders} 100 | 101 | model = instantiate(config.model.base, **model_kwarwgs).to(device) 102 | 103 | logger.debug(f"model has initialized\n{model}") 104 | 105 | if is_train: 106 | optimizer_kwargs = {"params": model.parameters()} 107 | optimizer = instantiate(config.optim.optimizer, **optimizer_kwargs) 108 | 109 | scheduler_kwargs = {"optimizer": optimizer} 110 | scheduler = instantiate(config.optim.scheduler, **scheduler_kwargs) 111 | 112 | return model, criterions, optimizer, scheduler 113 | else: 114 | return model 115 | -------------------------------------------------------------------------------- /marine/types.py: -------------------------------------------------------------------------------- 1 | from typing import Literal, TypedDict 2 | 3 | import numpy as np 4 | from numpy.typing import NDArray 5 | from torch import Tensor 6 | 7 | 8 | # アクセント表現モード 9 | AccentRepresentMode = Literal[ 10 | "binary", # アクセント核位置を1、それ以外を0で表現 11 | "high_low", # 各モーラの高低を表現 (0=低 / 1=高) 12 | ] 13 | 14 | 15 | class NJDFeature(TypedDict): 16 | """OpenJTalk の形態素解析結果・アクセント推定結果を表す型""" 17 | 18 | string: str # 表層形 19 | pos: str # 品詞 20 | pos_group1: str # 品詞細分類1 21 | pos_group2: str # 品詞細分類2 22 | pos_group3: str # 品詞細分類3 23 | ctype: str # 活用型 24 | cform: str # 活用形 25 | orig: str # 原形 26 | read: str # 読み 27 | pron: str # 発音形式 28 | acc: int # アクセント型 (0: 平板型, 1-n: n番目のモーラにアクセント核) 29 | mora_size: int # モーラ数 30 | chain_rule: str # アクセント結合規則 31 | chain_flag: int # アクセント句の連結フラグ 32 | 33 | 34 | class MarineFeature(TypedDict): 35 | """marine 内部で使用する形態素素性を表す型""" 36 | 37 | surface: str # 表層形 38 | pron: str | None # 発音形式 39 | pos: str # 品詞 (例: "名詞:代名詞:一般:*") 40 | c_type: str # 活用型 41 | c_form: str # 活用形 42 | accent_type: int # アクセント型 43 | accent_con_type: str # アクセント結合型 44 | chain_flag: int # アクセント句の連結フラグ 45 | 46 | 47 | class OpenJTalkFormatLabel(TypedDict): 48 | """OpenJTalkフォーマットのラベルを表す型""" 49 | 50 | # fmt: off 51 | accent_status: list[int] # アクセント核位置 (0: 無核, 1-n: n番目のモーラにアクセント核) 52 | accent_phrase_boundary: list[Literal[-1, 0, 1]] # アクセント句境界 (-1: 文頭, 0: 非境界, 1: 境界) 53 | # fmt: on 54 | 55 | 56 | class MarineLabel(TypedDict): 57 | """marine 内部で使用するラベルを表す型""" 58 | 59 | # fmt: off 60 | mora: list[list[str]] # モーラ列 (例: [["コ", "ン", "ニ", "チ", "ワ"]]) 61 | intonation_phrase_boundary: list[list[Literal[0, 1]]] # イントネーション句境界 (0: 非境界, 1: 境界) 62 | accent_phrase_boundary: list[list[Literal[0, 1]]] # アクセント句境界 (0: 非境界, 1: 境界) 63 | accent_status: list[list[Literal[0, 1]]] # アクセント核位置 (binary: 0/1, high_low: 0/1) 64 | # fmt: on 65 | 66 | 67 | class AnnotateLabel(TypedDict): 68 | """アノテーションラベルを表す型""" 69 | 70 | token_type: Literal["morph", "mora"] # トークンの単位 71 | labels: list[list[int]] | list[Tensor] # ラベル列 72 | 73 | 74 | class PredictAnnotates(TypedDict, total=False): 75 | """推論時のアノテーションを表す型""" 76 | 77 | intonation_phrase_boundary: AnnotateLabel 78 | accent_phrase_boundary: AnnotateLabel 79 | accent_status: AnnotateLabel 80 | 81 | 82 | class BatchFeature(TypedDict): 83 | """バッチの特徴量を表す型""" 84 | 85 | # 必須フィールド 86 | morph_boundary: NDArray[np.uint8] # 形態素境界情報 87 | # config.data.input_keys に依存するフィールド (推論に用いるモデルの config.yaml 定義次第では省略される) 88 | mora: NDArray[np.uint8] # モーラ ID 列 89 | surface: NDArray[np.uint8] # 表層形 ID 列 90 | pos: NDArray[np.uint8] # 品詞 ID 列 91 | c_type: NDArray[np.uint8] # 活用型 ID 列 92 | c_form: NDArray[np.uint8] # 活用形 ID 列 93 | accent_type: NDArray[np.uint8] # アクセント型 ID 列 94 | accent_con_type: NDArray[np.uint8] # アクセント結合型 ID 列 95 | chain_flag: NDArray[np.uint8] # アクセント句の連結フラグ ID 列 (現在の学習レシピでは未使用) # fmt: skip 96 | 97 | 98 | class BatchItem(TypedDict): 99 | """バッチの各要素を表す型""" 100 | 101 | features: BatchFeature # 特徴量 102 | labels: dict[str, list[int]] | None # ラベル (推論時は None) 103 | ids: str | None # スクリプトID (推論時は None) 104 | 105 | 106 | class ModelInputs(TypedDict): 107 | """モデルの入力を表す型""" 108 | 109 | embedding_features: dict[str, Tensor] # 埋め込み特徴量 110 | lengths: Tensor # 系列長 111 | mask: Tensor # マスク 112 | prev_decoder_outputs: dict[str, Tensor] # 前のデコーダーの出力 113 | 114 | 115 | class PadFeature(TypedDict): 116 | """パディングされた特徴量を表す型""" 117 | 118 | # 必須フィールド 119 | morph_boundary: list[list[list[int]]] # 形態素境界情報 120 | # config.data.input_keys に依存するフィールド (推論に用いるモデルの config.yaml 定義次第では省略される) 121 | mora: Tensor # モーラ ID 列 122 | mora_length: Tensor # モーラ長 (config.data.input_length_key によって定義される) 123 | surface: Tensor # 表層形 ID 列 124 | pos: Tensor # 品詞 ID 列 125 | c_type: Tensor # 活用型 ID 列 126 | c_form: Tensor # 活用形 ID 列 127 | accent_type: Tensor # アクセント型 ID 列 128 | accent_con_type: Tensor # アクセント結合型 ID 列 129 | chain_flag: Tensor # アクセント句の連結フラグ ID 列 (現在の学習レシピでは未使用) 130 | 131 | 132 | class PadOutputLabel(TypedDict): 133 | """パディングされた出力ラベルを表す型""" 134 | 135 | label: Tensor # ラベル列 136 | length: Tensor # 系列長 137 | 138 | 139 | class PadOutputs(TypedDict): 140 | """パディングされた出力を表す型""" 141 | 142 | intonation_phrase_boundary: PadOutputLabel 143 | accent_phrase_boundary: PadOutputLabel 144 | accent_status: PadOutputLabel 145 | -------------------------------------------------------------------------------- /tests/test_predictor.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from marine.predict import Predictor 4 | from marine.types import MarineFeature 5 | 6 | 7 | @pytest.fixture 8 | def predictor() -> Predictor: 9 | """load inference model using default config""" 10 | return Predictor() 11 | 12 | 13 | def test_predict(predictor: Predictor) -> None: 14 | """just to confirm predict() is working without errors.""" 15 | nodes: list[MarineFeature] = [ 16 | { 17 | "surface": "水", 18 | "pron": "ミズ", 19 | "pos": "名詞:一般:*:*", 20 | "c_type": "*", 21 | "c_form": "*", 22 | "accent_type": 0, 23 | "accent_con_type": "C3", 24 | "chain_flag": -1, 25 | }, 26 | { 27 | "surface": "を", 28 | "pron": "オ", 29 | "pos": "助詞:格助詞:一般:*", 30 | "c_type": "*", 31 | "c_form": "*", 32 | "accent_type": 0, 33 | "accent_con_type": "動詞%F5,名詞%F1", 34 | "chain_flag": 1, 35 | }, 36 | { 37 | "surface": "マレーシア", 38 | "pron": "マレーシア", 39 | "pos": "名詞:固有名詞:地域:国", 40 | "c_type": "*", 41 | "c_form": "*", 42 | "accent_type": 2, 43 | "accent_con_type": "C1", 44 | "chain_flag": 0, 45 | }, 46 | { 47 | "surface": "から", 48 | "pron": "カラ", 49 | "pos": "助詞:格助詞:一般:*", 50 | "c_type": "*", 51 | "c_form": "*", 52 | "accent_type": 2, 53 | "accent_con_type": "名詞%F1", 54 | "chain_flag": 1, 55 | }, 56 | { 57 | "surface": "買わ", 58 | "pron": "カワ", 59 | "pos": "動詞:自立:*:*", 60 | "c_type": "五段・ワ行促音便", 61 | "c_form": "未然形", 62 | "accent_type": 0, 63 | "accent_con_type": "*", 64 | "chain_flag": 0, 65 | }, 66 | { 67 | "surface": "なく", 68 | "pron": "ナク", 69 | "pos": "助動詞:*:*:*", 70 | "c_type": "特殊・ナイ", 71 | "c_form": "連用テ接続", 72 | "accent_type": 1, 73 | "accent_con_type": "動詞%F3@0", 74 | "chain_flag": 1, 75 | }, 76 | { 77 | "surface": "て", 78 | "pron": "テ", 79 | "pos": "助詞:接続助詞:*:*", 80 | "c_type": "*", 81 | "c_form": "*", 82 | "accent_type": 0, 83 | "accent_con_type": "動詞%F1,形容詞%F1,名詞%F5", 84 | "chain_flag": 1, 85 | }, 86 | { 87 | "surface": "は", 88 | "pron": "ワ", 89 | "pos": "助詞:係助詞:*:*", 90 | "c_type": "*", 91 | "c_form": "*", 92 | "accent_type": 0, 93 | "accent_con_type": "名詞%F1,動詞%F2@0,形容詞%F2@0", 94 | "chain_flag": 1, 95 | }, 96 | { 97 | "surface": "なら", 98 | "pron": "ナラ", 99 | "pos": "動詞:非自立:*:*", 100 | "c_type": "五段・ラ行", 101 | "c_form": "未然形", 102 | "accent_type": 2, 103 | "accent_con_type": "*", 104 | "chain_flag": 0, 105 | }, 106 | { 107 | "surface": "ない", 108 | "pron": "ナイ", 109 | "pos": "助動詞:*:*:*", 110 | "c_type": "特殊・ナイ", 111 | "c_form": "基本形", 112 | "accent_type": 1, 113 | "accent_con_type": "動詞%F3@0,形容詞%F2@1", 114 | "chain_flag": 1, 115 | }, 116 | { 117 | "surface": "の", 118 | "pron": "ノ", 119 | "pos": "名詞:非自立:一般:*", 120 | "c_type": "*", 121 | "c_form": "*", 122 | "accent_type": 2, 123 | "accent_con_type": "動詞%F2@0,形容詞%F2@-1", 124 | "chain_flag": 0, 125 | }, 126 | { 127 | "surface": "です", 128 | "pron": "デス", 129 | "pos": "助動詞:*:*:*", 130 | "c_type": "特殊・デス", 131 | "c_form": "基本形", 132 | "accent_type": 1, 133 | "accent_con_type": "名詞%F2@1,動詞%F1,形容詞%F2@0", 134 | "chain_flag": 1, 135 | }, 136 | { 137 | "surface": ".", 138 | "pron": None, 139 | "pos": "記号:句点:*:*", 140 | "c_type": "*", 141 | "c_form": "*", 142 | "accent_type": 0, 143 | "accent_con_type": "*", 144 | "chain_flag": 0, 145 | }, 146 | ] 147 | 148 | print(predictor.predict([nodes], accent_represent_mode="binary")) 149 | print(predictor.predict([nodes], accent_represent_mode="high_low")) 150 | 151 | # If you want the format for OpenJTalk, `accent_represent_mode` will be fixed as `binary 152 | print( 153 | predictor.predict( 154 | [nodes], 155 | accent_represent_mode="binary", 156 | require_open_jtalk_format=True, 157 | ) 158 | ) 159 | -------------------------------------------------------------------------------- /marine/utils/g2p_util/util.py: -------------------------------------------------------------------------------- 1 | UNACCENTED_MORA = "ン" 2 | CONNECTABLE_MORA = set(["ッ", UNACCENTED_MORA]) 3 | 4 | NON_MORA_LIST = set(["ァ", "ィ", "ゥ", "ェ", "ォ", "ャ", "ュ", "ョ"]) 5 | 6 | LONGVOWEL_CHARACTER = "ー" 7 | 8 | FULL_PUNCTUATION = "、。?!" 9 | HALF_PUNCTUATION = ",.?!" 10 | 11 | PHON_TABLE = { 12 | # x 13 | "ア": ["a"], 14 | "イ": ["i"], 15 | "ウ": ["u"], 16 | "エ": ["e"], 17 | "オ": ["o"], 18 | # w 19 | "ワ": ["w", "a"], 20 | "ウィ": ["w", "i"], 21 | "ウェ": ["w", "e"], 22 | "ウォ": ["w", "o"], 23 | # y 24 | "ヤ": ["y", "a"], 25 | "ユ": ["y", "u"], 26 | "ヨ": ["y", "o"], 27 | "イェ": ["y", "e"], 28 | # k 29 | "カ": ["k", "a"], 30 | "キ": ["k", "i"], 31 | "ク": ["k", "u"], 32 | "ケ": ["k", "e"], 33 | "コ": ["k", "o"], 34 | # kw 35 | "クヮ": ["kw", "a"], 36 | "キャ": ["ky", "a"], 37 | "キュ": ["ky", "u"], 38 | "キェ": ["ky", "e"], 39 | "キョ": ["ky", "o"], 40 | # g 41 | "ガ": ["g", "a"], 42 | "ギ": ["g", "i"], 43 | "グ": ["g", "u"], 44 | "ゲ": ["g", "e"], 45 | "ゴ": ["g", "o"], 46 | # gw 47 | "グヮ": ["gw", "a"], 48 | # gy 49 | "ギャ": ["gy", "a"], 50 | "ギュ": ["gy", "u"], 51 | "ギェ": ["gy", "e"], 52 | "ギョ": ["gy", "o"], 53 | # s 54 | "サ": ["s", "a"], 55 | "スィ": ["s", "i"], 56 | "ス": ["s", "u"], 57 | "セ": ["s", "e"], 58 | "ソ": ["s", "o"], 59 | # sh 60 | "シ": ["sh", "i"], 61 | "シェ": ["sh", "e"], 62 | "シャ": ["sh", "a"], 63 | "シュ": ["sh", "u"], 64 | "ショ": ["sh", "o"], 65 | # z 66 | "ザ": ["z", "a"], 67 | "ズィ": ["z", "i"], 68 | "ズ": ["z", "u"], 69 | "ゼ": ["z", "e"], 70 | "ゾ": ["z", "o"], 71 | # j 72 | "ジャ": ["j", "a"], 73 | "ジ": ["j", "i"], 74 | "ジュ": ["j", "u"], 75 | "ジェ": ["j", "e"], 76 | "ジョ": ["j", "o"], 77 | # t 78 | "タ": ["t", "a"], 79 | "ティ": ["t", "i"], 80 | "トゥ": ["t", "u"], 81 | "テ": ["t", "e"], 82 | "ト": ["t", "o"], 83 | # ty 84 | "テャ": ["ty", "a"], 85 | "テュ": ["ty", "u"], 86 | "テョ": ["ty", "o"], 87 | # d 88 | "ダ": ["d", "a"], 89 | "ディ": ["d", "i"], 90 | "ドゥ": ["d", "u"], 91 | "デ": ["d", "e"], 92 | "ド": ["d", "o"], 93 | # dy 94 | "デャ": ["dy", "a"], 95 | "デュ": ["dy", "u"], 96 | "デョ": ["dy", "o"], 97 | # ch 98 | "チャ": ["ch", "a"], 99 | "チ": ["ch", "i"], 100 | "チュ": ["ch", "u"], 101 | "チェ": ["ch", "e"], 102 | "チョ": ["ch", "o"], 103 | # ts 104 | "ツァ": ["ts", "a"], 105 | "ツィ": ["ts", "i"], 106 | "ツ": ["ts", "u"], 107 | "ツェ": ["ts", "e"], 108 | "ツォ": ["ts", "o"], 109 | # n 110 | "ナ": ["n", "a"], 111 | "ニ": ["n", "i"], 112 | "ヌ": ["n", "u"], 113 | "ネ": ["n", "e"], 114 | "ノ": ["n", "o"], 115 | # ny 116 | "ニャ": ["ny", "a"], 117 | "ニュ": ["ny", "u"], 118 | "ニェ": ["ny", "e"], 119 | "ニョ": ["ny", "o"], 120 | # h 121 | "ハ": ["h", "a"], 122 | "ヒ": ["h", "i"], 123 | "ヘ": ["h", "e"], 124 | "ホ": ["h", "o"], 125 | # hy 126 | "ヒャ": ["hy", "a"], 127 | "ヒュ": ["hy", "u"], 128 | "ヒェ": ["hy", "e"], 129 | "ヒョ": ["hy", "ふ"], 130 | # p 131 | "パ": ["p", "a"], 132 | "ピ": ["p", "i"], 133 | "プ": ["p", "u"], 134 | "ペ": ["p", "e"], 135 | "ポ": ["p", "o"], 136 | # py 137 | "ピャ": ["py", "a"], 138 | "ピュ": ["py", "u"], 139 | "ピェ": ["py", "e"], 140 | "ピョ": ["py", "o"], 141 | # b 142 | "バ": ["b", "a"], 143 | "ビ": ["b", "i"], 144 | "ブ": ["b", "u"], 145 | "ベ": ["b", "e"], 146 | "ボ": ["b", "o"], 147 | # by 148 | "ビャ": ["by", "a"], 149 | "ビュ": ["by", "u"], 150 | "ビェ": ["by", "e"], 151 | "ビョ": ["by", "o"], 152 | "ヴャ": ["by", "a"], 153 | "ヴュ": ["by", "u"], 154 | "ヴョ": ["by", "o"], 155 | # v 156 | "ヴァ": ["v", "a"], 157 | "ヴィ": ["v", "i"], 158 | "ヴ": ["v", "u"], 159 | "ヴェ": ["v", "e"], 160 | "ヴォ": ["v", "o"], 161 | # f 162 | "ファ": ["f", "a"], 163 | "フィ": ["f", "i"], 164 | "フ": ["f", "u"], 165 | "フェ": ["f", "e"], 166 | "フォ": ["f", "o"], 167 | # m 168 | "マ": ["m", "a"], 169 | "ミ": ["m", "i"], 170 | "ム": ["m", "u"], 171 | "メ": ["m", "e"], 172 | "モ": ["m", "o"], 173 | # my 174 | "ミャ": ["my", "a"], 175 | "ミュ": ["my", "u"], 176 | "ミェ": ["my", "e"], 177 | "ミョ": ["my", "o"], 178 | # r 179 | "ラ": ["r", "a"], 180 | "リ": ["r", "i"], 181 | "ル": ["r", "u"], 182 | "レ": ["r", "e"], 183 | "ロ": ["r", "o"], 184 | # ry 185 | "リャ": ["ry", "a"], 186 | "リュ": ["ry", "u"], 187 | "リェ": ["ry", "e"], 188 | "リョ": ["ry", "o"], 189 | "ン": ["N"], 190 | "ッ": ["cl"], 191 | # for punctuation 192 | ",": [","], 193 | ".": ["."], 194 | "?": ["?"], 195 | "!": ["!"], 196 | "、": [", "], 197 | "。": ["."], 198 | "?": ["?"], 199 | "!": ["!"], 200 | } 201 | 202 | SUPPORTED_MORA = set(PHON_TABLE.keys()) 203 | 204 | 205 | def get_phoneme(mora: str, current_phonemes: list[str]) -> list[str]: 206 | """ 207 | Convert mora(single or double Katakana characters) to Phoneme 208 | """ 209 | 210 | # If the current mora is a long-vowel symbol, add a vowel to the previous phoneme 211 | # e.g., ワー -> w + a + a -> w + aa 212 | if current_phonemes and mora == LONGVOWEL_CHARACTER: 213 | # cl (i.e., ッ) should not be copied from long-vowel (coco #28) 214 | if current_phonemes[-1] != "cl": 215 | current_phonemes[-1] = f"{current_phonemes[-1]}{current_phonemes[-1][-1]}" 216 | phoneme = [] 217 | else: 218 | try: 219 | phoneme = PHON_TABLE[mora] 220 | except KeyError: 221 | raise ValueError(f"Not supported mora : {mora}") 222 | 223 | return phoneme 224 | -------------------------------------------------------------------------------- /marine/utils/post_process.py: -------------------------------------------------------------------------------- 1 | import re 2 | from csv import reader 3 | from pathlib import Path 4 | from typing import Any, Literal 5 | 6 | import numpy as np 7 | 8 | from marine.types import AccentRepresentMode, MarineFeature 9 | from marine.utils.g2p_util import pron2mora 10 | from marine.utils.g2p_util.g2p import ACCENT_REPRESENT_FUNC_TABLE 11 | 12 | 13 | FEATURE_PARSE_SYMBOL = "/" 14 | FEATURE_NODE_SPLIT_SYMBOL = "_" 15 | FEATURE_MORA_SPLIT_SYMBOL = "," 16 | FEATURE_ACCENT_SYMBOL = "@" 17 | 18 | FEATURE_SYOMBOL_REMOVER = str.maketrans( 19 | { 20 | FEATURE_NODE_SPLIT_SYMBOL: "", 21 | FEATURE_MORA_SPLIT_SYMBOL: "", 22 | FEATURE_ACCENT_SYMBOL: "", 23 | } 24 | ) 25 | 26 | PADDING_VALUE_FOR_LABEL = 1 27 | 28 | 29 | def _make_align_array(surfaces: list[str]) -> list[int]: 30 | aligns = [] 31 | index = 0 32 | 33 | while index < len(surfaces): 34 | current_surface = surfaces[index] 35 | blank_len = len(current_surface) - 1 36 | 37 | boundary = [index + 1] 38 | blank = [0] * blank_len if blank_len >= 1 else [] 39 | 40 | aligns = aligns + boundary + blank 41 | 42 | index += 1 43 | 44 | return aligns 45 | 46 | 47 | def _is_available_match(aligns: list[int], head: int, tail: int) -> bool: 48 | return (head >= 0 and aligns[head] > 0) and ( 49 | (tail == len(aligns)) or (tail < len(aligns) and aligns[tail] != 0) 50 | ) 51 | 52 | 53 | def _search_mark(padded_aligns: list[int]) -> int: 54 | while len(padded_aligns) > 0: 55 | mark = padded_aligns.pop(-1) 56 | if mark > 0: 57 | return mark 58 | return -1 59 | 60 | 61 | def aligns2mask(aligns: list[int], head: int, tail: int) -> tuple[int, int] | None: 62 | if _is_available_match(aligns, head, tail): 63 | start = aligns[head] - 1 64 | end = _search_mark(aligns[head:tail]) 65 | return (start, end) 66 | else: 67 | return None 68 | 69 | 70 | def convert_feature_to_value( 71 | target: str, 72 | pron: str, 73 | label: int, 74 | ) -> tuple[list[str], list[int] | dict[AccentRepresentMode, list[int]]]: 75 | if target == "accent_status": 76 | moras = pron2mora(pron) 77 | assert isinstance(moras, list) 78 | value = {} 79 | 80 | for accent_represent_mode in ACCENT_REPRESENT_FUNC_TABLE.keys(): 81 | _, represented_accent = pron2mora(moras, label, accent_represent_mode) 82 | assert isinstance(represented_accent, list) 83 | value[accent_represent_mode] = represented_accent 84 | else: 85 | moras = pron2mora(pron) 86 | assert isinstance(moras, list) 87 | value = len(moras) * [0] 88 | 89 | if label > 1: 90 | value[label - 1] = 1 91 | 92 | return moras, value 93 | 94 | 95 | def load_postprocess_vocab(vocab_dir: Path, tasks: list[str]) -> dict[str, Any]: 96 | vocab = {key: {} for key in tasks} 97 | 98 | for dict_dir in vocab_dir.iterdir(): 99 | target = dict_dir.name 100 | 101 | if target == "vocab.pkl": 102 | continue 103 | 104 | assert target in tasks 105 | 106 | for dict_path in dict_dir.glob("*.tsv"): 107 | with dict_path.open("r", encoding="utf-8") as dict_file: 108 | table = reader(dict_file, delimiter="\t") 109 | 110 | for pattern, value in table: 111 | regex = re.compile(pattern) 112 | surface, feature = value.split(FEATURE_PARSE_SYMBOL) 113 | surfaces = surface.split(FEATURE_NODE_SPLIT_SYMBOL) 114 | pron = feature.translate(FEATURE_SYOMBOL_REMOVER) 115 | 116 | features = [ 117 | [mora for mora in moras.split(FEATURE_MORA_SPLIT_SYMBOL)] 118 | for moras in feature.split(FEATURE_NODE_SPLIT_SYMBOL) 119 | ] 120 | 121 | assert len(surfaces) == len(features), ( 122 | f"Wrong length entry : ({surfaces} != {features})" 123 | ) 124 | 125 | labels = [ 126 | (0 if node_index == 0 else len(features[node_index - 1])) 127 | + mora_index 128 | + 1 129 | for node_index, moras in enumerate(features) 130 | for mora_index, mora in enumerate(moras) 131 | if mora.endswith(FEATURE_ACCENT_SYMBOL) 132 | ] 133 | 134 | if labels: 135 | # only use first appeared symbol 136 | label = labels[0] 137 | else: 138 | label = -1 139 | 140 | moras, values = convert_feature_to_value(target, pron, label) 141 | vocab[target][pattern] = (regex, moras, values) 142 | 143 | return vocab 144 | 145 | 146 | def apply_postprocess_dict( 147 | task: str, 148 | nodes: list[MarineFeature], 149 | labels: list[int], 150 | moras: list[str], 151 | boundary: list[Literal[0, 1]], 152 | postprocess_targets: re.Pattern[Any], 153 | postprocess_vocab: dict[str, Any], 154 | accent_represent_mode: AccentRepresentMode = "binary", 155 | ) -> list[int]: 156 | surfaces = [node["surface"] for node in nodes] 157 | surface = "".join(surfaces) 158 | 159 | targets = postprocess_targets.findall(surface) 160 | 161 | if targets: 162 | aligns = _make_align_array(surfaces) 163 | 164 | for target in targets: 165 | regex, pron, values = postprocess_vocab[target] 166 | 167 | for match in regex.finditer(surface): 168 | head, tail = match.span() 169 | 170 | node_mask = aligns2mask(aligns, head, tail) 171 | 172 | if node_mask: 173 | # get mora-based boundary's position 174 | boundary_indexs = ( 175 | [0] + list(np.where(boundary > 0)[0]) + [len(moras)] # type: ignore 176 | ) 177 | node_start, node_end = node_mask 178 | 179 | mora_mask = slice( 180 | boundary_indexs[node_start], 181 | boundary_indexs[node_end], 182 | ) 183 | 184 | if moras[mora_mask] == pron: 185 | if task == "accent_status": 186 | value = values[accent_represent_mode] 187 | else: 188 | value = values 189 | 190 | labels[mora_mask] = value 191 | 192 | return labels 193 | -------------------------------------------------------------------------------- /marine/data/feature/feature_set.py: -------------------------------------------------------------------------------- 1 | from logging import getLogger 2 | from pathlib import Path 3 | from typing import Any, cast 4 | 5 | import joblib 6 | import numpy as np 7 | from numpy.typing import NDArray 8 | 9 | from marine.data.feature.feature_table import ( 10 | FEATURE_TABLES, 11 | PUNCTUATIONS, 12 | parse_accent_con_type, 13 | ) 14 | from marine.types import BatchFeature, MarineFeature 15 | from marine.utils.g2p_util import pron2mora 16 | 17 | 18 | logger = getLogger(__name__) 19 | 20 | 21 | class FeatureSet: 22 | """ 23 | A converter for embedding features 24 | """ 25 | 26 | def __init__( 27 | self, 28 | vocab_path: str | Path, 29 | feature_table_key: str = "unidic-csj", 30 | feature_keys: list[str] | None = None, 31 | pad_token: str = "[PAD]", 32 | unk_token: str = "[UNK]", 33 | ) -> None: 34 | self.vocab_path = Path(vocab_path) 35 | 36 | self.pad_token = pad_token 37 | self.unk_token = unk_token 38 | self.default_tokens = [self.pad_token, self.unk_token] 39 | self.feature_table = FEATURE_TABLES[feature_table_key] 40 | 41 | if feature_keys: 42 | self.feature_keys = feature_keys 43 | else: 44 | self.feature_keys = list(self.feature_table.keys()) 45 | 46 | self.feature_to_id = {key: {} for key in self.feature_keys} 47 | self.id_to_feature = {key: {} for key in self.feature_keys} 48 | 49 | self.init_feature_set() 50 | 51 | def _load_vocab(self) -> list[str]: 52 | if not self.vocab_path.exists(): 53 | logger.error(f"Vocab has not found : {self.vocab_path}") 54 | raise FileNotFoundError(f"Vocab has not found : {self.vocab_path}") 55 | 56 | vocab = joblib.load(self.vocab_path) 57 | logger.info(f"Vocab loaded from {self.vocab_path} : {len(vocab)} words") 58 | 59 | return vocab 60 | 61 | def init_feature_set(self) -> None: 62 | self._load_vocab() 63 | 64 | for key in self.feature_keys: 65 | if key == "surface": 66 | feature_set = self.default_tokens + self._load_vocab() 67 | else: 68 | if key not in self.feature_table.keys(): 69 | raise ValueError( 70 | f"Feature key must be one of {self.feature_table.keys()}" 71 | ) 72 | feature_set = self.default_tokens + (self.feature_table[key] or []) 73 | 74 | feature_to_id = { 75 | feature_value: index for index, feature_value in enumerate(feature_set) 76 | } 77 | id_to_feature = { 78 | index: feature_value for feature_value, index in feature_to_id.items() 79 | } 80 | self.feature_to_id[key] = feature_to_id 81 | self.id_to_feature[key] = id_to_feature 82 | 83 | def convert_feature_to_id( 84 | self, feature_key: str, features: list[str | int] 85 | ) -> NDArray[np.uint8]: 86 | if feature_key not in self.feature_to_id: 87 | raise ValueError( 88 | f"Not initialized feature key: the key must be one of {self.feature_to_id}" 89 | ) 90 | 91 | return np.array( 92 | [ 93 | self.feature_to_id[feature_key].get( 94 | value, self.feature_to_id[feature_key][self.unk_token] 95 | ) 96 | for value in features 97 | ], 98 | dtype=np.uint8, 99 | ) 100 | 101 | def convert_id_to_feature(self, feature_key: str, ids: list[int]) -> NDArray[Any]: 102 | if feature_key not in self.id_to_feature: 103 | raise ValueError( 104 | f"Not initialized feature key: the key must be one of {self.id_to_feature}" 105 | ) 106 | 107 | return np.array( 108 | [ 109 | self.id_to_feature[feature_key].get(value, self.unk_token) 110 | for value in ids 111 | ] 112 | ) 113 | 114 | def convert_nodes_to_feature(self, nodes: list[MarineFeature]) -> BatchFeature: 115 | """ 116 | Input: dict型のリスト 117 | example: 118 | [ 119 | { 120 | "surface": "今回", 121 | "pron": "コンカイ", 122 | "pos": "名詞:副詞可能:*:*", 123 | "c_type": "*", 124 | "c_form": "*", 125 | "accent_type": 1, 126 | "accent_con_type": "C1", 127 | "chain_flag": -1 128 | }, 129 | ... 130 | ] 131 | """ 132 | 133 | features = {key: np.array([], dtype=np.uint8) for key in self.feature_to_id} 134 | 135 | # init morph boundary for inference 136 | features["morph_boundary"] = np.array([], dtype=np.uint8) 137 | 138 | for node in nodes: 139 | mora = self.convert_feature_to_id( 140 | "mora", 141 | cast( 142 | list[str | int], 143 | pron2mora(node["pron"]) if node["pron"] else [node["surface"]], 144 | ), 145 | ) 146 | 147 | morph_boundary = np.array([1] + ([0] * (len(mora) - 1)), dtype=np.uint8) 148 | 149 | # Push features 150 | features["mora"] = np.concatenate([features["mora"], mora], axis=0) 151 | features["morph_boundary"] = np.concatenate( 152 | [features["morph_boundary"], morph_boundary], axis=0 153 | ) 154 | 155 | for key, table in self.feature_to_id.items(): 156 | if key in ["mora", "morph_boundary"]: 157 | continue 158 | 159 | if key == "accent_con_type": 160 | value = parse_accent_con_type( 161 | node["accent_con_type"], node["pos"], unk_token=self.unk_token 162 | ) 163 | else: 164 | value = node[key] 165 | feature = table.get(value, table[self.unk_token]) 166 | # Convert to numpy array first, then cast to uint8 to maintain 167 | # the same overflow behavior (avoid deprecation warning) 168 | feature = np.array([feature] * len(mora)).astype(np.uint8) 169 | features[key] = np.concatenate([features[key], feature], axis=0) 170 | 171 | # First Mora could not be boundary 172 | # (boundary should be [0, 0, 1, 0, 0 ...]) 173 | features["morph_boundary"][0] = 0 174 | 175 | return cast(BatchFeature, features) 176 | 177 | def get_punctuation_ids(self) -> list[int]: 178 | return [self.feature_to_id["mora"][punctuation] for punctuation in PUNCTUATIONS] 179 | -------------------------------------------------------------------------------- /marine/utils/metrics.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Literal 2 | 3 | import torch 4 | from torchmetrics import F1Score, Metric, MetricCollection 5 | 6 | from marine.types import AccentRepresentMode 7 | from marine.utils.util import convert_ap_based_accent_to_mora_based_accent 8 | 9 | 10 | class SentenceLevelAccuracy(Metric): 11 | """Metrics to calculate sentence level accuray.""" 12 | 13 | full_state_update = True 14 | 15 | def __init__(self) -> None: 16 | super().__init__() 17 | self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum") 18 | self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") 19 | 20 | def update( 21 | self, preds: torch.Tensor, targets: torch.Tensor, masks: torch.Tensor 22 | ) -> None: 23 | """Update variables for accuracy.""" 24 | 25 | # preds: (B, T), target: (B, T), mask: (B, T) 26 | # sequence_level_matchs: (B) 27 | sequence_level_matchs = torch.LongTensor( 28 | [ 29 | (target[mask] == pred[mask]).all() 30 | for pred, target, mask in zip(preds, targets, masks) 31 | ] 32 | ) 33 | 34 | self.correct += torch.sum(sequence_level_matchs) 35 | self.total += sequence_level_matchs.numel() # == batch size 36 | 37 | def compute(self) -> torch.Tensor: 38 | """Compute accuracy using variables.""" 39 | return self.correct.float() / self.total 40 | 41 | 42 | class MultiTaskMetrics: 43 | """Metrics to calculate scores.""" 44 | 45 | def __init__( 46 | self, 47 | phase: Literal["train", "val", "test"], 48 | task_label_sizes: dict[str, int], 49 | average: Literal["micro", "macro", "weighted", "none"] = "macro", 50 | accent_represent_mode: AccentRepresentMode = "binary", 51 | require_ap_level_f1_score: bool = False, 52 | device: Literal["cpu", "cuda"] = "cpu", 53 | ) -> None: 54 | self.phase = phase 55 | self.tasks = task_label_sizes.keys() 56 | self.accent_represent_mode: AccentRepresentMode = accent_represent_mode 57 | self.require_ap_level_f1_score = require_ap_level_f1_score 58 | self.device = device 59 | 60 | self.metrics = { 61 | task_name: MetricCollection( 62 | { 63 | "mora_level_f1_score": F1Score( 64 | task="multiclass", 65 | num_classes=2, # the AN label represents High/Low or non-AN/AN 66 | average=average, 67 | ).to(device), 68 | "ap_level_f1_score": F1Score( 69 | task="multiclass", 70 | num_classes=task_label_size, 71 | average=average, 72 | ).to(device), 73 | "sentence_level_accuracy": SentenceLevelAccuracy().to(device), 74 | } 75 | if task_name == "accent_status" and require_ap_level_f1_score 76 | else { 77 | "mora_level_f1_score": F1Score( 78 | task="multiclass", 79 | num_classes=task_label_size, 80 | average=average, 81 | ).to(device), 82 | "sentence_level_accuracy": SentenceLevelAccuracy().to(device), 83 | } 84 | ).to(device) 85 | for task_name, task_label_size in task_label_sizes.items() 86 | } 87 | 88 | def update( 89 | self, 90 | task: str, 91 | preds: torch.Tensor, 92 | targets: torch.Tensor, 93 | masks: torch.Tensor, 94 | padding_idx: int = 0, 95 | **kwargs: Any, 96 | ) -> None: 97 | """Update variables for accuracy.""" 98 | assert task in self.tasks, f"Not initialized task: {task} not in {self.tasks}" 99 | 100 | masked_preds, masked_targets = preds[masks], targets[masks] 101 | 102 | # Verify that masked target label sequence not includes [PAD] token 103 | # TODO: This assert could be omited using `ignore_index` of `F1Score` 104 | # However, the option dosen't behave as explained until 2022/08/10 105 | # see https://github.com/Lightning-AI/metrics/issues/613 106 | assert padding_idx not in masked_targets 107 | 108 | if task == "accent_status" and self.require_ap_level_f1_score: 109 | mora_pred, mora_target = self._convert_ap_seq_to_mora_seq( 110 | preds, targets, masks, **kwargs 111 | ) 112 | self.metrics[task]["mora_level_f1_score"].update(mora_pred, mora_target) 113 | self.metrics[task]["ap_level_f1_score"].update(masked_preds, masked_targets) 114 | else: 115 | self.metrics[task]["mora_level_f1_score"].update( 116 | masked_preds, masked_targets 117 | ) 118 | 119 | self.metrics[task]["sentence_level_accuracy"].update(preds, targets, masks) 120 | 121 | def compute(self) -> dict[str, dict[str, float]]: 122 | """Compute accuracy using variables.""" 123 | return { 124 | task_name: { 125 | score_name: metrics.compute().cpu().item() 126 | for score_name, metrics in self.metrics[task_name].items() 127 | } 128 | for task_name in self.tasks 129 | } 130 | 131 | def reset(self) -> None: 132 | """Reset all variables in metrics.""" 133 | for task_name in self.tasks: 134 | for metrics in self.metrics[task_name].values(): 135 | metrics.reset() 136 | 137 | def _convert_ap_seq_to_mora_seq( 138 | self, 139 | preds: torch.Tensor, 140 | targets: torch.Tensor, 141 | ap_seq_masks: torch.Tensor, 142 | predicted_accent_phrase_boundaries: torch.Tensor, 143 | target_accent_phrase_boundaries: torch.Tensor, 144 | mora_seq_masks: torch.Tensor, 145 | ) -> tuple[torch.Tensor, torch.Tensor]: 146 | """Convert accent phrase-based accent status sequence to mora-based.""" 147 | mora_preds = convert_ap_based_accent_to_mora_based_accent( 148 | preds, 149 | predicted_accent_phrase_boundaries, 150 | ap_seq_masks, 151 | mora_seq_masks, 152 | self.accent_represent_mode, 153 | ) 154 | mora_targets = convert_ap_based_accent_to_mora_based_accent( 155 | targets, 156 | target_accent_phrase_boundaries, 157 | ap_seq_masks, 158 | mora_seq_masks, 159 | self.accent_represent_mode, 160 | ) 161 | 162 | return ( 163 | torch.LongTensor(mora_preds).to(self.device), 164 | torch.LongTensor(mora_targets).to(self.device), 165 | ) 166 | -------------------------------------------------------------------------------- /marine/utils/g2p_util/g2p.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Callable 2 | from typing import Any, Literal 3 | 4 | from numpy.typing import NDArray 5 | 6 | from marine.types import AccentRepresentMode 7 | 8 | from .accent import ( 9 | represent_accent_binary, 10 | represent_accent_high_low, 11 | represent_longvowel_accent_binary, 12 | represent_longvowel_accent_high_low, 13 | set_accent_status, 14 | ) 15 | from .boundary import represent_syllable_boundary 16 | from .util import ( 17 | CONNECTABLE_MORA, 18 | HALF_PUNCTUATION, 19 | LONGVOWEL_CHARACTER, 20 | NON_MORA_LIST, 21 | SUPPORTED_MORA, 22 | get_phoneme, 23 | ) 24 | 25 | 26 | ACCENT_REPRESENT_FUNC_TABLE: dict[ 27 | AccentRepresentMode, dict[str, Callable[[int, int, int], Literal[0, 1]]] 28 | ] = { 29 | "binary": { 30 | "represent_accent": represent_accent_binary, 31 | "represent_longvowel_accent": represent_longvowel_accent_binary, 32 | }, 33 | "high_low": { 34 | "represent_accent": represent_accent_high_low, 35 | "represent_longvowel_accent": represent_longvowel_accent_high_low, 36 | }, 37 | } 38 | 39 | 40 | def pron2mora( 41 | pron: str | list[str] | NDArray[Any], 42 | accent: int | None = None, 43 | represent_mode: AccentRepresentMode = "binary", 44 | ) -> list[str] | tuple[list[str], list[Literal[0, 1]]]: 45 | moras: list[str] = [] 46 | i = 0 47 | 48 | while i < len(pron): 49 | current_pron = pron[i] 50 | 51 | if current_pron in NON_MORA_LIST and len(moras) > 0: 52 | merged_mora = f"{moras[-1]}{current_pron}" 53 | 54 | if merged_mora in SUPPORTED_MORA: 55 | moras[-1] = merged_mora 56 | else: 57 | moras.append(current_pron) 58 | else: 59 | moras.append(current_pron) 60 | 61 | i += 1 62 | 63 | if accent is not None: 64 | if not isinstance(accent, int): 65 | raise TypeError(f"Accent is must be int not {type(accent)}") 66 | 67 | if represent_mode not in ACCENT_REPRESENT_FUNC_TABLE.keys(): 68 | raise NotImplementedError(f"Not Implemented mode : {represent_mode}") 69 | 70 | # init rule 71 | accent_rule = ACCENT_REPRESENT_FUNC_TABLE[represent_mode] 72 | represent_accent = accent_rule["represent_accent"] 73 | represent_longvowel_accent = accent_rule["represent_longvowel_accent"] 74 | 75 | # init satus 76 | high, end_low = set_accent_status(accent) 77 | represented_accents: list[Literal[0, 1]] = [] 78 | 79 | for index, mora in enumerate(moras): 80 | # if currnet mora is long-vowel syombol, update last mora 81 | if _is_longvowel(mora): 82 | represented_accent = represent_longvowel_accent(index, high, end_low) 83 | represented_accents.append(represented_accent) 84 | else: 85 | represented_accent = represent_accent(index, high, end_low) 86 | represented_accents.append(represented_accent) 87 | 88 | assert len(moras) == len(represented_accents), ( 89 | f"Wrong repersentation : {moras} != {represented_accents}" 90 | ) 91 | 92 | return moras, represented_accents 93 | 94 | return moras 95 | 96 | 97 | def _is_longvowel(mora: str) -> bool: 98 | return mora == LONGVOWEL_CHARACTER 99 | 100 | 101 | # Consider whether the mora is long-vowel 102 | # and previous mora was not unaccneted mora for escapte excepted case 103 | # e.g., ンー = NN11, ッー = cl11 104 | def _is_prev_mora_not_unaccented_mora(index: int, moras: list[str]) -> bool: 105 | return index > 0 and moras[index - 1] not in CONNECTABLE_MORA 106 | 107 | 108 | def mora2phon( 109 | moras: list[str], 110 | accents: list[Literal[0, 1]] | None = None, 111 | ignore_longvowel_accent: bool = False, 112 | use_punctuation: bool = True, 113 | punctuation_accent_label: int = 7, 114 | ) -> list[str] | tuple[list[str], list[int], list[Literal[0, 1]]]: 115 | phonemes = [] 116 | 117 | if accents is None: 118 | for mora in moras: 119 | if not use_punctuation and mora in HALF_PUNCTUATION: 120 | continue 121 | 122 | phoneme = get_phoneme(mora, phonemes) 123 | phonemes += phoneme 124 | 125 | return phonemes 126 | 127 | else: 128 | if len(accents) != len(moras): 129 | raise ValueError( 130 | f"Accent is must be same to length of mora (got : {len(accents) != len(moras)})" 131 | ) 132 | 133 | represented_accents = [] 134 | represented_boundaries = [] 135 | 136 | for index, (mora, accent) in enumerate(zip(moras, accents)): 137 | if not use_punctuation and mora in HALF_PUNCTUATION: 138 | continue 139 | 140 | phoneme = get_phoneme(mora, phonemes) 141 | len_phoneme = len(phoneme) 142 | represented_boundary = represent_syllable_boundary( 143 | index, moras, len_phoneme 144 | ) 145 | phonemes += phoneme 146 | 147 | # if currnet mora is long-vowel syombol, update last mora 148 | if _is_longvowel(mora): 149 | # if currnet mora is long-vowel symbol, only update previous accent 150 | if not ignore_longvowel_accent and _is_prev_mora_not_unaccented_mora( 151 | index, moras 152 | ): 153 | represented_accents[-1] = accent 154 | represented_boundaries[-1] = represented_boundary[-1] 155 | elif mora in HALF_PUNCTUATION: 156 | represented_accent = [punctuation_accent_label] 157 | represented_accents += represented_accent 158 | represented_boundaries += represented_boundary 159 | else: 160 | represented_accent = [0] * (len_phoneme - 1) + [accent] 161 | represented_accents += represented_accent 162 | represented_boundaries += represented_boundary 163 | 164 | return phonemes, represented_accents, represented_boundaries 165 | 166 | 167 | def pron2phon( 168 | pron: str, 169 | accent: int | None = None, 170 | represent_mode: AccentRepresentMode = "binary", 171 | ) -> list[str] | tuple[list[str], list[int], list[Literal[0, 1]]]: 172 | if represent_mode not in ACCENT_REPRESENT_FUNC_TABLE.keys(): 173 | raise NotImplementedError(f"Not Implemented mode : {represent_mode}") 174 | 175 | if accent is None: 176 | result = pron2mora(pron, accent, represent_mode) 177 | assert isinstance(result, list) 178 | moras = result 179 | accents = None 180 | else: 181 | result = pron2mora(pron, accent, represent_mode) 182 | assert isinstance(result, tuple) 183 | moras, accents = result 184 | 185 | return mora2phon(moras, accents) 186 | -------------------------------------------------------------------------------- /marine/models/att_lstm_decoder.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Mapping, Sequence 2 | 3 | import torch 4 | from torch import Tensor, nn 5 | 6 | from marine.modules.attention import BahdanauAttention, ZoneOutCell 7 | from marine.utils.util import get_ap_length 8 | 9 | 10 | class AttentionBasedLSTMDecoder(nn.Module): 11 | def __init__( 12 | self, 13 | input_size: int = 512, 14 | output_size: int = 20, 15 | hidden_size: int = 512, 16 | num_layers: int = 2, 17 | attention_hidden_size: int = 128, 18 | decoder_embedding_size: int = 256, 19 | zoneout: float = 0.1, 20 | prev_task_dropout: float = 0.5, 21 | decoder_prev_out_dropout: float = 0.5, 22 | prev_task_embedding_label_list: Sequence[str] | None = None, 23 | prev_task_embedding_label_size: Mapping[str, int] | None = None, 24 | prev_task_embedding_size: Mapping[str, int] | None = None, 25 | padding_idx: int = 0, 26 | ) -> None: 27 | super().__init__() 28 | # NOTE: output_size must includes size for [PAD] 29 | self.output_size = output_size 30 | 31 | if ( 32 | prev_task_embedding_label_size 33 | and prev_task_embedding_label_list 34 | and prev_task_embedding_size 35 | ): 36 | embeddings = {} 37 | dropouts = {} 38 | for key in prev_task_embedding_label_list: 39 | embeddings[key] = nn.Embedding( 40 | prev_task_embedding_label_size[key], 41 | prev_task_embedding_size[key], 42 | padding_idx=padding_idx, 43 | ) 44 | input_size += prev_task_embedding_size[key] 45 | 46 | if prev_task_dropout: 47 | dropouts[key] = nn.Dropout(prev_task_dropout) 48 | 49 | self.prev_task_embedding = nn.ModuleDict(embeddings) 50 | 51 | if len(dropouts) > 0: 52 | self.prev_task_dropout = nn.ModuleDict(dropouts) 53 | else: 54 | self.prev_task_dropout = None 55 | 56 | else: 57 | self.prev_task_embedding = None 58 | self.prev_task_dropout = None 59 | 60 | self.attention = BahdanauAttention( 61 | input_size, 62 | hidden_size, 63 | attention_hidden_size, 64 | ) 65 | 66 | # in_dim: output-size + [PAD + SOS] 67 | self.decoder_embedding = nn.Embedding( 68 | self.output_size + 1, decoder_embedding_size, padding_idx=padding_idx 69 | ) 70 | 71 | self.decoder_prev_out_dropout = nn.Dropout(decoder_prev_out_dropout) 72 | 73 | # Setup autogressive LSTM layer 74 | self.lstm = nn.ModuleList() 75 | for layer in range(num_layers): 76 | lstm = nn.LSTMCell( 77 | input_size + decoder_embedding_size if layer == 0 else hidden_size, 78 | hidden_size, 79 | ) 80 | self.lstm += [ZoneOutCell(lstm, zoneout)] 81 | 82 | # Setup feature projection layer 83 | project_size = input_size + hidden_size 84 | # out_dim: output-size + [PAD] 85 | self.projection = nn.Linear(project_size, self.output_size, bias=False) 86 | 87 | def _zero_state(self, hs: Tensor) -> Tensor: 88 | init_hs = hs.new_zeros(hs.size(0), self.lstm[0].hidden_size) 89 | return init_hs 90 | 91 | def forward( 92 | self, 93 | encoder_outputs: Tensor, 94 | mask: Tensor, 95 | prev_task_outputs: dict[str, Tensor], 96 | decoder_targets: Tensor | None = None, 97 | ) -> tuple[Tensor, Tensor, list[int]]: 98 | is_inference = decoder_targets is None 99 | ap_lengths = get_ap_length(prev_task_outputs["accent_phrase_boundary"], mask) 100 | 101 | if is_inference: 102 | max_decoder_time_steps = max(ap_lengths) 103 | else: 104 | max_decoder_time_steps = decoder_targets.shape[1] 105 | 106 | if self.prev_task_embedding is not None: 107 | prev_task_output_embs = [] 108 | for key in self.prev_task_embedding.keys(): 109 | prev_task_output = prev_task_outputs[key] 110 | prev_task_output_emb = self.prev_task_embedding[key](prev_task_output) 111 | 112 | if self.prev_task_dropout: 113 | prev_task_output_emb = self.prev_task_dropout[key]( 114 | prev_task_output_emb 115 | ) 116 | 117 | prev_task_output_embs.append(prev_task_output_emb) 118 | 119 | encoder_outputs = torch.cat( 120 | [encoder_outputs] + prev_task_output_embs, dim=2 121 | ) 122 | 123 | h_list, c_list = [], [] 124 | for _ in range(len(self.lstm)): 125 | h_list.append(self._zero_state(encoder_outputs)) 126 | c_list.append(self._zero_state(encoder_outputs)) 127 | 128 | go_frame = encoder_outputs.new_zeros(encoder_outputs.size(0), dtype=torch.long) 129 | go_frame[:] = self.output_size # SOS 130 | prev_out = go_frame 131 | 132 | self.attention.reset() 133 | 134 | outs: list[Tensor] = [] 135 | att_ws: list[Tensor] = [] 136 | t = 0 137 | 138 | while True: 139 | att_c, att_w = self.attention(encoder_outputs, h_list[0], mask) 140 | 141 | decoder_emb = self.decoder_embedding(prev_out) 142 | decoder_emb = self.decoder_prev_out_dropout(decoder_emb) 143 | 144 | # LSTM 145 | xs = torch.cat([att_c, decoder_emb], dim=1) 146 | h_list[0], c_list[0] = self.lstm[0](xs, (h_list[0], c_list[0])) 147 | for i in range(1, len(self.lstm)): 148 | h_list[i], c_list[i] = self.lstm[i]( 149 | h_list[i - 1], (h_list[i], c_list[i]) 150 | ) 151 | hcs = torch.cat([h_list[-1], att_c], dim=1) 152 | 153 | outs.append( 154 | # (B, out_dim) -> (B, 1, out_dim) 155 | self.projection(hcs).view( 156 | encoder_outputs.size(0), 157 | -1, 158 | self.output_size, 159 | ) 160 | ) 161 | att_ws.append(att_w) 162 | 163 | if is_inference: 164 | # List[(B, 1, out_dim)] -> (B, out_dim) -> (B, 1) 165 | prev_out = torch.argmax(outs[-1][:, -1, :], dim=1) 166 | else: 167 | # Teacher forcing 168 | # (B, Lmax) -> (B, 1) 169 | prev_out = decoder_targets[:, t] 170 | 171 | t += 1 172 | if t >= max_decoder_time_steps: 173 | break 174 | 175 | outs_tensor = torch.cat(outs, dim=1) # (B, Lmax, out_dim) 176 | att_ws_tensor = torch.stack(att_ws, dim=1) # (B, Lmax, Tmax) 177 | 178 | return outs_tensor, att_ws_tensor, ap_lengths 179 | -------------------------------------------------------------------------------- /tests/test_g2p_util.py: -------------------------------------------------------------------------------- 1 | from logging import getLogger 2 | 3 | from marine.utils.g2p_util.accent import set_accent_status 4 | from marine.utils.g2p_util.g2p import pron2mora, pron2phon 5 | 6 | 7 | logger = getLogger("test") 8 | 9 | 10 | def test_accent_type_prediction(): 11 | for accent, expected in [ 12 | # Type 0 13 | (0, (1, -1)), 14 | # Type 1 15 | (1, (0, 1)), 16 | # Type 2 17 | (2, (1, 2)), 18 | # Type 3 19 | (3, (1, 3)), 20 | ]: 21 | status = set_accent_status(accent) 22 | assert status == expected 23 | 24 | 25 | def test_g2p_as_high_low(): 26 | for index, (pron, accent, expected) in enumerate( 27 | [ 28 | # For Basic Pattern 29 | ("ソレデワ", 3, "s00 o01 r00 e11 d00 e11 w00 a01"), 30 | ("カジノオ", 1, "k00 a11 j00 i01 n00 o01 o01"), 31 | # For ッ 32 | ("オクッテイル", 0, "o01 k00 u10 cl11 t00 e11 i11 r00 u11"), 33 | ("ハジマッタ", 0, "h00 a01 j00 i11 m00 a10 cl11 t00 a11"), 34 | ("ギインリッポー", 4, "g00 i01 i10 N11 r00 i10 cl01 p00 oo01"), 35 | # For long Vowel(ー) 36 | ("シンデイル", 0, "sh00 i00 N11 d00 e11 i11 r00 u11"), 37 | ("ユキオンナオ", 3, "y00 u01 k00 i11 o10 N01 n00 a01 o01"), 38 | ("ツーワリョーガ", 3, "ts00 uu11 w00 a11 ry00 oo01 g00 a01"), 39 | ("アソンデイル", 0, "a01 s00 o10 N11 d00 e11 i11 r00 u11"), 40 | ("トーナンアジアオ", 5, "t00 oo11 n00 a10 N11 a11 j00 i01 a01 o01"), 41 | ("サイセーシマス", 6, "s00 a01 i11 s00 ee11 sh00 i11 m00 a11 s00 u01"), 42 | ( 43 | "トーゴーガタリゾートノ", 44 | 8, 45 | "t00 oo11 g00 oo11 g00 a11 t00 a11 r00 i11 z00 oo11 t00 o01 n00 o01", 46 | ), 47 | # For ン(N) 48 | ("ンダモシタン", 4, "N01 d00 a11 m00 o11 sh00 i11 t00 a00 N01"), 49 | ("チョーセンスル", 0, "ch00 oo11 s00 e10 N11 s00 u11 r00 u11"), 50 | ("ヘンコーオ", 0, "h00 e00 N11 k00 oo11 o11"), 51 | # For long vowel + ン(N) 52 | ("バチーンッテ", 2, "b00 a01 ch00 ii10 N01 cl01 t00 e01"), 53 | ("バーンアウトワ", 4, "b00 aa10 N11 a11 u01 t00 o01 w00 a01"), 54 | # For ンー(NN) 55 | ("ンー", 1, "NN11"), 56 | ("ンー", 0, "NN01"), 57 | ("ジャンケンー", 1, "j00 a10 N01 k00 e00 NN01"), 58 | ("オニーチャンー", 2, "o01 n00 ii11 ch00 a00 NN01"), 59 | # For ッー(cl) 60 | ("ッー", 1, "cl11"), 61 | ("ッー", 0, "cl01"), 62 | ] 63 | ): 64 | phons, accents, boundaries = pron2phon(pron, accent, represent_mode="high_low") 65 | 66 | phonemes = " ".join( 67 | [f"{p}{a}{b}" for p, a, b in zip(phons, accents, boundaries)] 68 | ) 69 | 70 | logger.info(f"No.{index} {pron} / {accent}") 71 | logger.info(f"answer\t:\t{phonemes}") 72 | logger.info(f"expected\t:\t{expected}") 73 | logger.info("---------") 74 | 75 | assert phonemes == expected 76 | 77 | 78 | def test_g2p_as_binary(): 79 | for index, (pron, accent, expected) in enumerate( 80 | [ 81 | # For Basic Pattern 82 | ("ソレデワ", 3, "s00 o01 r00 e01 d00 e11 w00 a01"), 83 | ("カジノオ", 1, "k00 a11 j00 i01 n00 o01 o01"), 84 | # For ッ 85 | ("オクッテイル", 0, "o01 k00 u00 cl01 t00 e01 i01 r00 u01"), 86 | ("ハジマッタ", 0, "h00 a01 j00 i01 m00 a00 cl01 t00 a01"), 87 | ("ギインリッポー", 4, "g00 i01 i00 N01 r00 i10 cl01 p00 oo01"), 88 | # For long Vowel(ー) 89 | ("シンデイル", 0, "sh00 i00 N01 d00 e01 i01 r00 u01"), 90 | ("ユキオンナオ", 3, "y00 u01 k00 i01 o10 N01 n00 a01 o01"), 91 | ("ツーワリョーガ", 3, "ts00 uu01 w00 a11 ry00 oo01 g00 a01"), 92 | ("アソンデイル", 0, "a01 s00 o00 N01 d00 e01 i01 r00 u01"), 93 | ("トーナンアジアオ", 5, "t00 oo01 n00 a00 N01 a11 j00 i01 a01 o01"), 94 | ("サイセーシマス", 6, "s00 a01 i01 s00 ee01 sh00 i01 m00 a11 s00 u01"), 95 | ( 96 | "トーゴーガタリゾートノ", 97 | 8, 98 | "t00 oo01 g00 oo01 g00 a01 t00 a01 r00 i01 z00 oo11 t00 o01 n00 o01", 99 | ), 100 | # For ン(N) 101 | ("ンダモシタン", 4, "N01 d00 a01 m00 o01 sh00 i11 t00 a00 N01"), 102 | ("チョーセンスル", 0, "ch00 oo01 s00 e00 N01 s00 u01 r00 u01"), 103 | ("ヘンコーオ", 0, "h00 e00 N01 k00 oo01 o01"), 104 | # For long vowel + ン(N) 105 | ("バチーンッテ", 2, "b00 a01 ch00 ii10 N01 cl01 t00 e01"), 106 | ("バーンアウトワ", 4, "b00 aa00 N01 a11 u01 t00 o01 w00 a01"), 107 | # For ンー(NN) 108 | ("ンー", 1, "NN11"), 109 | ("ンー", 0, "NN01"), 110 | ("ジャンケンー", 1, "j00 a10 N01 k00 e00 NN01"), 111 | ("オニーチャンー", 2, "o01 n00 ii11 ch00 a00 NN01"), 112 | # For ッー(cl) 113 | ("ッー", 1, "cl11"), 114 | ("ッー", 0, "cl01"), 115 | ] 116 | ): 117 | phons, accents, boundaries = pron2phon(pron, accent, represent_mode="binary") 118 | 119 | phonemes = " ".join( 120 | [f"{p}{a}{b}" for p, a, b in zip(phons, accents, boundaries)] 121 | ) 122 | 123 | logger.info(f"No.{index} {pron} / {accent}") 124 | logger.info(f"answer\t:\t{phonemes}") 125 | logger.info(f"expected\t:\t{expected}") 126 | logger.info("---------") 127 | 128 | assert phonemes == expected 129 | 130 | 131 | def test_mora_split(): 132 | for index, (pron, expected) in enumerate( 133 | [ 134 | # For Basic Pattern 135 | ("ソレデワ", ["ソ", "レ", "デ", "ワ"]), 136 | ("カジノオ", ["カ", "ジ", "ノ", "オ"]), 137 | # For long Vowel(ー) 138 | ("シンデイル", ["シ", "ン", "デ", "イ", "ル"]), 139 | ("ユキオンナオ", ["ユ", "キ", "オ", "ン", "ナ", "オ"]), 140 | ("ツーワリョーガ", ["ツ", "ー", "ワ", "リョ", "ー", "ガ"]), 141 | ("ツーヮワリョーガ", ["ツ", "ー", "ヮ", "ワ", "リョ", "ー", "ガ"]), 142 | ("アソンデイル", ["ア", "ソ", "ン", "デ", "イ", "ル"]), 143 | ("トーナンアジアオ", ["ト", "ー", "ナ", "ン", "ア", "ジ", "ア", "オ"]), 144 | ("サイセーシマス", ["サ", "イ", "セ", "ー", "シ", "マ", "ス"]), 145 | ( 146 | "トーゴーガタリゾートノ", 147 | ["ト", "ー", "ゴ", "ー", "ガ", "タ", "リ", "ゾ", "ー", "ト", "ノ"], 148 | ), 149 | # For ン(N) 150 | ("ンダモシタン", ["ン", "ダ", "モ", "シ", "タ", "ン"]), 151 | ("チョーセンスル", ["チョ", "ー", "セ", "ン", "ス", "ル"]), 152 | ("ヘンコーオ", ["ヘ", "ン", "コ", "ー", "オ"]), 153 | ("ヘンコーオォ", ["ヘ", "ン", "コ", "ー", "オ", "ォ"]), 154 | # For long vowel + ン(N) 155 | ("バチーンッテ", ["バ", "チ", "ー", "ン", "ッ", "テ"]), 156 | ("バーンアウトワ", ["バ", "ー", "ン", "ア", "ウ", "ト", "ワ"]), 157 | # For ンー(NN) 158 | ("ンー", ["ン", "ー"]), 159 | ("ジャンケンー", ["ジャ", "ン", "ケ", "ン", "ー"]), 160 | ("オニーチャンー", ["オ", "ニ", "ー", "チャ", "ン", "ー"]), 161 | # For ッー(cl) 162 | ("ッー", ["ッ", "ー"]), 163 | ("ッョ", ["ッ", "ョ"]), 164 | ("オクッテイル", ["オ", "ク", "ッ", "テ", "イ", "ル"]), 165 | ("ハジマッタ", ["ハ", "ジ", "マ", "ッ", "タ"]), 166 | ("ギインリッポー", ["ギ", "イ", "ン", "リ", "ッ", "ポ", "ー"]), 167 | ] 168 | ): 169 | moras = pron2mora(pron) 170 | 171 | logger.info(f"No.{index} {pron}") 172 | logger.info(f"answer\t:\t{moras}") 173 | logger.info(f"expected\t:\t{expected}") 174 | logger.info("---------") 175 | 176 | assert moras == expected 177 | -------------------------------------------------------------------------------- /marine/bin/test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import random 4 | import sys 5 | from pathlib import Path 6 | 7 | import torch 8 | from omegaconf import OmegaConf 9 | from tqdm import tqdm 10 | 11 | from marine.data.feature.feature_set import FeatureSet 12 | from marine.data.util import load_dataset 13 | from marine.logger import getLogger 14 | from marine.models import ( 15 | AttentionBasedLSTMDecoder, 16 | CRFDecoder, 17 | LinearDecoder, 18 | init_model, 19 | ) 20 | from marine.utils.metrics import MultiTaskMetrics 21 | from marine.utils.util import ( 22 | convert_readable_labels, 23 | group_by_script_id, 24 | init_seed, 25 | log_scores, 26 | pack_inputs, 27 | pack_outputs, 28 | pad_incomplete_accent_logits, 29 | plot_batch_attention, 30 | ) 31 | 32 | 33 | logger = None 34 | 35 | 36 | def get_parser(): 37 | parser = argparse.ArgumentParser( 38 | description="Test model", 39 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 40 | ) 41 | parser.add_argument( 42 | "checkpoint_dir", type=Path, help="Directory for checkpoint to test" 43 | ) 44 | parser.add_argument( 45 | "--out_dir", "-o", type=Path, default=None, help="Directory test log" 46 | ) 47 | parser.add_argument( 48 | "--checkpoint_filename", 49 | "-f", 50 | type=str, 51 | default="latest.pth", 52 | help="Model's file name to test", 53 | ) 54 | parser.add_argument( 55 | "--data_dir", "-d", type=Path, default=None, help="Directory of dataset to test" 56 | ) 57 | parser.add_argument( 58 | "--vocab_path", "-b", type=Path, default=None, help="Path of vocab file" 59 | ) 60 | parser.add_argument("--n_jobs", type=int, default=8, help="Number of jobs") 61 | parser.add_argument( 62 | "--accent_status_represent_mode", 63 | "-m", 64 | type=str, 65 | choices=["binary", "high_low"], 66 | default="binary", 67 | help="Representation mode for accent status label", 68 | ) 69 | parser.add_argument( 70 | "--random_seed", 71 | "-r", 72 | type=int, 73 | default=12345, 74 | help="Random seed for sampling", 75 | ) 76 | parser.add_argument( 77 | "--verbose", 78 | "-v", 79 | type=int, 80 | default=50, 81 | help="Logging level", 82 | ) 83 | return parser 84 | 85 | 86 | def test_model( 87 | model, 88 | checkpoint_dir, 89 | checkpoint_file, 90 | tasks, 91 | dataloader, 92 | config, 93 | feature_set, 94 | tensorboard_writer=None, 95 | logger=None, 96 | device="cpu", 97 | ): 98 | model_path = checkpoint_dir / checkpoint_file 99 | states = torch.load(model_path, weights_only=False) 100 | 101 | phase = "test" 102 | fig_logging_targets = random.choices(range(config.data.batch_size), k=10) 103 | 104 | logger.info(f"Load checkpoint from {model_path} ({states['epoch']}th epoch)") 105 | model.load_state_dict(states["state_dict"]) 106 | 107 | dataloader = dataloader[phase] 108 | 109 | if "accent_status" in tasks: 110 | has_att_based_model = isinstance( 111 | model.decoders["accent_status"], AttentionBasedLSTMDecoder 112 | ) 113 | else: 114 | has_att_based_model = False 115 | 116 | # for logging 117 | metrics = MultiTaskMetrics( 118 | phase, 119 | config.data.output_sizes, 120 | accent_represent_mode=config.data.represent_mode, 121 | require_ap_level_f1_score=has_att_based_model, 122 | device=device, 123 | ) 124 | 125 | total_logs = {task: {} for task in tasks} 126 | 127 | for batch_index, (inputs, outputs, _, script_ids) in enumerate( 128 | tqdm(dataloader, desc=f"{phase}: ", leave=False) 129 | ): 130 | # pack inputs to device 131 | inputs = pack_inputs(inputs, config.data.input_keys, device) 132 | outputs = pack_outputs(outputs, device) 133 | 134 | prev_decoder_output = {} 135 | 136 | for task_index, task in enumerate(tasks): 137 | output, output_mask = outputs[task]["label"], outputs[task]["mask"] 138 | decoder_outputs = model(task, **inputs) 139 | 140 | # predict 141 | if isinstance(model.decoders[task], CRFDecoder): 142 | _, crf_logits = decoder_outputs 143 | logits = crf_logits 144 | elif isinstance(model.decoders[task], LinearDecoder): 145 | logits = decoder_outputs 146 | elif isinstance(model.decoders[task], AttentionBasedLSTMDecoder): 147 | logits, attentions, ap_lengths = decoder_outputs 148 | 149 | # plot attention when first batch on test 150 | if tensorboard_writer and batch_index == 0 and task == "accent_status": 151 | plot_batch_attention( 152 | inputs, 153 | logits, 154 | ap_lengths, 155 | attentions, 156 | feature_set, 157 | plot_targets=fig_logging_targets, 158 | tensorboard_writer=tensorboard_writer, 159 | phase=phase, 160 | epoch=0, 161 | script_ids=script_ids, 162 | ) 163 | 164 | # resize logits 165 | logits = pad_incomplete_accent_logits(logits, output_mask) 166 | 167 | # logits: (B, T, dim) -> (B, T) 168 | predicts = torch.argmax(logits, dim=2) 169 | 170 | # Log predicts 171 | total_logs[task].update( 172 | convert_readable_labels(predicts, output, output_mask, script_ids) 173 | ) 174 | 175 | # Update metrics 176 | # for ap-based seq 177 | if task == "accent_status" and has_att_based_model: 178 | metrics.update( 179 | task, 180 | predicts, 181 | output, 182 | output_mask, 183 | predicted_accent_phrase_boundaries=inputs["prev_decoder_outputs"][ 184 | "accent_phrase_boundary" 185 | ], 186 | target_accent_phrase_boundaries=outputs["accent_phrase_boundary"][ 187 | "label" 188 | ], 189 | mora_seq_masks=outputs["accent_phrase_boundary"]["mask"], 190 | ) 191 | 192 | # for mora-based seq 193 | else: 194 | metrics.update(task, predicts, output, output_mask) 195 | 196 | if task_index < len(tasks) - 1: 197 | prev_decoder_output[task] = predicts 198 | inputs["prev_decoder_outputs"][task] = predicts 199 | 200 | # Group by script_id 201 | total_logs = group_by_script_id(total_logs) 202 | 203 | # Logging scores 204 | log_scores( 205 | phase=phase, 206 | epoch=0, 207 | tasks=tasks, 208 | metrics=metrics, 209 | logs=total_logs, 210 | loss=None, 211 | tensorboard_writer=tensorboard_writer, 212 | ) 213 | 214 | return total_logs 215 | 216 | 217 | def entry(argv=sys.argv): 218 | global logger 219 | args = get_parser().parse_args(argv[1:]) 220 | logger = getLogger(args.verbose) 221 | logger.debug(f"Loaded parameters: {args}") 222 | 223 | init_seed(args.random_seed) 224 | 225 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 226 | 227 | checkpoint_dir = args.checkpoint_dir 228 | 229 | if not checkpoint_dir.exists(): 230 | raise FileNotFoundError("Checkpoint dir not found") 231 | 232 | checkpoint_config_path = checkpoint_dir / "config.yaml" 233 | 234 | if not checkpoint_config_path.exists(): 235 | raise FileNotFoundError("config file not found") 236 | 237 | checkpoint_config = OmegaConf.load(checkpoint_config_path) 238 | 239 | logger.info("Loaded config") 240 | logger.info(checkpoint_config) 241 | 242 | if args.data_dir: 243 | checkpoint_config.data.data_dir = str(args.data_dir) 244 | 245 | checkpoint_config.data.num_workers = args.n_jobs 246 | checkpoint_config.data.represent_mode = args.accent_status_represent_mode 247 | 248 | dataloader = load_dataset(checkpoint_config, phases=["test"]) 249 | tasks = checkpoint_config.data.output_keys 250 | 251 | if args.out_dir: 252 | log_dir = Path(args.out_dir) 253 | else: 254 | log_dir = Path("logs") / checkpoint_dir.name 255 | 256 | if not log_dir.exists(): 257 | log_dir.mkdir(parents=True) 258 | 259 | log_path = log_dir / f"{checkpoint_dir.name}_test_log.json" 260 | 261 | # init feature set 262 | if args.vocab_path: 263 | if args.vocab_path.exists(): 264 | feature_set = FeatureSet( 265 | args.vocab_path, 266 | feature_table_key=checkpoint_config.data.feature_table_key, 267 | feature_keys=checkpoint_config.data.input_keys, 268 | ) 269 | else: 270 | raise FileNotFoundError(f"Not found vocab file: {args.vocab_path}") 271 | else: 272 | raise FileNotFoundError( 273 | "Please specify vocab file path in args or config/model/*.yaml" 274 | ) 275 | 276 | # Init test model 277 | model = init_model(tasks, checkpoint_config, feature_set, device) 278 | logs = test_model( 279 | model, 280 | checkpoint_dir, 281 | args.checkpoint_filename, 282 | tasks, 283 | dataloader, 284 | checkpoint_config, 285 | feature_set, 286 | logger=logger, 287 | device=device, 288 | ) 289 | 290 | # save log 291 | with open(log_path, "w", encoding="utf-8") as file: 292 | json.dump(logs, file, ensure_ascii=False, indent=4, separators=(",", ": ")) 293 | 294 | 295 | if __name__ == "__main__": 296 | sys.exit(entry()) 297 | -------------------------------------------------------------------------------- /marine/bin/jsut2corpus.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import json 4 | import sys 5 | from pathlib import Path 6 | 7 | import numpy as np 8 | import yaml 9 | from tqdm import tqdm 10 | 11 | from marine.logger import getLogger 12 | from marine.utils.g2p_util import pron2mora 13 | 14 | 15 | logger = None 16 | 17 | UNUSED_SYMBOL_REMOVER = str.maketrans("", "", "^$[?") 18 | ACCENT_NUCLEUS_SYMBOL = "]" 19 | ACCENT_PHRASE_BOUNDARY_SYMBOL = "#" 20 | INTONATION_PHRASE_BOUNDARY_SYMBOL = "_" 21 | INTONATION_PHRASE_BOUNDARY_PUNCTUATION = "," 22 | 23 | 24 | def get_parser(): 25 | parser = argparse.ArgumentParser( 26 | description="Convert Special format txt format data to json file", 27 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 28 | ) 29 | parser.add_argument("in_path", type=Path, help="Input path or directory") 30 | parser.add_argument("out_dir", type=Path, help="Output directory") 31 | parser.add_argument( 32 | "--accent_status_seq_level", 33 | "-s", 34 | type=str, 35 | choices=["ap", "mora"], 36 | default="ap", 37 | help="Sequence level for accent status label", 38 | ) 39 | parser.add_argument( 40 | "--accent_status_represent_mode", 41 | "-m", 42 | type=str, 43 | choices=["binary", "high_low"], 44 | default="binary", 45 | help="""Representation mode for accent status label 46 | (this option will be ignored when --accent-status-seq-level is chosen as 'ap')""", 47 | ) 48 | parser.add_argument( 49 | "--verbose", 50 | "-v", 51 | type=int, 52 | default=50, 53 | help="Logging level", 54 | ) 55 | return parser 56 | 57 | 58 | def alignment_feature(features, target_feature, ignore_features): 59 | # filtering 60 | for ignore_feature in ignore_features: 61 | features = features[features != ignore_feature] 62 | 63 | # detect 64 | _features = features == target_feature 65 | 66 | # move & pad 67 | mask = ~(np.concatenate([_features[1:], [False]])) 68 | _features = _features[mask] 69 | 70 | if target_feature in [ 71 | ACCENT_PHRASE_BOUNDARY_SYMBOL, 72 | INTONATION_PHRASE_BOUNDARY_SYMBOL, 73 | ]: 74 | _features = np.concatenate([[False], _features[:-1]]) 75 | 76 | return _features 77 | 78 | 79 | def convert_mask_seq_to_int_seq(feature): 80 | return np.where(feature, 1, 0) 81 | 82 | 83 | def merge_ip_ap_boundary(accent_phrase_boundaries, intonation_phrase_boundaries): 84 | assert len(accent_phrase_boundaries) == len(intonation_phrase_boundaries) 85 | 86 | return accent_phrase_boundaries + intonation_phrase_boundaries 87 | 88 | 89 | def binary_accent_to_ap_accent( 90 | binary_accents, 91 | accent_phrase_boundaries, 92 | accent_label=1, 93 | accnet_phrase_label=1, 94 | ): 95 | assert len(binary_accents) == len(accent_phrase_boundaries) 96 | 97 | accent_phrase_boundary_indexes = np.where( 98 | accent_phrase_boundaries == accnet_phrase_label 99 | )[0] 100 | splitted_binary_accents = np.split(binary_accents, accent_phrase_boundary_indexes) 101 | 102 | ap_accent_labbels = [] 103 | 104 | for binary_accent in splitted_binary_accents: 105 | accent_indexs = np.where(binary_accent == accent_label)[0] 106 | 107 | if len(accent_indexs) >= 1: 108 | acc = accent_indexs[0] + 1 # zero pad 109 | else: 110 | acc = 0 111 | 112 | ap_accent_labbels.append(acc) 113 | 114 | return ap_accent_labbels 115 | 116 | 117 | def binary_accent_to_high_low_accent( 118 | moras, 119 | binary_accents, 120 | accent_phrase_boundaries, 121 | accent_label=1, 122 | accnet_phrase_label=1, 123 | accent_status_represent_mode="high_low", 124 | ): 125 | assert len(moras) == len(binary_accents) == len(accent_phrase_boundaries) 126 | 127 | accent_phrase_boundary_indexes = np.where( 128 | accent_phrase_boundaries == accnet_phrase_label 129 | )[0] 130 | splitted_moras = np.split(moras, accent_phrase_boundary_indexes) 131 | splitted_binary_accents = np.split(binary_accents, accent_phrase_boundary_indexes) 132 | 133 | high_low_accent_labels = [] 134 | 135 | for mora, binary_accent in zip(splitted_moras, splitted_binary_accents): 136 | accent_indexs = np.where(binary_accent == accent_label)[0] 137 | 138 | if len(accent_indexs) >= 1: 139 | _acc = accent_indexs[0] + 1 # zero pad 140 | _, acc = pron2mora(mora, int(_acc), accent_status_represent_mode) 141 | if accent_status_represent_mode == "high_low" and mora[int(_acc)] == "ー": 142 | acc[int(_acc)] = 0 143 | else: 144 | _, acc = pron2mora(mora, 0, accent_status_represent_mode) 145 | 146 | high_low_accent_labels += acc 147 | 148 | return high_low_accent_labels 149 | 150 | 151 | def convert_to_srt_feature(features, splitter=","): 152 | if isinstance(features, np.ndarray): 153 | features = features.tolist() 154 | elif not isinstance(features, list): 155 | raise TypeError("Wrong type of feature") 156 | 157 | return splitter.join([str(value) for value in features]) 158 | 159 | 160 | def parse_jsut_annotation( 161 | annotation, accent_status_seq_level, accent_status_represent_mode 162 | ): 163 | features = {} 164 | 165 | # preprocessing: remove unused symbols 166 | annotation = annotation.translate(UNUSED_SYMBOL_REMOVER) 167 | # preprocessing: replace ヲ -> オ 168 | annotation = annotation.replace("ヲ", "オ") 169 | # preprocessing: parse as sequence 170 | mora_based_annotation = np.array(pron2mora(annotation)) 171 | 172 | # filtering symbol 173 | moras = mora_based_annotation[ 174 | (mora_based_annotation != ACCENT_NUCLEUS_SYMBOL) 175 | & (mora_based_annotation != ACCENT_PHRASE_BOUNDARY_SYMBOL) 176 | & (mora_based_annotation != INTONATION_PHRASE_BOUNDARY_SYMBOL) 177 | ] 178 | 179 | assert len(moras) > 0, "Empty annotation" 180 | 181 | binary_accents = alignment_feature( 182 | mora_based_annotation, 183 | ACCENT_NUCLEUS_SYMBOL, 184 | ignore_features=ACCENT_PHRASE_BOUNDARY_SYMBOL 185 | + INTONATION_PHRASE_BOUNDARY_SYMBOL, 186 | ) 187 | accent_phrase_boundaries = alignment_feature( 188 | mora_based_annotation, 189 | ACCENT_PHRASE_BOUNDARY_SYMBOL, 190 | ignore_features=ACCENT_NUCLEUS_SYMBOL + INTONATION_PHRASE_BOUNDARY_SYMBOL, 191 | ) 192 | intonation_phrase_boundaries = alignment_feature( 193 | mora_based_annotation, 194 | INTONATION_PHRASE_BOUNDARY_SYMBOL, 195 | ignore_features=ACCENT_NUCLEUS_SYMBOL + ACCENT_PHRASE_BOUNDARY_SYMBOL, 196 | ) 197 | 198 | assert ( 199 | len(moras) 200 | == len(binary_accents) 201 | == len(accent_phrase_boundaries) 202 | == len(intonation_phrase_boundaries) 203 | ), ( 204 | f"{len(moras)} != {len(binary_accents)} != {len(accent_phrase_boundaries)} != {len(intonation_phrase_boundaries)}" 205 | ) 206 | 207 | accents = convert_mask_seq_to_int_seq(binary_accents) 208 | accent_phrase_boundaries = convert_mask_seq_to_int_seq(accent_phrase_boundaries) 209 | intonation_phrase_boundaries = convert_mask_seq_to_int_seq( 210 | intonation_phrase_boundaries 211 | ) 212 | 213 | accent_phrase_boundaries = merge_ip_ap_boundary( 214 | accent_phrase_boundaries, intonation_phrase_boundaries 215 | ) 216 | 217 | if accent_status_seq_level == "ap": 218 | accents = binary_accent_to_ap_accent(accents, accent_phrase_boundaries) 219 | elif accent_status_seq_level == "mora" and accent_status_represent_mode != "binary": 220 | accents = binary_accent_to_high_low_accent( 221 | moras, 222 | accents, 223 | accent_phrase_boundaries, 224 | accent_status_represent_mode=accent_status_represent_mode, 225 | ) 226 | 227 | features["pron"] = convert_to_srt_feature(moras, splitter="") 228 | features["accent_status"] = convert_to_srt_feature(accents) 229 | features["accent_phrase_boundary"] = convert_to_srt_feature( 230 | accent_phrase_boundaries 231 | ) 232 | features["intonation_phrase_boundary"] = convert_to_srt_feature( 233 | intonation_phrase_boundaries 234 | ) 235 | 236 | return features 237 | 238 | 239 | def load_jsut_corpus( 240 | jsut_corpus_dir, accent_status_seq_level, accent_status_represent_mode 241 | ): 242 | text_yaml_path = jsut_corpus_dir / "text_kana" / "basic5000.yaml" 243 | annotation_yaml_path = jsut_corpus_dir / "e2e_symbol" / "katakana.yaml" 244 | 245 | scripts = [] 246 | texts = {} 247 | annotations = {} 248 | 249 | with open(text_yaml_path, encoding="utf-8") as file: 250 | texts = yaml.safe_load(file) 251 | 252 | with open(annotation_yaml_path, encoding="utf-8") as file: 253 | annotations = yaml.safe_load(file) 254 | 255 | assert texts.keys() == annotations.keys(), "Not matched text and annotations" 256 | 257 | for script_id in tqdm(texts.keys(), "Parse anntoations"): 258 | features = {} 259 | 260 | surface = texts[script_id]["text_level0"] 261 | annotation = annotations[script_id] 262 | feature = parse_jsut_annotation( 263 | annotation, accent_status_seq_level, accent_status_represent_mode 264 | ) 265 | 266 | features["script_id"] = script_id 267 | features["surface"] = surface 268 | features.update(feature) 269 | 270 | scripts.append(features) 271 | 272 | logger.info(f"Loaded {len(scripts)} scripts") 273 | 274 | return scripts 275 | 276 | 277 | def entry(argv=sys.argv): 278 | global logger 279 | 280 | args = get_parser().parse_args(argv[1:]) 281 | logger = getLogger(args.verbose) 282 | logger.debug(f"Loaded parameters: {args}") 283 | 284 | scripts = load_jsut_corpus( 285 | args.in_path, args.accent_status_seq_level, args.accent_status_represent_mode 286 | ) 287 | 288 | if not args.out_dir.exists(): 289 | args.out_dir.mkdir(parents=True) 290 | 291 | today = datetime.date.today().strftime("%y%m%d") 292 | with open( 293 | args.out_dir / f"just_corpus_{today}.json", "w", encoding="utf-8" 294 | ) as file: 295 | json.dump(scripts, file, ensure_ascii=False, indent=4, separators=(",", ": ")) 296 | 297 | 298 | if __name__ == "__main__": 299 | sys.exit(entry()) 300 | -------------------------------------------------------------------------------- /tests/test_metrics.py: -------------------------------------------------------------------------------- 1 | import importlib.resources as importlib_resources 2 | import json 3 | from logging import getLogger 4 | from pathlib import Path 5 | from typing import Any 6 | 7 | import pytest 8 | import torch 9 | from numpy.testing import assert_almost_equal 10 | 11 | from marine.utils.metrics import MultiTaskMetrics, SentenceLevelAccuracy 12 | 13 | 14 | logger = getLogger("test") 15 | BASE_DIR = Path(str(importlib_resources.files("marine"))) 16 | 17 | 18 | @pytest.fixture 19 | def sentence_level_accuracy() -> SentenceLevelAccuracy: 20 | return SentenceLevelAccuracy() 21 | 22 | 23 | @pytest.fixture 24 | def high_low_multi_task_metrics() -> MultiTaskMetrics: 25 | phase = "train" 26 | task_label_sizes = {"accent_phrase_boundary": 3} 27 | 28 | return MultiTaskMetrics(phase, task_label_sizes, device="cpu") 29 | 30 | 31 | @pytest.fixture 32 | def ap_multi_task_metrics() -> MultiTaskMetrics: 33 | phase = "train" 34 | task_label_sizes = {"accent_status": 21} 35 | require_ap_level_f1_score = True 36 | 37 | return MultiTaskMetrics( 38 | phase, 39 | task_label_sizes, 40 | require_ap_level_f1_score=require_ap_level_f1_score, 41 | device="cpu", 42 | ) 43 | 44 | 45 | @pytest.fixture 46 | def full_multi_task_metrics() -> MultiTaskMetrics: 47 | phase = "train" 48 | task_label_sizes = { 49 | "intonation_phrase_boundary": 3, 50 | "accent_phrase_boundary": 3, 51 | "accent_status": 21, 52 | } 53 | require_ap_level_f1_score = True 54 | device = "cpu" 55 | 56 | return MultiTaskMetrics( 57 | phase, 58 | task_label_sizes, 59 | require_ap_level_f1_score=require_ap_level_f1_score, 60 | device=device, 61 | ) 62 | 63 | 64 | @pytest.fixture 65 | def test_log_sample() -> dict[str, Any]: 66 | logs = None 67 | sample_path = BASE_DIR.parent / "tests" / "samples" / "test_log_sample.json" 68 | with open(sample_path, encoding="utf-8") as file: 69 | logs = json.load(file) 70 | return logs 71 | 72 | 73 | def test_sentence_level_accuracy(sentence_level_accuracy): 74 | for pred, target, mask, expect in [ 75 | ( 76 | torch.tensor([[1, 2, 1, 1, 2, 1, 0], [1, 2, 1, 1, 1, 0, 0]]), 77 | torch.tensor([[1, 1, 2, 1, 1, 2, 0], [1, 1, 2, 1, 1, 0, 0]]), 78 | torch.tensor( 79 | [ 80 | [True, True, True, True, True, True, False], 81 | [True, True, True, True, True, False, False], 82 | ], 83 | ), 84 | 0.0, 85 | ), 86 | ( 87 | torch.tensor([[1, 1, 2, 1, 1, 2, 0], [1, 2, 1, 1, 1, 0, 0]]), 88 | torch.tensor([[1, 1, 2, 1, 1, 2, 0], [1, 1, 2, 1, 1, 0, 0]]), 89 | torch.tensor( 90 | [ 91 | [True, True, True, True, True, True, False], 92 | [True, True, True, True, True, False, False], 93 | ], 94 | ), 95 | 0.5, 96 | ), 97 | ( 98 | torch.tensor([[1, 1, 2, 1, 1, 2, -1], [1, 2, 1, 1, 1, 0, 0]]), 99 | torch.tensor([[1, 1, 2, 1, 1, 2, -1], [1, 2, 1, 1, 1, 0, 0]]), 100 | torch.tensor( 101 | [ 102 | [True, True, True, True, True, True, False], 103 | [True, True, True, True, True, False, False], 104 | ], 105 | ), 106 | 1.0, 107 | ), 108 | ]: 109 | sentence_level_accuracy.update(pred, target, mask) 110 | score = sentence_level_accuracy.compute() 111 | assert_almost_equal(score, expect) 112 | sentence_level_accuracy.reset() 113 | 114 | 115 | def test_ap_muti_task_metrics(ap_multi_task_metrics): 116 | for task, ap_pred, ap_target, ap_mask, kwargs, expect in [ 117 | ( 118 | "accent_status", 119 | # NOTE: AP-based accent label requires remove padding index (=0) 120 | # e,g, torch.tensor([4]) means 3rd mora has accent nucleus 121 | torch.tensor([[1, 1, 2], [1, 3, 0]]), 122 | torch.tensor([[1, 1, 3], [1, 2, 0]]), 123 | torch.tensor([[True, True, True], [True, True, False]]), 124 | { 125 | # NOTE: accent phrase boundary label requires remove padding index (=0) 126 | # e,g, torch.tensor([1, 2, 1]) means 2nd mora has accent phrase boundary 127 | "predicted_accent_phrase_boundaries": torch.tensor( 128 | [[1, 2, 1, 1, 2, 1, 1, 0], [1, 2, 1, 1, 1, 0, 0, 0]] 129 | ), 130 | "target_accent_phrase_boundaries": torch.tensor( 131 | [[1, 2, 1, 1, 2, 1, 1, 0], [1, 2, 1, 1, 1, 0, 0, 0]] 132 | ), 133 | "mora_seq_masks": torch.tensor( 134 | [ 135 | [True, True, True, True, True, True, True, False], 136 | [True, True, True, True, True, False, False, False], 137 | ] 138 | ), 139 | }, 140 | { 141 | "ap_level_f1_score": 0.3333333432674408, 142 | "mora_level_f1_score": 0.40000003576278687, 143 | "sentence_level_accuracy": 0.0, 144 | }, 145 | ), 146 | ( 147 | "accent_status", 148 | torch.tensor([[1, 1, 3], [1, 2, -1]]), 149 | torch.tensor([[1, 1, 3], [1, 3, -1]]), 150 | torch.tensor([[True, True, True], [True, True, False]]), 151 | { 152 | "predicted_accent_phrase_boundaries": torch.tensor( 153 | [[1, 2, 1, 1, 2, 1, 1, 0], [1, 2, 1, 1, 1, 0, 0, 0]] 154 | ), 155 | "target_accent_phrase_boundaries": torch.tensor( 156 | [[1, 2, 1, 1, 2, 1, 1, 0], [1, 2, 1, 1, 1, 0, 0, 0]] 157 | ), 158 | "mora_seq_masks": torch.tensor( 159 | [ 160 | [True, True, True, True, True, True, True, False], 161 | [True, True, True, True, True, False, False, False], 162 | ] 163 | ), 164 | }, 165 | { 166 | "ap_level_f1_score": 0.5555555820465088, 167 | "mora_level_f1_score": 0.699999988079071, 168 | "sentence_level_accuracy": 0.5, 169 | }, 170 | ), 171 | ( 172 | "accent_status", 173 | torch.tensor([[1, 1, 3], [1, 3, 0]]), 174 | torch.tensor([[1, 1, 3], [1, 3, 0]]), 175 | torch.tensor([[True, True, True], [True, True, False]]), 176 | { 177 | "predicted_accent_phrase_boundaries": torch.tensor( 178 | [[1, 2, 1, 1, 2, 1, 1, 0], [1, 2, 1, 1, 1, 0, 0, 0]] 179 | ), 180 | "target_accent_phrase_boundaries": torch.tensor( 181 | [[1, 2, 1, 1, 2, 1, 1, 0], [1, 2, 1, 1, 1, 0, 0, 0]] 182 | ), 183 | "mora_seq_masks": torch.tensor( 184 | [ 185 | [True, True, True, True, True, True, True, False], 186 | [True, True, True, True, True, False, False, False], 187 | ] 188 | ), 189 | }, 190 | { 191 | "ap_level_f1_score": 1.0, 192 | "mora_level_f1_score": 1.0, 193 | "sentence_level_accuracy": 1.0, 194 | }, 195 | ), 196 | ]: 197 | ap_multi_task_metrics.update(task, ap_pred, ap_target, ap_mask, **kwargs) 198 | scores = ap_multi_task_metrics.compute() 199 | 200 | for score_name, expect_score in expect.items(): 201 | assert_almost_equal(scores[task][score_name], expect_score) 202 | 203 | ap_multi_task_metrics.reset() 204 | 205 | 206 | def test_ap_muti_task_metrics_by_log(test_log_sample, full_multi_task_metrics): 207 | """Unit test for MultiTaskMetrics by sample""" 208 | 209 | def _extract_label(key, log): 210 | label = [int(v) for v in log[key].split(",")] 211 | # (T) -> (1, T) 212 | label = torch.tensor(label).unsqueeze(0).to("cpu") 213 | return label 214 | 215 | expected_scores = { 216 | "intonation_phrase_boundary": { 217 | "mora_level_f1_score": 0.9472237825393677, 218 | "sentence_level_accuracy": 0.7064254283905029, 219 | }, 220 | "accent_phrase_boundary": { 221 | "mora_level_f1_score": 0.993394672870636, 222 | "sentence_level_accuracy": 0.875554, 223 | }, 224 | "accent_status": { 225 | "ap_level_f1_score": 0.6181702017784119, 226 | "mora_level_f1_score": 0.9536119699478149, 227 | "sentence_level_accuracy": 0.6894387006759644, 228 | }, 229 | } 230 | 231 | for script_status in test_log_sample.values(): 232 | for task_name, task_status in script_status.items(): 233 | preds = _extract_label("predict", task_status) 234 | targets = _extract_label("target", task_status) 235 | 236 | if task_name == "accent_status": 237 | predicted_ap_boundary = _extract_label( 238 | "predict", script_status["accent_phrase_boundary"] 239 | ) 240 | target_ap_bondary = _extract_label( 241 | "target", script_status["accent_phrase_boundary"] 242 | ) 243 | # ap_mask: (1, T_ap) 244 | ap_seq_mask = torch.full(targets.shape, True).to("cpu") 245 | # mora_mask: (1, T_mora) 246 | mora_seq_mask = torch.full(target_ap_bondary.shape, True).to("cpu") 247 | 248 | full_multi_task_metrics.update( 249 | task_name, 250 | preds, 251 | targets, 252 | ap_seq_mask, 253 | predicted_accent_phrase_boundaries=predicted_ap_boundary, 254 | target_accent_phrase_boundaries=target_ap_bondary, 255 | mora_seq_masks=mora_seq_mask, 256 | ) 257 | else: 258 | masks = torch.full(targets.shape, True).to("cpu") 259 | full_multi_task_metrics.update(task_name, preds, targets, masks) 260 | 261 | scores = full_multi_task_metrics.compute() 262 | 263 | # verify keys the metric has is correct 264 | assert expected_scores.keys() == scores.keys() 265 | 266 | for task_name, task_status in expected_scores.items(): 267 | for score_name, score in task_status.items(): 268 | assert_almost_equal(score, scores[task_name][score_name]) 269 | -------------------------------------------------------------------------------- /marine/utils/openjtalk_util.py: -------------------------------------------------------------------------------- 1 | import difflib 2 | import re 3 | import warnings 4 | from typing import Literal, cast 5 | 6 | import numpy as np 7 | import pykakasi 8 | from numpy.typing import NDArray 9 | 10 | from marine.data.feature.feature_table import RAW_FEATURE_KEYS 11 | from marine.types import MarineFeature, MarineLabel, NJDFeature, OpenJTalkFormatLabel 12 | 13 | 14 | kakasi = pykakasi.kakasi() 15 | BOIN_DICT = {"a": "ア", "i": "イ", "u": "ウ", "e": "エ", "o": "オ", "n": "ン"} 16 | 17 | 18 | OPEN_JTALK_FEATURE_INDEX_TABLE = { 19 | "surface": 0, 20 | "pos": [1, 2, 3, 4], 21 | "c_type": 5, 22 | "c_form": 6, 23 | "pron": 9, 24 | "accent_type": 10, 25 | "accent_con_type": 11, 26 | "chain_flag": 12, 27 | } 28 | OPEN_JTALK_FEATURE_RENAME_TABLE = { 29 | "surface": "string", 30 | "pos": ["pos", "pos_group1", "pos_group2", "pos_group3"], 31 | "c_type": "ctype", 32 | "c_form": "cform", 33 | "accent_type": "acc", 34 | "pron": "pron", 35 | "accent_con_type": "chain_rule", 36 | "chain_flag": "chain_flag", 37 | } 38 | 39 | PUNCTUATION_FULL_TO_HALF_TABLE = { 40 | "、": ",", 41 | "。": ".", 42 | "?": "?", 43 | "!": "!", 44 | } 45 | PUNCTUATION_FULL_TO_HALF_TRANS = str.maketrans(PUNCTUATION_FULL_TO_HALF_TABLE) 46 | 47 | 48 | # TODO: pyopenjtalk から呼ばれている convert_njd_feature_to_marine_feature() とロジックが同じ (?) なので統合する 49 | def convert_open_jtalk_node_to_feature( 50 | nodes: list[NJDFeature], 51 | ) -> list[MarineFeature]: 52 | features: list[MarineFeature] = [] 53 | raw_feature_keys = RAW_FEATURE_KEYS["open-jtalk"] 54 | pre_pron = None 55 | 56 | for node in nodes: 57 | # parse feature 58 | node_feature = {} 59 | for feature_key in raw_feature_keys: 60 | jtalk_key = cast(str, OPEN_JTALK_FEATURE_RENAME_TABLE[feature_key]) 61 | 62 | if feature_key == "pos": 63 | value = ":".join([node[_k] for _k in cast(list[str], jtalk_key)]) 64 | elif feature_key == "accent_type": 65 | value = int(node[jtalk_key]) 66 | elif feature_key == "accent_con_type": 67 | value = node[jtalk_key].replace("/", ",") 68 | elif feature_key == "chain_flag": 69 | value = int(node[jtalk_key]) 70 | elif feature_key == "pron": 71 | if node[jtalk_key][0] == "ー": 72 | try: 73 | value = trans_hyphen2katakana(pre_pron + node[jtalk_key])[ 74 | -len(node[jtalk_key]) : 75 | ] 76 | except Exception: 77 | print(node[jtalk_key]) 78 | value = node[jtalk_key] 79 | pre_pron = value 80 | else: 81 | value = node[jtalk_key].replace("’", "").replace("ヲ", "オ") 82 | try: 83 | value = trans_hyphen2katakana(value) 84 | except Exception: 85 | print(value) 86 | 87 | pre_pron = value 88 | else: 89 | value = node[jtalk_key] 90 | 91 | node_feature[feature_key] = value 92 | 93 | if node_feature["surface"] == "・": 94 | continue 95 | elif node_feature["surface"] in PUNCTUATION_FULL_TO_HALF_TABLE.keys(): 96 | surface = node_feature["surface"].translate(PUNCTUATION_FULL_TO_HALF_TRANS) 97 | pron = None 98 | node_feature["surface"] = surface 99 | node_feature["pron"] = pron 100 | 101 | features.append(cast(MarineFeature, node_feature)) 102 | 103 | return features 104 | 105 | 106 | def convert_njd_feature_to_marine_feature( 107 | njd_features: list[NJDFeature], 108 | ) -> list[MarineFeature]: 109 | marine_features: list[MarineFeature] = [] 110 | 111 | raw_feature_keys = RAW_FEATURE_KEYS["open-jtalk"] 112 | for njd_feature in njd_features: 113 | marine_feature = {} 114 | for feature_key in raw_feature_keys: 115 | if feature_key == "pos": 116 | value = ":".join( 117 | [ 118 | njd_feature["pos"], 119 | njd_feature["pos_group1"], 120 | njd_feature["pos_group2"], 121 | njd_feature["pos_group3"], 122 | ] 123 | ) 124 | elif feature_key == "accent_con_type": 125 | value = njd_feature["chain_rule"].replace("/", ",") 126 | elif feature_key == "pron": 127 | value = njd_feature["pron"].replace("’", "").replace("ヲ", "オ") 128 | else: 129 | value = njd_feature[ 130 | cast(str, OPEN_JTALK_FEATURE_RENAME_TABLE[feature_key]) 131 | ] 132 | marine_feature[feature_key] = value 133 | 134 | if marine_feature["surface"] == "・": 135 | continue 136 | elif marine_feature["surface"] in PUNCTUATION_FULL_TO_HALF_TABLE.keys(): 137 | surface = marine_feature["surface"].translate( 138 | PUNCTUATION_FULL_TO_HALF_TRANS 139 | ) 140 | pron = None 141 | marine_feature["surface"] = surface 142 | marine_feature["pron"] = pron 143 | 144 | marine_features.append(cast(MarineFeature, marine_feature)) 145 | 146 | return marine_features 147 | 148 | 149 | def convert_open_jtalk_format_label( 150 | labels: MarineLabel, 151 | morph_boundaries: list[NDArray[np.uint8]], 152 | accent_nucleus_label: int = 1, 153 | accent_phrase_boundary_label: int = 1, 154 | morph_boundary_label: int = 1, 155 | ) -> OpenJTalkFormatLabel: 156 | assert "accent_status" in labels.keys(), "`accent_status` is missing in labels" 157 | assert "accent_phrase_boundary" in labels.keys(), ( 158 | "`accent_phrase_boundary` is missing in labels" 159 | ) 160 | 161 | # squeeze results 162 | mora_accent_status = labels["accent_status"][0] 163 | mora_accent_phrase_boundary = labels["accent_phrase_boundary"][0] 164 | morph_boundary = morph_boundaries[0] 165 | 166 | assert len(mora_accent_status) == len(mora_accent_phrase_boundary), ( 167 | "Not match sequence lenght between" 168 | "`accent_status`, `morph_boundary`, and `accent_phrase_boundary`" 169 | ) 170 | 171 | mora_accent_phrase_boundary = np.array(mora_accent_phrase_boundary) 172 | 173 | # convert mora-based accent phrase boundary label to morph-based label 174 | morph_boundary_indexes = np.where(morph_boundary == morph_boundary_label)[0] 175 | morph_accent_phrase_boundary = np.split( 176 | mora_accent_phrase_boundary, morph_boundary_indexes 177 | ) 178 | # `chain_flag` in OpenJTalk represents the status whether the morph will be connected 179 | morph_accent_phrase_boundary = cast( 180 | list[Literal[-1, 0, 1]], 181 | [ 182 | 0 if boundary[0] == accent_phrase_boundary_label else 1 183 | for boundary in morph_accent_phrase_boundary 184 | ], 185 | ) 186 | # first `chain_flag` must be -1 187 | morph_accent_phrase_boundary[0] = -1 188 | num_boundary = morph_accent_phrase_boundary.count(0) + 1 189 | 190 | # convert mora-based accent status label to ap-based label 191 | # アクセント句境界かつ形態素句境界のindexを取得に修正 192 | mora_accent_phrase_boundary_indexes = np.where( 193 | mora_accent_phrase_boundary + morph_boundary 194 | == accent_phrase_boundary_label + morph_boundary_label 195 | )[0] 196 | phrase_accent_statuses = np.split( 197 | mora_accent_status, mora_accent_phrase_boundary_indexes 198 | ) 199 | phrase_accent_status_labels = [] 200 | 201 | for phrase_accent_status in phrase_accent_statuses: 202 | accent_nucleus_indexes = np.where(phrase_accent_status == accent_nucleus_label)[ 203 | 0 204 | ] 205 | if len(accent_nucleus_indexes) == 0: 206 | accent_nucleus_index = 0 207 | else: 208 | accent_nucleus_index = accent_nucleus_indexes[0] + 1 209 | phrase_accent_status_labels.append(accent_nucleus_index) 210 | 211 | if len(phrase_accent_status_labels) > num_boundary: 212 | warnings.warn( 213 | ( 214 | "Lenght of AP-based accent status will be adjusted " 215 | "by morph-based accent phrase boundary: " 216 | f"{len(phrase_accent_status_labels)} > {num_boundary}" 217 | ), 218 | stacklevel=2, 219 | ) 220 | phrase_accent_status_labels = phrase_accent_status_labels[:num_boundary] 221 | 222 | # convert mora-based accent status to morph-based label 223 | # the accent label for OpenJTalk pushed in first morph 224 | morph_accent_status: list[int] = [ 225 | phrase_accent_status_labels.pop(0) if morph_accent_phrase_flag < 1 else 0 226 | for morph_accent_phrase_flag in morph_accent_phrase_boundary 227 | ] 228 | 229 | return { 230 | "accent_status": morph_accent_status, 231 | "accent_phrase_boundary": morph_accent_phrase_boundary, 232 | } 233 | 234 | 235 | def trans_hyphen2katakana(text: str) -> str: 236 | """ 237 | 伸ばし棒をカタカナに変換 238 | 例:きょー→きょお 239 | """ 240 | hyphen_string_list = re.findall("..ー", text) 241 | text = replace_hyphen(text, hyphen_string_list) 242 | 243 | hyphen_string_list = re.findall(".ー", text) 244 | text = replace_hyphen(text, hyphen_string_list) 245 | 246 | return text 247 | 248 | 249 | def replace_hyphen(text: str, hyphen_string_list: list[str]) -> str: 250 | for _str in hyphen_string_list: 251 | if "[" in _str or "]" in _str: 252 | result = kakasi.convert(_str.replace("[", "").replace("]", ""))[0] 253 | else: 254 | _str_wo_hyphen = _str.replace("ー", "") 255 | result = kakasi.convert(_str_wo_hyphen)[-1] 256 | 257 | transed_hyphen_string = _str[:-1] + BOIN_DICT[result["hepburn"][-1]] 258 | text = text.replace(_str, transed_hyphen_string) 259 | 260 | return text 261 | 262 | 263 | def print_diff_hl(ground_truth: str, target: str) -> None: 264 | """ 265 | 文字列の差異をハイライト表示する 266 | """ 267 | color_dic = {"red": "\033[31m", "green": "\033[32m", "end": "\033[0m"} 268 | 269 | d = difflib.Differ() 270 | diffs = d.compare(ground_truth, target) 271 | 272 | result = "" 273 | for diff in diffs: 274 | status, _, character = list(diff) 275 | if status == "-": 276 | character = color_dic["red"] + character + color_dic["end"] 277 | elif status == "+": 278 | character = color_dic["green"] + character + color_dic["end"] 279 | else: 280 | pass 281 | result += character 282 | 283 | print(f" OpenJTalk : {ground_truth}") 284 | print(f" Annotation : {target}") 285 | print(f"Diff Result : {result}") 286 | -------------------------------------------------------------------------------- /marine/bin/make_raw_corpus.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import sys 5 | from pathlib import Path 6 | 7 | import numpy as np 8 | import yaml 9 | from tqdm import tqdm 10 | 11 | from marine.logger import getLogger 12 | from marine.utils.g2p_util import pron2mora 13 | from marine.utils.openjtalk_util import trans_hyphen2katakana 14 | 15 | 16 | logger = None 17 | 18 | 19 | UNUSED_SYMBOL_REMOVER = str.maketrans("", "", "^$[!?") 20 | ACCENT_NUCLEUS_SYMBOL = "]" 21 | ACCENT_PHRASE_BOUNDARY_SYMBOL = "#" 22 | INTONATION_PHRASE_BOUNDARY_SYMBOL = "_" 23 | INTONATION_PHRASE_BOUNDARY_PUNCTUATION = "," 24 | 25 | 26 | def get_parser(): 27 | parser = argparse.ArgumentParser( 28 | description="Convert Special format txt format data to json file", 29 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 30 | ) 31 | parser.add_argument( 32 | "in_path", default="./data", type=Path, help="Input path or directory" 33 | ) 34 | parser.add_argument("out_dir", default="./raw", type=Path, help="Output directory") 35 | 36 | parser.add_argument( 37 | "--text_f_name", 38 | type=str, 39 | default="text.yaml", 40 | help="text yaml file name.", 41 | ) 42 | parser.add_argument( 43 | "--annot_f_name", 44 | type=str, 45 | default="annotation.yaml", 46 | help="annotation yaml file name.", 47 | ) 48 | parser.add_argument( 49 | "--accent_status_seq_level", 50 | "-s", 51 | type=str, 52 | choices=["ap", "mora"], 53 | default="mora", 54 | help="Sequence level for accent status label", 55 | ) 56 | parser.add_argument( 57 | "--accent_status_represent_mode", 58 | "-m", 59 | type=str, 60 | choices=["binary", "high_low"], 61 | default="binary", 62 | help="""Representation mode for accent status label 63 | (this option will be ignored when --accent-status-seq-level is chosen as 'ap')""", 64 | ) 65 | parser.add_argument( 66 | "--verbose", 67 | "-v", 68 | type=int, 69 | default=50, 70 | help="Logging level", 71 | ) 72 | return parser 73 | 74 | 75 | def alignment_feature(features, target_feature, ignore_features): 76 | # filtering 77 | for ignore_feature in ignore_features: 78 | features = features[features != ignore_feature] 79 | 80 | # detect 81 | _features = features == target_feature 82 | 83 | # move & pad 84 | mask = ~(np.concatenate([_features[1:], [False]])) 85 | _features = _features[mask] 86 | 87 | if target_feature in [ 88 | ACCENT_PHRASE_BOUNDARY_SYMBOL, 89 | INTONATION_PHRASE_BOUNDARY_SYMBOL, 90 | ]: 91 | _features = np.concatenate([[False], _features[:-1]]) 92 | 93 | return _features 94 | 95 | 96 | def convert_mask_seq_to_int_seq(feature): 97 | return np.where(feature, 1, 0) 98 | 99 | 100 | def merge_ip_ap_boundary(accent_phrase_boundaries, intonation_phrase_boundaries): 101 | assert len(accent_phrase_boundaries) == len(intonation_phrase_boundaries) 102 | 103 | return accent_phrase_boundaries + intonation_phrase_boundaries 104 | 105 | 106 | def binary_accent_to_ap_accent( 107 | binary_accents, 108 | accent_phrase_boundaries, 109 | accent_label=1, 110 | accnet_phrase_label=1, 111 | ): 112 | assert len(binary_accents) == len(accent_phrase_boundaries) 113 | 114 | accent_phrase_boundary_indexes = np.where( 115 | accent_phrase_boundaries == accnet_phrase_label 116 | )[0] 117 | splitted_binary_accents = np.split(binary_accents, accent_phrase_boundary_indexes) 118 | 119 | ap_accent_labbels = [] 120 | 121 | for binary_accent in splitted_binary_accents: 122 | accent_indexs = np.where(binary_accent == accent_label)[0] 123 | 124 | if len(accent_indexs) >= 1: 125 | acc = accent_indexs[0] + 1 # zero pad 126 | else: 127 | acc = 0 128 | 129 | ap_accent_labbels.append(acc) 130 | 131 | return ap_accent_labbels 132 | 133 | 134 | def binary_accent_to_high_low_accent( 135 | moras, 136 | binary_accents, 137 | accent_phrase_boundaries, 138 | accent_label=1, 139 | accnet_phrase_label=1, 140 | accent_status_represent_mode="high_low", 141 | ): 142 | assert len(moras) == len(binary_accents) == len(accent_phrase_boundaries) 143 | 144 | accent_phrase_boundary_indexes = np.where( 145 | accent_phrase_boundaries == accnet_phrase_label 146 | )[0] 147 | splitted_moras = np.split(moras, accent_phrase_boundary_indexes) 148 | splitted_binary_accents = np.split(binary_accents, accent_phrase_boundary_indexes) 149 | 150 | high_low_accent_labels = [] 151 | 152 | for mora, binary_accent in zip(splitted_moras, splitted_binary_accents): 153 | accent_indexs = np.where(binary_accent == accent_label)[0] 154 | 155 | if len(accent_indexs) >= 1: 156 | _acc = accent_indexs[0] + 1 # zero pad 157 | _, acc = pron2mora(mora, int(_acc), accent_status_represent_mode) 158 | if accent_status_represent_mode == "high_low" and mora[int(_acc)] == "ー": 159 | acc[int(_acc)] = 0 160 | else: 161 | _, acc = pron2mora(mora, 0, accent_status_represent_mode) 162 | 163 | high_low_accent_labels += acc 164 | 165 | return high_low_accent_labels 166 | 167 | 168 | def convert_to_srt_feature(features, splitter=","): 169 | if isinstance(features, np.ndarray): 170 | features = features.tolist() 171 | elif not isinstance(features, list): 172 | raise TypeError("Wrong type of feature") 173 | 174 | return splitter.join([str(value) for value in features]) 175 | 176 | 177 | def parse_jsut_annotation( 178 | annotation, accent_status_seq_level, accent_status_represent_mode 179 | ): 180 | features = {} 181 | 182 | # preprocessing: remove unused symbols 183 | annotation = annotation.translate(UNUSED_SYMBOL_REMOVER) 184 | # preprocessing: replace ヲ -> オ 185 | annotation = annotation.replace("ヲ", "オ") 186 | 187 | annotation = trans_hyphen2katakana(annotation) 188 | # preprocessing: parse as sequence 189 | mora_based_annotation = np.array(pron2mora(annotation)) 190 | 191 | # filtering symbol 192 | moras = mora_based_annotation[ 193 | (mora_based_annotation != ACCENT_NUCLEUS_SYMBOL) 194 | & (mora_based_annotation != ACCENT_PHRASE_BOUNDARY_SYMBOL) 195 | & (mora_based_annotation != INTONATION_PHRASE_BOUNDARY_SYMBOL) 196 | ] 197 | 198 | assert len(moras) > 0, "Empty annotation" 199 | 200 | binary_accents = alignment_feature( 201 | mora_based_annotation, 202 | ACCENT_NUCLEUS_SYMBOL, 203 | ignore_features=ACCENT_PHRASE_BOUNDARY_SYMBOL 204 | + INTONATION_PHRASE_BOUNDARY_SYMBOL, 205 | ) 206 | accent_phrase_boundaries = alignment_feature( 207 | mora_based_annotation, 208 | ACCENT_PHRASE_BOUNDARY_SYMBOL, 209 | ignore_features=ACCENT_NUCLEUS_SYMBOL + INTONATION_PHRASE_BOUNDARY_SYMBOL, 210 | ) 211 | intonation_phrase_boundaries = alignment_feature( 212 | mora_based_annotation, 213 | INTONATION_PHRASE_BOUNDARY_SYMBOL, 214 | ignore_features=ACCENT_NUCLEUS_SYMBOL + ACCENT_PHRASE_BOUNDARY_SYMBOL, 215 | ) 216 | 217 | assert ( 218 | len(moras) 219 | == len(binary_accents) 220 | == len(accent_phrase_boundaries) 221 | == len(intonation_phrase_boundaries) 222 | ), ( 223 | f"{len(moras)} != {len(binary_accents)} != {len(accent_phrase_boundaries)} != {len(intonation_phrase_boundaries)},{annotation}" 224 | ) 225 | 226 | accents = convert_mask_seq_to_int_seq(binary_accents) 227 | accent_phrase_boundaries = convert_mask_seq_to_int_seq(accent_phrase_boundaries) 228 | intonation_phrase_boundaries = convert_mask_seq_to_int_seq( 229 | intonation_phrase_boundaries 230 | ) 231 | 232 | accent_phrase_boundaries = merge_ip_ap_boundary( 233 | accent_phrase_boundaries, intonation_phrase_boundaries 234 | ) 235 | 236 | if accent_status_seq_level == "ap": 237 | accents = binary_accent_to_ap_accent(accents, accent_phrase_boundaries) 238 | elif accent_status_seq_level == "mora" and accent_status_represent_mode != "binary": 239 | accents = binary_accent_to_high_low_accent( 240 | moras, 241 | accents, 242 | accent_phrase_boundaries, 243 | accent_status_represent_mode=accent_status_represent_mode, 244 | ) 245 | 246 | features["pron"] = convert_to_srt_feature(moras, splitter="") 247 | features["accent_status"] = convert_to_srt_feature(accents) 248 | features["accent_phrase_boundary"] = convert_to_srt_feature( 249 | accent_phrase_boundaries 250 | ) 251 | features["intonation_phrase_boundary"] = convert_to_srt_feature( 252 | intonation_phrase_boundaries 253 | ) 254 | 255 | return features 256 | 257 | 258 | def load_yaml_corpus( 259 | yaml_corpus_dir, 260 | accent_status_seq_level, 261 | accent_status_represent_mode, 262 | text_file_name, 263 | annotation_file_name, 264 | ): 265 | text_yaml_path = os.path.join(yaml_corpus_dir, text_file_name) 266 | annotation_yaml_path = os.path.join(yaml_corpus_dir, annotation_file_name) 267 | 268 | scripts = [] 269 | texts = {} 270 | annotations = {} 271 | 272 | with open(text_yaml_path, encoding="utf-8") as file: 273 | texts = yaml.safe_load(file) 274 | 275 | with open(annotation_yaml_path, encoding="utf-8") as file: 276 | annotations = yaml.safe_load(file) 277 | 278 | assert texts.keys() == annotations.keys(), "Not matched text and annotations" 279 | 280 | for script_id in tqdm(texts.keys(), "Parse anntoations"): 281 | features = {} 282 | 283 | surface = texts[script_id]["text_level0"] 284 | annotation = annotations[script_id] 285 | feature = parse_jsut_annotation( 286 | annotation, accent_status_seq_level, accent_status_represent_mode 287 | ) 288 | 289 | features["script_id"] = script_id 290 | features["surface"] = surface 291 | features.update(feature) 292 | 293 | scripts.append(features) 294 | 295 | logger.info(f"Loaded {len(scripts)} scripts") 296 | 297 | return scripts 298 | 299 | 300 | def entry(argv=sys.argv): 301 | global logger 302 | 303 | args = get_parser().parse_args(argv[1:]) 304 | logger = getLogger(args.verbose) 305 | logger.debug(f"Loaded parameters: {args}") 306 | 307 | scripts = load_yaml_corpus( 308 | args.in_path, 309 | args.accent_status_seq_level, 310 | args.accent_status_represent_mode, 311 | args.text_f_name, 312 | args.annot_f_name, 313 | ) 314 | 315 | if not args.out_dir.exists(): 316 | args.out_dir.mkdir(parents=True) 317 | 318 | with open(args.out_dir / "raw_corpus.json", "w", encoding="utf-8") as file: 319 | json.dump(scripts, file, ensure_ascii=False, indent=4, separators=(",", ": ")) 320 | 321 | 322 | if __name__ == "__main__": 323 | sys.exit(entry()) 324 | -------------------------------------------------------------------------------- /tests/test_modules.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Allen Institute for AI 2 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 3 | 4 | import random 5 | from logging import getLogger 6 | from typing import Any 7 | 8 | import torch 9 | from numpy.testing import assert_almost_equal, assert_equal 10 | 11 | from marine.modules.crf_tagger import logsumexp, viterbi_decode 12 | 13 | 14 | logger = getLogger("test") 15 | 16 | 17 | def test_logsumexp(): 18 | # First a simple example where we add probabilities in log space. 19 | tensor = torch.FloatTensor([[0.4, 0.1, 0.2]]) 20 | log_tensor = tensor.log() 21 | log_summed = logsumexp(log_tensor, dim=-1, keepdim=False) 22 | assert_almost_equal(log_summed.exp().data.numpy(), [0.7]) 23 | log_summed = logsumexp(log_tensor, dim=-1, keepdim=True) 24 | assert_almost_equal(log_summed.exp().data.numpy(), [[0.7]]) 25 | 26 | # Then some more atypical examples, and making sure this will work with how we handle 27 | # log masks. 28 | tensor = torch.FloatTensor([[float("-inf"), 20.0]]) 29 | assert_almost_equal(logsumexp(tensor).data.numpy(), [20.0]) 30 | tensor = torch.FloatTensor([[-200.0, 20.0]]) 31 | assert_almost_equal(logsumexp(tensor).data.numpy(), [20.0]) 32 | tensor = torch.FloatTensor([[20.0, 20.0], [-200.0, 200.0]]) 33 | assert_almost_equal(logsumexp(tensor, dim=0).data.numpy(), [20.0, 200.0]) 34 | 35 | 36 | def test_viterbi_decode(): 37 | # Test Viterbi decoding is equal to greedy decoding with no pairwise potentials. 38 | sequence_logits = torch.nn.functional.softmax(torch.rand([5, 9]), dim=-1) 39 | transition_matrix = torch.zeros([9, 9]) 40 | indices, _ = viterbi_decode(sequence_logits.data, transition_matrix) 41 | _, argmax_indices = torch.max(sequence_logits, 1) 42 | assert indices == argmax_indices.data.squeeze().tolist() 43 | 44 | # Test Viterbi decoding works with start and end transitions 45 | sequence_logits = torch.nn.functional.softmax(torch.rand([5, 9]), dim=-1) 46 | transition_matrix = torch.zeros([9, 9]) 47 | allowed_start_transitions = torch.zeros([9]) 48 | # Force start tag to be an 8 49 | allowed_start_transitions[:8] = float("-inf") 50 | allowed_end_transitions = torch.zeros([9]) 51 | # Force end tag to be a 0 52 | allowed_end_transitions[1:] = float("-inf") 53 | indices, _ = viterbi_decode( 54 | sequence_logits.data, 55 | transition_matrix, 56 | allowed_end_transitions=allowed_end_transitions, 57 | allowed_start_transitions=allowed_start_transitions, 58 | ) 59 | assert indices[0] == 8 60 | assert indices[-1] == 0 61 | 62 | # Test that pairwise potentials affect the sequence correctly and that 63 | # viterbi_decode can handle -inf values. 64 | sequence_logits = torch.FloatTensor( 65 | [ 66 | [0, 0, 0, 3, 5], 67 | [0, 0, 0, 3, 4], 68 | [0, 0, 0, 3, 4], 69 | [0, 0, 0, 3, 4], 70 | [0, 0, 0, 3, 4], 71 | [0, 0, 0, 3, 4], 72 | ] 73 | ) 74 | # The same tags shouldn't appear sequentially. 75 | transition_matrix = torch.zeros([5, 5]) 76 | for i in range(5): 77 | transition_matrix[i, i] = float("-inf") 78 | indices, _ = viterbi_decode(sequence_logits, transition_matrix) 79 | assert indices == [4, 3, 4, 3, 4, 3] 80 | 81 | # Test that unbalanced pairwise potentials break ties 82 | # between paths with equal unary potentials. 83 | sequence_logits = torch.FloatTensor( 84 | [ 85 | [0, 0, 0, 4, 4], 86 | [0, 0, 0, 4, 4], 87 | [0, 0, 0, 4, 4], 88 | [0, 0, 0, 4, 4], 89 | [0, 0, 0, 4, 4], 90 | [0, 0, 0, 4, 4], 91 | ] 92 | ) 93 | # The 5th tag has a penalty for appearing sequentially 94 | # or for transitioning to the 4th tag, making the best 95 | # path uniquely to take the 4th tag only. 96 | transition_matrix = torch.zeros([5, 5]) 97 | transition_matrix[4, 4] = -10 98 | transition_matrix[4, 3] = -10 99 | transition_matrix[3, 4] = -10 100 | indices, _ = viterbi_decode(sequence_logits, transition_matrix) 101 | assert indices == [3, 3, 3, 3, 3, 3] 102 | 103 | sequence_logits = torch.FloatTensor([[1, 0, 0, 4], [1, 0, 6, 2], [0, 3, 0, 4]]) 104 | # Best path would normally be [3, 2, 3] but we add a 105 | # potential from 2 -> 1, making [3, 2, 1] the best path. 106 | transition_matrix = torch.zeros([4, 4]) 107 | transition_matrix[0, 0] = 1 108 | transition_matrix[2, 1] = 5 109 | indices, value = viterbi_decode(sequence_logits, transition_matrix) 110 | assert indices == [3, 2, 1] 111 | assert value.numpy() == 18 112 | 113 | # Test that providing evidence results in paths containing specified tags. 114 | sequence_logits = torch.FloatTensor( 115 | [ 116 | [0, 0, 0, 7, 7], 117 | [0, 0, 0, 7, 7], 118 | [0, 0, 0, 7, 7], 119 | [0, 0, 0, 7, 7], 120 | [0, 0, 0, 7, 7], 121 | [0, 0, 0, 7, 7], 122 | ] 123 | ) 124 | # The 5th tag has a penalty for appearing sequentially 125 | # or for transitioning to the 4th tag, making the best 126 | # path to take the 4th tag for every label. 127 | transition_matrix = torch.zeros([5, 5]) 128 | transition_matrix[4, 4] = -10 129 | transition_matrix[4, 3] = -2 130 | transition_matrix[3, 4] = -2 131 | # The 1st, 4th and 5th sequence elements are observed - they should be 132 | # equal to 2, 0 and 4. The last tag should be equal to 3, because although 133 | # the penalty for transitioning to the 4th tag is -2, the unary potential 134 | # is 7, which is greater than the combination for any of the other labels. 135 | observations = [2, -1, -1, 0, 4, -1] 136 | indices, _ = viterbi_decode(sequence_logits, transition_matrix, observations) 137 | assert indices == [2, 3, 3, 0, 4, 3] 138 | 139 | 140 | def test_viterbi_decode_top_k(): 141 | # Test cases taken from: https://gist.github.com/PetrochukM/afaa3613a99a8e7213d2efdd02ae4762 142 | 143 | # Test Viterbi decoding is equal to greedy decoding with no pairwise potentials. 144 | sequence_logits = torch.autograd.Variable(torch.rand([5, 9])) 145 | transition_matrix = torch.zeros([9, 9]) 146 | 147 | indices, _ = viterbi_decode(sequence_logits.data, transition_matrix, top_k=5) 148 | 149 | _, argmax_indices = torch.max(sequence_logits, 1) 150 | assert indices[0] == argmax_indices.data.squeeze().tolist() 151 | 152 | # Test that pairwise potentials effect the sequence correctly and that 153 | # viterbi_decode can handle -inf values. 154 | sequence_logits = torch.FloatTensor( 155 | [ 156 | [0, 0, 0, 3, 4], 157 | [0, 0, 0, 3, 4], 158 | [0, 0, 0, 3, 4], 159 | [0, 0, 0, 3, 4], 160 | [0, 0, 0, 3, 4], 161 | [0, 0, 0, 3, 4], 162 | ] 163 | ) 164 | # The same tags shouldn't appear sequentially. 165 | transition_matrix = torch.zeros([5, 5]) 166 | for i in range(5): 167 | transition_matrix[i, i] = float("-inf") 168 | indices, _ = viterbi_decode(sequence_logits, transition_matrix, top_k=5) 169 | assert indices[0] == [3, 4, 3, 4, 3, 4] 170 | 171 | # Test that unbalanced pairwise potentials break ties 172 | # between paths with equal unary potentials. 173 | sequence_logits = torch.FloatTensor( 174 | [ 175 | [0, 0, 0, 4, 4], 176 | [0, 0, 0, 4, 4], 177 | [0, 0, 0, 4, 4], 178 | [0, 0, 0, 4, 4], 179 | [0, 0, 0, 4, 4], 180 | [0, 0, 0, 4, 0], 181 | ] 182 | ) 183 | # The 5th tag has a penalty for appearing sequentially 184 | # or for transitioning to the 4th tag, making the best 185 | # path uniquely to take the 4th tag only. 186 | transition_matrix = torch.zeros([5, 5]) 187 | transition_matrix[4, 4] = -10 188 | transition_matrix[4, 3] = -10 189 | indices, _ = viterbi_decode(sequence_logits, transition_matrix, top_k=5) 190 | assert indices[0] == [3, 3, 3, 3, 3, 3] 191 | 192 | sequence_logits = torch.FloatTensor([[1, 0, 0, 4], [1, 0, 6, 2], [0, 3, 0, 4]]) 193 | # Best path would normally be [3, 2, 3] but we add a 194 | # potential from 2 -> 1, making [3, 2, 1] the best path. 195 | transition_matrix = torch.zeros([4, 4]) 196 | transition_matrix[0, 0] = 1 197 | transition_matrix[2, 1] = 5 198 | indices, value = viterbi_decode(sequence_logits, transition_matrix, top_k=5) 199 | assert indices[0] == [3, 2, 1] 200 | assert value[0] == 18 201 | 202 | def _brute_decode( 203 | tag_sequence: torch.Tensor, transition_matrix: torch.Tensor, top_k: int = 5 204 | ) -> Any: 205 | """ 206 | Top-k decoder that uses brute search 207 | instead of the Viterbi Decode dynamic programing algorithm 208 | """ 209 | # Create all possible sequences 210 | sequences = [[]] # type: ignore 211 | 212 | for i in range(len(tag_sequence)): 213 | new_sequences = [] # type: ignore 214 | for j in range(len(tag_sequence[i])): 215 | for sequence in sequences: 216 | new_sequences.append(sequence[:] + [j]) 217 | sequences = new_sequences 218 | 219 | # Score 220 | scored_sequences = [] # type: ignore 221 | for sequence in sequences: 222 | emission_score = sum(tag_sequence[i, j] for i, j in enumerate(sequence)) 223 | transition_score = sum( 224 | transition_matrix[sequence[i - 1], sequence[i]] 225 | for i in range(1, len(sequence)) 226 | ) 227 | score = emission_score + transition_score 228 | scored_sequences.append((score, sequence)) 229 | 230 | # Get the top k scores / paths 231 | top_k_sequences = sorted(scored_sequences, key=lambda r: r[0], reverse=True)[ 232 | :top_k 233 | ] 234 | scores, paths = zip(*top_k_sequences) 235 | 236 | return paths, scores # type: ignore 237 | 238 | def _sanitize(x: Any) -> Any: 239 | """ 240 | Sanitize turns PyTorch and Numpy types into basic Python types 241 | """ 242 | if isinstance(x, (str, float, int, bool)): 243 | return x 244 | elif isinstance(x, torch.Tensor): 245 | return x.cpu().tolist() 246 | elif isinstance(x, (list, tuple, set)): 247 | return [_sanitize(x_i) for x_i in x] 248 | else: 249 | raise ValueError(f"Cannot sanitize {x} of type {type(x)}. ") 250 | 251 | # Run 100 randomly generated parameters and compare the outputs. 252 | for _ in range(100): 253 | num_tags = random.randint(1, 5) 254 | seq_len = random.randint(1, 5) 255 | k = random.randint(1, 5) 256 | sequence_logits = torch.rand([seq_len, num_tags]) 257 | transition_matrix = torch.rand([num_tags, num_tags]) 258 | viterbi_paths_v1, viterbi_scores_v1 = viterbi_decode( 259 | sequence_logits, transition_matrix, top_k=k 260 | ) 261 | viterbi_path_brute, viterbi_score_brute = _brute_decode( 262 | sequence_logits, transition_matrix, top_k=k 263 | ) 264 | assert_almost_equal( 265 | list(viterbi_score_brute), viterbi_scores_v1.tolist(), decimal=3 266 | ) 267 | assert_equal(_sanitize(viterbi_paths_v1), viterbi_path_brute) 268 | --------------------------------------------------------------------------------