├── .gitignore ├── .gitmodules ├── .travis.yml ├── CHANGELOG.md ├── CMakeLists.txt ├── LICENSE.md ├── README.md ├── cli ├── Batch.h ├── BatchReader.cc ├── BatchReader.h ├── BatchWriter.cc ├── BatchWriter.h ├── CMakeLists.txt └── translate.cc ├── cmake ├── FindEigen3.cmake └── FindMKL.cmake ├── include └── onmt │ ├── Dictionary.h │ ├── Eigen │ └── MatrixBatch.h │ ├── ITranslator.h │ ├── Logger.h │ ├── Model.h │ ├── Model.hxx │ ├── PhraseTable.h │ ├── Profiler.h │ ├── StorageLoader.h │ ├── SubDict.h │ ├── Threads.h │ ├── TranslationOptions.h │ ├── TranslationResult.h │ ├── Translator.h │ ├── Translator.hxx │ ├── TranslatorFactory.h │ ├── Utils.h │ ├── android_gnustl_compat.h │ ├── cuda │ ├── Kernels.cuh │ └── Utils.h │ ├── nn │ ├── CAddTable.h │ ├── CMulTable.h │ ├── ConcatTable.h │ ├── Container.h │ ├── Graph.h │ ├── Identity.h │ ├── JoinTable.h │ ├── Linear.h │ ├── LogSoftMax.h │ ├── LookupTable.h │ ├── MM.h │ ├── Module.h │ ├── ModuleFactory.h │ ├── ModuleFactory.hxx │ ├── MulConstant.h │ ├── Node.h │ ├── ParallelTable.h │ ├── Replicate.h │ ├── Reshape.h │ ├── SelectTable.h │ ├── Sequential.h │ ├── Sigmoid.h │ ├── SoftMax.h │ ├── SplitTable.h │ ├── Squeeze.h │ ├── Sum.h │ ├── Tanh.h │ ├── cuLinear.h │ └── qLinear.h │ ├── onmt.h │ ├── simd │ └── MatrixMult.h │ └── th │ ├── Env.h │ ├── Obj.h │ ├── Obj.hxx │ └── Utils.h ├── lib └── TH │ ├── CMakeLists.txt │ ├── COPYRIGHT.txt │ ├── THDiskFile.c │ ├── THDiskFile.h │ ├── THFile.c │ ├── THFile.h │ ├── THFilePrivate.h │ ├── THGeneral.c │ └── THGeneral.h └── src ├── Dictionary.cc ├── ITranslator.cc ├── Logger.cc ├── PhraseTable.cc ├── Profiler.cc ├── SubDict.cc ├── Threads.cc ├── TranslationOptions.cc ├── TranslationResult.cc ├── TranslatorFactory.cc ├── cuda ├── Kernels.cu └── Utils.cc ├── simd ├── AVX2_MatrixMult.cc ├── AVX512_MatrixMult.cc └── SSE_MatrixMult.cc └── th ├── Env.cc ├── Obj.cc └── Utils.cc /.gitignore: -------------------------------------------------------------------------------- 1 | build 2 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "lib/tokenizer"] 2 | path = lib/tokenizer 3 | url = https://github.com/OpenNMT/Tokenizer.git 4 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | dist: xenial 2 | language: cpp 3 | compiler: 4 | - gcc 5 | - clang 6 | addons: 7 | apt: 8 | packages: 9 | - cmake 10 | - libboost-program-options-dev 11 | before_install: 12 | - export ROOT_TRAVIS_DIR=$(pwd) 13 | - wget --no-check-certificate https://bitbucket.org/eigen/eigen/get/3.3.2.tar.bz2 14 | - tar xf 3.3.2.tar.bz2 15 | - mkdir -p ~/lib/ 16 | - mkdir -p eigen-eigen-da9b4e14c255/build && cd eigen-eigen-da9b4e14c255/build 17 | - cmake -DCMAKE_INSTALL_PREFIX=~/lib/ .. 18 | - make install 19 | - cd ../../ && rm -rf eigen-eigen-da9b4e14c255 20 | - cd $ROOT_TRAVIS_DIR 21 | script: 22 | - git submodule update --init 23 | - mkdir build && cd build 24 | - cmake -DCMAKE_INSTALL_PREFIX=~/lib/ .. 25 | - make 26 | - make install 27 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | ## [Unreleased] 2 | 3 | ### Breaking changes 4 | 5 | * Translation results now have an additional dimension covering the multiple hypotheses for each batch 6 | 7 | ### New features 8 | 9 | * Add n-best feature 10 | 11 | ### Fixes and improvements 12 | 13 | ## [v0.6.10](https://github.com/OpenNMT/CTranslate/releases/tag/v0.6.10) (2018-09-10) 14 | 15 | ### Fixes and improvements 16 | 17 | * Allow linking to external OpenNMT/Tokenizer 18 | 19 | ## [v0.6.9](https://github.com/OpenNMT/CTranslate/releases/tag/v0.6.9) (2018-09-07) 20 | 21 | ### Fixes and improvements 22 | 23 | * Update Tokenizer to v1.8.1 24 | 25 | ## [v0.6.8](https://github.com/OpenNMT/CTranslate/releases/tag/v0.6.8) (2018-09-07) 26 | 27 | ### Fixes and improvements 28 | 29 | * Update Tokenizer to v1.8.0 30 | 31 | ## [v0.6.7](https://github.com/OpenNMT/CTranslate/releases/tag/v0.6.7) (2018-09-04) 32 | 33 | ### Fixes and improvements 34 | 35 | * Update Tokenizer to v1.7.0 36 | 37 | ## [v0.6.6](https://github.com/OpenNMT/CTranslate/releases/tag/v0.6.6) (2018-08-29) 38 | 39 | ### Fixes and improvements 40 | 41 | * Update Tokenizer to v1.6.2 42 | 43 | ## [v0.6.5](https://github.com/OpenNMT/CTranslate/releases/tag/v0.6.5) (2018-07-30) 44 | 45 | ### Fixes and improvements 46 | 47 | * Update Tokenizer to v1.6.0 48 | 49 | ## [v0.6.4](https://github.com/OpenNMT/CTranslate/releases/tag/v0.6.4) (2018-07-13) 50 | 51 | ### Fixes and improvements 52 | 53 | * Update Tokenizer to v1.5.3 54 | 55 | ## [v0.6.3](https://github.com/OpenNMT/CTranslate/releases/tag/v0.6.3) (2018-07-12) 56 | 57 | ### Fixes and improvements 58 | 59 | * Update Tokenizer to v1.5.2 60 | 61 | ## [v0.6.2](https://github.com/OpenNMT/CTranslate/releases/tag/v0.6.2) (2018-07-12) 62 | 63 | ### Fixes and improvements 64 | 65 | * Update Tokenizer to v1.5.1 66 | 67 | ## [v0.6.1](https://github.com/OpenNMT/CTranslate/releases/tag/v0.6.1) (2018-07-05) 68 | 69 | ### Fixes and improvements 70 | 71 | * Update Tokenizer to v1.5.0 72 | 73 | ## [v0.6.0](https://github.com/OpenNMT/CTranslate/releases/tag/v0.6.0) (2018-06-14) 74 | 75 | ### New features 76 | 77 | * Add 16 bits quantization with SSE and AVX2 optimizations 78 | * Introduce vocabulary mapping to speed-up decoding-up 79 | * Report profiles per modules block (encoder_fwd, encoder_bwd, decoder, generator) 80 | * Support cloning a `Translator` while sharing the model data 81 | * Translate batches in parallel via `cli/translate` 82 | 83 | ### Fixes and improvements 84 | 85 | * Fix GRU support 86 | * Update Tokenizer to v1.4.0 87 | 88 | ## [v0.5.4](https://github.com/OpenNMT/CTranslate/releases/tag/v0.5.4) (2018-04-10) 89 | 90 | ### Fixes and improvements 91 | 92 | * Update Tokenizer to v1.3.0 93 | 94 | ## [v0.5.3](https://github.com/OpenNMT/CTranslate/releases/tag/v0.5.3) (2018-03-28) 95 | 96 | ### Fixes and improvements 97 | 98 | * Update Tokenizer to v1.2.0 99 | 100 | ## [v0.5.2](https://github.com/OpenNMT/CTranslate/releases/tag/v0.5.2) (2018-01-23) 101 | 102 | ### Fixes and improvements 103 | 104 | * Update Tokenizer to v1.1.1 105 | 106 | ## [v0.5.1](https://github.com/OpenNMT/CTranslate/releases/tag/v0.5.1) (2018-01-22) 107 | 108 | ### Fixes and improvements 109 | 110 | * Update Tokenizer to v1.1.0 111 | 112 | ## [v0.5.0](https://github.com/OpenNMT/CTranslate/releases/tag/v0.5.0) (2017-12-11) 113 | 114 | ### New features 115 | 116 | * Add module profiling 117 | * Link against Intel® MKL if available 118 | * [*experimental*] Offload matrix multiplication to the GPU 119 | 120 | ### Fixes and improvements 121 | 122 | * Improve Eigen library finder logic 123 | 124 | ## [v0.4.1](https://github.com/OpenNMT/CTranslate/releases/tag/v0.4.1) (2017-03-08) 125 | 126 | ### Fixes and improvements 127 | 128 | * Fix install rule for TH dependency 129 | 130 | ## [v0.4.0](https://github.com/OpenNMT/CTranslate/releases/tag/v0.4.0) (2017-03-08) 131 | 132 | ### New features 133 | 134 | * Add CMake install rule 135 | * Add static library compilation support 136 | 137 | ### Fixes and improvements 138 | 139 | * Tokenization is now an external library 140 | 141 | ## [v0.3.2](https://github.com/OpenNMT/CTranslate/releases/tag/v0.3.2) (2017-02-08) 142 | 143 | ### Fixes and improvements 144 | 145 | * Fix error when decoded sequences reached `max_sent_length` 146 | * Fix incorrect extraction of word features 147 | 148 | ## [v0.3.1](https://github.com/OpenNMT/CTranslate/releases/tag/v0.3.1) (2017-01-29) 149 | 150 | ### Fixes and improvements 151 | 152 | * Fix `--joiner_new` option when using BPE 153 | * Fix segmentation fault when a translator is destroyed and other instances are in use 154 | 155 | ## [v0.3.0](https://github.com/OpenNMT/CTranslate/releases/tag/v0.3.0) (2017-01-26) 156 | 157 | ### New features 158 | 159 | * Tokenization and detokenization 160 | 161 | ### Fixes and improvements 162 | 163 | * Fix errors when using models with word features 164 | * Remove Boost dependency when compiling as a library 165 | 166 | ## [v0.2.0](https://github.com/OpenNMT/CTranslate/releases/tag/v0.2.0) (2017-01-19) 167 | 168 | ### New features 169 | 170 | * Programmatically set the number of threads to use 171 | 172 | ### Fixes and improvements 173 | 174 | * Simplify project include with a single public header `onmt/onmt.h` 175 | * Fix compilation on Mac OS 176 | 177 | ## [v0.1.0](https://github.com/OpenNMT/CTranslate/releases/tag/v0.1.0) (2017-01-11) 178 | 179 | Initial release. 180 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.1.0) 2 | 3 | # Set policy for setting the MSVC runtime library for static MSVC builds 4 | if(POLICY CMP0091) 5 | cmake_policy(SET CMP0091 NEW) 6 | endif() 7 | 8 | project(onmt) 9 | 10 | option(LIB_ONLY "Do not compile clients" OFF) 11 | option(WITH_OPENMP "Use OpenMP if available" ON) 12 | option(WITH_CUDA "Use CUDA if available" ON) 13 | option(WITH_MKL "Use Intel® MKL if available" ON) 14 | option(WITH_QLINEAR "Add Quantized linear module, value can be OFF, SSE, AVX2 or AVX512" OFF) 15 | option(WITH_BOOST_LOG "Use Boost log if available" ON) 16 | option(BUILD_SHARED_LIBS "Build shared libraries" ON) 17 | 18 | set(CMAKE_CXX_STANDARD 11) 19 | 20 | # Set Release build type by default to get sane performance. 21 | if(NOT CMAKE_BUILD_TYPE) 22 | set(CMAKE_BUILD_TYPE Release) 23 | endif(NOT CMAKE_BUILD_TYPE) 24 | 25 | message(STATUS "Build type: " ${CMAKE_BUILD_TYPE}) 26 | 27 | if(WITH_OPENMP) 28 | find_package(OpenMP) 29 | if(OPENMP_FOUND) 30 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") 31 | else() 32 | message(WARNING "OpenMP not found: Compilation will not use OpenMP") 33 | endif() 34 | endif() 35 | 36 | if(ANDROID) 37 | set(LIB_ONLY ON) 38 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") 39 | if(ANDROID_STL MATCHES "gnustl") 40 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DANDROID_GNUSTL_COMPAT") 41 | endif() 42 | endif() 43 | 44 | if(MSVC) 45 | if(NOT BUILD_SHARED_LIBS) 46 | if(CMAKE_VERSION VERSION_LESS "3.15.0") 47 | message(FATAL_ERROR "Use CMake 3.15 or later when setting BUILD_SHARED_LIBS to OFF") 48 | endif() 49 | set(CMAKE_MSVC_RUNTIME_LIBRARY "MultiThreaded$<$:Debug>") 50 | endif() 51 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /W4") 52 | else() 53 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra") 54 | set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -pthread") 55 | endif() 56 | 57 | add_subdirectory(lib/TH) 58 | 59 | set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_SOURCE_DIR}/cmake/") 60 | 61 | if (EIGEN_ROOT) 62 | set(ENV{EIGEN3_ROOT} ${EIGEN_ROOT}) 63 | endif() 64 | if (EIGEN3_ROOT) 65 | set(ENV{EIGEN3_ROOT} ${EIGEN3_ROOT}) 66 | endif() 67 | 68 | find_package(Eigen3 3.3 REQUIRED) 69 | if(WITH_CUDA) 70 | find_package(CUDA 6.5) 71 | endif() 72 | if(WITH_MKL) 73 | find_package(MKL) 74 | endif() 75 | if(WITH_BOOST_LOG) 76 | find_package(Boost COMPONENTS log) 77 | endif() 78 | 79 | set(INCLUDE_DIRECTORIES 80 | ${CMAKE_CURRENT_SOURCE_DIR}/lib 81 | ${CMAKE_CURRENT_SOURCE_DIR}/include 82 | ${EIGEN3_INCLUDE_DIR} 83 | ${PROJECT_BINARY_DIR} 84 | ) 85 | set(SOURCES 86 | src/th/Env.cc 87 | src/th/Obj.cc 88 | src/th/Utils.cc 89 | src/Dictionary.cc 90 | src/SubDict.cc 91 | src/PhraseTable.cc 92 | src/Profiler.cc 93 | src/ITranslator.cc 94 | src/TranslatorFactory.cc 95 | src/TranslationOptions.cc 96 | src/TranslationResult.cc 97 | src/Threads.cc 98 | ) 99 | set(LIBRARIES 100 | OpenNMTTokenizer 101 | TH) 102 | 103 | find_library(ONMT_TOKENIZER_LIBRARY NAMES OpenNMTTokenizer) 104 | find_path(ONMT_TOKENIZER_INCLUDE_DIR NAMES onmt/Tokenizer.h) 105 | 106 | if(NOT ONMT_TOKENIZER_LIBRARY OR NOT ONMT_TOKENIZER_INCLUDE_DIR) 107 | message(STATUS "Using OpenNMT tokenizer submodule") 108 | add_subdirectory(lib/tokenizer) 109 | list(APPEND INCLUDE_DIRECTORIES ${CMAKE_CURRENT_SOURCE_DIR}/lib/tokenizer/include) 110 | else() 111 | message(STATUS "Found OpenNMT tokenizer: ${ONMT_TOKENIZER_LIBRARY}") 112 | list(APPEND INCLUDE_DIRECTORIES ${ONMT_TOKENIZER_INCLUDE_DIR}) 113 | get_filename_component(ONMT_TOKENIZER_LIBRARY_DIR ${ONMT_TOKENIZER_LIBRARY} DIRECTORY) 114 | link_directories(${ONMT_TOKENIZER_LIBRARY_DIR}) 115 | endif() 116 | 117 | if (WITH_QLINEAR) 118 | IF(WITH_QLINEAR STREQUAL "AVX2") 119 | SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DWITH_QLINEAR -DSIMD_AVX2") 120 | LIST(APPEND SOURCES 121 | src/simd/AVX2_MatrixMult.cc 122 | ) 123 | MESSAGE(STATUS "Using AVX2 instruction sets") 124 | ELSEIF(WITH_QLINEAR STREQUAL "SSE") 125 | SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DWITH_QLINEAR -DSIMD_SSE") 126 | LIST(APPEND SOURCES 127 | src/simd/SSE_MatrixMult.cc 128 | ) 129 | MESSAGE(STATUS "Using SSE instruction sets") 130 | ELSEIF(WITH_QLINEAR STREQUAL "AVX512") 131 | SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DWITH_QLINEAR -DSIMD_AVX512") 132 | LIST(APPEND SOURCES 133 | src/simd/AVX512_MatrixMult.cc 134 | ) 135 | MESSAGE(STATUS "Using AVX512 instruction sets") 136 | ELSE(WITH_QLINEAR STREQUAL "AVX2") 137 | MESSAGE(FATAL_ERROR "Incorrect value of WITH_QLINEAR option") 138 | ENDIF(WITH_QLINEAR STREQUAL "AVX2") 139 | endif() 140 | 141 | if (MKL_FOUND) 142 | add_definitions(-DWITH_MKL) 143 | add_definitions(-DEIGEN_USE_MKL_ALL) 144 | list(APPEND INCLUDE_DIRECTORIES ${MKL_INCLUDE_DIR}) 145 | list(APPEND LIBRARIES ${MKL_LIBRARIES}) 146 | endif() 147 | 148 | if (Boost_LOG_FOUND) 149 | add_definitions(-DWITH_BOOST_LOG) 150 | list(APPEND INCLUDE_DIRECTORIES ${Boost_INCLUDE_DIRS}) 151 | list(APPEND LIBRARIES ${Boost_LIBRARIES}) 152 | list(APPEND SOURCES src/Logger.cc) 153 | if (NOT Boost_USE_STATIC_LIBS) 154 | add_definitions(-DBOOST_ALL_DYN_LINK) 155 | if (WIN32 AND NOT CYGWIN) 156 | add_definitions(-DBOOST_ALL_NO_LIB) 157 | endif() 158 | endif() 159 | endif() 160 | 161 | if (CUDA_FOUND) 162 | add_definitions(-DWITH_CUDA) 163 | if(MSVC) 164 | if(BUILD_SHARED_LIBS) 165 | set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -Xcompiler=/MD$<$:d>") 166 | else() 167 | set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -Xcompiler=/MT$<$:d>") 168 | endif() 169 | endif() 170 | cuda_include_directories(${INCLUDE_DIRECTORIES}) 171 | cuda_add_library(${PROJECT_NAME} 172 | ${SOURCES} 173 | src/cuda/Utils.cc 174 | src/cuda/Kernels.cu 175 | ) 176 | cuda_add_cublas_to_target(${PROJECT_NAME}) 177 | else() 178 | add_library(${PROJECT_NAME} ${SOURCES}) 179 | endif() 180 | 181 | include(GNUInstallDirs) 182 | include(GenerateExportHeader) 183 | string(TOLOWER ${PROJECT_NAME} PROJECT_NAME_LOWER) 184 | generate_export_header(${PROJECT_NAME} EXPORT_FILE_NAME ${PROJECT_BINARY_DIR}/onmt/${PROJECT_NAME_LOWER}_export.h) 185 | target_link_libraries(${PROJECT_NAME} ${LIBRARIES}) 186 | target_include_directories(${PROJECT_NAME} PUBLIC ${INCLUDE_DIRECTORIES}) 187 | 188 | if (MSVC AND NOT (MSVC_VERSION LESS 1910)) 189 | if (Boost_LOG_FOUND) 190 | target_compile_options(${PROJECT_NAME} PUBLIC /permissive-) #to avoid C4596 errors in VS2017+ 191 | endif() 192 | if (OPENMP_FOUND) 193 | target_compile_options(${PROJECT_NAME} PUBLIC /Zc:twoPhase-) #to avoid "C2338: two-phase name lookup is not supported for C++/CLI, C++/CX, or OpenMP; use /Zc:twoPhase-" errors in VS2017+ 194 | endif() 195 | endif() 196 | 197 | if (NOT LIB_ONLY) 198 | add_subdirectory(cli) 199 | endif() 200 | 201 | install( 202 | TARGETS ${PROJECT_NAME} 203 | RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} 204 | ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} 205 | LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} 206 | INCLUDES DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} 207 | ) 208 | install( 209 | DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/include/onmt" 210 | DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}" 211 | FILES_MATCHING PATTERN "*.h*" PATTERN "*.cuh" 212 | ) 213 | install( 214 | FILES "${PROJECT_BINARY_DIR}/onmt/${PROJECT_NAME_LOWER}_export.h" 215 | DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/onmt" 216 | ) 217 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | **This project is considered obsolete as the Torch framework is no longer maintained. For compatibility with OpenNMT-tf or OpenNMT-py, please check out [CTranslate2](https://github.com/OpenNMT/CTranslate2).** 2 | 3 | [![Build Status](https://api.travis-ci.org/OpenNMT/CTranslate.svg?branch=master)](https://travis-ci.org/OpenNMT/CTranslate) 4 | 5 | # CTranslate 6 | 7 | CTranslate is a C++ implementation of OpenNMT's `translate.lua` script with no LuaTorch dependencies. It facilitates the use of OpenNMT models in existing products and on various platforms using [Eigen](http://eigen.tuxfamily.org) as a backend. 8 | 9 | CTranslate provides optimized CPU translation and optionally offloads matrix multiplication on a CUDA-compatible device using [cuBLAS](http://docs.nvidia.com/cuda/cublas/). It only supports OpenNMT models released with the [`release_model.lua`](https://github.com/OpenNMT/OpenNMT/tree/master/tools#release-model) script. 10 | 11 | ## Dependencies 12 | 13 | * [Eigen](http://eigen.tuxfamily.org/index.php?title=Main_Page) >= 3.3 14 | * [Boost](http://www.boost.org/) (`log`, when `-DWITH_BOOST_LOG=ON`; `program_options`, when `-DLIB_ONLY=OFF`) 15 | 16 | ### Optional 17 | 18 | * [CUDA](https://developer.nvidia.com/cuda-toolkit) for matrix multiplication offloading on a GPU 19 | * [Intel® MKL](https://software.intel.com/en-us/intel-mkl) for an alternative BLAS backend 20 | 21 | ## Compiling 22 | 23 | *CMake and a compiler that supports the C++11 standard are required to compile the project.* 24 | 25 | ``` 26 | git submodule update --init 27 | mkdir build 28 | cd build 29 | cmake .. 30 | make 31 | ``` 32 | 33 | It will produce the dynamic library `libonmt.so` (or `.dylib` on Mac OS, `.dll` on Windows) and the translation client `cli/translate`. 34 | 35 | CTranslate also bundles OpenNMT's [Tokenizer](https://github.com/OpenNMT/Tokenizer) which provides the tokenization tools `lib/tokenizer/cli/tokenize` and `lib/tokenizer/cli/detokenize`. 36 | 37 | ### Options 38 | 39 | * To give hints about Eigen location, use the `-DEIGEN3_ROOT=` option. 40 | * To compile only the library, use the `-DLIB_ONLY=ON` flag. 41 | * To disable [OpenMP](http://www.openmp.org), use the `-DWITH_OPENMP=OFF` flag. 42 | * To enable optimization through quantization in matrix multiplications, use the `-DWITH_QLINEAR=AVX2|SSE` flag (`OFF` by default) and set the appropriate extended instructions set via `-DCMAKE_CXX_FLAGS`: 43 | * `-DWITH_QLINEAR=AVX2` requires at least `-mavx2` 44 | * `-DWITH_QLINEAR=SSE` requires at least `-mssse3` 45 | 46 | ### Performance tips 47 | 48 | * Use extended instructions sets: 49 | * if you are not cross-compiling, add `-DCMAKE_CXX_FLAGS="-march=native"` to the `cmake` command above to optimize for speed; 50 | * otherwise, select a recent [SIMD extensions](https://gcc.gnu.org/onlinedocs/gcc-5.5.0/gcc/x86-Options.html#x86-Options) to improve performance while meeting portability requirements. 51 | * Consider installing [Intel® MKL](https://software.intel.com/en-us/intel-mkl) when you are targetting Intel®-powered platforms. If found, the project will automatically link against it. 52 | * Consider using quantization options as described above. 53 | * When using `cli/translate`, consider fine-tuning the level of parallelism: 54 | * the `--parallel` option enables concurrent translation of `--batch_size` sentences 55 | * the `--threads` option enables each translation to use multiple threads 56 | * *Bottom-line:* if you want optimal throughput for a collection of sentences, increase `--parallel` and set `--threads` to 1; if you want minimal latency for a single batch, set `--parallel` to 1, and increase `--threads`. 57 | 58 | ## Using 59 | 60 | ### Clients 61 | 62 | See `--help` on the clients to discover available options and usage. They have the same interface as their Lua counterpart. 63 | 64 | ### Library 65 | 66 | This project is also a convenient way to load OpenNMT models and translate texts in existing software. 67 | 68 | Here is a very simple example: 69 | 70 | ```cpp 71 | #include 72 | 73 | #include 74 | 75 | int main() 76 | { 77 | // Create a new Translator object. 78 | auto translator = onmt::TranslatorFactory::build("enfr_model_release.t7"); 79 | 80 | // Translate a tokenized sentence. 81 | std::cout << translator->translate("Hello world !") << std::endl; 82 | 83 | return 0; 84 | } 85 | 86 | ``` 87 | 88 | For a more advanced usage, see: 89 | 90 | * `include/onmt/TranslatorFactory.h` to instantiate a new translator 91 | * `include/onmt/ITranslator.h` (the `Translator` interface) to translate sequences or batch of sequences 92 | * `include/onmt/TranslationResult.h` to retrieve results and attention vectors 93 | * `include/onmt/Threads.h` to programmatically control the number of threads to use 94 | 95 | Also see the headers available in the [Tokenizer](https://github.com/OpenNMT/Tokenizer) that are accessible when linking against CTranslate. 96 | 97 | ## Supported features 98 | 99 | CTranslate focuses on supporting model configurations that are likely to be used in production settings. It covers models trained with the default options, plus some variants: 100 | 101 | * additional input or output word features 102 | * `brnn` encoder (with `sum` or `concat` merge policy) 103 | * `dot` attention 104 | * residual connections 105 | * no input feeding 106 | 107 | Additionally, CTranslate misses some advanced features of `translate.lua`: 108 | 109 | * gold data score 110 | * hypotheses filtering 111 | * beam search normalization 112 | -------------------------------------------------------------------------------- /cli/Batch.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | class Batch 7 | { 8 | public: 9 | Batch() 10 | : _id(0) 11 | { 12 | } 13 | 14 | Batch(const std::vector& input, size_t id) 15 | : _input(input) 16 | , _id(id) 17 | { 18 | } 19 | 20 | void set_result(const std::vector >& translations, 21 | const std::vector >& score, 22 | const std::vector >& count_tgt_words, 23 | const std::vector >& count_tgt_unk_words, 24 | const std::vector& count_src_words, 25 | const std::vector& count_src_unk_words) 26 | { 27 | _translations = translations; 28 | _score = score; 29 | _count_tgt_words = count_tgt_words; 30 | _count_tgt_unk_words = count_tgt_unk_words; 31 | _count_src_words = count_src_words; 32 | _count_src_unk_words = count_src_unk_words; 33 | } 34 | 35 | size_t size() const 36 | { 37 | return _input.size(); 38 | } 39 | 40 | bool empty() const 41 | { 42 | return _input.empty(); 43 | } 44 | 45 | const std::vector& get_input() const 46 | { 47 | return _input; 48 | } 49 | 50 | const std::vector >& get_translations() const 51 | { 52 | return _translations; 53 | } 54 | 55 | const std::vector >& get_score() const 56 | { 57 | return _score; 58 | } 59 | 60 | const std::vector >& get_count_tgt_words() const 61 | { 62 | return _count_tgt_words; 63 | } 64 | 65 | const std::vector >& get_count_tgt_unk_words() const 66 | { 67 | return _count_tgt_unk_words; 68 | } 69 | 70 | const std::vector& get_count_src_words() const 71 | { 72 | return _count_src_words; 73 | } 74 | 75 | const std::vector& get_count_src_unk_words() const 76 | { 77 | return _count_src_unk_words; 78 | } 79 | 80 | size_t get_id() const 81 | { 82 | return _id; 83 | } 84 | 85 | private: 86 | std::vector _input; 87 | std::vector > _translations; 88 | std::vector > _score; 89 | std::vector _count_src_words; 90 | std::vector _count_src_unk_words; 91 | std::vector > _count_tgt_words; 92 | std::vector > _count_tgt_unk_words; 93 | size_t _id; 94 | }; 95 | -------------------------------------------------------------------------------- /cli/BatchReader.cc: -------------------------------------------------------------------------------- 1 | #include "BatchReader.h" 2 | #include 3 | 4 | BatchReader::BatchReader(const std::string& file, size_t batch_size) 5 | : _file(file.c_str()) 6 | , _in(_file) 7 | , _batch_size(batch_size) 8 | , _batch_id(0) 9 | , _read_lines(0) 10 | { 11 | if (!_file.is_open()) 12 | ONMT_LOG_STREAM_SEV("cannot open '" << file << '\'', boost::log::trivial::error); 13 | } 14 | 15 | BatchReader::BatchReader(std::istream& in, size_t batch_size) 16 | : _in(in) 17 | , _batch_size(batch_size) 18 | , _batch_id(0) 19 | , _read_lines(0) 20 | { 21 | } 22 | 23 | Batch BatchReader::read_next() 24 | { 25 | std::lock_guard lock(_reader_mutex); 26 | std::vector batch; 27 | 28 | if (_in.eof()) 29 | return Batch(batch, ++_batch_id); 30 | 31 | batch.reserve(_batch_size); 32 | 33 | std::string line; 34 | while (batch.size() < _batch_size && std::getline(_in, line)) 35 | batch.push_back(line); 36 | 37 | _read_lines += batch.size(); 38 | 39 | return Batch(batch, ++_batch_id); 40 | } 41 | -------------------------------------------------------------------------------- /cli/BatchReader.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | #include "Batch.h" 7 | 8 | class BatchReader 9 | { 10 | public: 11 | BatchReader(const std::string& file, size_t batch_size); 12 | BatchReader(std::istream& in, size_t batch_size); 13 | 14 | Batch read_next(); 15 | 16 | size_t read_lines() const 17 | { 18 | return _read_lines; 19 | } 20 | 21 | private: 22 | std::ifstream _file; 23 | std::istream& _in; 24 | size_t _batch_size; 25 | size_t _batch_id; 26 | size_t _read_lines; 27 | std::mutex _reader_mutex; 28 | }; 29 | -------------------------------------------------------------------------------- /cli/BatchWriter.cc: -------------------------------------------------------------------------------- 1 | #include "BatchWriter.h" 2 | #include 3 | #include 4 | #include 5 | 6 | BatchWriter::BatchWriter(const std::string& file) 7 | : _file(file.c_str()) 8 | , _out(_file) 9 | , _sentence_no(0) 10 | , _last_batch_id(0) 11 | , _total_count_src_words(0) 12 | , _total_count_src_unk_words(0) 13 | , _total_count_tgt_words(0) 14 | , _total_count_tgt_unk_words(0) 15 | , _total_score(0) 16 | { 17 | if (!_file.is_open()) 18 | ONMT_LOG_STREAM_SEV("cannot open '" << file << '\'', boost::log::trivial::error); 19 | } 20 | 21 | BatchWriter::BatchWriter(std::ostream& out) 22 | : _out(out) 23 | , _sentence_no(0) 24 | , _last_batch_id(0) 25 | , _total_count_src_words(0) 26 | , _total_count_src_unk_words(0) 27 | , _total_count_tgt_words(0) 28 | , _total_count_tgt_unk_words(0) 29 | , _total_score(0) 30 | { 31 | } 32 | 33 | void BatchWriter::write(const Batch& batch) 34 | { 35 | std::lock_guard lock(_writer_mutex); 36 | size_t batch_id = batch.get_id(); 37 | if (batch_id == _last_batch_id + 1) 38 | { 39 | auto input = batch.get_input(); 40 | auto translations = batch.get_translations(); 41 | auto score = batch.get_score(); 42 | auto count_tgt_words = batch.get_count_tgt_words(); 43 | auto count_tgt_unk_words = batch.get_count_tgt_unk_words(); 44 | auto count_src_words = batch.get_count_src_words(); 45 | auto count_src_unk_words = batch.get_count_src_unk_words(); 46 | 47 | while (batch_id == _last_batch_id + 1) 48 | { 49 | for (size_t b = 0; b < input.size(); ++b) 50 | { 51 | _total_count_src_words += count_src_words[b]; 52 | _total_count_src_unk_words += count_src_unk_words[b]; 53 | ONMT_LOG_STREAM_SEV("SENT " << ++_sentence_no << ": " << input[b], boost::log::trivial::info); 54 | if (translations[b].size() > 1) 55 | { 56 | ONMT_LOG_STREAM_SEV("", boost::log::trivial::info); 57 | ONMT_LOG_STREAM_SEV("BEST HYP:", boost::log::trivial::info); 58 | } 59 | 60 | for (size_t n = 0; n < translations[b].size(); ++n) 61 | { 62 | _out << translations[b][n] << std::endl; 63 | if (translations[b].size() > 1) 64 | { 65 | ONMT_LOG_STREAM_SEV('[' << std::fixed << std::setprecision(2) << score[b][n] << "] " << translations[b][n], boost::log::trivial::info); 66 | } 67 | else 68 | { 69 | ONMT_LOG_STREAM_SEV("PRED " << _sentence_no << ": " << translations[b][n], boost::log::trivial::info); 70 | ONMT_LOG_STREAM_SEV("PRED SCORE: " << std::fixed << std::setprecision(2) << score[b][n], boost::log::trivial::info); 71 | } 72 | 73 | // count target unknown words and words generated on 1-best 74 | if (n == 0) 75 | { 76 | _total_count_tgt_words += count_tgt_words[b][n]; 77 | _total_count_tgt_unk_words += count_tgt_unk_words[b][n]; 78 | _total_score += score[b][n]; 79 | } 80 | } 81 | 82 | ONMT_LOG_STREAM_SEV("", boost::log::trivial::info); 83 | } 84 | 85 | _last_batch_id = batch_id; 86 | auto it = _pending_batches.find(_last_batch_id + 1); 87 | if (it != _pending_batches.end()) 88 | { 89 | input = it->second.get_input(); 90 | translations = it->second.get_translations(); 91 | score = it->second.get_score(); 92 | count_tgt_words = it->second.get_count_tgt_words(); 93 | count_tgt_unk_words = it->second.get_count_tgt_unk_words(); 94 | count_src_words = it->second.get_count_src_words(); 95 | count_src_unk_words = it->second.get_count_src_unk_words(); 96 | batch_id = _last_batch_id + 1; 97 | _pending_batches.erase(it); 98 | } 99 | } 100 | } 101 | else 102 | { 103 | _pending_batches[batch_id] = batch; 104 | } 105 | } 106 | 107 | size_t BatchWriter::total_count_src_words() const 108 | { 109 | return _total_count_src_words; 110 | } 111 | 112 | size_t BatchWriter::total_count_src_unk_words() const 113 | { 114 | return _total_count_src_unk_words; 115 | } 116 | 117 | size_t BatchWriter::total_count_tgt_words() const 118 | { 119 | return _total_count_tgt_words; 120 | } 121 | 122 | size_t BatchWriter::total_count_tgt_unk_words() const 123 | { 124 | return _total_count_tgt_unk_words; 125 | } 126 | 127 | float BatchWriter::total_score() const 128 | { 129 | return _total_score; 130 | } 131 | -------------------------------------------------------------------------------- /cli/BatchWriter.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include "Batch.h" 8 | 9 | class BatchWriter 10 | { 11 | public: 12 | BatchWriter(const std::string& file); 13 | BatchWriter(std::ostream& out); 14 | 15 | void write(const Batch& batch); 16 | size_t total_count_src_words() const; 17 | size_t total_count_src_unk_words() const; 18 | size_t total_count_tgt_words() const; 19 | size_t total_count_tgt_unk_words() const; 20 | float total_score() const; 21 | 22 | private: 23 | std::ofstream _file; 24 | std::ostream& _out; 25 | size_t _sentence_no; 26 | std::map _pending_batches; 27 | size_t _last_batch_id; 28 | size_t _total_count_src_words; 29 | size_t _total_count_src_unk_words; 30 | size_t _total_count_tgt_words; 31 | size_t _total_count_tgt_unk_words; 32 | float _total_score; 33 | std::mutex _writer_mutex; 34 | }; 35 | -------------------------------------------------------------------------------- /cli/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | find_package(Boost REQUIRED COMPONENTS program_options) 2 | 3 | include_directories( 4 | ${Boost_INCLUDE_DIRS} 5 | ) 6 | 7 | if (WIN32 AND NOT CYGWIN AND NOT Boost_USE_STATIC_LIBS) 8 | add_definitions(-DBOOST_ALL_NO_LIB) #Boost: Tells the config system not to automatically select which libraries to link against 9 | add_definitions(-DBOOST_ALL_DYN_LINK) #Boost: Forces all libraries that have separate source, to be linked as dll's rather than static libraries on Microsoft Windows 10 | endif() 11 | 12 | add_executable(translate 13 | translate.cc 14 | BatchReader.cc 15 | BatchWriter.cc 16 | ) 17 | target_link_libraries(translate 18 | ${PROJECT_NAME} 19 | ${Boost_LIBRARIES} 20 | ) 21 | 22 | install( 23 | TARGETS translate 24 | DESTINATION bin/ 25 | ) 26 | -------------------------------------------------------------------------------- /cli/translate.cc: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | #include 14 | #include 15 | #ifdef WITH_BOOST_LOG 16 | # include 17 | #endif 18 | 19 | #include "BatchReader.h" 20 | #include "BatchWriter.h" 21 | 22 | namespace po = boost::program_options; 23 | 24 | int main(int argc, char* argv[]) 25 | { 26 | po::options_description desc("OpenNMT Translator"); 27 | desc.add_options() 28 | ("help", "display available options") 29 | ("model", po::value(), "path to the OpenNMT model") 30 | ("src", po::value(), "path to the file to translate (read from the standard input if not set)") 31 | ("tgt", po::value(), "path to the output file (write to the standard output if not set") 32 | ("phrase_table", po::value()->default_value(""), "path to the phrase table") 33 | ("vocab_mapping", po::value()->default_value(""), "path to a vocabulary mapping table") 34 | ("replace_unk", po::bool_switch()->default_value(false), "replace unknown tokens by source tokens with the highest attention") 35 | ("replace_unk_tagged", po::bool_switch()->default_value(false), "The same as -replace_unk, but wrap the replaced token in⦅unk:xxxxx⦆if it is not found in the phrase table") 36 | ("batch_size", po::value()->default_value(30), "batch size") 37 | ("beam_size", po::value()->default_value(5), "beam size") 38 | ("n_best", po::value()->default_value(1), "n best") 39 | ("max_sent_length", po::value()->default_value(250), "maximum sentence length to produce") 40 | ("time", po::bool_switch()->default_value(false), "output average translation time") 41 | ("profiler", po::bool_switch()->default_value(false), "output per module computation time") 42 | ("parallel", po::value()->default_value(1), "number of parallel translator") 43 | ("threads", po::value()->default_value(0), "number of threads to use (set to 0 to use the number defined by OpenMP)") 44 | ("cuda", po::bool_switch()->default_value(false), "use cuda when available") 45 | ("qlinear", po::bool_switch()->default_value(false), "use quantized linear for speed-up") 46 | #ifdef WITH_BOOST_LOG 47 | ("log_file", po::value()->default_value(""), "path to the log file (write to standard output if not set)") 48 | ("disable_logs", po::bool_switch()->default_value(false), "if set, output nothing") 49 | ("log_level", po::value()->default_value("INFO"), "output logs at this level and above (accepted: DEBUG, INFO, WARNING, ERROR, NONE)") 50 | #endif 51 | ; 52 | 53 | po::variables_map vm; 54 | po::store(po::parse_command_line(argc, argv, desc), vm); 55 | po::notify(vm); 56 | 57 | if (vm.count("help")) 58 | { 59 | std::cerr << desc << std::endl; 60 | return 1; 61 | } 62 | 63 | if (!vm.count("model")) 64 | { 65 | std::cerr << "missing model" << std::endl; 66 | return 1; 67 | } 68 | 69 | #ifdef WITH_BOOST_LOG 70 | onmt::Logger::init(vm["log_file"].as(), vm["disable_logs"].as(), vm["log_level"].as()); 71 | #endif 72 | 73 | if (vm["threads"].as() > 0) 74 | onmt::Threads::set(vm["threads"].as()); 75 | 76 | onmt::TranslationOptions options(vm["max_sent_length"].as(), 77 | vm["beam_size"].as(), 78 | vm["n_best"].as(), 79 | vm["replace_unk"].as(), 80 | vm["replace_unk_tagged"].as()); 81 | 82 | std::vector> translator_pool; 83 | translator_pool.emplace_back(onmt::TranslatorFactory::build(vm["model"].as(), 84 | vm["phrase_table"].as(), 85 | vm["vocab_mapping"].as(), 86 | vm["cuda"].as(), 87 | vm["qlinear"].as(), 88 | vm["profiler"].as())); 89 | for (size_t i = 0; i < vm["parallel"].as() - 1; ++i) { 90 | translator_pool.emplace_back(onmt::TranslatorFactory::clone(translator_pool.front())); 91 | } 92 | 93 | std::unique_ptr reader; 94 | if (vm.count("src")) 95 | reader.reset(new BatchReader(vm["src"].as(), vm["batch_size"].as())); 96 | else 97 | reader.reset(new BatchReader(std::cin, vm["batch_size"].as())); 98 | 99 | std::unique_ptr writer; 100 | if (vm.count("tgt")) 101 | writer.reset(new BatchWriter(vm["tgt"].as())); 102 | else 103 | writer.reset(new BatchWriter(std::cout)); 104 | 105 | std::chrono::high_resolution_clock::time_point t1, t2; 106 | 107 | if (vm["time"].as()) 108 | t1 = std::chrono::high_resolution_clock::now(); 109 | 110 | std::vector> futures; 111 | 112 | for (auto& translator: translator_pool) 113 | { 114 | futures.emplace_back( 115 | std::async(std::launch::async, 116 | [](BatchReader* p_reader, BatchWriter* p_writer, onmt::ITranslator* p_trans, const onmt::TranslationOptions& options) 117 | { 118 | while (true) 119 | { 120 | auto batch = p_reader->read_next(); 121 | if (batch.empty()) 122 | break; 123 | 124 | std::vector > res; 125 | std::vector > score; 126 | std::vector > count_tgt_words, count_tgt_unk_words; 127 | std::vector count_src_words, count_src_unk_words; 128 | try 129 | { 130 | res = p_trans->get_translations_batch(batch.get_input(), score, count_tgt_words, count_tgt_unk_words, count_src_words, count_src_unk_words, options); 131 | } 132 | catch (const std::exception& e) 133 | { 134 | ONMT_LOG_STREAM_SEV(e.what(), boost::log::trivial::error); 135 | throw; 136 | } 137 | 138 | batch.set_result(res, score, count_tgt_words, count_tgt_unk_words, count_src_words, count_src_unk_words); 139 | p_writer->write(batch); 140 | } 141 | return true; 142 | }, 143 | reader.get(), 144 | writer.get(), 145 | translator.get(), 146 | options)); 147 | } 148 | 149 | for (auto& f: futures) 150 | f.wait(); 151 | 152 | ONMT_LOG_STREAM_SEV("Translated " << writer->total_count_src_words() << " words, src unk count: " << writer->total_count_src_unk_words() 153 | << ", coverage: " << std::floor(writer->total_count_src_unk_words() * 1000.0 / writer->total_count_src_words()) / 10.0 << "%, " 154 | << "tgt words: " << writer->total_count_tgt_words() << " words, tgt unk count: " << writer->total_count_tgt_unk_words() 155 | << ", coverage: " << std::floor(writer->total_count_tgt_unk_words() * 1000.0 / writer->total_count_tgt_words()) / 10.0 << "%, ", boost::log::trivial::info); 156 | ONMT_LOG_STREAM_SEV("PRED AVG SCORE: " << std::fixed << std::setprecision(2) << writer->total_score() / writer->total_count_tgt_words() 157 | << ", PRED PPL: " << std::exp(-writer->total_score() / writer->total_count_tgt_words()), boost::log::trivial::info); 158 | 159 | if (vm["time"].as()) 160 | { 161 | t2 = std::chrono::high_resolution_clock::now(); 162 | std::chrono::duration sec = t2 - t1; 163 | size_t num_sents = reader->read_lines(); 164 | std::cerr << "avg real (seconds/sentence)\t" << sec.count() / num_sents << std::endl; 165 | } 166 | 167 | return 0; 168 | } 169 | -------------------------------------------------------------------------------- /cmake/FindEigen3.cmake: -------------------------------------------------------------------------------- 1 | # - Try to find Eigen3 lib 2 | # 3 | # This module supports requiring a minimum version, e.g. you can do 4 | # find_package(Eigen3 3.1.2) 5 | # to require version 3.1.2 or newer of Eigen3. 6 | # 7 | # Once done this will define 8 | # 9 | # EIGEN3_FOUND - system has eigen lib with correct version 10 | # EIGEN3_INCLUDE_DIR - the eigen include directory 11 | # EIGEN3_VERSION - eigen version 12 | # 13 | # This module reads hints about search locations from 14 | # the following enviroment variables: 15 | # 16 | # EIGEN3_ROOT 17 | # EIGEN3_ROOT_DIR 18 | 19 | # Copyright (c) 2006, 2007 Montel Laurent, 20 | # Copyright (c) 2008, 2009 Gael Guennebaud, 21 | # Copyright (c) 2009 Benoit Jacob 22 | # Redistribution and use is allowed according to the terms of the 2-clause BSD license. 23 | 24 | if(NOT Eigen3_FIND_VERSION) 25 | if(NOT Eigen3_FIND_VERSION_MAJOR) 26 | set(Eigen3_FIND_VERSION_MAJOR 2) 27 | endif(NOT Eigen3_FIND_VERSION_MAJOR) 28 | if(NOT Eigen3_FIND_VERSION_MINOR) 29 | set(Eigen3_FIND_VERSION_MINOR 91) 30 | endif(NOT Eigen3_FIND_VERSION_MINOR) 31 | if(NOT Eigen3_FIND_VERSION_PATCH) 32 | set(Eigen3_FIND_VERSION_PATCH 0) 33 | endif(NOT Eigen3_FIND_VERSION_PATCH) 34 | 35 | set(Eigen3_FIND_VERSION "${Eigen3_FIND_VERSION_MAJOR}.${Eigen3_FIND_VERSION_MINOR}.${Eigen3_FIND_VERSION_PATCH}") 36 | endif(NOT Eigen3_FIND_VERSION) 37 | 38 | macro(_eigen3_check_version) 39 | file(READ "${EIGEN3_INCLUDE_DIR}/Eigen/src/Core/util/Macros.h" _eigen3_version_header) 40 | 41 | string(REGEX MATCH "define[ \t]+EIGEN_WORLD_VERSION[ \t]+([0-9]+)" _eigen3_world_version_match "${_eigen3_version_header}") 42 | set(EIGEN3_WORLD_VERSION "${CMAKE_MATCH_1}") 43 | string(REGEX MATCH "define[ \t]+EIGEN_MAJOR_VERSION[ \t]+([0-9]+)" _eigen3_major_version_match "${_eigen3_version_header}") 44 | set(EIGEN3_MAJOR_VERSION "${CMAKE_MATCH_1}") 45 | string(REGEX MATCH "define[ \t]+EIGEN_MINOR_VERSION[ \t]+([0-9]+)" _eigen3_minor_version_match "${_eigen3_version_header}") 46 | set(EIGEN3_MINOR_VERSION "${CMAKE_MATCH_1}") 47 | 48 | set(EIGEN3_VERSION ${EIGEN3_WORLD_VERSION}.${EIGEN3_MAJOR_VERSION}.${EIGEN3_MINOR_VERSION}) 49 | if(${EIGEN3_VERSION} VERSION_LESS ${Eigen3_FIND_VERSION}) 50 | set(EIGEN3_VERSION_OK FALSE) 51 | else(${EIGEN3_VERSION} VERSION_LESS ${Eigen3_FIND_VERSION}) 52 | set(EIGEN3_VERSION_OK TRUE) 53 | endif(${EIGEN3_VERSION} VERSION_LESS ${Eigen3_FIND_VERSION}) 54 | 55 | if(NOT EIGEN3_VERSION_OK) 56 | 57 | message(STATUS "Eigen3 version ${EIGEN3_VERSION} found in ${EIGEN3_INCLUDE_DIR}, " 58 | "but at least version ${Eigen3_FIND_VERSION} is required") 59 | endif(NOT EIGEN3_VERSION_OK) 60 | endmacro(_eigen3_check_version) 61 | 62 | macro(_search_eigen3) 63 | if(NOT EIGEN3_INCLUDE_DIR) 64 | find_path(EIGEN3_INCLUDE_DIR NAMES signature_of_eigen3_matrix_library 65 | HINTS 66 | ENV EIGEN3_ROOT 67 | ENV EIGEN3_ROOT_DIR 68 | PATHS 69 | ${CMAKE_INSTALL_PREFIX}/include 70 | ${KDE4_INCLUDE_DIR} 71 | PATH_SUFFIXES eigen3 eigen 72 | ) 73 | endif(NOT EIGEN3_INCLUDE_DIR) 74 | endmacro(_search_eigen3) 75 | 76 | if (EIGEN3_INCLUDE_DIR) 77 | 78 | # in cache already 79 | _eigen3_check_version() 80 | set(EIGEN3_FOUND ${EIGEN3_VERSION_OK}) 81 | 82 | else (EIGEN3_INCLUDE_DIR) 83 | 84 | if(DEFINED ENV{EIGEN3_ROOT}) 85 | _search_eigen3() 86 | else() 87 | # search if an Eigen3Config.cmake is available in the system, 88 | # if successful this would set EIGEN3_INCLUDE_DIR and the rest of 89 | # the script will work as usual 90 | find_package(Eigen3 ${Eigen3_FIND_VERSION} NO_MODULE QUIET) 91 | _search_eigen3() 92 | endif() 93 | 94 | if(EIGEN3_INCLUDE_DIR) 95 | _eigen3_check_version() 96 | endif(EIGEN3_INCLUDE_DIR) 97 | 98 | include(FindPackageHandleStandardArgs) 99 | find_package_handle_standard_args(Eigen3 DEFAULT_MSG EIGEN3_INCLUDE_DIR EIGEN3_VERSION_OK) 100 | 101 | mark_as_advanced(EIGEN3_INCLUDE_DIR) 102 | 103 | endif(EIGEN3_INCLUDE_DIR) 104 | 105 | -------------------------------------------------------------------------------- /cmake/FindMKL.cmake: -------------------------------------------------------------------------------- 1 | # Modified from Caffe. 2 | 3 | 4 | # All contributions by the University of California: 5 | # Copyright (c) 2014-2017 The Regents of the University of California (Regents) 6 | # All rights reserved. 7 | 8 | # All other contributions: 9 | # Copyright (c) 2014-2017, the respective contributors 10 | # All rights reserved. 11 | 12 | # Caffe uses a shared copyright model: each contributor holds copyright over 13 | # their contributions to Caffe. The project versioning records all such 14 | # contribution and copyright details. If a contributor wants to further mark 15 | # their specific copyright on a particular contribution, they should indicate 16 | # their copyright solely in the commit message of the change when it is 17 | # committed. 18 | 19 | 20 | # Find the MKL libraries 21 | # 22 | # Options: 23 | # 24 | # MKL_USE_STATIC_LIBS : use static libraries 25 | # 26 | # This module defines the following variables: 27 | # 28 | # MKL_FOUND : True mkl is found 29 | # MKL_INCLUDE_DIR : unclude directory 30 | # MKL_LIBRARIES : the libraries to link against. 31 | 32 | 33 | # ---[ Options 34 | option(MKL_USE_STATIC_LIBS "Use static libraries" OFF) 35 | 36 | # ---[ Root folders 37 | if(WIN32) 38 | set(ProgramFilesx86 "ProgramFiles(x86)") 39 | set(INTEL_ROOT_DEFAULT $ENV{${ProgramFilesx86}}/IntelSWTools/compilers_and_libraries/windows) 40 | else() 41 | set(INTEL_ROOT_DEFAULT "/opt/intel") 42 | endif() 43 | set(INTEL_ROOT ${INTEL_ROOT_DEFAULT} CACHE PATH "Folder contains intel libs") 44 | find_path(MKL_ROOT include/mkl.h PATHS $ENV{MKLROOT} ${INTEL_ROOT}/mkl 45 | DOC "Folder contains MKL") 46 | 47 | # ---[ Find include dir 48 | find_path(MKL_INCLUDE_DIR mkl.h PATHS ${MKL_ROOT} PATH_SUFFIXES include) 49 | set(__looked_for MKL_INCLUDE_DIR) 50 | 51 | # ---[ Find libraries 52 | if(CMAKE_SIZEOF_VOID_P EQUAL 4) 53 | set(__path_suffixes lib lib/ia32) 54 | else() 55 | set(__path_suffixes lib lib/intel64) 56 | endif() 57 | 58 | set(__mkl_libs "") 59 | 60 | if(CMAKE_SIZEOF_VOID_P EQUAL 4) 61 | if(WIN32) 62 | list(APPEND __mkl_libs intel_c) 63 | else() 64 | list(APPEND __mkl_libs intel) 65 | endif() 66 | else() 67 | list(APPEND __mkl_libs intel_lp64) 68 | endif() 69 | 70 | if(WIN32) 71 | list(APPEND __mkl_libs intel_thread) 72 | else() 73 | list(APPEND __mkl_libs gnu_thread) 74 | endif() 75 | 76 | list(APPEND __mkl_libs core) 77 | 78 | foreach (__lib ${__mkl_libs}) 79 | set(__mkl_lib "mkl_${__lib}") 80 | string(TOUPPER ${__mkl_lib} __mkl_lib_upper) 81 | 82 | if(WIN32) 83 | if(NOT MKL_USE_STATIC_LIBS) 84 | set(__mkl_lib "${__mkl_lib}_dll") 85 | endif() 86 | else() 87 | if(MKL_USE_STATIC_LIBS) 88 | set(__mkl_lib "lib${__mkl_lib}.a") 89 | endif() 90 | endif() 91 | 92 | find_library(${__mkl_lib_upper}_LIBRARY 93 | NAMES ${__mkl_lib} 94 | PATHS ${MKL_ROOT} "${MKL_INCLUDE_DIR}/.." 95 | PATH_SUFFIXES ${__path_suffixes} 96 | DOC "The path to Intel(R) MKL ${__mkl_lib} library") 97 | mark_as_advanced(${__mkl_lib_upper}_LIBRARY) 98 | 99 | list(APPEND __looked_for ${__mkl_lib_upper}_LIBRARY) 100 | list(APPEND MKL_LIBRARIES ${${__mkl_lib_upper}_LIBRARY}) 101 | endforeach() 102 | 103 | if(WIN32) 104 | set(__iomp5_libs iomp5 libiomp5md.lib) 105 | find_library(MKL_RTL_LIBRARY ${__iomp5_libs} 106 | PATHS ${INTEL_ROOT} ${INTEL_ROOT}/compiler ${MKL_ROOT}/.. ${MKL_ROOT}/../compiler 107 | PATH_SUFFIXES ${__path_suffixes} 108 | DOC "Path to OpenMP runtime library") 109 | 110 | list(APPEND __looked_for MKL_RTL_LIBRARY) 111 | list(APPEND MKL_LIBRARIES ${MKL_RTL_LIBRARY}) 112 | endif() 113 | 114 | include(FindPackageHandleStandardArgs) 115 | find_package_handle_standard_args(MKL DEFAULT_MSG ${__looked_for}) 116 | 117 | if(MKL_FOUND) 118 | message(STATUS "Found MKL (include: ${MKL_INCLUDE_DIR}, lib: ${MKL_LIBRARIES}") 119 | endif() 120 | -------------------------------------------------------------------------------- /include/onmt/Dictionary.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include "onmt/onmt_export.h" 8 | #include "onmt/th/Obj.h" 9 | 10 | namespace onmt 11 | { 12 | 13 | class ONMT_EXPORT Dictionary 14 | { 15 | public: 16 | static const size_t pad_id; 17 | static const size_t unk_id; 18 | static const size_t bos_id; 19 | static const size_t eos_id; 20 | 21 | Dictionary(); 22 | Dictionary(th::Class* dict); 23 | 24 | void load(th::Class* dict); 25 | 26 | size_t get_size() const; 27 | 28 | size_t get_word_id(const std::string& word) const; 29 | const std::string& get_id_word(size_t id) const; 30 | 31 | private: 32 | std::vector _id2word; 33 | std::unordered_map _word2id; 34 | }; 35 | 36 | } 37 | -------------------------------------------------------------------------------- /include/onmt/Eigen/MatrixBatch.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | namespace onmt 7 | { 8 | namespace Eigen 9 | { 10 | 11 | template 12 | using RowMajorMat = ::Eigen::Matrix; 13 | 14 | template 15 | using MatrixBatchBase = RowMajorMat; 16 | 17 | template 18 | using Map = ::Eigen::Map; 19 | 20 | template 21 | using RowMajorMatMap = Map>; 22 | 23 | // This class inherits from Eigen::Matrix to simulate a batch of Matrix 24 | // (a.k.a. a 3D Tensor). The object stores a hidden dimension, for example: 25 | // 26 | // Eigen::MatrixBatch mat(2, 4000); // 2x4000 27 | // mat.setHiddenDim(4); // virtually 2x4x1000 but still 2x4000 28 | // 29 | // When setHiddenDim is not called, the class behaves like a standard Matrix. 30 | template 31 | class MatrixBatch: public MatrixBatchBase 32 | { 33 | public: 34 | using MatrixBatchBase::MatrixBatchBase; 35 | 36 | MatrixBatch() 37 | : MatrixBatchBase() 38 | { 39 | } 40 | 41 | template 42 | MatrixBatch(const ::Eigen::MatrixBase& other) 43 | : MatrixBatchBase(other) 44 | { 45 | } 46 | 47 | template 48 | MatrixBatch& operator=(const ::Eigen::MatrixBase& other) 49 | { 50 | this->MatrixBatchBase::operator=(other); 51 | return *this; 52 | } 53 | 54 | void setHiddenDim(size_t size) 55 | { 56 | _rows = size; 57 | _cols = this->cols() / size; 58 | _ndim = 3; 59 | } 60 | 61 | void resetHiddenDim() 62 | { 63 | _ndim = 2; 64 | } 65 | 66 | MatrixBatchBase batch(size_t b) const 67 | { 68 | if (_ndim == 3) 69 | return Eigen::Map >(this->row(b).data(), _rows, _cols); 70 | else 71 | return this->row(b); 72 | } 73 | 74 | MatrixBatchBase sum(int dimension) const 75 | { 76 | size_t new_cols = 0; 77 | 78 | if (dimension == 2) 79 | { 80 | new_cols = _cols; 81 | } 82 | else if (dimension == 3) 83 | { 84 | new_cols = _rows; 85 | } 86 | 87 | MatrixBatchBase out(this->batches(), new_cols); 88 | 89 | for (size_t b = 0; b < this->batches(); ++b) 90 | { 91 | Eigen::Map > mat(this->row(b).data(), _rows, _cols); 92 | if (dimension == 2) 93 | out.row(b).noalias() = mat.colwise().sum(); 94 | else if (dimension == 3) 95 | out.row(b).noalias() = mat.rowwise().sum(); 96 | } 97 | 98 | return out; 99 | } 100 | 101 | void squeeze(int dimension) 102 | { 103 | if (dimension == 2 && _rows == 1) 104 | resetHiddenDim(); 105 | else if (dimension == 3 && _cols == 1) 106 | resetHiddenDim(); 107 | } 108 | 109 | void assign(size_t b, MatrixBatch& mat) 110 | { 111 | this->row(b).noalias() = Eigen::Map >(mat.data(), 1, mat.cols() * mat.rows()); 112 | } 113 | 114 | size_t batches() const 115 | { 116 | return this->rows(); 117 | } 118 | 119 | size_t virtualRows() const 120 | { 121 | if (_ndim == 3) 122 | return _rows; 123 | else 124 | return 1; 125 | } 126 | 127 | size_t virtualCols() const 128 | { 129 | if (_ndim == 3) 130 | return _cols; 131 | else 132 | return this->cols(); 133 | } 134 | 135 | std::ostream& printSizes(std::ostream& os) const 136 | { 137 | os << this->rows() << "x"; 138 | 139 | if (_ndim == 3) 140 | os << _rows << "x" << _cols; 141 | else 142 | os << this->cols(); 143 | 144 | os << std::endl; 145 | 146 | return os; 147 | } 148 | 149 | private: 150 | size_t _ndim; 151 | size_t _rows; 152 | size_t _cols; 153 | }; 154 | 155 | } 156 | 157 | template 158 | std::ostream& operator<<(std::ostream& os, std::vector > table) 159 | { 160 | os << "{" << std::endl; 161 | 162 | for (const auto& mat: table) 163 | { 164 | os << " MatrixBatch<" << typeid(T).name() << ">: "; 165 | mat.printSizes(os); 166 | } 167 | 168 | os << "}"; 169 | 170 | return os; 171 | } 172 | 173 | } 174 | -------------------------------------------------------------------------------- /include/onmt/ITranslator.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | #include "onmt/ITokenizer.h" 7 | #include "onmt/TranslationOptions.h" 8 | #include "onmt/TranslationResult.h" 9 | 10 | namespace onmt 11 | { 12 | 13 | class ITranslator 14 | { 15 | public: 16 | virtual ~ITranslator() = default; 17 | 18 | // Translate a raw text. If the tokenizer is not given, the input text is split on spaces. 19 | virtual std::string 20 | translate(const std::string& text, 21 | const TranslationOptions& options = TranslationOptions()); 22 | virtual std::string 23 | translate(const std::string& text, 24 | float& score, 25 | size_t& count_tgt_words, 26 | size_t& count_tgt_unk_words, 27 | size_t& count_src_words, 28 | size_t& count_src_unk_words, 29 | const TranslationOptions& options = TranslationOptions()); 30 | virtual std::string 31 | translate(const std::string& text, 32 | ITokenizer& tokenizer, 33 | const TranslationOptions& options = TranslationOptions()); 34 | virtual std::string 35 | translate(const std::string& text, 36 | ITokenizer& tokenizer, 37 | float& score, 38 | size_t& count_tgt_words, 39 | size_t& count_tgt_unk_words, 40 | size_t& count_src_words, 41 | size_t& count_src_unk_words, 42 | const TranslationOptions& options = TranslationOptions()); 43 | 44 | // Multiple translations version of the previous methods. 45 | 46 | virtual std::vector 47 | get_translations(const std::string& text, 48 | const TranslationOptions& options = TranslationOptions()); 49 | virtual std::vector 50 | get_translations(const std::string& text, 51 | std::vector& scores, 52 | std::vector& count_tgt_words, 53 | std::vector& count_tgt_unk_words, 54 | size_t& count_src_words, 55 | size_t& count_src_unk_words, 56 | const TranslationOptions& options = TranslationOptions()); 57 | virtual std::vector 58 | get_translations(const std::string& text, 59 | ITokenizer& tokenizer, 60 | const TranslationOptions& options = TranslationOptions()); 61 | virtual std::vector 62 | get_translations(const std::string& text, 63 | ITokenizer& tokenizer, 64 | std::vector& scores, 65 | std::vector& count_tgt_words, 66 | std::vector& count_tgt_unk_words, 67 | size_t& count_src_words, 68 | size_t& count_src_unk_words, 69 | const TranslationOptions& options = TranslationOptions()) = 0; 70 | 71 | // Translate pre-tokenized text. (As with previous methods, this is also for multiple translations.) 72 | virtual TranslationResult 73 | translate(const std::vector& tokens, 74 | const std::vector >& features, 75 | const TranslationOptions& options = TranslationOptions()); 76 | virtual TranslationResult 77 | translate(const std::vector& tokens, 78 | const std::vector >& features, 79 | size_t& count_src_unk_words, 80 | const TranslationOptions& options = TranslationOptions()) = 0; 81 | 82 | // Batch version of the previous methods: translate several sequences at once. 83 | 84 | virtual std::vector 85 | translate_batch(const std::vector& texts, 86 | const TranslationOptions& options = TranslationOptions()); 87 | virtual std::vector 88 | translate_batch(const std::vector& texts, 89 | std::vector& scores, 90 | std::vector& count_tgt_words, 91 | std::vector& count_tgt_unk_words, 92 | std::vector& count_src_words, 93 | std::vector& count_src_unk_words, 94 | const TranslationOptions& options = TranslationOptions()); 95 | virtual std::vector 96 | translate_batch(const std::vector& texts, 97 | ITokenizer& tokenizer, 98 | const TranslationOptions& options = TranslationOptions()); 99 | virtual std::vector 100 | translate_batch(const std::vector& texts, 101 | ITokenizer& tokenizer, 102 | std::vector& scores, 103 | std::vector& count_tgt_words, 104 | std::vector& count_tgt_unk_words, 105 | std::vector& count_src_words, 106 | std::vector& count_src_unk_words, 107 | const TranslationOptions& options = TranslationOptions()); 108 | 109 | // Multiple translations version of the previous methods. 110 | 111 | virtual std::vector > 112 | get_translations_batch(const std::vector& texts, 113 | const TranslationOptions& options = TranslationOptions()); 114 | virtual std::vector > 115 | get_translations_batch(const std::vector& texts, 116 | std::vector >& scores, 117 | std::vector >& count_tgt_words, 118 | std::vector >& count_tgt_unk_words, 119 | std::vector& count_src_words, 120 | std::vector& count_src_unk_words, 121 | const TranslationOptions& options = TranslationOptions()); 122 | virtual std::vector > 123 | get_translations_batch(const std::vector& texts, 124 | ITokenizer& tokenizer, 125 | const TranslationOptions& options = TranslationOptions()); 126 | virtual std::vector > 127 | get_translations_batch(const std::vector& texts, 128 | ITokenizer& tokenizer, 129 | std::vector >& scores, 130 | std::vector >& count_tgt_words, 131 | std::vector >& count_tgt_unk_words, 132 | std::vector& count_src_words, 133 | std::vector& count_src_unk_words, 134 | const TranslationOptions& options = TranslationOptions()) = 0; 135 | 136 | // Translate pre-tokenized text. (As with previous methods, this is also for multiple translations.) 137 | virtual TranslationResult 138 | translate_batch(const std::vector >& batch_tokens, 139 | const std::vector > >& batch_features, 140 | const TranslationOptions& options = TranslationOptions()); 141 | virtual TranslationResult 142 | translate_batch(const std::vector >& batch_tokens, 143 | const std::vector > >& batch_features, 144 | std::vector& batch_count_src_unk_words, 145 | const TranslationOptions& options = TranslationOptions()) = 0; 146 | }; 147 | 148 | } 149 | -------------------------------------------------------------------------------- /include/onmt/Logger.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "onmt/onmt_export.h" 4 | #include 5 | 6 | #include 7 | 8 | namespace onmt 9 | { 10 | 11 | class ONMT_EXPORT Logger 12 | { 13 | public: 14 | static void init(const std::string& log_file, bool disable_logs, const std::string& log_level); 15 | static boost::log::sources::severity_logger_mt& lg(); 16 | 17 | private: 18 | static boost::log::sources::severity_logger_mt _lg; 19 | }; 20 | 21 | } 22 | -------------------------------------------------------------------------------- /include/onmt/Model.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include "onmt/Dictionary.h" 6 | #include "onmt/nn/ModuleFactory.h" 7 | #include "onmt/th/Env.h" 8 | 9 | namespace onmt 10 | { 11 | 12 | template 13 | class Model 14 | { 15 | public: 16 | Model(const std::string& filename); 17 | 18 | void create_graph(nn::ModuleFactory& factory, 19 | std::vector& encoder, 20 | std::vector& decoder); 21 | 22 | const Dictionary& get_src_dict() const; 23 | const Dictionary& get_tgt_dict() const; 24 | const std::vector& get_src_feat_dicts() const; 25 | const std::vector& get_tgt_feat_dicts() const; 26 | 27 | const std::string& get_option_string(const std::string& key) const; 28 | bool get_option_flag(const std::string& key, bool default_value = false) const; 29 | 30 | template 31 | T get_option_value(const std::string& key, T default_value = 0) const; 32 | 33 | private: 34 | void load_options(th::Table* obj); 35 | void load_dictionaries(th::Table* obj, Dictionary& words, std::vector& features); 36 | void load_dictionaries(th::Table* obj); 37 | 38 | void load_modules(th::Table* obj, 39 | std::vector& modules, 40 | nn::ModuleFactory& module_factory) const; 41 | 42 | th::Env _env; 43 | th::Table* _root; 44 | 45 | Dictionary _src_dict; 46 | Dictionary _tgt_dict; 47 | std::vector _src_feat_dicts; 48 | std::vector _tgt_feat_dicts; 49 | 50 | std::unordered_map _options_value; 51 | std::unordered_map _options_str; 52 | const std::string _empty_str; 53 | }; 54 | 55 | } 56 | 57 | #include "Model.hxx" 58 | -------------------------------------------------------------------------------- /include/onmt/Model.hxx: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "onmt/th/Obj.h" 4 | #include "onmt/Utils.h" 5 | 6 | namespace onmt 7 | { 8 | 9 | template 10 | Model::Model(const std::string& filename) 11 | { 12 | ONMT_LOG_STREAM_SEV("Loading '" << filename << "'...", boost::log::trivial::info); 13 | THFile* tf = THDiskFile_new(filename.c_str(), "r", 0); 14 | THFile_binary(tf); 15 | THDiskFile_longSize(tf, th::dfLongSize); 16 | 17 | th::Obj* obj = read_obj(tf, _env); 18 | 19 | THFile_free(tf); 20 | 21 | _root = dynamic_cast(obj); 22 | 23 | load_options(th::get_field(_root, "options")); 24 | load_dictionaries(th::get_field(_root, "dicts")); 25 | } 26 | 27 | template 28 | void Model::load_options(th::Table* obj) 29 | { 30 | const auto& opt = obj->get_object(); 31 | 32 | for (auto pair: opt) 33 | { 34 | const std::string& key = pair.first; 35 | 36 | if (dynamic_cast(pair.second)) 37 | { 38 | double value = static_cast(dynamic_cast(pair.second)->get_value()); 39 | _options_value[key] = value; 40 | } 41 | else if (dynamic_cast(pair.second)) 42 | { 43 | bool value = dynamic_cast(pair.second)->get_value(); 44 | _options_value[key] = value ? 1 : 0; 45 | } 46 | else if (dynamic_cast(pair.second)) 47 | { 48 | const std::string& str = dynamic_cast(pair.second)->get_value(); 49 | _options_str[key] = str; 50 | } 51 | } 52 | } 53 | 54 | template 55 | void Model::load_dictionaries(th::Table* obj, 56 | Dictionary& words, 57 | std::vector& features) 58 | { 59 | words.load(th::get_field(obj, "words")); 60 | 61 | auto features_set = th::get_field(obj, "features"); 62 | auto features_dicts = features_set->get_array(); 63 | 64 | for (size_t i = 0; i < features_dicts.size(); ++i) 65 | { 66 | features.emplace_back(dynamic_cast(features_dicts[i])); 67 | } 68 | } 69 | 70 | template 71 | void Model::load_dictionaries(th::Table* obj) 72 | { 73 | load_dictionaries(th::get_field(obj, "src"), _src_dict, _src_feat_dicts); 74 | load_dictionaries(th::get_field(obj, "tgt"), _tgt_dict, _tgt_feat_dicts); 75 | } 76 | 77 | template 78 | void Model::load_modules( 79 | th::Table* obj, 80 | std::vector& modules, 81 | nn::ModuleFactory& module_factory) const 82 | { 83 | auto modules_set = th::get_field(obj, "modules"); 84 | auto modules_data = modules_set->get_array(); 85 | 86 | for (auto module: modules_data) 87 | { 88 | th::Class* mod = dynamic_cast(module); 89 | 90 | if (mod) 91 | modules.push_back(module_factory.build(mod)); 92 | else if (dynamic_cast(module)) 93 | load_modules(dynamic_cast(module), modules, module_factory); 94 | } 95 | } 96 | 97 | template 98 | struct record_graph { 99 | record_graph():currentGraph(nullptr),attentionGraph(nullptr) {} 100 | nn::Module *currentGraph; 101 | nn::Module *attentionGraph; 102 | }; 103 | 104 | template 105 | static void* _find_attentiongraph(nn::Module* M, void* t) 106 | { 107 | if (M->get_name() == "nn.gModule"){ 108 | ((record_graph*)t)->currentGraph = M; 109 | } 110 | else if (M->get_custom_name() == "softmaxAttn") { 111 | ((record_graph*)t)->attentionGraph = 112 | ((record_graph*)t)->currentGraph; 113 | } 114 | return 0; 115 | } 116 | 117 | template 118 | static void* _mark_block(nn::Module* M, void* t) 119 | { 120 | M->set_block((const char*)t); 121 | return 0; 122 | } 123 | 124 | template 125 | void Model::create_graph( 126 | nn::ModuleFactory& factory, 127 | std::vector& encoder, 128 | std::vector& decoder) 129 | { 130 | auto models = th::get_field(_root, "models"); 131 | load_modules(th::get_field(models, "encoder"), encoder, factory); 132 | load_modules(th::get_field(models, "decoder"), decoder, factory); 133 | 134 | /* annotate the different modules for profiling */ 135 | factory.get_module(encoder[0])->apply(_mark_block, (void*)"encoder_fwd"); 136 | if (encoder.size() > 1) 137 | factory.get_module(encoder[1])->apply(_mark_block, (void*)"encoder_bwd"); 138 | auto decoder_mod = factory.get_module(decoder[0]); 139 | decoder_mod->apply(_mark_block, (void*)"decoder"); 140 | factory.get_module(decoder[1])->apply(_mark_block, (void*)"generator"); 141 | 142 | /* find the attention module and annotate it specifically */ 143 | record_graph rg; 144 | decoder_mod->apply(_find_attentiongraph, (void*)&rg); 145 | if (rg.attentionGraph) 146 | rg.attentionGraph->apply(_mark_block, (void*)"attention"); 147 | } 148 | 149 | template 150 | const Dictionary& Model::get_src_dict() const 151 | { 152 | return _src_dict; 153 | } 154 | template 155 | const Dictionary& Model::get_tgt_dict() const 156 | { 157 | return _tgt_dict; 158 | } 159 | template 160 | const std::vector& Model::get_src_feat_dicts() const 161 | { 162 | return _src_feat_dicts; 163 | } 164 | template 165 | const std::vector& Model::get_tgt_feat_dicts() const 166 | { 167 | return _tgt_feat_dicts; 168 | } 169 | 170 | template 171 | const std::string& Model::get_option_string(const std::string& key) const 172 | { 173 | auto it = _options_str.find(key); 174 | 175 | if (it == _options_str.cend()) 176 | return _empty_str; 177 | 178 | return it->second; 179 | } 180 | 181 | template 182 | bool Model::get_option_flag(const std::string& key, 183 | bool default_value) const 184 | { 185 | return get_option_value(key, static_cast(default_value)) == 1; 186 | } 187 | 188 | template 189 | template 190 | T Model::get_option_value(const std::string& key, T default_value) const 191 | { 192 | auto it = _options_value.find(key); 193 | 194 | if (it == _options_value.cend()) 195 | return default_value; 196 | 197 | return static_cast(it->second); 198 | } 199 | 200 | } 201 | -------------------------------------------------------------------------------- /include/onmt/PhraseTable.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | #include "onmt/onmt_export.h" 7 | 8 | namespace onmt 9 | { 10 | 11 | class ONMT_EXPORT PhraseTable 12 | { 13 | public: 14 | PhraseTable(const std::string& file); 15 | 16 | bool is_empty() const; 17 | size_t get_size() const; 18 | 19 | std::string lookup(const std::string& src) const; 20 | 21 | private: 22 | std::unordered_map _src_to_tgt; 23 | }; 24 | 25 | } 26 | -------------------------------------------------------------------------------- /include/onmt/Profiler.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | namespace onmt 11 | { 12 | 13 | class Profiler 14 | { 15 | public: 16 | Profiler(bool enabled = false, bool start_chrono = false); 17 | ~Profiler(); 18 | 19 | void enable(); 20 | void disable(); 21 | void reset(); 22 | 23 | int get_id() const; 24 | 25 | void start(); 26 | void stop(const std::string& module_name); 27 | 28 | friend std::ostream& operator<<(std::ostream& os, const Profiler& profiler); 29 | 30 | private: 31 | bool _enabled; 32 | std::chrono::microseconds _total_time; 33 | std::stack _start; 34 | std::unordered_map _cumulated; 35 | int _id; 36 | 37 | static std::mutex _profiler_mutex; 38 | static int _counter; 39 | }; 40 | 41 | } 42 | -------------------------------------------------------------------------------- /include/onmt/StorageLoader.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "Eigen/MatrixBatch.h" 4 | 5 | #include "onmt/th/Utils.h" 6 | #include "onmt/th/Obj.h" 7 | 8 | namespace onmt 9 | { 10 | 11 | // This class can be specialized to implement different loading behaviours 12 | // (including conversions) according to the source and/or target types. 13 | template 14 | class StorageLoader; 15 | 16 | // This default specialization maps the storage to a Eigen structure without 17 | // any changes (same precision and same storage order). 18 | template 19 | class StorageLoader, T> 20 | { 21 | public: 22 | static Eigen::RowMajorMatMap get_matrix(th::Table* module_data, 23 | const std::string& name) 24 | { 25 | th::Tensor* tensor = th::get_field*>(module_data, name); 26 | 27 | if (!tensor) 28 | return Eigen::RowMajorMatMap(nullptr, 0, 0); 29 | 30 | size_t rows = tensor->get_size()[0]; 31 | size_t cols = tensor->get_dimension() == 1 ? 1 : tensor->get_size()[1]; 32 | 33 | const T* storage_data = get_tensor_data(tensor); 34 | 35 | return Eigen::RowMajorMatMap(storage_data, rows, cols); 36 | } 37 | }; 38 | 39 | } 40 | -------------------------------------------------------------------------------- /include/onmt/SubDict.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #include "onmt/onmt_export.h" 9 | #include "onmt/Eigen/MatrixBatch.h" 10 | #include "onmt/Dictionary.h" 11 | 12 | namespace onmt 13 | { 14 | 15 | class ONMT_EXPORT SubDict { 16 | public: 17 | /* build subdict class given a dictionary and map file */ 18 | SubDict(const std::string& map_file, const Dictionary& dict); 19 | 20 | /* given a sequence of words, extract sub-dictionary */ 21 | void extract(const std::vector& words, std::set& r) const; 22 | 23 | bool empty() const 24 | { 25 | return _map_rules.size() == 0; 26 | } 27 | 28 | private: 29 | std::vector > > _map_rules; 30 | }; 31 | 32 | } 33 | -------------------------------------------------------------------------------- /include/onmt/Threads.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "onmt/onmt_export.h" 4 | 5 | namespace onmt 6 | { 7 | 8 | class ONMT_EXPORT Threads 9 | { 10 | public: 11 | static void set(int number); 12 | static int get(); 13 | }; 14 | 15 | } 16 | -------------------------------------------------------------------------------- /include/onmt/TranslationOptions.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include "onmt/onmt_export.h" 6 | 7 | namespace onmt 8 | { 9 | 10 | class ONMT_EXPORT TranslationOptions 11 | { 12 | public: 13 | TranslationOptions(size_t max_sent_length = 250, 14 | size_t beam_size = 5, 15 | size_t n_best = 1, 16 | bool replace_unk = true, 17 | bool replace_unk_tagged = false); 18 | 19 | size_t max_sent_length() const; 20 | size_t& max_sent_length(); 21 | size_t beam_size() const; 22 | size_t& beam_size(); 23 | size_t n_best() const; 24 | size_t& n_best(); 25 | bool replace_unk() const; 26 | bool& replace_unk(); 27 | bool replace_unk_tagged() const; 28 | bool& replace_unk_tagged(); 29 | 30 | private: 31 | size_t _max_sent_length; 32 | size_t _beam_size; 33 | size_t _n_best; 34 | bool _replace_unk; 35 | bool _replace_unk_tagged; 36 | }; 37 | 38 | } 39 | -------------------------------------------------------------------------------- /include/onmt/TranslationResult.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | #include "onmt/onmt_export.h" 7 | 8 | namespace onmt 9 | { 10 | 11 | class ONMT_EXPORT TranslationResult 12 | { 13 | public: 14 | TranslationResult(const std::vector > >& words, 15 | const std::vector > > >& features, 16 | const std::vector > > >& attention, 17 | const std::vector >& score, 18 | const std::vector >& count_unk_words); 19 | 20 | const std::vector& get_words(size_t job_index = 0, size_t translation_index = 0) const; 21 | const std::vector >& get_features(size_t job_index = 0, size_t translation_index = 0) const; 22 | const std::vector >& get_attention(size_t job_index = 0, size_t translation_index = 0) const; 23 | float get_score(size_t job_index = 0, size_t translation_index = 0) const; 24 | size_t get_count_unk_words(size_t job_index = 0, size_t translation_index = 0) const; 25 | 26 | const std::vector >& get_words_job(size_t job_index = 0) const; 27 | const std::vector > >& get_features_job(size_t job_index = 0) const; 28 | const std::vector > >& get_attention_job(size_t job_index = 0) const; 29 | const std::vector& get_score_job(size_t job_index = 0) const; 30 | const std::vector& get_count_unk_words_job(size_t job_index = 0) const; 31 | size_t count_job(size_t job_index = 0) const; 32 | 33 | const std::vector > >& get_words_batch() const; 34 | const std::vector > > >& get_features_batch() const; 35 | const std::vector > > >& get_attention_batch() const; 36 | const std::vector >& get_score_batch() const; 37 | const std::vector >& get_count_unk_words_batch() const; 38 | 39 | size_t count() const; 40 | bool has_features() const; 41 | 42 | private: 43 | std::vector > > _words; 44 | std::vector > > > _features; 45 | std::vector > > > _attention; 46 | std::vector > _score; 47 | std::vector > _count_unk_words; 48 | }; 49 | 50 | } 51 | -------------------------------------------------------------------------------- /include/onmt/Translator.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "Eigen/MatrixBatch.h" 4 | 5 | #include "Model.h" 6 | #include "Dictionary.h" 7 | #include "PhraseTable.h" 8 | #include "ITranslator.h" 9 | #include "SubDict.h" 10 | 11 | namespace onmt 12 | { 13 | 14 | template , 15 | typename MatIn = Eigen::RowMajorMatMap, 16 | typename MatEmb = Eigen::RowMajorMatMap, 17 | typename ModelT = float> 18 | class Translator: public ITranslator 19 | { 20 | public: 21 | friend class TranslatorFactory; 22 | 23 | std::vector 24 | get_translations(const std::string& text, 25 | ITokenizer& tokenizer, 26 | std::vector& scores, 27 | std::vector& count_tgt_words, 28 | std::vector& count_tgt_unk_words, 29 | size_t& count_src_words, 30 | size_t& count_src_unk_words, 31 | const TranslationOptions& options = TranslationOptions()) override; 32 | 33 | std::vector > 34 | get_translations_batch(const std::vector& texts, 35 | ITokenizer& tokenizer, 36 | std::vector >& scores, 37 | std::vector >& count_tgt_words, 38 | std::vector >& count_tgt_unk_words, 39 | std::vector& count_src_words, 40 | std::vector& count_src_unk_words, 41 | const TranslationOptions& options = TranslationOptions()) override; 42 | 43 | TranslationResult 44 | translate(const std::vector& tokens, 45 | const std::vector >& features, 46 | size_t& count_src_unk_words, 47 | const TranslationOptions& options = TranslationOptions()) override; 48 | 49 | TranslationResult 50 | translate_batch(const std::vector >& batch_tokens, 51 | const std::vector > >& batch_features, 52 | std::vector& batch_count_src_unk_words, 53 | const TranslationOptions& options = TranslationOptions()) override; 54 | 55 | protected: 56 | Translator(const std::string& model, 57 | const std::string& phrase_table, 58 | const std::string& vocab_mapping, 59 | bool cuda, 60 | bool qlinear, 61 | bool profiling); 62 | Translator(const Translator& other); 63 | 64 | /* profiling - starting first to profile load time */ 65 | bool _profiling; 66 | Profiler _profiler; 67 | 68 | // Members shared across translator instances. 69 | std::shared_ptr> _model; 70 | std::shared_ptr _phrase_table; 71 | std::shared_ptr _subdict; 72 | std::shared_ptr> _encoder_mod_ids; 73 | std::shared_ptr> _decoder_mod_ids; 74 | 75 | bool _cuda; 76 | bool _qlinear; 77 | 78 | std::vector 79 | get_encoder_input(size_t t, 80 | const std::vector >& batch_ids, 81 | const std::vector > >& batch_feat_ids, 82 | const std::vector& rnn_state_enc) const; 83 | 84 | void 85 | encode(const std::vector >& batch_tokens, 86 | const std::vector >& batch_ids, 87 | const std::vector > >& batch_feat_ids, 88 | std::vector& rnn_state_enc, 89 | MatFwd& context); 90 | 91 | TranslationResult 92 | decode(const std::vector >& batch_tokens, 93 | size_t source_l, 94 | const std::vector& rnn_state_enc, 95 | const MatFwd& context, 96 | const std::vector& subvocab, 97 | const TranslationOptions& options); 98 | 99 | private: 100 | void init_graph(); 101 | 102 | nn::ModuleFactory _factory; 103 | nn::Module* _encoder; 104 | nn::Module* _encoder_bwd; 105 | nn::Module* _decoder; 106 | nn::Module* _generator; 107 | }; 108 | 109 | 110 | template 111 | using DefaultTranslator = Translator, 112 | Eigen::RowMajorMatMap, 113 | Eigen::RowMajorMatMap, 114 | T>; 115 | 116 | } 117 | 118 | #include "Translator.hxx" 119 | -------------------------------------------------------------------------------- /include/onmt/TranslatorFactory.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include "onmt/onmt_export.h" 6 | #include "Translator.h" 7 | 8 | namespace onmt 9 | { 10 | 11 | class ONMT_EXPORT TranslatorFactory 12 | { 13 | public: 14 | static std::unique_ptr build(const std::string& model, 15 | const std::string& phrase_table = "", 16 | const std::string& vocab_mapping = "", 17 | bool cuda = false, 18 | bool qlinear = false, 19 | bool profiling = false); 20 | 21 | static std::unique_ptr clone(const std::unique_ptr& translator); 22 | }; 23 | 24 | } 25 | -------------------------------------------------------------------------------- /include/onmt/Utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | #ifdef WITH_BOOST_LOG 7 | # include "onmt/Logger.h" 8 | # define ONMT_LOG_STREAM_SEV(s, l) BOOST_LOG_SEV(onmt::Logger::lg(), l) << s 9 | #else 10 | # define ONMT_LOG_STREAM_SEV(s, l) ((void)0) 11 | #endif 12 | 13 | namespace onmt 14 | { 15 | 16 | inline void *align( std::size_t alignment, std::size_t size, 17 | void *&ptr, std::size_t &space ) { 18 | // Copyright 2014 David Krauss 19 | // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=57350 20 | std::uintptr_t pn = reinterpret_cast< std::uintptr_t >( ptr ); 21 | std::uintptr_t aligned = ( pn + alignment - 1 ) & - alignment; 22 | std::size_t padding = aligned - pn; 23 | if ( space < size + padding ) return nullptr; 24 | space -= padding; 25 | return ptr = reinterpret_cast< void * >( aligned ); 26 | } 27 | 28 | } 29 | -------------------------------------------------------------------------------- /include/onmt/android_gnustl_compat.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #ifdef ANDROID_GNUSTL_COMPAT 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | namespace std 10 | { 11 | template 12 | std::string to_string(T value) 13 | { 14 | std::ostringstream os; 15 | os << value; 16 | return os.str(); 17 | } 18 | 19 | inline unsigned long stoul(const std::string& s) 20 | { 21 | return atol(s.c_str()); 22 | } 23 | } 24 | 25 | #endif 26 | -------------------------------------------------------------------------------- /include/onmt/cuda/Kernels.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | namespace onmt 4 | { 5 | namespace cuda 6 | { 7 | namespace kernels 8 | { 9 | 10 | void add(float* a, const float* b, int len); 11 | 12 | } 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /include/onmt/cuda/Utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include 8 | #include 9 | 10 | #define CUDA_CHECK(ans) { onmt::cuda::cudaAssert((ans), __FILE__, __LINE__); } 11 | #define CUBLAS_CHECK(ans) { onmt::cuda::cublasAssert((ans), __FILE__, __LINE__); } 12 | 13 | namespace onmt 14 | { 15 | namespace cuda 16 | { 17 | 18 | std::string cublasGetStatusString(cublasStatus_t status); 19 | 20 | inline 21 | void cudaAssert(cudaError_t code, const std::string& file, int line) 22 | { 23 | if (code != cudaSuccess) 24 | throw std::runtime_error("CUDA failed with error " + std::string(cudaGetErrorString(code)) + " at " + file + ":" + std::to_string(line)); 25 | } 26 | 27 | inline 28 | void cublasAssert(cublasStatus_t status, const std::string& file, int line) 29 | { 30 | if (status != CUBLAS_STATUS_SUCCESS) 31 | throw std::runtime_error("cuBLAS failed with status " + cublasGetStatusString(status) + " at " + file + ":" + std::to_string(line)); 32 | } 33 | 34 | template 35 | void to_device(T* device, const T* host, int rows, int cols) 36 | { 37 | CUBLAS_CHECK(cublasSetMatrix(rows, cols, sizeof (T), host, rows, device, rows)); 38 | } 39 | 40 | template 41 | T* to_device(const T* host, int rows, int cols) 42 | { 43 | T* device = nullptr; 44 | 45 | CUDA_CHECK(cudaMalloc(&device, rows * cols * sizeof (T))); 46 | CUBLAS_CHECK(cublasSetMatrix(rows, cols, sizeof (T), host, rows, device, rows)); 47 | 48 | return device; 49 | } 50 | 51 | template 52 | T* to_device(int rows, int cols) 53 | { 54 | T* device = nullptr; 55 | 56 | CUDA_CHECK(cudaMalloc(&device, rows * cols * sizeof (T))); 57 | 58 | return device; 59 | } 60 | 61 | template 62 | T* to_device(const T* host, int n) 63 | { 64 | T* device = nullptr; 65 | 66 | CUDA_CHECK(cudaMalloc(&device, n * sizeof (T))); 67 | 68 | if (host) 69 | CUBLAS_CHECK(cublasSetVector(n, sizeof (T), host, 1, device, 1)); 70 | 71 | return device; 72 | } 73 | 74 | template 75 | T* to_host(const T* device, T* host, int rows, int cols) 76 | { 77 | if (!host) 78 | host = (T*) std::malloc(rows * cols * sizeof (T)); 79 | 80 | CUBLAS_CHECK(cublasGetMatrix(rows, cols, sizeof (T), device, rows, host, rows)); 81 | 82 | return host; 83 | } 84 | 85 | } 86 | } 87 | -------------------------------------------------------------------------------- /include/onmt/nn/CAddTable.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "onmt/nn/Module.h" 4 | 5 | namespace onmt 6 | { 7 | namespace nn 8 | { 9 | 10 | template 11 | class CAddTable: public Module 12 | { 13 | public: 14 | CAddTable() 15 | : Module("nn.CAddTable") 16 | { 17 | } 18 | 19 | CAddTable(const CAddTable& other) 20 | : Module(other) 21 | { 22 | } 23 | 24 | Module* clone(const ModuleFactory*) const override 25 | { 26 | return new CAddTable(*this); 27 | } 28 | 29 | void forward_impl(const std::vector& inputs) override 30 | { 31 | this->_output = inputs[0]; 32 | 33 | for (size_t i = 1; i < inputs.size(); ++i) 34 | this->_output.noalias() += inputs[i]; 35 | } 36 | 37 | }; 38 | 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /include/onmt/nn/CMulTable.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "onmt/nn/Module.h" 4 | 5 | namespace onmt 6 | { 7 | namespace nn 8 | { 9 | 10 | template 11 | class CMulTable: public Module 12 | { 13 | public: 14 | CMulTable() 15 | : Module("nn.CMulTable") 16 | { 17 | } 18 | 19 | CMulTable(const CMulTable& other) 20 | : Module(other) 21 | { 22 | } 23 | 24 | Module* clone(const ModuleFactory*) const override 25 | { 26 | return new CMulTable(*this); 27 | } 28 | 29 | void forward_impl(const std::vector& inputs) override 30 | { 31 | this->_output = inputs[0]; 32 | 33 | for (size_t i = 1; i < inputs.size(); ++i) 34 | this->_output.noalias() = this->_output.cwiseProduct(inputs[i]); 35 | } 36 | }; 37 | 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /include/onmt/nn/ConcatTable.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "onmt/nn/Container.h" 4 | 5 | namespace onmt 6 | { 7 | namespace nn 8 | { 9 | 10 | template 11 | class ConcatTable: public Container 12 | { 13 | public: 14 | ConcatTable(th::Table* data, ModuleFactory& factory) 15 | : Container("nn.ConcatTable", data, factory) 16 | { 17 | } 18 | 19 | ConcatTable(const ConcatTable& other, 20 | const ModuleFactory& factory) 21 | : Container(other, factory) 22 | { 23 | } 24 | 25 | Module* clone(const ModuleFactory* factory) const override 26 | { 27 | return new ConcatTable(*this, *factory); 28 | } 29 | 30 | void forward_impl(const std::vector& inputs) override 31 | { 32 | this->_outputs.resize(this->_sequence->size()); 33 | 34 | for (size_t i = 0; i < this->_sequence->size(); ++i) 35 | this->_outputs[i] = this->_factory.get_module((*(this->_sequence))[i])->forward(inputs)[0]; 36 | } 37 | }; 38 | 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /include/onmt/nn/Container.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include "onmt/nn/ModuleFactory.h" 5 | #include "onmt/th/Utils.h" 6 | #include "onmt/th/Obj.h" 7 | 8 | namespace onmt 9 | { 10 | namespace nn 11 | { 12 | 13 | template 14 | class Container: public Module 15 | { 16 | public: 17 | Container(const std::string& name, 18 | th::Table* data, 19 | ModuleFactory& factory) 20 | : Module(name, false) 21 | , _factory(factory) 22 | , _sequence(new std::vector()) 23 | { 24 | th::Table* modules = th::get_field(data, "modules"); 25 | _sequence->reserve(modules->get_array().size()); 26 | 27 | for (auto module_obj: modules->get_array()) 28 | { 29 | th::Class* module = dynamic_cast(module_obj); 30 | _sequence->push_back(factory.build(module)); 31 | } 32 | } 33 | 34 | Container(const Container& other, 35 | const ModuleFactory& factory) 36 | : Module(other) 37 | , _factory(factory) 38 | , _sequence(other._sequence) 39 | { 40 | } 41 | 42 | /* apply recursively a generic function to each node of the graph */ 43 | void* apply(void* (*func)(Module*, void*), void* data) 44 | { 45 | func(this, data); 46 | 47 | for (auto child: *_sequence) 48 | _factory.get_module(child)->apply(func, data); 49 | return 0; 50 | } 51 | 52 | virtual void forward_impl(const std::vector& inputs) = 0; 53 | 54 | protected: 55 | const ModuleFactory& _factory; 56 | std::shared_ptr> _sequence; 57 | }; 58 | 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /include/onmt/nn/Graph.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | #include "onmt/nn/ModuleFactory.h" 7 | #include "onmt/nn/Node.h" 8 | #include "onmt/th/Obj.h" 9 | #include "onmt/th/Utils.h" 10 | 11 | namespace onmt 12 | { 13 | namespace nn 14 | { 15 | 16 | template 17 | class Graph: public Module 18 | { 19 | public: 20 | Graph(th::Class* module, 21 | const std::string& name, 22 | ModuleFactory& factory) 23 | : Module(name, false) 24 | , _root_id(build_graph(dynamic_cast( 25 | dynamic_cast(module->get_data()) 26 | ->get_object().at("forwardnodes")) 27 | ->get_array()[0], 28 | factory)) 29 | , _root(&(_node_map.find(_root_id)->second)) 30 | { 31 | } 32 | 33 | Graph(const Graph& other, 34 | const ModuleFactory& factory) 35 | : Module(other) 36 | , _root_id(other._root_id) 37 | { 38 | for (const auto& it: other._node_map) 39 | { 40 | _node_map.emplace(std::piecewise_construct, 41 | std::forward_as_tuple(it.first), std::forward_as_tuple(it.second, factory, _node_map)); 42 | } 43 | _root = &(_node_map.find(_root_id)->second); 44 | } 45 | 46 | Module* clone(const ModuleFactory* factory) const override 47 | { 48 | return new Graph(*this, *factory); 49 | } 50 | 51 | void forward_impl(const std::vector& inputs) override 52 | { 53 | _root->forward(inputs, this->_outputs, nullptr); 54 | } 55 | 56 | Module* find(const std::string& custom_name) override 57 | { 58 | if (this->_custom_name == custom_name) 59 | return this; 60 | 61 | return _root->find(custom_name); 62 | } 63 | 64 | void* apply(void* (*func)(Module*, void*), void* data) override 65 | { 66 | for(auto &it: _node_map) 67 | it.second.set_unvisited(); 68 | 69 | func(this, data); 70 | 71 | _root->apply(func, data); 72 | return nullptr; 73 | } 74 | 75 | // Dump the graph in the DOT format. 76 | void to_dot(const std::string& file, const std::string& name) 77 | { 78 | for(auto &it: _node_map) 79 | it.second.set_unvisited(); 80 | 81 | std::ofstream out(file.c_str()); 82 | out << "digraph " << name << " {" << std::endl; 83 | _root->to_dot(out); 84 | out << "}" << std::endl; 85 | } 86 | 87 | private: 88 | std::unordered_map> _node_map; 89 | size_t _root_id; 90 | Node* _root; 91 | 92 | size_t build_graph(th::Obj* root, ModuleFactory& factory) 93 | { 94 | th::Class* root_class = dynamic_cast(root); 95 | th::Table* root_fields = dynamic_cast(root_class->get_data()); 96 | th::Number* root_id = dynamic_cast(root_fields->get_object().at("id")); 97 | 98 | size_t id = static_cast(root_id->get_value()); 99 | 100 | auto placeholder = _node_map.emplace(std::piecewise_construct, 101 | std::forward_as_tuple(id), std::forward_as_tuple(id, factory, _node_map)); 102 | Node& root_node = placeholder.first->second; 103 | 104 | if (!placeholder.second) // already exists 105 | return id; 106 | 107 | th::Table* root_data = th::get_field(root_fields, "data"); 108 | th::Class* module_class = th::get_field(root_data, "module"); 109 | if (!module_class) 110 | root_node.set_module_id(std::numeric_limits::max()); 111 | else 112 | root_node.set_module_id(factory.build(module_class)); 113 | th::Number* selectindex = th::get_field(root_data, "selectindex"); 114 | if (selectindex) 115 | root_node.set_select_index(static_cast(selectindex->get_value())-1); 116 | th::Table* mapindex = th::get_field(root_data, "mapindex"); 117 | 118 | for (auto it = mapindex->get_array().begin(); it != mapindex->get_array().end(); ++it) 119 | { 120 | th::Table* data = dynamic_cast(*it); 121 | if (data) 122 | { 123 | th::Number* nodeid = th::get_field(data, "forwardNodeId"); 124 | size_t nodeid_val = static_cast(nodeid->get_value()); 125 | root_node.add_input_index(nodeid_val); 126 | } 127 | } 128 | 129 | th::Table* children = th::get_field(root_fields, "children"); 130 | 131 | if (children) 132 | { 133 | for (auto it = children->get_array().begin(); it != children->get_array().end(); ++it) 134 | { 135 | if ((*it)->type() == th::ObjType::TORCH) 136 | { 137 | size_t child = build_graph(*it, factory); 138 | root_node.add_child(child); 139 | } 140 | } 141 | } 142 | 143 | return id; 144 | } 145 | }; 146 | 147 | } 148 | } 149 | -------------------------------------------------------------------------------- /include/onmt/nn/Identity.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "onmt/nn/Module.h" 4 | 5 | namespace onmt 6 | { 7 | namespace nn 8 | { 9 | 10 | template 11 | class Identity: public Module 12 | { 13 | public: 14 | Identity() 15 | : Module("nn.Identity") 16 | { 17 | } 18 | 19 | Identity(const Identity& other) 20 | : Module(other) 21 | { 22 | } 23 | 24 | Module* clone(const ModuleFactory*) const override 25 | { 26 | return new Identity(*this); 27 | } 28 | 29 | void forward_impl(const std::vector& inputs) override 30 | { 31 | this->_outputs = inputs; 32 | } 33 | }; 34 | 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /include/onmt/nn/JoinTable.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "onmt/nn/Module.h" 4 | 5 | namespace onmt 6 | { 7 | namespace nn 8 | { 9 | 10 | template 11 | class JoinTable: public Module 12 | { 13 | public: 14 | JoinTable() 15 | : Module("nn.JoinTable") 16 | { 17 | } 18 | 19 | JoinTable(const JoinTable& other) 20 | : Module(other) 21 | { 22 | } 23 | 24 | Module* clone(const ModuleFactory*) const override 25 | { 26 | return new JoinTable(*this); 27 | } 28 | 29 | void forward_impl(const std::vector& inputs) override 30 | { 31 | // Compute final size. 32 | int rows = inputs[0].rows(); 33 | int cols = 0; 34 | 35 | for (size_t i = 0; i < inputs.size(); ++i) 36 | cols += inputs[i].cols(); 37 | 38 | this->_output.resize(rows, cols); 39 | 40 | // Join column-wise by default. 41 | int offset = 0; 42 | 43 | for (size_t i = 0; i < inputs.size(); ++i) 44 | { 45 | this->_output.block(0, offset, rows, inputs[i].cols()) = inputs[i]; 46 | offset += inputs[i].cols(); 47 | } 48 | } 49 | }; 50 | 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /include/onmt/nn/Linear.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include "onmt/nn/Module.h" 5 | #include "onmt/th/Obj.h" 6 | #include "onmt/StorageLoader.h" 7 | 8 | #ifdef ANDROID_GNUSTL_COMPAT 9 | # include "onmt/android_gnustl_compat.h" 10 | #endif 11 | 12 | namespace onmt 13 | { 14 | namespace nn 15 | { 16 | 17 | template 18 | class Linear: public Module 19 | { 20 | public: 21 | Linear(th::Table* data) 22 | : Module("nn.Linear") 23 | , _weight(new MatIn(StorageLoader::get_matrix(data, "weight"))) 24 | , _bias(new MatIn(StorageLoader::get_matrix(data, "bias"))) 25 | { 26 | _wrows = _weight->rows(); 27 | _wcols = _weight->cols(); 28 | _rwrows = 0; 29 | } 30 | 31 | Linear(const Linear& other) 32 | : Module(other) 33 | , _weight(other._weight) 34 | , _bias(other._bias) 35 | { 36 | _wrows = _weight->rows(); 37 | _wcols = _weight->cols(); 38 | _rwrows = 0; 39 | } 40 | 41 | virtual ~Linear() 42 | { 43 | } 44 | 45 | Module* clone(const ModuleFactory*) const override 46 | { 47 | return new Linear(*this); 48 | } 49 | 50 | virtual void forward_impl(const MatFwd& input) override 51 | { 52 | if (_rwrows) 53 | { 54 | this->_output.resize(input.rows(), _rwrows); 55 | this->_output = input * _rweight.transpose(); 56 | if (_rbias.rows() > 0) 57 | { 58 | for (int i = 0; i < input.rows(); ++i) 59 | this->_output.row(i).noalias() += _rbias.transpose(); 60 | } 61 | } 62 | else 63 | { 64 | this->_output.resize(input.rows(), _wrows); 65 | this->_output = input * _weight->transpose(); 66 | if (_bias->rows() > 0) 67 | { 68 | for (int i = 0; i < input.rows(); ++i) 69 | this->_output.row(i).noalias() += _bias->transpose(); 70 | } 71 | } 72 | } 73 | 74 | std::string get_details() const override 75 | { 76 | std::string details = std::to_string(_wcols) + "->" + std::to_string(_wrows); 77 | if (_bias->rows() == 0) 78 | details += " without bias"; 79 | return details; 80 | } 81 | 82 | size_t get_weight_rows() const 83 | { 84 | return _wrows; 85 | } 86 | 87 | /* reduce the weight matrix to a given vocabulary, v is the list of index to keep */ 88 | virtual void apply_subdictionary(const std::vector& v) 89 | { 90 | _rwrows = v.size(); 91 | _rweight.resize(v.size(), _wcols); 92 | _rbias.resize(v.size(), 1); 93 | /* build sub-matrix */ 94 | for (size_t i = 0; i < v.size(); i++) 95 | { 96 | _rweight.row(i) = _weight->row(v[i]); 97 | _rbias.row(i) = _bias->row(v[i]); 98 | } 99 | } 100 | 101 | protected: 102 | std::shared_ptr _weight; 103 | std::shared_ptr _bias; 104 | Eigen::RowMajorMat _rweight; 105 | Eigen::RowMajorMat _rbias; 106 | size_t _wrows, _wcols; 107 | size_t _rwrows; 108 | }; 109 | 110 | } 111 | } 112 | -------------------------------------------------------------------------------- /include/onmt/nn/LogSoftMax.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "onmt/nn/Module.h" 4 | 5 | namespace onmt 6 | { 7 | namespace nn 8 | { 9 | 10 | template 11 | class LogSoftMax: public Module 12 | { 13 | public: 14 | LogSoftMax() 15 | : Module("nn.LogSoftMax") 16 | { 17 | } 18 | 19 | LogSoftMax(const LogSoftMax& other) 20 | : Module(other) 21 | { 22 | } 23 | 24 | Module* clone(const ModuleFactory*) const override 25 | { 26 | return new LogSoftMax(*this); 27 | } 28 | 29 | void forward_impl(const MatFwd& input) 30 | { 31 | this->_output.resize(input.rows(), input.cols()); 32 | 33 | for (int i = 0; i < input.rows(); ++i) 34 | { 35 | auto v = input.row(i); 36 | double max = v.maxCoeff(); 37 | double log_z = log((v.array() - max).exp().sum()) + max; 38 | this->_output.row(i) = v.array() - log_z; 39 | } 40 | } 41 | }; 42 | 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /include/onmt/nn/LookupTable.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include "onmt/nn/Module.h" 5 | #include "onmt/th/Obj.h" 6 | #include "onmt/StorageLoader.h" 7 | 8 | namespace onmt 9 | { 10 | namespace nn 11 | { 12 | 13 | template 14 | class LookupTable: public Module 15 | { 16 | public: 17 | LookupTable(th::Table* data) 18 | : Module("nn.LookupTable") 19 | , _weight(new MatEmb(StorageLoader::get_matrix(data, "weight"))) 20 | { 21 | } 22 | 23 | LookupTable(const LookupTable& other) 24 | : Module(other) 25 | , _weight(other._weight) 26 | { 27 | } 28 | 29 | Module* clone(const ModuleFactory*) const override 30 | { 31 | return new LookupTable(*this); 32 | } 33 | 34 | void forward_impl(const MatFwd& input) override 35 | { 36 | this->_output.resize(input.rows(), _weight->cols()); 37 | 38 | for (size_t i = 0; i < input.batches(); ++i) 39 | this->_output.row(i).noalias() = _weight->row(input(i, 0)); 40 | } 41 | 42 | private: 43 | std::shared_ptr _weight; 44 | }; 45 | 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /include/onmt/nn/MM.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "onmt/nn/Module.h" 4 | 5 | namespace onmt 6 | { 7 | namespace nn 8 | { 9 | 10 | template 11 | class MM: public Module 12 | { 13 | public: 14 | MM(th::Table* data) 15 | : Module("nn.MM") 16 | , _transA(get_boolean(data, "transA")) 17 | , _transB(get_boolean(data, "transB")) 18 | { 19 | } 20 | 21 | MM(const MM& other) 22 | : Module(other) 23 | , _transA(other._transA) 24 | , _transB(other._transB) 25 | { 26 | } 27 | 28 | Module* clone(const ModuleFactory*) const override 29 | { 30 | return new MM(*this); 31 | } 32 | 33 | void forward_impl(const std::vector& inputs) override 34 | { 35 | this->_output.resize(inputs[0].rows(), inputs[0].virtualRows()*inputs[1].virtualCols()); 36 | this->_output.setHiddenDim(inputs[0].virtualRows()); 37 | 38 | for (size_t i = 0; i < inputs[0].batches(); ++i) 39 | { 40 | MatFwd m1 = inputs[0].batch(i); 41 | MatFwd m2 = inputs[1].batch(i); 42 | 43 | if (_transA) 44 | m1.transposeInPlace(); 45 | if (_transB) 46 | m2.transposeInPlace(); 47 | 48 | MatFwd res = m1 * m2; 49 | 50 | this->_output.assign(i, res); 51 | } 52 | } 53 | 54 | private: 55 | bool _transA; 56 | bool _transB; 57 | }; 58 | 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /include/onmt/nn/Module.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include "onmt/Profiler.h" 8 | 9 | namespace onmt 10 | { 11 | namespace nn 12 | { 13 | 14 | template 15 | class ModuleFactory; 16 | 17 | template 18 | class Module 19 | { 20 | public: 21 | Module(const std::string& name) 22 | : _name(name) 23 | , _profile(true) 24 | , _profiler(nullptr) 25 | , _outputs(1) 26 | , _output(_outputs.front()) 27 | { 28 | } 29 | 30 | Module(const std::string& name, bool profile) 31 | : _name(name) 32 | , _profile(profile) 33 | , _profiler(nullptr) 34 | , _outputs(1) 35 | , _output(_outputs.front()) 36 | { 37 | } 38 | 39 | Module(const Module& other) 40 | : _name(other._name) 41 | , _custom_name(other._custom_name) 42 | , _profile(other._profile) 43 | , _block(other._block) 44 | , _profiler(nullptr) 45 | , _outputs(1) 46 | , _output(_outputs.front()) 47 | { 48 | } 49 | 50 | virtual ~Module() 51 | { 52 | } 53 | 54 | virtual Module* clone(const ModuleFactory* factory) const = 0; 55 | 56 | const std::vector& forward(const std::vector& inputs) 57 | { 58 | if (_profile && _profiler) 59 | _profiler->start(); 60 | 61 | forward_impl(inputs); 62 | 63 | if (_profile && _profiler) 64 | _profiler->stop(_block + (!_custom_name.empty() ? _custom_name : _name)); 65 | 66 | if (_post_process) 67 | _post_process(_outputs); 68 | 69 | return _outputs; 70 | } 71 | 72 | const MatFwd& forward_one(const MatFwd& input) 73 | { 74 | return forward(std::vector(1, input))[0]; 75 | } 76 | 77 | virtual void forward_impl(const std::vector& inputs) 78 | { 79 | forward_impl(inputs.front()); 80 | } 81 | 82 | virtual void forward_impl(const MatFwd& input) 83 | { 84 | _output = input; 85 | } 86 | 87 | virtual Module* find(const std::string& custom_name) 88 | { 89 | if (_custom_name == custom_name) 90 | return this; 91 | 92 | return nullptr; 93 | } 94 | 95 | virtual void* apply(void* (*func)(Module*, void*), void* data) 96 | { 97 | return func(this, data); 98 | } 99 | 100 | std::function&)>& post_process_fun() 101 | { 102 | return _post_process; 103 | } 104 | 105 | const std::vector& get_outputs() const 106 | { 107 | return _outputs; 108 | } 109 | 110 | const std::string& get_name() const 111 | { 112 | return _name; 113 | } 114 | 115 | const std::string& get_custom_name() const 116 | { 117 | return _custom_name; 118 | } 119 | 120 | virtual std::string get_details() const 121 | { 122 | return ""; 123 | } 124 | 125 | void set_custom_name(const std::string& custom_name) 126 | { 127 | _custom_name = custom_name; 128 | } 129 | 130 | void set_profiler(Profiler* profiler) 131 | { 132 | _profiler = profiler; 133 | } 134 | 135 | void set_block(const char* s) 136 | { 137 | _block = std::string(s) + ":"; 138 | } 139 | 140 | protected: 141 | std::string _name; 142 | std::string _custom_name; 143 | bool _profile; 144 | std::string _block; 145 | Profiler* _profiler; 146 | 147 | std::vector _outputs; 148 | MatFwd& _output; 149 | 150 | std::function&)> _post_process; 151 | }; 152 | 153 | } 154 | } 155 | -------------------------------------------------------------------------------- /include/onmt/nn/ModuleFactory.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "onmt/nn/Module.h" 4 | #include "onmt/th/Obj.h" 5 | 6 | #ifdef WITH_CUDA 7 | # include "onmt/cuda/Utils.h" 8 | #endif 9 | 10 | namespace onmt 11 | { 12 | namespace nn 13 | { 14 | 15 | template 16 | class ModuleFactory 17 | { 18 | public: 19 | ModuleFactory(Profiler& profiler, bool cuda, bool qlinear); 20 | ModuleFactory(const ModuleFactory& other); 21 | ~ModuleFactory(); 22 | 23 | size_t build(th::Class* obj); 24 | Module* get_module(size_t id) const; 25 | void set_profiler(Profiler& profiler); 26 | 27 | private: 28 | std::vector*> _storage; 29 | Profiler* _profiler; 30 | bool _cuda; 31 | bool _qlinear; 32 | #ifdef WITH_CUDA 33 | cublasHandle_t _handle; 34 | #endif 35 | }; 36 | 37 | } 38 | } 39 | 40 | #include "onmt/nn/ModuleFactory.hxx" 41 | -------------------------------------------------------------------------------- /include/onmt/nn/ModuleFactory.hxx: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include "onmt/nn/Linear.h" 6 | #include "onmt/nn/LookupTable.h" 7 | #include "onmt/nn/CAddTable.h" 8 | #include "onmt/nn/CMulTable.h" 9 | #include "onmt/nn/Sigmoid.h" 10 | #include "onmt/nn/Tanh.h" 11 | #include "onmt/nn/SplitTable.h" 12 | #include "onmt/nn/JoinTable.h" 13 | #include "onmt/nn/SelectTable.h" 14 | #include "onmt/nn/Reshape.h" 15 | #include "onmt/nn/Replicate.h" 16 | #include "onmt/nn/Identity.h" 17 | #include "onmt/nn/SoftMax.h" 18 | #include "onmt/nn/LogSoftMax.h" 19 | #include "onmt/nn/MM.h" 20 | #include "onmt/nn/Sum.h" 21 | #include "onmt/nn/Squeeze.h" 22 | #include "onmt/nn/MulConstant.h" 23 | 24 | #include "onmt/nn/Sequential.h" 25 | #include "onmt/nn/ParallelTable.h" 26 | #include "onmt/nn/ConcatTable.h" 27 | 28 | #include "onmt/nn/Graph.h" 29 | 30 | #ifdef WITH_CUDA 31 | # include "onmt/nn/cuLinear.h" 32 | #endif 33 | 34 | #ifdef WITH_QLINEAR 35 | # include "onmt/nn/qLinear.h" 36 | #endif 37 | 38 | namespace onmt 39 | { 40 | namespace nn 41 | { 42 | 43 | template 44 | ModuleFactory::ModuleFactory(Profiler& profiler, bool cuda, bool qlinear) 45 | : _profiler(&profiler) 46 | , _cuda(cuda) 47 | , _qlinear(qlinear) 48 | { 49 | if (_cuda) 50 | { 51 | #ifdef WITH_CUDA 52 | CUBLAS_CHECK(cublasCreate(&_handle)); 53 | #else 54 | throw std::runtime_error("CTranslate was not compiled with CUDA support"); 55 | #endif 56 | } 57 | if (_qlinear) 58 | { 59 | #ifndef WITH_QLINEAR 60 | throw std::runtime_error("CTranslate was not compiled with QLINEAR support"); 61 | #endif 62 | } 63 | } 64 | 65 | template 66 | ModuleFactory::ModuleFactory(const ModuleFactory& other) 67 | : _profiler(nullptr) 68 | , _cuda(other._cuda) 69 | , _qlinear(other._qlinear) 70 | { 71 | if (_cuda) 72 | { 73 | #ifdef WITH_CUDA 74 | CUBLAS_CHECK(cublasCreate(&_handle)); 75 | #else 76 | throw std::runtime_error("CTranslate was not compiled with CUDA support"); 77 | #endif 78 | } 79 | if (_qlinear) 80 | { 81 | #ifndef WITH_QLINEAR 82 | throw std::runtime_error("CTranslate was not compiled with QLINEAR support"); 83 | #endif 84 | } 85 | 86 | for (const auto& mod: other._storage) 87 | { 88 | _storage.push_back(mod->clone(this)); 89 | #ifdef WITH_CUDA 90 | if (_cuda) 91 | { 92 | auto cul = dynamic_cast*>(_storage.back()); 93 | if (cul) 94 | cul->set_handle(&_handle); 95 | } 96 | #endif 97 | } 98 | } 99 | 100 | template 101 | ModuleFactory::~ModuleFactory() 102 | { 103 | #ifdef WITH_CUDA 104 | if (_cuda) 105 | CUBLAS_CHECK(cublasDestroy(_handle)); 106 | #endif 107 | 108 | for (const auto& mod: _storage) 109 | delete mod; 110 | } 111 | 112 | template 113 | size_t 114 | ModuleFactory::build(th::Class* obj) 115 | { 116 | std::string name = obj->get_classname(); 117 | auto data = dynamic_cast(obj->get_data()); 118 | 119 | Module* mod = nullptr; 120 | 121 | if (name == "nn.Linear") 122 | { 123 | #ifdef WITH_CUDA 124 | if (_cuda) 125 | mod = new cuLinear(data, &_handle); 126 | else 127 | #endif 128 | #ifdef WITH_QLINEAR 129 | if (_qlinear) 130 | mod = new qLinear(data); 131 | else 132 | #endif 133 | mod = new Linear(data); 134 | } 135 | else if (name == "nn.LookupTable") 136 | mod = new LookupTable(data); 137 | else if (name == "nn.CAddTable") 138 | mod = new CAddTable(); 139 | else if (name == "nn.CMulTable") 140 | mod = new CMulTable(); 141 | else if (name == "nn.Sigmoid") 142 | mod = new Sigmoid(); 143 | else if (name == "nn.Tanh") 144 | mod = new Tanh(); 145 | else if (name == "nn.SplitTable") 146 | mod = new SplitTable(); 147 | else if (name == "nn.JoinTable") 148 | mod = new JoinTable(); 149 | else if (name == "nn.SelectTable") 150 | mod = new SelectTable(data); 151 | else if (name == "nn.Reshape") 152 | mod = new Reshape(data); 153 | else if (name == "nn.Replicate") 154 | mod = new Replicate(data); 155 | else if (name == "nn.SoftMax") 156 | mod = new SoftMax(); 157 | else if (name == "nn.LogSoftMax") 158 | mod = new LogSoftMax(); 159 | else if (name == "nn.MM") 160 | mod = new MM(data); 161 | else if (name == "nn.Sum") 162 | mod = new Sum(data); 163 | else if (name == "nn.Squeeze") 164 | mod = new Squeeze(data); 165 | else if (name == "nn.MulConstant") 166 | mod = new MulConstant(data); 167 | else if (name == "nn.Sequential") 168 | mod = new Sequential(data, *this); 169 | else if (name == "nn.ConcatTable") 170 | mod = new ConcatTable(data, *this); 171 | else if (name == "nn.ParallelTable") 172 | mod = new ParallelTable(data, *this); 173 | else if (name == "nn.Identity" || name == "nn.Dropout") 174 | mod = new Identity(); 175 | else if (name == "nn.gModule") 176 | mod = new Graph(obj, name, *this); 177 | else 178 | { 179 | auto net = th::get_field(data, "net"); 180 | 181 | if (net) 182 | return build(net); 183 | else 184 | throw std::runtime_error(name + " is not supported yet"); 185 | } 186 | 187 | auto custom_name = th::get_field(data, "name"); 188 | 189 | if (custom_name) 190 | mod->set_custom_name(custom_name->get_value()); 191 | 192 | mod->set_profiler(_profiler); 193 | 194 | size_t id = _storage.size(); 195 | _storage.push_back(mod); 196 | 197 | return id; 198 | } 199 | 200 | template 201 | Module* 202 | ModuleFactory::get_module(size_t id) const 203 | { 204 | return _storage.at(id); 205 | } 206 | 207 | template 208 | void ModuleFactory::set_profiler(Profiler& profiler) 209 | { 210 | _profiler = &profiler; 211 | for (auto& mod : _storage) 212 | mod->set_profiler(_profiler); 213 | } 214 | 215 | } 216 | } 217 | -------------------------------------------------------------------------------- /include/onmt/nn/MulConstant.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "onmt/nn/Module.h" 4 | 5 | namespace onmt 6 | { 7 | namespace nn 8 | { 9 | 10 | template 11 | class MulConstant: public Module 12 | { 13 | public: 14 | MulConstant(th::Table* data) 15 | : Module("nn.MulConstant") 16 | , _scalar(th::get_scalar(data, "constant_scalar")) 17 | { 18 | } 19 | 20 | MulConstant(const MulConstant& other) 21 | : Module(other) 22 | , _scalar(other._scalar) 23 | { 24 | } 25 | 26 | Module* clone(const ModuleFactory*) const override 27 | { 28 | return new MulConstant(*this); 29 | } 30 | 31 | void forward_impl(const MatFwd& input) 32 | { 33 | this->_output.noalias() = input * _scalar; 34 | } 35 | 36 | private: 37 | ModelT _scalar; 38 | }; 39 | 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /include/onmt/nn/Node.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #include "onmt/nn/ModuleFactory.h" 10 | 11 | namespace onmt 12 | { 13 | namespace nn 14 | { 15 | 16 | template 17 | class Node 18 | { 19 | public: 20 | Node(size_t id, const ModuleFactory& factory, 21 | std::unordered_map>& node_map) 22 | : _visited(false) 23 | , _id(id) 24 | , _select_index(-1) 25 | , _expected_inputs(0) 26 | , _children() 27 | , _factory(factory) 28 | , _node_map(node_map) 29 | { 30 | } 31 | 32 | Node(const Node& other, const ModuleFactory& factory, 33 | std::unordered_map>& node_map) 34 | : _visited(false) 35 | , _id(other._id) 36 | , _select_index(other._select_index) 37 | , _expected_inputs(other._index_to_input.size()) 38 | , _module_id(other._module_id) 39 | , _children(other._children) 40 | , _input_to_index(other._input_to_index) 41 | , _index_to_input(other._index_to_input) 42 | , _module_inputs(other._index_to_input.size()) 43 | , _factory(factory) 44 | , _node_map(node_map) 45 | { 46 | } 47 | 48 | void set_id(size_t id) 49 | { 50 | _id = id; 51 | } 52 | 53 | void set_select_index(int index) 54 | { 55 | _select_index = index; 56 | } 57 | 58 | void add_child(size_t child) 59 | { 60 | _children.push_back(child); 61 | } 62 | 63 | void set_module_id(size_t module_id) 64 | { 65 | _module_id = module_id; 66 | } 67 | 68 | void add_input_index(size_t id) 69 | { 70 | size_t index = _index_to_input.size(); 71 | _index_to_input.push_back(id); 72 | _input_to_index[id] = index; 73 | 74 | _module_inputs.resize(_module_inputs.size() + 1); 75 | _expected_inputs += 1; 76 | } 77 | 78 | Module* find(const std::string& custom_name) 79 | { 80 | if (_module_id != std::numeric_limits::max()) 81 | { 82 | auto mod = _factory.get_module(_module_id); 83 | auto res = mod->find(custom_name); 84 | if (res) 85 | return res; 86 | } 87 | 88 | for (auto child: _children) 89 | { 90 | auto& node = _node_map.at(child); 91 | auto res = node.find(custom_name); 92 | if (res) 93 | return res; 94 | } 95 | 96 | return nullptr; 97 | } 98 | 99 | void* apply(void* (*func)(Module*, void*), void* data) 100 | { 101 | if (_visited) 102 | return nullptr; 103 | 104 | _visited = true; 105 | 106 | if (_module_id != std::numeric_limits::max()) 107 | { 108 | auto mod = _factory.get_module(_module_id); 109 | mod->apply(func, data); 110 | } 111 | 112 | for (auto child: _children) 113 | { 114 | auto& node = _node_map.at(child); 115 | node.apply(func, data); 116 | } 117 | 118 | return nullptr; 119 | } 120 | 121 | void to_dot(std::ostream& os) const 122 | { 123 | if (_visited) 124 | return; 125 | 126 | _visited = true; 127 | 128 | os << _id << " [label=\"Node" << _id; 129 | if (_module_id != std::numeric_limits::max()) 130 | { 131 | auto mod = _factory.get_module(_module_id); 132 | os << "\nmodule = " << mod->get_name(); 133 | std::string details = mod->get_details(); 134 | if (!details.empty()) 135 | os << " " << details; 136 | } 137 | if (_select_index >= 0) 138 | { 139 | os << "\ninput = { }"; 140 | os << "\nselectindex = " << _select_index; 141 | } 142 | if (_index_to_input.size() > 1) 143 | { 144 | os << "\nmapindex = {"; 145 | for (auto it = _index_to_input.begin(); it != _index_to_input.end(); it++) 146 | { 147 | if (it != _index_to_input.begin()) 148 | os << ","; 149 | os << "Node" << *it; 150 | } 151 | os << "}"; 152 | } 153 | os << "\"];" << std::endl; 154 | 155 | for (auto child: _children) 156 | { 157 | auto& node = _node_map.at(child); 158 | os << _id << " -> " << node._id << ";" << std::endl; 159 | node.to_dot(os); 160 | } 161 | } 162 | 163 | void forward(const std::vector& node_inputs, 164 | std::vector& final_output, 165 | const Node* from) 166 | { 167 | if (_select_index >= 0) // Pick input matrix from table. 168 | _module_inputs[0] = node_inputs[_select_index]; 169 | else if (_index_to_input.size() > 1) // Map matrix into input table. 170 | _module_inputs[_input_to_index.at(from->_id)] = node_inputs[0]; 171 | else // node_inputs is also input of the module. 172 | _module_inputs = node_inputs; 173 | 174 | _expected_inputs--; 175 | 176 | // Only forward into the module when all inputs were forwarded into the node. 177 | if (_expected_inputs <= 0) 178 | { 179 | Module* mod = nullptr; 180 | if (_module_id != std::numeric_limits::max()) 181 | mod = _factory.get_module(_module_id); 182 | if (mod && mod->get_name() != "nn.Identity") 183 | _output = mod->forward(_module_inputs); 184 | else 185 | _output = _module_inputs; 186 | 187 | // Reset input table and count to make the node reentrant. 188 | _expected_inputs = _index_to_input.size(); 189 | _module_inputs.resize(_expected_inputs); 190 | 191 | if (_children.empty()) // No child == graph output. 192 | final_output = _output; 193 | else 194 | { 195 | // Forward output into every chilren node 196 | for (auto child: _children) 197 | { 198 | auto& node = _node_map.at(child); 199 | node.forward(_output, final_output, this); 200 | } 201 | } 202 | } 203 | } 204 | 205 | void set_unvisited() 206 | { 207 | _visited = false; 208 | } 209 | 210 | private: 211 | mutable bool _visited; 212 | size_t _id; 213 | int _select_index; 214 | int _expected_inputs; 215 | size_t _module_id; 216 | std::vector _children; 217 | std::map _input_to_index; 218 | std::vector _index_to_input; 219 | std::vector _module_inputs; 220 | std::vector _output; 221 | const ModuleFactory& _factory; 222 | std::unordered_map>& _node_map; 223 | }; 224 | 225 | } 226 | } 227 | -------------------------------------------------------------------------------- /include/onmt/nn/ParallelTable.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "onmt/nn/Container.h" 4 | 5 | namespace onmt 6 | { 7 | namespace nn 8 | { 9 | 10 | template 11 | class ParallelTable: public Container 12 | { 13 | public: 14 | ParallelTable(th::Table* data, ModuleFactory& factory) 15 | : Container("nn.ParallelTable", data, factory) 16 | { 17 | } 18 | 19 | ParallelTable(const ParallelTable& other, 20 | const ModuleFactory& factory) 21 | : Container(other, factory) 22 | { 23 | } 24 | 25 | Module* clone(const ModuleFactory* factory) const override 26 | { 27 | return new ParallelTable(*this, *factory); 28 | } 29 | 30 | void forward_impl(const std::vector& inputs) override 31 | { 32 | this->_outputs.resize(this->_sequence->size()); 33 | 34 | for (size_t i = 0; i < this->_sequence->size(); ++i) 35 | { 36 | auto mod = this->_factory.get_module((*(this->_sequence))[i]); 37 | if (inputs.size() == 1 && this->_sequence->size() > 1) 38 | { 39 | // This is a special case when the inputs table is actually bundled in a single matrix. 40 | // The dimensions that do not have a corresponding module are all forwarded to the 41 | // last module in the sequence. 42 | if (i == this->_sequence->size() - 1) 43 | { 44 | std::vector in; 45 | for (int j = i; j < inputs[0].cols(); ++j) 46 | in.push_back(inputs[0].col(j)); 47 | 48 | auto res = mod->forward(in); 49 | 50 | for (size_t j = i; j - i < res.size(); ++j) 51 | this->_outputs[j] = res[j - i]; 52 | } 53 | else 54 | { 55 | this->_outputs[i] = mod->forward_one(inputs[0].col(i)); 56 | } 57 | } 58 | else 59 | { 60 | this->_outputs[i] = mod->forward_one(inputs[i]); 61 | } 62 | } 63 | } 64 | }; 65 | 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /include/onmt/nn/Replicate.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "onmt/nn/Module.h" 4 | #include "onmt/th/Obj.h" 5 | #include "onmt/th/Utils.h" 6 | 7 | namespace onmt 8 | { 9 | namespace nn 10 | { 11 | 12 | template 13 | class Replicate: public Module 14 | { 15 | public: 16 | Replicate(th::Table* data) 17 | : Module("nn.Replicate") 18 | , _dimension(th::get_number(data, "dim")) 19 | , _nfeatures(th::get_number(data, "nfeatures")) 20 | { 21 | } 22 | 23 | Replicate(const Replicate& other) 24 | : Module(other) 25 | , _dimension(other._dimension) 26 | , _nfeatures(other._nfeatures) 27 | { 28 | } 29 | 30 | Module* clone(const ModuleFactory*) const override 31 | { 32 | return new Replicate(*this); 33 | } 34 | 35 | void forward_impl(const MatFwd& input) override 36 | { 37 | this->_output = input; 38 | 39 | if (_dimension == 2) 40 | this->_output.setHiddenDim(_nfeatures); 41 | else if (_dimension == 3) 42 | this->_output.setHiddenDim(input.cols()); 43 | 44 | if (_nfeatures > 1) 45 | this->_output = input.replicate(1, _nfeatures); 46 | } 47 | 48 | private: 49 | int _dimension; 50 | int _nfeatures; 51 | }; 52 | 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /include/onmt/nn/Reshape.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "onmt/nn/Module.h" 4 | 5 | namespace onmt 6 | { 7 | namespace nn 8 | { 9 | 10 | template 11 | class Reshape: public Module 12 | { 13 | public: 14 | Reshape(th::Table* data) 15 | : Module("nn.Reshape") 16 | , _dims(th::get_storage_as_vector(data, "size")) 17 | { 18 | } 19 | 20 | Reshape(const Reshape& other) 21 | : Module(other) 22 | , _dims(other._dims) 23 | { 24 | } 25 | 26 | Module* clone(const ModuleFactory*) const override 27 | { 28 | return new Reshape(*this); 29 | } 30 | 31 | void forward_impl(const std::vector& inputs) override 32 | { 33 | // also do the SplitTable 34 | long leading_dim = _dims[0]; 35 | 36 | this->_outputs.resize(leading_dim); 37 | 38 | for (long i = 0; i < leading_dim; ++i) 39 | { 40 | this->_outputs[i] = inputs[0].block(0, 41 | i * (inputs[0].cols() / leading_dim), 42 | inputs[0].rows(), 43 | inputs[0].cols() / leading_dim); 44 | } 45 | } 46 | 47 | private: 48 | std::vector _dims; 49 | }; 50 | 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /include/onmt/nn/SelectTable.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "onmt/nn/Module.h" 4 | #include "onmt/th/Obj.h" 5 | #include "onmt/th/Utils.h" 6 | 7 | namespace onmt 8 | { 9 | namespace nn 10 | { 11 | 12 | template 13 | class SelectTable: public Module 14 | { 15 | public: 16 | SelectTable(th::Table* data) 17 | : Module("nn.SelectTable") 18 | , _index(th::get_number(data, "index")) 19 | { 20 | } 21 | 22 | SelectTable(const SelectTable& other) 23 | : Module(other) 24 | , _index(other._index) 25 | { 26 | } 27 | 28 | Module* clone(const ModuleFactory*) const override 29 | { 30 | return new SelectTable(*this); 31 | } 32 | 33 | void forward_impl(const std::vector& inputs) override 34 | { 35 | int index; 36 | if (_index < 0) 37 | index = inputs.size() + _index; 38 | else 39 | index = _index - 1; // Lua is 1-indexed. 40 | 41 | this->_output = inputs[index]; 42 | } 43 | 44 | private: 45 | int _index; 46 | }; 47 | 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /include/onmt/nn/Sequential.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "onmt/nn/Container.h" 4 | 5 | namespace onmt 6 | { 7 | namespace nn 8 | { 9 | 10 | template 11 | class Sequential: public Container 12 | { 13 | public: 14 | Sequential(th::Table* data, ModuleFactory& factory) 15 | : Container("nn.Sequential", data, factory) 16 | { 17 | } 18 | 19 | Sequential(const Sequential& other, 20 | const ModuleFactory& factory) 21 | : Container(other, factory) 22 | { 23 | } 24 | 25 | Module* clone(const ModuleFactory* factory) const override 26 | { 27 | return new Sequential(*this, *factory); 28 | } 29 | 30 | void forward_impl(const std::vector& inputs) override 31 | { 32 | if (this->_sequence->empty()) 33 | { 34 | this->_outputs = inputs; 35 | return; 36 | } 37 | 38 | auto it = this->_sequence->begin(); 39 | this->_outputs = this->_factory.get_module(*it)->forward(inputs); 40 | 41 | for (it++; it != this->_sequence->end(); it++) 42 | { 43 | this->_outputs = this->_factory.get_module(*it)->forward(this->_outputs); 44 | } 45 | } 46 | }; 47 | 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /include/onmt/nn/Sigmoid.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "onmt/nn/Module.h" 4 | 5 | namespace onmt 6 | { 7 | namespace nn 8 | { 9 | 10 | template 11 | class Sigmoid: public Module 12 | { 13 | public: 14 | Sigmoid() 15 | : Module("nn.Sigmoid") 16 | { 17 | } 18 | 19 | Sigmoid(const Sigmoid& other) 20 | : Module(other) 21 | { 22 | } 23 | 24 | Module* clone(const ModuleFactory*) const override 25 | { 26 | return new Sigmoid(*this); 27 | } 28 | 29 | void forward_impl(const MatFwd& input) override 30 | { 31 | this->_output = (1.0 + (-input).array().exp()).inverse(); 32 | } 33 | }; 34 | 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /include/onmt/nn/SoftMax.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "onmt/nn/Module.h" 4 | 5 | namespace onmt 6 | { 7 | namespace nn 8 | { 9 | 10 | template 11 | class SoftMax: public Module 12 | { 13 | public: 14 | SoftMax() 15 | : Module("nn.SoftMax") 16 | { 17 | } 18 | 19 | SoftMax(const SoftMax& other) 20 | : Module(other) 21 | { 22 | } 23 | 24 | Module* clone(const ModuleFactory*) const override 25 | { 26 | return new SoftMax(*this); 27 | } 28 | 29 | void forward_impl(const MatFwd& input) 30 | { 31 | this->_output.resizeLike(input); 32 | 33 | for (int i = 0; i < input.rows(); ++i) 34 | { 35 | auto v = input.row(i); 36 | double max = v.maxCoeff(); 37 | this->_output.row(i) = ((v.array() - (log((v.array() - max).exp().sum()) + max)).exp()); 38 | } 39 | } 40 | }; 41 | 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /include/onmt/nn/SplitTable.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "onmt/nn/Module.h" 4 | 5 | namespace onmt 6 | { 7 | namespace nn 8 | { 9 | 10 | template 11 | class SplitTable: public Module 12 | { 13 | public: 14 | SplitTable() 15 | : Module("nn.SplitTable") 16 | { 17 | } 18 | 19 | SplitTable(const SplitTable& other) 20 | : Module(other) 21 | { 22 | } 23 | 24 | Module* clone(const ModuleFactory*) const override 25 | { 26 | return new SplitTable(*this); 27 | } 28 | 29 | void forward_impl(const std::vector& inputs) override 30 | { 31 | // it is assumed that the previous reshape did the split 32 | this->_outputs = inputs; 33 | } 34 | }; 35 | 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /include/onmt/nn/Squeeze.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "onmt/nn/Module.h" 4 | #include "onmt/th/Obj.h" 5 | 6 | namespace onmt 7 | { 8 | namespace nn 9 | { 10 | 11 | template 12 | class Squeeze: public Module 13 | { 14 | public: 15 | Squeeze(th::Table* data) 16 | : Module("nn.Squeeze") 17 | , _dimension(get_number(data, "dimension")) 18 | { 19 | } 20 | 21 | Squeeze(const Squeeze& other) 22 | : Module(other) 23 | , _dimension(other._dimension) 24 | { 25 | } 26 | 27 | Module* clone(const ModuleFactory*) const override 28 | { 29 | return new Squeeze(*this); 30 | } 31 | 32 | void forward_impl(const MatFwd& input) override 33 | { 34 | this->_output = input; 35 | this->_output.squeeze(_dimension); 36 | } 37 | 38 | private: 39 | int _dimension; 40 | }; 41 | 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /include/onmt/nn/Sum.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "onmt/nn/Module.h" 4 | #include "onmt/th/Obj.h" 5 | 6 | namespace onmt 7 | { 8 | namespace nn 9 | { 10 | 11 | template 12 | class Sum: public Module 13 | { 14 | public: 15 | Sum(th::Table* data) 16 | : Module("nn.Sum") 17 | , _dimension(get_number(data, "dimension")) 18 | { 19 | } 20 | 21 | Sum(const Sum& other) 22 | : Module(other) 23 | , _dimension(other._dimension) 24 | { 25 | } 26 | 27 | Module* clone(const ModuleFactory*) const override 28 | { 29 | return new Sum(*this); 30 | } 31 | 32 | void forward_impl(const MatFwd& input) override 33 | { 34 | this->_output = input.sum(_dimension); 35 | } 36 | 37 | private: 38 | int _dimension; 39 | }; 40 | 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /include/onmt/nn/Tanh.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "onmt/nn/Module.h" 4 | 5 | namespace onmt 6 | { 7 | namespace nn 8 | { 9 | 10 | template 11 | class Tanh: public Module 12 | { 13 | public: 14 | Tanh() 15 | : Module("nn.Tanh") 16 | { 17 | } 18 | 19 | Tanh(const Tanh& other) 20 | : Module(other) 21 | { 22 | } 23 | 24 | Module* clone(const ModuleFactory*) const override 25 | { 26 | return new Tanh(*this); 27 | } 28 | 29 | void forward_impl(const MatFwd& input) override 30 | { 31 | this->_output = input.array().tanh().matrix(); 32 | } 33 | }; 34 | 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /include/onmt/nn/cuLinear.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "onmt/nn/Linear.h" 4 | #include "onmt/cuda/Utils.h" 5 | #include "onmt/cuda/Kernels.cuh" 6 | 7 | namespace onmt 8 | { 9 | namespace nn 10 | { 11 | 12 | template 13 | class cuLinear: public Linear 14 | { 15 | public: 16 | cuLinear(th::Table* data, cublasHandle_t* handle) 17 | : Linear(data) 18 | , _handle(handle) 19 | // cuBLAS works with col-major matrices. 20 | , _bias_device((this->_bias->rows() > 0) ? cuda::to_device(this->_bias->data(), this->_bias->rows()) : nullptr, 21 | [](float* p) { CUDA_CHECK(cudaFree(p)); }) 22 | , _weight_device(cuda::to_device(this->_weight->data(), this->_weight->cols(), this->_weight->rows()), 23 | [](float* p) { CUDA_CHECK(cudaFree(p)); }) 24 | , _input_device(nullptr) 25 | , _output_device(nullptr) 26 | , _expanded_bias_device(nullptr) 27 | , _allocated_batches(0) 28 | { 29 | } 30 | 31 | cuLinear(const cuLinear& other) 32 | : Linear(other) 33 | , _handle(other._handle) 34 | , _bias_device(other._bias_device) 35 | , _weight_device(other._weight_device) 36 | , _input_device(nullptr) 37 | , _output_device(nullptr) 38 | , _expanded_bias_device(nullptr) 39 | , _allocated_batches(0) 40 | { 41 | } 42 | 43 | ~cuLinear() 44 | { 45 | CUDA_CHECK(cudaFree(_input_device)); 46 | CUDA_CHECK(cudaFree(_output_device)); 47 | CUDA_CHECK(cudaFree(_expanded_bias_device)); 48 | } 49 | 50 | Module* clone(const ModuleFactory*) const override 51 | { 52 | return new cuLinear(*this); 53 | } 54 | 55 | void set_handle(cublasHandle_t* handle) 56 | { 57 | _handle = handle; 58 | } 59 | 60 | void forward_impl(const MatFwd& input) override 61 | { 62 | // See http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm 63 | 64 | const size_t batch_size = input.rows(); 65 | const int input_size = input.cols(); 66 | const int output_size = this->_weight->rows(); 67 | 68 | if (batch_size > _allocated_batches) 69 | this->realloc_device_buffers(batch_size); 70 | 71 | cuda::to_device(_input_device, input.data(), input_size, batch_size); 72 | 73 | float alpha = 1; 74 | float beta = 0; 75 | 76 | CUBLAS_CHECK(cublasSgemm(*_handle, 77 | CUBLAS_OP_T, CUBLAS_OP_N, 78 | output_size, batch_size, input_size, 79 | &alpha, 80 | _weight_device.get(), input_size, 81 | _input_device, input_size, 82 | &beta, 83 | _output_device, output_size)); 84 | 85 | if (_expanded_bias_device) 86 | cuda::kernels::add(_output_device, _expanded_bias_device, batch_size * output_size); 87 | 88 | this->_output.resize(batch_size, this->_weight->rows()); 89 | cuda::to_host(_output_device, this->_output.data(), output_size, batch_size); 90 | } 91 | 92 | virtual void apply_subdictionary(const std::vector&) 93 | { 94 | throw std::runtime_error("subdictionary not implemented for cuLinear"); 95 | } 96 | 97 | private: 98 | void realloc_device_buffers(int num_batches) 99 | { 100 | CUDA_CHECK(cudaFree(_output_device)); 101 | CUDA_CHECK(cudaFree(_input_device)); 102 | CUDA_CHECK(cudaFree(_expanded_bias_device)); 103 | 104 | _output_device = cuda::to_device(this->_weight->rows(), num_batches); 105 | _input_device = cuda::to_device(this->_weight->cols(), num_batches); 106 | 107 | if (_bias_device) 108 | { 109 | _expanded_bias_device = cuda::to_device(this->_weight->rows(), num_batches); 110 | for (int i = 0; i < num_batches; ++i) 111 | CUDA_CHECK(cudaMemcpy(_expanded_bias_device + i * this->_weight->rows(), 112 | _bias_device.get(), 113 | this->_weight->rows() * sizeof (float), 114 | cudaMemcpyDeviceToDevice)); 115 | 116 | } 117 | 118 | _allocated_batches = num_batches; 119 | } 120 | 121 | cublasHandle_t* _handle; 122 | 123 | std::shared_ptr _bias_device; 124 | std::shared_ptr _weight_device; 125 | 126 | // Preallocate device buffers. 127 | float* _input_device; 128 | float* _output_device; 129 | float* _expanded_bias_device; 130 | size_t _allocated_batches; 131 | }; 132 | 133 | } 134 | } 135 | -------------------------------------------------------------------------------- /include/onmt/nn/qLinear.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include "onmt/nn/Linear.h" 6 | #include "onmt/th/Obj.h" 7 | #include "onmt/Utils.h" 8 | #include "onmt/simd/MatrixMult.h" 9 | 10 | namespace onmt 11 | { 12 | namespace nn 13 | { 14 | 15 | template 16 | class qLinear: public Linear 17 | { 18 | public: 19 | qLinear(th::Table* data) 20 | : Linear(data) 21 | , _quant_weight_buffer(_malloc_align(_quant_weight, this->_wrows * this->_wcols / SIMD_VSIZE), 22 | [](void* p) { std::free(p); }) 23 | , _quant_input_buffer(nullptr) 24 | { 25 | // Quantize the weight - ncols=width is supposed to be multiple of SIMD_VSIZE 26 | if (this->_wcols % SIMD_VSIZE) 27 | throw std::runtime_error("Weight matrix width should be multiple of 8/16 for qLinear"); 28 | simd::Quantize(this->_weight->data(), _quant_weight, this->_wrows, this->_wcols); 29 | } 30 | 31 | qLinear(const qLinear& other) 32 | : Linear(other) 33 | , _quant_weight_buffer(other._quant_weight_buffer) 34 | , _quant_input_buffer(nullptr) 35 | , _quant_weight(other._quant_weight) 36 | { 37 | } 38 | 39 | virtual ~qLinear() 40 | { 41 | std::free(_quant_input_buffer); 42 | } 43 | 44 | Module* clone(const ModuleFactory*) const override 45 | { 46 | return new qLinear(*this); 47 | } 48 | 49 | /* aligned allocation method - in c++17 we have aligned_alloc that we can use */ 50 | void* _malloc_align(SIMD_TYPE *&data, size_t size) 51 | { 52 | return _realloc_align(nullptr, data, size); 53 | } 54 | 55 | void* _realloc_align(void *buffer, SIMD_TYPE *&data, size_t size) 56 | { 57 | size_t buf_size = (size + 1) * sizeof(SIMD_TYPE); 58 | void* p = std::realloc(buffer, buf_size); 59 | if (!p) 60 | { 61 | std::free(buffer); 62 | throw std::runtime_error("Cannot allocate memory"); 63 | } 64 | void* ptr = p; 65 | align(sizeof(SIMD_TYPE), size * sizeof(SIMD_TYPE), ptr, buf_size); 66 | data = reinterpret_cast(ptr); 67 | return p; 68 | } 69 | 70 | virtual void forward_impl(const MatFwd& input) override 71 | { 72 | if (this->_rwrows) 73 | this->_output.resize(input.rows(), this->_rwrows); 74 | else 75 | this->_output.resize(input.rows(), this->_wrows); 76 | 77 | /* quantize the input */ 78 | _quant_input_buffer = _realloc_align(_quant_input_buffer, _quant_input, input.rows() * input.cols() / SIMD_VSIZE); 79 | simd::Quantize(input.data(), _quant_input, input.rows(), input.cols()); 80 | 81 | simd::MatrixMult(_quant_input, _quant_weight, this->_output.data(), 82 | input.rows(), (this->_rwrows ? this->_rwrows : this->_wrows), this->_wcols, 83 | _subdict); 84 | 85 | /* add bias */ 86 | if (this->_bias->rows() > 0) 87 | { 88 | if (this->_rwrows) 89 | for (int i = 0; i < input.rows(); ++i) 90 | this->_output.row(i).noalias() += this->_rbias.transpose(); 91 | else 92 | for (int i = 0; i < input.rows(); ++i) 93 | this->_output.row(i).noalias() += this->_bias->transpose(); 94 | } 95 | } 96 | 97 | /* reduce a linear weigth matrix to a given vocabulary */ 98 | virtual void apply_subdictionary(const std::vector& v) override 99 | { 100 | this->_rwrows = v.size(); 101 | _subdict = v; 102 | this->_rbias.resize(v.size(), 1); 103 | /* adjust bias */ 104 | for (size_t i = 0; i < v.size(); i++) { 105 | this->_rbias.row(i) = this->_bias->row(v[i]); 106 | } 107 | } 108 | 109 | protected: 110 | std::shared_ptr _quant_weight_buffer; 111 | void* _quant_input_buffer; 112 | SIMD_TYPE* _quant_weight; 113 | SIMD_TYPE* _quant_input; 114 | std::vector _subdict; 115 | }; 116 | 117 | } 118 | } 119 | -------------------------------------------------------------------------------- /include/onmt/onmt.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "onmt/TranslatorFactory.h" 4 | #include "onmt/Threads.h" 5 | #include "onmt/Tokenizer.h" 6 | -------------------------------------------------------------------------------- /include/onmt/simd/MatrixMult.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #ifdef SIMD_SSE 10 | # define SIMD_TYPE __m128i 11 | # define SIMD_VSIZE 8 12 | #elif SIMD_AVX2 13 | # define SIMD_TYPE __m256i 14 | # define SIMD_VSIZE 16 15 | #elif SIMD_AVX512 16 | # define SIMD_TYPE __m512i 17 | # define SIMD_VSIZE 32 18 | #else 19 | # error "no simd type defined" 20 | #endif 21 | 22 | namespace onmt 23 | { 24 | namespace simd 25 | { 26 | 27 | // We quantize with 10 bits of precision. This works well "universally". 28 | // See the top of SSE_MatricMult.cc for more info on why. 29 | const float quant_mult = 1000.0; 30 | 31 | // If we quantize to n bits and then multiple the values together, the result will be quantized to n^2 bits. 32 | // So we must divide by 1.0/(n^2) to get back the original value. 33 | const float unquant_mult = 1.0 / (quant_mult * quant_mult); 34 | 35 | void Quantize(const float * input, 36 | SIMD_TYPE * output, 37 | int num_rows, 38 | int width); 39 | 40 | void MatrixMult(const SIMD_TYPE * A, 41 | const SIMD_TYPE * B, 42 | float * C, 43 | int num_A_rows, 44 | int num_B_rows, 45 | int width, 46 | const std::vector & subdict); 47 | 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /include/onmt/th/Env.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | #include "onmt/onmt_export.h" 7 | #include "onmt/th/Obj.h" 8 | 9 | namespace onmt 10 | { 11 | namespace th 12 | { 13 | 14 | class ONMT_EXPORT Env 15 | { 16 | public: 17 | ~Env(); 18 | 19 | Obj* get_object(int index); 20 | void set_object(Obj* obj, int index = -1); 21 | 22 | private: 23 | std::map _idx_obj; 24 | std::vector _list_obj; 25 | }; 26 | 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /include/onmt/th/Obj.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #include "onmt/onmt_export.h" 10 | #include "TH/THDiskFile.h" 11 | 12 | namespace onmt 13 | { 14 | namespace th 15 | { 16 | 17 | class Env; 18 | 19 | static const int dfLongSize = 8; 20 | 21 | enum class ObjType 22 | { 23 | NIL = 0, 24 | NUMBER = 1, 25 | STRING = 2, 26 | TABLE = 3, 27 | TORCH = 4, 28 | BOOLEAN = 5, 29 | FUNCTION = 6, 30 | RECUR_FUNCTION = 8, 31 | LEGACY_RECUR_FUNCTION = 7 32 | }; 33 | 34 | 35 | class ONMT_EXPORT Obj 36 | { 37 | public: 38 | Obj(ObjType type); 39 | virtual ~Obj() = default; 40 | virtual void read(THFile*, Env& Env) = 0; 41 | 42 | inline ObjType type() const 43 | { 44 | return _type; 45 | } 46 | private: 47 | const ObjType _type; 48 | }; 49 | 50 | 51 | Obj* read_obj(THFile *tf, Env &Env); 52 | 53 | 54 | class ONMT_EXPORT Nil: public Obj 55 | { 56 | public: 57 | Nil(); 58 | Nil(THFile*, Env& Env); 59 | void read(THFile*, Env& Env); 60 | }; 61 | 62 | 63 | class ONMT_EXPORT Number: public Obj 64 | { 65 | public: 66 | Number(); 67 | Number(THFile*, Env& Env); 68 | void read(THFile*, Env& Env); 69 | double get_value() const; 70 | private: 71 | double _value; 72 | }; 73 | 74 | 75 | class ONMT_EXPORT Boolean: public Obj 76 | { 77 | public: 78 | Boolean(); 79 | Boolean(THFile*, Env& Env); 80 | void read(THFile*, Env& Env); 81 | bool get_value() const; 82 | private: 83 | bool _value; 84 | }; 85 | 86 | 87 | class ONMT_EXPORT String: public Obj 88 | { 89 | public: 90 | String(); 91 | String(THFile*, Env& Env); 92 | void read(THFile*, Env& Env); 93 | const std::string& get_value() const; 94 | private: 95 | std::string _value; 96 | }; 97 | 98 | 99 | class ONMT_EXPORT Table: public Obj { 100 | public: 101 | enum class TableType 102 | { 103 | None, 104 | Array, 105 | Object, 106 | Map 107 | }; 108 | 109 | Table(); 110 | Table(THFile*, Env& Env); 111 | void read(THFile*, Env& Env); 112 | Table& insert(Obj* key, Obj* value); 113 | 114 | const std::map& get_map() const; 115 | const std::map& get_object() const; 116 | const std::vector& get_array() const; 117 | 118 | private: 119 | TableType _type; 120 | std::map _map; 121 | std::map _object; 122 | std::vector _array; 123 | }; 124 | 125 | 126 | class ONMT_EXPORT TorchObj: public Obj 127 | { 128 | public: 129 | TorchObj(const std::string& classname, int version); 130 | virtual ~TorchObj() = default; 131 | const std::string& get_classname() const; 132 | private: 133 | std::string _classname; 134 | int _version; 135 | }; 136 | 137 | 138 | class ONMT_EXPORT Creator 139 | { 140 | public: 141 | Creator(const std::string& classname, int version); 142 | virtual TorchObj* create(const std::string& classname, int version) = 0; 143 | }; 144 | 145 | 146 | template 147 | class CreatorImpl: public Creator 148 | { 149 | public: 150 | CreatorImpl(const std::string& classname, int version) 151 | : Creator(classname, version) 152 | { 153 | } 154 | 155 | virtual TorchObj* create(const std::string& classname, int version) 156 | { 157 | return new T(classname, version); 158 | } 159 | }; 160 | 161 | 162 | class ONMT_EXPORT Factory 163 | { 164 | public: 165 | static TorchObj* create(const std::string& classname, int version); 166 | static void register_it(const std::string& classname, int version, Creator* creator); 167 | private: 168 | static std::map, Creator*>& get_table(); 169 | }; 170 | 171 | 172 | // forward declaration 173 | template 174 | class Tensor; 175 | 176 | 177 | template 178 | class Storage: public TorchObj 179 | { 180 | public: 181 | Storage(const std::string &classname, int version); 182 | Storage(const std::string &classname, int version, THFile*, Env&); 183 | ~Storage(); 184 | 185 | void read(THFile*, Env& Env); 186 | const T* get_data() const; 187 | long get_size() const; 188 | void release(); 189 | 190 | private: 191 | friend class Tensor; 192 | static const CreatorImpl< Storage > creator; 193 | 194 | long _size; 195 | T *_data; 196 | }; 197 | 198 | 199 | template 200 | class Tensor: public TorchObj 201 | { 202 | public: 203 | Tensor(const std::string& classname, int version); 204 | Tensor(const std::string& classname, int version, THFile*, Env&); 205 | ~Tensor(); 206 | 207 | void read(THFile*, Env&); 208 | Obj* get_storage() const; 209 | const long* get_size() const; 210 | int get_dimension() const; 211 | long get_storage_offset() const; 212 | void release_storage() const; 213 | private: 214 | static const CreatorImpl< Tensor > creator; 215 | 216 | int _n_dimension; 217 | long* _size; 218 | long* _stride; 219 | long _storage_offset; 220 | Obj* _thstorage; 221 | }; 222 | 223 | 224 | class ONMT_EXPORT Class: public TorchObj 225 | { 226 | public: 227 | Class(const std::string& classname, int version); 228 | Class(const std::string& classname, int version, THFile*, Env&); 229 | 230 | void read(THFile*, Env&); 231 | Obj* get_data() const; 232 | 233 | private: 234 | Obj* _data; 235 | }; 236 | 237 | } 238 | } 239 | 240 | #include "Obj.hxx" 241 | -------------------------------------------------------------------------------- /include/onmt/th/Obj.hxx: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace onmt 6 | { 7 | namespace th 8 | { 9 | 10 | // Storage 11 | template 12 | Storage::Storage(const std::string& classname, int version) 13 | : TorchObj(classname, version) 14 | , _data(nullptr) 15 | { 16 | } 17 | 18 | template 19 | Storage::~Storage() 20 | { 21 | if (_data != nullptr) 22 | release(); 23 | } 24 | 25 | template 26 | const T* Storage::get_data() const 27 | { 28 | return _data; 29 | } 30 | 31 | template 32 | long Storage::get_size() const 33 | { 34 | return _size; 35 | } 36 | 37 | template 38 | void Storage::release() 39 | { 40 | THFree(_data); 41 | _data = nullptr; 42 | } 43 | 44 | 45 | // Tensor 46 | template 47 | Tensor::Tensor(const std::string& classname, int version) 48 | : TorchObj(classname, version) 49 | , _n_dimension(0) 50 | , _size(nullptr) 51 | , _stride(nullptr) 52 | , _thstorage(nullptr) 53 | { 54 | } 55 | 56 | template 57 | Tensor::~Tensor() 58 | { 59 | THFree(_size); 60 | THFree(_stride); 61 | } 62 | 63 | template 64 | Obj* Tensor::get_storage() const 65 | { 66 | return _thstorage; 67 | } 68 | 69 | template 70 | const long* Tensor::get_size() const 71 | { 72 | return _size; 73 | } 74 | 75 | template 76 | int Tensor::get_dimension() const 77 | { 78 | return _n_dimension; 79 | } 80 | 81 | template 82 | long Tensor::get_storage_offset() const 83 | { 84 | return _storage_offset; 85 | } 86 | 87 | template 88 | void Tensor::release_storage() const 89 | { 90 | dynamic_cast*>(_thstorage)->release(); 91 | } 92 | 93 | template 94 | void Tensor::read(THFile* tf, Env& env) 95 | { 96 | if (!tf) 97 | return; 98 | 99 | _n_dimension = THFile_readIntScalar(tf); 100 | _size = reinterpret_cast(THAlloc(dfLongSize * _n_dimension)); 101 | _stride = reinterpret_cast(THAlloc(dfLongSize * _n_dimension)); 102 | THFile_readLongRaw(tf, _size, _n_dimension); 103 | THFile_readLongRaw(tf, _stride, _n_dimension); 104 | _storage_offset = THFile_readLongScalar(tf)-1; 105 | _thstorage = read_obj(tf, env); 106 | } 107 | 108 | } 109 | } 110 | -------------------------------------------------------------------------------- /include/onmt/th/Utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "onmt/th/Obj.h" 4 | 5 | namespace onmt 6 | { 7 | namespace th 8 | { 9 | 10 | int get_number(Table* module_data, const std::string& name); 11 | bool get_boolean(Table* module_data, const std::string& name); 12 | 13 | template 14 | T get_field(Obj* obj, const std::string& name) 15 | { 16 | auto table = dynamic_cast(obj); 17 | 18 | if (!table) 19 | return nullptr; 20 | 21 | auto it = table->get_object().find(name); 22 | 23 | if (it == table->get_object().end()) 24 | return nullptr; 25 | 26 | return dynamic_cast(it->second); 27 | } 28 | 29 | template 30 | T get_scalar(Table* module_data, const std::string& name) 31 | { 32 | Number* dim = get_field(module_data, name); 33 | return dim ? static_cast(dim->get_value()) : -1; 34 | } 35 | 36 | template 37 | std::vector get_storage_as_vector(Obj* obj, const std::string& name) 38 | { 39 | Storage* storage = get_field*>(obj, name); 40 | 41 | const T* data = storage->get_data(); 42 | auto size = storage->get_size(); 43 | 44 | std::vector vec; 45 | vec.reserve(size); 46 | 47 | for (int i = 0; i < size; ++i) 48 | vec.push_back(data[i]); 49 | 50 | return vec; 51 | } 52 | 53 | template 54 | const T* get_tensor_data(Tensor* tensor) 55 | { 56 | auto storage = dynamic_cast*>(tensor->get_storage()); 57 | return storage->get_data() + tensor->get_storage_offset(); 58 | } 59 | 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /lib/TH/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | project(TH) 2 | 3 | set(PUBLIC_HEADERS 4 | THGeneral.h 5 | THFile.h 6 | THDiskFile.h 7 | ) 8 | 9 | add_library(${PROJECT_NAME} 10 | THGeneral.c 11 | THFile.c 12 | THDiskFile.c 13 | ) 14 | 15 | target_include_directories(${PROJECT_NAME} PUBLIC 16 | ${CMAKE_CURRENT_SOURCE_DIR} 17 | ${PROJECT_BINARY_DIR} 18 | ) 19 | 20 | include(GNUInstallDirs) 21 | include(GenerateExportHeader) 22 | string(TOLOWER ${PROJECT_NAME} PROJECT_NAME_LOWER) 23 | generate_export_header(${PROJECT_NAME} EXPORT_FILE_NAME ${PROJECT_BINARY_DIR}/TH/${PROJECT_NAME_LOWER}_export.h) 24 | 25 | install( 26 | TARGETS ${PROJECT_NAME} 27 | RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} 28 | ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} 29 | LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} 30 | ) 31 | install( 32 | FILES ${PUBLIC_HEADERS} "${PROJECT_BINARY_DIR}/TH/${PROJECT_NAME_LOWER}_export.h" 33 | DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/TH" 34 | ) 35 | -------------------------------------------------------------------------------- /lib/TH/COPYRIGHT.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) 2 | Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) 3 | Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) 4 | Copyright (c) 2011-2013 NYU (Clement Farabet) 5 | Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) 6 | Copyright (c) 2006 Idiap Research Institute (Samy Bengio) 7 | Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) 8 | 9 | All rights reserved. 10 | 11 | Redistribution and use in source and binary forms, with or without 12 | modification, are permitted provided that the following conditions are met: 13 | 14 | 1. Redistributions of source code must retain the above copyright 15 | notice, this list of conditions and the following disclaimer. 16 | 17 | 2. Redistributions in binary form must reproduce the above copyright 18 | notice, this list of conditions and the following disclaimer in the 19 | documentation and/or other materials provided with the distribution. 20 | 21 | 3. Neither the names of Deepmind Technologies, NYU, NEC Laboratories America 22 | and IDIAP Research Institute nor the names of its contributors may be 23 | used to endorse or promote products derived from this software without 24 | specific prior written permission. 25 | 26 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 27 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 28 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 29 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 30 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 31 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 32 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 33 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 34 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 35 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 36 | POSSIBILITY OF SUCH DAMAGE. 37 | -------------------------------------------------------------------------------- /lib/TH/THDiskFile.h: -------------------------------------------------------------------------------- 1 | #ifndef TH_DISK_FILE_INC 2 | #define TH_DISK_FILE_INC 3 | 4 | #include "THFile.h" 5 | 6 | TH_API THFile *THDiskFile_new(const char *name, const char *mode, int isQuiet); 7 | TH_API THFile *THPipeFile_new(const char *name, const char *mode, int isQuiet); 8 | 9 | TH_API const char *THDiskFile_name(THFile *self); 10 | 11 | TH_API int THDiskFile_isLittleEndianCPU(void); 12 | TH_API int THDiskFile_isBigEndianCPU(void); 13 | TH_API void THDiskFile_nativeEndianEncoding(THFile *self); 14 | TH_API void THDiskFile_littleEndianEncoding(THFile *self); 15 | TH_API void THDiskFile_bigEndianEncoding(THFile *self); 16 | TH_API void THDiskFile_longSize(THFile *self, int size); 17 | TH_API void THDiskFile_noBuffer(THFile *self); 18 | 19 | #endif 20 | -------------------------------------------------------------------------------- /lib/TH/THFile.c: -------------------------------------------------------------------------------- 1 | #include "THFile.h" 2 | #include "THFilePrivate.h" 3 | 4 | #define IMPLEMENT_THFILE_R(TYPEC, TYPE) \ 5 | size_t THFile_read##TYPEC##Raw(THFile *self, TYPE *data, size_t n) \ 6 | { \ 7 | return (*self->vtable->read##TYPEC)(self, data, n); \ 8 | } 9 | 10 | IMPLEMENT_THFILE_R(Byte, unsigned char) 11 | IMPLEMENT_THFILE_R(Char, char) 12 | IMPLEMENT_THFILE_R(Short, short) 13 | IMPLEMENT_THFILE_R(Int, int) 14 | IMPLEMENT_THFILE_R(Long, long) 15 | IMPLEMENT_THFILE_R(Float, float) 16 | IMPLEMENT_THFILE_R(Double, double) 17 | 18 | size_t THFile_readStringRaw(THFile *self, const char *format, char **str_) 19 | { 20 | return self->vtable->readString(self, format, str_); 21 | } 22 | 23 | void THFile_synchronize(THFile *self) 24 | { 25 | self->vtable->synchronize(self); 26 | } 27 | 28 | void THFile_seek(THFile *self, size_t position) 29 | { 30 | self->vtable->seek(self, position); 31 | } 32 | 33 | void THFile_seekEnd(THFile *self) 34 | { 35 | self->vtable->seekEnd(self); 36 | } 37 | 38 | size_t THFile_position(THFile *self) 39 | { 40 | return self->vtable->position(self); 41 | } 42 | 43 | void THFile_close(THFile *self) 44 | { 45 | self->vtable->close(self); 46 | } 47 | 48 | void THFile_free(THFile *self) 49 | { 50 | self->vtable->free(self); 51 | } 52 | 53 | int THFile_isOpened(THFile *self) 54 | { 55 | return self->vtable->isOpened(self); 56 | } 57 | 58 | #define IMPLEMENT_THFILE_FLAGS(FLAG) \ 59 | int THFile_##FLAG(THFile *self) \ 60 | { \ 61 | return self->FLAG; \ 62 | } 63 | 64 | IMPLEMENT_THFILE_FLAGS(isQuiet) 65 | IMPLEMENT_THFILE_FLAGS(isReadable) 66 | IMPLEMENT_THFILE_FLAGS(isWritable) 67 | IMPLEMENT_THFILE_FLAGS(isBinary) 68 | IMPLEMENT_THFILE_FLAGS(isAutoSpacing) 69 | IMPLEMENT_THFILE_FLAGS(hasError) 70 | 71 | void THFile_binary(THFile *self) 72 | { 73 | self->isBinary = 1; 74 | } 75 | 76 | void THFile_ascii(THFile *self) 77 | { 78 | self->isBinary = 0; 79 | } 80 | 81 | void THFile_autoSpacing(THFile *self) 82 | { 83 | self->isAutoSpacing = 1; 84 | } 85 | 86 | void THFile_noAutoSpacing(THFile *self) 87 | { 88 | self->isAutoSpacing = 0; 89 | } 90 | 91 | void THFile_quiet(THFile *self) 92 | { 93 | self->isQuiet = 1; 94 | } 95 | 96 | void THFile_pedantic(THFile *self) 97 | { 98 | self->isQuiet = 0; 99 | } 100 | 101 | void THFile_clearError(THFile *self) 102 | { 103 | self->hasError = 0; 104 | } 105 | 106 | #define IMPLEMENT_THFILE_SCALAR(TYPEC, TYPE) \ 107 | TYPE THFile_read##TYPEC##Scalar(THFile *self) \ 108 | { \ 109 | TYPE scalar; \ 110 | THFile_read##TYPEC##Raw(self, &scalar, 1); \ 111 | return scalar; \ 112 | } 113 | 114 | IMPLEMENT_THFILE_SCALAR(Byte, unsigned char) 115 | IMPLEMENT_THFILE_SCALAR(Char, char) 116 | IMPLEMENT_THFILE_SCALAR(Short, short) 117 | IMPLEMENT_THFILE_SCALAR(Int, int) 118 | IMPLEMENT_THFILE_SCALAR(Long, long) 119 | IMPLEMENT_THFILE_SCALAR(Float, float) 120 | IMPLEMENT_THFILE_SCALAR(Double, double) 121 | -------------------------------------------------------------------------------- /lib/TH/THFile.h: -------------------------------------------------------------------------------- 1 | #ifndef TH_FILE_INC 2 | #define TH_FILE_INC 3 | 4 | #include "THGeneral.h" 5 | 6 | typedef struct THFile__ THFile; 7 | 8 | TH_API int THFile_isOpened(THFile *self); 9 | TH_API int THFile_isQuiet(THFile *self); 10 | TH_API int THFile_isReadable(THFile *self); 11 | TH_API int THFile_isWritable(THFile *self); 12 | TH_API int THFile_isBinary(THFile *self); 13 | TH_API int THFile_isAutoSpacing(THFile *self); 14 | TH_API int THFile_hasError(THFile *self); 15 | 16 | TH_API void THFile_binary(THFile *self); 17 | TH_API void THFile_ascii(THFile *self); 18 | TH_API void THFile_autoSpacing(THFile *self); 19 | TH_API void THFile_noAutoSpacing(THFile *self); 20 | TH_API void THFile_quiet(THFile *self); 21 | TH_API void THFile_pedantic(THFile *self); 22 | TH_API void THFile_clearError(THFile *self); 23 | 24 | /* scalar */ 25 | TH_API unsigned char THFile_readByteScalar(THFile *self); 26 | TH_API char THFile_readCharScalar(THFile *self); 27 | TH_API short THFile_readShortScalar(THFile *self); 28 | TH_API int THFile_readIntScalar(THFile *self); 29 | TH_API long THFile_readLongScalar(THFile *self); 30 | TH_API float THFile_readFloatScalar(THFile *self); 31 | TH_API double THFile_readDoubleScalar(THFile *self); 32 | 33 | /* raw */ 34 | TH_API size_t THFile_readByteRaw(THFile *self, unsigned char *data, size_t n); 35 | TH_API size_t THFile_readCharRaw(THFile *self, char *data, size_t n); 36 | TH_API size_t THFile_readShortRaw(THFile *self, short *data, size_t n); 37 | TH_API size_t THFile_readIntRaw(THFile *self, int *data, size_t n); 38 | TH_API size_t THFile_readLongRaw(THFile *self, long *data, size_t n); 39 | TH_API size_t THFile_readFloatRaw(THFile *self, float *data, size_t n); 40 | TH_API size_t THFile_readDoubleRaw(THFile *self, double *data, size_t n); 41 | TH_API size_t THFile_readStringRaw(THFile *self, const char *format, char **str_); /* you must deallocate str_ */ 42 | 43 | TH_API void THFile_synchronize(THFile *self); 44 | TH_API void THFile_seek(THFile *self, size_t position); 45 | TH_API void THFile_seekEnd(THFile *self); 46 | TH_API size_t THFile_position(THFile *self); 47 | TH_API void THFile_close(THFile *self); 48 | TH_API void THFile_free(THFile *self); 49 | 50 | #endif 51 | -------------------------------------------------------------------------------- /lib/TH/THFilePrivate.h: -------------------------------------------------------------------------------- 1 | struct THFile__ 2 | { 3 | struct THFileVTable *vtable; 4 | 5 | int isQuiet; 6 | int isReadable; 7 | int isWritable; 8 | int isBinary; 9 | int isAutoSpacing; 10 | int hasError; 11 | }; 12 | 13 | /* virtual table definition */ 14 | 15 | struct THFileVTable 16 | { 17 | int (*isOpened)(THFile *self); 18 | 19 | size_t (*readByte)(THFile *self, unsigned char *data, size_t n); 20 | size_t (*readChar)(THFile *self, char *data, size_t n); 21 | size_t (*readShort)(THFile *self, short *data, size_t n); 22 | size_t (*readInt)(THFile *self, int *data, size_t n); 23 | size_t (*readLong)(THFile *self, long *data, size_t n); 24 | size_t (*readFloat)(THFile *self, float *data, size_t n); 25 | size_t (*readDouble)(THFile *self, double *data, size_t n); 26 | size_t (*readString)(THFile *self, const char *format, char **str_); 27 | 28 | void (*synchronize)(THFile *self); 29 | void (*seek)(THFile *self, size_t position); 30 | void (*seekEnd)(THFile *self); 31 | size_t (*position)(THFile *self); 32 | void (*close)(THFile *self); 33 | void (*free)(THFile *self); 34 | }; 35 | -------------------------------------------------------------------------------- /lib/TH/THGeneral.c: -------------------------------------------------------------------------------- 1 | #include "THGeneral.h" 2 | 3 | #if (defined(__unix) || defined(_WIN32)) 4 | #if defined(__FreeBSD__) 5 | #include 6 | #else 7 | #include 8 | #endif 9 | #elif defined(__APPLE__) 10 | #include 11 | #endif 12 | 13 | /* Torch Error Handling */ 14 | static void defaultErrorHandlerFunction(const char *msg) 15 | { 16 | printf("$ Error: %s\n", msg); 17 | exit(-1); 18 | } 19 | 20 | void _THError(const char *file, const int line, const char *fmt, ...) 21 | { 22 | char msg[2048]; 23 | va_list args; 24 | 25 | /* vasprintf not standard */ 26 | /* vsnprintf: how to handle if does not exists? */ 27 | va_start(args, fmt); 28 | int n = vsnprintf(msg, 2048, fmt, args); 29 | va_end(args); 30 | 31 | if(n < 2048) { 32 | snprintf(msg + n, 2048 - n, " at %s:%d", file, line); 33 | } 34 | 35 | defaultErrorHandlerFunction(msg); 36 | } 37 | 38 | void _THAssertionFailed(const char *file, const int line, const char *exp, const char *fmt, ...) { 39 | char msg[1024]; 40 | va_list args; 41 | va_start(args, fmt); 42 | vsnprintf(msg, 1024, fmt, args); 43 | va_end(args); 44 | _THError(file, line, "Assertion `%s' failed. %s", exp, msg); 45 | } 46 | 47 | void _THArgCheck(const char *file, int line, int condition, int argNumber, const char *fmt, ...) 48 | { 49 | if(!condition) { 50 | char msg[2048]; 51 | va_list args; 52 | 53 | /* vasprintf not standard */ 54 | /* vsnprintf: how to handle if does not exists? */ 55 | va_start(args, fmt); 56 | int n = vsnprintf(msg, 2048, fmt, args); 57 | va_end(args); 58 | 59 | if(n < 2048) { 60 | snprintf(msg + n, 2048 - n, " at %s:%d", file, line); 61 | } 62 | 63 | defaultErrorHandlerFunction(msg); 64 | } 65 | } 66 | 67 | void* THRealloc(void *ptr, ptrdiff_t size) 68 | { 69 | return realloc(ptr, size); 70 | } 71 | 72 | void* THAlloc(ptrdiff_t size) 73 | { 74 | return malloc(size); 75 | } 76 | 77 | void THFree(void *ptr) 78 | { 79 | free(ptr); 80 | } 81 | -------------------------------------------------------------------------------- /lib/TH/THGeneral.h: -------------------------------------------------------------------------------- 1 | #ifndef TH_GENERAL_INC 2 | #define TH_GENERAL_INC 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include "TH/th_export.h" 15 | 16 | #ifdef __cplusplus 17 | # define TH_EXTERNC extern "C" 18 | #else 19 | # define TH_EXTERNC extern 20 | #endif 21 | 22 | #define TH_API TH_EXTERNC TH_EXPORT 23 | 24 | TH_API void _THError(const char *file, const int line, const char *fmt, ...); 25 | TH_API void _THAssertionFailed(const char *file, const int line, const char *exp, const char *fmt, ...); 26 | TH_API void _THArgCheck(const char *file, int line, int condition, int argNumber, const char *fmt, ...); 27 | TH_API void* THRealloc(void *ptr, ptrdiff_t size); 28 | TH_API void* THAlloc(ptrdiff_t size); 29 | TH_API void THFree(void *ptr); 30 | 31 | #define THError(...) _THError(__FILE__, __LINE__, __VA_ARGS__) 32 | 33 | #define THCleanup(...) __VA_ARGS__ 34 | 35 | #define THArgCheck(...) \ 36 | do { \ 37 | _THArgCheck(__FILE__, __LINE__, __VA_ARGS__); \ 38 | } while(0) 39 | 40 | #define THArgCheckWithCleanup(condition, cleanup, ...) \ 41 | do if (!(condition)) { \ 42 | cleanup \ 43 | _THArgCheck(__FILE__, __LINE__, 0, __VA_ARGS__); \ 44 | } while(0) 45 | 46 | #define THAssert(exp) \ 47 | do { \ 48 | if (!(exp)) { \ 49 | _THAssertionFailed(__FILE__, __LINE__, #exp, ""); \ 50 | } \ 51 | } while(0) 52 | 53 | #define THAssertMsg(exp, ...) \ 54 | do { \ 55 | if (!(exp)) { \ 56 | _THAssertionFailed(__FILE__, __LINE__, #exp, __VA_ARGS__); \ 57 | } \ 58 | } while(0) 59 | 60 | #endif 61 | -------------------------------------------------------------------------------- /src/Dictionary.cc: -------------------------------------------------------------------------------- 1 | #include "onmt/Dictionary.h" 2 | 3 | #ifdef ANDROID_GNUSTL_COMPAT 4 | # include "onmt/android_gnustl_compat.h" 5 | #endif 6 | 7 | #include 8 | 9 | #include "onmt/th/Utils.h" 10 | 11 | namespace onmt 12 | { 13 | 14 | const size_t Dictionary::pad_id = 0; 15 | const size_t Dictionary::unk_id = 1; 16 | const size_t Dictionary::bos_id = 2; 17 | const size_t Dictionary::eos_id = 3; 18 | 19 | Dictionary::Dictionary() 20 | { 21 | } 22 | 23 | Dictionary::Dictionary(th::Class* dict) 24 | { 25 | load(dict); 26 | } 27 | 28 | void Dictionary::load(th::Class* dict) 29 | { 30 | auto dict_data = dynamic_cast(dict->get_data()); 31 | auto id2word = th::get_field(dict_data, "idxToLabel"); 32 | 33 | auto array = id2word->get_array(); 34 | 35 | for (size_t i = 0; i < array.size(); ++i) 36 | { 37 | const std::string& word = dynamic_cast(array[i])->get_value(); 38 | _id2word.push_back(word); 39 | _word2id[word] = i; 40 | } 41 | } 42 | 43 | size_t Dictionary::get_size() const 44 | { 45 | return _id2word.size(); 46 | } 47 | 48 | size_t Dictionary::get_word_id(const std::string& word) const 49 | { 50 | auto it = _word2id.find(word); 51 | 52 | if (it == _word2id.cend()) 53 | return unk_id; 54 | 55 | return it->second; 56 | } 57 | 58 | const std::string& Dictionary::get_id_word(size_t id) const 59 | { 60 | return _id2word[id]; 61 | } 62 | 63 | } 64 | -------------------------------------------------------------------------------- /src/ITranslator.cc: -------------------------------------------------------------------------------- 1 | #include "onmt/ITranslator.h" 2 | #include "onmt/SpaceTokenizer.h" 3 | #include 4 | 5 | namespace onmt 6 | { 7 | 8 | std::string 9 | ITranslator::translate(const std::string& text, 10 | const TranslationOptions& options) 11 | { 12 | return translate(text, SpaceTokenizer::get_instance(), options); 13 | } 14 | 15 | std::string 16 | ITranslator::translate(const std::string& text, 17 | float& score, 18 | size_t& count_tgt_words, 19 | size_t& count_tgt_unk_words, 20 | size_t& count_src_words, 21 | size_t& count_src_unk_words, 22 | const TranslationOptions& options) 23 | { 24 | return translate(text, SpaceTokenizer::get_instance(), score, count_tgt_words, count_tgt_unk_words, count_src_words, count_src_unk_words, options); 25 | } 26 | 27 | std::string 28 | ITranslator::translate(const std::string& text, 29 | ITokenizer& tokenizer, 30 | const TranslationOptions& options) 31 | { 32 | float score; 33 | size_t count_tgt_words, count_tgt_unk_words, count_src_words, count_src_unk_words; 34 | return translate(text, tokenizer, score, count_tgt_words, count_tgt_unk_words, count_src_words, count_src_unk_words, options); 35 | } 36 | 37 | std::string 38 | ITranslator::translate(const std::string& text, 39 | ITokenizer& tokenizer, 40 | float& score, 41 | size_t& count_tgt_words, 42 | size_t& count_tgt_unk_words, 43 | size_t& count_src_words, 44 | size_t& count_src_unk_words, 45 | const TranslationOptions& options) 46 | { 47 | std::vector best_scores; 48 | std::vector best_count_tgt_words, best_count_tgt_unk_words; 49 | auto res = get_translations(text, tokenizer, best_scores, best_count_tgt_words, best_count_tgt_unk_words, count_src_words, count_src_unk_words, options); 50 | score = best_scores.at(0); 51 | count_tgt_words = best_count_tgt_words.at(0); 52 | count_tgt_unk_words = best_count_tgt_unk_words.at(0); 53 | return res.at(0); 54 | } 55 | 56 | std::vector 57 | ITranslator::get_translations(const std::string& text, 58 | const TranslationOptions& options) 59 | { 60 | return get_translations(text, SpaceTokenizer::get_instance(), options); 61 | } 62 | 63 | std::vector 64 | ITranslator::get_translations(const std::string& text, 65 | std::vector& scores, 66 | std::vector& count_tgt_words, 67 | std::vector& count_tgt_unk_words, 68 | size_t& count_src_words, 69 | size_t& count_src_unk_words, 70 | const TranslationOptions& options) 71 | { 72 | return get_translations(text, SpaceTokenizer::get_instance(), scores, count_tgt_words, count_tgt_unk_words, count_src_words, count_src_unk_words, options); 73 | } 74 | 75 | std::vector 76 | ITranslator::get_translations(const std::string& text, 77 | ITokenizer& tokenizer, 78 | const TranslationOptions& options) 79 | { 80 | std::vector scores; 81 | std::vector count_tgt_words, count_tgt_unk_words; 82 | size_t count_src_words, count_src_unk_words; 83 | return get_translations(text, tokenizer, scores, count_tgt_words, count_tgt_unk_words, count_src_words, count_src_unk_words, options); 84 | } 85 | 86 | TranslationResult 87 | ITranslator::translate(const std::vector& tokens, 88 | const std::vector >& features, 89 | const TranslationOptions& options) 90 | { 91 | size_t count_src_unk_words; 92 | return translate(tokens, features, count_src_unk_words, options); 93 | } 94 | 95 | std::vector 96 | ITranslator::translate_batch(const std::vector& texts, 97 | const TranslationOptions& options) 98 | { 99 | return translate_batch(texts, SpaceTokenizer::get_instance(), options); 100 | } 101 | 102 | std::vector 103 | ITranslator::translate_batch(const std::vector& texts, 104 | std::vector& scores, 105 | std::vector& count_tgt_words, 106 | std::vector& count_tgt_unk_words, 107 | std::vector& count_src_words, 108 | std::vector& count_src_unk_words, 109 | const TranslationOptions& options) 110 | { 111 | return translate_batch(texts, SpaceTokenizer::get_instance(), scores, count_tgt_words, count_tgt_unk_words, count_src_words, count_src_unk_words, options); 112 | } 113 | 114 | std::vector 115 | ITranslator::translate_batch(const std::vector& texts, 116 | ITokenizer& tokenizer, 117 | const TranslationOptions& options) 118 | { 119 | std::vector scores; 120 | std::vector count_tgt_words, count_tgt_unk_words; 121 | std::vector count_src_words, count_src_unk_words; 122 | return translate_batch(texts, tokenizer, scores, count_tgt_words, count_tgt_unk_words, count_src_words, count_src_unk_words, options); 123 | } 124 | 125 | std::vector 126 | ITranslator::translate_batch(const std::vector& texts, 127 | ITokenizer& tokenizer, 128 | std::vector& scores, 129 | std::vector& count_tgt_words, 130 | std::vector& count_tgt_unk_words, 131 | std::vector& count_src_words, 132 | std::vector& count_src_unk_words, 133 | const TranslationOptions& options) 134 | { 135 | std::vector translations; 136 | scores.clear(); 137 | count_tgt_words.clear(); 138 | count_tgt_unk_words.clear(); 139 | std::vector > batch_scores; 140 | std::vector > batch_count_tgt_words, batch_count_tgt_unk_words; 141 | auto res = get_translations_batch(texts, tokenizer, batch_scores, batch_count_tgt_words, batch_count_tgt_unk_words, count_src_words, count_src_unk_words, options); 142 | for (size_t i = 0; i < res.size(); ++i) 143 | { 144 | translations.push_back(std::move(res[i].at(0))); 145 | scores.push_back(batch_scores[i].at(0)); 146 | count_tgt_words.push_back(batch_count_tgt_words[i].at(0)); 147 | count_tgt_unk_words.push_back(batch_count_tgt_unk_words[i].at(0)); 148 | } 149 | 150 | return translations; 151 | } 152 | 153 | std::vector > 154 | ITranslator::get_translations_batch(const std::vector& texts, 155 | const TranslationOptions& options) 156 | { 157 | return get_translations_batch(texts, SpaceTokenizer::get_instance(), options); 158 | } 159 | 160 | std::vector > 161 | ITranslator::get_translations_batch(const std::vector& texts, 162 | std::vector >& scores, 163 | std::vector >& count_tgt_words, 164 | std::vector >& count_tgt_unk_words, 165 | std::vector& count_src_words, 166 | std::vector& count_src_unk_words, 167 | const TranslationOptions& options) 168 | { 169 | return get_translations_batch(texts, SpaceTokenizer::get_instance(), scores, count_tgt_words, count_tgt_unk_words, count_src_words, count_src_unk_words, options); 170 | } 171 | 172 | std::vector > 173 | ITranslator::get_translations_batch(const std::vector& texts, 174 | ITokenizer& tokenizer, 175 | const TranslationOptions& options) 176 | { 177 | std::vector > scores; 178 | std::vector > count_tgt_words, count_tgt_unk_words; 179 | std::vector count_src_words, count_src_unk_words; 180 | return get_translations_batch(texts, tokenizer, scores, count_tgt_words, count_tgt_unk_words, count_src_words, count_src_unk_words, options); 181 | } 182 | 183 | TranslationResult 184 | ITranslator::translate_batch(const std::vector >& batch_tokens, 185 | const std::vector > >& batch_features, 186 | const TranslationOptions& options) 187 | { 188 | std::vector batch_count_src_unk_words; 189 | return translate_batch(batch_tokens, batch_features, batch_count_src_unk_words, options); 190 | } 191 | 192 | } 193 | -------------------------------------------------------------------------------- /src/Logger.cc: -------------------------------------------------------------------------------- 1 | #include "onmt/Logger.h" 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | namespace onmt 14 | { 15 | 16 | boost::log::sources::severity_logger_mt Logger::_lg; 17 | 18 | void Logger::init(const std::string& log_file, bool disable_logs, const std::string& log_level) 19 | { 20 | if (disable_logs || log_level == "NONE") 21 | { 22 | boost::log::core::get()->set_logging_enabled(false); 23 | } 24 | else 25 | { 26 | if (!log_file.empty()) 27 | { 28 | boost::log::add_file_log 29 | ( 30 | boost::log::keywords::file_name = log_file, 31 | boost::log::keywords::auto_flush = true, 32 | boost::log::keywords::format = 33 | ( 34 | boost::log::expressions::stream 35 | << '[' << boost::log::expressions::format_date_time("TimeStamp", "%Y-%m-%d %H:%M:%S") 36 | << ' ' << boost::log::trivial::severity 37 | << "] " << boost::log::expressions::smessage 38 | ) 39 | ); 40 | } 41 | 42 | boost::log::trivial::severity_level level = boost::log::trivial::info; 43 | if (!log_level.empty() && log_level != "INFO") 44 | { 45 | if (log_level == "DEBUG") 46 | level = boost::log::trivial::debug; 47 | else if (log_level == "WARNING") 48 | level = boost::log::trivial::warning; 49 | else if (log_level == "ERROR") 50 | level = boost::log::trivial::error; 51 | else 52 | std::cerr << "invalid log level specified: " << log_level << "; using default log level" << std::endl; 53 | } 54 | 55 | auto core = boost::log::core::get(); 56 | core->set_filter 57 | ( 58 | boost::log::trivial::severity >= level 59 | ); 60 | 61 | core->add_global_attribute(boost::log::aux::default_attribute_names::timestamp(), boost::log::attributes::local_clock()); 62 | } 63 | } 64 | 65 | boost::log::sources::severity_logger_mt& Logger::lg() 66 | { 67 | return _lg; 68 | } 69 | 70 | } 71 | -------------------------------------------------------------------------------- /src/PhraseTable.cc: -------------------------------------------------------------------------------- 1 | #include "onmt/PhraseTable.h" 2 | 3 | #include 4 | 5 | namespace onmt 6 | { 7 | 8 | static const std::string separator = "|||"; 9 | 10 | PhraseTable::PhraseTable(const std::string& file) 11 | { 12 | if (!file.empty()) 13 | { 14 | std::ifstream in(file.c_str()); 15 | std::string line; 16 | 17 | while (std::getline(in, line)) 18 | { 19 | size_t sep_idx = line.find(separator); 20 | 21 | std::string src = line.substr(0, sep_idx); 22 | std::string tgt = line.substr(sep_idx + separator.length()); 23 | 24 | _src_to_tgt[src] = tgt; 25 | } 26 | } 27 | } 28 | 29 | bool PhraseTable::is_empty() const 30 | { 31 | return _src_to_tgt.empty(); 32 | } 33 | 34 | size_t PhraseTable::get_size() const 35 | { 36 | return _src_to_tgt.size(); 37 | } 38 | 39 | std::string PhraseTable::lookup(const std::string& src) const 40 | { 41 | auto it = _src_to_tgt.find(src); 42 | 43 | if (it == _src_to_tgt.cend()) 44 | return ""; 45 | 46 | return it->second; 47 | } 48 | 49 | } 50 | -------------------------------------------------------------------------------- /src/Profiler.cc: -------------------------------------------------------------------------------- 1 | #include "onmt/Profiler.h" 2 | 3 | #include 4 | #include 5 | 6 | namespace onmt 7 | { 8 | int Profiler::_counter = 0; 9 | std::mutex Profiler::_profiler_mutex; 10 | 11 | Profiler::Profiler(bool enabled, bool start_chrono) 12 | : _enabled(enabled) 13 | , _total_time(std::chrono::microseconds::zero()) 14 | { 15 | if (start_chrono) 16 | start(); 17 | reset(); 18 | /* assign unique id */ 19 | std::lock_guard lock(_profiler_mutex); 20 | _id = ++_counter; 21 | } 22 | 23 | Profiler::~Profiler() 24 | { 25 | if (_enabled) 26 | { 27 | std::cerr << *this; 28 | } 29 | } 30 | 31 | void Profiler::enable() 32 | { 33 | _enabled = true; 34 | } 35 | 36 | void Profiler::disable() 37 | { 38 | _enabled = false; 39 | } 40 | 41 | void Profiler::reset() 42 | { 43 | _cumulated.clear(); 44 | } 45 | 46 | int Profiler::get_id() const 47 | { 48 | return _id; 49 | } 50 | 51 | void Profiler::start() 52 | { 53 | if (_enabled) 54 | { 55 | _start.emplace(std::chrono::high_resolution_clock::now()); 56 | } 57 | } 58 | 59 | void Profiler::stop(const std::string& module_name) 60 | { 61 | if (_enabled) 62 | { 63 | auto diff = std::chrono::high_resolution_clock::now() - _start.top(); 64 | auto elapsed = std::chrono::duration_cast(diff); 65 | _total_time += elapsed; 66 | _start.pop(); 67 | _cumulated[module_name] += elapsed; 68 | } 69 | } 70 | 71 | std::ostream& operator<<(std::ostream& os, const Profiler& profiler) 72 | { 73 | std::lock_guard lock(Profiler::_profiler_mutex); 74 | // Sort accumulated time. 75 | std::vector > samples; 76 | for (const auto& sample: profiler._cumulated) 77 | samples.emplace_back(sample); 78 | 79 | std::sort(samples.begin(), samples.end(), 80 | [] (const std::pair& a, 81 | const std::pair& b) 82 | { 83 | return a.second > b.second; 84 | }); 85 | 86 | for (auto it: samples) 87 | { 88 | os << "[" << profiler.get_id() << "]" 89 | << '\t' 90 | << it.first 91 | << '\t' 92 | << static_cast(it.second.count()) / 1000 << "ms" 93 | << '\t' 94 | << "(" << (static_cast(it.second.count()) / static_cast(profiler._total_time.count())) * 100 << "%)" 95 | << std::endl; 96 | } 97 | 98 | return os; 99 | } 100 | 101 | } 102 | -------------------------------------------------------------------------------- /src/SubDict.cc: -------------------------------------------------------------------------------- 1 | #include "onmt/SubDict.h" 2 | 3 | #include 4 | 5 | namespace onmt 6 | { 7 | 8 | SubDict::SubDict(const std::string& map_file, const Dictionary& dict) 9 | { 10 | if (!map_file.empty()) 11 | { 12 | std::ifstream mapf(map_file); 13 | if (!mapf.is_open()) 14 | throw std::invalid_argument("Unable to open dictionary vocab mapping file `" + map_file + "`"); 15 | std::string line; 16 | while (std::getline(mapf, line)) 17 | { 18 | std::string token; 19 | std::string key; 20 | std::vector values; 21 | bool target = false; 22 | size_t ngram = 1; 23 | 24 | for (size_t i = 0; i < line.length(); ++i) 25 | { 26 | if (line[i] == '\t') 27 | { 28 | target = true; 29 | std::swap(key, token); 30 | } 31 | else if (line[i] == ' ') 32 | { 33 | if (target) 34 | { 35 | values.push_back(dict.get_word_id(token)); 36 | token.clear(); 37 | } 38 | else 39 | { 40 | token += line[i]; 41 | ++ngram; 42 | } 43 | } 44 | else 45 | token += line[i]; 46 | } 47 | 48 | if (!token.empty()) 49 | values.push_back(dict.get_word_id(token)); 50 | 51 | if (ngram > _map_rules.size()) 52 | _map_rules.resize(ngram); 53 | 54 | _map_rules[ngram - 1][key] = values; 55 | } 56 | } 57 | } 58 | 59 | void SubDict::extract(const std::vector& words, std::set& r) const 60 | { 61 | r.insert(Dictionary::unk_id); 62 | r.insert(Dictionary::bos_id); 63 | r.insert(Dictionary::eos_id); 64 | r.insert(Dictionary::pad_id); 65 | 66 | auto it = _map_rules[0].find(""); 67 | if (it != _map_rules[0].end()) 68 | { 69 | for (const auto& v: it->second) 70 | r.insert(v); 71 | } 72 | 73 | for (size_t i = 0; i < words.size(); i++) 74 | { 75 | std::string tok = words[i]; 76 | size_t h = 0; 77 | do { 78 | if (h > 0) 79 | { 80 | if (i + h >= words.size()) 81 | break; 82 | tok += " " + words[i + h]; 83 | } 84 | auto it = _map_rules[h].find(tok); 85 | if (it != _map_rules[h].end()) 86 | { 87 | for (const auto& v: it->second) 88 | r.insert(v); 89 | } 90 | h++; 91 | } while (h < _map_rules.size()); 92 | } 93 | } 94 | 95 | } 96 | -------------------------------------------------------------------------------- /src/Threads.cc: -------------------------------------------------------------------------------- 1 | #include "onmt/Threads.h" 2 | 3 | #include 4 | #ifdef WITH_MKL 5 | # include 6 | #endif 7 | 8 | namespace onmt 9 | { 10 | 11 | void Threads::set(int number) 12 | { 13 | Eigen::setNbThreads(number); 14 | #ifdef WITH_MKL 15 | mkl_set_num_threads(number); 16 | #endif 17 | } 18 | 19 | int Threads::get() 20 | { 21 | return Eigen::nbThreads(); 22 | } 23 | 24 | } 25 | -------------------------------------------------------------------------------- /src/TranslationOptions.cc: -------------------------------------------------------------------------------- 1 | #include "onmt/TranslationOptions.h" 2 | 3 | namespace onmt 4 | { 5 | 6 | TranslationOptions::TranslationOptions(size_t max_sent_length, 7 | size_t beam_size, 8 | size_t n_best, 9 | bool replace_unk, 10 | bool replace_unk_tagged) 11 | : _max_sent_length(max_sent_length) 12 | , _beam_size(beam_size) 13 | , _n_best(n_best) 14 | , _replace_unk(replace_unk) 15 | , _replace_unk_tagged(replace_unk_tagged) 16 | { 17 | } 18 | 19 | size_t TranslationOptions::max_sent_length() const 20 | { 21 | return _max_sent_length; 22 | } 23 | 24 | size_t& TranslationOptions::max_sent_length() 25 | { 26 | return _max_sent_length; 27 | } 28 | 29 | size_t TranslationOptions::beam_size() const 30 | { 31 | return _beam_size; 32 | } 33 | 34 | size_t& TranslationOptions::beam_size() 35 | { 36 | return _beam_size; 37 | } 38 | 39 | size_t TranslationOptions::n_best() const 40 | { 41 | return _n_best; 42 | } 43 | 44 | size_t& TranslationOptions::n_best() 45 | { 46 | return _n_best; 47 | } 48 | 49 | bool TranslationOptions::replace_unk() const 50 | { 51 | return _replace_unk; 52 | } 53 | 54 | bool& TranslationOptions::replace_unk() 55 | { 56 | return _replace_unk; 57 | } 58 | 59 | bool TranslationOptions::replace_unk_tagged() const 60 | { 61 | return _replace_unk_tagged; 62 | } 63 | 64 | bool& TranslationOptions::replace_unk_tagged() 65 | { 66 | return _replace_unk_tagged; 67 | } 68 | 69 | } 70 | -------------------------------------------------------------------------------- /src/TranslationResult.cc: -------------------------------------------------------------------------------- 1 | #include "onmt/TranslationResult.h" 2 | 3 | namespace onmt 4 | { 5 | 6 | TranslationResult::TranslationResult(const std::vector > >& words, 7 | const std::vector > > >& features, 8 | const std::vector > > >& attention, 9 | const std::vector >& score, 10 | const std::vector >& count_unk_words) 11 | : _words(words) 12 | , _features(features) 13 | , _attention(attention) 14 | , _score(score) 15 | , _count_unk_words(count_unk_words) 16 | { 17 | } 18 | 19 | const std::vector& TranslationResult::get_words(size_t job_index, size_t translation_index) const 20 | { 21 | return _words[job_index][translation_index]; 22 | } 23 | 24 | const std::vector >& TranslationResult::get_features(size_t job_index, size_t translation_index) const 25 | { 26 | return _features[job_index][translation_index]; 27 | } 28 | 29 | const std::vector >& TranslationResult::get_attention(size_t job_index, size_t translation_index) const 30 | { 31 | return _attention[job_index][translation_index]; 32 | } 33 | 34 | float TranslationResult::get_score(size_t job_index, size_t translation_index) const 35 | { 36 | return _score[job_index][translation_index]; 37 | } 38 | 39 | size_t TranslationResult::get_count_unk_words(size_t job_index, size_t translation_index) const 40 | { 41 | return _count_unk_words[job_index][translation_index]; 42 | } 43 | 44 | const std::vector >& TranslationResult::get_words_job(size_t job_index) const 45 | { 46 | return _words[job_index]; 47 | } 48 | 49 | const std::vector > >& TranslationResult::get_features_job(size_t job_index) const 50 | { 51 | return _features[job_index]; 52 | } 53 | 54 | const std::vector > >& TranslationResult::get_attention_job(size_t job_index) const 55 | { 56 | return _attention[job_index]; 57 | } 58 | 59 | const std::vector& TranslationResult::get_score_job(size_t job_index) const 60 | { 61 | return _score[job_index]; 62 | } 63 | 64 | const std::vector& TranslationResult::get_count_unk_words_job(size_t job_index) const 65 | { 66 | return _count_unk_words[job_index]; 67 | } 68 | 69 | size_t TranslationResult::count_job(size_t job_index) const 70 | { 71 | return _words[job_index].size(); 72 | } 73 | 74 | const std::vector > >& TranslationResult::get_words_batch() const 75 | { 76 | return _words; 77 | } 78 | 79 | const std::vector > > >& TranslationResult::get_features_batch() const 80 | { 81 | return _features; 82 | } 83 | 84 | const std::vector > > >& TranslationResult::get_attention_batch() const 85 | { 86 | return _attention; 87 | } 88 | 89 | const std::vector >& TranslationResult::get_score_batch() const 90 | { 91 | return _score; 92 | } 93 | 94 | const std::vector >& TranslationResult::get_count_unk_words_batch() const 95 | { 96 | return _count_unk_words; 97 | } 98 | 99 | size_t TranslationResult::count() const 100 | { 101 | return _words.size(); 102 | } 103 | 104 | bool TranslationResult::has_features() const 105 | { 106 | return !_features.empty(); 107 | } 108 | 109 | } 110 | -------------------------------------------------------------------------------- /src/TranslatorFactory.cc: -------------------------------------------------------------------------------- 1 | #include "onmt/TranslatorFactory.h" 2 | 3 | namespace onmt 4 | { 5 | 6 | std::unique_ptr TranslatorFactory::build(const std::string& model, 7 | const std::string& phrase_table, 8 | const std::string& vocab_mapping, 9 | bool cuda, 10 | bool qlinear, 11 | bool profiling) 12 | { 13 | ITranslator* t = nullptr; 14 | 15 | t = new DefaultTranslator(model, 16 | phrase_table, 17 | vocab_mapping, 18 | cuda, 19 | qlinear, 20 | profiling); 21 | 22 | return std::unique_ptr(t); 23 | } 24 | 25 | std::unique_ptr 26 | TranslatorFactory::clone(const std::unique_ptr& translator) 27 | { 28 | ITranslator* t = new DefaultTranslator( 29 | dynamic_cast&>(*translator)); 30 | return std::unique_ptr(t); 31 | } 32 | 33 | } 34 | -------------------------------------------------------------------------------- /src/cuda/Kernels.cu: -------------------------------------------------------------------------------- 1 | #include "onmt/cuda/Kernels.cuh" 2 | 3 | #include 4 | 5 | namespace onmt 6 | { 7 | namespace cuda 8 | { 9 | namespace kernels 10 | { 11 | 12 | struct AddOp 13 | { 14 | __device__ __forceinline__ void operator()(float* out, const float* in) 15 | { 16 | *out += *in; 17 | } 18 | }; 19 | 20 | template 21 | __global__ void pointwise2_kernel(float* __restrict__ dst, 22 | const float* __restrict__ src, 23 | int len) 24 | { 25 | int stride = gridDim.x * blockDim.x; 26 | int tid = blockDim.x * blockIdx.x + threadIdx.x; 27 | Op op; 28 | 29 | for (int i = tid; i < len; i += stride) 30 | { 31 | op(dst + i, src + i); 32 | } 33 | } 34 | 35 | template 36 | void pointwise2(float* dst, const float* src, int len) 37 | { 38 | int grid_size = -1; 39 | int block_size = -1; 40 | 41 | cudaOccupancyMaxPotentialBlockSize(&grid_size, &block_size, &pointwise2_kernel); 42 | grid_size = (len + block_size - 1) / block_size; 43 | 44 | pointwise2_kernel<<>>(dst, src, len); 45 | } 46 | 47 | void add(float* a, const float* b, int len) 48 | { 49 | pointwise2(a, b, len); 50 | } 51 | 52 | } 53 | } 54 | } -------------------------------------------------------------------------------- /src/cuda/Utils.cc: -------------------------------------------------------------------------------- 1 | #include "onmt/cuda/Utils.h" 2 | 3 | namespace onmt 4 | { 5 | namespace cuda 6 | { 7 | 8 | std::string cublasGetStatusString(cublasStatus_t status) 9 | { 10 | switch (status) 11 | { 12 | case CUBLAS_STATUS_SUCCESS: 13 | return "CUBLAS_STATUS_SUCCESS"; 14 | case CUBLAS_STATUS_NOT_INITIALIZED: 15 | return "CUBLAS_STATUS_NOT_INITIALIZED"; 16 | case CUBLAS_STATUS_ALLOC_FAILED: 17 | return "CUBLAS_STATUS_ALLOC_FAILED"; 18 | case CUBLAS_STATUS_INVALID_VALUE: 19 | return "CUBLAS_STATUS_INVALID_VALUE"; 20 | case CUBLAS_STATUS_ARCH_MISMATCH: 21 | return "CUBLAS_STATUS_ARCH_MISMATCH"; 22 | case CUBLAS_STATUS_MAPPING_ERROR: 23 | return "CUBLAS_STATUS_MAPPING_ERROR"; 24 | case CUBLAS_STATUS_EXECUTION_FAILED: 25 | return "CUBLAS_STATUS_EXECUTION_FAILED"; 26 | case CUBLAS_STATUS_INTERNAL_ERROR: 27 | return "CUBLAS_STATUS_INTERNAL_ERROR"; 28 | case CUBLAS_STATUS_NOT_SUPPORTED: 29 | return "CUBLAS_STATUS_NOT_SUPPORTED"; 30 | case CUBLAS_STATUS_LICENSE_ERROR: 31 | return "CUBLAS_STATUS_LICENSE_ERROR"; 32 | default: 33 | return "UNKNOWN"; 34 | } 35 | } 36 | 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /src/th/Env.cc: -------------------------------------------------------------------------------- 1 | #include "onmt/th/Env.h" 2 | 3 | namespace onmt 4 | { 5 | namespace th 6 | { 7 | 8 | Env::~Env() 9 | { 10 | for (const auto& pair: _idx_obj) 11 | delete pair.second; 12 | 13 | for (const auto& obj: _list_obj) 14 | delete obj; 15 | } 16 | 17 | Obj* Env::get_object(int index) 18 | { 19 | auto it = _idx_obj.find(index); 20 | 21 | if (it != _idx_obj.end()) 22 | return it->second; 23 | 24 | return nullptr; 25 | } 26 | 27 | void Env::set_object(Obj* thobj, int index) 28 | { 29 | if (index >= 0) 30 | _idx_obj[index] = thobj; 31 | else 32 | _list_obj.push_back(thobj); 33 | } 34 | 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /src/th/Utils.cc: -------------------------------------------------------------------------------- 1 | #include "onmt/th/Utils.h" 2 | 3 | namespace onmt 4 | { 5 | namespace th 6 | { 7 | 8 | int get_number(Table* module_data, const std::string& name) 9 | { 10 | return get_scalar(module_data, name); 11 | } 12 | 13 | bool get_boolean(Table* module_data, const std::string& name) 14 | { 15 | Boolean* b = get_field(module_data, name); 16 | return b && b->get_value(); 17 | } 18 | 19 | } 20 | } 21 | --------------------------------------------------------------------------------