├── .github └── workflows │ ├── mac.yml │ ├── ubuntu.yml │ └── windows.yml ├── .gitignore ├── BUILDING ├── CMakeLists.txt ├── COPYING ├── COPYING.3 ├── COPYING.LESSER.3 ├── Doxyfile ├── LICENSE ├── MANIFEST.in ├── README.md ├── clean_query_only.sh ├── cmake ├── KenLMFunctions.cmake ├── kenlmConfig.cmake.in └── modules │ └── FindEigen3.cmake ├── compile_query_only.sh ├── lm ├── CMakeLists.txt ├── bhiksha.cc ├── bhiksha.hh ├── binary_format.cc ├── binary_format.hh ├── blank.hh ├── build_binary_main.cc ├── builder │ ├── CMakeLists.txt │ ├── README.md │ ├── TODO │ ├── adjust_counts.cc │ ├── adjust_counts.hh │ ├── adjust_counts_test.cc │ ├── combine_counts.hh │ ├── corpus_count.cc │ ├── corpus_count.hh │ ├── corpus_count_test.cc │ ├── count_ngrams_main.cc │ ├── debug_print.hh │ ├── discount.hh │ ├── dump_counts_main.cc │ ├── hash_gamma.hh │ ├── header_info.hh │ ├── initial_probabilities.cc │ ├── initial_probabilities.hh │ ├── interpolate.cc │ ├── interpolate.hh │ ├── lmplz_main.cc │ ├── output.cc │ ├── output.hh │ ├── payload.hh │ ├── pipeline.cc │ └── pipeline.hh ├── common │ ├── CMakeLists.txt │ ├── compare.hh │ ├── joint_order.hh │ ├── model_buffer.cc │ ├── model_buffer.hh │ ├── model_buffer_test.cc │ ├── ngram.hh │ ├── ngram_stream.hh │ ├── print.cc │ ├── print.hh │ ├── renumber.cc │ ├── renumber.hh │ ├── size_option.cc │ ├── size_option.hh │ ├── special.hh │ └── test_data │ │ ├── bigendian │ │ ├── toy0.1 │ │ ├── toy0.2 │ │ ├── toy0.3 │ │ ├── toy0.kenlm_intermediate │ │ ├── toy0.vocab │ │ ├── toy1.1 │ │ ├── toy1.2 │ │ ├── toy1.3 │ │ ├── toy1.kenlm_intermediate │ │ └── toy1.vocab │ │ ├── generate.sh │ │ ├── littleendian │ │ ├── toy0.1 │ │ ├── toy0.2 │ │ ├── toy0.3 │ │ ├── toy0.kenlm_intermediate │ │ ├── toy0.vocab │ │ ├── toy1.1 │ │ ├── toy1.2 │ │ ├── toy1.3 │ │ ├── toy1.kenlm_intermediate │ │ └── toy1.vocab │ │ ├── toy0.arpa │ │ └── toy1.arpa ├── config.cc ├── config.hh ├── enumerate_vocab.hh ├── facade.hh ├── filter │ ├── CMakeLists.txt │ ├── arpa_io.cc │ ├── arpa_io.hh │ ├── count_io.hh │ ├── filter_main.cc │ ├── format.hh │ ├── phrase.cc │ ├── phrase.hh │ ├── phrase_table_vocab_main.cc │ ├── thread.hh │ ├── vocab.cc │ ├── vocab.hh │ └── wrapper.hh ├── fragment_main.cc ├── interpolate │ ├── CMakeLists.txt │ ├── backoff_matrix.hh │ ├── backoff_reunification.cc │ ├── backoff_reunification.hh │ ├── backoff_reunification_test.cc │ ├── bounded_sequence_encoding.cc │ ├── bounded_sequence_encoding.hh │ ├── bounded_sequence_encoding_test.cc │ ├── interpolate_info.hh │ ├── interpolate_main.cc │ ├── merge_probabilities.cc │ ├── merge_probabilities.hh │ ├── merge_vocab.cc │ ├── merge_vocab.hh │ ├── merge_vocab_test.cc │ ├── normalize.cc │ ├── normalize.hh │ ├── normalize_test.cc │ ├── pipeline.cc │ ├── pipeline.hh │ ├── split_worker.cc │ ├── split_worker.hh │ ├── streaming_example_main.cc │ ├── tune_derivatives.cc │ ├── tune_derivatives.hh │ ├── tune_derivatives_test.cc │ ├── tune_instances.cc │ ├── tune_instances.hh │ ├── tune_instances_test.cc │ ├── tune_matrix.hh │ ├── tune_weights.cc │ ├── tune_weights.hh │ ├── universal_vocab.cc │ └── universal_vocab.hh ├── kenlm_benchmark_main.cc ├── left.hh ├── left_test.cc ├── lm_exception.cc ├── lm_exception.hh ├── max_order.hh ├── model.cc ├── model.hh ├── model_test.cc ├── model_type.hh ├── ngram_query.hh ├── partial.hh ├── partial_test.cc ├── quantize.cc ├── quantize.hh ├── query_main.cc ├── read_arpa.cc ├── read_arpa.hh ├── return.hh ├── search_hashed.cc ├── search_hashed.hh ├── search_trie.cc ├── search_trie.hh ├── sizes.cc ├── sizes.hh ├── state.hh ├── test.arpa ├── test_nounk.arpa ├── trie.cc ├── trie.hh ├── trie_sort.cc ├── trie_sort.hh ├── value.hh ├── value_build.cc ├── value_build.hh ├── virtual_interface.cc ├── virtual_interface.hh ├── vocab.cc ├── vocab.hh ├── weights.hh ├── word_index.hh └── wrappers │ ├── README │ ├── nplm.cc │ └── nplm.hh ├── pyproject.toml ├── python ├── BuildStandalone.cmake ├── CMakeLists.txt ├── _kenlm.pxd ├── example.py ├── kenlm.cpp ├── kenlm.pyx ├── score_sentence.cc └── score_sentence.hh ├── setup.py └── util ├── CMakeLists.txt ├── bit_packing.cc ├── bit_packing.hh ├── bit_packing_test.cc ├── cat_compressed_main.cc ├── double-conversion ├── CMakeLists.txt ├── LICENSE ├── bignum-dtoa.cc ├── bignum-dtoa.h ├── bignum.cc ├── bignum.h ├── cached-powers.cc ├── cached-powers.h ├── diy-fp.h ├── double-conversion.h ├── double-to-string.cc ├── double-to-string.h ├── fast-dtoa.cc ├── fast-dtoa.h ├── fixed-dtoa.cc ├── fixed-dtoa.h ├── ieee.h ├── string-to-double.cc ├── string-to-double.h ├── strtod.cc ├── strtod.h └── utils.h ├── ersatz_progress.cc ├── ersatz_progress.hh ├── exception.cc ├── exception.hh ├── fake_ostream.hh ├── file.cc ├── file.hh ├── file_piece.cc ├── file_piece.hh ├── file_piece_test.cc ├── file_stream.hh ├── fixed_array.hh ├── float_to_string.cc ├── float_to_string.hh ├── getopt.c ├── getopt.hh ├── have.hh ├── integer_to_string.cc ├── integer_to_string.hh ├── integer_to_string_test.cc ├── joint_sort.hh ├── joint_sort_test.cc ├── mmap.cc ├── mmap.hh ├── multi_intersection.hh ├── multi_intersection_test.cc ├── murmur_hash.cc ├── murmur_hash.hh ├── parallel_read.cc ├── parallel_read.hh ├── pcqueue.hh ├── pcqueue_test.cc ├── pool.cc ├── pool.hh ├── probing_hash_table.hh ├── probing_hash_table_benchmark_main.cc ├── probing_hash_table_test.cc ├── proxy_iterator.hh ├── read_compressed.cc ├── read_compressed.hh ├── read_compressed_test.cc ├── scoped.cc ├── scoped.hh ├── sized_iterator.hh ├── sized_iterator_test.cc ├── sorted_uniform.hh ├── sorted_uniform_test.cc ├── spaces.cc ├── spaces.hh ├── stream ├── CMakeLists.txt ├── block.hh ├── chain.cc ├── chain.hh ├── config.hh ├── count_records.cc ├── count_records.hh ├── io.cc ├── io.hh ├── io_test.cc ├── line_input.cc ├── line_input.hh ├── multi_progress.cc ├── multi_progress.hh ├── multi_stream.hh ├── rewindable_stream.cc ├── rewindable_stream.hh ├── rewindable_stream_test.cc ├── sort.hh ├── sort_test.cc ├── stream.hh ├── stream_test.cc └── typed_stream.hh ├── string_piece.cc ├── string_piece.hh ├── string_piece_hash.hh ├── string_stream.hh ├── string_stream_test.cc ├── thread_pool.hh ├── tokenize_piece.hh ├── tokenize_piece_test.cc ├── usage.cc └── usage.hh /.github/workflows/mac.yml: -------------------------------------------------------------------------------- 1 | name: Mac 2 | 3 | on: 4 | push: 5 | branches: master 6 | pull_request: 7 | branches: master 8 | 9 | jobs: 10 | build: 11 | runs-on: macOS-latest 12 | 13 | steps: 14 | - uses: actions/checkout@v2 15 | - name: Install Boost 16 | run: | 17 | brew install boost 18 | brew install libomp 19 | brew install eigen 20 | - name: cmake 21 | run: | 22 | cmake -E make_directory build 23 | cd build 24 | cmake .. 25 | - name: Compile 26 | working-directory: build 27 | run: cmake --build . -j2 28 | - name: Test 29 | working-directory: build 30 | run: ctest -j2 31 | -------------------------------------------------------------------------------- /.github/workflows/ubuntu.yml: -------------------------------------------------------------------------------- 1 | name: Ubuntu 2 | 3 | on: 4 | push: 5 | branches: [master] 6 | pull_request: 7 | branches: [master] 8 | 9 | jobs: 10 | build: 11 | runs-on: ubuntu-latest 12 | 13 | steps: 14 | - uses: actions/checkout@v2 15 | - name: dependencies 16 | run: | 17 | sudo apt-get update 18 | sudo apt-get install -y build-essential libboost-all-dev cmake zlib1g-dev libbz2-dev liblzma-dev 19 | - name: cmake 20 | run: | 21 | cmake -E make_directory build 22 | cd build 23 | cmake -DCOMPILE_TESTS=ON .. 24 | - name: Compile 25 | working-directory: build 26 | run: cmake --build . -j2 27 | - name: Test 28 | working-directory: build 29 | run: ctest -j2 30 | -------------------------------------------------------------------------------- /.github/workflows/windows.yml: -------------------------------------------------------------------------------- 1 | name: Windows 2 | 3 | on: 4 | push: 5 | branches: [master] 6 | pull_request: 7 | branches: [master] 8 | 9 | jobs: 10 | build: 11 | runs-on: windows-latest 12 | 13 | steps: 14 | - uses: actions/checkout@v2 15 | - name: cmake 16 | run: | 17 | cmake -E make_directory build 18 | cd build 19 | cmake -DBOOST_ROOT="${env:BOOST_ROOT_1_72_0}" .. 20 | - name: Compile 21 | working-directory: build 22 | run: cmake --build . -j2 23 | - name: Test 24 | working-directory: build 25 | run: ctest -j2 26 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | util/file_piece.cc.gz 2 | *.swp 3 | *.o 4 | doc/ 5 | build/ 6 | /bin 7 | /lib 8 | /tests 9 | ._* 10 | windows/Win32 11 | windows/x64 12 | windows/*.user 13 | windows/*.sdf 14 | windows/*.opensdf 15 | windows/*.suo 16 | CMakeFiles 17 | cmake_install.cmake 18 | CMakeCache.txt 19 | CTestTestfile.cmake 20 | DartConfiguration.tcl 21 | Makefile 22 | *.egg-info/ 23 | -------------------------------------------------------------------------------- /BUILDING: -------------------------------------------------------------------------------- 1 | KenLM has switched to cmake 2 | cmake . 3 | make -j 4 4 | But they recommend building out of tree 5 | mkdir -p build && cd build 6 | cmake .. 7 | make -j 4 8 | 9 | If you only want the query code and do not care about compression (.gz, .bz2, and .xz): 10 | ./compile_query_only.sh 11 | 12 | Windows: 13 | The windows directory has visual studio files. Note that you need to compile 14 | the kenlm project before build_binary and ngram_query projects. 15 | 16 | OSX: 17 | Missing dependencies can be remedied with brew. 18 | brew install cmake boost eigen 19 | 20 | Debian/Ubuntu: 21 | sudo apt install build-essential cmake libboost-system-dev libboost-thread-dev libboost-program-options-dev libboost-test-dev libeigen3-dev zlib1g-dev libbz2-dev liblzma-dev 22 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Most of the code here is licensed under the LGPL. There are exceptions that 2 | have their own licenses, listed below. See comments in those files for more 3 | details. 4 | 5 | util/getopt.* is getopt for Windows 6 | util/murmur_hash.cc 7 | util/string_piece.hh and util/string_piece.cc 8 | util/double-conversion/LICENSE covers util/double-conversion except the build files 9 | util/file.cc contains a modified implementation of mkstemp under the LGPL 10 | util/integer_to_string.* is BSD 11 | 12 | For the rest: 13 | 14 | KenLM is free software: you can redistribute it and/or modify 15 | it under the terms of the GNU Lesser General Public License as published 16 | by the Free Software Foundation, either version 2.1 of the License, or 17 | (at your option) any later version. 18 | 19 | KenLM is distributed in the hope that it will be useful, 20 | but WITHOUT ANY WARRANTY; without even the implied warranty of 21 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 22 | GNU Lesser General Public License for more details. 23 | 24 | You should have received a copy of the GNU Lesser General Public License 2.1 25 | along with KenLM code. If not, see . 26 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | # file GENERATED by distutils, do NOT edit 2 | include setup.py 3 | include lm/*.cc 4 | include lm/*.hh 5 | include python/*.cpp 6 | include util/*.cc 7 | include util/*.hh 8 | include util/double-conversion/*.cc 9 | include util/double-conversion/*.h 10 | -------------------------------------------------------------------------------- /clean_query_only.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | rm -rf {lm,util,util/double-conversion}/*.o bin/{query,build_binary} 3 | -------------------------------------------------------------------------------- /cmake/KenLMFunctions.cmake: -------------------------------------------------------------------------------- 1 | # Helper functions used across the CMake build system 2 | 3 | include(CMakeParseArguments) 4 | 5 | # Adds a bunch of executables to the build, each depending on the specified 6 | # dependent object files and linking against the specified libraries 7 | function(AddExes) 8 | set(multiValueArgs EXES DEPENDS LIBRARIES) 9 | cmake_parse_arguments(AddExes "" "" "${multiValueArgs}" ${ARGN}) 10 | 11 | # Iterate through the executable list 12 | foreach(exe ${AddExes_EXES}) 13 | 14 | # Compile the executable, linking against the requisite dependent object files 15 | add_executable(${exe} ${exe}_main.cc ${AddExes_DEPENDS}) 16 | 17 | # Link the executable against the supplied libraries 18 | target_link_libraries(${exe} ${AddExes_LIBRARIES}) 19 | 20 | # Group executables together 21 | set_target_properties(${exe} PROPERTIES FOLDER executables) 22 | 23 | # End for loop 24 | endforeach(exe) 25 | 26 | # Install the executable files 27 | install(TARGETS ${AddExes_EXES} DESTINATION bin) 28 | endfunction() 29 | 30 | # Adds a single test to the build, depending on the specified dependent 31 | # object files, linking against the specified libraries, and with the 32 | # specified command line arguments 33 | function(KenLMAddTest) 34 | cmake_parse_arguments(KenLMAddTest "" "TEST" 35 | "DEPENDS;LIBRARIES;TEST_ARGS" ${ARGN}) 36 | 37 | # Compile the executable, linking against the requisite dependent object files 38 | add_executable(${KenLMAddTest_TEST} 39 | ${KenLMAddTest_TEST}.cc 40 | ${KenLMAddTest_DEPENDS}) 41 | 42 | if (Boost_USE_STATIC_LIBS) 43 | set(DYNLINK_FLAGS) 44 | else() 45 | set(DYNLINK_FLAGS COMPILE_FLAGS -DBOOST_TEST_DYN_LINK) 46 | endif() 47 | 48 | # Require the following compile flag 49 | set_target_properties(${KenLMAddTest_TEST} PROPERTIES 50 | ${DYNLINK_FLAGS} 51 | RUNTIME_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}/tests) 52 | 53 | target_link_libraries(${KenLMAddTest_TEST} ${KenLMAddTest_LIBRARIES} ${TIMER_LINK}) 54 | 55 | set(test_params "") 56 | if(KenLMAddTest_TEST_ARGS) 57 | set(test_params ${KenLMAddTest_TEST_ARGS}) 58 | endif() 59 | 60 | # Specify command arguments for how to run each unit test 61 | add_test(NAME ${KenLMAddTest_TEST} 62 | COMMAND ${KenLMAddTest_TEST} ${test_params}) 63 | 64 | # Group unit tests together 65 | set_target_properties(${KenLMAddTest_TEST} PROPERTIES FOLDER "unit_tests") 66 | endfunction() 67 | 68 | # Adds a bunch of tests to the build, each depending on the specified 69 | # dependent object files and linking against the specified libraries 70 | function(AddTests) 71 | set(multiValueArgs TESTS DEPENDS LIBRARIES TEST_ARGS) 72 | cmake_parse_arguments(AddTests "" "" "${multiValueArgs}" ${ARGN}) 73 | 74 | # Iterate through the Boost tests list 75 | foreach(test ${AddTests_TESTS}) 76 | KenLMAddTest(TEST ${test} 77 | DEPENDS ${AddTests_DEPENDS} 78 | LIBRARIES ${AddTests_LIBRARIES} 79 | TEST_ARGS ${AddTests_TEST_ARGS}) 80 | endforeach(test) 81 | endfunction() 82 | -------------------------------------------------------------------------------- /cmake/kenlmConfig.cmake.in: -------------------------------------------------------------------------------- 1 | @PACKAGE_INIT@ 2 | 3 | include(CMakeFindDependencyMacro) 4 | 5 | find_dependency(Boost) 6 | find_dependency(Threads) 7 | 8 | # Compression libs 9 | if (@ZLIB_FOUND@) 10 | find_dependency(ZLIB) 11 | endif() 12 | if (@BZIP2_FOUND@) 13 | find_dependency(BZip2) 14 | endif() 15 | if (@LIBLZMA_FOUND@) 16 | find_dependency(LibLZMA) 17 | endif() 18 | 19 | include("${CMAKE_CURRENT_LIST_DIR}/kenlmTargets.cmake") 20 | -------------------------------------------------------------------------------- /compile_query_only.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #This is just an example compilation. You should integrate these files into your build system. Boost jam is provided and preferred. 3 | 4 | echo You must use ./bjam if you want language model estimation, filtering, or support for compressed files \(.gz, .bz2, .xz\) 1>&2 5 | 6 | rm {lm,util}/*.o 2>/dev/null 7 | set -e 8 | 9 | CXX=${CXX:-g++} 10 | 11 | CXXFLAGS+=" -I. -O3 -DNDEBUG -DKENLM_MAX_ORDER=6" 12 | 13 | #If this fails for you, consider using bjam. 14 | if [ ${#NPLM} != 0 ]; then 15 | CXXFLAGS+=" -DHAVE_NPLM -lneuralLM -L$NPLM/src -I$NPLM/src -lboost_thread-mt -fopenmp" 16 | ADDED_PATHS="lm/wrappers/*.cc" 17 | fi 18 | echo 'Compiling with '$CXX $CXXFLAGS 19 | 20 | #Grab all cc files in these directories except those ending in test.cc or main.cc 21 | objects="" 22 | for i in util/double-conversion/*.cc util/*.cc lm/*.cc $ADDED_PATHS; do 23 | if [ "${i%test.cc}" == "$i" ] && [ "${i%main.cc}" == "$i" ]; then 24 | $CXX $CXXFLAGS -c $i -o ${i%.cc}.o 25 | objects="$objects ${i%.cc}.o" 26 | fi 27 | done 28 | 29 | mkdir -p bin 30 | if [ "$(uname)" != Darwin ]; then 31 | CXXFLAGS="$CXXFLAGS -lrt" 32 | fi 33 | $CXX lm/build_binary_main.cc $objects -o bin/build_binary $CXXFLAGS $LDFLAGS 34 | $CXX lm/query_main.cc $objects -o bin/query $CXXFLAGS $LDFLAGS 35 | -------------------------------------------------------------------------------- /lm/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Explicitly list the source files for this subdirectory 2 | # 3 | # If you add any source files to this subdirectory 4 | # that should be included in the kenlm library, 5 | # (this excludes any unit test files) 6 | # you should add them to the following list: 7 | set(KENLM_LM_SOURCE 8 | bhiksha.cc 9 | binary_format.cc 10 | config.cc 11 | lm_exception.cc 12 | model.cc 13 | quantize.cc 14 | read_arpa.cc 15 | search_hashed.cc 16 | search_trie.cc 17 | sizes.cc 18 | trie.cc 19 | trie_sort.cc 20 | value_build.cc 21 | virtual_interface.cc 22 | vocab.cc 23 | ) 24 | 25 | 26 | # Group these objects together for later use. 27 | # 28 | # Given add_library(foo OBJECT ${my_foo_sources}), 29 | # refer to these objects as $ 30 | # 31 | add_subdirectory(common) 32 | 33 | add_library(kenlm ${KENLM_LM_SOURCE} ${KENLM_LM_COMMON_SOURCE}) 34 | set_target_properties(kenlm PROPERTIES POSITION_INDEPENDENT_CODE ON) 35 | target_link_libraries(kenlm PUBLIC kenlm_util Threads::Threads) 36 | # Since headers are relative to `include/kenlm` at install time, not just `include` 37 | target_include_directories(kenlm PUBLIC $) 38 | 39 | target_compile_definitions(kenlm PUBLIC -DKENLM_MAX_ORDER=${KENLM_MAX_ORDER}) 40 | 41 | # This directory has children that need to be processed 42 | add_subdirectory(builder) 43 | add_subdirectory(filter) 44 | add_subdirectory(interpolate) 45 | 46 | # Explicitly list the executable files to be compiled 47 | set(EXE_LIST 48 | query 49 | fragment 50 | build_binary 51 | kenlm_benchmark 52 | ) 53 | 54 | set(LM_LIBS kenlm kenlm_util Threads::Threads) 55 | 56 | install( 57 | TARGETS kenlm 58 | EXPORT kenlmTargets 59 | RUNTIME DESTINATION bin 60 | LIBRARY DESTINATION lib 61 | ARCHIVE DESTINATION lib 62 | INCLUDES DESTINATION include 63 | ) 64 | 65 | AddExes(EXES ${EXE_LIST} 66 | LIBRARIES ${LM_LIBS}) 67 | 68 | if(BUILD_TESTING) 69 | 70 | set(KENLM_BOOST_TESTS_LIST left_test partial_test) 71 | AddTests(TESTS ${KENLM_BOOST_TESTS_LIST} 72 | LIBRARIES ${LM_LIBS} 73 | TEST_ARGS ${CMAKE_CURRENT_SOURCE_DIR}/test.arpa) 74 | 75 | # model_test requires an extra command line parameter 76 | KenLMAddTest(TEST model_test 77 | LIBRARIES ${LM_LIBS} 78 | TEST_ARGS ${CMAKE_CURRENT_SOURCE_DIR}/test.arpa 79 | ${CMAKE_CURRENT_SOURCE_DIR}/test_nounk.arpa) 80 | endif() 81 | -------------------------------------------------------------------------------- /lm/blank.hh: -------------------------------------------------------------------------------- 1 | #ifndef LM_BLANK_H 2 | #define LM_BLANK_H 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | namespace lm { 9 | namespace ngram { 10 | 11 | /* Suppose "foo bar" appears with zero backoff but there is no trigram 12 | * beginning with these words. Then, when scoring "foo bar", the model could 13 | * return out_state containing "bar" or even null context if "bar" also has no 14 | * backoff and is never followed by another word. Then the backoff is set to 15 | * kNoExtensionBackoff. If the n-gram might be extended, then out_state must 16 | * contain the full n-gram, in which case kExtensionBackoff is set. In any 17 | * case, if an n-gram has non-zero backoff, the full state is returned so 18 | * backoff can be properly charged. 19 | * These differ only in sign bit because the backoff is in fact zero in either 20 | * case. 21 | */ 22 | const float kNoExtensionBackoff = -0.0; 23 | const float kExtensionBackoff = 0.0; 24 | const uint64_t kNoExtensionQuant = 0; 25 | const uint64_t kExtensionQuant = 1; 26 | 27 | inline void SetExtension(float &backoff) { 28 | if (backoff == kNoExtensionBackoff) backoff = kExtensionBackoff; 29 | } 30 | 31 | // This compiles down nicely. 32 | inline bool HasExtension(const float &backoff) { 33 | typedef union { float f; uint32_t i; } UnionValue; 34 | UnionValue compare, interpret; 35 | compare.f = kNoExtensionBackoff; 36 | interpret.f = backoff; 37 | return compare.i != interpret.i; 38 | } 39 | 40 | } // namespace ngram 41 | } // namespace lm 42 | #endif // LM_BLANK_H 43 | -------------------------------------------------------------------------------- /lm/builder/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # This CMake file was created by Lane Schwartz 2 | 3 | # Explicitly list the source files for this subdirectory 4 | # 5 | # If you add any source files to this subdirectory 6 | # that should be included in the kenlm library, 7 | # (this excludes any unit test files) 8 | # you should add them to the following list: 9 | # 10 | # In order to set correct paths to these files 11 | # in case this variable is referenced by CMake files in the parent directory, 12 | # we prefix all files with ${CMAKE_CURRENT_SOURCE_DIR}. 13 | # 14 | set(KENLM_BUILDER_SOURCE 15 | ${CMAKE_CURRENT_SOURCE_DIR}/adjust_counts.cc 16 | ${CMAKE_CURRENT_SOURCE_DIR}/corpus_count.cc 17 | ${CMAKE_CURRENT_SOURCE_DIR}/initial_probabilities.cc 18 | ${CMAKE_CURRENT_SOURCE_DIR}/interpolate.cc 19 | ${CMAKE_CURRENT_SOURCE_DIR}/output.cc 20 | ${CMAKE_CURRENT_SOURCE_DIR}/pipeline.cc 21 | ) 22 | 23 | 24 | # Group these objects together for later use. 25 | # 26 | # Given add_library(foo OBJECT ${my_foo_sources}), 27 | # refer to these objects as $ 28 | # 29 | add_library(kenlm_builder ${KENLM_BUILDER_SOURCE}) 30 | 31 | target_link_libraries(kenlm_builder PUBLIC kenlm kenlm_util Threads::Threads) 32 | # Since headers are relative to `include/kenlm` at install time, not just `include` 33 | target_include_directories(kenlm_builder PUBLIC $) 34 | 35 | AddExes(EXES lmplz 36 | LIBRARIES kenlm_builder kenlm kenlm_util Threads::Threads) 37 | AddExes(EXES count_ngrams 38 | LIBRARIES kenlm_builder kenlm kenlm_util Threads::Threads) 39 | 40 | install( 41 | TARGETS kenlm_builder 42 | EXPORT kenlmTargets 43 | RUNTIME DESTINATION bin 44 | LIBRARY DESTINATION lib 45 | ARCHIVE DESTINATION lib 46 | INCLUDES DESTINATION include 47 | ) 48 | 49 | if(BUILD_TESTING) 50 | 51 | # Explicitly list the Boost test files to be compiled 52 | set(KENLM_BOOST_TESTS_LIST 53 | adjust_counts_test 54 | corpus_count_test 55 | ) 56 | 57 | AddTests(TESTS ${KENLM_BOOST_TESTS_LIST} 58 | LIBRARIES kenlm_builder kenlm kenlm_util Threads::Threads) 59 | endif() 60 | -------------------------------------------------------------------------------- /lm/builder/README.md: -------------------------------------------------------------------------------- 1 | Dependencies 2 | ============ 3 | 4 | Boost >= 1.42.0 is required. 5 | 6 | For Ubuntu, 7 | ```bash 8 | sudo apt-get install libboost1.48-all-dev 9 | ``` 10 | 11 | Alternatively, you can download, compile, and install it yourself: 12 | 13 | ```bash 14 | wget http://sourceforge.net/projects/boost/files/boost/1.52.0/boost_1_52_0.tar.gz/download -O boost_1_52_0.tar.gz 15 | tar -xvzf boost_1_52_0.tar.gz 16 | cd boost_1_52_0 17 | ./bootstrap.sh 18 | ./b2 19 | sudo ./b2 install 20 | ``` 21 | 22 | Local install options (in a user-space prefix directory) are also possible. See http://www.boost.org/doc/libs/1_52_0/doc/html/bbv2/installation.html. 23 | 24 | 25 | Building 26 | ======== 27 | 28 | ```bash 29 | bjam 30 | ``` 31 | Your distribution might package bjam and boost-build separately from Boost. Both are required. 32 | 33 | Usage 34 | ===== 35 | 36 | Run 37 | ```bash 38 | $ bin/lmplz 39 | ``` 40 | to see command line arguments 41 | 42 | Running 43 | ======= 44 | 45 | ```bash 46 | bin/lmplz -o 5 text.arpa 47 | ``` 48 | -------------------------------------------------------------------------------- /lm/builder/TODO: -------------------------------------------------------------------------------- 1 | More tests! 2 | Sharding. 3 | Some way to manage all the crazy config options. 4 | Option to build the binary file directly. 5 | Interpolation of different orders. 6 | -------------------------------------------------------------------------------- /lm/builder/adjust_counts.hh: -------------------------------------------------------------------------------- 1 | #ifndef LM_BUILDER_ADJUST_COUNTS_H 2 | #define LM_BUILDER_ADJUST_COUNTS_H 3 | 4 | #include "discount.hh" 5 | #include "../lm_exception.hh" 6 | #include "../../util/exception.hh" 7 | 8 | #include 9 | 10 | #include 11 | 12 | namespace util { namespace stream { class ChainPositions; } } 13 | 14 | namespace lm { 15 | namespace builder { 16 | 17 | class BadDiscountException : public util::Exception { 18 | public: 19 | BadDiscountException() throw(); 20 | ~BadDiscountException() throw(); 21 | }; 22 | 23 | struct DiscountConfig { 24 | // Overrides discounts for orders [1,discount_override.size()]. 25 | std::vector overwrite; 26 | // If discounting fails for an order, copy them from here. 27 | Discount fallback; 28 | // What to do when discounts are out of range or would trigger divison by 29 | // zero. It it does something other than THROW_UP, use fallback_discount. 30 | WarningAction bad_action; 31 | }; 32 | 33 | /* Compute adjusted counts. 34 | * Input: unique suffix sorted N-grams (and just the N-grams) with raw counts. 35 | * Output: [1,N]-grams with adjusted counts. 36 | * [1,N)-grams are in suffix order 37 | * N-grams are in undefined order (they're going to be sorted anyway). 38 | */ 39 | class AdjustCounts { 40 | public: 41 | // counts: output 42 | // counts_pruned: output 43 | // discounts: mostly output. If the input already has entries, they will be kept. 44 | // prune_thresholds: input. n-grams with normal (not adjusted) count below this will be pruned. 45 | AdjustCounts( 46 | const std::vector &prune_thresholds, 47 | std::vector &counts, 48 | std::vector &counts_pruned, 49 | const std::vector &prune_words, 50 | const DiscountConfig &discount_config, 51 | std::vector &discounts) 52 | : prune_thresholds_(prune_thresholds), counts_(counts), counts_pruned_(counts_pruned), 53 | prune_words_(prune_words), discount_config_(discount_config), discounts_(discounts) 54 | {} 55 | 56 | void Run(const util::stream::ChainPositions &positions); 57 | 58 | private: 59 | const std::vector &prune_thresholds_; 60 | std::vector &counts_; 61 | std::vector &counts_pruned_; 62 | const std::vector &prune_words_; 63 | 64 | DiscountConfig discount_config_; 65 | std::vector &discounts_; 66 | }; 67 | 68 | } // namespace builder 69 | } // namespace lm 70 | 71 | #endif // LM_BUILDER_ADJUST_COUNTS_H 72 | 73 | -------------------------------------------------------------------------------- /lm/builder/combine_counts.hh: -------------------------------------------------------------------------------- 1 | #ifndef LM_BUILDER_COMBINE_COUNTS_H 2 | #define LM_BUILDER_COMBINE_COUNTS_H 3 | 4 | #include "payload.hh" 5 | #include "../common/ngram.hh" 6 | #include "../common/compare.hh" 7 | #include "../word_index.hh" 8 | #include "../../util/stream/sort.hh" 9 | 10 | #include 11 | #include 12 | 13 | namespace lm { 14 | namespace builder { 15 | 16 | // Sum counts for the same n-gram. 17 | struct CombineCounts { 18 | bool operator()(void *first_void, const void *second_void, const SuffixOrder &compare) const { 19 | NGram first(first_void, compare.Order()); 20 | // There isn't a const version of NGram. 21 | NGram second(const_cast(second_void), compare.Order()); 22 | if (memcmp(first.begin(), second.begin(), sizeof(WordIndex) * compare.Order())) return false; 23 | first.Value().count += second.Value().count; 24 | return true; 25 | } 26 | }; 27 | 28 | } // namespace builder 29 | } // namespace lm 30 | 31 | #endif // LM_BUILDER_COMBINE_COUNTS_H 32 | -------------------------------------------------------------------------------- /lm/builder/corpus_count.hh: -------------------------------------------------------------------------------- 1 | #ifndef LM_BUILDER_CORPUS_COUNT_H 2 | #define LM_BUILDER_CORPUS_COUNT_H 3 | 4 | #include "../lm_exception.hh" 5 | #include "../word_index.hh" 6 | #include "../../util/scoped.hh" 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | namespace util { 14 | class FilePiece; 15 | namespace stream { 16 | class ChainPosition; 17 | } // namespace stream 18 | } // namespace util 19 | 20 | namespace lm { 21 | namespace builder { 22 | 23 | class CorpusCount { 24 | public: 25 | // Memory usage will be DedupeMultipler(order) * block_size + total_chain_size + unknown vocab_hash_size 26 | static float DedupeMultiplier(std::size_t order); 27 | 28 | // How much memory vocabulary will use based on estimated size of the vocab. 29 | static std::size_t VocabUsage(std::size_t vocab_estimate); 30 | 31 | // token_count: out. 32 | // type_count aka vocabulary size. Initialize to an estimate. It is set to the exact value. 33 | CorpusCount(util::FilePiece &from, int vocab_write, bool dynamic_vocab, uint64_t &token_count, WordIndex &type_count, std::vector &prune_words, const std::string& prune_vocab_filename, std::size_t entries_per_block, WarningAction disallowed_symbol); 34 | 35 | void Run(const util::stream::ChainPosition &position); 36 | 37 | private: 38 | template void RunWithVocab(const util::stream::ChainPosition &position, Vocab &vocab); 39 | 40 | util::FilePiece &from_; 41 | int vocab_write_; 42 | bool dynamic_vocab_; 43 | uint64_t &token_count_; 44 | WordIndex &type_count_; 45 | std::vector &prune_words_; 46 | const std::string prune_vocab_filename_; 47 | 48 | std::size_t dedupe_mem_size_; 49 | util::scoped_malloc dedupe_mem_; 50 | 51 | WarningAction disallowed_symbol_action_; 52 | }; 53 | 54 | } // namespace builder 55 | } // namespace lm 56 | #endif // LM_BUILDER_CORPUS_COUNT_H 57 | -------------------------------------------------------------------------------- /lm/builder/corpus_count_test.cc: -------------------------------------------------------------------------------- 1 | #include "corpus_count.hh" 2 | 3 | #include "payload.hh" 4 | #include "../common/ngram_stream.hh" 5 | #include "../common/ngram.hh" 6 | 7 | #include "../../util/file.hh" 8 | #include "../../util/file_piece.hh" 9 | #include "../../util/tokenize_piece.hh" 10 | #include "../../util/stream/chain.hh" 11 | #include "../../util/stream/stream.hh" 12 | 13 | #define BOOST_TEST_MODULE CorpusCountTest 14 | #include 15 | 16 | namespace lm { namespace builder { namespace { 17 | 18 | #define Check(str, cnt) { \ 19 | BOOST_REQUIRE(stream); \ 20 | w = stream->begin(); \ 21 | for (util::TokenIter t(str, " "); t; ++t, ++w) { \ 22 | BOOST_CHECK_EQUAL(*t, v[*w]); \ 23 | } \ 24 | BOOST_CHECK_EQUAL((uint64_t)cnt, stream->Value().count); \ 25 | ++stream; \ 26 | } 27 | 28 | class CheckAnswers { 29 | public: 30 | void Run(const util::stream::ChainPosition &position) { 31 | NGramStream stream(position); 32 | const char *v[] = {"", "", "", "looking", "on", "a", "little", "more", "loin", "foo", "bar"}; 33 | WordIndex *w; 34 | 35 | Check(" looking", 1); 36 | Check(" looking on", 1); 37 | Check("looking on a", 1); 38 | Check("on a little", 2); 39 | Check("a little more", 2); 40 | Check("little more loin", 2); 41 | Check("more loin ", 2); 42 | Check(" on", 2); 43 | Check(" on a", 1); 44 | Check(" on foo", 1); 45 | Check("on foo little", 1); 46 | Check("foo little more", 1); 47 | Check("little more loin", 1); 48 | Check("more loin ", 1); 49 | Check(" bar", 1); 50 | Check(" bar ", 1); 51 | Check(" ", 1); 52 | BOOST_CHECK(!stream); 53 | } 54 | }; 55 | 56 | BOOST_AUTO_TEST_CASE(Short) { 57 | util::scoped_fd input_file(util::MakeTemp("corpus_count_test_temp")); 58 | const char input[] = "looking on a little more loin\non a little more loin\non foo little more loin\nbar\n\n"; 59 | // Blocks of 10 are 60 | // looking on a little more loin on a little[duplicate] more[duplicate] loin[duplicate] [duplicate] on[duplicate] foo 61 | // little more loin bar 62 | 63 | util::WriteOrThrow(input_file.get(), input, sizeof(input) - 1); 64 | util::SeekOrThrow(input_file.get(), 0); 65 | util::FilePiece input_piece(input_file.release(), "temp file"); 66 | 67 | util::stream::ChainConfig config; 68 | config.entry_size = NGram::TotalSize(3); 69 | config.total_memory = config.entry_size * 20; 70 | config.block_count = 2; 71 | 72 | util::scoped_fd vocab(util::MakeTemp("corpus_count_test_vocab")); 73 | 74 | uint64_t token_count; 75 | WordIndex type_count = 10; 76 | std::vector prune_words; 77 | util::stream::Chain chain(config); 78 | CorpusCount counter(input_piece, vocab.get(), true, token_count, type_count, prune_words, "", chain.BlockSize() / chain.EntrySize(), SILENT); 79 | chain >> boost::ref(counter) >> CheckAnswers() >> util::stream::kRecycle; 80 | 81 | chain.Wait(); 82 | BOOST_CHECK_EQUAL(11, type_count); 83 | } 84 | 85 | }}} // namespaces 86 | -------------------------------------------------------------------------------- /lm/builder/debug_print.hh: -------------------------------------------------------------------------------- 1 | #ifndef LM_BUILDER_DEBUG_PRINT_H 2 | #define LM_BUILDER_DEBUG_PRINT_H 3 | 4 | #include "payload.hh" 5 | #include "../common/print.hh" 6 | #include "../common/ngram_stream.hh" 7 | #include "../../util/file_stream.hh" 8 | #include "../../util/file.hh" 9 | 10 | #include 11 | 12 | namespace lm { namespace builder { 13 | // Not defined, only specialized. 14 | template void PrintPayload(util::FileStream &to, const BuildingPayload &payload); 15 | template <> inline void PrintPayload(util::FileStream &to, const BuildingPayload &payload) { 16 | to << payload.count; 17 | } 18 | template <> inline void PrintPayload(util::FileStream &to, const BuildingPayload &payload) { 19 | to << log10(payload.uninterp.prob) << ' ' << log10(payload.uninterp.gamma); 20 | } 21 | template <> inline void PrintPayload(util::FileStream &to, const BuildingPayload &payload) { 22 | to << payload.complete.prob << ' ' << payload.complete.backoff; 23 | } 24 | 25 | // template parameter is the type stored. 26 | template class Print { 27 | public: 28 | static void DumpSeparateFiles(const VocabReconstitute &vocab, const std::string &file_base, util::stream::Chains &chains) { 29 | for (unsigned int i = 0; i < chains.size(); ++i) { 30 | std::string file(file_base + boost::lexical_cast(i)); 31 | chains[i] >> Print(vocab, util::CreateOrThrow(file.c_str())); 32 | } 33 | } 34 | 35 | explicit Print(const VocabReconstitute &vocab, int fd) : vocab_(vocab), to_(fd) {} 36 | 37 | void Run(const util::stream::ChainPositions &chains) { 38 | util::scoped_fd fd(to_); 39 | util::FileStream out(to_); 40 | NGramStreams streams(chains); 41 | for (NGramStream *s = streams.begin(); s != streams.end(); ++s) { 42 | DumpStream(*s, out); 43 | } 44 | } 45 | 46 | void Run(const util::stream::ChainPosition &position) { 47 | util::scoped_fd fd(to_); 48 | util::FileStream out(to_); 49 | NGramStream stream(position); 50 | DumpStream(stream, out); 51 | } 52 | 53 | private: 54 | void DumpStream(NGramStream &stream, util::FileStream &to) { 55 | for (; stream; ++stream) { 56 | PrintPayload(to, stream->Value()); 57 | for (const WordIndex *w = stream->begin(); w != stream->end(); ++w) { 58 | to << ' ' << vocab_.Lookup(*w) << '=' << *w; 59 | } 60 | to << '\n'; 61 | } 62 | } 63 | 64 | const VocabReconstitute &vocab_; 65 | int to_; 66 | }; 67 | 68 | }} // namespaces 69 | 70 | #endif // LM_BUILDER_DEBUG_PRINT_H 71 | -------------------------------------------------------------------------------- /lm/builder/discount.hh: -------------------------------------------------------------------------------- 1 | #ifndef LM_BUILDER_DISCOUNT_H 2 | #define LM_BUILDER_DISCOUNT_H 3 | 4 | #include 5 | 6 | #include 7 | 8 | namespace lm { 9 | namespace builder { 10 | 11 | struct Discount { 12 | float amount[4]; 13 | 14 | float Get(uint64_t count) const { 15 | return amount[std::min(count, 3)]; 16 | } 17 | 18 | float Apply(uint64_t count) const { 19 | return static_cast(count) - Get(count); 20 | } 21 | }; 22 | 23 | } // namespace builder 24 | } // namespace lm 25 | 26 | #endif // LM_BUILDER_DISCOUNT_H 27 | -------------------------------------------------------------------------------- /lm/builder/dump_counts_main.cc: -------------------------------------------------------------------------------- 1 | #include "../common/print.hh" 2 | #include "../word_index.hh" 3 | #include "../../util/file.hh" 4 | #include "../../util/read_compressed.hh" 5 | 6 | #include 7 | 8 | #include 9 | #include 10 | 11 | int main(int argc, char *argv[]) { 12 | if (argc != 4) { 13 | std::cerr << "Usage: " << argv[0] << " counts vocabulary order\n" 14 | "The counts file contains records with 4-byte vocabulary ids followed by 8-byte\n" 15 | "counts. Each record has order many vocabulary ids.\n" 16 | "The vocabulary file contains the words delimited by NULL in order of id.\n" 17 | "The vocabulary file may not be compressed because it is mmapped but the counts\n" 18 | "file can be compressed.\n"; 19 | return 1; 20 | } 21 | util::ReadCompressed counts(util::OpenReadOrThrow(argv[1])); 22 | util::scoped_fd vocab_file(util::OpenReadOrThrow(argv[2])); 23 | lm::VocabReconstitute vocab(vocab_file.get()); 24 | unsigned int order = boost::lexical_cast(argv[3]); 25 | std::vector record(sizeof(uint32_t) * order + sizeof(uint64_t)); 26 | while (std::size_t got = counts.ReadOrEOF(&*record.begin(), record.size())) { 27 | UTIL_THROW_IF(got != record.size(), util::Exception, "Read " << got << " bytes at the end of file, which is not a complete record of length " << record.size()); 28 | const lm::WordIndex *words = reinterpret_cast(&*record.begin()); 29 | for (const lm::WordIndex *i = words; i != words + order; ++i) { 30 | UTIL_THROW_IF(*i >= vocab.Size(), util::Exception, "Vocab ID " << *i << " is larger than the vocab file's maximum of " << vocab.Size() << ". Are you sure you have the right order and vocab file for these counts?"); 31 | std::cout << vocab.Lookup(*i) << ' '; 32 | } 33 | // TODO don't use std::cout because it is slow. Add fast uint64_t printing support to FileStream. 34 | std::cout << *reinterpret_cast(words + order) << '\n'; 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /lm/builder/hash_gamma.hh: -------------------------------------------------------------------------------- 1 | #ifndef LM_BUILDER_HASH_GAMMA__ 2 | #define LM_BUILDER_HASH_GAMMA__ 3 | 4 | #include 5 | 6 | namespace lm { namespace builder { 7 | 8 | #pragma pack(push) 9 | #pragma pack(4) 10 | 11 | struct HashGamma { 12 | uint64_t hash_value; 13 | float gamma; 14 | }; 15 | 16 | #pragma pack(pop) 17 | 18 | }} // namespaces 19 | #endif // LM_BUILDER_HASH_GAMMA__ 20 | -------------------------------------------------------------------------------- /lm/builder/header_info.hh: -------------------------------------------------------------------------------- 1 | #ifndef LM_BUILDER_HEADER_INFO_H 2 | #define LM_BUILDER_HEADER_INFO_H 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | namespace lm { namespace builder { 9 | 10 | // Some configuration info that is used to add 11 | // comments to the beginning of an ARPA file 12 | struct HeaderInfo { 13 | std::string input_file; 14 | uint64_t token_count; 15 | std::vector counts_pruned; 16 | 17 | HeaderInfo() {} 18 | 19 | HeaderInfo(const std::string& input_file_in, uint64_t token_count_in, const std::vector &counts_pruned_in) 20 | : input_file(input_file_in), token_count(token_count_in), counts_pruned(counts_pruned_in) {} 21 | 22 | // TODO: Add smoothing type 23 | // TODO: More info if multiple models were interpolated 24 | }; 25 | 26 | }} // namespaces 27 | 28 | #endif 29 | -------------------------------------------------------------------------------- /lm/builder/initial_probabilities.hh: -------------------------------------------------------------------------------- 1 | #ifndef LM_BUILDER_INITIAL_PROBABILITIES_H 2 | #define LM_BUILDER_INITIAL_PROBABILITIES_H 3 | 4 | #include "discount.hh" 5 | #include "../word_index.hh" 6 | #include "../../util/stream/config.hh" 7 | 8 | #include 9 | 10 | namespace util { namespace stream { class Chains; } } 11 | 12 | namespace lm { 13 | class SpecialVocab; 14 | namespace builder { 15 | 16 | struct InitialProbabilitiesConfig { 17 | // These should be small buffers to keep the adder from getting too far ahead 18 | util::stream::ChainConfig adder_in; 19 | util::stream::ChainConfig adder_out; 20 | // SRILM doesn't normally interpolate unigrams. 21 | bool interpolate_unigrams; 22 | }; 23 | 24 | /* Compute initial (uninterpolated) probabilities 25 | * primary: the normal chain of n-grams. Incoming is context sorted adjusted 26 | * counts. Outgoing has uninterpolated probabilities for use by Interpolate. 27 | * second_in: a second copy of the primary input. Discard the output. 28 | * gamma_out: Computed gamma values are output on these chains in suffix order. 29 | * The values are bare floats and should be buffered for interpolation to 30 | * use. 31 | */ 32 | void InitialProbabilities( 33 | const InitialProbabilitiesConfig &config, 34 | const std::vector &discounts, 35 | util::stream::Chains &primary, 36 | util::stream::Chains &second_in, 37 | util::stream::Chains &gamma_out, 38 | const std::vector &prune_thresholds, 39 | bool prune_vocab, 40 | const SpecialVocab &vocab); 41 | 42 | } // namespace builder 43 | } // namespace lm 44 | 45 | #endif // LM_BUILDER_INITIAL_PROBABILITIES_H 46 | -------------------------------------------------------------------------------- /lm/builder/interpolate.hh: -------------------------------------------------------------------------------- 1 | #ifndef LM_BUILDER_INTERPOLATE_H 2 | #define LM_BUILDER_INTERPOLATE_H 3 | 4 | #include "../common/special.hh" 5 | #include "../word_index.hh" 6 | #include "../../util/stream/multi_stream.hh" 7 | 8 | #include 9 | 10 | #include 11 | 12 | namespace lm { namespace builder { 13 | 14 | /* Interpolate step. 15 | * Input: suffix sorted n-grams with (p_uninterpolated, gamma) from 16 | * InitialProbabilities. 17 | * Output: suffix sorted n-grams with complete probability 18 | */ 19 | class Interpolate { 20 | public: 21 | // Normally vocab_size is the unigram count-1 (since p() = 0) but might 22 | // be larger when the user specifies a consistent vocabulary size. 23 | explicit Interpolate(uint64_t vocab_size, const util::stream::ChainPositions &backoffs, const std::vector &prune_thresholds, bool prune_vocab, bool output_q, const SpecialVocab &specials); 24 | 25 | void Run(const util::stream::ChainPositions &positions); 26 | 27 | private: 28 | float uniform_prob_; 29 | util::stream::ChainPositions backoffs_; 30 | const std::vector prune_thresholds_; 31 | bool prune_vocab_; 32 | bool output_q_; 33 | const SpecialVocab specials_; 34 | }; 35 | 36 | }} // namespaces 37 | #endif // LM_BUILDER_INTERPOLATE_H 38 | -------------------------------------------------------------------------------- /lm/builder/output.cc: -------------------------------------------------------------------------------- 1 | #include "output.hh" 2 | 3 | #include "../common/model_buffer.hh" 4 | #include "../common/print.hh" 5 | #include "../../util/file_stream.hh" 6 | #include "../../util/stream/multi_stream.hh" 7 | 8 | #include 9 | 10 | namespace lm { namespace builder { 11 | 12 | OutputHook::~OutputHook() {} 13 | 14 | Output::Output(StringPiece file_base, bool keep_buffer, bool output_q) 15 | : buffer_(file_base, keep_buffer, output_q) {} 16 | 17 | void Output::SinkProbs(util::stream::Chains &chains) { 18 | Apply(PROB_PARALLEL_HOOK, chains); 19 | if (!buffer_.Keep() && !Have(PROB_SEQUENTIAL_HOOK)) { 20 | chains >> util::stream::kRecycle; 21 | chains.Wait(true); 22 | return; 23 | } 24 | buffer_.Sink(chains, header_.counts_pruned); 25 | chains >> util::stream::kRecycle; 26 | chains.Wait(false); 27 | if (Have(PROB_SEQUENTIAL_HOOK)) { 28 | std::cerr << "=== 5/5 Writing ARPA model ===" << std::endl; 29 | buffer_.Source(chains); 30 | Apply(PROB_SEQUENTIAL_HOOK, chains); 31 | chains >> util::stream::kRecycle; 32 | chains.Wait(true); 33 | } 34 | } 35 | 36 | void Output::Apply(HookType hook_type, util::stream::Chains &chains) { 37 | for (boost::ptr_vector::iterator entry = outputs_[hook_type].begin(); entry != outputs_[hook_type].end(); ++entry) { 38 | entry->Sink(header_, VocabFile(), chains); 39 | } 40 | } 41 | 42 | void PrintHook::Sink(const HeaderInfo &info, int vocab_file, util::stream::Chains &chains) { 43 | if (verbose_header_) { 44 | util::FileStream out(file_.get(), 50); 45 | out << "# Input file: " << info.input_file << '\n'; 46 | out << "# Token count: " << info.token_count << '\n'; 47 | out << "# Smoothing: Modified Kneser-Ney" << '\n'; 48 | } 49 | chains >> PrintARPA(vocab_file, file_.get(), info.counts_pruned); 50 | } 51 | 52 | }} // namespaces 53 | -------------------------------------------------------------------------------- /lm/builder/output.hh: -------------------------------------------------------------------------------- 1 | #ifndef LM_BUILDER_OUTPUT_H 2 | #define LM_BUILDER_OUTPUT_H 3 | 4 | #include "header_info.hh" 5 | #include "../common/model_buffer.hh" 6 | #include "../../util/file.hh" 7 | 8 | #include 9 | #include 10 | 11 | namespace util { namespace stream { class Chains; class ChainPositions; } } 12 | 13 | /* Outputs from lmplz: ARPA, sharded files, etc */ 14 | namespace lm { namespace builder { 15 | 16 | // These are different types of hooks. Values should be consecutive to enable a vector lookup. 17 | enum HookType { 18 | // TODO: counts. 19 | PROB_PARALLEL_HOOK, // Probability and backoff (or just q). Output must process the orders in parallel or there will be a deadlock. 20 | PROB_SEQUENTIAL_HOOK, // Probability and backoff (or just q). Output can process orders any way it likes. This requires writing the data to disk then reading. Useful for ARPA files, which put unigrams first etc. 21 | NUMBER_OF_HOOKS // Keep this last so we know how many values there are. 22 | }; 23 | 24 | class OutputHook { 25 | public: 26 | explicit OutputHook(HookType hook_type) : type_(hook_type) {} 27 | 28 | virtual ~OutputHook(); 29 | 30 | virtual void Sink(const HeaderInfo &info, int vocab_file, util::stream::Chains &chains) = 0; 31 | 32 | HookType Type() const { return type_; } 33 | 34 | private: 35 | HookType type_; 36 | }; 37 | 38 | class Output : boost::noncopyable { 39 | public: 40 | Output(StringPiece file_base, bool keep_buffer, bool output_q); 41 | 42 | // Takes ownership. 43 | void Add(OutputHook *hook) { 44 | outputs_[hook->Type()].push_back(hook); 45 | } 46 | 47 | bool Have(HookType hook_type) const { 48 | return !outputs_[hook_type].empty(); 49 | } 50 | 51 | int VocabFile() const { return buffer_.VocabFile(); } 52 | 53 | void SetHeader(const HeaderInfo &header) { header_ = header; } 54 | const HeaderInfo &GetHeader() const { return header_; } 55 | 56 | // This is called by the pipeline. 57 | void SinkProbs(util::stream::Chains &chains); 58 | 59 | unsigned int Steps() const { return Have(PROB_SEQUENTIAL_HOOK); } 60 | 61 | private: 62 | void Apply(HookType hook_type, util::stream::Chains &chains); 63 | 64 | ModelBuffer buffer_; 65 | 66 | boost::ptr_vector outputs_[NUMBER_OF_HOOKS]; 67 | HeaderInfo header_; 68 | }; 69 | 70 | class PrintHook : public OutputHook { 71 | public: 72 | // Takes ownership 73 | PrintHook(int write_fd, bool verbose_header) 74 | : OutputHook(PROB_SEQUENTIAL_HOOK), file_(write_fd), verbose_header_(verbose_header) {} 75 | 76 | void Sink(const HeaderInfo &info, int vocab_file, util::stream::Chains &chains); 77 | 78 | private: 79 | util::scoped_fd file_; 80 | bool verbose_header_; 81 | }; 82 | 83 | }} // namespaces 84 | 85 | #endif // LM_BUILDER_OUTPUT_H 86 | -------------------------------------------------------------------------------- /lm/builder/payload.hh: -------------------------------------------------------------------------------- 1 | #ifndef LM_BUILDER_PAYLOAD_H 2 | #define LM_BUILDER_PAYLOAD_H 3 | 4 | #include "../weights.hh" 5 | #include "../word_index.hh" 6 | #include 7 | 8 | namespace lm { namespace builder { 9 | 10 | struct Uninterpolated { 11 | float prob; // Uninterpolated probability. 12 | float gamma; // Interpolation weight for lower order. 13 | }; 14 | 15 | union BuildingPayload { 16 | uint64_t count; 17 | Uninterpolated uninterp; 18 | ProbBackoff complete; 19 | 20 | /*mjd**********************************************************************/ 21 | bool IsMarked() const { 22 | return count >> (sizeof(count) * 8 - 1); 23 | } 24 | 25 | void Mark() { 26 | count |= (1ULL << (sizeof(count) * 8 - 1)); 27 | } 28 | 29 | void Unmark() { 30 | count &= ~(1ULL << (sizeof(count) * 8 - 1)); 31 | } 32 | 33 | uint64_t UnmarkedCount() const { 34 | return count & ~(1ULL << (sizeof(count) * 8 - 1)); 35 | } 36 | 37 | uint64_t CutoffCount() const { 38 | return IsMarked() ? 0 : UnmarkedCount(); 39 | } 40 | /*mjd**********************************************************************/ 41 | }; 42 | 43 | const WordIndex kBOS = 1; 44 | const WordIndex kEOS = 2; 45 | 46 | }} // namespaces 47 | 48 | #endif // LM_BUILDER_PAYLOAD_H 49 | -------------------------------------------------------------------------------- /lm/builder/pipeline.hh: -------------------------------------------------------------------------------- 1 | #ifndef LM_BUILDER_PIPELINE_H 2 | #define LM_BUILDER_PIPELINE_H 3 | 4 | #include "adjust_counts.hh" 5 | #include "initial_probabilities.hh" 6 | #include "header_info.hh" 7 | #include "../lm_exception.hh" 8 | #include "../word_index.hh" 9 | #include "../../util/stream/config.hh" 10 | #include "../../util/file_piece.hh" 11 | 12 | #include 13 | #include 14 | 15 | namespace lm { namespace builder { 16 | 17 | class Output; 18 | 19 | struct PipelineConfig { 20 | std::size_t order; 21 | util::stream::SortConfig sort; 22 | InitialProbabilitiesConfig initial_probs; 23 | util::stream::ChainConfig read_backoffs; 24 | 25 | // Estimated vocabulary size. Used for sizing CorpusCount memory and 26 | // initial probing hash table sizing, also in CorpusCount. 27 | lm::WordIndex vocab_estimate; 28 | 29 | // Minimum block size to tolerate. 30 | std::size_t minimum_block; 31 | 32 | // Number of blocks to use. This will be overridden to 1 if everything fits. 33 | std::size_t block_count; 34 | 35 | // n-gram count thresholds for pruning. 0 values means no pruning for 36 | // corresponding n-gram order 37 | std::vector prune_thresholds; //mjd 38 | bool prune_vocab; 39 | std::string prune_vocab_file; 40 | 41 | /* Renumber the vocabulary the way the trie likes it? */ 42 | bool renumber_vocabulary; 43 | 44 | // What to do with discount failures. 45 | DiscountConfig discount; 46 | 47 | // Compute collapsed q values instead of probability and backoff 48 | bool output_q; 49 | 50 | /* Computing the perplexity of LMs with different vocabularies is hard. For 51 | * example, the lowest perplexity is attained by a unigram model that 52 | * predicts p() = 1 and has no other vocabulary. Also, linearly 53 | * interpolated models will sum to more than 1 because is duplicated 54 | * (SRI just pretends p() = 0 for these purposes, which makes it sum to 55 | * 1 but comes with its own problems). This option will make the vocabulary 56 | * a particular size by replicating multiple times for purposes of 57 | * computing vocabulary size. It has no effect if the actual vocabulary is 58 | * larger. This parameter serves the same purpose as IRSTLM's "dub". 59 | */ 60 | uint64_t vocab_size_for_unk; 61 | 62 | /* What to do the first time , , or appears in the input. If 63 | * this is anything but THROW_UP, then the symbol will always be treated as 64 | * whitespace. 65 | */ 66 | WarningAction disallowed_symbol_action; 67 | 68 | const std::string &TempPrefix() const { return sort.temp_prefix; } 69 | std::size_t TotalMemory() const { return sort.total_memory; } 70 | }; 71 | 72 | // Takes ownership of text_file and out_arpa. 73 | void Pipeline(PipelineConfig &config, int text_file, Output &output); 74 | 75 | }} // namespaces 76 | #endif // LM_BUILDER_PIPELINE_H 77 | -------------------------------------------------------------------------------- /lm/common/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # This CMake file was created by Lane Schwartz 2 | 3 | # Explicitly list the source files for this subdirectory 4 | # 5 | # If you add any source files to this subdirectory 6 | # that should be included in the kenlm library, 7 | # (this excludes any unit test files) 8 | # you should add them to the following list: 9 | # 10 | # In order to set correct paths to these files 11 | # in case this variable is referenced by CMake files in the parent directory, 12 | # we prefix all files with ${CMAKE_CURRENT_SOURCE_DIR}. 13 | # 14 | set(KENLM_LM_COMMON_SOURCE 15 | ${CMAKE_CURRENT_SOURCE_DIR}/model_buffer.cc 16 | ${CMAKE_CURRENT_SOURCE_DIR}/print.cc 17 | ${CMAKE_CURRENT_SOURCE_DIR}/renumber.cc 18 | ${CMAKE_CURRENT_SOURCE_DIR}/size_option.cc 19 | PARENT_SCOPE) 20 | 21 | if(BUILD_TESTING) 22 | KenLMAddTest(TEST model_buffer_test 23 | LIBRARIES kenlm 24 | TEST_ARGS ${CMAKE_CURRENT_SOURCE_DIR}/test_data) 25 | endif() 26 | -------------------------------------------------------------------------------- /lm/common/joint_order.hh: -------------------------------------------------------------------------------- 1 | #ifndef LM_COMMON_JOINT_ORDER_H 2 | #define LM_COMMON_JOINT_ORDER_H 3 | 4 | #include "ngram_stream.hh" 5 | #include "../lm_exception.hh" 6 | 7 | #ifdef DEBUG 8 | #include "../../util/fixed_array.hh" 9 | #include 10 | #endif 11 | 12 | #include 13 | 14 | namespace lm { 15 | 16 | template void JointOrder(const util::stream::ChainPositions &positions, Callback &callback) { 17 | // Allow matching to reference streams[-1]. 18 | util::FixedArray > streams_with_dummy(positions.size() + 1); 19 | // A bogus stream for [-1]. 20 | streams_with_dummy.push_back(); 21 | for (std::size_t i = 0; i < positions.size(); ++i) { 22 | streams_with_dummy.push_back(positions[i], NGramHeader(NULL, i + 1)); 23 | } 24 | ProxyStream *streams = streams_with_dummy.begin() + 1; 25 | 26 | std::size_t order; 27 | for (order = 0; order < positions.size() && streams[order]; ++order) {} 28 | assert(order); // should always have . 29 | 30 | // Debugging only: call comparison function to sanity check order. 31 | #ifdef DEBUG 32 | util::FixedArray less_compare(order); 33 | for (unsigned i = 0; i < order; ++i) 34 | less_compare.push_back(i + 1); 35 | #endif // DEBUG 36 | 37 | std::size_t current = 0; 38 | while (true) { 39 | // Does the context match the lower one? 40 | if (!memcmp(streams[static_cast(current) - 1]->begin(), streams[current]->begin() + Compare::kMatchOffset, sizeof(WordIndex) * current)) { 41 | callback.Enter(current, streams[current].Get()); 42 | // Transition to looking for extensions. 43 | if (++current < order) continue; 44 | } 45 | #ifdef DEBUG 46 | // match_check[current - 1] matches current-grams 47 | // The lower-order stream (which skips fewer current-grams) should always be <= the higher order-stream (which can skip current-grams). 48 | else if (!less_compare[current - 1](streams[static_cast(current) - 1]->begin(), streams[current]->begin() + Compare::kMatchOffset)) { 49 | std::cerr << "Stream out of order detected" << std::endl; 50 | abort(); 51 | } 52 | #endif // DEBUG 53 | // No extension left. 54 | while(true) { 55 | assert(current > 0); 56 | --current; 57 | callback.Exit(current, streams[current].Get()); 58 | 59 | if (++streams[current]) break; 60 | 61 | UTIL_THROW_IF(order != current + 1, FormatLoadException, "Detected n-gram without matching suffix"); 62 | 63 | order = current; 64 | if (!order) return; 65 | } 66 | } 67 | } 68 | 69 | } // namespaces 70 | 71 | #endif // LM_COMMON_JOINT_ORDER_H 72 | -------------------------------------------------------------------------------- /lm/common/model_buffer.hh: -------------------------------------------------------------------------------- 1 | #ifndef LM_COMMON_MODEL_BUFFER_H 2 | #define LM_COMMON_MODEL_BUFFER_H 3 | 4 | /* Format with separate files in suffix order. Each file contains 5 | * n-grams of the same order. 6 | */ 7 | #include "../word_index.hh" 8 | #include "../../util/file.hh" 9 | #include "../../util/fixed_array.hh" 10 | #include "../../util/string_piece.hh" 11 | 12 | #include 13 | #include 14 | 15 | namespace util { namespace stream { 16 | class Chains; 17 | class Chain; 18 | }} // namespaces 19 | 20 | namespace lm { 21 | 22 | namespace ngram { class State; } 23 | 24 | class ModelBuffer { 25 | public: 26 | // Construct for writing. Must call VocabFile() and fill it with null-delimited vocab words. 27 | ModelBuffer(StringPiece file_base, bool keep_buffer, bool output_q); 28 | 29 | // Load from file. 30 | explicit ModelBuffer(StringPiece file_base); 31 | 32 | // Must call VocabFile and populate before calling this function. 33 | void Sink(util::stream::Chains &chains, const std::vector &counts); 34 | 35 | // Read files and write to the given chains. If fewer chains are provided, 36 | // only do the lower orders. 37 | void Source(util::stream::Chains &chains); 38 | 39 | void Source(std::size_t order_minus_1, util::stream::Chain &chain); 40 | 41 | // The order of the n-gram model that is associated with the model buffer. 42 | std::size_t Order() const { return counts_.size(); } 43 | // Requires Sink or load from file. 44 | const std::vector &Counts() const { 45 | assert(!counts_.empty()); 46 | return counts_; 47 | } 48 | 49 | int VocabFile() const { return vocab_file_.get(); } 50 | 51 | int RawFile(std::size_t order_minus_1) const { 52 | return files_[order_minus_1].get(); 53 | } 54 | 55 | bool Keep() const { return keep_buffer_; } 56 | 57 | // Slowly execute a language model query with binary search. 58 | // This is used by interpolation to gather tuning probabilities rather than 59 | // scanning the files. 60 | float SlowQuery(const ngram::State &context, WordIndex word, ngram::State &out) const; 61 | 62 | private: 63 | const std::string file_base_; 64 | const bool keep_buffer_; 65 | bool output_q_; 66 | std::vector counts_; 67 | 68 | util::scoped_fd vocab_file_; 69 | util::FixedArray files_; 70 | }; 71 | 72 | } // namespace lm 73 | 74 | #endif // LM_COMMON_MODEL_BUFFER_H 75 | -------------------------------------------------------------------------------- /lm/common/model_buffer_test.cc: -------------------------------------------------------------------------------- 1 | #include "model_buffer.hh" 2 | #include "../model.hh" 3 | #include "../state.hh" 4 | 5 | #define BOOST_TEST_MODULE ModelBufferTest 6 | #include 7 | 8 | namespace lm { namespace { 9 | 10 | BOOST_AUTO_TEST_CASE(Query) { 11 | std::string dir("test_data"); 12 | if (boost::unit_test::framework::master_test_suite().argc == 2) { 13 | dir = boost::unit_test::framework::master_test_suite().argv[1]; 14 | } 15 | ngram::Model ref((dir + "/toy0.arpa").c_str()); 16 | #if BYTE_ORDER == LITTLE_ENDIAN 17 | std::string endian = "little"; 18 | #elif BYTE_ORDER == BIG_ENDIAN 19 | std::string endian = "big"; 20 | #else 21 | #error "Unsupported byte order." 22 | #endif 23 | 24 | ModelBuffer test(dir + "/" + endian + "endian/toy0"); 25 | ngram::State ref_state, test_state; 26 | WordIndex a = ref.GetVocabulary().Index("a"); 27 | BOOST_CHECK_CLOSE( 28 | ref.FullScore(ref.BeginSentenceState(), a, ref_state).prob, 29 | test.SlowQuery(ref.BeginSentenceState(), a, test_state), 30 | 0.001); 31 | BOOST_CHECK_EQUAL((unsigned)ref_state.length, (unsigned)test_state.length); 32 | BOOST_CHECK_EQUAL(ref_state.words[0], test_state.words[0]); 33 | BOOST_CHECK_EQUAL(ref_state.backoff[0], test_state.backoff[0]); 34 | BOOST_CHECK(ref_state == test_state); 35 | 36 | ngram::State ref_state2, test_state2; 37 | WordIndex b = ref.GetVocabulary().Index("b"); 38 | BOOST_CHECK_CLOSE( 39 | ref.FullScore(ref_state, b, ref_state2).prob, 40 | test.SlowQuery(test_state, b, test_state2), 41 | 0.001); 42 | BOOST_CHECK(ref_state2 == test_state2); 43 | BOOST_CHECK_EQUAL(ref_state2.backoff[0], test_state2.backoff[0]); 44 | 45 | BOOST_CHECK_CLOSE( 46 | ref.FullScore(ref_state2, 0, ref_state).prob, 47 | test.SlowQuery(test_state2, 0, test_state), 48 | 0.001); 49 | // The reference does state minimization but this doesn't. 50 | } 51 | 52 | }} // namespaces 53 | -------------------------------------------------------------------------------- /lm/common/ngram.hh: -------------------------------------------------------------------------------- 1 | #ifndef LM_COMMON_NGRAM_H 2 | #define LM_COMMON_NGRAM_H 3 | 4 | #include "../weights.hh" 5 | #include "../word_index.hh" 6 | 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | namespace lm { 13 | 14 | class NGramHeader { 15 | public: 16 | NGramHeader(void *begin, std::size_t order) 17 | : begin_(static_cast(begin)), end_(begin_ + order) {} 18 | 19 | NGramHeader() : begin_(NULL), end_(NULL) {} 20 | 21 | const uint8_t *Base() const { return reinterpret_cast(begin_); } 22 | uint8_t *Base() { return reinterpret_cast(begin_); } 23 | 24 | void ReBase(void *to) { 25 | std::size_t difference = end_ - begin_; 26 | begin_ = reinterpret_cast(to); 27 | end_ = begin_ + difference; 28 | } 29 | 30 | // These are for the vocab index. 31 | // Lower-case in deference to STL. 32 | const WordIndex *begin() const { return begin_; } 33 | WordIndex *begin() { return begin_; } 34 | const WordIndex *end() const { return end_; } 35 | WordIndex *end() { return end_; } 36 | 37 | std::size_t size() const { return end_ - begin_; } 38 | std::size_t Order() const { return end_ - begin_; } 39 | 40 | private: 41 | WordIndex *begin_, *end_; 42 | }; 43 | 44 | template class NGram : public NGramHeader { 45 | public: 46 | typedef PayloadT Payload; 47 | 48 | NGram() : NGramHeader(NULL, 0) {} 49 | 50 | NGram(void *begin, std::size_t order) : NGramHeader(begin, order) {} 51 | 52 | // Would do operator++ but that can get confusing for a stream. 53 | void NextInMemory() { 54 | ReBase(&Value() + 1); 55 | } 56 | 57 | static std::size_t TotalSize(std::size_t order) { 58 | return order * sizeof(WordIndex) + sizeof(Payload); 59 | } 60 | std::size_t TotalSize() const { 61 | // Compiler should optimize this. 62 | return TotalSize(Order()); 63 | } 64 | 65 | static std::size_t OrderFromSize(std::size_t size) { 66 | std::size_t ret = (size - sizeof(Payload)) / sizeof(WordIndex); 67 | assert(size == TotalSize(ret)); 68 | return ret; 69 | } 70 | 71 | const Payload &Value() const { return *reinterpret_cast(end()); } 72 | Payload &Value() { return *reinterpret_cast(end()); } 73 | }; 74 | 75 | } // namespace lm 76 | 77 | #endif // LM_COMMON_NGRAM_H 78 | -------------------------------------------------------------------------------- /lm/common/ngram_stream.hh: -------------------------------------------------------------------------------- 1 | #ifndef LM_BUILDER_NGRAM_STREAM_H 2 | #define LM_BUILDER_NGRAM_STREAM_H 3 | 4 | #include "ngram.hh" 5 | #include "../../util/stream/chain.hh" 6 | #include "../../util/stream/multi_stream.hh" 7 | #include "../../util/stream/stream.hh" 8 | 9 | #include 10 | 11 | namespace lm { 12 | 13 | template class ProxyStream { 14 | public: 15 | // Make an invalid stream. 16 | ProxyStream() {} 17 | 18 | explicit ProxyStream(const util::stream::ChainPosition &position, const Proxy &proxy = Proxy()) 19 | : proxy_(proxy), stream_(position) { 20 | proxy_.ReBase(stream_.Get()); 21 | } 22 | 23 | Proxy &operator*() { return proxy_; } 24 | const Proxy &operator*() const { return proxy_; } 25 | 26 | Proxy *operator->() { return &proxy_; } 27 | const Proxy *operator->() const { return &proxy_; } 28 | 29 | void *Get() { return stream_.Get(); } 30 | const void *Get() const { return stream_.Get(); } 31 | 32 | operator bool() const { return stream_; } 33 | bool operator!() const { return !stream_; } 34 | void Poison() { stream_.Poison(); } 35 | 36 | ProxyStream &operator++() { 37 | ++stream_; 38 | proxy_.ReBase(stream_.Get()); 39 | return *this; 40 | } 41 | 42 | private: 43 | Proxy proxy_; 44 | util::stream::Stream stream_; 45 | }; 46 | 47 | template class NGramStream : public ProxyStream > { 48 | public: 49 | // Make an invalid stream. 50 | NGramStream() {} 51 | 52 | explicit NGramStream(const util::stream::ChainPosition &position) : 53 | ProxyStream >(position, NGram(NULL, NGram::OrderFromSize(position.GetChain().EntrySize()))) {} 54 | }; 55 | 56 | template class NGramStreams : public util::stream::GenericStreams > { 57 | private: 58 | typedef util::stream::GenericStreams > P; 59 | public: 60 | NGramStreams() : P() {} 61 | NGramStreams(const util::stream::ChainPositions &positions) : P(positions) {} 62 | }; 63 | 64 | } // namespace 65 | #endif // LM_BUILDER_NGRAM_STREAM_H 66 | -------------------------------------------------------------------------------- /lm/common/print.cc: -------------------------------------------------------------------------------- 1 | #include "print.hh" 2 | 3 | #include "ngram_stream.hh" 4 | #include "../../util/file_stream.hh" 5 | #include "../../util/file.hh" 6 | #include "../../util/mmap.hh" 7 | #include "../../util/scoped.hh" 8 | 9 | #include 10 | #include 11 | 12 | namespace lm { 13 | 14 | VocabReconstitute::VocabReconstitute(int fd) { 15 | uint64_t size = util::SizeOrThrow(fd); 16 | util::MapRead(util::POPULATE_OR_READ, fd, 0, size, memory_); 17 | const char *const start = static_cast(memory_.get()); 18 | const char *i; 19 | for (i = start; i != start + size; i += strlen(i) + 1) { 20 | map_.push_back(i); 21 | } 22 | // Last one for LookupPiece. 23 | map_.push_back(i); 24 | } 25 | 26 | namespace { 27 | template void PrintLead(const VocabReconstitute &vocab, ProxyStream &stream, util::FileStream &out) { 28 | out << stream->Value().prob << '\t' << vocab.Lookup(*stream->begin()); 29 | for (const WordIndex *i = stream->begin() + 1; i != stream->end(); ++i) { 30 | out << ' ' << vocab.Lookup(*i); 31 | } 32 | } 33 | } // namespace 34 | 35 | void PrintARPA::Run(const util::stream::ChainPositions &positions) { 36 | VocabReconstitute vocab(vocab_fd_); 37 | util::FileStream out(out_fd_); 38 | out << "\\data\\\n"; 39 | for (size_t i = 0; i < positions.size(); ++i) { 40 | out << "ngram " << (i+1) << '=' << counts_[i] << '\n'; 41 | } 42 | out << '\n'; 43 | 44 | for (unsigned order = 1; order < positions.size(); ++order) { 45 | out << "\\" << order << "-grams:" << '\n'; 46 | for (ProxyStream > stream(positions[order - 1], NGram(NULL, order)); stream; ++stream) { 47 | PrintLead(vocab, stream, out); 48 | out << '\t' << stream->Value().backoff << '\n'; 49 | } 50 | out << '\n'; 51 | } 52 | 53 | out << "\\" << positions.size() << "-grams:" << '\n'; 54 | for (ProxyStream > stream(positions.back(), NGram(NULL, positions.size())); stream; ++stream) { 55 | PrintLead(vocab, stream, out); 56 | out << '\n'; 57 | } 58 | out << '\n'; 59 | out << "\\end\\\n"; 60 | } 61 | 62 | } // namespace lm 63 | -------------------------------------------------------------------------------- /lm/common/print.hh: -------------------------------------------------------------------------------- 1 | #ifndef LM_COMMON_PRINT_H 2 | #define LM_COMMON_PRINT_H 3 | 4 | #include "../word_index.hh" 5 | #include "../../util/mmap.hh" 6 | #include "../../util/string_piece.hh" 7 | 8 | #include 9 | #include 10 | 11 | namespace util { namespace stream { class ChainPositions; }} 12 | 13 | // Warning: PrintARPA routines read all unigrams before all bigrams before all 14 | // trigrams etc. So if other parts of the chain move jointly, you'll have to 15 | // buffer. 16 | 17 | namespace lm { 18 | 19 | class VocabReconstitute { 20 | public: 21 | // fd must be alive for life of this object; does not take ownership. 22 | explicit VocabReconstitute(int fd); 23 | 24 | const char *Lookup(WordIndex index) const { 25 | assert(index < map_.size() - 1); 26 | return map_[index]; 27 | } 28 | 29 | StringPiece LookupPiece(WordIndex index) const { 30 | return StringPiece(map_[index], map_[index + 1] - 1 - map_[index]); 31 | } 32 | 33 | std::size_t Size() const { 34 | // There's an extra entry to support StringPiece lengths. 35 | return map_.size() - 1; 36 | } 37 | 38 | private: 39 | util::scoped_memory memory_; 40 | std::vector map_; 41 | }; 42 | 43 | class PrintARPA { 44 | public: 45 | // Does not take ownership of vocab_fd or out_fd. 46 | explicit PrintARPA(int vocab_fd, int out_fd, const std::vector &counts) 47 | : vocab_fd_(vocab_fd), out_fd_(out_fd), counts_(counts) {} 48 | 49 | void Run(const util::stream::ChainPositions &positions); 50 | 51 | private: 52 | int vocab_fd_; 53 | int out_fd_; 54 | std::vector counts_; 55 | }; 56 | 57 | } // namespace lm 58 | #endif // LM_COMMON_PRINT_H 59 | -------------------------------------------------------------------------------- /lm/common/renumber.cc: -------------------------------------------------------------------------------- 1 | #include "renumber.hh" 2 | #include "ngram.hh" 3 | 4 | #include "../../util/stream/stream.hh" 5 | 6 | namespace lm { 7 | 8 | void Renumber::Run(const util::stream::ChainPosition &position) { 9 | for (util::stream::Stream stream(position); stream; ++stream) { 10 | NGramHeader gram(stream.Get(), order_); 11 | for (WordIndex *w = gram.begin(); w != gram.end(); ++w) { 12 | *w = new_numbers_[*w]; 13 | } 14 | } 15 | } 16 | 17 | } // namespace lm 18 | -------------------------------------------------------------------------------- /lm/common/renumber.hh: -------------------------------------------------------------------------------- 1 | /* Map vocab ids. This is useful to merge independently collected counts or 2 | * change the vocab ids to the order used by the trie. 3 | */ 4 | #ifndef LM_COMMON_RENUMBER_H 5 | #define LM_COMMON_RENUMBER_H 6 | 7 | #include "../word_index.hh" 8 | 9 | #include 10 | 11 | namespace util { namespace stream { class ChainPosition; }} 12 | 13 | namespace lm { 14 | 15 | class Renumber { 16 | public: 17 | // Assumes the array is large enough to map all words and stays alive while 18 | // the thread is active. 19 | Renumber(const WordIndex *new_numbers, std::size_t order) 20 | : new_numbers_(new_numbers), order_(order) {} 21 | 22 | void Run(const util::stream::ChainPosition &position); 23 | 24 | private: 25 | const WordIndex *new_numbers_; 26 | std::size_t order_; 27 | }; 28 | 29 | } // namespace lm 30 | #endif // LM_COMMON_RENUMBER_H 31 | -------------------------------------------------------------------------------- /lm/common/size_option.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include "../../util/usage.hh" 3 | 4 | namespace lm { 5 | 6 | namespace { 7 | class SizeNotify { 8 | public: 9 | explicit SizeNotify(std::size_t &out) : behind_(out) {} 10 | 11 | void operator()(const std::string &from) { 12 | behind_ = util::ParseSize(from); 13 | } 14 | 15 | private: 16 | std::size_t &behind_; 17 | }; 18 | } 19 | 20 | boost::program_options::typed_value *SizeOption(std::size_t &to, const char *default_value) { 21 | return boost::program_options::value()->notifier(SizeNotify(to))->default_value(default_value); 22 | } 23 | 24 | } // namespace lm 25 | -------------------------------------------------------------------------------- /lm/common/size_option.hh: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | 6 | namespace lm { 7 | 8 | // Create a boost program option for data sizes. This parses sizes like 1T and 10k. 9 | boost::program_options::typed_value *SizeOption(std::size_t &to, const char *default_value); 10 | 11 | } // namespace lm 12 | -------------------------------------------------------------------------------- /lm/common/special.hh: -------------------------------------------------------------------------------- 1 | #ifndef LM_COMMON_SPECIAL_H 2 | #define LM_COMMON_SPECIAL_H 3 | 4 | #include "../word_index.hh" 5 | 6 | namespace lm { 7 | 8 | class SpecialVocab { 9 | public: 10 | SpecialVocab(WordIndex bos, WordIndex eos) : bos_(bos), eos_(eos) {} 11 | 12 | bool IsSpecial(WordIndex word) const { 13 | return word == kUNK || word == bos_ || word == eos_; 14 | } 15 | 16 | WordIndex UNK() const { return kUNK; } 17 | WordIndex BOS() const { return bos_; } 18 | WordIndex EOS() const { return eos_; } 19 | 20 | private: 21 | WordIndex bos_; 22 | WordIndex eos_; 23 | }; 24 | 25 | } // namespace lm 26 | 27 | #endif // LM_COMMON_SPECIAL_H 28 | -------------------------------------------------------------------------------- /lm/common/test_data/bigendian/toy0.1: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kpu/kenlm/4cb443e60b7bf2c0ddf3c745378f76cb59e254e5/lm/common/test_data/bigendian/toy0.1 -------------------------------------------------------------------------------- /lm/common/test_data/bigendian/toy0.2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kpu/kenlm/4cb443e60b7bf2c0ddf3c745378f76cb59e254e5/lm/common/test_data/bigendian/toy0.2 -------------------------------------------------------------------------------- /lm/common/test_data/bigendian/toy0.3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kpu/kenlm/4cb443e60b7bf2c0ddf3c745378f76cb59e254e5/lm/common/test_data/bigendian/toy0.3 -------------------------------------------------------------------------------- /lm/common/test_data/bigendian/toy0.kenlm_intermediate: -------------------------------------------------------------------------------- 1 | KenLM intermediate binary file 2 | Counts 5 7 7 3 | Payload pb 4 | -------------------------------------------------------------------------------- /lm/common/test_data/bigendian/toy0.vocab: -------------------------------------------------------------------------------- 1 | ab -------------------------------------------------------------------------------- /lm/common/test_data/bigendian/toy1.1: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kpu/kenlm/4cb443e60b7bf2c0ddf3c745378f76cb59e254e5/lm/common/test_data/bigendian/toy1.1 -------------------------------------------------------------------------------- /lm/common/test_data/bigendian/toy1.2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kpu/kenlm/4cb443e60b7bf2c0ddf3c745378f76cb59e254e5/lm/common/test_data/bigendian/toy1.2 -------------------------------------------------------------------------------- /lm/common/test_data/bigendian/toy1.3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kpu/kenlm/4cb443e60b7bf2c0ddf3c745378f76cb59e254e5/lm/common/test_data/bigendian/toy1.3 -------------------------------------------------------------------------------- /lm/common/test_data/bigendian/toy1.kenlm_intermediate: -------------------------------------------------------------------------------- 1 | KenLM intermediate binary file 2 | Counts 6 7 6 3 | Payload pb 4 | -------------------------------------------------------------------------------- /lm/common/test_data/bigendian/toy1.vocab: -------------------------------------------------------------------------------- 1 | acb -------------------------------------------------------------------------------- /lm/common/test_data/generate.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ../../../../build/bin/lmplz --discount_fallback -o 3 -S 100M --intermediate toy0 --arpa ../toy0.arpa <ab -------------------------------------------------------------------------------- /lm/common/test_data/littleendian/toy1.1: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kpu/kenlm/4cb443e60b7bf2c0ddf3c745378f76cb59e254e5/lm/common/test_data/littleendian/toy1.1 -------------------------------------------------------------------------------- /lm/common/test_data/littleendian/toy1.2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kpu/kenlm/4cb443e60b7bf2c0ddf3c745378f76cb59e254e5/lm/common/test_data/littleendian/toy1.2 -------------------------------------------------------------------------------- /lm/common/test_data/littleendian/toy1.3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kpu/kenlm/4cb443e60b7bf2c0ddf3c745378f76cb59e254e5/lm/common/test_data/littleendian/toy1.3 -------------------------------------------------------------------------------- /lm/common/test_data/littleendian/toy1.kenlm_intermediate: -------------------------------------------------------------------------------- 1 | KenLM intermediate binary file 2 | Counts 6 7 6 3 | Payload pb 4 | -------------------------------------------------------------------------------- /lm/common/test_data/littleendian/toy1.vocab: -------------------------------------------------------------------------------- 1 | acb -------------------------------------------------------------------------------- /lm/common/test_data/toy0.arpa: -------------------------------------------------------------------------------- 1 | \data\ 2 | ngram 1=5 3 | ngram 2=7 4 | ngram 3=7 5 | 6 | \1-grams: 7 | -0.90309 0 8 | 0 -0.30103 9 | -0.46943438 a -0.30103 10 | -0.5720968 0 11 | -0.5720968 b -0.30103 12 | 13 | \2-grams: 14 | -0.37712017 a -0.30103 15 | -0.37712017 a a -0.30103 16 | -0.2984526 b a -0.30103 17 | -0.58682007 a 0 18 | -0.5220179 b 0 19 | -0.41574955 b -0.30103 20 | -0.58682007 a b -0.30103 21 | 22 | \3-grams: 23 | -0.14885087 a a 24 | -0.33741078 b a a 25 | -0.124077894 b a 26 | -0.2997394 a b a 27 | -0.42082912 b a 28 | -0.397617 a b 29 | -0.20102891 a a b 30 | 31 | \end\ 32 | -------------------------------------------------------------------------------- /lm/common/test_data/toy1.arpa: -------------------------------------------------------------------------------- 1 | \data\ 2 | ngram 1=6 3 | ngram 2=7 4 | ngram 3=6 5 | 6 | \1-grams: 7 | -1 0 8 | 0 -0.30103 9 | -0.6146491 a -0.30103 10 | -0.6146491 0 11 | -0.7659168 c -0.30103 12 | -0.6146491 b -0.30103 13 | 14 | \2-grams: 15 | -0.4301247 a -0.30103 16 | -0.4301247 a a -0.30103 17 | -0.20660876 c 0 18 | -0.5404639 b 0 19 | -0.4740302 c -0.30103 20 | -0.4301247 a b -0.30103 21 | -0.3422159 b b -0.47712123 22 | 23 | \3-grams: 24 | -0.1638568 a a 25 | -0.09113217 c 26 | -0.7462621 b b 27 | -0.1638568 a a b 28 | -0.13823806 a b b 29 | -0.13375957 b b b 30 | 31 | \end\ 32 | -------------------------------------------------------------------------------- /lm/config.cc: -------------------------------------------------------------------------------- 1 | #include "config.hh" 2 | 3 | #include 4 | 5 | namespace lm { 6 | namespace ngram { 7 | 8 | Config::Config() : 9 | show_progress(true), 10 | messages(&std::cerr), 11 | enumerate_vocab(NULL), 12 | unknown_missing(COMPLAIN), 13 | sentence_marker_missing(THROW_UP), 14 | positive_log_probability(THROW_UP), 15 | unknown_missing_logprob(-100.0), 16 | probing_multiplier(1.5), 17 | building_memory(1073741824ULL), // 1 GB 18 | temporary_directory_prefix(""), 19 | arpa_complain(ALL), 20 | write_mmap(NULL), 21 | write_method(WRITE_AFTER), 22 | include_vocab(true), 23 | rest_function(REST_MAX), 24 | prob_bits(8), 25 | backoff_bits(8), 26 | pointer_bhiksha_bits(22), 27 | load_method(util::POPULATE_OR_READ) {} 28 | 29 | } // namespace ngram 30 | } // namespace lm 31 | -------------------------------------------------------------------------------- /lm/enumerate_vocab.hh: -------------------------------------------------------------------------------- 1 | #ifndef LM_ENUMERATE_VOCAB_H 2 | #define LM_ENUMERATE_VOCAB_H 3 | 4 | #include "word_index.hh" 5 | #include "../util/string_piece.hh" 6 | 7 | namespace lm { 8 | 9 | /* If you need the actual strings in the vocabulary, inherit from this class 10 | * and implement Add. Then put a pointer in Config.enumerate_vocab; it does 11 | * not take ownership. Add is called once per vocab word. index starts at 0 12 | * and increases by 1 each time. This is only used by the Model constructor; 13 | * the pointer is not retained by the class. 14 | */ 15 | class EnumerateVocab { 16 | public: 17 | virtual ~EnumerateVocab() {} 18 | 19 | virtual void Add(WordIndex index, const StringPiece &str) = 0; 20 | 21 | protected: 22 | EnumerateVocab() {} 23 | }; 24 | 25 | } // namespace lm 26 | 27 | #endif // LM_ENUMERATE_VOCAB_H 28 | 29 | -------------------------------------------------------------------------------- /lm/facade.hh: -------------------------------------------------------------------------------- 1 | #ifndef LM_FACADE_H 2 | #define LM_FACADE_H 3 | 4 | #include "virtual_interface.hh" 5 | #include "../util/string_piece.hh" 6 | 7 | #include 8 | 9 | namespace lm { 10 | namespace base { 11 | 12 | // Common model interface that depends on knowing the specific classes. 13 | // Curiously recurring template pattern. 14 | template class ModelFacade : public Model { 15 | public: 16 | typedef StateT State; 17 | typedef VocabularyT Vocabulary; 18 | 19 | /* Translate from void* to State */ 20 | FullScoreReturn BaseFullScore(const void *in_state, const WordIndex new_word, void *out_state) const { 21 | return static_cast(this)->FullScore( 22 | *reinterpret_cast(in_state), 23 | new_word, 24 | *reinterpret_cast(out_state)); 25 | } 26 | 27 | FullScoreReturn BaseFullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, void *out_state) const { 28 | return static_cast(this)->FullScoreForgotState( 29 | context_rbegin, 30 | context_rend, 31 | new_word, 32 | *reinterpret_cast(out_state)); 33 | } 34 | 35 | // Default Score function calls FullScore. Model can override this. 36 | float Score(const State &in_state, const WordIndex new_word, State &out_state) const { 37 | return static_cast(this)->FullScore(in_state, new_word, out_state).prob; 38 | } 39 | 40 | float BaseScore(const void *in_state, const WordIndex new_word, void *out_state) const { 41 | return static_cast(this)->Score( 42 | *reinterpret_cast(in_state), 43 | new_word, 44 | *reinterpret_cast(out_state)); 45 | } 46 | 47 | const State &BeginSentenceState() const { return begin_sentence_; } 48 | const State &NullContextState() const { return null_context_; } 49 | const Vocabulary &GetVocabulary() const { return *static_cast(&BaseVocabulary()); } 50 | 51 | protected: 52 | ModelFacade() : Model(sizeof(State)) {} 53 | 54 | virtual ~ModelFacade() {} 55 | 56 | // begin_sentence and null_context can disappear after. vocab should stay. 57 | void Init(const State &begin_sentence, const State &null_context, const Vocabulary &vocab, unsigned char order) { 58 | begin_sentence_ = begin_sentence; 59 | null_context_ = null_context; 60 | begin_sentence_memory_ = &begin_sentence_; 61 | null_context_memory_ = &null_context_; 62 | base_vocab_ = &vocab; 63 | order_ = order; 64 | } 65 | 66 | private: 67 | State begin_sentence_, null_context_; 68 | }; 69 | 70 | } // mamespace base 71 | } // namespace lm 72 | 73 | #endif // LM_FACADE_H 74 | -------------------------------------------------------------------------------- /lm/filter/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # This CMake file was created by Lane Schwartz 2 | 3 | # Explicitly list the source files for this subdirectory 4 | # 5 | # If you add any source files to this subdirectory 6 | # that should be included in the kenlm library, 7 | # (this excludes any unit test files) 8 | # you should add them to the following list: 9 | # 10 | # In order to set correct paths to these files 11 | # in case this variable is referenced by CMake files in the parent directory, 12 | # we prefix all files with ${CMAKE_CURRENT_SOURCE_DIR}. 13 | # 14 | set(KENLM_FILTER_SOURCE 15 | ${CMAKE_CURRENT_SOURCE_DIR}/arpa_io.cc 16 | ${CMAKE_CURRENT_SOURCE_DIR}/phrase.cc 17 | ${CMAKE_CURRENT_SOURCE_DIR}/vocab.cc 18 | ) 19 | 20 | # Group these objects together for later use. 21 | # 22 | # Given add_library(foo OBJECT ${my_foo_sources}), 23 | # refer to these objects as $ 24 | # 25 | add_library(kenlm_filter ${KENLM_FILTER_SOURCE}) 26 | target_link_libraries(kenlm_filter PUBLIC kenlm_util) 27 | # Since headers are relative to `include/kenlm` at install time, not just `include` 28 | target_include_directories(kenlm_filter PUBLIC $) 29 | 30 | AddExes(EXES filter phrase_table_vocab 31 | LIBRARIES kenlm_filter kenlm) 32 | 33 | install( 34 | TARGETS kenlm_filter 35 | EXPORT kenlmTargets 36 | RUNTIME DESTINATION bin 37 | LIBRARY DESTINATION lib 38 | ARCHIVE DESTINATION lib 39 | INCLUDES DESTINATION include 40 | ) 41 | -------------------------------------------------------------------------------- /lm/filter/arpa_io.cc: -------------------------------------------------------------------------------- 1 | #include "arpa_io.hh" 2 | #include "../../util/file_piece.hh" 3 | #include "../../util/string_stream.hh" 4 | 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include 11 | #include 12 | #include 13 | 14 | namespace lm { 15 | 16 | ARPAInputException::ARPAInputException(const StringPiece &message) throw() { 17 | *this << message; 18 | } 19 | 20 | ARPAInputException::ARPAInputException(const StringPiece &message, const StringPiece &line) throw() { 21 | *this << message << " in line " << line; 22 | } 23 | 24 | ARPAInputException::~ARPAInputException() throw() {} 25 | 26 | // Seeking is the responsibility of the caller. 27 | template void WriteCounts(Stream &out, const std::vector &number) { 28 | out << "\n\\data\\\n"; 29 | for (unsigned int i = 0; i < number.size(); ++i) { 30 | out << "ngram " << i+1 << "=" << number[i] << '\n'; 31 | } 32 | out << '\n'; 33 | } 34 | 35 | size_t SizeNeededForCounts(const std::vector &number) { 36 | util::StringStream stream; 37 | WriteCounts(stream, number); 38 | return stream.str().size(); 39 | } 40 | 41 | bool IsEntirelyWhiteSpace(const StringPiece &line) { 42 | for (size_t i = 0; i < static_cast(line.size()); ++i) { 43 | if (!isspace(line.data()[i])) return false; 44 | } 45 | return true; 46 | } 47 | 48 | ARPAOutput::ARPAOutput(const char *name, size_t buffer_size) 49 | : file_backing_(util::CreateOrThrow(name)), file_(file_backing_.get(), buffer_size) {} 50 | 51 | void ARPAOutput::ReserveForCounts(std::streampos reserve) { 52 | for (std::streampos i = 0; i < reserve; i += std::streampos(1)) { 53 | file_ << '\n'; 54 | } 55 | } 56 | 57 | void ARPAOutput::BeginLength(unsigned int length) { 58 | file_ << '\\' << length << "-grams:" << '\n'; 59 | fast_counter_ = 0; 60 | } 61 | 62 | void ARPAOutput::EndLength(unsigned int length) { 63 | file_ << '\n'; 64 | if (length > counts_.size()) { 65 | counts_.resize(length); 66 | } 67 | counts_[length - 1] = fast_counter_; 68 | } 69 | 70 | void ARPAOutput::Finish() { 71 | file_ << "\\end\\\n"; 72 | file_.seekp(0); 73 | WriteCounts(file_, counts_); 74 | file_.flush(); 75 | } 76 | 77 | } // namespace lm 78 | -------------------------------------------------------------------------------- /lm/filter/count_io.hh: -------------------------------------------------------------------------------- 1 | #ifndef LM_FILTER_COUNT_IO_H 2 | #define LM_FILTER_COUNT_IO_H 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #include "../../util/file_stream.hh" 9 | #include "../../util/file.hh" 10 | #include "../../util/file_piece.hh" 11 | 12 | namespace lm { 13 | 14 | class CountOutput : boost::noncopyable { 15 | public: 16 | explicit CountOutput(const char *name) : file_(util::CreateOrThrow(name)) {} 17 | 18 | void AddNGram(const StringPiece &line) { 19 | file_ << line << '\n'; 20 | } 21 | 22 | template void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line) { 23 | AddNGram(line); 24 | } 25 | 26 | void AddNGram(const StringPiece &ngram, const StringPiece &line) { 27 | AddNGram(line); 28 | } 29 | 30 | private: 31 | util::FileStream file_; 32 | }; 33 | 34 | class CountBatch { 35 | public: 36 | explicit CountBatch(std::streamsize initial_read) 37 | : initial_read_(initial_read) { 38 | buffer_.reserve(initial_read); 39 | } 40 | 41 | void Read(std::istream &in) { 42 | buffer_.resize(initial_read_); 43 | in.read(&*buffer_.begin(), initial_read_); 44 | buffer_.resize(in.gcount()); 45 | char got; 46 | while (in.get(got) && got != '\n') 47 | buffer_.push_back(got); 48 | } 49 | 50 | template void Send(Output &out) { 51 | for (util::TokenIter line(StringPiece(&*buffer_.begin(), buffer_.size()), '\n'); line; ++line) { 52 | util::TokenIter tabber(*line, '\t'); 53 | if (!tabber) { 54 | std::cerr << "Warning: empty n-gram count line being removed\n"; 55 | continue; 56 | } 57 | util::TokenIter words(*tabber, ' '); 58 | if (!words) { 59 | std::cerr << "Line has a tab but no words.\n"; 60 | continue; 61 | } 62 | out.AddNGram(words, util::TokenIter::end(), *line); 63 | } 64 | } 65 | 66 | private: 67 | std::streamsize initial_read_; 68 | 69 | // This could have been a std::string but that's less happy with raw writes. 70 | std::vector buffer_; 71 | }; 72 | 73 | template void ReadCount(util::FilePiece &in_file, Output &out) { 74 | try { 75 | while (true) { 76 | StringPiece line = in_file.ReadLine(); 77 | util::TokenIter tabber(line, '\t'); 78 | if (!tabber) { 79 | std::cerr << "Warning: empty n-gram count line being removed\n"; 80 | continue; 81 | } 82 | out.AddNGram(*tabber, line); 83 | } 84 | } catch (const util::EndOfFileException &) {} 85 | } 86 | 87 | } // namespace lm 88 | 89 | #endif // LM_FILTER_COUNT_IO_H 90 | -------------------------------------------------------------------------------- /lm/filter/vocab.cc: -------------------------------------------------------------------------------- 1 | #include "vocab.hh" 2 | 3 | #include 4 | #include 5 | 6 | #include 7 | 8 | namespace lm { 9 | namespace vocab { 10 | 11 | void ReadSingle(std::istream &in, boost::unordered_set &out) { 12 | in.exceptions(std::istream::badbit); 13 | std::string word; 14 | while (in >> word) { 15 | out.insert(word); 16 | } 17 | } 18 | 19 | namespace { 20 | bool IsLineEnd(std::istream &in) { 21 | int got; 22 | do { 23 | got = in.get(); 24 | if (!in) return true; 25 | if (got == '\n') return true; 26 | } while (isspace(got)); 27 | in.unget(); 28 | return false; 29 | } 30 | }// namespace 31 | 32 | // Read space separated words in enter separated lines. These lines can be 33 | // very long, so don't read an entire line at a time. 34 | unsigned int ReadMultiple(std::istream &in, boost::unordered_map > &out) { 35 | in.exceptions(std::istream::badbit); 36 | unsigned int sentence = 0; 37 | bool used_id = false; 38 | std::string word; 39 | while (in >> word) { 40 | used_id = true; 41 | std::vector &posting = out[word]; 42 | if (posting.empty() || (posting.back() != sentence)) 43 | posting.push_back(sentence); 44 | if (IsLineEnd(in)) { 45 | ++sentence; 46 | used_id = false; 47 | } 48 | } 49 | return sentence + used_id; 50 | } 51 | 52 | } // namespace vocab 53 | } // namespace lm 54 | -------------------------------------------------------------------------------- /lm/filter/wrapper.hh: -------------------------------------------------------------------------------- 1 | #ifndef LM_FILTER_WRAPPER_H 2 | #define LM_FILTER_WRAPPER_H 3 | 4 | #include "../../util/string_piece.hh" 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | namespace lm { 11 | 12 | // Provide a single-output filter with the same interface as a 13 | // multiple-output filter so clients code against one interface. 14 | template class BinaryFilter { 15 | public: 16 | // Binary modes are just references (and a set) and it makes the API cleaner to copy them. 17 | explicit BinaryFilter(Binary binary) : binary_(binary) {} 18 | 19 | template void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line, Output &output) { 20 | if (binary_.PassNGram(begin, end)) 21 | output.AddNGram(line); 22 | } 23 | 24 | template void AddNGram(const StringPiece &ngram, const StringPiece &line, Output &output) { 25 | AddNGram(util::TokenIter(ngram, ' '), util::TokenIter::end(), line, output); 26 | } 27 | 28 | void Flush() const {} 29 | 30 | private: 31 | Binary binary_; 32 | }; 33 | 34 | // Wrap another filter to pay attention only to context words 35 | template class ContextFilter { 36 | public: 37 | typedef FilterT Filter; 38 | 39 | explicit ContextFilter(Filter &backend) : backend_(backend) {} 40 | 41 | template void AddNGram(const StringPiece &ngram, const StringPiece &line, Output &output) { 42 | // Find beginning of string or last space. 43 | const char *last_space; 44 | for (last_space = ngram.data() + ngram.size() - 1; last_space > ngram.data() && *last_space != ' '; --last_space) {} 45 | backend_.AddNGram(StringPiece(ngram.data(), last_space - ngram.data()), line, output); 46 | } 47 | 48 | void Flush() const {} 49 | 50 | private: 51 | Filter backend_; 52 | }; 53 | 54 | } // namespace lm 55 | 56 | #endif // LM_FILTER_WRAPPER_H 57 | -------------------------------------------------------------------------------- /lm/fragment_main.cc: -------------------------------------------------------------------------------- 1 | #include "binary_format.hh" 2 | #include "model.hh" 3 | #include "left.hh" 4 | #include "../util/tokenize_piece.hh" 5 | 6 | template void Query(const char *name) { 7 | Model model(name); 8 | std::string line; 9 | lm::ngram::ChartState ignored; 10 | while (getline(std::cin, line)) { 11 | lm::ngram::RuleScore scorer(model, ignored); 12 | for (util::TokenIter i(line, ' '); i; ++i) { 13 | scorer.Terminal(model.GetVocabulary().Index(*i)); 14 | } 15 | std::cout << scorer.Finish() << '\n'; 16 | } 17 | } 18 | 19 | int main(int argc, char *argv[]) { 20 | if (argc != 2) { 21 | std::cerr << "Expected model file name." << std::endl; 22 | return 1; 23 | } 24 | const char *name = argv[1]; 25 | lm::ngram::ModelType model_type = lm::ngram::PROBING; 26 | lm::ngram::RecognizeBinary(name, model_type); 27 | switch (model_type) { 28 | case lm::ngram::PROBING: 29 | Query(name); 30 | break; 31 | case lm::ngram::REST_PROBING: 32 | Query(name); 33 | break; 34 | default: 35 | std::cerr << "Model type not supported yet." << std::endl; 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /lm/interpolate/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Eigen3 less than 3.1.0 has a race condition: http://eigen.tuxfamily.org/bz/show_bug.cgi?id=466 2 | 3 | if(ENABLE_INTERPOLATE) 4 | find_package(Eigen3 3.1.0 CONFIG REQUIRED) 5 | include_directories(${EIGEN3_INCLUDE_DIR}) 6 | 7 | set(KENLM_INTERPOLATE_SOURCE 8 | backoff_reunification.cc 9 | bounded_sequence_encoding.cc 10 | merge_probabilities.cc 11 | merge_vocab.cc 12 | normalize.cc 13 | pipeline.cc 14 | split_worker.cc 15 | tune_derivatives.cc 16 | tune_instances.cc 17 | tune_weights.cc 18 | universal_vocab.cc) 19 | 20 | add_library(kenlm_interpolate ${KENLM_INTERPOLATE_SOURCE}) 21 | target_link_libraries(kenlm_interpolate PUBLIC kenlm Eigen3::Eigen) 22 | # Since headers are relative to `include/kenlm` at install time, not just `include` 23 | target_include_directories(kenlm_interpolate PUBLIC $) 24 | 25 | 26 | find_package(OpenMP) 27 | if (OPENMP_CXX_FOUND) 28 | target_link_libraries(kenlm_interpolate PUBLIC OpenMP::OpenMP_CXX) 29 | endif() 30 | 31 | 32 | set(KENLM_INTERPOLATE_EXES 33 | interpolate 34 | streaming_example) 35 | 36 | set(KENLM_INTERPOLATE_LIBS 37 | kenlm_interpolate) 38 | 39 | AddExes(EXES ${KENLM_INTERPOLATE_EXES} 40 | LIBRARIES ${KENLM_INTERPOLATE_LIBS}) 41 | 42 | install( 43 | TARGETS kenlm_interpolate 44 | EXPORT kenlmTargets 45 | RUNTIME DESTINATION bin 46 | LIBRARY DESTINATION lib 47 | ARCHIVE DESTINATION lib 48 | INCLUDES DESTINATION include 49 | ) 50 | 51 | if(BUILD_TESTING) 52 | AddTests(TESTS backoff_reunification_test bounded_sequence_encoding_test merge_vocab_test normalize_test tune_derivatives_test 53 | LIBRARIES ${KENLM_INTERPOLATE_LIBS} Threads::Threads) 54 | 55 | # tune_instances_test needs an extra command line parameter 56 | KenLMAddTest(TEST tune_instances_test 57 | LIBRARIES ${KENLM_INTERPOLATE_LIBS} 58 | TEST_ARGS -- ${CMAKE_CURRENT_SOURCE_DIR}/../common/test_data) 59 | endif() 60 | endif() 61 | -------------------------------------------------------------------------------- /lm/interpolate/backoff_matrix.hh: -------------------------------------------------------------------------------- 1 | #ifndef LM_INTERPOLATE_BACKOFF_MATRIX_H 2 | #define LM_INTERPOLATE_BACKOFF_MATRIX_H 3 | 4 | #include 5 | #include 6 | 7 | namespace lm { namespace interpolate { 8 | 9 | class BackoffMatrix { 10 | public: 11 | BackoffMatrix(std::size_t num_models, std::size_t max_order) 12 | : max_order_(max_order), backing_(num_models * max_order) {} 13 | 14 | float &Backoff(std::size_t model, std::size_t order_minus_1) { 15 | return backing_[model * max_order_ + order_minus_1]; 16 | } 17 | 18 | float Backoff(std::size_t model, std::size_t order_minus_1) const { 19 | return backing_[model * max_order_ + order_minus_1]; 20 | } 21 | 22 | private: 23 | const std::size_t max_order_; 24 | std::vector backing_; 25 | }; 26 | 27 | }} // namespaces 28 | 29 | #endif // LM_INTERPOLATE_BACKOFF_MATRIX_H 30 | -------------------------------------------------------------------------------- /lm/interpolate/backoff_reunification.cc: -------------------------------------------------------------------------------- 1 | #include "backoff_reunification.hh" 2 | #include "../common/model_buffer.hh" 3 | #include "../common/ngram_stream.hh" 4 | #include "../common/ngram.hh" 5 | #include "../common/compare.hh" 6 | 7 | #include 8 | #include 9 | 10 | namespace lm { 11 | namespace interpolate { 12 | 13 | namespace { 14 | class MergeWorker { 15 | public: 16 | MergeWorker(std::size_t order, const util::stream::ChainPosition &prob_pos, 17 | const util::stream::ChainPosition &boff_pos) 18 | : order_(order), prob_pos_(prob_pos), boff_pos_(boff_pos) { 19 | // nothing 20 | } 21 | 22 | void Run(const util::stream::ChainPosition &position) { 23 | lm::NGramStream stream(position); 24 | 25 | lm::NGramStream prob_input(prob_pos_); 26 | util::stream::Stream boff_input(boff_pos_); 27 | for (; prob_input && boff_input; ++prob_input, ++boff_input, ++stream) { 28 | std::copy(prob_input->begin(), prob_input->end(), stream->begin()); 29 | stream->Value().prob = std::min(0.0f, prob_input->Value()); 30 | stream->Value().backoff = *reinterpret_cast(boff_input.Get()); 31 | } 32 | UTIL_THROW_IF2(prob_input || boff_input, 33 | "Streams were not the same size during merging"); 34 | stream.Poison(); 35 | } 36 | 37 | private: 38 | std::size_t order_; 39 | util::stream::ChainPosition prob_pos_; 40 | util::stream::ChainPosition boff_pos_; 41 | }; 42 | } 43 | 44 | // Since we are *adding* something to the output chain here, we pass in the 45 | // chain itself so that we can safely add a new step to the chain without 46 | // creating a deadlock situation (since creating a new ChainPosition will 47 | // make a new input/output pair---we want that position to be created 48 | // *here*, not before). 49 | void ReunifyBackoff(util::stream::ChainPositions &prob_pos, 50 | util::stream::ChainPositions &boff_pos, 51 | util::stream::Chains &output_chains) { 52 | assert(prob_pos.size() == boff_pos.size()); 53 | 54 | for (size_t i = 0; i < prob_pos.size(); ++i) 55 | output_chains[i] >> MergeWorker(i + 1, prob_pos[i], boff_pos[i]); 56 | } 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /lm/interpolate/backoff_reunification.hh: -------------------------------------------------------------------------------- 1 | #ifndef KENLM_INTERPOLATE_BACKOFF_REUNIFICATION_ 2 | #define KENLM_INTERPOLATE_BACKOFF_REUNIFICATION_ 3 | 4 | #include "../../util/stream/stream.hh" 5 | #include "../../util/stream/multi_stream.hh" 6 | 7 | namespace lm { 8 | namespace interpolate { 9 | 10 | /** 11 | * The third pass for the offline log-linear interpolation algorithm. This 12 | * reads **suffix-ordered** probability values (ngram-id, float) and 13 | * **suffix-ordered** backoff values (float) and writes the merged contents 14 | * to the output. 15 | * 16 | * @param prob_pos The chain position for each order from which to read 17 | * the probability values 18 | * @param boff_pos The chain position for each order from which to read 19 | * the backoff values 20 | * @param output_chains The output chains for each order 21 | */ 22 | void ReunifyBackoff(util::stream::ChainPositions &prob_pos, 23 | util::stream::ChainPositions &boff_pos, 24 | util::stream::Chains &output_chains); 25 | } 26 | } 27 | #endif 28 | -------------------------------------------------------------------------------- /lm/interpolate/bounded_sequence_encoding.cc: -------------------------------------------------------------------------------- 1 | #include "bounded_sequence_encoding.hh" 2 | 3 | #include 4 | 5 | namespace lm { namespace interpolate { 6 | 7 | BoundedSequenceEncoding::BoundedSequenceEncoding(const unsigned char *bound_begin, const unsigned char *bound_end) 8 | : entries_(bound_end - bound_begin) { 9 | std::size_t full = 0; 10 | Entry entry; 11 | entry.shift = 0; 12 | for (const unsigned char *i = bound_begin; i != bound_end; ++i) { 13 | uint8_t length; 14 | if (*i <= 1) { 15 | length = 0; 16 | } else { 17 | length = sizeof(unsigned int) * 8 - __builtin_clz((unsigned int)*i); 18 | } 19 | entry.mask = (1ULL << length) - 1ULL; 20 | if (entry.shift + length > 64) { 21 | entry.shift = 0; 22 | entry.next = true; 23 | ++full; 24 | } else { 25 | entry.next = false; 26 | } 27 | entries_.push_back(entry); 28 | entry.shift += length; 29 | } 30 | byte_length_ = full * sizeof(uint64_t) + (entry.shift + 7) / 8; 31 | first_copy_ = std::min(byte_length_, sizeof(uint64_t)); 32 | // Size of last uint64_t. Zero if empty, otherwise [1,8] depending on mod. 33 | overhang_ = byte_length_ == 0 ? 0 : ((byte_length_ - 1) % 8 + 1); 34 | } 35 | 36 | }} // namespaces 37 | -------------------------------------------------------------------------------- /lm/interpolate/bounded_sequence_encoding.hh: -------------------------------------------------------------------------------- 1 | #ifndef LM_INTERPOLATE_BOUNDED_SEQUENCE_ENCODING_H 2 | #define LM_INTERPOLATE_BOUNDED_SEQUENCE_ENCODING_H 3 | 4 | /* Encodes fixed-length sequences of integers with known bounds on each entry. 5 | * This is used to encode how far each model has backed off. 6 | * TODO: make this class efficient. Bit-level packing or multiply by bound and 7 | * add. 8 | */ 9 | 10 | #include "../../util/exception.hh" 11 | #include "../../util/fixed_array.hh" 12 | 13 | #include 14 | #include 15 | 16 | namespace lm { 17 | namespace interpolate { 18 | 19 | class BoundedSequenceEncoding { 20 | public: 21 | // Encode [0, bound_begin[0]) x [0, bound_begin[1]) x [0, bound_begin[2]) x ... x [0, *(bound_end - 1)) for entries in the sequence 22 | BoundedSequenceEncoding(const unsigned char *bound_begin, const unsigned char *bound_end); 23 | 24 | std::size_t Entries() const { return entries_.size(); } 25 | 26 | std::size_t EncodedLength() const { return byte_length_; } 27 | 28 | void Encode(const unsigned char *from, void *to_void) const { 29 | uint8_t *to = static_cast(to_void); 30 | uint64_t cur = 0; 31 | for (const Entry *i = entries_.begin(); i != entries_.end(); ++i, ++from) { 32 | if (UTIL_UNLIKELY(i->next)) { 33 | std::memcpy(to, &cur, sizeof(uint64_t)); 34 | to += sizeof(uint64_t); 35 | cur = 0; 36 | } 37 | cur |= static_cast(*from) << i->shift; 38 | } 39 | #if BYTE_ORDER == BIG_ENDIAN 40 | cur <<= (8 - overhang_) * 8; 41 | #endif 42 | memcpy(to, &cur, overhang_); 43 | } 44 | 45 | void Decode(const void *from_void, unsigned char *to) const { 46 | const uint8_t *from = static_cast(from_void); 47 | uint64_t cur = 0; 48 | memcpy(&cur, from, first_copy_); 49 | #if BYTE_ORDER == BIG_ENDIAN 50 | cur >>= (8 - first_copy_) * 8; 51 | #endif 52 | for (const Entry *i = entries_.begin(); i != entries_.end(); ++i, ++to) { 53 | if (UTIL_UNLIKELY(i->next)) { 54 | from += sizeof(uint64_t); 55 | cur = 0; 56 | std::memcpy(&cur, from, 57 | std::min(sizeof(uint64_t), static_cast(from_void) + byte_length_ - from)); 58 | #if BYTE_ORDER == BIG_ENDIAN 59 | cur >>= (8 - (static_cast(from_void) + byte_length_ - from)) * 8; 60 | #endif 61 | } 62 | *to = (cur >> i->shift) & i->mask; 63 | } 64 | } 65 | 66 | private: 67 | struct Entry { 68 | bool next; 69 | uint8_t shift; 70 | uint64_t mask; 71 | }; 72 | util::FixedArray entries_; 73 | std::size_t byte_length_; 74 | std::size_t first_copy_; 75 | std::size_t overhang_; 76 | }; 77 | 78 | 79 | }} // namespaces 80 | 81 | #endif // LM_INTERPOLATE_BOUNDED_SEQUENCE_ENCODING_H 82 | -------------------------------------------------------------------------------- /lm/interpolate/interpolate_info.hh: -------------------------------------------------------------------------------- 1 | #ifndef KENLM_INTERPOLATE_INTERPOLATE_INFO_H 2 | #define KENLM_INTERPOLATE_INTERPOLATE_INFO_H 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | namespace lm { 9 | namespace interpolate { 10 | 11 | /** 12 | * Stores relevant info for interpolating several language models, for use 13 | * during the three-pass offline log-linear interpolation algorithm. 14 | */ 15 | struct InterpolateInfo { 16 | /** 17 | * @return the number of models being interpolated 18 | */ 19 | std::size_t Models() const { 20 | return orders.size(); 21 | } 22 | 23 | /** 24 | * The lambda (interpolation weight) for each model. 25 | */ 26 | std::vector lambdas; 27 | 28 | /** 29 | * The maximum ngram order for each model. 30 | */ 31 | std::vector orders; 32 | }; 33 | } 34 | } 35 | #endif 36 | -------------------------------------------------------------------------------- /lm/interpolate/merge_vocab.hh: -------------------------------------------------------------------------------- 1 | #ifndef LM_INTERPOLATE_MERGE_VOCAB_H 2 | #define LM_INTERPOLATE_MERGE_VOCAB_H 3 | 4 | #include "../word_index.hh" 5 | #include "../../util/file.hh" 6 | #include "../../util/fixed_array.hh" 7 | 8 | namespace lm { 9 | 10 | class EnumerateVocab; 11 | 12 | namespace interpolate { 13 | 14 | class UniversalVocab; 15 | 16 | // The combined vocabulary is enumerated with enumerate. 17 | // Returns the size of the combined vocabulary. 18 | // Does not take ownership of vocab_files. 19 | WordIndex MergeVocab(util::FixedArray &vocab_files, UniversalVocab &vocab, EnumerateVocab &enumerate); 20 | 21 | }} // namespaces 22 | 23 | #endif // LM_INTERPOLATE_MERGE_VOCAB_H 24 | -------------------------------------------------------------------------------- /lm/interpolate/normalize.hh: -------------------------------------------------------------------------------- 1 | #ifndef LM_INTERPOLATE_NORMALIZE_H 2 | #define LM_INTERPOLATE_NORMALIZE_H 3 | 4 | #include "../../util/fixed_array.hh" 5 | 6 | /* Pass 2: 7 | * - Multiply backoff weights by the backed off probabilities from pass 1. 8 | * - Compute the normalization factor Z. 9 | * - Send Z to the next highest order. 10 | * - Rewind and divide by Z. 11 | */ 12 | 13 | namespace util { namespace stream { 14 | class ChainPositions; 15 | class Chains; 16 | }} // namespaces 17 | 18 | namespace lm { namespace interpolate { 19 | 20 | struct InterpolateInfo; 21 | 22 | void Normalize( 23 | const InterpolateInfo &info, 24 | // Input full models for backoffs. Assumes that renumbering has been done. Suffix order. 25 | util::FixedArray &models_by_order, 26 | // Input PartialProbGamma from MergeProbabilities. Context order. 27 | util::stream::Chains &merged_probabilities, 28 | // Output NGram with normalized probabilities. Context order. 29 | util::stream::Chains &probabilities_out, 30 | // Output bare floats with backoffs. Note backoffs.size() == order - 1. Suffix order. 31 | util::stream::Chains &backoffs_out); 32 | 33 | }} // namespaces 34 | 35 | #endif // LM_INTERPOLATE_NORMALIZE_H 36 | -------------------------------------------------------------------------------- /lm/interpolate/normalize_test.cc: -------------------------------------------------------------------------------- 1 | #include "normalize.hh" 2 | 3 | #include "interpolate_info.hh" 4 | #include "merge_probabilities.hh" 5 | #include "../common/ngram_stream.hh" 6 | #include "../../util/stream/chain.hh" 7 | #include "../../util/stream/multi_stream.hh" 8 | 9 | #define BOOST_TEST_MODULE NormalizeTest 10 | #include 11 | 12 | namespace lm { namespace interpolate { namespace { 13 | 14 | // log without backoff 15 | const float kInputs[] = {-0.3, 1.2, -9.8, 4.0, -7.0, 0.0}; 16 | 17 | class WriteInput { 18 | public: 19 | WriteInput() {} 20 | void Run(const util::stream::ChainPosition &to) { 21 | util::stream::Stream out(to); 22 | for (WordIndex i = 0; i < sizeof(kInputs) / sizeof(float); ++i, ++out) { 23 | memcpy(out.Get(), &i, sizeof(WordIndex)); 24 | memcpy((uint8_t*)out.Get() + sizeof(WordIndex), &kInputs[i], sizeof(float)); 25 | } 26 | out.Poison(); 27 | } 28 | }; 29 | 30 | void CheckOutput(const util::stream::ChainPosition &from) { 31 | NGramStream in(from); 32 | float sum = 0.0; 33 | for (WordIndex i = 0; i < sizeof(kInputs) / sizeof(float) - 1 /* at the end */; ++i) { 34 | sum += pow(10.0, kInputs[i]); 35 | } 36 | sum = log10(sum); 37 | BOOST_REQUIRE(in); 38 | BOOST_CHECK_CLOSE(kInputs[0] - sum, in->Value(), 0.0001); 39 | BOOST_REQUIRE(++in); 40 | BOOST_CHECK_CLOSE(kInputs[1] - sum, in->Value(), 0.0001); 41 | BOOST_REQUIRE(++in); 42 | BOOST_CHECK_CLOSE(kInputs[2] - sum, in->Value(), 0.0001); 43 | BOOST_REQUIRE(++in); 44 | BOOST_CHECK_CLOSE(kInputs[3] - sum, in->Value(), 0.0001); 45 | BOOST_REQUIRE(++in); 46 | BOOST_CHECK_CLOSE(kInputs[4] - sum, in->Value(), 0.0001); 47 | BOOST_REQUIRE(++in); 48 | BOOST_CHECK_CLOSE(kInputs[5] - sum, in->Value(), 0.0001); 49 | BOOST_CHECK(!++in); 50 | } 51 | 52 | BOOST_AUTO_TEST_CASE(Unigrams) { 53 | InterpolateInfo info; 54 | info.lambdas.push_back(2.0); 55 | info.lambdas.push_back(-0.1); 56 | info.orders.push_back(1); 57 | info.orders.push_back(1); 58 | 59 | BOOST_CHECK_EQUAL(0, MakeEncoder(info, 1).EncodedLength()); 60 | 61 | // No backoffs. 62 | util::stream::Chains blank(0); 63 | util::FixedArray models_by_order(2); 64 | models_by_order.push_back(blank); 65 | models_by_order.push_back(blank); 66 | 67 | util::stream::Chains merged_probabilities(1); 68 | util::stream::Chains probabilities_out(1); 69 | util::stream::Chains backoffs_out(0); 70 | 71 | merged_probabilities.push_back(util::stream::ChainConfig(sizeof(WordIndex) + sizeof(float) + sizeof(float), 2, 24)); 72 | probabilities_out.push_back(util::stream::ChainConfig(sizeof(WordIndex) + sizeof(float), 2, 100)); 73 | 74 | merged_probabilities[0] >> WriteInput(); 75 | Normalize(info, models_by_order, merged_probabilities, probabilities_out, backoffs_out); 76 | 77 | util::stream::ChainPosition checker(probabilities_out[0].Add()); 78 | 79 | merged_probabilities >> util::stream::kRecycle; 80 | probabilities_out >> util::stream::kRecycle; 81 | 82 | CheckOutput(checker); 83 | probabilities_out.Wait(); 84 | } 85 | 86 | }}} // namespaces 87 | -------------------------------------------------------------------------------- /lm/interpolate/pipeline.hh: -------------------------------------------------------------------------------- 1 | #ifndef LM_INTERPOLATE_PIPELINE_H 2 | #define LM_INTERPOLATE_PIPELINE_H 3 | 4 | #include "../common/model_buffer.hh" 5 | #include "../../util/fixed_array.hh" 6 | #include "../../util/stream/config.hh" 7 | 8 | #include 9 | #include 10 | 11 | namespace lm { namespace interpolate { 12 | 13 | struct Config { 14 | std::vector lambdas; 15 | util::stream::SortConfig sort; 16 | std::size_t BufferSize() const { return sort.buffer_size; } 17 | }; 18 | 19 | void Pipeline(util::FixedArray &models, const Config &config, int write_file); 20 | 21 | }} // namespaces 22 | #endif // LM_INTERPOLATE_PIPELINE_H 23 | -------------------------------------------------------------------------------- /lm/interpolate/split_worker.cc: -------------------------------------------------------------------------------- 1 | #include "split_worker.hh" 2 | #include "../common/ngram.hh" 3 | 4 | namespace lm { 5 | namespace interpolate { 6 | 7 | SplitWorker::SplitWorker(std::size_t order, util::stream::Chain &backoff_chain, 8 | util::stream::Chain &sort_chain) 9 | : order_(order) { 10 | backoff_chain >> backoff_input_; 11 | sort_chain >> sort_input_; 12 | } 13 | 14 | void SplitWorker::Run(const util::stream::ChainPosition &position) { 15 | // input: ngram record (id, prob, and backoff) 16 | // output: a float to the backoff_input stream 17 | // an ngram id and a float to the sort_input stream 18 | for (util::stream::Stream stream(position); stream; ++stream) { 19 | NGram ngram(stream.Get(), order_); 20 | 21 | // write id and prob to the sort stream 22 | float prob = ngram.Value().prob; 23 | lm::WordIndex *out = reinterpret_cast(sort_input_.Get()); 24 | for (const lm::WordIndex *it = ngram.begin(); it != ngram.end(); ++it) { 25 | *out++ = *it; 26 | } 27 | *reinterpret_cast(out) = prob; 28 | ++sort_input_; 29 | 30 | // write backoff to the backoff output stream 31 | float boff = ngram.Value().backoff; 32 | *reinterpret_cast(backoff_input_.Get()) = boff; 33 | ++backoff_input_; 34 | } 35 | sort_input_.Poison(); 36 | backoff_input_.Poison(); 37 | } 38 | 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /lm/interpolate/split_worker.hh: -------------------------------------------------------------------------------- 1 | #ifndef KENLM_INTERPOLATE_SPLIT_WORKER_H_ 2 | #define KENLM_INTERPOLATE_SPLIT_WORKER_H_ 3 | 4 | #include "../../util/stream/chain.hh" 5 | #include "../../util/stream/stream.hh" 6 | 7 | namespace lm { 8 | namespace interpolate { 9 | 10 | class SplitWorker { 11 | public: 12 | /** 13 | * Constructs a split worker for a particular order. It writes the 14 | * split-off backoff values to the backoff chain and the ngram id and 15 | * probability to the sort chain for each ngram in the input. 16 | */ 17 | SplitWorker(std::size_t order, util::stream::Chain &backoff_chain, 18 | util::stream::Chain &sort_chain); 19 | 20 | /** 21 | * The callback invoked to handle the input from the ngram intermediate 22 | * files. 23 | */ 24 | void Run(const util::stream::ChainPosition& position); 25 | 26 | private: 27 | /** 28 | * The ngram order we are reading/writing for. 29 | */ 30 | std::size_t order_; 31 | 32 | /** 33 | * The stream to write to for the backoff values. 34 | */ 35 | util::stream::Stream backoff_input_; 36 | 37 | /** 38 | * The stream to write to for the ngram id + probability values. 39 | */ 40 | util::stream::Stream sort_input_; 41 | }; 42 | } 43 | } 44 | #endif 45 | -------------------------------------------------------------------------------- /lm/interpolate/tune_derivatives.hh: -------------------------------------------------------------------------------- 1 | #ifndef LM_INTERPOLATE_TUNE_DERIVATIVES_H 2 | #define LM_INTERPOLATE_TUNE_DERIVATIVES_H 3 | 4 | #include "tune_matrix.hh" 5 | 6 | #include 7 | #include 8 | 9 | namespace lm { namespace interpolate { 10 | 11 | class Instances; 12 | 13 | // Given tuning instances and model weights, computes the objective function (log probability), gradient, and Hessian. 14 | // Returns log probability / number of instances. 15 | Accum Derivatives(Instances &instances /* Doesn't modify but ReadExtensions is lazy */, const Vector &weights, Vector &gradient, Matrix &hessian); 16 | 17 | }} // namespaces 18 | 19 | #endif // LM_INTERPOLATE_TUNE_DERIVATIVES_H 20 | 21 | -------------------------------------------------------------------------------- /lm/interpolate/tune_matrix.hh: -------------------------------------------------------------------------------- 1 | #ifndef LM_INTERPOLATE_TUNE_MATRIX_H 2 | #define LM_INTERPOLATE_TUNE_MATRIX_H 3 | 4 | #pragma GCC diagnostic push 5 | #pragma GCC diagnostic ignored "-Wpragmas" // Older gcc doesn't have "-Wunused-local-typedefs" and complains. 6 | #pragma GCC diagnostic ignored "-Wunused-local-typedefs" 7 | #include 8 | #pragma GCC diagnostic pop 9 | 10 | namespace lm { namespace interpolate { 11 | 12 | typedef Eigen::MatrixXf Matrix; 13 | typedef Eigen::VectorXf Vector; 14 | 15 | typedef Matrix::Scalar Accum; 16 | 17 | }} // namespaces 18 | #endif // LM_INTERPOLATE_TUNE_MATRIX_H 19 | -------------------------------------------------------------------------------- /lm/interpolate/tune_weights.cc: -------------------------------------------------------------------------------- 1 | #include "tune_weights.hh" 2 | 3 | #include "tune_derivatives.hh" 4 | #include "tune_instances.hh" 5 | 6 | #pragma GCC diagnostic push 7 | #pragma GCC diagnostic ignored "-Wpragmas" // Older gcc doesn't have "-Wunused-local-typedefs" and complains. 8 | #pragma GCC diagnostic ignored "-Wunused-local-typedefs" 9 | #include 10 | #pragma GCC diagnostic pop 11 | #include 12 | 13 | #include 14 | 15 | namespace lm { namespace interpolate { 16 | void TuneWeights(int tune_file, const std::vector &model_names, const InstancesConfig &config, std::vector &weights_out) { 17 | Instances instances(tune_file, model_names, config); 18 | Vector weights = Vector::Constant(model_names.size(), 1.0 / model_names.size()); 19 | Vector gradient; 20 | Matrix hessian; 21 | for (std::size_t iteration = 0; iteration < 10 /*TODO fancy stopping criteria */; ++iteration) { 22 | std::cerr << "Iteration " << iteration << ": weights ="; 23 | for (Vector::Index i = 0; i < weights.rows(); ++i) { 24 | std::cerr << ' ' << weights(i); 25 | } 26 | std::cerr << std::endl; 27 | std::cerr << "Perplexity = " << Derivatives(instances, weights, gradient, hessian) << std::endl; 28 | // TODO: 1.0 step size was too big and it kept getting unstable. More math. 29 | weights -= 0.7 * hessian.inverse() * gradient; 30 | } 31 | weights_out.assign(weights.data(), weights.data() + weights.size()); 32 | } 33 | }} // namespaces 34 | -------------------------------------------------------------------------------- /lm/interpolate/tune_weights.hh: -------------------------------------------------------------------------------- 1 | #ifndef LM_INTERPOLATE_TUNE_WEIGHTS_H 2 | #define LM_INTERPOLATE_TUNE_WEIGHTS_H 3 | 4 | #include "../../util/string_piece.hh" 5 | 6 | #include 7 | 8 | namespace lm { namespace interpolate { 9 | struct InstancesConfig; 10 | 11 | // Run a tuning loop, producing weights as output. 12 | void TuneWeights(int tune_file, const std::vector &model_names, const InstancesConfig &config, std::vector &weights); 13 | 14 | }} // namespaces 15 | #endif // LM_INTERPOLATE_TUNE_WEIGHTS_H 16 | -------------------------------------------------------------------------------- /lm/interpolate/universal_vocab.cc: -------------------------------------------------------------------------------- 1 | #include "universal_vocab.hh" 2 | 3 | namespace lm { 4 | namespace interpolate { 5 | 6 | UniversalVocab::UniversalVocab(const std::vector& model_vocab_sizes) { 7 | model_index_map_.resize(model_vocab_sizes.size()); 8 | for (size_t i = 0; i < model_vocab_sizes.size(); ++i) { 9 | model_index_map_[i].resize(model_vocab_sizes[i]); 10 | } 11 | } 12 | 13 | }} // namespaces 14 | -------------------------------------------------------------------------------- /lm/interpolate/universal_vocab.hh: -------------------------------------------------------------------------------- 1 | #ifndef LM_INTERPOLATE_UNIVERSAL_VOCAB_H 2 | #define LM_INTERPOLATE_UNIVERSAL_VOCAB_H 3 | 4 | #include "../word_index.hh" 5 | 6 | #include 7 | #include 8 | 9 | namespace lm { 10 | namespace interpolate { 11 | 12 | class UniversalVocab { 13 | public: 14 | explicit UniversalVocab(const std::vector& model_vocab_sizes); 15 | 16 | // GetUniversalIndex takes the model number and index for the specific 17 | // model and returns the universal model number 18 | WordIndex GetUniversalIdx(std::size_t model_num, WordIndex model_word_index) const { 19 | return model_index_map_[model_num][model_word_index]; 20 | } 21 | 22 | const WordIndex *Mapping(std::size_t model) const { 23 | return &*model_index_map_[model].begin(); 24 | } 25 | 26 | WordIndex SlowConvertToModel(std::size_t model, WordIndex index) const { 27 | std::vector::const_iterator i = lower_bound(model_index_map_[model].begin(), model_index_map_[model].end(), index); 28 | if (i == model_index_map_[model].end() || *i != index) return 0; 29 | return i - model_index_map_[model].begin(); 30 | } 31 | 32 | void InsertUniversalIdx(std::size_t model_num, WordIndex word_index, 33 | WordIndex universal_word_index) { 34 | model_index_map_[model_num][word_index] = universal_word_index; 35 | } 36 | 37 | private: 38 | std::vector > model_index_map_; 39 | }; 40 | 41 | } // namespace interpolate 42 | } // namespace lm 43 | 44 | #endif // LM_INTERPOLATE_UNIVERSAL_VOCAB_H 45 | -------------------------------------------------------------------------------- /lm/lm_exception.cc: -------------------------------------------------------------------------------- 1 | #include "lm_exception.hh" 2 | 3 | #include 4 | #include 5 | 6 | namespace lm { 7 | 8 | ConfigException::ConfigException() throw() {} 9 | ConfigException::~ConfigException() throw() {} 10 | 11 | LoadException::LoadException() throw() {} 12 | LoadException::~LoadException() throw() {} 13 | 14 | FormatLoadException::FormatLoadException() throw() {} 15 | FormatLoadException::~FormatLoadException() throw() {} 16 | 17 | VocabLoadException::VocabLoadException() throw() {} 18 | VocabLoadException::~VocabLoadException() throw() {} 19 | 20 | SpecialWordMissingException::SpecialWordMissingException() throw() {} 21 | SpecialWordMissingException::~SpecialWordMissingException() throw() {} 22 | 23 | } // namespace lm 24 | -------------------------------------------------------------------------------- /lm/lm_exception.hh: -------------------------------------------------------------------------------- 1 | #ifndef LM_LM_EXCEPTION_H 2 | #define LM_LM_EXCEPTION_H 3 | 4 | // Named to avoid conflict with util/exception.hh. 5 | 6 | #include "../util/exception.hh" 7 | #include "../util/string_piece.hh" 8 | 9 | #include 10 | #include 11 | 12 | namespace lm { 13 | 14 | typedef enum {THROW_UP, COMPLAIN, SILENT} WarningAction; 15 | 16 | class ConfigException : public util::Exception { 17 | public: 18 | ConfigException() throw(); 19 | ~ConfigException() throw(); 20 | }; 21 | 22 | class LoadException : public util::Exception { 23 | public: 24 | virtual ~LoadException() throw(); 25 | 26 | protected: 27 | LoadException() throw(); 28 | }; 29 | 30 | class FormatLoadException : public LoadException { 31 | public: 32 | FormatLoadException() throw(); 33 | ~FormatLoadException() throw(); 34 | }; 35 | 36 | class VocabLoadException : public LoadException { 37 | public: 38 | virtual ~VocabLoadException() throw(); 39 | VocabLoadException() throw(); 40 | }; 41 | 42 | class SpecialWordMissingException : public VocabLoadException { 43 | public: 44 | explicit SpecialWordMissingException() throw(); 45 | ~SpecialWordMissingException() throw(); 46 | }; 47 | 48 | } // namespace lm 49 | 50 | #endif // LM_LM_EXCEPTION 51 | -------------------------------------------------------------------------------- /lm/max_order.hh: -------------------------------------------------------------------------------- 1 | #ifndef LM_MAX_ORDER_H 2 | #define LM_MAX_ORDER_H 3 | /* IF YOUR BUILD SYSTEM PASSES -DKENLM_MAX_ORDER, THEN CHANGE THE BUILD SYSTEM. 4 | * If not, this is the default maximum order. 5 | * Having this limit means that State can be 6 | * (kMaxOrder - 1) * sizeof(float) bytes instead of 7 | * sizeof(float*) + (kMaxOrder - 1) * sizeof(float) + malloc overhead 8 | */ 9 | #ifndef KENLM_ORDER_MESSAGE 10 | #define KENLM_ORDER_MESSAGE "If your build system supports changing KENLM_MAX_ORDER, change it there and recompile. With cmake:\n cmake -DKENLM_MAX_ORDER=10 ..\nWith Moses:\n bjam --max-kenlm-order=10 -a\nOtherwise, edit lm/max_order.hh." 11 | #endif 12 | 13 | #endif // LM_MAX_ORDER_H 14 | -------------------------------------------------------------------------------- /lm/model_type.hh: -------------------------------------------------------------------------------- 1 | #ifndef LM_MODEL_TYPE_H 2 | #define LM_MODEL_TYPE_H 3 | 4 | namespace lm { 5 | namespace ngram { 6 | 7 | /* Not the best numbering system, but it grew this way for historical reasons 8 | * and I want to preserve existing binary files. */ 9 | typedef enum {PROBING=0, REST_PROBING=1, TRIE=2, QUANT_TRIE=3, ARRAY_TRIE=4, QUANT_ARRAY_TRIE=5} ModelType; 10 | 11 | // Historical names. 12 | const ModelType HASH_PROBING = PROBING; 13 | const ModelType TRIE_SORTED = TRIE; 14 | const ModelType QUANT_TRIE_SORTED = QUANT_TRIE; 15 | const ModelType ARRAY_TRIE_SORTED = ARRAY_TRIE; 16 | const ModelType QUANT_ARRAY_TRIE_SORTED = QUANT_ARRAY_TRIE; 17 | 18 | const static ModelType kQuantAdd = static_cast(QUANT_TRIE - TRIE); 19 | const static ModelType kArrayAdd = static_cast(ARRAY_TRIE - TRIE); 20 | 21 | } // namespace ngram 22 | } // namespace lm 23 | #endif // LM_MODEL_TYPE_H 24 | -------------------------------------------------------------------------------- /lm/return.hh: -------------------------------------------------------------------------------- 1 | #ifndef LM_RETURN_H 2 | #define LM_RETURN_H 3 | 4 | #include 5 | 6 | namespace lm { 7 | /* Structure returned by scoring routines. */ 8 | struct FullScoreReturn { 9 | // log10 probability 10 | float prob; 11 | 12 | /* The length of n-gram matched. Do not use this for recombination. 13 | * Consider a model containing only the following n-grams: 14 | * -1 foo 15 | * -3.14 bar 16 | * -2.718 baz -5 17 | * -6 foo bar 18 | * 19 | * If you score ``bar'' then ngram_length is 1 and recombination state is the 20 | * empty string because bar has zero backoff and does not extend to the 21 | * right. 22 | * If you score ``foo'' then ngram_length is 1 and recombination state is 23 | * ``foo''. 24 | * 25 | * Ideally, keep output states around and compare them. Failing that, 26 | * get out_state.ValidLength() and use that length for recombination. 27 | */ 28 | unsigned char ngram_length; 29 | 30 | /* Left extension information. If independent_left is set, then prob is 31 | * independent of words to the left (up to additional backoff). Otherwise, 32 | * extend_left indicates how to efficiently extend further to the left. 33 | */ 34 | bool independent_left; 35 | uint64_t extend_left; // Defined only if independent_left 36 | 37 | // Rest cost for extension to the left. 38 | float rest; 39 | }; 40 | 41 | } // namespace lm 42 | #endif // LM_RETURN_H 43 | -------------------------------------------------------------------------------- /lm/sizes.cc: -------------------------------------------------------------------------------- 1 | #include "sizes.hh" 2 | #include "model.hh" 3 | #include "../util/file_piece.hh" 4 | 5 | #include 6 | #include 7 | 8 | namespace lm { 9 | namespace ngram { 10 | 11 | void ShowSizes(const std::vector &counts, const lm::ngram::Config &config) { 12 | uint64_t sizes[6]; 13 | sizes[0] = ProbingModel::Size(counts, config); 14 | sizes[1] = RestProbingModel::Size(counts, config); 15 | sizes[2] = TrieModel::Size(counts, config); 16 | sizes[3] = QuantTrieModel::Size(counts, config); 17 | sizes[4] = ArrayTrieModel::Size(counts, config); 18 | sizes[5] = QuantArrayTrieModel::Size(counts, config); 19 | uint64_t max_length = *std::max_element(sizes, sizes + sizeof(sizes) / sizeof(uint64_t)); 20 | uint64_t min_length = *std::min_element(sizes, sizes + sizeof(sizes) / sizeof(uint64_t)); 21 | uint64_t divide; 22 | char prefix; 23 | if (min_length < (1 << 10) * 10) { 24 | prefix = ' '; 25 | divide = 1; 26 | } else if (min_length < (1 << 20) * 10) { 27 | prefix = 'k'; 28 | divide = 1 << 10; 29 | } else if (min_length < (1ULL << 30) * 10) { 30 | prefix = 'M'; 31 | divide = 1 << 20; 32 | } else { 33 | prefix = 'G'; 34 | divide = 1 << 30; 35 | } 36 | long int length = std::max(2, static_cast(ceil(log10((double) max_length / divide)))); 37 | std::cerr << "Memory estimate for binary LM:\ntype "; 38 | 39 | // right align bytes. 40 | for (long int i = 0; i < length - 2; ++i) std::cerr << ' '; 41 | 42 | std::cerr << prefix << "B\n" 43 | "probing " << std::setw(length) << (sizes[0] / divide) << " assuming -p " << config.probing_multiplier << "\n" 44 | "probing " << std::setw(length) << (sizes[1] / divide) << " assuming -r models -p " << config.probing_multiplier << "\n" 45 | "trie " << std::setw(length) << (sizes[2] / divide) << " without quantization\n" 46 | "trie " << std::setw(length) << (sizes[3] / divide) << " assuming -q " << (unsigned)config.prob_bits << " -b " << (unsigned)config.backoff_bits << " quantization \n" 47 | "trie " << std::setw(length) << (sizes[4] / divide) << " assuming -a " << (unsigned)config.pointer_bhiksha_bits << " array pointer compression\n" 48 | "trie " << std::setw(length) << (sizes[5] / divide) << " assuming -a " << (unsigned)config.pointer_bhiksha_bits << " -q " << (unsigned)config.prob_bits << " -b " << (unsigned)config.backoff_bits<< " array pointer compression and quantization\n"; 49 | } 50 | 51 | void ShowSizes(const std::vector &counts) { 52 | lm::ngram::Config config; 53 | ShowSizes(counts, config); 54 | } 55 | 56 | void ShowSizes(const char *file, const lm::ngram::Config &config) { 57 | std::vector counts; 58 | util::FilePiece f(file); 59 | lm::ReadARPACounts(f, counts); 60 | ShowSizes(counts, config); 61 | } 62 | 63 | }} //namespaces 64 | -------------------------------------------------------------------------------- /lm/sizes.hh: -------------------------------------------------------------------------------- 1 | #ifndef LM_SIZES_H 2 | #define LM_SIZES_H 3 | 4 | #include 5 | 6 | #include 7 | 8 | namespace lm { namespace ngram { 9 | 10 | struct Config; 11 | 12 | void ShowSizes(const std::vector &counts, const lm::ngram::Config &config); 13 | void ShowSizes(const std::vector &counts); 14 | void ShowSizes(const char *file, const lm::ngram::Config &config); 15 | 16 | }} // namespaces 17 | #endif // LM_SIZES_H 18 | -------------------------------------------------------------------------------- /lm/value_build.cc: -------------------------------------------------------------------------------- 1 | #include "value_build.hh" 2 | 3 | #include "model.hh" 4 | #include "read_arpa.hh" 5 | 6 | namespace lm { 7 | namespace ngram { 8 | 9 | template LowerRestBuild::LowerRestBuild(const Config &config, unsigned int order, const typename Model::Vocabulary &vocab) { 10 | UTIL_THROW_IF(config.rest_lower_files.size() != order - 1, ConfigException, "This model has order " << order << " so there should be " << (order - 1) << " lower-order models for rest cost purposes."); 11 | Config for_lower = config; 12 | for_lower.write_mmap = NULL; 13 | for_lower.rest_lower_files.clear(); 14 | 15 | // Unigram models aren't supported, so this is a custom loader. 16 | // TODO: optimize the unigram loading? 17 | { 18 | util::FilePiece uni(config.rest_lower_files[0].c_str()); 19 | std::vector number; 20 | ReadARPACounts(uni, number); 21 | UTIL_THROW_IF(number.size() != 1, FormatLoadException, "Expected the unigram model to have order 1, not " << number.size()); 22 | ReadNGramHeader(uni, 1); 23 | unigrams_.resize(number[0]); 24 | unigrams_[0] = config.unknown_missing_logprob; 25 | PositiveProbWarn warn; 26 | for (uint64_t i = 0; i < number[0]; ++i) { 27 | WordIndex w; 28 | Prob entry; 29 | ReadNGram(uni, 1, vocab, &w, entry, warn); 30 | unigrams_[w] = entry.prob; 31 | } 32 | } 33 | 34 | try { 35 | for (unsigned int i = 2; i < order; ++i) { 36 | models_.push_back(new Model(config.rest_lower_files[i - 1].c_str(), for_lower)); 37 | UTIL_THROW_IF(models_.back()->Order() != i, FormatLoadException, "Lower order file " << config.rest_lower_files[i-1] << " should have order " << i); 38 | } 39 | } catch (...) { 40 | for (typename std::vector::const_iterator i = models_.begin(); i != models_.end(); ++i) { 41 | delete *i; 42 | } 43 | models_.clear(); 44 | throw; 45 | } 46 | 47 | // TODO: force/check same vocab. 48 | } 49 | 50 | template LowerRestBuild::~LowerRestBuild() { 51 | for (typename std::vector::const_iterator i = models_.begin(); i != models_.end(); ++i) { 52 | delete *i; 53 | } 54 | } 55 | 56 | template class LowerRestBuild; 57 | 58 | } // namespace ngram 59 | } // namespace lm 60 | -------------------------------------------------------------------------------- /lm/value_build.hh: -------------------------------------------------------------------------------- 1 | #ifndef LM_VALUE_BUILD_H 2 | #define LM_VALUE_BUILD_H 3 | 4 | #include "weights.hh" 5 | #include "word_index.hh" 6 | #include "../util/bit_packing.hh" 7 | 8 | #include 9 | 10 | namespace lm { 11 | namespace ngram { 12 | 13 | struct Config; 14 | struct BackoffValue; 15 | struct RestValue; 16 | 17 | class NoRestBuild { 18 | public: 19 | typedef BackoffValue Value; 20 | 21 | NoRestBuild() {} 22 | 23 | void SetRest(const WordIndex *, unsigned int, const Prob &/*prob*/) const {} 24 | void SetRest(const WordIndex *, unsigned int, const ProbBackoff &) const {} 25 | 26 | template bool MarkExtends(ProbBackoff &weights, const Second &) const { 27 | util::UnsetSign(weights.prob); 28 | return false; 29 | } 30 | 31 | // Probing doesn't need to go back to unigram. 32 | const static bool kMarkEvenLower = false; 33 | }; 34 | 35 | class MaxRestBuild { 36 | public: 37 | typedef RestValue Value; 38 | 39 | MaxRestBuild() {} 40 | 41 | void SetRest(const WordIndex *, unsigned int, const Prob &/*prob*/) const {} 42 | void SetRest(const WordIndex *, unsigned int, RestWeights &weights) const { 43 | weights.rest = weights.prob; 44 | util::SetSign(weights.rest); 45 | } 46 | 47 | bool MarkExtends(RestWeights &weights, const RestWeights &to) const { 48 | util::UnsetSign(weights.prob); 49 | if (weights.rest >= to.rest) return false; 50 | weights.rest = to.rest; 51 | return true; 52 | } 53 | bool MarkExtends(RestWeights &weights, const Prob &to) const { 54 | util::UnsetSign(weights.prob); 55 | if (weights.rest >= to.prob) return false; 56 | weights.rest = to.prob; 57 | return true; 58 | } 59 | 60 | // Probing does need to go back to unigram. 61 | const static bool kMarkEvenLower = true; 62 | }; 63 | 64 | template class LowerRestBuild { 65 | public: 66 | typedef RestValue Value; 67 | 68 | LowerRestBuild(const Config &config, unsigned int order, const typename Model::Vocabulary &vocab); 69 | 70 | ~LowerRestBuild(); 71 | 72 | void SetRest(const WordIndex *, unsigned int, const Prob &/*prob*/) const {} 73 | void SetRest(const WordIndex *vocab_ids, unsigned int n, RestWeights &weights) const { 74 | typename Model::State ignored; 75 | if (n == 1) { 76 | weights.rest = unigrams_[*vocab_ids]; 77 | } else { 78 | weights.rest = models_[n-2]->FullScoreForgotState(vocab_ids + 1, vocab_ids + n, *vocab_ids, ignored).prob; 79 | } 80 | } 81 | 82 | template bool MarkExtends(RestWeights &weights, const Second &) const { 83 | util::UnsetSign(weights.prob); 84 | return false; 85 | } 86 | 87 | const static bool kMarkEvenLower = false; 88 | 89 | std::vector unigrams_; 90 | 91 | std::vector models_; 92 | }; 93 | 94 | } // namespace ngram 95 | } // namespace lm 96 | 97 | #endif // LM_VALUE_BUILD_H 98 | -------------------------------------------------------------------------------- /lm/virtual_interface.cc: -------------------------------------------------------------------------------- 1 | #include "virtual_interface.hh" 2 | 3 | #include "lm_exception.hh" 4 | 5 | namespace lm { 6 | namespace base { 7 | 8 | Vocabulary::~Vocabulary() {} 9 | 10 | void Vocabulary::SetSpecial(WordIndex begin_sentence, WordIndex end_sentence, WordIndex not_found) { 11 | begin_sentence_ = begin_sentence; 12 | end_sentence_ = end_sentence; 13 | not_found_ = not_found; 14 | } 15 | 16 | Model::~Model() {} 17 | 18 | } // namespace base 19 | } // namespace lm 20 | -------------------------------------------------------------------------------- /lm/weights.hh: -------------------------------------------------------------------------------- 1 | #ifndef LM_WEIGHTS_H 2 | #define LM_WEIGHTS_H 3 | 4 | // Weights for n-grams. Probability and possibly a backoff. 5 | 6 | namespace lm { 7 | struct Prob { 8 | float prob; 9 | }; 10 | // No inheritance so this will be a POD. 11 | struct ProbBackoff { 12 | float prob; 13 | float backoff; 14 | }; 15 | struct RestWeights { 16 | float prob; 17 | float backoff; 18 | float rest; 19 | }; 20 | 21 | } // namespace lm 22 | #endif // LM_WEIGHTS_H 23 | -------------------------------------------------------------------------------- /lm/word_index.hh: -------------------------------------------------------------------------------- 1 | // Separate header because this is used often. 2 | #ifndef LM_WORD_INDEX_H 3 | #define LM_WORD_INDEX_H 4 | 5 | #include 6 | 7 | namespace lm { 8 | typedef unsigned int WordIndex; 9 | const WordIndex kMaxWordIndex = UINT_MAX; 10 | const WordIndex kUNK = 0; 11 | } // namespace lm 12 | 13 | typedef lm::WordIndex LMWordIndex; 14 | 15 | #endif 16 | -------------------------------------------------------------------------------- /lm/wrappers/README: -------------------------------------------------------------------------------- 1 | This directory is for wrappers around other people's LMs, presenting an interface similar to KenLM's. You will need to have their LM installed. 2 | 3 | NPLM is a work in progress. 4 | -------------------------------------------------------------------------------- /lm/wrappers/nplm.hh: -------------------------------------------------------------------------------- 1 | #ifndef LM_WRAPPERS_NPLM_H 2 | #define LM_WRAPPERS_NPLM_H 3 | 4 | #include "../facade.hh" 5 | #include "../max_order.hh" 6 | #include "../../util/string_piece.hh" 7 | 8 | #include 9 | #include 10 | 11 | /* Wrapper to NPLM "by Ashish Vaswani, with contributions from David Chiang 12 | * and Victoria Fossum." 13 | * http://nlg.isi.edu/software/nplm/ 14 | */ 15 | 16 | namespace nplm { 17 | class vocabulary; 18 | class neuralLM; 19 | } // namespace nplm 20 | 21 | namespace lm { 22 | namespace np { 23 | 24 | class Vocabulary : public base::Vocabulary { 25 | public: 26 | Vocabulary(const nplm::vocabulary &vocab); 27 | 28 | ~Vocabulary(); 29 | 30 | WordIndex Index(const std::string &str) const; 31 | 32 | // TODO: lobby them to support StringPiece 33 | WordIndex Index(const StringPiece &str) const { 34 | return Index(std::string(str.data(), str.size())); 35 | } 36 | 37 | lm::WordIndex NullWord() const { return null_word_; } 38 | 39 | private: 40 | const nplm::vocabulary &vocab_; 41 | 42 | const lm::WordIndex null_word_; 43 | }; 44 | 45 | // Sorry for imposing my limitations on your code. 46 | #define NPLM_MAX_ORDER 7 47 | 48 | struct State { 49 | WordIndex words[NPLM_MAX_ORDER - 1]; 50 | }; 51 | 52 | class Backend; 53 | 54 | class Model : public lm::base::ModelFacade { 55 | private: 56 | typedef lm::base::ModelFacade P; 57 | 58 | public: 59 | // Does this look like an NPLM? 60 | static bool Recognize(const std::string &file); 61 | 62 | explicit Model(const std::string &file, std::size_t cache_size = 1 << 20); 63 | 64 | ~Model(); 65 | 66 | FullScoreReturn FullScore(const State &from, const WordIndex new_word, State &out_state) const; 67 | 68 | FullScoreReturn FullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, State &out_state) const; 69 | 70 | private: 71 | boost::scoped_ptr base_instance_; 72 | 73 | mutable boost::thread_specific_ptr backend_; 74 | 75 | Vocabulary vocab_; 76 | 77 | lm::WordIndex null_word_; 78 | 79 | const std::size_t cache_size_; 80 | }; 81 | 82 | } // namespace np 83 | } // namespace lm 84 | 85 | #endif // LM_WRAPPERS_NPLM_H 86 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "wheel", "cmake>=3.10"] 3 | -------------------------------------------------------------------------------- /python/BuildStandalone.cmake: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.10) 2 | 3 | file(GLOB 4 | KENLM_PYTHON_STANDALONE_SRCS 5 | "util/*.cc" 6 | "lm/*.cc" 7 | "util/double-conversion/*.cc" 8 | "python/*.cc" 9 | ) 10 | 11 | list(FILTER KENLM_PYTHON_STANDALONE_SRCS EXCLUDE REGEX ".*main.cc") 12 | list(FILTER KENLM_PYTHON_STANDALONE_SRCS EXCLUDE REGEX ".*test.cc") 13 | 14 | add_library( 15 | kenlm 16 | SHARED 17 | ${KENLM_PYTHON_STANDALONE_SRCS} 18 | ) 19 | 20 | target_include_directories(kenlm PRIVATE ${PROJECT_SOURCE_DIR}) 21 | target_compile_definitions(kenlm PRIVATE KENLM_MAX_ORDER=${KENLM_MAX_ORDER}) 22 | 23 | find_package(ZLIB) 24 | find_package(BZip2) 25 | find_package(LibLZMA) 26 | 27 | if (ZLIB_FOUND) 28 | target_link_libraries(kenlm PRIVATE ${ZLIB_LIBRARIES}) 29 | target_include_directories(kenlm PRIVATE ${ZLIB_INCLUDE_DIRS}) 30 | target_compile_definitions(kenlm PRIVATE HAVE_ZLIB) 31 | endif() 32 | if(BZIP2_FOUND) 33 | target_link_libraries(kenlm PRIVATE ${BZIP2_LIBRARIES}) 34 | target_include_directories(kenlm PRIVATE ${BZIP2_INCLUDE_DIR}) 35 | target_compile_definitions(kenlm PRIVATE HAVE_BZLIB) 36 | endif() 37 | if(LIBLZMA_FOUND) 38 | target_link_libraries(kenlm PRIVATE ${LIBLZMA_LIBRARIES}) 39 | target_include_directories(kenlm PRIVATE ${LIBLZMA_INCLUDE_DIRS}) 40 | target_compile_definitions(kenlm PRIVATE HAVE_LZMA) 41 | endif() 42 | -------------------------------------------------------------------------------- /python/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | find_package(PythonInterp REQUIRED) 2 | find_package(PythonLibs ${PYTHON_VERSION_STRING} EXACT REQUIRED) 3 | include_directories(${PYTHON_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR}) 4 | 5 | add_library(kenlm_python MODULE kenlm.cpp score_sentence.cc) 6 | set_target_properties(kenlm_python PROPERTIES OUTPUT_NAME kenlm) 7 | set_target_properties(kenlm_python PROPERTIES PREFIX "") 8 | 9 | if(APPLE) 10 | set_target_properties(kenlm_python PROPERTIES SUFFIX ".so") 11 | elseif(WIN32) 12 | set_target_properties(kenlm_python PROPERTIES SUFFIX ".pyd") 13 | endif() 14 | 15 | target_link_libraries(kenlm_python PUBLIC kenlm) 16 | if(WIN32) 17 | target_link_libraries(kenlm_python PUBLIC ${PYTHON_LIBRARIES}) 18 | elseif(APPLE) 19 | set_target_properties(kenlm_python PROPERTIES LINK_FLAGS "-undefined dynamic_lookup") 20 | endif() 21 | 22 | if (WIN32) 23 | set (PYTHON_SITE_PACKAGES Lib/site-packages) 24 | else () 25 | set (PYTHON_SITE_PACKAGES lib/python${PYTHON_VERSION_MAJOR}.${PYTHON_VERSION_MINOR}/site-packages) 26 | endif () 27 | 28 | install(TARGETS kenlm_python DESTINATION ${PYTHON_SITE_PACKAGES}) 29 | -------------------------------------------------------------------------------- /python/_kenlm.pxd: -------------------------------------------------------------------------------- 1 | from libcpp cimport bool 2 | 3 | cdef extern from "lm/word_index.hh" namespace "lm": 4 | ctypedef unsigned WordIndex 5 | 6 | cdef extern from "lm/return.hh" namespace "lm": 7 | cdef struct FullScoreReturn: 8 | float prob 9 | unsigned char ngram_length 10 | 11 | cdef extern from "lm/state.hh" namespace "lm::ngram": 12 | cdef cppclass State : 13 | int Compare(const State &other) const 14 | 15 | int hash_value(const State &state) 16 | 17 | cdef extern from "lm/virtual_interface.hh" namespace "lm::base": 18 | cdef cppclass Vocabulary: 19 | WordIndex Index(char*) 20 | WordIndex BeginSentence() 21 | WordIndex EndSentence() 22 | WordIndex NotFound() 23 | 24 | ctypedef Vocabulary const_Vocabulary "const lm::base::Vocabulary" 25 | 26 | cdef cppclass Model: 27 | void BeginSentenceWrite(void *) 28 | void NullContextWrite(void *) 29 | unsigned int Order() 30 | const_Vocabulary& BaseVocabulary() 31 | float BaseScore(void *in_state, WordIndex new_word, void *out_state) 32 | FullScoreReturn BaseFullScore(void *in_state, WordIndex new_word, void *out_state) 33 | 34 | cdef extern from "util/mmap.hh" namespace "util": 35 | cdef enum LoadMethod: 36 | LAZY 37 | POPULATE_OR_LAZY 38 | POPULATE_OR_READ 39 | READ 40 | PARALLEL_READ 41 | 42 | cdef extern from "lm/config.hh" namespace "lm::ngram::Config": 43 | cdef enum ARPALoadComplain: 44 | ALL 45 | EXPENSIVE 46 | NONE 47 | 48 | cdef extern from "lm/config.hh" namespace "lm::ngram": 49 | cdef cppclass Config: 50 | Config() 51 | float probing_multiplier 52 | LoadMethod load_method 53 | bool show_progress 54 | ARPALoadComplain arpa_complain 55 | float unknown_missing_logprob 56 | 57 | cdef extern from "lm/model.hh" namespace "lm::ngram": 58 | cdef Model *LoadVirtual(char *, Config &config) except + 59 | #default constructor 60 | cdef Model *LoadVirtual(char *) except + 61 | 62 | cdef extern from "python/score_sentence.hh" namespace "lm::base": 63 | cdef float ScoreSentence(const Model *model, const char *sentence) 64 | -------------------------------------------------------------------------------- /python/example.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | import kenlm 4 | 5 | LM = os.path.join(os.path.dirname(__file__), '..', 'lm', 'test.arpa') 6 | model = kenlm.LanguageModel(LM) 7 | print('{0}-gram model'.format(model.order)) 8 | 9 | sentence = 'language modeling is fun .' 10 | print(sentence) 11 | print(model.score(sentence)) 12 | 13 | # Check that total full score = direct score 14 | def score(s): 15 | return sum(prob for prob, _, _ in model.full_scores(s)) 16 | 17 | assert (abs(score(sentence) - model.score(sentence)) < 1e-3) 18 | 19 | # Show scores and n-gram matches 20 | words = [''] + sentence.split() + [''] 21 | for i, (prob, length, oov) in enumerate(model.full_scores(sentence)): 22 | print('{0} {1}: {2}'.format(prob, length, ' '.join(words[i+2-length:i+2]))) 23 | if oov: 24 | print('\t"{0}" is an OOV'.format(words[i+1])) 25 | 26 | # Find out-of-vocabulary words 27 | for w in words: 28 | if not w in model: 29 | print('"{0}" is an OOV'.format(w)) 30 | 31 | #Stateful query 32 | state = kenlm.State() 33 | state2 = kenlm.State() 34 | #Use as context. If you don't want , use model.NullContextWrite(state). 35 | model.BeginSentenceWrite(state) 36 | accum = 0.0 37 | accum += model.BaseScore(state, "a", state2) 38 | accum += model.BaseScore(state2, "sentence", state) 39 | #score defaults to bos = True and eos = True. Here we'll check without the end 40 | #of sentence marker. 41 | assert (abs(accum - model.score("a sentence", eos = False)) < 1e-3) 42 | accum += model.BaseScore(state, "", state2) 43 | assert (abs(accum - model.score("a sentence")) < 1e-3) 44 | -------------------------------------------------------------------------------- /python/score_sentence.cc: -------------------------------------------------------------------------------- 1 | #include "lm/state.hh" 2 | #include "lm/virtual_interface.hh" 3 | #include "util/tokenize_piece.hh" 4 | 5 | #include 6 | #include 7 | 8 | namespace lm { 9 | namespace base { 10 | 11 | float ScoreSentence(const base::Model *model, const char *sentence) { 12 | // TODO: reduce virtual dispatch to one per sentence? 13 | const base::Vocabulary &vocab = model->BaseVocabulary(); 14 | // We know it's going to be a KenLM State. 15 | lm::ngram::State state_vec[2]; 16 | lm::ngram::State *state = &state_vec[0]; 17 | lm::ngram::State *state2 = &state_vec[1]; 18 | model->BeginSentenceWrite(state); 19 | float ret = 0.0; 20 | for (util::TokenIter i(sentence, util::kSpaces); i; ++i) { 21 | lm::WordIndex index = vocab.Index(*i); 22 | ret += model->BaseScore(state, index, state2); 23 | std::swap(state, state2); 24 | } 25 | ret += model->BaseScore(state, vocab.EndSentence(), state2); 26 | return ret; 27 | } 28 | 29 | } // namespace base 30 | } // namespace lm 31 | -------------------------------------------------------------------------------- /python/score_sentence.hh: -------------------------------------------------------------------------------- 1 | // Score an entire sentence splitting on whitespace. This should not be needed 2 | // for C++ users (who should do it themselves), but it's faster for python users. 3 | #pragma once 4 | 5 | namespace lm { 6 | namespace base { 7 | 8 | class Model; 9 | 10 | float ScoreSentence(const Model *model, const char *sentence); 11 | 12 | } // namespace base 13 | } // namespace lm 14 | -------------------------------------------------------------------------------- /util/bit_packing.cc: -------------------------------------------------------------------------------- 1 | #include "bit_packing.hh" 2 | #include "exception.hh" 3 | 4 | #include 5 | 6 | namespace util { 7 | 8 | namespace { 9 | template struct StaticCheck {}; 10 | template <> struct StaticCheck { typedef bool StaticAssertionPassed; }; 11 | 12 | // If your float isn't 4 bytes, we're hosed. 13 | typedef StaticCheck::StaticAssertionPassed FloatSize; 14 | 15 | } // namespace 16 | 17 | uint8_t RequiredBits(uint64_t max_value) { 18 | if (!max_value) return 0; 19 | uint8_t ret = 1; 20 | while (max_value >>= 1) ++ret; 21 | return ret; 22 | } 23 | 24 | void BitPackingSanity() { 25 | const FloatEnc neg1 = { -1.0 }, pos1 = { 1.0 }; 26 | if ((neg1.i ^ pos1.i) != 0x80000000) UTIL_THROW(Exception, "Sign bit is not 0x80000000"); 27 | char mem[57+8]; 28 | memset(mem, 0, sizeof(mem)); 29 | const uint64_t test57 = 0x123456789abcdefULL; 30 | for (uint64_t b = 0; b < 57 * 8; b += 57) { 31 | WriteInt57(mem, b, 57, test57); 32 | } 33 | for (uint64_t b = 0; b < 57 * 8; b += 57) { 34 | if (test57 != ReadInt57(mem, b, 57, (1ULL << 57) - 1)) 35 | UTIL_THROW(Exception, "The bit packing routines are failing for your architecture. Please send a bug report with your architecture, operating system, and compiler."); 36 | } 37 | // TODO: more checks. 38 | } 39 | 40 | } // namespace util 41 | -------------------------------------------------------------------------------- /util/bit_packing_test.cc: -------------------------------------------------------------------------------- 1 | #include "bit_packing.hh" 2 | 3 | #define BOOST_TEST_MODULE BitPackingTest 4 | #include 5 | 6 | #include 7 | 8 | namespace util { 9 | namespace { 10 | 11 | const uint64_t test57 = 0x123456789abcdefULL; 12 | const uint32_t test25 = 0x1234567; 13 | 14 | BOOST_AUTO_TEST_CASE(ZeroBit57) { 15 | char mem[16]; 16 | memset(mem, 0, sizeof(mem)); 17 | WriteInt57(mem, 0, 57, test57); 18 | BOOST_CHECK_EQUAL(test57, ReadInt57(mem, 0, 57, (1ULL << 57) - 1)); 19 | } 20 | 21 | BOOST_AUTO_TEST_CASE(EachBit57) { 22 | char mem[16]; 23 | for (uint8_t b = 0; b < 8; ++b) { 24 | memset(mem, 0, sizeof(mem)); 25 | WriteInt57(mem, b, 57, test57); 26 | BOOST_CHECK_EQUAL(test57, ReadInt57(mem, b, 57, (1ULL << 57) - 1)); 27 | } 28 | } 29 | 30 | BOOST_AUTO_TEST_CASE(Consecutive57) { 31 | char mem[57+8]; 32 | memset(mem, 0, sizeof(mem)); 33 | for (uint64_t b = 0; b < 57 * 8; b += 57) { 34 | WriteInt57(mem, b, 57, test57); 35 | BOOST_CHECK_EQUAL(test57, ReadInt57(mem, b, 57, (1ULL << 57) - 1)); 36 | } 37 | for (uint64_t b = 0; b < 57 * 8; b += 57) { 38 | BOOST_CHECK_EQUAL(test57, ReadInt57(mem, b, 57, (1ULL << 57) - 1)); 39 | } 40 | } 41 | 42 | BOOST_AUTO_TEST_CASE(Consecutive25) { 43 | char mem[25+8]; 44 | memset(mem, 0, sizeof(mem)); 45 | for (uint64_t b = 0; b < 25 * 8; b += 25) { 46 | WriteInt25(mem, b, 25, test25); 47 | BOOST_CHECK_EQUAL(test25, ReadInt25(mem, b, 25, (1ULL << 25) - 1)); 48 | } 49 | for (uint64_t b = 0; b < 25 * 8; b += 25) { 50 | BOOST_CHECK_EQUAL(test25, ReadInt25(mem, b, 25, (1ULL << 25) - 1)); 51 | } 52 | } 53 | 54 | BOOST_AUTO_TEST_CASE(Sanity) { 55 | BitPackingSanity(); 56 | } 57 | 58 | } // namespace 59 | } // namespace util 60 | -------------------------------------------------------------------------------- /util/cat_compressed_main.cc: -------------------------------------------------------------------------------- 1 | // Like cat but interprets compressed files. 2 | #include "file.hh" 3 | #include "read_compressed.hh" 4 | 5 | #include 6 | #include 7 | 8 | namespace { 9 | const std::size_t kBufSize = 16384; 10 | void Copy(util::ReadCompressed &from, int to) { 11 | util::scoped_malloc buffer(util::MallocOrThrow(kBufSize)); 12 | while (std::size_t amount = from.Read(buffer.get(), kBufSize)) { 13 | util::WriteOrThrow(to, buffer.get(), amount); 14 | } 15 | } 16 | } // namespace 17 | 18 | int main(int argc, char *argv[]) { 19 | // Lane Schwartz likes -h and --help 20 | for (int i = 1; i < argc; ++i) { 21 | char *arg = argv[i]; 22 | if (!strcmp(arg, "--")) break; 23 | if (!strcmp(arg, "-h") || !strcmp(arg, "--help")) { 24 | std::cerr << 25 | "A cat implementation that interprets compressed files.\n" 26 | "Usage: " << argv[0] << " [file1] [file2] ...\n" 27 | "If no file is provided, then stdin is read.\n"; 28 | return 1; 29 | } 30 | } 31 | 32 | try { 33 | if (argc == 1) { 34 | util::ReadCompressed in(0); 35 | Copy(in, 1); 36 | } else { 37 | for (int i = 1; i < argc; ++i) { 38 | util::ReadCompressed in(util::OpenReadOrThrow(argv[i])); 39 | Copy(in, 1); 40 | } 41 | } 42 | } catch (const std::exception &e) { 43 | std::cerr << e.what() << std::endl; 44 | return 2; 45 | } 46 | return 0; 47 | } 48 | -------------------------------------------------------------------------------- /util/double-conversion/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # This CMake file was created by Lane Schwartz 2 | 3 | # Explicitly list the source files for this subdirectory 4 | # 5 | # If you add any source files to this subdirectory 6 | # that should be included in the kenlm library, 7 | # (this excludes any unit test files) 8 | # you should add them to the following list: 9 | # 10 | # In order to allow CMake files in the parent directory 11 | # to see this variable definition, we set PARENT_SCOPE. 12 | # 13 | # In order to set correct paths to these files 14 | # when this variable is referenced by CMake files in the parent directory, 15 | # we prefix all files with ${CMAKE_CURRENT_SOURCE_DIR}. 16 | # 17 | set(KENLM_UTIL_DOUBLECONVERSION_SOURCE 18 | ${CMAKE_CURRENT_SOURCE_DIR}/bignum-dtoa.cc 19 | ${CMAKE_CURRENT_SOURCE_DIR}/bignum.cc 20 | ${CMAKE_CURRENT_SOURCE_DIR}/cached-powers.cc 21 | ${CMAKE_CURRENT_SOURCE_DIR}/fast-dtoa.cc 22 | ${CMAKE_CURRENT_SOURCE_DIR}/fixed-dtoa.cc 23 | ${CMAKE_CURRENT_SOURCE_DIR}/strtod.cc 24 | ${CMAKE_CURRENT_SOURCE_DIR}/double-to-string.cc 25 | ${CMAKE_CURRENT_SOURCE_DIR}/string-to-double.cc 26 | PARENT_SCOPE) 27 | 28 | -------------------------------------------------------------------------------- /util/double-conversion/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2006-2011, the V8 project authors. All rights reserved. 2 | Redistribution and use in source and binary forms, with or without 3 | modification, are permitted provided that the following conditions are 4 | met: 5 | 6 | * Redistributions of source code must retain the above copyright 7 | notice, this list of conditions and the following disclaimer. 8 | * Redistributions in binary form must reproduce the above 9 | copyright notice, this list of conditions and the following 10 | disclaimer in the documentation and/or other materials provided 11 | with the distribution. 12 | * Neither the name of Google Inc. nor the names of its 13 | contributors may be used to endorse or promote products derived 14 | from this software without specific prior written permission. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 17 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 18 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 19 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 20 | OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 21 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 22 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 23 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 24 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 25 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 26 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27 | -------------------------------------------------------------------------------- /util/double-conversion/double-conversion.h: -------------------------------------------------------------------------------- 1 | // Copyright 2012 the V8 project authors. All rights reserved. 2 | // Redistribution and use in source and binary forms, with or without 3 | // modification, are permitted provided that the following conditions are 4 | // met: 5 | // 6 | // * Redistributions of source code must retain the above copyright 7 | // notice, this list of conditions and the following disclaimer. 8 | // * Redistributions in binary form must reproduce the above 9 | // copyright notice, this list of conditions and the following 10 | // disclaimer in the documentation and/or other materials provided 11 | // with the distribution. 12 | // * Neither the name of Google Inc. nor the names of its 13 | // contributors may be used to endorse or promote products derived 14 | // from this software without specific prior written permission. 15 | // 16 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 17 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 18 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 19 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 20 | // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 21 | // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 22 | // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 23 | // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 24 | // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 25 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 26 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27 | 28 | #ifndef DOUBLE_CONVERSION_DOUBLE_CONVERSION_H_ 29 | #define DOUBLE_CONVERSION_DOUBLE_CONVERSION_H_ 30 | 31 | #include "string-to-double.h" 32 | #include "double-to-string.h" 33 | 34 | #endif // DOUBLE_CONVERSION_DOUBLE_CONVERSION_H_ 35 | -------------------------------------------------------------------------------- /util/double-conversion/fixed-dtoa.h: -------------------------------------------------------------------------------- 1 | // Copyright 2010 the V8 project authors. All rights reserved. 2 | // Redistribution and use in source and binary forms, with or without 3 | // modification, are permitted provided that the following conditions are 4 | // met: 5 | // 6 | // * Redistributions of source code must retain the above copyright 7 | // notice, this list of conditions and the following disclaimer. 8 | // * Redistributions in binary form must reproduce the above 9 | // copyright notice, this list of conditions and the following 10 | // disclaimer in the documentation and/or other materials provided 11 | // with the distribution. 12 | // * Neither the name of Google Inc. nor the names of its 13 | // contributors may be used to endorse or promote products derived 14 | // from this software without specific prior written permission. 15 | // 16 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 17 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 18 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 19 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 20 | // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 21 | // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 22 | // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 23 | // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 24 | // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 25 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 26 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27 | 28 | #ifndef DOUBLE_CONVERSION_FIXED_DTOA_H_ 29 | #define DOUBLE_CONVERSION_FIXED_DTOA_H_ 30 | 31 | #include "utils.h" 32 | 33 | namespace double_conversion { 34 | 35 | // Produces digits necessary to print a given number with 36 | // 'fractional_count' digits after the decimal point. 37 | // The buffer must be big enough to hold the result plus one terminating null 38 | // character. 39 | // 40 | // The produced digits might be too short in which case the caller has to fill 41 | // the gaps with '0's. 42 | // Example: FastFixedDtoa(0.001, 5, ...) is allowed to return buffer = "1", and 43 | // decimal_point = -2. 44 | // Halfway cases are rounded towards +/-Infinity (away from 0). The call 45 | // FastFixedDtoa(0.15, 2, ...) thus returns buffer = "2", decimal_point = 0. 46 | // The returned buffer may contain digits that would be truncated from the 47 | // shortest representation of the input. 48 | // 49 | // This method only works for some parameters. If it can't handle the input it 50 | // returns false. The output is null-terminated when the function succeeds. 51 | bool FastFixedDtoa(double v, int fractional_count, 52 | Vector buffer, int* length, int* decimal_point); 53 | 54 | } // namespace double_conversion 55 | 56 | #endif // DOUBLE_CONVERSION_FIXED_DTOA_H_ 57 | -------------------------------------------------------------------------------- /util/ersatz_progress.cc: -------------------------------------------------------------------------------- 1 | #include "ersatz_progress.hh" 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | namespace util { 9 | 10 | namespace { const unsigned char kWidth = 100; } 11 | 12 | const char kProgressBanner[] = "----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100\n"; 13 | 14 | ErsatzProgress::ErsatzProgress() : current_(0), next_(std::numeric_limits::max()), complete_(next_), out_(NULL) {} 15 | 16 | ErsatzProgress::~ErsatzProgress() { 17 | if (out_) Finished(); 18 | } 19 | 20 | ErsatzProgress::ErsatzProgress(uint64_t complete, std::ostream *to, const std::string &message) 21 | : current_(0), next_(complete / kWidth), complete_(complete), stones_written_(0), out_(to) { 22 | if (!out_) { 23 | next_ = std::numeric_limits::max(); 24 | return; 25 | } 26 | if (!message.empty()) *out_ << message << '\n'; 27 | *out_ << kProgressBanner; 28 | } 29 | 30 | void ErsatzProgress::Milestone() { 31 | if (!out_) { current_ = 0; return; } 32 | if (!complete_) return; 33 | unsigned char stone = std::min(static_cast(kWidth), (current_ * kWidth) / complete_); 34 | 35 | for (; stones_written_ < stone; ++stones_written_) { 36 | (*out_) << '*'; 37 | } 38 | if (stone == kWidth) { 39 | (*out_) << std::endl; 40 | next_ = std::numeric_limits::max(); 41 | out_ = NULL; 42 | } else { 43 | next_ = std::max(next_, ((stone + 1) * complete_ + kWidth - 1) / kWidth); 44 | } 45 | } 46 | 47 | } // namespace util 48 | -------------------------------------------------------------------------------- /util/ersatz_progress.hh: -------------------------------------------------------------------------------- 1 | #ifndef UTIL_ERSATZ_PROGRESS_H 2 | #define UTIL_ERSATZ_PROGRESS_H 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | // Ersatz version of boost::progress so core language model doesn't depend on 9 | // boost. Also adds option to print nothing. 10 | 11 | namespace util { 12 | 13 | extern const char kProgressBanner[]; 14 | 15 | class ErsatzProgress { 16 | public: 17 | // No output. 18 | ErsatzProgress(); 19 | 20 | // Null means no output. The null value is useful for passing along the ostream pointer from another caller. 21 | explicit ErsatzProgress(uint64_t complete, std::ostream *to = &std::cerr, const std::string &message = ""); 22 | 23 | #if __cplusplus >= 201103L 24 | ErsatzProgress(ErsatzProgress &&from) noexcept : current_(from.current_), next_(from.next_), complete_(from.complete_), stones_written_(from.stones_written_), out_(from.out_) { 25 | from.out_ = nullptr; 26 | from.next_ = (uint64_t)-1; 27 | } 28 | #endif 29 | 30 | ~ErsatzProgress(); 31 | 32 | ErsatzProgress &operator++() { 33 | if (++current_ >= next_) Milestone(); 34 | return *this; 35 | } 36 | 37 | ErsatzProgress &operator+=(uint64_t amount) { 38 | if ((current_ += amount) >= next_) Milestone(); 39 | return *this; 40 | } 41 | 42 | void Set(uint64_t to) { 43 | if ((current_ = to) >= next_) Milestone(); 44 | } 45 | 46 | void Finished() { 47 | Set(complete_); 48 | } 49 | 50 | private: 51 | void Milestone(); 52 | 53 | uint64_t current_, next_, complete_; 54 | unsigned char stones_written_; 55 | std::ostream *out_; 56 | 57 | // noncopyable 58 | ErsatzProgress(const ErsatzProgress &other); 59 | ErsatzProgress &operator=(const ErsatzProgress &other); 60 | }; 61 | 62 | } // namespace util 63 | 64 | #endif // UTIL_ERSATZ_PROGRESS_H 65 | -------------------------------------------------------------------------------- /util/file_stream.hh: -------------------------------------------------------------------------------- 1 | /* Like std::ofstream but without being incredibly slow. Backed by a raw fd. 2 | * Supports most of the built-in types except for long double. 3 | */ 4 | #ifndef UTIL_FILE_STREAM_H 5 | #define UTIL_FILE_STREAM_H 6 | 7 | #include "fake_ostream.hh" 8 | #include "file.hh" 9 | #include "scoped.hh" 10 | 11 | #include 12 | #include 13 | 14 | #include 15 | 16 | namespace util { 17 | 18 | class FileStream : public FakeOStream { 19 | public: 20 | explicit FileStream(int out = -1, std::size_t buffer_size = 8192) 21 | : buf_(util::MallocOrThrow(std::max(buffer_size, kToStringMaxBytes))), 22 | current_(static_cast(buf_.get())), 23 | end_(current_ + std::max(buffer_size, kToStringMaxBytes)), 24 | fd_(out) {} 25 | 26 | #if __cplusplus >= 201103L 27 | FileStream(FileStream &&from) noexcept : buf_(from.buf_.release()), current_(from.current_), end_(from.end_), fd_(from.fd_) { 28 | from.end_ = reinterpret_cast(from.buf_.get()); 29 | from.current_ = from.end_; 30 | } 31 | #endif 32 | 33 | ~FileStream() { 34 | flush(); 35 | } 36 | 37 | void SetFD(int to) { 38 | flush(); 39 | fd_ = to; 40 | } 41 | 42 | FileStream &flush() { 43 | if (current_ != buf_.get()) { 44 | util::WriteOrThrow(fd_, buf_.get(), current_ - (char*)buf_.get()); 45 | current_ = static_cast(buf_.get()); 46 | } 47 | return *this; 48 | } 49 | 50 | // For writes of arbitrary size. 51 | FileStream &write(const void *data, std::size_t length) { 52 | if (UTIL_LIKELY(current_ + length <= end_)) { 53 | std::memcpy(current_, data, length); 54 | current_ += length; 55 | return *this; 56 | } 57 | flush(); 58 | if (current_ + length <= end_) { 59 | std::memcpy(current_, data, length); 60 | current_ += length; 61 | } else { 62 | util::WriteOrThrow(fd_, data, length); 63 | } 64 | return *this; 65 | } 66 | 67 | FileStream &seekp(uint64_t to) { 68 | flush(); 69 | util::SeekOrThrow(fd_, to); 70 | return *this; 71 | } 72 | 73 | protected: 74 | friend class FakeOStream; 75 | // For writes directly to buffer guaranteed to have amount < buffer size. 76 | char *Ensure(std::size_t amount) { 77 | if (UTIL_UNLIKELY(current_ + amount > end_)) { 78 | flush(); 79 | assert(current_ + amount <= end_); 80 | } 81 | return current_; 82 | } 83 | 84 | void AdvanceTo(char *to) { 85 | current_ = to; 86 | assert(current_ <= end_); 87 | } 88 | 89 | private: 90 | util::scoped_malloc buf_; 91 | char *current_, *end_; 92 | int fd_; 93 | }; 94 | 95 | } // namespace 96 | 97 | #endif 98 | -------------------------------------------------------------------------------- /util/float_to_string.cc: -------------------------------------------------------------------------------- 1 | #include "float_to_string.hh" 2 | 3 | #include "double-conversion/double-conversion.h" 4 | #include "double-conversion/utils.h" 5 | 6 | namespace util { 7 | namespace { 8 | const double_conversion::DoubleToStringConverter kConverter(double_conversion::DoubleToStringConverter::NO_FLAGS, "inf", "NaN", 'e', -6, 21, 6, 0); 9 | } // namespace 10 | 11 | char *ToString(double value, char *to) { 12 | double_conversion::StringBuilder builder(to, ToStringBuf::kBytes); 13 | kConverter.ToShortest(value, &builder); 14 | return &to[builder.position()]; 15 | } 16 | 17 | char *ToString(float value, char *to) { 18 | double_conversion::StringBuilder builder(to, ToStringBuf::kBytes); 19 | kConverter.ToShortestSingle(value, &builder); 20 | return &to[builder.position()]; 21 | } 22 | 23 | } // namespace util 24 | -------------------------------------------------------------------------------- /util/float_to_string.hh: -------------------------------------------------------------------------------- 1 | #ifndef UTIL_FLOAT_TO_STRING_H 2 | #define UTIL_FLOAT_TO_STRING_H 3 | 4 | // Just for ToStringBuf 5 | #include "integer_to_string.hh" 6 | 7 | namespace util { 8 | 9 | template <> struct ToStringBuf { 10 | // DoubleToStringConverter::kBase10MaximalLength + 1 for null paranoia. 11 | static const unsigned kBytes = 19; 12 | }; 13 | 14 | // Single wasn't documented in double conversion, so be conservative and 15 | // say the same as double. 16 | template <> struct ToStringBuf { 17 | static const unsigned kBytes = 19; 18 | }; 19 | 20 | char *ToString(double value, char *to); 21 | char *ToString(float value, char *to); 22 | 23 | } // namespace util 24 | 25 | #endif // UTIL_FLOAT_TO_STRING_H 26 | -------------------------------------------------------------------------------- /util/getopt.c: -------------------------------------------------------------------------------- 1 | /* 2 | POSIX getopt for Windows 3 | 4 | AT&T Public License 5 | 6 | Code given out at the 1985 UNIFORUM conference in Dallas. 7 | */ 8 | 9 | #ifndef __GNUC__ 10 | 11 | #include "getopt.hh" 12 | #include 13 | #include 14 | 15 | #define NULL 0 16 | #define EOF (-1) 17 | #define ERR(s, c) if(opterr){\ 18 | char errbuf[2];\ 19 | errbuf[0] = c; errbuf[1] = '\n';\ 20 | fputs(argv[0], stderr);\ 21 | fputs(s, stderr);\ 22 | fputc(c, stderr);} 23 | //(void) write(2, argv[0], (unsigned)strlen(argv[0]));\ 24 | //(void) write(2, s, (unsigned)strlen(s));\ 25 | //(void) write(2, errbuf, 2);} 26 | 27 | int opterr = 1; 28 | int optind = 1; 29 | int optopt; 30 | char *optarg; 31 | 32 | int 33 | getopt(argc, argv, opts) 34 | int argc; 35 | char **argv, *opts; 36 | { 37 | static int sp = 1; 38 | register int c; 39 | register char *cp; 40 | 41 | if(sp == 1) 42 | if(optind >= argc || 43 | argv[optind][0] != '-' || argv[optind][1] == '\0') 44 | return(EOF); 45 | else if(strcmp(argv[optind], "--") == NULL) { 46 | optind++; 47 | return(EOF); 48 | } 49 | optopt = c = argv[optind][sp]; 50 | if(c == ':' || (cp=strchr(opts, c)) == NULL) { 51 | ERR(": illegal option -- ", c); 52 | if(argv[optind][++sp] == '\0') { 53 | optind++; 54 | sp = 1; 55 | } 56 | return('?'); 57 | } 58 | if(*++cp == ':') { 59 | if(argv[optind][sp+1] != '\0') 60 | optarg = &argv[optind++][sp+1]; 61 | else if(++optind >= argc) { 62 | ERR(": option requires an argument -- ", c); 63 | sp = 1; 64 | return('?'); 65 | } else 66 | optarg = argv[optind++]; 67 | sp = 1; 68 | } else { 69 | if(argv[optind][++sp] == '\0') { 70 | sp = 1; 71 | optind++; 72 | } 73 | optarg = NULL; 74 | } 75 | return(c); 76 | } 77 | 78 | #endif /* __GNUC__ */ 79 | -------------------------------------------------------------------------------- /util/getopt.hh: -------------------------------------------------------------------------------- 1 | /* 2 | POSIX getopt for Windows 3 | 4 | AT&T Public License 5 | 6 | Code given out at the 1985 UNIFORUM conference in Dallas. 7 | */ 8 | 9 | #ifdef __GNUC__ 10 | #include 11 | #endif 12 | #ifndef __GNUC__ 13 | 14 | #ifndef UTIL_GETOPT_H 15 | #define UTIL_GETOPT_H 16 | 17 | #ifdef __cplusplus 18 | extern "C" { 19 | #endif 20 | 21 | extern int opterr; 22 | extern int optind; 23 | extern int optopt; 24 | extern char *optarg; 25 | extern int getopt(int argc, char **argv, char *opts); 26 | 27 | #ifdef __cplusplus 28 | } 29 | #endif 30 | 31 | #endif /* UTIL_GETOPT_H */ 32 | #endif /* __GNUC__ */ 33 | 34 | -------------------------------------------------------------------------------- /util/have.hh: -------------------------------------------------------------------------------- 1 | /* Optional packages. You might want to integrate this with your build system e.g. config.h from ./configure. */ 2 | #ifndef UTIL_HAVE_H 3 | #define UTIL_HAVE_H 4 | 5 | #ifdef HAVE_CONFIG_H 6 | #include "config.h" 7 | #endif 8 | 9 | #ifndef HAVE_ICU 10 | //#define HAVE_ICU 11 | #endif 12 | 13 | #endif // UTIL_HAVE_H 14 | -------------------------------------------------------------------------------- /util/integer_to_string.hh: -------------------------------------------------------------------------------- 1 | #ifndef UTIL_INTEGER_TO_STRING_H 2 | #define UTIL_INTEGER_TO_STRING_H 3 | #include 4 | #include 5 | 6 | namespace util { 7 | 8 | /* These functions convert integers to strings and return the end pointer. 9 | */ 10 | char *ToString(uint32_t value, char *to); 11 | char *ToString(uint64_t value, char *to); 12 | 13 | // Implemented as wrappers to above 14 | char *ToString(int32_t value, char *to); 15 | char *ToString(int64_t value, char *to); 16 | 17 | // Calls the 32-bit versions for now. 18 | char *ToString(uint16_t value, char *to); 19 | char *ToString(int16_t value, char *to); 20 | 21 | char *ToString(const void *value, char *to); 22 | 23 | inline char *ToString(bool value, char *to) { 24 | *to++ = '0' + value; 25 | return to; 26 | } 27 | 28 | // How many bytes to reserve in the buffer for these strings: 29 | // g++ 4.9.1 doesn't work with this: 30 | // static const std::size_t kBytes = 5; 31 | // So use enum. 32 | template struct ToStringBuf; 33 | template <> struct ToStringBuf { 34 | enum { kBytes = 1 }; 35 | }; 36 | template <> struct ToStringBuf { 37 | enum { kBytes = 5 }; 38 | }; 39 | template <> struct ToStringBuf { 40 | enum { kBytes = 6 }; 41 | }; 42 | template <> struct ToStringBuf { 43 | enum { kBytes = 10 }; 44 | }; 45 | template <> struct ToStringBuf { 46 | enum { kBytes = 11 }; 47 | }; 48 | template <> struct ToStringBuf { 49 | enum { kBytes = 20 }; 50 | }; 51 | template <> struct ToStringBuf { 52 | // Not a typo. 2^63 has 19 digits. 53 | enum { kBytes = 20 }; 54 | }; 55 | 56 | template <> struct ToStringBuf { 57 | // Either 18 on 64-bit or 10 on 32-bit. 58 | enum { kBytes = sizeof(const void*) * 2 + 2 }; 59 | }; 60 | 61 | // Maximum over this and float. 62 | enum { kToStringMaxBytes = 20 }; 63 | 64 | } // namespace util 65 | 66 | #endif // UTIL_INTEGER_TO_STRING_H 67 | -------------------------------------------------------------------------------- /util/integer_to_string_test.cc: -------------------------------------------------------------------------------- 1 | #define BOOST_LEXICAL_CAST_ASSUME_C_LOCALE 2 | #include "integer_to_string.hh" 3 | #include "string_piece.hh" 4 | 5 | #define BOOST_TEST_MODULE IntegerToStringTest 6 | #include 7 | #include 8 | 9 | #include 10 | 11 | namespace util { 12 | namespace { 13 | 14 | template void TestValue(const T value) { 15 | char buf[ToStringBuf::kBytes]; 16 | StringPiece result(buf, ToString(value, buf) - buf); 17 | BOOST_REQUIRE_GE(static_cast(ToStringBuf::kBytes), result.size()); 18 | if (value) { 19 | BOOST_CHECK_EQUAL(boost::lexical_cast(value), result); 20 | } else { 21 | // Platforms can do void * as 0x0 or 0. 22 | BOOST_CHECK(result == "0x0" || result == "0"); 23 | } 24 | } 25 | 26 | template void TestCorners() { 27 | TestValue(std::numeric_limits::min()); 28 | TestValue(std::numeric_limits::max()); 29 | TestValue((T)0); 30 | TestValue((T)-1); 31 | TestValue((T)1); 32 | } 33 | 34 | BOOST_AUTO_TEST_CASE(Corners) { 35 | TestCorners(); 36 | TestCorners(); 37 | TestCorners(); 38 | TestCorners(); 39 | TestCorners(); 40 | TestCorners(); 41 | TestCorners(); 42 | } 43 | 44 | template void TestAll() { 45 | for (T i = std::numeric_limits::min(); i < std::numeric_limits::max(); ++i) { 46 | TestValue(i); 47 | } 48 | TestValue(std::numeric_limits::max()); 49 | } 50 | 51 | BOOST_AUTO_TEST_CASE(Short) { 52 | TestAll(); 53 | TestAll(); 54 | } 55 | 56 | template void Test10s() { 57 | for (T i = 1; i < std::numeric_limits::max() / 10; i *= 10) { 58 | TestValue(i); 59 | TestValue(i - 1); 60 | TestValue(i + 1); 61 | } 62 | } 63 | 64 | BOOST_AUTO_TEST_CASE(Tens) { 65 | Test10s(); 66 | Test10s(); 67 | Test10s(); 68 | Test10s(); 69 | } 70 | 71 | BOOST_AUTO_TEST_CASE(Pointers) { 72 | for (uintptr_t i = 1; i < std::numeric_limits::max() / 10; i *= 10) { 73 | TestValue((const void*)i); 74 | } 75 | for (uintptr_t i = 0; i < 256; ++i) { 76 | TestValue((const void*)i); 77 | TestValue((const void*)(i + 0xf00)); 78 | } 79 | } 80 | 81 | }} // namespaces 82 | -------------------------------------------------------------------------------- /util/joint_sort_test.cc: -------------------------------------------------------------------------------- 1 | #include "joint_sort.hh" 2 | 3 | #define BOOST_TEST_MODULE JointSortTest 4 | #include 5 | 6 | namespace util { namespace { 7 | 8 | BOOST_AUTO_TEST_CASE(just_flip) { 9 | char keys[2]; 10 | int values[2]; 11 | keys[0] = 1; values[0] = 327; 12 | keys[1] = 0; values[1] = 87897; 13 | JointSort(keys + 0, keys + 2, values + 0); 14 | BOOST_CHECK_EQUAL(0, keys[0]); 15 | BOOST_CHECK_EQUAL(87897, values[0]); 16 | BOOST_CHECK_EQUAL(1, keys[1]); 17 | BOOST_CHECK_EQUAL(327, values[1]); 18 | } 19 | 20 | BOOST_AUTO_TEST_CASE(three) { 21 | char keys[3]; 22 | int values[3]; 23 | keys[0] = 1; values[0] = 327; 24 | keys[1] = 2; values[1] = 87897; 25 | keys[2] = 0; values[2] = 10; 26 | JointSort(keys + 0, keys + 3, values + 0); 27 | BOOST_CHECK_EQUAL(0, keys[0]); 28 | BOOST_CHECK_EQUAL(1, keys[1]); 29 | BOOST_CHECK_EQUAL(2, keys[2]); 30 | } 31 | 32 | BOOST_AUTO_TEST_CASE(char_int) { 33 | char keys[4]; 34 | int values[4]; 35 | keys[0] = 3; values[0] = 327; 36 | keys[1] = 1; values[1] = 87897; 37 | keys[2] = 2; values[2] = 10; 38 | keys[3] = 0; values[3] = 24347; 39 | JointSort(keys + 0, keys + 4, values + 0); 40 | BOOST_CHECK_EQUAL(0, keys[0]); 41 | BOOST_CHECK_EQUAL(24347, values[0]); 42 | BOOST_CHECK_EQUAL(1, keys[1]); 43 | BOOST_CHECK_EQUAL(87897, values[1]); 44 | BOOST_CHECK_EQUAL(2, keys[2]); 45 | BOOST_CHECK_EQUAL(10, values[2]); 46 | BOOST_CHECK_EQUAL(3, keys[3]); 47 | BOOST_CHECK_EQUAL(327, values[3]); 48 | } 49 | 50 | BOOST_AUTO_TEST_CASE(swap_proxy) { 51 | char keys[2] = {0, 1}; 52 | int values[2] = {2, 3}; 53 | detail::JointProxy first(keys, values); 54 | detail::JointProxy second(keys + 1, values + 1); 55 | swap(first, second); 56 | BOOST_CHECK_EQUAL(1, keys[0]); 57 | BOOST_CHECK_EQUAL(0, keys[1]); 58 | BOOST_CHECK_EQUAL(3, values[0]); 59 | BOOST_CHECK_EQUAL(2, values[1]); 60 | } 61 | 62 | }} // namespace anonymous util 63 | -------------------------------------------------------------------------------- /util/multi_intersection_test.cc: -------------------------------------------------------------------------------- 1 | #include "multi_intersection.hh" 2 | 3 | #define BOOST_TEST_MODULE MultiIntersectionTest 4 | #include 5 | 6 | namespace util { 7 | namespace { 8 | 9 | BOOST_AUTO_TEST_CASE(Empty) { 10 | std::vector > sets; 11 | 12 | sets.push_back(boost::iterator_range(static_cast(NULL), static_cast(NULL))); 13 | BOOST_CHECK(!FirstIntersection(sets)); 14 | } 15 | 16 | BOOST_AUTO_TEST_CASE(Single) { 17 | std::vector nums; 18 | nums.push_back(1); 19 | nums.push_back(4); 20 | nums.push_back(100); 21 | std::vector::const_iterator> > sets; 22 | sets.push_back(nums); 23 | 24 | boost::optional ret(FirstIntersection(sets)); 25 | 26 | BOOST_REQUIRE(ret); 27 | BOOST_CHECK_EQUAL(static_cast(1), *ret); 28 | } 29 | 30 | template boost::iterator_range RangeFromArray(const T (&arr)[len]) { 31 | return boost::iterator_range(arr, arr + len); 32 | } 33 | 34 | BOOST_AUTO_TEST_CASE(MultiNone) { 35 | unsigned int nums0[] = {1, 3, 4, 22}; 36 | unsigned int nums1[] = {2, 5, 12}; 37 | unsigned int nums2[] = {4, 17}; 38 | 39 | std::vector > sets; 40 | sets.push_back(RangeFromArray(nums0)); 41 | sets.push_back(RangeFromArray(nums1)); 42 | sets.push_back(RangeFromArray(nums2)); 43 | 44 | BOOST_CHECK(!FirstIntersection(sets)); 45 | } 46 | 47 | BOOST_AUTO_TEST_CASE(MultiOne) { 48 | unsigned int nums0[] = {1, 3, 4, 17, 22}; 49 | unsigned int nums1[] = {2, 5, 12, 17}; 50 | unsigned int nums2[] = {4, 17}; 51 | 52 | std::vector > sets; 53 | sets.push_back(RangeFromArray(nums0)); 54 | sets.push_back(RangeFromArray(nums1)); 55 | sets.push_back(RangeFromArray(nums2)); 56 | 57 | boost::optional ret(FirstIntersection(sets)); 58 | BOOST_REQUIRE(ret); 59 | BOOST_CHECK_EQUAL(static_cast(17), *ret); 60 | } 61 | 62 | } // namespace 63 | } // namespace util 64 | -------------------------------------------------------------------------------- /util/murmur_hash.hh: -------------------------------------------------------------------------------- 1 | #ifndef UTIL_MURMUR_HASH_H 2 | #define UTIL_MURMUR_HASH_H 3 | #include 4 | #include 5 | 6 | namespace util { 7 | 8 | // 64-bit machine version 9 | uint64_t MurmurHash64A(const void * key, std::size_t len, uint64_t seed = 0); 10 | // 32-bit machine version (not the same function as above) 11 | uint64_t MurmurHash64B(const void * key, std::size_t len, uint64_t seed = 0); 12 | // Use the version for this arch. Because the values differ across 13 | // architectures, really only use it for in-memory structures. 14 | uint64_t MurmurHashNative(const void * key, std::size_t len, uint64_t seed = 0); 15 | 16 | } // namespace util 17 | 18 | #endif // UTIL_MURMUR_HASH_H 19 | -------------------------------------------------------------------------------- /util/parallel_read.cc: -------------------------------------------------------------------------------- 1 | #include "parallel_read.hh" 2 | 3 | #include "file.hh" 4 | 5 | #ifdef WITH_THREADS 6 | #include "thread_pool.hh" 7 | 8 | namespace util { 9 | namespace { 10 | 11 | class Reader { 12 | public: 13 | explicit Reader(int fd) : fd_(fd) {} 14 | 15 | struct Request { 16 | void *to; 17 | std::size_t size; 18 | uint64_t offset; 19 | 20 | bool operator==(const Request &other) const { 21 | return (to == other.to) && (size == other.size) && (offset == other.offset); 22 | } 23 | }; 24 | 25 | void operator()(const Request &request) { 26 | util::ErsatzPRead(fd_, request.to, request.size, request.offset); 27 | } 28 | 29 | private: 30 | int fd_; 31 | }; 32 | 33 | } // namespace 34 | 35 | void ParallelRead(int fd, void *to, std::size_t amount, uint64_t offset) { 36 | Reader::Request poison; 37 | poison.to = NULL; 38 | poison.size = 0; 39 | poison.offset = 0; 40 | unsigned threads = boost::thread::hardware_concurrency(); 41 | if (!threads) threads = 2; 42 | ThreadPool pool(2 /* don't need much of a queue */, threads, fd, poison); 43 | const std::size_t kBatch = 1ULL << 25; // 32 MB 44 | Reader::Request request; 45 | request.to = to; 46 | request.size = kBatch; 47 | request.offset = offset; 48 | for (; amount > kBatch; amount -= kBatch) { 49 | pool.Produce(request); 50 | request.to = reinterpret_cast(request.to) + kBatch; 51 | request.offset += kBatch; 52 | } 53 | request.size = amount; 54 | if (request.size) { 55 | pool.Produce(request); 56 | } 57 | } 58 | 59 | } // namespace util 60 | 61 | #else // WITH_THREADS 62 | 63 | namespace util { 64 | void ParallelRead(int fd, void *to, std::size_t amount, uint64_t offset) { 65 | util::ErsatzPRead(fd, to, amount, offset); 66 | } 67 | } // namespace util 68 | 69 | #endif 70 | -------------------------------------------------------------------------------- /util/parallel_read.hh: -------------------------------------------------------------------------------- 1 | #ifndef UTIL_PARALLEL_READ__ 2 | #define UTIL_PARALLEL_READ__ 3 | 4 | /* Read pieces of a file in parallel. This has a very specific use case: 5 | * reading files from Lustre is CPU bound so multiple threads actually 6 | * increases throughput. Speed matters when an LM takes a terabyte. 7 | */ 8 | 9 | #include 10 | #include 11 | 12 | namespace util { 13 | void ParallelRead(int fd, void *to, std::size_t amount, uint64_t offset); 14 | } // namespace util 15 | 16 | #endif // UTIL_PARALLEL_READ__ 17 | -------------------------------------------------------------------------------- /util/pcqueue_test.cc: -------------------------------------------------------------------------------- 1 | #include "pcqueue.hh" 2 | 3 | #define BOOST_TEST_MODULE PCQueueTest 4 | #include 5 | 6 | namespace util { 7 | namespace { 8 | 9 | BOOST_AUTO_TEST_CASE(SingleThread) { 10 | PCQueue queue(10); 11 | for (int i = 0; i < 10; ++i) { 12 | queue.Produce(i); 13 | } 14 | for (int i = 0; i < 10; ++i) { 15 | BOOST_CHECK_EQUAL(i, queue.Consume()); 16 | } 17 | } 18 | 19 | } 20 | } // namespace util 21 | -------------------------------------------------------------------------------- /util/pool.cc: -------------------------------------------------------------------------------- 1 | #include "pool.hh" 2 | 3 | #include "scoped.hh" 4 | 5 | #include 6 | 7 | #include 8 | 9 | namespace util { 10 | 11 | Pool::Pool() { 12 | current_ = NULL; 13 | current_end_ = NULL; 14 | } 15 | 16 | Pool::~Pool() { 17 | FreeAll(); 18 | } 19 | 20 | void Pool::FreeAll() { 21 | for (std::vector::const_iterator i(free_list_.begin()); i != free_list_.end(); ++i) { 22 | free(*i); 23 | } 24 | free_list_.clear(); 25 | current_ = NULL; 26 | current_end_ = NULL; 27 | } 28 | 29 | void *Pool::More(std::size_t size) { 30 | std::size_t amount = std::max(static_cast(32) << free_list_.size(), size); 31 | uint8_t *ret = static_cast(MallocOrThrow(amount)); 32 | free_list_.push_back(ret); 33 | current_ = ret + size; 34 | current_end_ = ret + amount; 35 | return ret; 36 | } 37 | 38 | } // namespace util 39 | -------------------------------------------------------------------------------- /util/probing_hash_table_test.cc: -------------------------------------------------------------------------------- 1 | #include "probing_hash_table.hh" 2 | 3 | #include "murmur_hash.hh" 4 | #include "scoped.hh" 5 | 6 | #define BOOST_TEST_MODULE ProbingHashTableTest 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | namespace util { 16 | namespace { 17 | 18 | struct Entry { 19 | unsigned char key; 20 | typedef unsigned char Key; 21 | 22 | unsigned char GetKey() const { 23 | return key; 24 | } 25 | 26 | void SetKey(unsigned char to) { 27 | key = to; 28 | } 29 | 30 | uint64_t GetValue() const { 31 | return value; 32 | } 33 | 34 | uint64_t value; 35 | }; 36 | 37 | typedef ProbingHashTable > Table; 38 | 39 | BOOST_AUTO_TEST_CASE(simple) { 40 | size_t size = Table::Size(10, 1.2); 41 | boost::scoped_array mem(new char[size]); 42 | memset(mem.get(), 0, size); 43 | 44 | Table table(mem.get(), size); 45 | const Entry *i = NULL; 46 | BOOST_CHECK(!table.Find(2, i)); 47 | Entry to_ins; 48 | to_ins.key = 3; 49 | to_ins.value = 328920; 50 | table.Insert(to_ins); 51 | BOOST_REQUIRE(table.Find(3, i)); 52 | BOOST_CHECK_EQUAL(3, i->GetKey()); 53 | BOOST_CHECK_EQUAL(static_cast(328920), i->GetValue()); 54 | BOOST_CHECK(!table.Find(2, i)); 55 | } 56 | 57 | struct Entry64 { 58 | uint64_t key; 59 | typedef uint64_t Key; 60 | 61 | Entry64() {} 62 | 63 | explicit Entry64(uint64_t key_in) { 64 | key = key_in; 65 | } 66 | 67 | Key GetKey() const { return key; } 68 | void SetKey(uint64_t to) { key = to; } 69 | }; 70 | 71 | struct MurmurHashEntry64 { 72 | std::size_t operator()(uint64_t value) const { 73 | return util::MurmurHash64A(&value, 8); 74 | } 75 | }; 76 | 77 | typedef ProbingHashTable Table64; 78 | 79 | BOOST_AUTO_TEST_CASE(Double) { 80 | for (std::size_t initial = 19; initial < 30; ++initial) { 81 | size_t size = Table64::Size(initial, 1.2); 82 | scoped_malloc mem(MallocOrThrow(size)); 83 | Table64 table(mem.get(), size, std::numeric_limits::max()); 84 | table.Clear(); 85 | for (uint64_t i = 0; i < 19; ++i) { 86 | table.Insert(Entry64(i)); 87 | } 88 | table.CheckConsistency(); 89 | mem.call_realloc(table.DoubleTo()); 90 | table.Double(mem.get()); 91 | table.CheckConsistency(); 92 | for (uint64_t i = 20; i < 40 ; ++i) { 93 | table.Insert(Entry64(i)); 94 | } 95 | mem.call_realloc(table.DoubleTo()); 96 | table.Double(mem.get()); 97 | table.CheckConsistency(); 98 | } 99 | } 100 | 101 | } // namespace 102 | } // namespace util 103 | -------------------------------------------------------------------------------- /util/read_compressed.hh: -------------------------------------------------------------------------------- 1 | #ifndef UTIL_READ_COMPRESSED_H 2 | #define UTIL_READ_COMPRESSED_H 3 | 4 | #include "exception.hh" 5 | #include "scoped.hh" 6 | 7 | #include 8 | #include 9 | 10 | namespace util { 11 | 12 | class CompressedException : public Exception { 13 | public: 14 | CompressedException() throw(); 15 | virtual ~CompressedException() throw(); 16 | }; 17 | 18 | class GZException : public CompressedException { 19 | public: 20 | GZException() throw(); 21 | ~GZException() throw(); 22 | }; 23 | 24 | class BZException : public CompressedException { 25 | public: 26 | BZException() throw(); 27 | ~BZException() throw(); 28 | }; 29 | 30 | class XZException : public CompressedException { 31 | public: 32 | XZException() throw(); 33 | ~XZException() throw(); 34 | }; 35 | 36 | class ReadCompressed; 37 | 38 | class ReadBase { 39 | public: 40 | virtual ~ReadBase() {} 41 | 42 | virtual std::size_t Read(void *to, std::size_t amount, ReadCompressed &thunk) = 0; 43 | 44 | protected: 45 | static void ReplaceThis(ReadBase *with, ReadCompressed &thunk); 46 | 47 | ReadBase *Current(ReadCompressed &thunk); 48 | 49 | static uint64_t &ReadCount(ReadCompressed &thunk); 50 | }; 51 | 52 | class ReadCompressed { 53 | public: 54 | static const std::size_t kMagicSize = 6; 55 | // Must have at least kMagicSize bytes. 56 | static bool DetectCompressedMagic(const void *from); 57 | 58 | // Takes ownership of fd. 59 | explicit ReadCompressed(int fd); 60 | 61 | // Try to avoid using this. Use the fd instead. 62 | // There is no decompression support for istreams. 63 | explicit ReadCompressed(std::istream &in); 64 | 65 | // Must call Reset later. 66 | ReadCompressed(); 67 | 68 | // Takes ownership of fd. 69 | void Reset(int fd); 70 | 71 | // Same advice as the constructor. 72 | void Reset(std::istream &in); 73 | 74 | std::size_t Read(void *to, std::size_t amount); 75 | 76 | // Repeatedly call read to fill a buffer unless EOF is hit. 77 | // Return number of bytes read. 78 | std::size_t ReadOrEOF(void *const to, std::size_t amount); 79 | 80 | uint64_t RawAmount() const { return raw_amount_; } 81 | 82 | private: 83 | friend class ReadBase; 84 | 85 | scoped_ptr internal_; 86 | 87 | uint64_t raw_amount_; 88 | }; 89 | 90 | } // namespace util 91 | 92 | #endif // UTIL_READ_COMPRESSED_H 93 | -------------------------------------------------------------------------------- /util/scoped.cc: -------------------------------------------------------------------------------- 1 | #include "scoped.hh" 2 | 3 | #include 4 | #if !defined(_WIN32) && !defined(_WIN64) 5 | #include 6 | #endif 7 | 8 | namespace util { 9 | 10 | // TODO: if we're really under memory pressure, don't allocate memory to 11 | // display the error. 12 | MallocException::MallocException(std::size_t requested) throw() { 13 | *this << "for " << requested << " bytes "; 14 | } 15 | 16 | MallocException::~MallocException() throw() {} 17 | 18 | namespace { 19 | void *InspectAddr(void *addr, std::size_t requested, const char *func_name) { 20 | UTIL_THROW_IF_ARG(!addr && requested, MallocException, (requested), "in " << func_name); 21 | return addr; 22 | } 23 | } // namespace 24 | 25 | void *MallocOrThrow(std::size_t requested) { 26 | return InspectAddr(std::malloc(requested), requested, "malloc"); 27 | } 28 | 29 | void *CallocOrThrow(std::size_t requested) { 30 | return InspectAddr(std::calloc(requested, 1), requested, "calloc"); 31 | } 32 | 33 | void scoped_malloc::call_realloc(std::size_t requested) { 34 | p_ = InspectAddr(std::realloc(p_, requested), requested, "realloc"); 35 | } 36 | 37 | void AdviseHugePages(const void *addr, std::size_t size) { 38 | #if MADV_HUGEPAGE 39 | madvise((void*)addr, size, MADV_HUGEPAGE); 40 | #endif 41 | } 42 | 43 | } // namespace util 44 | -------------------------------------------------------------------------------- /util/sized_iterator_test.cc: -------------------------------------------------------------------------------- 1 | #include "sized_iterator.hh" 2 | 3 | #define BOOST_TEST_MODULE SizedIteratorTest 4 | #include 5 | 6 | namespace util { namespace { 7 | 8 | struct CompareChar { 9 | bool operator()(const void *first, const void *second) const { 10 | return *static_cast(first) < *static_cast(second); 11 | } 12 | }; 13 | 14 | BOOST_AUTO_TEST_CASE(sort) { 15 | char items[3] = {1, 2, 0}; 16 | SizedSort(items, items + 3, 1, CompareChar()); 17 | BOOST_CHECK_EQUAL(0, items[0]); 18 | BOOST_CHECK_EQUAL(1, items[1]); 19 | BOOST_CHECK_EQUAL(2, items[2]); 20 | } 21 | 22 | }} // namespace anonymous util 23 | -------------------------------------------------------------------------------- /util/spaces.cc: -------------------------------------------------------------------------------- 1 | #include "spaces.hh" 2 | 3 | namespace util { 4 | 5 | // Sigh this is the only way I could come up with to do a _const_ bool. It has ' ', '\f', '\n', '\r', '\t', and '\v' (same as isspace on C locale). 6 | const bool kSpaces[256] = {0,0,0,0,0,0,0,0,0,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0}; 7 | 8 | } // namespace util 9 | -------------------------------------------------------------------------------- /util/spaces.hh: -------------------------------------------------------------------------------- 1 | #ifndef UTIL_SPACES_H 2 | #define UTIL_SPACES_H 3 | 4 | // bool array of spaces. 5 | 6 | namespace util { 7 | 8 | extern const bool kSpaces[256]; 9 | 10 | } // namespace util 11 | 12 | #endif // UTIL_SPACES_H 13 | -------------------------------------------------------------------------------- /util/stream/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # This CMake file was created by Lane Schwartz 2 | 3 | # Explicitly list the source files for this subdirectory 4 | # 5 | # If you add any source files to this subdirectory 6 | # that should be included in the kenlm library, 7 | # (this excludes any unit test files) 8 | # you should add them to the following list: 9 | # 10 | # In order to allow CMake files in the parent directory 11 | # to see this variable definition, we set PARENT_SCOPE. 12 | # 13 | # In order to set correct paths to these files 14 | # when this variable is referenced by CMake files in the parent directory, 15 | # we prefix all files with ${CMAKE_CURRENT_SOURCE_DIR}. 16 | # 17 | set(KENLM_UTIL_STREAM_SOURCE 18 | ${CMAKE_CURRENT_SOURCE_DIR}/chain.cc 19 | ${CMAKE_CURRENT_SOURCE_DIR}/count_records.cc 20 | ${CMAKE_CURRENT_SOURCE_DIR}/io.cc 21 | ${CMAKE_CURRENT_SOURCE_DIR}/line_input.cc 22 | ${CMAKE_CURRENT_SOURCE_DIR}/multi_progress.cc 23 | ${CMAKE_CURRENT_SOURCE_DIR}/rewindable_stream.cc 24 | PARENT_SCOPE) 25 | 26 | 27 | 28 | if(BUILD_TESTING) 29 | # Explicitly list the Boost test files to be compiled 30 | set(KENLM_BOOST_TESTS_LIST 31 | io_test 32 | sort_test 33 | stream_test 34 | rewindable_stream_test 35 | ) 36 | 37 | AddTests(TESTS ${KENLM_BOOST_TESTS_LIST} 38 | LIBRARIES kenlm_util ${Boost_LIBRARIES} Threads::Threads) 39 | endif() 40 | -------------------------------------------------------------------------------- /util/stream/block.hh: -------------------------------------------------------------------------------- 1 | #ifndef UTIL_STREAM_BLOCK_H 2 | #define UTIL_STREAM_BLOCK_H 3 | 4 | #include 5 | #include 6 | 7 | namespace util { 8 | namespace stream { 9 | 10 | /** 11 | * Encapsulates a block of memory. 12 | */ 13 | class Block { 14 | public: 15 | 16 | /** 17 | * Constructs an empty block. 18 | */ 19 | Block() : mem_(NULL), valid_size_(0) {} 20 | 21 | /** 22 | * Constructs a block that encapsulates a segment of memory. 23 | * 24 | * @param[in] mem The segment of memory to encapsulate 25 | * @param[in] size The size of the memory segment in bytes 26 | */ 27 | Block(void *mem, std::size_t size) : mem_(mem), valid_size_(size) {} 28 | 29 | /** 30 | * Set the number of bytes in this block that should be interpreted as valid. 31 | * 32 | * @param[in] to Number of bytes 33 | */ 34 | void SetValidSize(std::size_t to) { valid_size_ = to; } 35 | 36 | /** 37 | * Gets the number of bytes in this block that should be interpreted as valid. 38 | * This is important because read might fill in less than Allocated at EOF. 39 | */ 40 | std::size_t ValidSize() const { return valid_size_; } 41 | 42 | /** Gets a void pointer to the memory underlying this block. */ 43 | void *Get() { return mem_; } 44 | 45 | /** Gets a const void pointer to the memory underlying this block. */ 46 | const void *Get() const { return mem_; } 47 | 48 | 49 | /** 50 | * Gets a const void pointer to the end of the valid section of memory 51 | * encapsulated by this block. 52 | */ 53 | const void *ValidEnd() const { 54 | return reinterpret_cast(mem_) + valid_size_; 55 | } 56 | 57 | /** 58 | * Returns true if this block encapsulates a valid (non-NULL) block of memory. 59 | * 60 | * This method is a user-defined implicit conversion function to boolean; 61 | * among other things, this method enables bare instances of this class 62 | * to be used as the condition of an if statement. 63 | */ 64 | operator bool() const { return mem_ != NULL; } 65 | 66 | /** 67 | * Returns true if this block is empty. 68 | * 69 | * In other words, if Get()==NULL, this method will return true. 70 | */ 71 | bool operator!() const { return mem_ == NULL; } 72 | 73 | private: 74 | friend class Link; 75 | friend class RewindableStream; 76 | 77 | /** 78 | * Points this block's memory at NULL. 79 | * 80 | * This class defines poison as a block whose memory pointer is NULL. 81 | */ 82 | void SetToPoison() { 83 | mem_ = NULL; 84 | } 85 | 86 | void *mem_; 87 | std::size_t valid_size_; 88 | }; 89 | 90 | } // namespace stream 91 | } // namespace util 92 | 93 | #endif // UTIL_STREAM_BLOCK_H 94 | -------------------------------------------------------------------------------- /util/stream/config.hh: -------------------------------------------------------------------------------- 1 | #ifndef UTIL_STREAM_CONFIG_H 2 | #define UTIL_STREAM_CONFIG_H 3 | 4 | #include 5 | #include 6 | 7 | namespace util { namespace stream { 8 | 9 | /** 10 | * Represents how a chain should be configured. 11 | */ 12 | struct ChainConfig { 13 | 14 | /** Constructs an configuration with underspecified (or default) parameters. */ 15 | ChainConfig() {} 16 | 17 | /** 18 | * Constructs a chain configuration object. 19 | * 20 | * @param [in] in_entry_size Number of bytes in each record. 21 | * @param [in] in_block_count Number of blocks in the chain. 22 | * @param [in] in_total_memory Total number of bytes available to the chain. 23 | * This value will be divided amongst the blocks in the chain. 24 | */ 25 | ChainConfig(std::size_t in_entry_size, std::size_t in_block_count, std::size_t in_total_memory) 26 | : entry_size(in_entry_size), block_count(in_block_count), total_memory(in_total_memory) {} 27 | 28 | /** 29 | * Number of bytes in each record. 30 | */ 31 | std::size_t entry_size; 32 | 33 | /** 34 | * Number of blocks in the chain. 35 | */ 36 | std::size_t block_count; 37 | 38 | /** 39 | * Total number of bytes available to the chain. 40 | * This value will be divided amongst the blocks in the chain. 41 | * Chain's constructor will make this a multiple of entry_size. 42 | */ 43 | std::size_t total_memory; 44 | }; 45 | 46 | 47 | /** 48 | * Represents how a sorter should be configured. 49 | */ 50 | struct SortConfig { 51 | 52 | /** Filename prefix where temporary files should be placed. */ 53 | std::string temp_prefix; 54 | 55 | /** Size of each input/output buffer. */ 56 | std::size_t buffer_size; 57 | 58 | /** Total memory to use when running alone. */ 59 | std::size_t total_memory; 60 | }; 61 | 62 | }} // namespaces 63 | #endif // UTIL_STREAM_CONFIG_H 64 | -------------------------------------------------------------------------------- /util/stream/count_records.cc: -------------------------------------------------------------------------------- 1 | #include "count_records.hh" 2 | #include "chain.hh" 3 | 4 | namespace util { namespace stream { 5 | 6 | void CountRecords::Run(const ChainPosition &position) { 7 | for (Link link(position); link; ++link) { 8 | *count_ += link->ValidSize() / position.GetChain().EntrySize(); 9 | } 10 | } 11 | 12 | }} // namespaces 13 | -------------------------------------------------------------------------------- /util/stream/count_records.hh: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | namespace util { namespace stream { 4 | 5 | class ChainPosition; 6 | 7 | class CountRecords { 8 | public: 9 | explicit CountRecords(uint64_t *out) 10 | : count_(out) { 11 | *count_ = 0; 12 | } 13 | 14 | void Run(const ChainPosition &position); 15 | 16 | private: 17 | uint64_t *count_; 18 | }; 19 | 20 | }} // namespaces 21 | -------------------------------------------------------------------------------- /util/stream/io.cc: -------------------------------------------------------------------------------- 1 | #include "io.hh" 2 | 3 | #include "../file.hh" 4 | #include "chain.hh" 5 | 6 | #include 7 | 8 | namespace util { 9 | namespace stream { 10 | 11 | ReadSizeException::ReadSizeException() throw() {} 12 | ReadSizeException::~ReadSizeException() throw() {} 13 | 14 | void Read::Run(const ChainPosition &position) { 15 | const std::size_t block_size = position.GetChain().BlockSize(); 16 | const std::size_t entry_size = position.GetChain().EntrySize(); 17 | for (Link link(position); link; ++link) { 18 | std::size_t got = util::ReadOrEOF(file_, link->Get(), block_size); 19 | UTIL_THROW_IF(got % entry_size, ReadSizeException, "File ended with " << got << " bytes, not a multiple of " << entry_size << "."); 20 | if (got == 0) { 21 | link.Poison(); 22 | return; 23 | } else { 24 | link->SetValidSize(got); 25 | } 26 | } 27 | } 28 | 29 | void PRead::Run(const ChainPosition &position) { 30 | scoped_fd owner; 31 | if (own_) owner.reset(file_); 32 | const uint64_t size = SizeOrThrow(file_); 33 | UTIL_THROW_IF(size % static_cast(position.GetChain().EntrySize()), ReadSizeException, "File size " << file_ << " size is " << size << " not a multiple of " << position.GetChain().EntrySize()); 34 | const std::size_t block_size = position.GetChain().BlockSize(); 35 | const uint64_t block_size64 = static_cast(block_size); 36 | Link link(position); 37 | uint64_t offset = 0; 38 | for (; offset + block_size64 < size; offset += block_size64, ++link) { 39 | ErsatzPRead(file_, link->Get(), block_size, offset); 40 | link->SetValidSize(block_size); 41 | } 42 | // size - offset is <= block_size, so it casts to 32-bit fine. 43 | if (size - offset) { 44 | ErsatzPRead(file_, link->Get(), size - offset, offset); 45 | link->SetValidSize(size - offset); 46 | ++link; 47 | } 48 | link.Poison(); 49 | } 50 | 51 | void Write::Run(const ChainPosition &position) { 52 | for (Link link(position); link; ++link) { 53 | WriteOrThrow(file_, link->Get(), link->ValidSize()); 54 | } 55 | } 56 | 57 | void WriteAndRecycle::Run(const ChainPosition &position) { 58 | const std::size_t block_size = position.GetChain().BlockSize(); 59 | for (Link link(position); link; ++link) { 60 | WriteOrThrow(file_, link->Get(), link->ValidSize()); 61 | link->SetValidSize(block_size); 62 | } 63 | } 64 | 65 | void PWrite::Run(const ChainPosition &position) { 66 | uint64_t offset = 0; 67 | for (Link link(position); link; ++link) { 68 | ErsatzPWrite(file_, link->Get(), link->ValidSize(), offset); 69 | offset += link->ValidSize(); 70 | } 71 | // Trim file to size. 72 | util::ResizeOrThrow(file_, offset); 73 | } 74 | 75 | } // namespace stream 76 | } // namespace util 77 | -------------------------------------------------------------------------------- /util/stream/io.hh: -------------------------------------------------------------------------------- 1 | #ifndef UTIL_STREAM_IO_H 2 | #define UTIL_STREAM_IO_H 3 | 4 | #include "../exception.hh" 5 | #include "../file.hh" 6 | 7 | namespace util { 8 | namespace stream { 9 | 10 | class ChainPosition; 11 | 12 | class ReadSizeException : public util::Exception { 13 | public: 14 | ReadSizeException() throw(); 15 | ~ReadSizeException() throw(); 16 | }; 17 | 18 | class Read { 19 | public: 20 | explicit Read(int fd) : file_(fd) {} 21 | void Run(const ChainPosition &position); 22 | private: 23 | int file_; 24 | }; 25 | 26 | // Like read but uses pread so that the file can be accessed from multiple threads. 27 | class PRead { 28 | public: 29 | explicit PRead(int fd, bool take_own = false) : file_(fd), own_(take_own) {} 30 | void Run(const ChainPosition &position); 31 | private: 32 | int file_; 33 | bool own_; 34 | }; 35 | 36 | class Write { 37 | public: 38 | explicit Write(int fd) : file_(fd) {} 39 | void Run(const ChainPosition &position); 40 | private: 41 | int file_; 42 | }; 43 | 44 | // It's a common case that stuff is written and then recycled. So rather than 45 | // spawn another thread to Recycle, this combines the two roles. 46 | class WriteAndRecycle { 47 | public: 48 | explicit WriteAndRecycle(int fd) : file_(fd) {} 49 | void Run(const ChainPosition &position); 50 | private: 51 | int file_; 52 | }; 53 | 54 | class PWrite { 55 | public: 56 | explicit PWrite(int fd) : file_(fd) {} 57 | void Run(const ChainPosition &position); 58 | private: 59 | int file_; 60 | }; 61 | 62 | 63 | // Reuse the same file over and over again to buffer output. 64 | class FileBuffer { 65 | public: 66 | explicit FileBuffer(int fd) : file_(fd) {} 67 | 68 | PWrite Sink() const { 69 | util::SeekOrThrow(file_.get(), 0); 70 | return PWrite(file_.get()); 71 | } 72 | 73 | PRead Source(bool discard = false) { 74 | return PRead(discard ? file_.release() : file_.get(), discard); 75 | } 76 | 77 | uint64_t Size() const { 78 | return SizeOrThrow(file_.get()); 79 | } 80 | 81 | private: 82 | scoped_fd file_; 83 | }; 84 | 85 | } // namespace stream 86 | } // namespace util 87 | #endif // UTIL_STREAM_IO_H 88 | -------------------------------------------------------------------------------- /util/stream/io_test.cc: -------------------------------------------------------------------------------- 1 | #include "io.hh" 2 | 3 | #include "chain.hh" 4 | #include "../file.hh" 5 | 6 | #define BOOST_TEST_MODULE IOTest 7 | #include 8 | 9 | #include 10 | 11 | namespace util { namespace stream { namespace { 12 | 13 | BOOST_AUTO_TEST_CASE(CopyFile) { 14 | std::string temps("io_test_temp"); 15 | 16 | scoped_fd in(MakeTemp(temps)); 17 | for (uint64_t i = 0; i < 100000; ++i) { 18 | WriteOrThrow(in.get(), &i, sizeof(uint64_t)); 19 | } 20 | SeekOrThrow(in.get(), 0); 21 | scoped_fd out(MakeTemp(temps)); 22 | 23 | ChainConfig config; 24 | config.entry_size = 8; 25 | config.total_memory = 1024; 26 | config.block_count = 10; 27 | 28 | Chain(config) >> PRead(in.get()) >> Write(out.get()); 29 | 30 | SeekOrThrow(out.get(), 0); 31 | for (uint64_t i = 0; i < 100000; ++i) { 32 | uint64_t got; 33 | ReadOrThrow(out.get(), &got, sizeof(uint64_t)); 34 | BOOST_CHECK_EQUAL(i, got); 35 | } 36 | } 37 | 38 | }}} // namespaces 39 | -------------------------------------------------------------------------------- /util/stream/line_input.cc: -------------------------------------------------------------------------------- 1 | #include "line_input.hh" 2 | 3 | #include "../exception.hh" 4 | #include "../file.hh" 5 | #include "../read_compressed.hh" 6 | #include "chain.hh" 7 | 8 | #include 9 | #include 10 | 11 | namespace util { namespace stream { 12 | 13 | void LineInput::Run(const ChainPosition &position) { 14 | ReadCompressed reader(fd_); 15 | // Holding area for beginning of line to be placed in next block. 16 | std::vector carry; 17 | 18 | for (Link block(position); ; ++block) { 19 | char *to = static_cast(block->Get()); 20 | char *begin = to; 21 | char *end = to + position.GetChain().BlockSize(); 22 | std::copy(carry.begin(), carry.end(), to); 23 | to += carry.size(); 24 | while (to != end) { 25 | std::size_t got = reader.Read(to, end - to); 26 | if (!got) { 27 | // EOF 28 | block->SetValidSize(to - begin); 29 | ++block; 30 | block.Poison(); 31 | return; 32 | } 33 | to += got; 34 | } 35 | 36 | // Find the last newline. 37 | char *newline; 38 | for (newline = to - 1; ; --newline) { 39 | UTIL_THROW_IF(newline < begin, Exception, "Did not find a newline in " << position.GetChain().BlockSize() << " bytes of input of " << NameFromFD(fd_) << ". Is this a text file?"); 40 | if (*newline == '\n') break; 41 | } 42 | 43 | // Copy everything after the last newline to the carry. 44 | carry.clear(); 45 | carry.resize(to - (newline + 1)); 46 | std::copy(newline + 1, to, &*carry.begin()); 47 | 48 | block->SetValidSize(newline + 1 - begin); 49 | } 50 | } 51 | 52 | }} // namespaces 53 | -------------------------------------------------------------------------------- /util/stream/line_input.hh: -------------------------------------------------------------------------------- 1 | #ifndef UTIL_STREAM_LINE_INPUT_H 2 | #define UTIL_STREAM_LINE_INPUT_H 3 | namespace util {namespace stream { 4 | 5 | class ChainPosition; 6 | 7 | /* Worker that reads input into blocks, ensuring that blocks contain whole 8 | * lines. Assumes that the maximum size of a line is less than the block size 9 | */ 10 | class LineInput { 11 | public: 12 | // Takes ownership upon thread execution. 13 | explicit LineInput(int fd); 14 | 15 | void Run(const ChainPosition &position); 16 | 17 | private: 18 | int fd_; 19 | }; 20 | 21 | }} // namespaces 22 | #endif // UTIL_STREAM_LINE_INPUT_H 23 | -------------------------------------------------------------------------------- /util/stream/multi_progress.cc: -------------------------------------------------------------------------------- 1 | #include "multi_progress.hh" 2 | 3 | // TODO: merge some functionality with the simple progress bar? 4 | #include "../ersatz_progress.hh" 5 | 6 | #include 7 | #include 8 | 9 | #include 10 | 11 | #if !defined(_WIN32) && !defined(_WIN64) 12 | #include 13 | #endif 14 | 15 | namespace util { namespace stream { 16 | 17 | namespace { 18 | const char kDisplayCharacters[] = "-+*#0123456789"; 19 | 20 | uint64_t Next(unsigned char stone, uint64_t complete) { 21 | return (static_cast(stone + 1) * complete + MultiProgress::kWidth - 1) / MultiProgress::kWidth; 22 | } 23 | 24 | } // namespace 25 | 26 | MultiProgress::MultiProgress() : active_(false), complete_(std::numeric_limits::max()), character_handout_(0) {} 27 | 28 | MultiProgress::~MultiProgress() { 29 | if (active_ && complete_ != std::numeric_limits::max()) 30 | std::cerr << '\n'; 31 | } 32 | 33 | void MultiProgress::Activate() { 34 | active_ = 35 | #if !defined(_WIN32) && !defined(_WIN64) 36 | // Is stderr a terminal? 37 | (isatty(2) == 1) 38 | #else 39 | true 40 | #endif 41 | ; 42 | } 43 | 44 | void MultiProgress::SetTarget(uint64_t complete) { 45 | if (!active_) return; 46 | complete_ = complete; 47 | if (!complete) complete_ = 1; 48 | memset(display_, 0, sizeof(display_)); 49 | character_handout_ = 0; 50 | std::cerr << kProgressBanner; 51 | } 52 | 53 | WorkerProgress MultiProgress::Add() { 54 | if (!active_) 55 | return WorkerProgress(std::numeric_limits::max(), *this, '\0'); 56 | std::size_t character_index; 57 | { 58 | boost::unique_lock lock(mutex_); 59 | character_index = character_handout_++; 60 | if (character_handout_ == sizeof(kDisplayCharacters) - 1) 61 | character_handout_ = 0; 62 | } 63 | return WorkerProgress(Next(0, complete_), *this, kDisplayCharacters[character_index]); 64 | } 65 | 66 | void MultiProgress::Finished() { 67 | if (!active_ || complete_ == std::numeric_limits::max()) return; 68 | std::cerr << '\n'; 69 | complete_ = std::numeric_limits::max(); 70 | } 71 | 72 | void MultiProgress::Milestone(WorkerProgress &worker) { 73 | if (!active_ || complete_ == std::numeric_limits::max()) return; 74 | unsigned char stone = std::min(static_cast(kWidth), worker.current_ * kWidth / complete_); 75 | for (char *i = &display_[worker.stone_]; i < &display_[stone]; ++i) { 76 | *i = worker.character_; 77 | } 78 | worker.next_ = Next(stone, complete_); 79 | worker.stone_ = stone; 80 | { 81 | boost::unique_lock lock(mutex_); 82 | std::cerr << '\r' << display_ << std::flush; 83 | } 84 | } 85 | 86 | }} // namespaces 87 | -------------------------------------------------------------------------------- /util/stream/multi_progress.hh: -------------------------------------------------------------------------------- 1 | /* Progress bar suitable for chains of workers */ 2 | #ifndef UTIL_STREAM_MULTI_PROGRESS_H 3 | #define UTIL_STREAM_MULTI_PROGRESS_H 4 | 5 | #include 6 | 7 | #include 8 | #include 9 | 10 | namespace util { namespace stream { 11 | 12 | class WorkerProgress; 13 | 14 | class MultiProgress { 15 | public: 16 | static const unsigned char kWidth = 100; 17 | 18 | MultiProgress(); 19 | 20 | ~MultiProgress(); 21 | 22 | // Turns on showing (requires SetTarget too). 23 | void Activate(); 24 | 25 | void SetTarget(uint64_t complete); 26 | 27 | WorkerProgress Add(); 28 | 29 | void Finished(); 30 | 31 | private: 32 | friend class WorkerProgress; 33 | void Milestone(WorkerProgress &worker); 34 | 35 | bool active_; 36 | 37 | uint64_t complete_; 38 | 39 | boost::mutex mutex_; 40 | 41 | // \0 at the end. 42 | char display_[kWidth + 1]; 43 | 44 | std::size_t character_handout_; 45 | 46 | MultiProgress(const MultiProgress &); 47 | MultiProgress &operator=(const MultiProgress &); 48 | }; 49 | 50 | class WorkerProgress { 51 | public: 52 | // Default contrutor must be initialized with operator= later. 53 | WorkerProgress() : parent_(NULL) {} 54 | 55 | // Not threadsafe for the same worker by default. 56 | WorkerProgress &operator++() { 57 | if (++current_ >= next_) { 58 | parent_->Milestone(*this); 59 | } 60 | return *this; 61 | } 62 | 63 | WorkerProgress &operator+=(uint64_t amount) { 64 | current_ += amount; 65 | if (current_ >= next_) { 66 | parent_->Milestone(*this); 67 | } 68 | return *this; 69 | } 70 | 71 | private: 72 | friend class MultiProgress; 73 | WorkerProgress(uint64_t next, MultiProgress &parent, char character) 74 | : current_(0), next_(next), parent_(&parent), stone_(0), character_(character) {} 75 | 76 | uint64_t current_, next_; 77 | 78 | MultiProgress *parent_; 79 | 80 | // Previous milestone reached. 81 | unsigned char stone_; 82 | 83 | // Character to display in bar. 84 | char character_; 85 | }; 86 | 87 | }} // namespaces 88 | 89 | #endif // UTIL_STREAM_MULTI_PROGRESS_H 90 | -------------------------------------------------------------------------------- /util/stream/rewindable_stream_test.cc: -------------------------------------------------------------------------------- 1 | #include "io.hh" 2 | 3 | #include "rewindable_stream.hh" 4 | #include "../file.hh" 5 | 6 | #define BOOST_TEST_MODULE RewindableStreamTest 7 | #include 8 | 9 | namespace util { 10 | namespace stream { 11 | namespace { 12 | 13 | BOOST_AUTO_TEST_CASE(RewindableStreamTest) { 14 | scoped_fd in(MakeTemp("io_test_temp")); 15 | for (uint64_t i = 0; i < 100000; ++i) { 16 | WriteOrThrow(in.get(), &i, sizeof(uint64_t)); 17 | } 18 | SeekOrThrow(in.get(), 0); 19 | 20 | ChainConfig config; 21 | config.entry_size = 8; 22 | config.total_memory = 100; 23 | config.block_count = 6; 24 | 25 | Chain chain(config); 26 | RewindableStream s; 27 | chain >> Read(in.get()) >> s >> kRecycle; 28 | uint64_t i = 0; 29 | for (; s; ++s, ++i) { 30 | BOOST_CHECK_EQUAL(i, *static_cast(s.Get())); 31 | if (100000UL - i == 2) 32 | s.Mark(); 33 | } 34 | BOOST_CHECK_EQUAL(100000ULL, i); 35 | s.Rewind(); 36 | BOOST_CHECK_EQUAL(100000ULL - 2, *static_cast(s.Get())); 37 | } 38 | 39 | } 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /util/stream/sort_test.cc: -------------------------------------------------------------------------------- 1 | #include "sort.hh" 2 | 3 | #define BOOST_TEST_MODULE SortTest 4 | #include 5 | 6 | #include 7 | 8 | #include 9 | 10 | namespace util { namespace stream { namespace { 11 | 12 | struct CompareUInt64 : public std::binary_function { 13 | bool operator()(const void *first, const void *second) const { 14 | return *static_cast(first) < *reinterpret_cast(second); 15 | } 16 | }; 17 | 18 | const uint64_t kSize = 100000; 19 | 20 | struct Putter { 21 | Putter(std::vector &shuffled) : shuffled_(shuffled) {} 22 | 23 | void Run(const ChainPosition &position) { 24 | Stream put_shuffled(position); 25 | for (uint64_t i = 0; i < shuffled_.size(); ++i, ++put_shuffled) { 26 | *static_cast(put_shuffled.Get()) = shuffled_[i]; 27 | } 28 | put_shuffled.Poison(); 29 | } 30 | std::vector &shuffled_; 31 | }; 32 | 33 | BOOST_AUTO_TEST_CASE(FromShuffled) { 34 | std::vector shuffled; 35 | shuffled.reserve(kSize); 36 | for (uint64_t i = 0; i < kSize; ++i) { 37 | shuffled.push_back(i); 38 | } 39 | std::random_shuffle(shuffled.begin(), shuffled.end()); 40 | 41 | ChainConfig config; 42 | config.entry_size = 8; 43 | config.total_memory = 800; 44 | config.block_count = 3; 45 | 46 | SortConfig merge_config; 47 | merge_config.temp_prefix = "sort_test_temp"; 48 | merge_config.buffer_size = 800; 49 | merge_config.total_memory = 3300; 50 | 51 | Chain chain(config); 52 | chain >> Putter(shuffled); 53 | BlockingSort(chain, merge_config, CompareUInt64(), NeverCombine()); 54 | Stream sorted; 55 | chain >> sorted >> kRecycle; 56 | for (uint64_t i = 0; i < kSize; ++i, ++sorted) { 57 | BOOST_CHECK_EQUAL(i, *static_cast(sorted.Get())); 58 | } 59 | BOOST_CHECK(!sorted); 60 | } 61 | 62 | }}} // namespaces 63 | -------------------------------------------------------------------------------- /util/stream/stream.hh: -------------------------------------------------------------------------------- 1 | #ifndef UTIL_STREAM_STREAM_H 2 | #define UTIL_STREAM_STREAM_H 3 | 4 | #include "chain.hh" 5 | 6 | #include 7 | 8 | #include 9 | #include 10 | 11 | namespace util { 12 | namespace stream { 13 | 14 | class Stream : boost::noncopyable { 15 | public: 16 | Stream() : current_(NULL), end_(NULL) {} 17 | 18 | void Init(const ChainPosition &position) { 19 | entry_size_ = position.GetChain().EntrySize(); 20 | block_size_ = position.GetChain().BlockSize(); 21 | block_it_.Init(position); 22 | StartBlock(); 23 | } 24 | 25 | explicit Stream(const ChainPosition &position) { 26 | Init(position); 27 | } 28 | 29 | operator bool() const { return current_ != NULL; } 30 | bool operator!() const { return current_ == NULL; } 31 | 32 | const void *Get() const { return current_; } 33 | void *Get() { return current_; } 34 | 35 | void Poison() { 36 | block_it_->SetValidSize(current_ - static_cast(block_it_->Get())); 37 | ++block_it_; 38 | block_it_.Poison(); 39 | } 40 | 41 | Stream &operator++() { 42 | assert(*this); 43 | assert(current_ < end_); 44 | current_ += entry_size_; 45 | if (current_ == end_) { 46 | ++block_it_; 47 | StartBlock(); 48 | } 49 | return *this; 50 | } 51 | 52 | private: 53 | void StartBlock() { 54 | for (; block_it_ && !block_it_->ValidSize(); ++block_it_) {} 55 | current_ = static_cast(block_it_->Get()); 56 | end_ = current_ + block_it_->ValidSize(); 57 | } 58 | 59 | // The following are pointers to raw memory 60 | // current_ is the current record 61 | // end_ is the end of the block (so we know when to move to the next block) 62 | uint8_t *current_, *end_; 63 | 64 | std::size_t entry_size_; 65 | std::size_t block_size_; 66 | 67 | Link block_it_; 68 | }; 69 | 70 | inline Chain &operator>>(Chain &chain, Stream &stream) { 71 | stream.Init(chain.Add()); 72 | return chain; 73 | } 74 | 75 | } // namespace stream 76 | } // namespace util 77 | #endif // UTIL_STREAM_STREAM_H 78 | -------------------------------------------------------------------------------- /util/stream/stream_test.cc: -------------------------------------------------------------------------------- 1 | #include "io.hh" 2 | 3 | #include "stream.hh" 4 | #include "../file.hh" 5 | 6 | #define BOOST_TEST_MODULE StreamTest 7 | #include 8 | 9 | #include 10 | 11 | namespace util { namespace stream { namespace { 12 | 13 | BOOST_AUTO_TEST_CASE(StreamTest) { 14 | scoped_fd in(MakeTemp("io_test_temp")); 15 | for (uint64_t i = 0; i < 100000; ++i) { 16 | WriteOrThrow(in.get(), &i, sizeof(uint64_t)); 17 | } 18 | SeekOrThrow(in.get(), 0); 19 | 20 | ChainConfig config; 21 | config.entry_size = 8; 22 | config.total_memory = 100; 23 | config.block_count = 12; 24 | 25 | Stream s; 26 | Chain chain(config); 27 | chain >> Read(in.get()) >> s >> kRecycle; 28 | uint64_t i = 0; 29 | for (; s; ++s, ++i) { 30 | BOOST_CHECK_EQUAL(i, *static_cast(s.Get())); 31 | } 32 | BOOST_CHECK_EQUAL(100000ULL, i); 33 | } 34 | 35 | }}} // namespaces 36 | -------------------------------------------------------------------------------- /util/stream/typed_stream.hh: -------------------------------------------------------------------------------- 1 | #ifndef UTIL_STREAM_TYPED_STREAM_H 2 | #define UTIL_STREAM_TYPED_STREAM_H 3 | // A typed wrapper to Stream for POD types. 4 | 5 | #include "stream.hh" 6 | 7 | namespace util { namespace stream { 8 | 9 | template class TypedStream : public Stream { 10 | public: 11 | // After using the default constructor, call Init (in the parent class) 12 | TypedStream() {} 13 | 14 | explicit TypedStream(const ChainPosition &position) : Stream(position) {} 15 | 16 | const T *operator->() const { return static_cast(Get()); } 17 | T *operator->() { return static_cast(Get()); } 18 | 19 | const T &operator*() const { return *static_cast(Get()); } 20 | T &operator*() { return *static_cast(Get()); } 21 | }; 22 | 23 | }} // namespaces 24 | 25 | #endif // UTIL_STREAM_TYPED_STREAM_H 26 | -------------------------------------------------------------------------------- /util/string_piece_hash.hh: -------------------------------------------------------------------------------- 1 | #ifndef UTIL_STRING_PIECE_HASH_H 2 | #define UTIL_STRING_PIECE_HASH_H 3 | 4 | #include 5 | #include "have.hh" 6 | #include "string_piece.hh" 7 | 8 | #include 9 | #include 10 | 11 | #ifdef HAVE_ICU 12 | U_NAMESPACE_BEGIN 13 | #endif 14 | 15 | inline size_t hash_value(const StringPiece &str) { 16 | return boost::hash_range(str.data(), str.data() + str.length()); 17 | } 18 | 19 | #ifdef HAVE_ICU 20 | U_NAMESPACE_END 21 | #endif 22 | 23 | /* Support for lookup of StringPiece in boost::unordered_map */ 24 | struct StringPieceCompatibleHash : public std::unary_function { 25 | size_t operator()(const StringPiece &str) const { 26 | return hash_value(str); 27 | } 28 | }; 29 | 30 | struct StringPieceCompatibleEquals : public std::binary_function { 31 | bool operator()(const StringPiece &first, const StringPiece &second) const { 32 | return first == second; 33 | } 34 | }; 35 | template typename T::const_iterator FindStringPiece(const T &t, const StringPiece &key) { 36 | #if BOOST_VERSION < 104200 37 | std::string temp(key.data(), key.size()); 38 | return t.find(temp); 39 | #else 40 | return t.find(key, StringPieceCompatibleHash(), StringPieceCompatibleEquals()); 41 | #endif 42 | } 43 | 44 | template typename T::iterator FindStringPiece(T &t, const StringPiece &key) { 45 | #if BOOST_VERSION < 104200 46 | std::string temp(key.data(), key.size()); 47 | return t.find(temp); 48 | #else 49 | return t.find(key, StringPieceCompatibleHash(), StringPieceCompatibleEquals()); 50 | #endif 51 | } 52 | 53 | #endif // UTIL_STRING_PIECE_HASH_H 54 | -------------------------------------------------------------------------------- /util/string_stream.hh: -------------------------------------------------------------------------------- 1 | #ifndef UTIL_STRING_STREAM_H 2 | #define UTIL_STRING_STREAM_H 3 | 4 | #include "fake_ostream.hh" 5 | 6 | #include 7 | #include 8 | 9 | namespace util { 10 | 11 | class StringStream : public FakeOStream { 12 | public: 13 | StringStream() {} 14 | 15 | StringStream &flush() { return *this; } 16 | 17 | StringStream &write(const void *data, std::size_t length) { 18 | out_.append(static_cast(data), length); 19 | return *this; 20 | } 21 | 22 | const std::string &str() const { return out_; } 23 | 24 | void str(const std::string &val) { out_ = val; } 25 | 26 | void swap(std::string &str) { std::swap(out_, str); } 27 | 28 | protected: 29 | friend class FakeOStream; 30 | char *Ensure(std::size_t amount) { 31 | std::size_t current = out_.size(); 32 | out_.resize(out_.size() + amount); 33 | return &out_[current]; 34 | } 35 | 36 | void AdvanceTo(char *to) { 37 | assert(to <= &*out_.end()); 38 | assert(to >= &*out_.begin()); 39 | out_.resize(to - &*out_.begin()); 40 | } 41 | 42 | private: 43 | std::string out_; 44 | }; 45 | 46 | } // namespace 47 | 48 | #endif // UTIL_STRING_STREAM_H 49 | -------------------------------------------------------------------------------- /util/string_stream_test.cc: -------------------------------------------------------------------------------- 1 | #define BOOST_LEXICAL_CAST_ASSUME_C_LOCALE 2 | #define BOOST_TEST_MODULE FakeOStreamTest 3 | 4 | #include "string_stream.hh" 5 | #include 6 | #include 7 | 8 | #include 9 | #include 10 | 11 | namespace util { namespace { 12 | 13 | template void TestEqual(const T value) { 14 | StringStream strme; 15 | strme << value; 16 | BOOST_CHECK_EQUAL(boost::lexical_cast(value), strme.str()); 17 | } 18 | 19 | template void TestCorners() { 20 | TestEqual(std::numeric_limits::max()); 21 | TestEqual(std::numeric_limits::min()); 22 | TestEqual(static_cast(0)); 23 | TestEqual(static_cast(-1)); 24 | TestEqual(static_cast(1)); 25 | } 26 | 27 | BOOST_AUTO_TEST_CASE(Integer) { 28 | TestCorners(); 29 | TestCorners(); 30 | TestCorners(); 31 | 32 | TestCorners(); 33 | TestCorners(); 34 | TestCorners(); 35 | 36 | TestCorners(); 37 | TestCorners(); 38 | TestCorners(); 39 | 40 | TestCorners(); 41 | TestCorners(); 42 | TestCorners(); 43 | 44 | TestCorners(); 45 | TestCorners(); 46 | TestCorners(); 47 | 48 | TestCorners(); 49 | } 50 | 51 | enum TinyEnum { EnumValue }; 52 | 53 | BOOST_AUTO_TEST_CASE(EnumCase) { 54 | TestEqual(EnumValue); 55 | } 56 | 57 | BOOST_AUTO_TEST_CASE(Strings) { 58 | TestEqual("foo"); 59 | const char *a = "bar"; 60 | TestEqual(a); 61 | StringPiece piece("abcdef"); 62 | TestEqual(piece); 63 | TestEqual(StringPiece()); 64 | 65 | char non_const[3]; 66 | non_const[0] = 'b'; 67 | non_const[1] = 'c'; 68 | non_const[2] = 0; 69 | 70 | StringStream out; 71 | out << "a" << non_const << 'c'; 72 | BOOST_CHECK_EQUAL("abcc", out.str()); 73 | 74 | // Now test as a separate object. 75 | StringStream stream; 76 | stream << "a" << non_const << 'c' << piece; 77 | BOOST_CHECK_EQUAL("abccabcdef", stream.str()); 78 | } 79 | 80 | }} // namespaces 81 | -------------------------------------------------------------------------------- /util/tokenize_piece_test.cc: -------------------------------------------------------------------------------- 1 | #include "tokenize_piece.hh" 2 | #include "string_piece.hh" 3 | 4 | #define BOOST_TEST_MODULE TokenIteratorTest 5 | #include 6 | 7 | #include 8 | 9 | namespace util { 10 | namespace { 11 | 12 | BOOST_AUTO_TEST_CASE(pipe_pipe_none) { 13 | const char str[] = "nodelimit at all"; 14 | TokenIter it(str, MultiCharacter("|||")); 15 | BOOST_REQUIRE(it); 16 | BOOST_CHECK_EQUAL(StringPiece(str), *it); 17 | ++it; 18 | BOOST_CHECK(!it); 19 | } 20 | BOOST_AUTO_TEST_CASE(pipe_pipe_two) { 21 | const char str[] = "|||"; 22 | TokenIter it(str, MultiCharacter("|||")); 23 | BOOST_REQUIRE(it); 24 | BOOST_CHECK_EQUAL(StringPiece(), *it); 25 | ++it; 26 | BOOST_REQUIRE(it); 27 | BOOST_CHECK_EQUAL(StringPiece(), *it); 28 | ++it; 29 | BOOST_CHECK(!it); 30 | } 31 | 32 | BOOST_AUTO_TEST_CASE(remove_empty) { 33 | const char str[] = "|||"; 34 | TokenIter it(str, MultiCharacter("|||")); 35 | BOOST_CHECK(!it); 36 | } 37 | 38 | BOOST_AUTO_TEST_CASE(remove_empty_keep) { 39 | const char str[] = " |||"; 40 | TokenIter it(str, MultiCharacter("|||")); 41 | BOOST_REQUIRE(it); 42 | BOOST_CHECK_EQUAL(StringPiece(" "), *it); 43 | ++it; 44 | BOOST_CHECK(!it); 45 | } 46 | 47 | } // namespace 48 | } // namespace util 49 | -------------------------------------------------------------------------------- /util/usage.hh: -------------------------------------------------------------------------------- 1 | #ifndef UTIL_USAGE_H 2 | #define UTIL_USAGE_H 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | namespace util { 9 | // Time in seconds since process started. Zero on unsupported platforms. 10 | double WallTime(); 11 | 12 | // User + system time, process-wide. 13 | double CPUTime(); 14 | 15 | // User + system time, thread-specific. 16 | double ThreadTime(); 17 | 18 | // Resident usage in bytes. 19 | uint64_t RSSMax(); 20 | 21 | void PrintUsage(std::ostream &to); 22 | 23 | // Determine how much physical memory there is. Return 0 on failure. 24 | uint64_t GuessPhysicalMemory(); 25 | 26 | // Parse a size like unix sort. Sadly, this means the default multiplier is K. 27 | uint64_t ParseSize(const std::string &arg); 28 | 29 | } // namespace util 30 | #endif // UTIL_USAGE_H 31 | --------------------------------------------------------------------------------