├── .clang-format ├── CMakeLists.txt ├── LICENSE ├── LICENSES.3rdparty.txt ├── README.md ├── bootstrap.sh ├── cmake ├── FindASan.cmake ├── FindMSan.cmake ├── FindSanitizers.cmake ├── FindTSan.cmake ├── FindUBSan.cmake ├── asan-wrapper └── sanitize-helpers.cmake ├── experiment ├── Makefile ├── README.md ├── number_to_words │ ├── Makefile │ └── main.cc └── text-util.cc ├── sample ├── hparams.json ├── output01.wav ├── processed01.wav └── sequence01.json ├── src ├── audio_util.cc ├── audio_util.h ├── cxxopts.hpp ├── dr_wav.h ├── json.hpp ├── main.cc ├── tf_synthesizer.cc └── tf_synthesizer.h └── tacotron_frozen.pb /.clang-format: -------------------------------------------------------------------------------- 1 | --- 2 | BasedOnStyle: Google 3 | IndentWidth: 2 4 | TabWidth: 2 5 | UseTab: Never 6 | BreakBeforeBraces: Attach 7 | Standard: Cpp11 8 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.5.1) 2 | 3 | project(TTS) 4 | 5 | set(CMAKE_MODULE_PATH "${CMAKE_SOURCE_DIR}/cmake" ${CMAKE_MODULE_PATH}) 6 | find_package(Sanitizers) # Address sanitizer. 7 | 8 | # threads 9 | find_package(Threads) 10 | 11 | # Add custom build type DebugOpt 12 | message("* Adding build types...") 13 | IF (MSVC) 14 | SET(CMAKE_CXX_FLAGS_DEBUGOPT 15 | "-DDEBUG /DEBUG /O2" 16 | CACHE STRING "Flags used by the C++ compiler during coverage builds." 17 | FORCE ) 18 | SET(CMAKE_C_FLAGS_DEBUGOPT 19 | "-DDEBUG /DEBUG /O2" 20 | CACHE STRING "Flags used by the C compiler during coverage builds." 21 | FORCE ) 22 | ELSE () # Assume gcc 23 | SET(CMAKE_CXX_FLAGS_DEBUGOPT 24 | "-g -O2 -fno-omit-frame-pointer" 25 | CACHE STRING "Flags used by the C++ compiler during coverage builds." 26 | FORCE ) 27 | SET(CMAKE_C_FLAGS_DEBUGOPT 28 | "-g -O2 -fno-omit-frame-pointer" 29 | CACHE STRING "Flags used by the C compiler during coverage builds." 30 | FORCE ) 31 | ENDIF() 32 | 33 | SET(CMAKE_EXE_LINKER_FLAGS_DEBUGOPT 34 | "" 35 | CACHE STRING "Flags used for linking binaries during coverage builds." 36 | FORCE ) 37 | SET(CMAKE_SHARED_LINKER_FLAGS_DEBUGOPT 38 | "" 39 | CACHE STRING "Flags used by the shared libraries linker during coverage builds." 40 | FORCE ) 41 | MARK_AS_ADVANCED( 42 | CMAKE_CXX_FLAGS_DEBUGOPT 43 | CMAKE_C_FLAGS_DEBUGOPT 44 | CMAKE_EXE_LINKER_FLAGS_DEBUGOPT 45 | CMAKE_SHARED_LINKER_FLAGS_DEBUGOPT ) 46 | 47 | IF(NOT CMAKE_BUILD_TYPE) 48 | SET(CMAKE_BUILD_TYPE Release 49 | CACHE STRING "Choose the type of build : None Debug Release RelWithDebInfo MinSizeRel DebugOpt." 50 | FORCE) 51 | ENDIF(NOT CMAKE_BUILD_TYPE) 52 | message("* Current build type is : ${CMAKE_BUILD_TYPE}") 53 | 54 | # C++11 55 | set (CMAKE_CXX_STANDARD 11) 56 | 57 | # PIC 58 | set (CMAKE_POSITION_INDEPENDENT_CODE ON) 59 | 60 | include_directories( 61 | ${CMAKE_SOURCE_DIR}/src 62 | ) 63 | 64 | set (CORE_SOURCE 65 | ${CMAKE_SOURCE_DIR}/src/main.cc 66 | ${CMAKE_SOURCE_DIR}/src/tf_synthesizer.cc 67 | ${CMAKE_SOURCE_DIR}/src/audio_util.cc 68 | ) 69 | 70 | link_directories( 71 | ${TENSORFLOW_BUILD_DIR} 72 | ) 73 | 74 | add_executable( tts 75 | ${CORE_SOURCE} 76 | ) 77 | 78 | target_include_directories(tts 79 | # TensorFlow 80 | PUBLIC ${TENSORFLOW_DIR} 81 | 82 | # for array_ops.h 83 | PUBLIC ${TENSORFLOW_DIR}/bazel-genfiles 84 | 85 | # headers for external packages 86 | PUBLIC ${TENSORFLOW_EXTERNAL_DIR}/external/protobuf_archive/src 87 | PUBLIC ${TENSORFLOW_EXTERNAL_DIR}/external/eigen_archive 88 | PUBLIC ${TENSORFLOW_EXTERNAL_DIR}/external/nsync/public 89 | PUBLIC ${TENSORFLOW_EXTERNAL_DIR}/external/com_google_absl 90 | 91 | # this project 92 | PUBLIC ${CMAKE_SOURCE_DIR}/src 93 | ) 94 | 95 | target_link_libraries( tts 96 | # TensorFlow C++ 97 | tensorflow_cc 98 | 99 | ${TTS_EXT_LIBS} 100 | ${CMAKE_THREAD_LIBS_INIT} 101 | ${CMAKE_DL_LIBS} 102 | ) 103 | 104 | 105 | # Increase warning level for clang. 106 | IF (CMAKE_CXX_COMPILER_ID MATCHES "Clang") 107 | target_compile_options(tts PRIVATE -Weverything -Werror -Wno-padded -Wno-c++98-compat-pedantic -Wno-documentation -Wno-documentation-unknown-command) 108 | ENDIF () 109 | 110 | add_sanitizers(tts) 111 | 112 | # [VisualStudio] 113 | if (WIN32) 114 | # Set `tts` as a startup project for VS IDE 115 | if (CMAKE_VERSION VERSION_GREATER 3.6.0) 116 | set_property(DIRECTORY PROPERTY VS_STARTUP_PROJECT tts) 117 | endif () 118 | 119 | # For easier debugging in VS IDE(cmake 3.8.0 or later required) 120 | if (CMAKE_VERSION VERSION_GREATER 3.8.0) 121 | # Set working directory to $lucia git repo root. 122 | set_target_properties(tts PROPERTIES VS_DEBUGGER_WORKING_DIRECTORY "${CMAKE_SOURCE_DIR}") 123 | endif () 124 | endif () 125 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Syoyo Fujita 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /LICENSES.3rdparty.txt: -------------------------------------------------------------------------------- 1 | - dr_wav 2 | 3 | Public domain 4 | 5 | - keithito tacotron 6 | 7 | MIT license 8 | 9 | - cxxopts.hpp 10 | 11 | MIT license 12 | 13 | - json.hpp 14 | 15 | MIT license 16 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Text-to-speech in (partially) C++ using Tacotron model + Tensorflow 2 | 3 | Running Tacotron model in TensorFlow C++ API. 4 | 5 | Its good for running TTS in mobile or embedded device. 6 | 7 | Code is based on keithito's tacotron implementation: https://github.com/keithito/tacotron 8 | 9 | ## Status 10 | 11 | Experimental. 12 | 13 | Python preprocessing is required to generate sequence data from a text. 14 | 15 | ## Requirment 16 | 17 | * TensorFlow r1.8+ 18 | * Ubuntu 16.04 or later 19 | * C++ compiler + cmake 20 | 21 | ## Dump graph. 22 | 23 | In keithito's tacotron repo, append `tf.train.write_graph` to `Synthesizer::load` to save TensorFlow graph. 24 | 25 | ``` 26 | class Synthesizer: 27 | def load(self, checkpoint_path, model_name='tacotron'): 28 | 29 | ... 30 | 31 | # write graph 32 | tf.train.write_graph(self.session.graph.as_graph_def(), "models/", "graph.pb") 33 | ``` 34 | 35 | ## Freeze graph 36 | 37 | Freeze graph for example: 38 | 39 | ``` 40 | freeze_graph \ 41 | --input_graph=models/graph.pb \ 42 | --input_checkpoint=./tacotron-20180906/model.ckpt \ 43 | --output_graph=models/tacotron_frozen.pb \ 44 | --output_node_names=model/griffinlim/Squeeze 45 | ``` 46 | 47 | Example freeze graph file is included in this repo. 48 | 49 | ## Build 50 | 51 | Edit libtensorflow_cc.so path(Assume you build TensorFlow from source code) in `bootstrap.sh`, then 52 | 53 | ``` 54 | $ ./bootstrap.sh 55 | $ build 56 | $ make 57 | ``` 58 | 59 | ### Note on libtensorflow_cc 60 | 61 | Please make sure building libtensorflow_cc with `--config=monolithic`. Otherwise you'll face undefined symbols error at linking stage. 62 | 63 | https://www.tensorflow.org/install/source#preconfigured_configurations 64 | 65 | 66 | 67 | ## Run 68 | 69 | Prepare sequence JSON file. 70 | Sequence can be generated by using `text_to_sequence()` function in keithito's tacotron repo. 71 | 72 | See `sample/sequence01.json` for generated example. 73 | 74 | Then, 75 | 76 | ``` 77 | $ ./tts -i ../sample/sequence01.json -g ../tacotron_frozen.pb output.wav 78 | ``` 79 | 80 | example output01.wav and processed01.wav is included in `sample/` 81 | 82 | ### Optional parameter 83 | 84 | You can specify hyperparameter settings(JSON format) using `-h` option. 85 | See `sample/hparams.json` for example. 86 | 87 | ``` 88 | $ ./tts -i ../sample/sequence01.json -h ../sample/hparams.json -g ../tacotron_frozen.pb output.wav 89 | ``` 90 | 91 | ## Performance 92 | 93 | Currently TensorFlow C++ code path only uses single CPU core, so its slow. 94 | Time for synthesis is roughly 10x slower on 2018's CPU than synthesized audio length(e.g. 60 secs for 6 secs audio). 95 | 96 | ## TODO 97 | 98 | * Write all TTS pipeline fully in C++ 99 | * [ ] Text to sequence(Issue #1) 100 | * [ ] Convert to lower case 101 | * [ ] Expand abbreviation 102 | * [ ] Normalize numbers(number_to_words. python inflect equivalent) 103 | * [ ] Remove extra whitespace 104 | * [ ] Use CPU implementation of Griffin-Lim 105 | 106 | ## License 107 | 108 | MIT license. 109 | 110 | Pretrained model used for freezing graph is obtained from keithito's repo. 111 | 112 | ### Third party licenses 113 | 114 | - json.hpp : MIT license 115 | - cxxopts.hpp : MIT license 116 | - dr_wav : Public domain 117 | -------------------------------------------------------------------------------- /bootstrap.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # source code directory of tensorflow 4 | TF_DIR=`pwd`/../tensorflow 5 | 6 | # external source code directory of tensorflow(e.g. Eigen) 7 | TF_EXTERNAL_DIR=`pwd`/../tensorflow/bazel-tensorflow 8 | 9 | # bazel build directory of tensorflow where `libtensorflow.so` exists. 10 | # Please specify absolute path, otherwise cmake cannot find lib**.a 11 | TF_BUILD_DIR=`pwd`/../tensorflow/bazel-bin/tensorflow 12 | 13 | rm -rf build 14 | 15 | cmake -DTENSORFLOW_DIR=${TF_DIR} \ 16 | -DTENSORFLOW_EXTERNAL_DIR=${TF_EXTERNAL_DIR} \ 17 | -DTENSORFLOW_BUILD_DIR=${TF_BUILD_DIR} \ 18 | -DSANITIZE_ADDRESS=On \ 19 | -Bbuild \ 20 | -H. 21 | -------------------------------------------------------------------------------- /cmake/FindASan.cmake: -------------------------------------------------------------------------------- 1 | # The MIT License (MIT) 2 | # 3 | # Copyright (c) 4 | # 2013 Matthew Arsenault 5 | # 2015-2016 RWTH Aachen University, Federal Republic of Germany 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | option(SANITIZE_ADDRESS "Enable AddressSanitizer for sanitized targets." Off) 26 | 27 | set(FLAG_CANDIDATES 28 | # Clang 3.2+ use this version. The no-omit-frame-pointer option is optional. 29 | "-g -fsanitize=address -fno-omit-frame-pointer" 30 | "-g -fsanitize=address" 31 | 32 | # Older deprecated flag for ASan 33 | "-g -faddress-sanitizer" 34 | ) 35 | 36 | 37 | if (SANITIZE_ADDRESS AND (SANITIZE_THREAD OR SANITIZE_MEMORY)) 38 | message(FATAL_ERROR "AddressSanitizer is not compatible with " 39 | "ThreadSanitizer or MemorySanitizer.") 40 | endif () 41 | 42 | 43 | include(sanitize-helpers) 44 | 45 | if (SANITIZE_ADDRESS) 46 | sanitizer_check_compiler_flags("${FLAG_CANDIDATES}" "AddressSanitizer" 47 | "ASan") 48 | 49 | find_program(ASan_WRAPPER "asan-wrapper" PATHS ${CMAKE_MODULE_PATH}) 50 | mark_as_advanced(ASan_WRAPPER) 51 | endif () 52 | 53 | function (add_sanitize_address TARGET) 54 | if (NOT SANITIZE_ADDRESS) 55 | return() 56 | endif () 57 | 58 | saitizer_add_flags(${TARGET} "AddressSanitizer" "ASan") 59 | endfunction () 60 | -------------------------------------------------------------------------------- /cmake/FindMSan.cmake: -------------------------------------------------------------------------------- 1 | # The MIT License (MIT) 2 | # 3 | # Copyright (c) 4 | # 2013 Matthew Arsenault 5 | # 2015-2016 RWTH Aachen University, Federal Republic of Germany 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | option(SANITIZE_MEMORY "Enable MemorySanitizer for sanitized targets." Off) 26 | 27 | set(FLAG_CANDIDATES 28 | "-g -fsanitize=memory" 29 | ) 30 | 31 | 32 | include(sanitize-helpers) 33 | 34 | if (SANITIZE_MEMORY) 35 | if (NOT ${CMAKE_SYSTEM_NAME} STREQUAL "Linux") 36 | message(WARNING "MemorySanitizer disabled for target ${TARGET} because " 37 | "MemorySanitizer is supported for Linux systems only.") 38 | set(SANITIZE_MEMORY Off CACHE BOOL 39 | "Enable MemorySanitizer for sanitized targets." FORCE) 40 | elseif (NOT ${CMAKE_SIZEOF_VOID_P} EQUAL 8) 41 | message(WARNING "MemorySanitizer disabled for target ${TARGET} because " 42 | "MemorySanitizer is supported for 64bit systems only.") 43 | set(SANITIZE_MEMORY Off CACHE BOOL 44 | "Enable MemorySanitizer for sanitized targets." FORCE) 45 | else () 46 | sanitizer_check_compiler_flags("${FLAG_CANDIDATES}" "MemorySanitizer" 47 | "MSan") 48 | endif () 49 | endif () 50 | 51 | function (add_sanitize_memory TARGET) 52 | if (NOT SANITIZE_MEMORY) 53 | return() 54 | endif () 55 | 56 | saitizer_add_flags(${TARGET} "MemorySanitizer" "MSan") 57 | endfunction () 58 | -------------------------------------------------------------------------------- /cmake/FindSanitizers.cmake: -------------------------------------------------------------------------------- 1 | # The MIT License (MIT) 2 | # 3 | # Copyright (c) 4 | # 2013 Matthew Arsenault 5 | # 2015-2016 RWTH Aachen University, Federal Republic of Germany 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | # If any of the used compiler is a GNU compiler, add a second option to static 26 | # link against the sanitizers. 27 | option(SANITIZE_LINK_STATIC "Try to link static against sanitizers." Off) 28 | 29 | 30 | 31 | 32 | set(FIND_QUIETLY_FLAG "") 33 | if (DEFINED Sanitizers_FIND_QUIETLY) 34 | set(FIND_QUIETLY_FLAG "QUIET") 35 | endif () 36 | 37 | find_package(ASan ${FIND_QUIETLY_FLAG}) 38 | find_package(TSan ${FIND_QUIETLY_FLAG}) 39 | find_package(MSan ${FIND_QUIETLY_FLAG}) 40 | find_package(UBSan ${FIND_QUIETLY_FLAG}) 41 | 42 | 43 | 44 | 45 | function(sanitizer_add_blacklist_file FILE) 46 | if(NOT IS_ABSOLUTE ${FILE}) 47 | set(FILE "${CMAKE_CURRENT_SOURCE_DIR}/${FILE}") 48 | endif() 49 | get_filename_component(FILE "${FILE}" REALPATH) 50 | 51 | sanitizer_check_compiler_flags("-fsanitize-blacklist=${FILE}" 52 | "SanitizerBlacklist" "SanBlist") 53 | endfunction() 54 | 55 | function(add_sanitizers ...) 56 | # If no sanitizer is enabled, return immediately. 57 | if (NOT (SANITIZE_ADDRESS OR SANITIZE_MEMORY OR SANITIZE_THREAD OR 58 | SANITIZE_UNDEFINED)) 59 | return() 60 | endif () 61 | 62 | foreach (TARGET ${ARGV}) 63 | # Check if this target will be compiled by exactly one compiler. Other- 64 | # wise sanitizers can't be used and a warning should be printed once. 65 | sanitizer_target_compilers(${TARGET} TARGET_COMPILER) 66 | list(LENGTH TARGET_COMPILER NUM_COMPILERS) 67 | if (NUM_COMPILERS GREATER 1) 68 | message(WARNING "Can't use any sanitizers for target ${TARGET}, " 69 | "because it will be compiled by incompatible compilers. " 70 | "Target will be compiled without sanitzers.") 71 | return() 72 | 73 | # If the target is compiled by no known compiler, ignore it. 74 | elseif (NUM_COMPILERS EQUAL 0) 75 | message(WARNING "Can't use any sanitizers for target ${TARGET}, " 76 | "because it uses an unknown compiler. Target will be " 77 | "compiled without sanitzers.") 78 | return() 79 | endif () 80 | 81 | # Add sanitizers for target. 82 | add_sanitize_address(${TARGET}) 83 | add_sanitize_thread(${TARGET}) 84 | add_sanitize_memory(${TARGET}) 85 | add_sanitize_undefined(${TARGET}) 86 | endforeach () 87 | endfunction(add_sanitizers) 88 | -------------------------------------------------------------------------------- /cmake/FindTSan.cmake: -------------------------------------------------------------------------------- 1 | # The MIT License (MIT) 2 | # 3 | # Copyright (c) 4 | # 2013 Matthew Arsenault 5 | # 2015-2016 RWTH Aachen University, Federal Republic of Germany 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | option(SANITIZE_THREAD "Enable ThreadSanitizer for sanitized targets." Off) 26 | 27 | set(FLAG_CANDIDATES 28 | "-g -fsanitize=thread" 29 | ) 30 | 31 | 32 | # ThreadSanitizer is not compatible with MemorySanitizer. 33 | if (SANITIZE_THREAD AND SANITIZE_MEMORY) 34 | message(FATAL_ERROR "ThreadSanitizer is not compatible with " 35 | "MemorySanitizer.") 36 | endif () 37 | 38 | 39 | include(sanitize-helpers) 40 | 41 | if (SANITIZE_THREAD) 42 | if (NOT ${CMAKE_SYSTEM_NAME} STREQUAL "Linux") 43 | message(WARNING "ThreadSanitizer disabled for target ${TARGET} because " 44 | "ThreadSanitizer is supported for Linux systems only.") 45 | set(SANITIZE_THREAD Off CACHE BOOL 46 | "Enable ThreadSanitizer for sanitized targets." FORCE) 47 | elseif (NOT ${CMAKE_SIZEOF_VOID_P} EQUAL 8) 48 | message(WARNING "ThreadSanitizer disabled for target ${TARGET} because " 49 | "ThreadSanitizer is supported for 64bit systems only.") 50 | set(SANITIZE_THREAD Off CACHE BOOL 51 | "Enable ThreadSanitizer for sanitized targets." FORCE) 52 | else () 53 | sanitizer_check_compiler_flags("${FLAG_CANDIDATES}" "ThreadSanitizer" 54 | "TSan") 55 | endif () 56 | endif () 57 | 58 | function (add_sanitize_thread TARGET) 59 | if (NOT SANITIZE_THREAD) 60 | return() 61 | endif () 62 | 63 | saitizer_add_flags(${TARGET} "ThreadSanitizer" "TSan") 64 | endfunction () 65 | -------------------------------------------------------------------------------- /cmake/FindUBSan.cmake: -------------------------------------------------------------------------------- 1 | # The MIT License (MIT) 2 | # 3 | # Copyright (c) 4 | # 2013 Matthew Arsenault 5 | # 2015-2016 RWTH Aachen University, Federal Republic of Germany 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | option(SANITIZE_UNDEFINED 26 | "Enable UndefinedBehaviorSanitizer for sanitized targets." Off) 27 | 28 | set(FLAG_CANDIDATES 29 | "-g -fsanitize=undefined" 30 | ) 31 | 32 | 33 | include(sanitize-helpers) 34 | 35 | if (SANITIZE_UNDEFINED) 36 | sanitizer_check_compiler_flags("${FLAG_CANDIDATES}" 37 | "UndefinedBehaviorSanitizer" "UBSan") 38 | endif () 39 | 40 | function (add_sanitize_undefined TARGET) 41 | if (NOT SANITIZE_UNDEFINED) 42 | return() 43 | endif () 44 | 45 | saitizer_add_flags(${TARGET} "UndefinedBehaviorSanitizer" "UBSan") 46 | endfunction () 47 | -------------------------------------------------------------------------------- /cmake/asan-wrapper: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # The MIT License (MIT) 4 | # 5 | # Copyright (c) 6 | # 2013 Matthew Arsenault 7 | # 2015-2016 RWTH Aachen University, Federal Republic of Germany 8 | # 9 | # Permission is hereby granted, free of charge, to any person obtaining a copy 10 | # of this software and associated documentation files (the "Software"), to deal 11 | # in the Software without restriction, including without limitation the rights 12 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 13 | # copies of the Software, and to permit persons to whom the Software is 14 | # furnished to do so, subject to the following conditions: 15 | # 16 | # The above copyright notice and this permission notice shall be included in all 17 | # copies or substantial portions of the Software. 18 | # 19 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 20 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 21 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 22 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 23 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 24 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 25 | # SOFTWARE. 26 | 27 | # This script is a wrapper for AddressSanitizer. In some special cases you need 28 | # to preload AddressSanitizer to avoid error messages - e.g. if you're 29 | # preloading another library to your application. At the moment this script will 30 | # only do something, if we're running on a Linux platform. OSX might not be 31 | # affected. 32 | 33 | 34 | # Exit immediately, if platform is not Linux. 35 | if [ "$(uname)" != "Linux" ] 36 | then 37 | exec $@ 38 | fi 39 | 40 | 41 | # Get the used libasan of the application ($1). If a libasan was found, it will 42 | # be prepended to LD_PRELOAD. 43 | libasan=$(ldd $1 | grep libasan | sed "s/^[[:space:]]//" | cut -d' ' -f1) 44 | if [ -n "$libasan" ] 45 | then 46 | if [ -n "$LD_PRELOAD" ] 47 | then 48 | export LD_PRELOAD="$libasan:$LD_PRELOAD" 49 | else 50 | export LD_PRELOAD="$libasan" 51 | fi 52 | fi 53 | 54 | # Execute the application. 55 | exec $@ 56 | -------------------------------------------------------------------------------- /cmake/sanitize-helpers.cmake: -------------------------------------------------------------------------------- 1 | # The MIT License (MIT) 2 | # 3 | # Copyright (c) 4 | # 2013 Matthew Arsenault 5 | # 2015-2016 RWTH Aachen University, Federal Republic of Germany 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | # Helper function to get the language of a source file. 26 | function (sanitizer_lang_of_source FILE RETURN_VAR) 27 | get_filename_component(FILE_EXT "${FILE}" EXT) 28 | string(TOLOWER "${FILE_EXT}" FILE_EXT) 29 | string(SUBSTRING "${FILE_EXT}" 1 -1 FILE_EXT) 30 | 31 | get_property(ENABLED_LANGUAGES GLOBAL PROPERTY ENABLED_LANGUAGES) 32 | foreach (LANG ${ENABLED_LANGUAGES}) 33 | list(FIND CMAKE_${LANG}_SOURCE_FILE_EXTENSIONS "${FILE_EXT}" TEMP) 34 | if (NOT ${TEMP} EQUAL -1) 35 | set(${RETURN_VAR} "${LANG}" PARENT_SCOPE) 36 | return() 37 | endif () 38 | endforeach() 39 | 40 | set(${RETURN_VAR} "" PARENT_SCOPE) 41 | endfunction () 42 | 43 | 44 | # Helper function to get compilers used by a target. 45 | function (sanitizer_target_compilers TARGET RETURN_VAR) 46 | # Check if all sources for target use the same compiler. If a target uses 47 | # e.g. C and Fortran mixed and uses different compilers (e.g. clang and 48 | # gfortran) this can trigger huge problems, because different compilers may 49 | # use different implementations for sanitizers. 50 | set(BUFFER "") 51 | get_target_property(TSOURCES ${TARGET} SOURCES) 52 | foreach (FILE ${TSOURCES}) 53 | # If expression was found, FILE is a generator-expression for an object 54 | # library. Object libraries will be ignored. 55 | string(REGEX MATCH "TARGET_OBJECTS:([^ >]+)" _file ${FILE}) 56 | if ("${_file}" STREQUAL "") 57 | sanitizer_lang_of_source(${FILE} LANG) 58 | if (LANG) 59 | list(APPEND BUFFER ${CMAKE_${LANG}_COMPILER_ID}) 60 | endif () 61 | endif () 62 | endforeach () 63 | 64 | list(REMOVE_DUPLICATES BUFFER) 65 | set(${RETURN_VAR} "${BUFFER}" PARENT_SCOPE) 66 | endfunction () 67 | 68 | 69 | # Helper function to check compiler flags for language compiler. 70 | function (sanitizer_check_compiler_flag FLAG LANG VARIABLE) 71 | if (${LANG} STREQUAL "C") 72 | include(CheckCCompilerFlag) 73 | check_c_compiler_flag("${FLAG}" ${VARIABLE}) 74 | 75 | elseif (${LANG} STREQUAL "CXX") 76 | include(CheckCXXCompilerFlag) 77 | check_cxx_compiler_flag("${FLAG}" ${VARIABLE}) 78 | 79 | elseif (${LANG} STREQUAL "Fortran") 80 | # CheckFortranCompilerFlag was introduced in CMake 3.x. To be compatible 81 | # with older Cmake versions, we will check if this module is present 82 | # before we use it. Otherwise we will define Fortran coverage support as 83 | # not available. 84 | include(CheckFortranCompilerFlag OPTIONAL RESULT_VARIABLE INCLUDED) 85 | if (INCLUDED) 86 | check_fortran_compiler_flag("${FLAG}" ${VARIABLE}) 87 | elseif (NOT CMAKE_REQUIRED_QUIET) 88 | message(STATUS "Performing Test ${VARIABLE}") 89 | message(STATUS "Performing Test ${VARIABLE}" 90 | " - Failed (Check not supported)") 91 | endif () 92 | endif() 93 | endfunction () 94 | 95 | 96 | # Helper function to test compiler flags. 97 | function (sanitizer_check_compiler_flags FLAG_CANDIDATES NAME PREFIX) 98 | set(CMAKE_REQUIRED_QUIET ${${PREFIX}_FIND_QUIETLY}) 99 | 100 | get_property(ENABLED_LANGUAGES GLOBAL PROPERTY ENABLED_LANGUAGES) 101 | foreach (LANG ${ENABLED_LANGUAGES}) 102 | # Sanitizer flags are not dependend on language, but the used compiler. 103 | # So instead of searching flags foreach language, search flags foreach 104 | # compiler used. 105 | set(COMPILER ${CMAKE_${LANG}_COMPILER_ID}) 106 | if (NOT DEFINED ${PREFIX}_${COMPILER}_FLAGS) 107 | foreach (FLAG ${FLAG_CANDIDATES}) 108 | if(NOT CMAKE_REQUIRED_QUIET) 109 | message(STATUS "Try ${COMPILER} ${NAME} flag = [${FLAG}]") 110 | endif() 111 | 112 | set(CMAKE_REQUIRED_FLAGS "${FLAG}") 113 | unset(${PREFIX}_FLAG_DETECTED CACHE) 114 | sanitizer_check_compiler_flag("${FLAG}" ${LANG} 115 | ${PREFIX}_FLAG_DETECTED) 116 | 117 | if (${PREFIX}_FLAG_DETECTED) 118 | # If compiler is a GNU compiler, search for static flag, if 119 | # SANITIZE_LINK_STATIC is enabled. 120 | if (SANITIZE_LINK_STATIC AND (${COMPILER} STREQUAL "GNU")) 121 | string(TOLOWER ${PREFIX} PREFIX_lower) 122 | sanitizer_check_compiler_flag( 123 | "-static-lib${PREFIX_lower}" ${LANG} 124 | ${PREFIX}_STATIC_FLAG_DETECTED) 125 | 126 | if (${PREFIX}_STATIC_FLAG_DETECTED) 127 | set(FLAG "-static-lib${PREFIX_lower} ${FLAG}") 128 | endif () 129 | endif () 130 | 131 | set(${PREFIX}_${COMPILER}_FLAGS "${FLAG}" CACHE STRING 132 | "${NAME} flags for ${COMPILER} compiler.") 133 | mark_as_advanced(${PREFIX}_${COMPILER}_FLAGS) 134 | break() 135 | endif () 136 | endforeach () 137 | 138 | if (NOT ${PREFIX}_FLAG_DETECTED) 139 | set(${PREFIX}_${COMPILER}_FLAGS "" CACHE STRING 140 | "${NAME} flags for ${COMPILER} compiler.") 141 | mark_as_advanced(${PREFIX}_${COMPILER}_FLAGS) 142 | 143 | message(WARNING "${NAME} is not available for ${COMPILER} " 144 | "compiler. Targets using this compiler will be " 145 | "compiled without ${NAME}.") 146 | endif () 147 | endif () 148 | endforeach () 149 | endfunction () 150 | 151 | 152 | # Helper to assign sanitizer flags for TARGET. 153 | function (saitizer_add_flags TARGET NAME PREFIX) 154 | # Get list of compilers used by target and check, if sanitizer is available 155 | # for this target. Other compiler checks like check for conflicting 156 | # compilers will be done in add_sanitizers function. 157 | sanitizer_target_compilers(${TARGET} TARGET_COMPILER) 158 | list(LENGTH TARGET_COMPILER NUM_COMPILERS) 159 | if ("${${PREFIX}_${TARGET_COMPILER}_FLAGS}" STREQUAL "") 160 | return() 161 | endif() 162 | 163 | # Set compile- and link-flags for target. 164 | set_property(TARGET ${TARGET} APPEND_STRING 165 | PROPERTY COMPILE_FLAGS " ${${PREFIX}_${TARGET_COMPILER}_FLAGS}") 166 | set_property(TARGET ${TARGET} APPEND_STRING 167 | PROPERTY COMPILE_FLAGS " ${SanBlist_${TARGET_COMPILER}_FLAGS}") 168 | set_property(TARGET ${TARGET} APPEND_STRING 169 | PROPERTY LINK_FLAGS " ${${PREFIX}_${TARGET_COMPILER}_FLAGS}") 170 | endfunction () 171 | -------------------------------------------------------------------------------- /experiment/Makefile: -------------------------------------------------------------------------------- 1 | all: 2 | clang++ -std=c++11 -fsanitize=address -g -Weverything -Werror -Wno-c++98-compat text-util.cc 3 | -------------------------------------------------------------------------------- /experiment/README.md: -------------------------------------------------------------------------------- 1 | Text-to-sequence in C++ experiment 2 | -------------------------------------------------------------------------------- /experiment/number_to_words/Makefile: -------------------------------------------------------------------------------- 1 | all: 2 | clang++ -std=c++11 -Weverything main.cc 3 | -------------------------------------------------------------------------------- /experiment/number_to_words/main.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | // C++ implementation of python inflect's number_to_words() 6 | 7 | static std::string lastN(std::string input, size_t n) 8 | { 9 | size_t length = input.size(); 10 | return (n > 0 && length > n) ? input.substr(length - n) : ""; 11 | } 12 | 13 | int main(int argc, char **argv) 14 | { 15 | std::string input = "42nd"; 16 | 17 | if (argc > 1) { 18 | input = std::string(argv[1]); 19 | } 20 | 21 | // # Handle "stylistic" conversions (up to a given threshold)... 22 | // if threshold is not None and float(num) > threshold: 23 | // spnum = num.split(".", 1) 24 | // while comma: 25 | // (spnum[0], n) = re.subn(r"(\d)(\d{3}(?:,|\Z))", r"\1,\2", spnum[0]) 26 | // if n == 0: 27 | // break 28 | // try: 29 | // return "{}.{}".format(spnum[0], spnum[1]) 30 | // except IndexError: 31 | // return "%s" % spnum[0] 32 | 33 | // if group < 0 or group > 3: 34 | // raise BadChunkingOptionError 35 | // nowhite = num.lstrip() 36 | // if nowhite[0] == "+": 37 | // sign = "plus" 38 | // elif nowhite[0] == "-": 39 | // sign = "minus" 40 | // else: 41 | // sign = "" 42 | 43 | 44 | // Check ordinal number 45 | 46 | // myord = num[-2:] in ("st", "nd", "rd", "th") 47 | // if myord: 48 | // num = num[:-2] 49 | 50 | std::string ord = lastN(input, 2); 51 | if ((ord.compare("st") == 0) || 52 | (ord.compare("nd") == 0) || 53 | (ord.compare("rd") == 0) || 54 | (ord.compare("th") == 0)) { 55 | // strip 56 | input = input.substr(input.size() - 2); 57 | std::cout << "input = " << input << std::endl; 58 | } 59 | 60 | //bool finalpoint = false; 61 | bool is_decimal = false; 62 | int group = 0; 63 | 64 | //if decimal: 65 | // if group != 0: 66 | // chunks = num.split(".") 67 | // else: 68 | // chunks = num.split(".", 1) 69 | // if chunks[-1] == "": # remove blank string if nothing after decimal 70 | // chunks = chunks[:-1] 71 | // finalpoint = True # add 'point' to end of output 72 | //else: 73 | // chunks = [num] 74 | 75 | if (is_decimal) { 76 | if (group != 0) { 77 | //chunks = num.split(".") 78 | } else { 79 | 80 | } 81 | } 82 | 83 | return EXIT_SUCCESS; 84 | 85 | } 86 | -------------------------------------------------------------------------------- /experiment/text-util.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #include 9 | #include 10 | 11 | namespace { 12 | 13 | static const std::set GetValidSymbolsSet() { 14 | const std::vector valid_symbols = { 15 | "AA", "AA0", "AA1", "AA2", "AE", "AE0", "AE1", "AE2", "AH", "AH0", 16 | "AH1", "AH2", "AO", "AO0", "AO1", "AO2", "AW", "AW0", "AW1", "AW2", 17 | "AY", "AY0", "AY1", "AY2", "B", "CH", "D", "DH", "EH", "EH0", 18 | "EH1", "EH2", "ER", "ER0", "ER1", "ER2", "EY", "EY0", "EY1", "EY2", 19 | "F", "G", "HH", "IH", "IH0", "IH1", "IH2", "IY", "IY0", "IY1", 20 | "IY2", "JH", "K", "L", "M", "N", "NG", "OW", "OW0", "OW1", 21 | "OW2", "OY", "OY0", "OY1", "OY2", "P", "R", "S", "SH", "T", 22 | "TH", "UH", "UH0", "UH1", "UH2", "UW", "UW0", "UW1", "UW2", "V", 23 | "W", "Y", "Z", "ZH"}; 24 | 25 | std::set symbol_set; 26 | 27 | for (auto &s : valid_symbols) { 28 | symbol_set.insert(s); 29 | } 30 | 31 | return symbol_set; 32 | } 33 | 34 | // https://stackoverflow.com/questions/236129/the-most-elegant-way-to-iterate-the-words-of-a-string 35 | std::vector split(const std::string &text, 36 | const std::string &delims) { 37 | std::vector tokens; 38 | std::size_t start = text.find_first_not_of(delims), end = 0; 39 | 40 | while ((end = text.find_first_of(delims, start)) != std::string::npos) { 41 | tokens.push_back(text.substr(start, end - start)); 42 | start = text.find_first_not_of(delims, end); 43 | } 44 | if (start != std::string::npos) tokens.push_back(text.substr(start)); 45 | 46 | return tokens; 47 | } 48 | 49 | static std::string GetPronounciation( 50 | const std::set &valid_symbols_set, const std::string &s) { 51 | 52 | assert(s.size() >= 1); 53 | 54 | std::string ss = s; 55 | ss.pop_back(); // strip 56 | std::vector parts = split(ss, " "); 57 | 58 | for (auto &part : parts) { 59 | if (!valid_symbols_set.count(part)) { 60 | return std::string(); 61 | } 62 | } 63 | 64 | // ' '.join(parts) 65 | std::string ret; 66 | for (size_t i = 0; i < parts.size(); i++) { 67 | ret += parts[i]; 68 | if (i != (parts.size() - 1)) { 69 | ret += ' '; 70 | } 71 | } 72 | 73 | return ret; 74 | } 75 | 76 | // CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict 77 | static bool ParseCMUDict(const std::string &filename, 78 | std::map> *cmudict) { 79 | std::ifstream ifs(filename); 80 | if (!ifs) { 81 | return false; 82 | } 83 | 84 | cmudict->clear(); 85 | 86 | std::set valid_symbols_set = GetValidSymbolsSet(); 87 | 88 | std::regex alt_re("\\([0-9]+\\)"); 89 | 90 | std::string line; 91 | while (std::getline(ifs, line)) { 92 | if ((line.size() > 0) && 93 | (((line[0] >= 'A') && (line[0] <= 'Z')) || (line[0] == '\''))) { 94 | std::vector parts = split(line, " "); 95 | 96 | if (parts.size() >= 2) { 97 | std::string word; 98 | assert(!word.empty()); 99 | 100 | std::string pronounciation = 101 | GetPronounciation(valid_symbols_set, parts[1]); 102 | if (!pronounciation.empty()) { 103 | (*cmudict)[word].push_back(pronounciation); 104 | } 105 | } 106 | } 107 | } 108 | 109 | return true; 110 | } 111 | 112 | 113 | /* 114 | 115 | cleaners.py 116 | 117 | ''' 118 | Cleaners are transformations that run over the input text at both training and eval time. 119 | 120 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" 121 | hyperparameter. Some cleaners are English-specific. You'll typically want to use: 122 | 1. "english_cleaners" for English text 123 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using 124 | the Unidecode library (https://pypi.python.org/pypi/Unidecode) 125 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update 126 | the symbols in symbols.py to match your data). 127 | ''' 128 | 129 | # Regular expression matching whitespace: 130 | _whitespace_re = re.compile(r'\s+') 131 | 132 | # List of (regular expression, replacement) pairs for abbreviations: 133 | _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ 134 | ('mrs', 'misess'), 135 | ('mr', 'mister'), 136 | ('dr', 'doctor'), 137 | ('st', 'saint'), 138 | ('co', 'company'), 139 | ('jr', 'junior'), 140 | ('maj', 'major'), 141 | ('gen', 'general'), 142 | ('drs', 'doctors'), 143 | ('rev', 'reverend'), 144 | ('lt', 'lieutenant'), 145 | ('hon', 'honorable'), 146 | ('sgt', 'sergeant'), 147 | ('capt', 'captain'), 148 | ('esq', 'esquire'), 149 | ('ltd', 'limited'), 150 | ('col', 'colonel'), 151 | ('ft', 'fort'), 152 | ]] 153 | 154 | 155 | def expand_abbreviations(text): 156 | for regex, replacement in _abbreviations: 157 | text = re.sub(regex, replacement, text) 158 | return text 159 | 160 | 161 | def expand_numbers(text): 162 | return normalize_numbers(text) 163 | 164 | 165 | def lowercase(text): 166 | return text.lower() 167 | 168 | 169 | def collapse_whitespace(text): 170 | return re.sub(_whitespace_re, ' ', text) 171 | 172 | 173 | def convert_to_ascii(text): 174 | return unidecode(text) 175 | 176 | def basic_cleaners(text): 177 | '''Basic pipeline that lowercases and collapses whitespace without transliteration.''' 178 | text = lowercase(text) 179 | text = collapse_whitespace(text) 180 | return text 181 | 182 | 183 | def transliteration_cleaners(text): 184 | '''Pipeline for non-English text that transliterates to ASCII.''' 185 | text = convert_to_ascii(text) 186 | text = lowercase(text) 187 | text = collapse_whitespace(text) 188 | return text 189 | 190 | 191 | def english_cleaners(text): 192 | '''Pipeline for English text, including number and abbreviation expansion.''' 193 | text = convert_to_ascii(text) 194 | text = lowercase(text) 195 | text = expand_numbers(text) 196 | text = expand_abbreviations(text) 197 | text = collapse_whitespace(text) 198 | return text 199 | 200 | 201 | 202 | 203 | symbols.py 204 | 205 | _pad = '_' 206 | _eos = '~' 207 | _characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!\'(),-.:;? ' 208 | 209 | # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters): 210 | _arpabet = ['@' + s for s in cmudict.valid_symbols] 211 | 212 | # Export all symbols: 213 | symbols = [_pad, _eos] + list(_characters) + _arpabet 214 | 215 | numbers.py 216 | 217 | _inflect = inflect.engine() 218 | _comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])') 219 | _decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)') 220 | _pounds_re = re.compile(r'£([0-9\,]*[0-9]+)') 221 | _dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)') 222 | _ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)') 223 | _number_re = re.compile(r'[0-9]+') 224 | 225 | def _remove_commas(m): 226 | return m.group(1).replace(',', '') 227 | 228 | 229 | def _expand_decimal_point(m): 230 | return m.group(1).replace('.', ' point ') 231 | 232 | 233 | def _expand_dollars(m): 234 | match = m.group(1) 235 | parts = match.split('.') 236 | if len(parts) > 2: 237 | return match + ' dollars' # Unexpected format 238 | dollars = int(parts[0]) if parts[0] else 0 239 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 240 | if dollars and cents: 241 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 242 | cent_unit = 'cent' if cents == 1 else 'cents' 243 | return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit) 244 | elif dollars: 245 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 246 | return '%s %s' % (dollars, dollar_unit) 247 | elif cents: 248 | cent_unit = 'cent' if cents == 1 else 'cents' 249 | return '%s %s' % (cents, cent_unit) 250 | else: 251 | return 'zero dollars' 252 | 253 | def _expand_ordinal(m): 254 | return _inflect.number_to_words(m.group(0)) 255 | 256 | 257 | def _expand_number(m): 258 | num = int(m.group(0)) 259 | if num > 1000 and num < 3000: 260 | if num == 2000: 261 | return 'two thousand' 262 | elif num > 2000 and num < 2010: 263 | return 'two thousand ' + _inflect.number_to_words(num % 100) 264 | elif num % 100 == 0: 265 | return _inflect.number_to_words(num // 100) + ' hundred' 266 | else: 267 | return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ') 268 | else: 269 | return _inflect.number_to_words(num, andword='') 270 | 271 | 272 | def normalize_numbers(text): 273 | text = re.sub(_comma_number_re, _remove_commas, text) 274 | text = re.sub(_pounds_re, r'\1 pounds', text) 275 | text = re.sub(_dollars_re, _expand_dollars, text) 276 | text = re.sub(_decimal_number_re, _expand_decimal_point, text) 277 | text = re.sub(_ordinal_re, _expand_ordinal, text) 278 | text = re.sub(_number_re, _expand_number, text) 279 | return text 280 | 281 | */ 282 | 283 | } // namespace 284 | 285 | int main(int argc, char **argv) { 286 | (void)argc; 287 | (void)argv; 288 | 289 | std::string cmudict_filename = ""; 290 | 291 | std::map> cmudict; 292 | 293 | bool ret = ParseCMUDict(cmudict_filename, &cmudict); 294 | if (!ret) { 295 | std::cerr << "Failed to load/parse CMU dict file : " << cmudict_filename << std::endl; 296 | return EXIT_FAILURE; 297 | } 298 | 299 | return EXIT_SUCCESS; 300 | } 301 | -------------------------------------------------------------------------------- /sample/hparams.json: -------------------------------------------------------------------------------- 1 | { 2 | "preemphasis" : 0.97 3 | } 4 | -------------------------------------------------------------------------------- /sample/output01.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/syoyo/tacotron-tts-cpp/adb3d62913590aac3b8f1f1434c44c525dbbf8d3/sample/output01.wav -------------------------------------------------------------------------------- /sample/processed01.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/syoyo/tacotron-tts-cpp/adb3d62913590aac3b8f1f1434c44c525dbbf8d3/sample/processed01.wav -------------------------------------------------------------------------------- /sample/sequence01.json: -------------------------------------------------------------------------------- 1 | { "sequence" : 2 | [46, 30, 36, 32, 41, 47, 36, 46, 47, 46, 64, 28, 47, 64, 47, 35, 32, 64, 30, 32, 45, 41, 64, 39, 28, 29, 42, 45, 28, 47, 42, 45, 52, 64, 46, 28, 52, 64, 47, 35, 32, 52, 64, 35, 28, 49, 32, 64, 31, 36, 46, 30, 42, 49, 32, 45, 32, 31, 64, 28, 64, 41, 32, 50, 64, 43, 28, 45, 47, 36, 30, 39, 32, 60, 1] 3 | } 4 | -------------------------------------------------------------------------------- /src/audio_util.cc: -------------------------------------------------------------------------------- 1 | #include "audio_util.h" 2 | 3 | #include 4 | #include 5 | 6 | namespace tts { 7 | 8 | namespace { 9 | 10 | float db_to_amp(const float x) { return std::pow(10.0f, x * 0.05f); } 11 | 12 | } // namespace 13 | 14 | std::vector inv_preemphasis(const float *x, size_t len, 15 | const float scale) { 16 | // scipy.signal.lfilter([1], [1, -hparams.preemphasis], x) 17 | // => 18 | // y[0] = x[0] 19 | // y[1] = -y[0] * (-hparams.preemphasis) + x[1] 20 | // ... 21 | // y[n] = -y[n-1] * (-hparams.preemphasis) + x[n] 22 | // 23 | 24 | std::vector y; 25 | y.push_back(x[0]); 26 | 27 | for (size_t i = 1; i < len; i++) { 28 | y.emplace_back(x[i] + y[i - 1] * scale); 29 | } 30 | 31 | return y; 32 | } 33 | 34 | size_t find_end_point(const float *wav, const size_t wav_len, 35 | const size_t sample_rate, const float threshold_db, 36 | const float min_silence_sec) { 37 | const size_t window_length = size_t(sample_rate * min_silence_sec); 38 | const size_t hop_length = window_length / 4; 39 | 40 | const float threshold = db_to_amp(threshold_db); 41 | 42 | if (window_length > wav_len) { 43 | return wav_len; 44 | } 45 | 46 | for (size_t x = hop_length; x < (wav_len - window_length); x += hop_length) { 47 | const size_t end_pos = std::min(wav_len, x + window_length); 48 | 49 | // find maximum for range [x, x + window_length] 50 | float m = *(std::max_element(wav + x, wav + end_pos)); 51 | if (m < threshold) { 52 | return std::min(wav_len, x + hop_length); 53 | } 54 | } 55 | 56 | // No silence duration found. 57 | return wav_len; 58 | } 59 | 60 | } // namespace tts 61 | -------------------------------------------------------------------------------- /src/audio_util.h: -------------------------------------------------------------------------------- 1 | #ifndef AUDIO_UTIL_H_ 2 | #define AUDIO_UTIL_H_ 3 | 4 | #include 5 | #include 6 | 7 | namespace tts { 8 | 9 | std::vector inv_preemphasis(const float *x, const size_t len, 10 | const float scale); 11 | 12 | // 13 | // Find end point of audio by detecting silence duration. 14 | // @return End frame index. 15 | // 16 | size_t find_end_point(const float *wav, const size_t wav_len, 17 | const size_t sample_rate, 18 | const float threshold_db = -40.0f, 19 | const float min_silence_sec = 0.8f); 20 | 21 | } // namespace tts 22 | 23 | #endif // AUDIO_UTIL_H_ 24 | -------------------------------------------------------------------------------- /src/cxxopts.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | 3 | Copyright (c) 2014, 2015, 2016, 2017 Jarryd Beck 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | 23 | */ 24 | 25 | #ifndef CXXOPTS_HPP_INCLUDED 26 | #define CXXOPTS_HPP_INCLUDED 27 | 28 | #include 29 | #include 30 | #include 31 | #include 32 | #include 33 | #include 34 | #include 35 | #include 36 | #include 37 | #include 38 | #include 39 | #include 40 | 41 | namespace cxxopts 42 | { 43 | static constexpr struct { 44 | uint8_t major, minor, patch; 45 | } version = {2, 1, 0}; 46 | } 47 | 48 | //when we ask cxxopts to use Unicode, help strings are processed using ICU, 49 | //which results in the correct lengths being computed for strings when they 50 | //are formatted for the help output 51 | //it is necessary to make sure that can be found by the 52 | //compiler, and that icu-uc is linked in to the binary. 53 | 54 | #ifdef CXXOPTS_USE_UNICODE 55 | #include 56 | 57 | namespace cxxopts 58 | { 59 | typedef icu::UnicodeString String; 60 | 61 | inline 62 | String 63 | toLocalString(std::string s) 64 | { 65 | return icu::UnicodeString::fromUTF8(std::move(s)); 66 | } 67 | 68 | class UnicodeStringIterator : public 69 | std::iterator 70 | { 71 | public: 72 | 73 | UnicodeStringIterator(const icu::UnicodeString* string, int32_t pos) 74 | : s(string) 75 | , i(pos) 76 | { 77 | } 78 | 79 | value_type 80 | operator*() const 81 | { 82 | return s->char32At(i); 83 | } 84 | 85 | bool 86 | operator==(const UnicodeStringIterator& rhs) const 87 | { 88 | return s == rhs.s && i == rhs.i; 89 | } 90 | 91 | bool 92 | operator!=(const UnicodeStringIterator& rhs) const 93 | { 94 | return !(*this == rhs); 95 | } 96 | 97 | UnicodeStringIterator& 98 | operator++() 99 | { 100 | ++i; 101 | return *this; 102 | } 103 | 104 | UnicodeStringIterator 105 | operator+(int32_t v) 106 | { 107 | return UnicodeStringIterator(s, i + v); 108 | } 109 | 110 | private: 111 | const icu::UnicodeString* s; 112 | int32_t i; 113 | }; 114 | 115 | inline 116 | String& 117 | stringAppend(String&s, String a) 118 | { 119 | return s.append(std::move(a)); 120 | } 121 | 122 | inline 123 | String& 124 | stringAppend(String& s, int n, UChar32 c) 125 | { 126 | for (int i = 0; i != n; ++i) 127 | { 128 | s.append(c); 129 | } 130 | 131 | return s; 132 | } 133 | 134 | template 135 | String& 136 | stringAppend(String& s, Iterator begin, Iterator end) 137 | { 138 | while (begin != end) 139 | { 140 | s.append(*begin); 141 | ++begin; 142 | } 143 | 144 | return s; 145 | } 146 | 147 | inline 148 | size_t 149 | stringLength(const String& s) 150 | { 151 | return s.length(); 152 | } 153 | 154 | inline 155 | std::string 156 | toUTF8String(const String& s) 157 | { 158 | std::string result; 159 | s.toUTF8String(result); 160 | 161 | return result; 162 | } 163 | 164 | inline 165 | bool 166 | empty(const String& s) 167 | { 168 | return s.isEmpty(); 169 | } 170 | } 171 | 172 | namespace std 173 | { 174 | inline 175 | cxxopts::UnicodeStringIterator 176 | begin(const icu::UnicodeString& s) 177 | { 178 | return cxxopts::UnicodeStringIterator(&s, 0); 179 | } 180 | 181 | inline 182 | cxxopts::UnicodeStringIterator 183 | end(const icu::UnicodeString& s) 184 | { 185 | return cxxopts::UnicodeStringIterator(&s, s.length()); 186 | } 187 | } 188 | 189 | //ifdef CXXOPTS_USE_UNICODE 190 | #else 191 | 192 | namespace cxxopts 193 | { 194 | typedef std::string String; 195 | 196 | template 197 | T 198 | toLocalString(T&& t) 199 | { 200 | return t; 201 | } 202 | 203 | inline 204 | size_t 205 | stringLength(const String& s) 206 | { 207 | return s.length(); 208 | } 209 | 210 | inline 211 | String& 212 | stringAppend(String&s, String a) 213 | { 214 | return s.append(std::move(a)); 215 | } 216 | 217 | inline 218 | String& 219 | stringAppend(String& s, size_t n, char c) 220 | { 221 | return s.append(n, c); 222 | } 223 | 224 | template 225 | String& 226 | stringAppend(String& s, Iterator begin, Iterator end) 227 | { 228 | return s.append(begin, end); 229 | } 230 | 231 | template 232 | std::string 233 | toUTF8String(T&& t) 234 | { 235 | return std::forward(t); 236 | } 237 | 238 | inline 239 | bool 240 | empty(const std::string& s) 241 | { 242 | return s.empty(); 243 | } 244 | } 245 | 246 | //ifdef CXXOPTS_USE_UNICODE 247 | #endif 248 | 249 | namespace cxxopts 250 | { 251 | namespace 252 | { 253 | #ifdef _WIN32 254 | const std::string LQUOTE("\'"); 255 | const std::string RQUOTE("\'"); 256 | #else 257 | const std::string LQUOTE("‘"); 258 | const std::string RQUOTE("’"); 259 | #endif 260 | } 261 | 262 | class Value : public std::enable_shared_from_this 263 | { 264 | public: 265 | 266 | virtual ~Value() = default; 267 | 268 | virtual 269 | std::shared_ptr 270 | clone() const = 0; 271 | 272 | virtual void 273 | parse(const std::string& text) const = 0; 274 | 275 | virtual void 276 | parse() const = 0; 277 | 278 | virtual bool 279 | has_default() const = 0; 280 | 281 | virtual bool 282 | is_container() const = 0; 283 | 284 | virtual bool 285 | has_implicit() const = 0; 286 | 287 | virtual std::string 288 | get_default_value() const = 0; 289 | 290 | virtual std::string 291 | get_implicit_value() const = 0; 292 | 293 | virtual std::shared_ptr 294 | default_value(const std::string& value) = 0; 295 | 296 | virtual std::shared_ptr 297 | implicit_value(const std::string& value) = 0; 298 | 299 | virtual bool 300 | is_boolean() const = 0; 301 | }; 302 | 303 | class OptionException : public std::exception 304 | { 305 | public: 306 | OptionException(const std::string& message) 307 | : m_message(message) 308 | { 309 | } 310 | 311 | virtual const char* 312 | what() const noexcept 313 | { 314 | return m_message.c_str(); 315 | } 316 | 317 | private: 318 | std::string m_message; 319 | }; 320 | 321 | class OptionSpecException : public OptionException 322 | { 323 | public: 324 | 325 | OptionSpecException(const std::string& message) 326 | : OptionException(message) 327 | { 328 | } 329 | }; 330 | 331 | class OptionParseException : public OptionException 332 | { 333 | public: 334 | OptionParseException(const std::string& message) 335 | : OptionException(message) 336 | { 337 | } 338 | }; 339 | 340 | class option_exists_error : public OptionSpecException 341 | { 342 | public: 343 | option_exists_error(const std::string& option) 344 | : OptionSpecException(u8"Option " + LQUOTE + option + RQUOTE + u8" already exists") 345 | { 346 | } 347 | }; 348 | 349 | class invalid_option_format_error : public OptionSpecException 350 | { 351 | public: 352 | invalid_option_format_error(const std::string& format) 353 | : OptionSpecException(u8"Invalid option format " + LQUOTE + format + RQUOTE) 354 | { 355 | } 356 | }; 357 | 358 | class option_not_exists_exception : public OptionParseException 359 | { 360 | public: 361 | option_not_exists_exception(const std::string& option) 362 | : OptionParseException(u8"Option " + LQUOTE + option + RQUOTE + u8" does not exist") 363 | { 364 | } 365 | }; 366 | 367 | class missing_argument_exception : public OptionParseException 368 | { 369 | public: 370 | missing_argument_exception(const std::string& option) 371 | : OptionParseException( 372 | u8"Option " + LQUOTE + option + RQUOTE + u8" is missing an argument" 373 | ) 374 | { 375 | } 376 | }; 377 | 378 | class option_requires_argument_exception : public OptionParseException 379 | { 380 | public: 381 | option_requires_argument_exception(const std::string& option) 382 | : OptionParseException( 383 | u8"Option " + LQUOTE + option + RQUOTE + u8" requires an argument" 384 | ) 385 | { 386 | } 387 | }; 388 | 389 | class option_not_has_argument_exception : public OptionParseException 390 | { 391 | public: 392 | option_not_has_argument_exception 393 | ( 394 | const std::string& option, 395 | const std::string& arg 396 | ) 397 | : OptionParseException( 398 | u8"Option " + LQUOTE + option + RQUOTE + 399 | u8" does not take an argument, but argument " + 400 | LQUOTE + arg + RQUOTE + " given" 401 | ) 402 | { 403 | } 404 | }; 405 | 406 | class option_not_present_exception : public OptionParseException 407 | { 408 | public: 409 | option_not_present_exception(const std::string& option) 410 | : OptionParseException(u8"Option " + LQUOTE + option + RQUOTE + u8" not present") 411 | { 412 | } 413 | }; 414 | 415 | class argument_incorrect_type : public OptionParseException 416 | { 417 | public: 418 | argument_incorrect_type 419 | ( 420 | const std::string& arg 421 | ) 422 | : OptionParseException( 423 | u8"Argument " + LQUOTE + arg + RQUOTE + u8" failed to parse" 424 | ) 425 | { 426 | } 427 | }; 428 | 429 | class option_required_exception : public OptionParseException 430 | { 431 | public: 432 | option_required_exception(const std::string& option) 433 | : OptionParseException( 434 | u8"Option " + LQUOTE + option + RQUOTE + u8" is required but not present" 435 | ) 436 | { 437 | } 438 | }; 439 | 440 | namespace values 441 | { 442 | namespace 443 | { 444 | std::basic_regex integer_pattern 445 | ("(-)?(0x)?([1-9a-zA-Z][0-9a-zA-Z]*)|((0x)?0)"); 446 | std::basic_regex truthy_pattern 447 | ("(t|T)(rue)?"); 448 | std::basic_regex falsy_pattern 449 | ("((f|F)(alse)?)?"); 450 | } 451 | 452 | namespace detail 453 | { 454 | template 455 | struct SignedCheck; 456 | 457 | template 458 | struct SignedCheck 459 | { 460 | template 461 | void 462 | operator()(bool negative, U u, const std::string& text) 463 | { 464 | if (negative) 465 | { 466 | if (u > static_cast(-std::numeric_limits::min())) 467 | { 468 | throw argument_incorrect_type(text); 469 | } 470 | } 471 | else 472 | { 473 | if (u > static_cast(std::numeric_limits::max())) 474 | { 475 | throw argument_incorrect_type(text); 476 | } 477 | } 478 | } 479 | }; 480 | 481 | template 482 | struct SignedCheck 483 | { 484 | template 485 | void 486 | operator()(bool, U, const std::string&) {} 487 | }; 488 | 489 | template 490 | void 491 | check_signed_range(bool negative, U value, const std::string& text) 492 | { 493 | SignedCheck::is_signed>()(negative, value, text); 494 | } 495 | } 496 | 497 | template 498 | R 499 | checked_negate(T&& t, const std::string&, std::true_type) 500 | { 501 | // if we got to here, then `t` is a positive number that fits into 502 | // `R`. So to avoid MSVC C4146, we first cast it to `R`. 503 | // See https://github.com/jarro2783/cxxopts/issues/62 for more details. 504 | return -static_cast(t); 505 | } 506 | 507 | template 508 | T 509 | checked_negate(T&&, const std::string& text, std::false_type) 510 | { 511 | throw argument_incorrect_type(text); 512 | } 513 | 514 | template 515 | void 516 | integer_parser(const std::string& text, T& value) 517 | { 518 | std::smatch match; 519 | std::regex_match(text, match, integer_pattern); 520 | 521 | if (match.length() == 0) 522 | { 523 | throw argument_incorrect_type(text); 524 | } 525 | 526 | if (match.length(4) > 0) 527 | { 528 | value = 0; 529 | return; 530 | } 531 | 532 | using US = typename std::make_unsigned::type; 533 | 534 | constexpr auto umax = std::numeric_limits::max(); 535 | constexpr bool is_signed = std::numeric_limits::is_signed; 536 | const bool negative = match.length(1) > 0; 537 | const uint8_t base = match.length(2) > 0 ? 16 : 10; 538 | 539 | auto value_match = match[3]; 540 | 541 | US result = 0; 542 | 543 | for (auto iter = value_match.first; iter != value_match.second; ++iter) 544 | { 545 | size_t digit = 0; 546 | 547 | if (*iter >= '0' && *iter <= '9') 548 | { 549 | digit = *iter - '0'; 550 | } 551 | else if (base == 16 && *iter >= 'a' && *iter <= 'f') 552 | { 553 | digit = *iter - 'a' + 10; 554 | } 555 | else if (base == 16 && *iter >= 'A' && *iter <= 'F') 556 | { 557 | digit = *iter - 'A' + 10; 558 | } 559 | else 560 | { 561 | throw argument_incorrect_type(text); 562 | } 563 | 564 | if (umax - digit < result * base) 565 | { 566 | throw argument_incorrect_type(text); 567 | } 568 | 569 | result = result * base + digit; 570 | } 571 | 572 | detail::check_signed_range(negative, result, text); 573 | 574 | if (negative) 575 | { 576 | value = checked_negate(result, 577 | text, 578 | std::integral_constant()); 579 | } 580 | else 581 | { 582 | value = result; 583 | } 584 | } 585 | 586 | template 587 | void stringstream_parser(const std::string& text, T& value) 588 | { 589 | std::stringstream in(text); 590 | in >> value; 591 | if (!in) { 592 | throw argument_incorrect_type(text); 593 | } 594 | } 595 | 596 | inline 597 | void 598 | parse_value(const std::string& text, uint8_t& value) 599 | { 600 | integer_parser(text, value); 601 | } 602 | 603 | inline 604 | void 605 | parse_value(const std::string& text, int8_t& value) 606 | { 607 | integer_parser(text, value); 608 | } 609 | 610 | inline 611 | void 612 | parse_value(const std::string& text, uint16_t& value) 613 | { 614 | integer_parser(text, value); 615 | } 616 | 617 | inline 618 | void 619 | parse_value(const std::string& text, int16_t& value) 620 | { 621 | integer_parser(text, value); 622 | } 623 | 624 | inline 625 | void 626 | parse_value(const std::string& text, uint32_t& value) 627 | { 628 | integer_parser(text, value); 629 | } 630 | 631 | inline 632 | void 633 | parse_value(const std::string& text, int32_t& value) 634 | { 635 | integer_parser(text, value); 636 | } 637 | 638 | inline 639 | void 640 | parse_value(const std::string& text, uint64_t& value) 641 | { 642 | integer_parser(text, value); 643 | } 644 | 645 | inline 646 | void 647 | parse_value(const std::string& text, int64_t& value) 648 | { 649 | integer_parser(text, value); 650 | } 651 | 652 | inline 653 | void 654 | parse_value(const std::string& text, bool& value) 655 | { 656 | std::smatch result; 657 | std::regex_match(text, result, truthy_pattern); 658 | 659 | if (!result.empty()) 660 | { 661 | value = true; 662 | return; 663 | } 664 | 665 | std::regex_match(text, result, falsy_pattern); 666 | if (!result.empty()) 667 | { 668 | value = false; 669 | return; 670 | } 671 | 672 | throw argument_incorrect_type(text); 673 | } 674 | 675 | inline 676 | void 677 | parse_value(const std::string& text, std::string& value) 678 | { 679 | value = text; 680 | } 681 | 682 | // The fallback parser. It uses the stringstream parser to parse all types 683 | // that have not been overloaded explicitly. It has to be placed in the 684 | // source code before all other more specialized templates. 685 | template 686 | void 687 | parse_value(const std::string& text, T& value) { 688 | stringstream_parser(text, value); 689 | } 690 | 691 | template 692 | void 693 | parse_value(const std::string& text, std::vector& value) 694 | { 695 | T v; 696 | parse_value(text, v); 697 | value.push_back(v); 698 | } 699 | 700 | template 701 | struct type_is_container 702 | { 703 | static constexpr bool value = false; 704 | }; 705 | 706 | template 707 | struct type_is_container> 708 | { 709 | static constexpr bool value = true; 710 | }; 711 | 712 | template 713 | class abstract_value : public Value 714 | { 715 | using Self = abstract_value; 716 | 717 | public: 718 | abstract_value() 719 | : m_result(std::make_shared()) 720 | , m_store(m_result.get()) 721 | { 722 | } 723 | 724 | abstract_value(T* t) 725 | : m_store(t) 726 | { 727 | } 728 | 729 | virtual ~abstract_value() = default; 730 | 731 | abstract_value(const abstract_value& rhs) 732 | { 733 | if (rhs.m_result) 734 | { 735 | m_result = std::make_shared(); 736 | m_store = m_result.get(); 737 | } 738 | else 739 | { 740 | m_store = rhs.m_store; 741 | } 742 | 743 | m_default = rhs.m_default; 744 | m_implicit = rhs.m_implicit; 745 | m_default_value = rhs.m_default_value; 746 | m_implicit_value = rhs.m_implicit_value; 747 | } 748 | 749 | void 750 | parse(const std::string& text) const 751 | { 752 | parse_value(text, *m_store); 753 | } 754 | 755 | bool 756 | is_container() const 757 | { 758 | return type_is_container::value; 759 | } 760 | 761 | void 762 | parse() const 763 | { 764 | parse_value(m_default_value, *m_store); 765 | } 766 | 767 | bool 768 | has_default() const 769 | { 770 | return m_default; 771 | } 772 | 773 | bool 774 | has_implicit() const 775 | { 776 | return m_implicit; 777 | } 778 | 779 | std::shared_ptr 780 | default_value(const std::string& value) 781 | { 782 | m_default = true; 783 | m_default_value = value; 784 | return shared_from_this(); 785 | } 786 | 787 | std::shared_ptr 788 | implicit_value(const std::string& value) 789 | { 790 | m_implicit = true; 791 | m_implicit_value = value; 792 | return shared_from_this(); 793 | } 794 | 795 | std::string 796 | get_default_value() const 797 | { 798 | return m_default_value; 799 | } 800 | 801 | std::string 802 | get_implicit_value() const 803 | { 804 | return m_implicit_value; 805 | } 806 | 807 | bool 808 | is_boolean() const 809 | { 810 | return std::is_same::value; 811 | } 812 | 813 | const T& 814 | get() const 815 | { 816 | if (m_store == nullptr) 817 | { 818 | return *m_result; 819 | } 820 | else 821 | { 822 | return *m_store; 823 | } 824 | } 825 | 826 | protected: 827 | std::shared_ptr m_result; 828 | T* m_store; 829 | 830 | bool m_default = false; 831 | bool m_implicit = false; 832 | 833 | std::string m_default_value; 834 | std::string m_implicit_value; 835 | }; 836 | 837 | template 838 | class standard_value : public abstract_value 839 | { 840 | public: 841 | using abstract_value::abstract_value; 842 | 843 | std::shared_ptr 844 | clone() const 845 | { 846 | return std::make_shared>(*this); 847 | } 848 | }; 849 | 850 | template <> 851 | class standard_value : public abstract_value 852 | { 853 | public: 854 | ~standard_value() = default; 855 | 856 | standard_value() 857 | { 858 | set_implicit(); 859 | } 860 | 861 | standard_value(bool* b) 862 | : abstract_value(b) 863 | { 864 | set_implicit(); 865 | } 866 | 867 | std::shared_ptr 868 | clone() const 869 | { 870 | return std::make_shared>(*this); 871 | } 872 | 873 | private: 874 | 875 | void 876 | set_implicit() 877 | { 878 | m_implicit = true; 879 | m_implicit_value = "true"; 880 | } 881 | }; 882 | } 883 | 884 | template 885 | std::shared_ptr 886 | value() 887 | { 888 | return std::make_shared>(); 889 | } 890 | 891 | template 892 | std::shared_ptr 893 | value(T& t) 894 | { 895 | return std::make_shared>(&t); 896 | } 897 | 898 | class OptionAdder; 899 | 900 | class OptionDetails 901 | { 902 | public: 903 | OptionDetails 904 | ( 905 | const std::string& short_, 906 | const std::string& long_, 907 | const String& desc, 908 | std::shared_ptr val 909 | ) 910 | : m_short(short_) 911 | , m_long(long_) 912 | , m_desc(desc) 913 | , m_value(val) 914 | , m_count(0) 915 | { 916 | } 917 | 918 | OptionDetails(const OptionDetails& rhs) 919 | : m_desc(rhs.m_desc) 920 | , m_count(rhs.m_count) 921 | { 922 | m_value = rhs.m_value->clone(); 923 | } 924 | 925 | OptionDetails(OptionDetails&& rhs) = default; 926 | 927 | const String& 928 | description() const 929 | { 930 | return m_desc; 931 | } 932 | 933 | const Value& value() const { 934 | return *m_value; 935 | } 936 | 937 | std::shared_ptr 938 | make_storage() const 939 | { 940 | return m_value->clone(); 941 | } 942 | 943 | const std::string& 944 | short_name() const 945 | { 946 | return m_short; 947 | } 948 | 949 | const std::string& 950 | long_name() const 951 | { 952 | return m_long; 953 | } 954 | 955 | private: 956 | std::string m_short; 957 | std::string m_long; 958 | String m_desc; 959 | std::shared_ptr m_value; 960 | int m_count; 961 | }; 962 | 963 | struct HelpOptionDetails 964 | { 965 | std::string s; 966 | std::string l; 967 | String desc; 968 | bool has_default; 969 | std::string default_value; 970 | bool has_implicit; 971 | std::string implicit_value; 972 | std::string arg_help; 973 | bool is_container; 974 | bool is_boolean; 975 | }; 976 | 977 | struct HelpGroupDetails 978 | { 979 | std::string name; 980 | std::string description; 981 | std::vector options; 982 | }; 983 | 984 | class OptionValue 985 | { 986 | public: 987 | void 988 | parse 989 | ( 990 | std::shared_ptr details, 991 | const std::string& text 992 | ) 993 | { 994 | ensure_value(details); 995 | ++m_count; 996 | m_value->parse(text); 997 | } 998 | 999 | void 1000 | parse_default(std::shared_ptr details) 1001 | { 1002 | ensure_value(details); 1003 | m_value->parse(); 1004 | m_count++; 1005 | } 1006 | 1007 | size_t 1008 | count() const 1009 | { 1010 | return m_count; 1011 | } 1012 | 1013 | template 1014 | const T& 1015 | as() const 1016 | { 1017 | #ifdef CXXOPTS_NO_RTTI 1018 | return static_cast&>(*m_value).get(); 1019 | #else 1020 | return dynamic_cast&>(*m_value).get(); 1021 | #endif 1022 | } 1023 | 1024 | private: 1025 | void 1026 | ensure_value(std::shared_ptr details) 1027 | { 1028 | if (m_value == nullptr) 1029 | { 1030 | m_value = details->make_storage(); 1031 | } 1032 | } 1033 | 1034 | std::shared_ptr m_value; 1035 | size_t m_count = 0; 1036 | }; 1037 | 1038 | class KeyValue 1039 | { 1040 | public: 1041 | KeyValue(std::string key_, std::string value_) 1042 | : m_key(std::move(key_)) 1043 | , m_value(std::move(value_)) 1044 | { 1045 | } 1046 | 1047 | const 1048 | std::string& 1049 | key() const 1050 | { 1051 | return m_key; 1052 | } 1053 | 1054 | const std::string 1055 | value() const 1056 | { 1057 | return m_value; 1058 | } 1059 | 1060 | template 1061 | T 1062 | as() const 1063 | { 1064 | T result; 1065 | values::parse_value(m_value, result); 1066 | return result; 1067 | } 1068 | 1069 | private: 1070 | std::string m_key; 1071 | std::string m_value; 1072 | }; 1073 | 1074 | class ParseResult 1075 | { 1076 | public: 1077 | 1078 | ParseResult( 1079 | const std::unordered_map>&, 1080 | std::vector, 1081 | int&, char**&); 1082 | 1083 | size_t 1084 | count(const std::string& o) const 1085 | { 1086 | auto iter = m_options.find(o); 1087 | if (iter == m_options.end()) 1088 | { 1089 | return 0; 1090 | } 1091 | 1092 | auto riter = m_results.find(iter->second); 1093 | 1094 | return riter->second.count(); 1095 | } 1096 | 1097 | const OptionValue& 1098 | operator[](const std::string& option) const 1099 | { 1100 | auto iter = m_options.find(option); 1101 | 1102 | if (iter == m_options.end()) 1103 | { 1104 | throw option_not_present_exception(option); 1105 | } 1106 | 1107 | auto riter = m_results.find(iter->second); 1108 | 1109 | return riter->second; 1110 | } 1111 | 1112 | const std::vector& 1113 | arguments() const 1114 | { 1115 | return m_sequential; 1116 | } 1117 | 1118 | private: 1119 | 1120 | OptionValue& 1121 | get_option(std::shared_ptr); 1122 | 1123 | void 1124 | parse(int& argc, char**& argv); 1125 | 1126 | void 1127 | add_to_option(const std::string& option, const std::string& arg); 1128 | 1129 | bool 1130 | consume_positional(std::string a); 1131 | 1132 | void 1133 | parse_option 1134 | ( 1135 | std::shared_ptr value, 1136 | const std::string& name, 1137 | const std::string& arg = "" 1138 | ); 1139 | 1140 | void 1141 | parse_default(std::shared_ptr details); 1142 | 1143 | void 1144 | checked_parse_arg 1145 | ( 1146 | int argc, 1147 | char* argv[], 1148 | int& current, 1149 | std::shared_ptr value, 1150 | const std::string& name 1151 | ); 1152 | 1153 | const std::unordered_map> 1154 | &m_options; 1155 | std::vector m_positional; 1156 | std::vector::iterator m_next_positional; 1157 | std::unordered_set m_positional_set; 1158 | std::unordered_map, OptionValue> m_results; 1159 | 1160 | std::vector m_sequential; 1161 | }; 1162 | 1163 | class Options 1164 | { 1165 | public: 1166 | 1167 | Options(std::string program, std::string help_string = "") 1168 | : m_program(std::move(program)) 1169 | , m_help_string(toLocalString(std::move(help_string))) 1170 | , m_custom_help("[OPTION...]") 1171 | , m_positional_help("positional parameters") 1172 | , m_show_positional(false) 1173 | , m_next_positional(m_positional.end()) 1174 | { 1175 | } 1176 | 1177 | Options& 1178 | positional_help(std::string help_text) 1179 | { 1180 | m_positional_help = std::move(help_text); 1181 | return *this; 1182 | } 1183 | 1184 | Options& 1185 | custom_help(std::string help_text) 1186 | { 1187 | m_custom_help = std::move(help_text); 1188 | return *this; 1189 | } 1190 | 1191 | Options& 1192 | show_positional_help() 1193 | { 1194 | m_show_positional = true; 1195 | return *this; 1196 | } 1197 | 1198 | ParseResult 1199 | parse(int& argc, char**& argv); 1200 | 1201 | OptionAdder 1202 | add_options(std::string group = ""); 1203 | 1204 | void 1205 | add_option 1206 | ( 1207 | const std::string& group, 1208 | const std::string& s, 1209 | const std::string& l, 1210 | std::string desc, 1211 | std::shared_ptr value, 1212 | std::string arg_help 1213 | ); 1214 | 1215 | //parse positional arguments into the given option 1216 | void 1217 | parse_positional(std::string option); 1218 | 1219 | void 1220 | parse_positional(std::vector options); 1221 | 1222 | void 1223 | parse_positional(std::initializer_list options); 1224 | 1225 | std::string 1226 | help(const std::vector& groups = {""}) const; 1227 | 1228 | const std::vector 1229 | groups() const; 1230 | 1231 | const HelpGroupDetails& 1232 | group_help(const std::string& group) const; 1233 | 1234 | private: 1235 | 1236 | void 1237 | add_one_option 1238 | ( 1239 | const std::string& option, 1240 | std::shared_ptr details 1241 | ); 1242 | 1243 | String 1244 | help_one_group(const std::string& group) const; 1245 | 1246 | void 1247 | generate_group_help 1248 | ( 1249 | String& result, 1250 | const std::vector& groups 1251 | ) const; 1252 | 1253 | void 1254 | generate_all_groups_help(String& result) const; 1255 | 1256 | std::string m_program; 1257 | String m_help_string; 1258 | std::string m_custom_help; 1259 | std::string m_positional_help; 1260 | bool m_show_positional; 1261 | 1262 | std::unordered_map> m_options; 1263 | std::vector m_positional; 1264 | std::vector::iterator m_next_positional; 1265 | std::unordered_set m_positional_set; 1266 | 1267 | //mapping from groups to help options 1268 | std::map m_help; 1269 | }; 1270 | 1271 | class OptionAdder 1272 | { 1273 | public: 1274 | 1275 | OptionAdder(Options& options, std::string group) 1276 | : m_options(options), m_group(std::move(group)) 1277 | { 1278 | } 1279 | 1280 | OptionAdder& 1281 | operator() 1282 | ( 1283 | const std::string& opts, 1284 | const std::string& desc, 1285 | std::shared_ptr value 1286 | = ::cxxopts::value(), 1287 | std::string arg_help = "" 1288 | ); 1289 | 1290 | private: 1291 | Options& m_options; 1292 | std::string m_group; 1293 | }; 1294 | 1295 | namespace 1296 | { 1297 | constexpr int OPTION_LONGEST = 30; 1298 | constexpr int OPTION_DESC_GAP = 2; 1299 | 1300 | std::basic_regex option_matcher 1301 | ("--([[:alnum:]][-_[:alnum:]]+)(=(.*))?|-([[:alnum:]]+)"); 1302 | 1303 | std::basic_regex option_specifier 1304 | ("(([[:alnum:]]),)?[ ]*([[:alnum:]][-_[:alnum:]]*)?"); 1305 | 1306 | String 1307 | format_option 1308 | ( 1309 | const HelpOptionDetails& o 1310 | ) 1311 | { 1312 | auto& s = o.s; 1313 | auto& l = o.l; 1314 | 1315 | String result = " "; 1316 | 1317 | if (s.size() > 0) 1318 | { 1319 | result += "-" + toLocalString(s) + ","; 1320 | } 1321 | else 1322 | { 1323 | result += " "; 1324 | } 1325 | 1326 | if (l.size() > 0) 1327 | { 1328 | result += " --" + toLocalString(l); 1329 | } 1330 | 1331 | auto arg = o.arg_help.size() > 0 ? toLocalString(o.arg_help) : "arg"; 1332 | 1333 | if (!o.is_boolean) 1334 | { 1335 | if (o.has_implicit) 1336 | { 1337 | result += " [=" + arg + "(=" + toLocalString(o.implicit_value) + ")]"; 1338 | } 1339 | else 1340 | { 1341 | result += " " + arg; 1342 | } 1343 | } 1344 | 1345 | return result; 1346 | } 1347 | 1348 | String 1349 | format_description 1350 | ( 1351 | const HelpOptionDetails& o, 1352 | size_t start, 1353 | size_t width 1354 | ) 1355 | { 1356 | auto desc = o.desc; 1357 | 1358 | if (o.has_default) 1359 | { 1360 | desc += toLocalString(" (default: " + o.default_value + ")"); 1361 | } 1362 | 1363 | String result; 1364 | 1365 | auto current = std::begin(desc); 1366 | auto startLine = current; 1367 | auto lastSpace = current; 1368 | 1369 | auto size = size_t{}; 1370 | 1371 | while (current != std::end(desc)) 1372 | { 1373 | if (*current == ' ') 1374 | { 1375 | lastSpace = current; 1376 | } 1377 | 1378 | if (size > width) 1379 | { 1380 | if (lastSpace == startLine) 1381 | { 1382 | stringAppend(result, startLine, current + 1); 1383 | stringAppend(result, "\n"); 1384 | stringAppend(result, start, ' '); 1385 | startLine = current + 1; 1386 | lastSpace = startLine; 1387 | } 1388 | else 1389 | { 1390 | stringAppend(result, startLine, lastSpace); 1391 | stringAppend(result, "\n"); 1392 | stringAppend(result, start, ' '); 1393 | startLine = lastSpace + 1; 1394 | } 1395 | size = 0; 1396 | } 1397 | else 1398 | { 1399 | ++size; 1400 | } 1401 | 1402 | ++current; 1403 | } 1404 | 1405 | //append whatever is left 1406 | stringAppend(result, startLine, current); 1407 | 1408 | return result; 1409 | } 1410 | } 1411 | 1412 | inline 1413 | ParseResult::ParseResult 1414 | ( 1415 | const std::unordered_map>& options, 1416 | std::vector positional, 1417 | int& argc, char**& argv 1418 | ) 1419 | : m_options(options) 1420 | , m_positional(std::move(positional)) 1421 | , m_next_positional(m_positional.begin()) 1422 | { 1423 | parse(argc, argv); 1424 | } 1425 | 1426 | inline 1427 | OptionAdder 1428 | Options::add_options(std::string group) 1429 | { 1430 | return OptionAdder(*this, std::move(group)); 1431 | } 1432 | 1433 | inline 1434 | OptionAdder& 1435 | OptionAdder::operator() 1436 | ( 1437 | const std::string& opts, 1438 | const std::string& desc, 1439 | std::shared_ptr value, 1440 | std::string arg_help 1441 | ) 1442 | { 1443 | std::match_results result; 1444 | std::regex_match(opts.c_str(), result, option_specifier); 1445 | 1446 | if (result.empty()) 1447 | { 1448 | throw invalid_option_format_error(opts); 1449 | } 1450 | 1451 | const auto& short_match = result[2]; 1452 | const auto& long_match = result[3]; 1453 | 1454 | if (!short_match.length() && !long_match.length()) 1455 | { 1456 | throw invalid_option_format_error(opts); 1457 | } else if (long_match.length() == 1 && short_match.length()) 1458 | { 1459 | throw invalid_option_format_error(opts); 1460 | } 1461 | 1462 | auto option_names = [] 1463 | ( 1464 | const std::sub_match& short_, 1465 | const std::sub_match& long_ 1466 | ) 1467 | { 1468 | if (long_.length() == 1) 1469 | { 1470 | return std::make_tuple(long_.str(), short_.str()); 1471 | } 1472 | else 1473 | { 1474 | return std::make_tuple(short_.str(), long_.str()); 1475 | } 1476 | }(short_match, long_match); 1477 | 1478 | m_options.add_option 1479 | ( 1480 | m_group, 1481 | std::get<0>(option_names), 1482 | std::get<1>(option_names), 1483 | desc, 1484 | value, 1485 | std::move(arg_help) 1486 | ); 1487 | 1488 | return *this; 1489 | } 1490 | 1491 | inline 1492 | void 1493 | ParseResult::parse_default(std::shared_ptr details) 1494 | { 1495 | m_results[details].parse_default(details); 1496 | } 1497 | 1498 | inline 1499 | void 1500 | ParseResult::parse_option 1501 | ( 1502 | std::shared_ptr value, 1503 | const std::string& /*name*/, 1504 | const std::string& arg 1505 | ) 1506 | { 1507 | auto& result = m_results[value]; 1508 | result.parse(value, arg); 1509 | 1510 | m_sequential.emplace_back(value->long_name(), arg); 1511 | } 1512 | 1513 | inline 1514 | void 1515 | ParseResult::checked_parse_arg 1516 | ( 1517 | int argc, 1518 | char* argv[], 1519 | int& current, 1520 | std::shared_ptr value, 1521 | const std::string& name 1522 | ) 1523 | { 1524 | if (current + 1 >= argc) 1525 | { 1526 | if (value->value().has_implicit()) 1527 | { 1528 | parse_option(value, name, value->value().get_implicit_value()); 1529 | } 1530 | else 1531 | { 1532 | throw missing_argument_exception(name); 1533 | } 1534 | } 1535 | else 1536 | { 1537 | if (value->value().has_implicit()) 1538 | { 1539 | parse_option(value, name, value->value().get_implicit_value()); 1540 | } 1541 | else 1542 | { 1543 | parse_option(value, name, argv[current + 1]); 1544 | ++current; 1545 | } 1546 | } 1547 | } 1548 | 1549 | inline 1550 | void 1551 | ParseResult::add_to_option(const std::string& option, const std::string& arg) 1552 | { 1553 | auto iter = m_options.find(option); 1554 | 1555 | if (iter == m_options.end()) 1556 | { 1557 | throw option_not_exists_exception(option); 1558 | } 1559 | 1560 | parse_option(iter->second, option, arg); 1561 | } 1562 | 1563 | inline 1564 | bool 1565 | ParseResult::consume_positional(std::string a) 1566 | { 1567 | while (m_next_positional != m_positional.end()) 1568 | { 1569 | auto iter = m_options.find(*m_next_positional); 1570 | if (iter != m_options.end()) 1571 | { 1572 | auto& result = m_results[iter->second]; 1573 | if (!iter->second->value().is_container()) 1574 | { 1575 | if (result.count() == 0) 1576 | { 1577 | add_to_option(*m_next_positional, a); 1578 | ++m_next_positional; 1579 | return true; 1580 | } 1581 | else 1582 | { 1583 | ++m_next_positional; 1584 | continue; 1585 | } 1586 | } 1587 | else 1588 | { 1589 | add_to_option(*m_next_positional, a); 1590 | return true; 1591 | } 1592 | } 1593 | ++m_next_positional; 1594 | } 1595 | 1596 | return false; 1597 | } 1598 | 1599 | inline 1600 | void 1601 | Options::parse_positional(std::string option) 1602 | { 1603 | parse_positional(std::vector{std::move(option)}); 1604 | } 1605 | 1606 | inline 1607 | void 1608 | Options::parse_positional(std::vector options) 1609 | { 1610 | m_positional = std::move(options); 1611 | m_next_positional = m_positional.begin(); 1612 | 1613 | m_positional_set.insert(m_positional.begin(), m_positional.end()); 1614 | } 1615 | 1616 | inline 1617 | void 1618 | Options::parse_positional(std::initializer_list options) 1619 | { 1620 | parse_positional(std::vector(std::move(options))); 1621 | } 1622 | 1623 | inline 1624 | ParseResult 1625 | Options::parse(int& argc, char**& argv) 1626 | { 1627 | ParseResult result(m_options, m_positional, argc, argv); 1628 | return result; 1629 | } 1630 | 1631 | inline 1632 | void 1633 | ParseResult::parse(int& argc, char**& argv) 1634 | { 1635 | int current = 1; 1636 | 1637 | int nextKeep = 1; 1638 | 1639 | bool consume_remaining = false; 1640 | 1641 | while (current != argc) 1642 | { 1643 | if (strcmp(argv[current], "--") == 0) 1644 | { 1645 | consume_remaining = true; 1646 | ++current; 1647 | break; 1648 | } 1649 | 1650 | std::match_results result; 1651 | std::regex_match(argv[current], result, option_matcher); 1652 | 1653 | if (result.empty()) 1654 | { 1655 | //not a flag 1656 | 1657 | //if true is returned here then it was consumed, otherwise it is 1658 | //ignored 1659 | if (consume_positional(argv[current])) 1660 | { 1661 | } 1662 | else 1663 | { 1664 | argv[nextKeep] = argv[current]; 1665 | ++nextKeep; 1666 | } 1667 | //if we return from here then it was parsed successfully, so continue 1668 | } 1669 | else 1670 | { 1671 | //short or long option? 1672 | if (result[4].length() != 0) 1673 | { 1674 | const std::string& s = result[4]; 1675 | 1676 | for (std::size_t i = 0; i != s.size(); ++i) 1677 | { 1678 | std::string name(1, s[i]); 1679 | auto iter = m_options.find(name); 1680 | 1681 | if (iter == m_options.end()) 1682 | { 1683 | throw option_not_exists_exception(name); 1684 | } 1685 | 1686 | auto value = iter->second; 1687 | 1688 | if (i + 1 == s.size()) 1689 | { 1690 | //it must be the last argument 1691 | checked_parse_arg(argc, argv, current, value, name); 1692 | } 1693 | else if (value->value().has_implicit()) 1694 | { 1695 | parse_option(value, name, value->value().get_implicit_value()); 1696 | } 1697 | else 1698 | { 1699 | //error 1700 | throw option_requires_argument_exception(name); 1701 | } 1702 | } 1703 | } 1704 | else if (result[1].length() != 0) 1705 | { 1706 | const std::string& name = result[1]; 1707 | 1708 | auto iter = m_options.find(name); 1709 | 1710 | if (iter == m_options.end()) 1711 | { 1712 | throw option_not_exists_exception(name); 1713 | } 1714 | 1715 | auto opt = iter->second; 1716 | 1717 | //equals provided for long option? 1718 | if (result[2].length() != 0) 1719 | { 1720 | //parse the option given 1721 | 1722 | parse_option(opt, name, result[3]); 1723 | } 1724 | else 1725 | { 1726 | //parse the next argument 1727 | checked_parse_arg(argc, argv, current, opt, name); 1728 | } 1729 | } 1730 | 1731 | } 1732 | 1733 | ++current; 1734 | } 1735 | 1736 | for (auto& opt : m_options) 1737 | { 1738 | auto& detail = opt.second; 1739 | auto& value = detail->value(); 1740 | 1741 | auto& store = m_results[detail]; 1742 | 1743 | if(!store.count() && value.has_default()){ 1744 | parse_default(detail); 1745 | } 1746 | } 1747 | 1748 | if (consume_remaining) 1749 | { 1750 | while (current < argc) 1751 | { 1752 | if (!consume_positional(argv[current])) { 1753 | break; 1754 | } 1755 | ++current; 1756 | } 1757 | 1758 | //adjust argv for any that couldn't be swallowed 1759 | while (current != argc) { 1760 | argv[nextKeep] = argv[current]; 1761 | ++nextKeep; 1762 | ++current; 1763 | } 1764 | } 1765 | 1766 | argc = nextKeep; 1767 | 1768 | } 1769 | 1770 | inline 1771 | void 1772 | Options::add_option 1773 | ( 1774 | const std::string& group, 1775 | const std::string& s, 1776 | const std::string& l, 1777 | std::string desc, 1778 | std::shared_ptr value, 1779 | std::string arg_help 1780 | ) 1781 | { 1782 | auto stringDesc = toLocalString(std::move(desc)); 1783 | auto option = std::make_shared(s, l, stringDesc, value); 1784 | 1785 | if (s.size() > 0) 1786 | { 1787 | add_one_option(s, option); 1788 | } 1789 | 1790 | if (l.size() > 0) 1791 | { 1792 | add_one_option(l, option); 1793 | } 1794 | 1795 | //add the help details 1796 | auto& options = m_help[group]; 1797 | 1798 | options.options.emplace_back(HelpOptionDetails{s, l, stringDesc, 1799 | value->has_default(), value->get_default_value(), 1800 | value->has_implicit(), value->get_implicit_value(), 1801 | std::move(arg_help), 1802 | value->is_container(), 1803 | value->is_boolean()}); 1804 | } 1805 | 1806 | inline 1807 | void 1808 | Options::add_one_option 1809 | ( 1810 | const std::string& option, 1811 | std::shared_ptr details 1812 | ) 1813 | { 1814 | auto in = m_options.emplace(option, details); 1815 | 1816 | if (!in.second) 1817 | { 1818 | throw option_exists_error(option); 1819 | } 1820 | } 1821 | 1822 | inline 1823 | String 1824 | Options::help_one_group(const std::string& g) const 1825 | { 1826 | typedef std::vector> OptionHelp; 1827 | 1828 | auto group = m_help.find(g); 1829 | if (group == m_help.end()) 1830 | { 1831 | return ""; 1832 | } 1833 | 1834 | OptionHelp format; 1835 | 1836 | size_t longest = 0; 1837 | 1838 | String result; 1839 | 1840 | if (!g.empty()) 1841 | { 1842 | result += toLocalString(" " + g + " options:\n"); 1843 | } 1844 | 1845 | for (const auto& o : group->second.options) 1846 | { 1847 | if (o.is_container && 1848 | m_positional_set.find(o.l) != m_positional_set.end() && 1849 | !m_show_positional) 1850 | { 1851 | continue; 1852 | } 1853 | 1854 | auto s = format_option(o); 1855 | longest = std::max(longest, stringLength(s)); 1856 | format.push_back(std::make_pair(s, String())); 1857 | } 1858 | 1859 | longest = std::min(longest, static_cast(OPTION_LONGEST)); 1860 | 1861 | //widest allowed description 1862 | auto allowed = size_t{76} - longest - OPTION_DESC_GAP; 1863 | 1864 | auto fiter = format.begin(); 1865 | for (const auto& o : group->second.options) 1866 | { 1867 | if (o.is_container && 1868 | m_positional_set.find(o.l) != m_positional_set.end() && 1869 | !m_show_positional) 1870 | { 1871 | continue; 1872 | } 1873 | 1874 | auto d = format_description(o, longest + OPTION_DESC_GAP, allowed); 1875 | 1876 | result += fiter->first; 1877 | if (stringLength(fiter->first) > longest) 1878 | { 1879 | result += '\n'; 1880 | result += toLocalString(std::string(longest + OPTION_DESC_GAP, ' ')); 1881 | } 1882 | else 1883 | { 1884 | result += toLocalString(std::string(longest + OPTION_DESC_GAP - 1885 | stringLength(fiter->first), 1886 | ' ')); 1887 | } 1888 | result += d; 1889 | result += '\n'; 1890 | 1891 | ++fiter; 1892 | } 1893 | 1894 | return result; 1895 | } 1896 | 1897 | inline 1898 | void 1899 | Options::generate_group_help 1900 | ( 1901 | String& result, 1902 | const std::vector& print_groups 1903 | ) const 1904 | { 1905 | for (size_t i = 0; i != print_groups.size(); ++i) 1906 | { 1907 | const String& group_help_text = help_one_group(print_groups[i]); 1908 | if (empty(group_help_text)) 1909 | { 1910 | continue; 1911 | } 1912 | result += group_help_text; 1913 | if (i < print_groups.size() - 1) 1914 | { 1915 | result += '\n'; 1916 | } 1917 | } 1918 | } 1919 | 1920 | inline 1921 | void 1922 | Options::generate_all_groups_help(String& result) const 1923 | { 1924 | std::vector all_groups; 1925 | all_groups.reserve(m_help.size()); 1926 | 1927 | for (auto& group : m_help) 1928 | { 1929 | all_groups.push_back(group.first); 1930 | } 1931 | 1932 | generate_group_help(result, all_groups); 1933 | } 1934 | 1935 | inline 1936 | std::string 1937 | Options::help(const std::vector& help_groups) const 1938 | { 1939 | String result = m_help_string + "\nUsage:\n " + 1940 | toLocalString(m_program) + " " + toLocalString(m_custom_help); 1941 | 1942 | if (m_positional.size() > 0 && m_positional_help.size() > 0) { 1943 | result += " " + toLocalString(m_positional_help); 1944 | } 1945 | 1946 | result += "\n\n"; 1947 | 1948 | if (help_groups.size() == 0) 1949 | { 1950 | generate_all_groups_help(result); 1951 | } 1952 | else 1953 | { 1954 | generate_group_help(result, help_groups); 1955 | } 1956 | 1957 | return toUTF8String(result); 1958 | } 1959 | 1960 | inline 1961 | const std::vector 1962 | Options::groups() const 1963 | { 1964 | std::vector g; 1965 | 1966 | std::transform( 1967 | m_help.begin(), 1968 | m_help.end(), 1969 | std::back_inserter(g), 1970 | [] (const std::map::value_type& pair) 1971 | { 1972 | return pair.first; 1973 | } 1974 | ); 1975 | 1976 | return g; 1977 | } 1978 | 1979 | inline 1980 | const HelpGroupDetails& 1981 | Options::group_help(const std::string& group) const 1982 | { 1983 | return m_help.at(group); 1984 | } 1985 | 1986 | } 1987 | 1988 | #endif //CXXOPTS_HPP_INCLUDED 1989 | -------------------------------------------------------------------------------- /src/main.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #ifdef __clang__ 6 | #pragma clang diagnostic push 7 | #pragma clang diagnostic ignored "-Weverything" 8 | #endif 9 | 10 | #include "cxxopts.hpp" 11 | #include "json.hpp" 12 | #include "audio_util.h" 13 | 14 | #define DR_WAV_IMPLEMENTATION 15 | #include "dr_wav.h" 16 | 17 | #ifdef __clang__ 18 | #pragma clang diagnostic pop 19 | #endif 20 | 21 | #include "tf_synthesizer.h" 22 | 23 | class HyperParameters 24 | { 25 | public: 26 | HyperParameters() : preemphasis(0.97f) {}; 27 | 28 | float preemphasis; 29 | }; 30 | 31 | template 32 | bool GetNumberArray(const nlohmann::json &j, const std::string &name, 33 | std::vector *value) { 34 | if (j.find(name) == j.end()) { 35 | std::cerr << "Property not found : " << name << std::endl; 36 | return false; 37 | } 38 | 39 | if (!j.at(name).is_array()) { 40 | std::cerr << "Property is not an array type : " << name << std::endl; 41 | return false; 42 | } 43 | 44 | std::vector v; 45 | for (auto &element : j.at(name)) { 46 | if (!element.is_number()) { 47 | std::cerr << "An array element is not a number" << std::endl; 48 | return false; 49 | } 50 | 51 | v.push_back(static_cast(element.get())); 52 | } 53 | 54 | (*value) = v; 55 | return true; 56 | } 57 | 58 | 59 | // Load sequence from JSON array 60 | bool LoadSequence(const std::string &sequence_filename, std::vector *sequence) 61 | { 62 | std::ifstream is(sequence_filename); 63 | if (!is) { 64 | std::cerr << "Failed to open/read file : " << sequence_filename << std::endl; 65 | return false; 66 | } 67 | 68 | nlohmann::json j; 69 | is >> j; 70 | 71 | return GetNumberArray(j, "sequence", sequence); 72 | 73 | } 74 | 75 | bool ParseHyperPrameters(const std::string &json_filename, HyperParameters *hparams) 76 | { 77 | std::ifstream is(json_filename); 78 | if (!is) { 79 | std::cerr << "Failed to open/read hyper parameter JSON file : " << json_filename << std::endl; 80 | return false; 81 | } 82 | 83 | nlohmann::json j; 84 | is >> j; 85 | 86 | if (j.count("preemphasis")) { 87 | auto param = j["preemphasis"]; 88 | if (param.is_number()) { 89 | hparams->preemphasis = float(param.get()); 90 | } 91 | } 92 | 93 | return true; 94 | } 95 | 96 | void PrintHyperParameters(const HyperParameters &hparams) 97 | { 98 | std::cout << "Hyper parameter and configurations :\n"; 99 | std::cout << " preemphasis : " << hparams.preemphasis << "\n"; 100 | } 101 | 102 | static uint16_t ftous(const float x) 103 | { 104 | int f = int(x); 105 | return std::max(uint16_t(0), std::min(std::numeric_limits::max(), uint16_t(f))); 106 | } 107 | 108 | bool SaveWav(const std::string &filename, const std::vector &samples, const int sample_rate) 109 | { 110 | // We want to save audio with 32bit float format without loosing precision, 111 | // but librosa only supports PCM audio, so save audio data as 16bit PCM. 112 | 113 | drwav_data_format format; 114 | format.container = drwav_container_riff; // <-- drwav_container_riff = normal WAV files, drwav_container_w64 = Sony Wave64. 115 | format.format = DR_WAVE_FORMAT_PCM; 116 | format.channels = 1; 117 | format.sampleRate = sample_rate; 118 | format.bitsPerSample = 16; 119 | drwav* pWav = drwav_open_file_write(filename.c_str(), &format); 120 | 121 | std::vector data; 122 | data.resize(samples.size()); 123 | 124 | float max_value = std::fabs(samples[0]); 125 | for (size_t i = 0; i < samples.size(); i++) { 126 | max_value = std::max(max_value, std::fabs(samples[i])); 127 | } 128 | 129 | std::cout << "max value = " << max_value << "\n"; 130 | 131 | float factor = 32767.0f / std::max(0.01f, max_value); 132 | 133 | // normalize & 16bit quantize. 134 | for (size_t i = 0; i < samples.size(); i++) { 135 | data[i] = ftous(factor * samples[i]); 136 | } 137 | 138 | drwav_uint64 n = static_cast(samples.size()); 139 | drwav_uint64 samples_written = drwav_write(pWav, n, data.data()); 140 | 141 | drwav_close(pWav); 142 | 143 | if (samples_written > 0) return true; 144 | 145 | return false; 146 | } 147 | 148 | 149 | int main(int argc, char **argv) { 150 | cxxopts::Options options("tts", "Tacotron text to speec in C++"); 151 | options.add_options()("i,input", "Input sequence file(JSON)", 152 | cxxopts::value())( 153 | "g,graph", "Input freezed graph file", cxxopts::value()) 154 | ("h,hparams", "Hyper parameters(JSON)", cxxopts::value()); 155 | ("o,output", "Output WAV filename", cxxopts::value()); 156 | 157 | auto result = options.parse(argc, argv); 158 | 159 | if (!result.count("input")) { 160 | std::cerr << "Please specify input sequence file with -i or --input option." 161 | << std::endl; 162 | return EXIT_FAILURE; 163 | } 164 | 165 | if (!result.count("graph")) { 166 | std::cerr << "Please specify freezed graph with -g or --graph option." 167 | << std::endl; 168 | return EXIT_FAILURE; 169 | } 170 | 171 | HyperParameters hparams; 172 | 173 | if (result.count("hparams")) { 174 | if (!ParseHyperPrameters(result["hparams"].as(), &hparams)) { 175 | return EXIT_FAILURE; 176 | } 177 | } 178 | 179 | std::string input_filename = result["input"].as(); 180 | std::string graph_filename = result["graph"].as(); 181 | std::string output_filename = "output.wav"; 182 | 183 | if (result.count("output")) { 184 | output_filename = result["output"].as(); 185 | } 186 | 187 | std::vector sequence; 188 | if (!LoadSequence(input_filename, &sequence)) { 189 | std::cerr << "Failed to load sequence data : " << input_filename << std::endl; 190 | return EXIT_FAILURE; 191 | } 192 | 193 | std::cout << "sequence = ["; 194 | for (size_t i = 0; i < sequence.size(); i++) { 195 | std::cout << sequence[i]; 196 | if (i != (sequence.size() - 1)) { 197 | std::cout << ", "; 198 | } 199 | } 200 | std::cout << "]" << std::endl; 201 | 202 | // Synthesize(generate wav from sequence) 203 | tts::TensorflowSynthesizer tf_synthesizer; 204 | tf_synthesizer.init(argc, argv); 205 | if (!tf_synthesizer.load(graph_filename, "inputs", 206 | "model/griffinlim/Squeeze")) { 207 | std::cerr << "Failed to load/setup Tensorflow model from a frozen graph : " << graph_filename << std::endl; 208 | return EXIT_FAILURE; 209 | } 210 | 211 | PrintHyperParameters(hparams); 212 | 213 | std::cout << "Synthesize..." << std::endl; 214 | 215 | std::vector input_lengths; 216 | input_lengths.push_back(int(sequence.size())); 217 | 218 | std::vector wav0; 219 | 220 | if (!tf_synthesizer.synthesize(sequence, input_lengths, &wav0)) { 221 | std::cerr << "Failed to synthesize for a given sequence." << std::endl; 222 | return EXIT_FAILURE; 223 | } 224 | 225 | constexpr int32_t sample_rate = 20000; 226 | 227 | // Postprocess audio. 228 | // 1. Inverse preemphasis 229 | // 2. Remove silence 230 | std::vector output_wav = tts::inv_preemphasis(wav0.data(), wav0.size(), -hparams.preemphasis); 231 | size_t end_point = tts::find_end_point(output_wav.data(), output_wav.size(), sample_rate); 232 | 233 | std::cout << "Generated wav has " << output_wav.size() << "samples \n"; 234 | std::cout << "Truncated to " << end_point << " samples(by removing silence duration)\n"; 235 | 236 | output_wav.resize(end_point); 237 | 238 | if (!SaveWav(output_filename, output_wav, sample_rate)) { 239 | std::cerr << "Failed to save wav file :" << output_filename << std::endl; 240 | 241 | return EXIT_FAILURE; 242 | } 243 | 244 | return EXIT_SUCCESS; 245 | } 246 | -------------------------------------------------------------------------------- /src/tf_synthesizer.cc: -------------------------------------------------------------------------------- 1 | #include "tf_synthesizer.h" 2 | 3 | #ifdef __clang__ 4 | #pragma clang diagnostic push 5 | #pragma clang diagnostic ignored "-Weverything" 6 | #endif 7 | 8 | #include "tensorflow/cc/ops/array_ops.h" 9 | #include "tensorflow/cc/ops/const_op.h" 10 | #include "tensorflow/cc/ops/image_ops.h" 11 | #include "tensorflow/cc/ops/standard_ops.h" 12 | #include "tensorflow/core/framework/graph.pb.h" 13 | #include "tensorflow/core/framework/tensor.h" 14 | #include "tensorflow/core/graph/default_device.h" 15 | #include "tensorflow/core/graph/graph_def_builder.h" 16 | #include "tensorflow/core/lib/core/errors.h" 17 | #include "tensorflow/core/lib/core/stringpiece.h" 18 | #include "tensorflow/core/lib/core/threadpool.h" 19 | #include "tensorflow/core/lib/io/path.h" 20 | #include "tensorflow/core/lib/strings/stringprintf.h" 21 | #include "tensorflow/core/platform/env.h" 22 | #include "tensorflow/core/platform/init_main.h" 23 | #include "tensorflow/core/platform/logging.h" 24 | #include "tensorflow/core/platform/types.h" 25 | #include "tensorflow/core/public/session.h" 26 | #include "tensorflow/core/util/command_line_flags.h" 27 | 28 | #ifdef __clang__ 29 | #pragma clang diagnostic pop 30 | #endif 31 | 32 | #include 33 | 34 | using namespace tensorflow; 35 | using namespace tensorflow::ops; 36 | 37 | namespace tts { 38 | 39 | namespace { 40 | 41 | // Reads a model graph definition from disk, and creates a session object you 42 | // can use to run it. 43 | Status LoadGraph(const string& graph_file_name, 44 | std::unique_ptr* session) { 45 | tensorflow::GraphDef graph_def; 46 | Status load_graph_status = 47 | ReadBinaryProto(tensorflow::Env::Default(), graph_file_name, &graph_def); 48 | if (!load_graph_status.ok()) { 49 | return tensorflow::errors::NotFound("Failed to load compute graph at '", 50 | graph_file_name, "'"); 51 | } 52 | session->reset(tensorflow::NewSession(tensorflow::SessionOptions())); 53 | Status session_create_status = (*session)->Create(graph_def); 54 | if (!session_create_status.ok()) { 55 | return session_create_status; 56 | } 57 | return Status::OK(); 58 | } 59 | 60 | } // anonymous namespace 61 | 62 | class TensorflowSynthesizer::Impl { 63 | public: 64 | void init(int argc, char* argv[]) { 65 | // We need to call this to set up global state for TensorFlow. 66 | tensorflow::port::InitMain(argv[0], &argc, &argv); 67 | } 68 | 69 | bool load(const std::string& graph_filename, const std::string& inp_layer, 70 | const std::string& out_layer) { 71 | // First we load and initialize the model. 72 | Status load_graph_status = LoadGraph(graph_filename, &session); 73 | if (!load_graph_status.ok()) { 74 | std::cerr << load_graph_status; 75 | return false; 76 | } 77 | 78 | input_layer = inp_layer; 79 | output_layer = out_layer; 80 | 81 | return true; 82 | } 83 | 84 | bool synthesize(const std::vector& input_sequence, const std::vector& input_lengths, std::vector *output) { 85 | 86 | // Batch size = 1 for a while 87 | int N = 1; 88 | 89 | 90 | int input_length = int(input_sequence.size()); 91 | Tensor input_tensor(DT_INT32, {N, input_length}); 92 | 93 | std::copy_n(input_sequence.data(), input_sequence.size(), 94 | input_tensor.flat().data()); 95 | 96 | Tensor input_lengths_tensor(DT_INT32, {N}); 97 | 98 | *(input_lengths_tensor.flat().data()) = input_length; 99 | 100 | auto startT = std::chrono::system_clock::now(); 101 | 102 | 103 | // Run 104 | std::vector output_tensors; 105 | Status run_status = session->Run({{input_layer, input_tensor}, {"input_lengths", input_lengths_tensor}}, 106 | {output_layer}, {}, &output_tensors); 107 | if (!run_status.ok()) { 108 | std::cerr << "Running model failed: " << run_status; 109 | return false; 110 | } 111 | 112 | auto endT = std::chrono::system_clock::now(); 113 | std::chrono::duration ms = endT - startT; 114 | 115 | std::cout << "Synth time : " << ms.count() << " [ms]" << std::endl; 116 | 117 | const Tensor& output_tensor = output_tensors[0]; 118 | std::cout << "output dim " << output_tensor.dims() << std::endl; 119 | 120 | TTypes::ConstTensor tensor = output_tensor.tensor(); 121 | std::cout << "len = " << tensor.dimension(0) << std::endl; 122 | 123 | assert(tensor.dimension(0) > 0); 124 | output->resize(tensor.dimension(0)); 125 | // TODO(LTE): Use memcpy 126 | for (size_t i = 0; i < output->size(); i++) { 127 | (*output)[i] = tensor(i); 128 | } 129 | 130 | return true; 131 | } 132 | 133 | private: 134 | std::unique_ptr session; 135 | std::string input_layer, output_layer; 136 | }; 137 | 138 | // PImpl pattern 139 | TensorflowSynthesizer::TensorflowSynthesizer() : impl(new Impl()) {} 140 | TensorflowSynthesizer::~TensorflowSynthesizer() {} 141 | void TensorflowSynthesizer::init(int argc, char* argv[]) { 142 | impl->init(argc, argv); 143 | } 144 | bool TensorflowSynthesizer::load(const std::string& graph_filename, 145 | const std::string& inp_layer, 146 | const std::string& out_layer) { 147 | return impl->load(graph_filename, inp_layer, out_layer); 148 | } 149 | 150 | bool TensorflowSynthesizer::synthesize(const std::vector &input_sequence, const std::vector &input_lengths, std::vector *output) { 151 | return impl->synthesize(input_sequence, input_lengths, output); 152 | } 153 | 154 | 155 | 156 | } // namespace tts 157 | -------------------------------------------------------------------------------- /src/tf_synthesizer.h: -------------------------------------------------------------------------------- 1 | #ifndef TF_SYNTHESIZER_H_ 2 | #define TF_SYNTHESIZER_H_ 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | namespace tts { 9 | 10 | class TensorflowSynthesizer { 11 | public: 12 | TensorflowSynthesizer(); 13 | ~TensorflowSynthesizer(); 14 | 15 | void init(int argc, char* argv[]); 16 | 17 | /// 18 | /// Load's pretrained TF model. 19 | /// 20 | bool load(const std::string& graph_filename, const std::string& inp_layer, 21 | const std::string& out_layer); 22 | 23 | /// 24 | /// Synthesize speech. 25 | /// 26 | /// @param[in] input_sequence Input sequence. Shape = [N, T_in]. 27 | /// @param[in] input_lengths Tensor with shape = [N], where N is batch size 28 | /// and values are the lengths 29 | /// of each sequence in inputs. 30 | /// @param[out] output Output audio data(floating point) 31 | /// 32 | bool synthesize(const std::vector& input_sequence, const std::vector &input_lengths, std::vector *output); 33 | 34 | private: 35 | class Impl; 36 | std::unique_ptr impl; 37 | }; 38 | 39 | } // namespace tts 40 | 41 | #endif // TF_SYNTHESIZER_H_ 42 | -------------------------------------------------------------------------------- /tacotron_frozen.pb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/syoyo/tacotron-tts-cpp/adb3d62913590aac3b8f1f1434c44c525dbbf8d3/tacotron_frozen.pb --------------------------------------------------------------------------------