├── .clang-format ├── .gitmodules ├── CMakeLists.txt ├── LICENSE ├── Makefile ├── README.md ├── extern ├── catch.hpp ├── mew.h └── tblr.h ├── koan.cpp ├── koan.png ├── koan ├── cli.h ├── def.h ├── indexmap.h ├── reader.h ├── sample.h ├── sigmoid.h ├── timer.h ├── trainer.h └── util.h ├── tests ├── test_gradcheck.cpp └── test_utils.cpp ├── word2vec_train_times_cbow.png └── word2vec_train_times_sg.png /.clang-format: -------------------------------------------------------------------------------- 1 | Language: Cpp 2 | AccessModifierOffset: -1 3 | AlignAfterOpenBracket: Align 4 | AlignConsecutiveAssignments: false 5 | AlignConsecutiveDeclarations: false 6 | AlignEscapedNewlines: Right 7 | AlignOperands: true 8 | AlignTrailingComments: true 9 | AllowAllParametersOfDeclarationOnNextLine: false 10 | AllowShortBlocksOnASingleLine: true 11 | AllowShortCaseLabelsOnASingleLine: true 12 | AllowShortFunctionsOnASingleLine: Inline 13 | AllowShortIfStatementsOnASingleLine: true 14 | AllowShortLoopsOnASingleLine: true 15 | AlwaysBreakAfterDefinitionReturnType: None 16 | AlwaysBreakAfterReturnType: None 17 | AlwaysBreakBeforeMultilineStrings: false 18 | AlwaysBreakTemplateDeclarations: true 19 | BinPackArguments: false 20 | BinPackParameters: false 21 | BraceWrapping: 22 | AfterClass: false 23 | AfterControlStatement: false 24 | AfterEnum: false 25 | AfterFunction: false 26 | AfterNamespace: false 27 | AfterObjCDeclaration: false 28 | AfterStruct: false 29 | AfterUnion: false 30 | BeforeCatch: false 31 | BeforeElse: false 32 | IndentBraces: false 33 | SplitEmptyFunction: false 34 | SplitEmptyRecord: false 35 | SplitEmptyNamespace: false 36 | BreakBeforeBinaryOperators: None 37 | BreakBeforeBraces: Custom 38 | BreakBeforeInheritanceComma: false 39 | BreakBeforeTernaryOperators: true 40 | BreakConstructorInitializersBeforeComma: false 41 | BreakConstructorInitializers: BeforeColon 42 | BreakAfterJavaFieldAnnotations: false 43 | BreakStringLiterals: false 44 | ColumnLimit: 80 45 | CommentPragmas: '' 46 | CompactNamespaces: false 47 | ConstructorInitializerAllOnOneLineOrOnePerLine: true 48 | ConstructorInitializerIndentWidth: 4 49 | ContinuationIndentWidth: 4 50 | Cpp11BracedListStyle: true 51 | DerivePointerAlignment: false 52 | DisableFormat: false 53 | ExperimentalAutoDetectBinPacking: false 54 | FixNamespaceComments: true 55 | IncludeCategories: 56 | - Regex: 'koan' 57 | Priority: 1 58 | - Regex: '\.' 59 | Priority: 2 60 | IndentCaseLabels: false 61 | IndentWidth: 2 62 | IndentWrappedFunctionNames: false 63 | KeepEmptyLinesAtTheStartOfBlocks: true 64 | MaxEmptyLinesToKeep: 1 65 | NamespaceIndentation: None 66 | PenaltyBreakAssignment: 2 67 | PenaltyBreakBeforeFirstCallParameter: 19 68 | PenaltyBreakComment: 300 69 | PenaltyBreakFirstLessLess: 120 70 | PenaltyBreakString: 1000 71 | PenaltyExcessCharacter: 1000000 72 | PenaltyReturnTypeOnItsOwnLine: 60 73 | PointerAlignment: Left 74 | ReflowComments: true 75 | SortIncludes: true 76 | SortUsingDeclarations: true 77 | SpaceAfterCStyleCast: false 78 | SpaceAfterTemplateKeyword: true 79 | SpaceBeforeAssignmentOperators: true 80 | SpaceBeforeParens: ControlStatements 81 | SpaceInEmptyParentheses: false 82 | SpacesBeforeTrailingComments: 1 83 | SpacesInAngles: false 84 | SpacesInContainerLiterals: true 85 | SpacesInCStyleCastParentheses: false 86 | SpacesInParentheses: false 87 | SpacesInSquareBrackets: false 88 | Standard: Cpp11 89 | TabWidth: 2 90 | UseTab: Never 91 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "eigen"] 2 | path = eigen 3 | url = https://gitlab.com/libeigen/eigen 4 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.10) 2 | 3 | project(koan) 4 | 5 | set(CMAKE_CXX_STANDARD 17) 6 | set(CMAKE_CXX_STANDARD_REQUIRED True) 7 | 8 | add_executable(koan koan.cpp) 9 | add_executable(test_utils tests/test_utils.cpp) 10 | add_executable(test_gradcheck tests/test_gradcheck.cpp) 11 | 12 | include_directories("${PROJECT_SOURCE_DIR}/") 13 | include_directories("${PROJECT_SOURCE_DIR}/eigen/") 14 | 15 | target_include_directories(test_utils PUBLIC "${PROJECT_SOURCE_DIR}/extern") 16 | target_include_directories(test_gradcheck PUBLIC "${PROJECT_SOURCE_DIR}/extern") 17 | 18 | add_compile_options(-Wall -Wextra -Werror) 19 | 20 | if(KOAN_ENABLE_ZIP) 21 | target_compile_options(koan PUBLIC -Ofast -march=native -mtune=native -DKOAN_ENABLE_ZIP) 22 | else() 23 | target_compile_options(koan PUBLIC -Ofast -march=native -mtune=native) 24 | endif() 25 | 26 | set(THREADS_PREFER_PTHREAD_FLAG ON) 27 | find_package(Threads REQUIRED) 28 | 29 | if(KOAN_ENABLE_ZIP) 30 | find_package(ZLIB REQUIRED) 31 | target_link_libraries(koan PRIVATE Threads::Threads ZLIB::ZLIB) 32 | else() 33 | target_link_libraries(koan PRIVATE Threads::Threads) 34 | endif() 35 | 36 | 37 | install(TARGETS koan DESTINATION bin) 38 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright 2020 Bloomberg Finance L.P. 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | 204 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | CXX = g++ 2 | 3 | BUILD_PATH = ./build 4 | EIGEN_INCLUDE = ./eigen/ 5 | 6 | INCLUDES = -I$(EIGEN_INCLUDE) -I./ 7 | 8 | CXXFLAGS = -std=c++17 -Wall -Wextra -pthread 9 | 10 | ZIPFLAGS = -lz -DKOAN_ENABLE_ZIP 11 | 12 | OPTFLAGS = -Ofast -march=native -mtune=native 13 | DEBUGFLAGS = -g -O0 14 | 15 | build_path: 16 | @mkdir -p $(BUILD_PATH) 17 | 18 | % : %.cpp build_path 19 | $(CXX) $< $(CXXFLAGS) ${ZIPFLAGS} $(OPTFLAGS) $(INCLUDES) -o $(BUILD_PATH)/$@ 20 | 21 | debug : koan.cpp build_path 22 | $(CXX) $< $(CXXFLAGS) ${ZIPFLAGS} $(DEBUGFLAGS) $(INCLUDES) -o $(BUILD_PATH)/koan 23 | 24 | test_utils : tests/test_utils.cpp build_path 25 | $(CXX) $< $(CXXFLAGS) ${ZIPFLAGS} $(DEBUGFLAGS) $(INCLUDES) -I./extern/ -o $(BUILD_PATH)/test_utils 26 | 27 | test_gradcheck : tests/test_gradcheck.cpp build_path 28 | $(CXX) $< $(CXXFLAGS) ${ZIPFLAGS} $(DEBUGFLAGS) $(INCLUDES) -I./extern/ -o $(BUILD_PATH)/test_gradcheck 29 | 30 | all: koan test_utils test_gradcheck 31 | 32 | clean: 33 | rm -rf $(BUILD_PATH) 34 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 | > ... the Zen attitude is that words and truth are incompatible, or at least that no words can capture truth. 4 | > 5 | > Douglas R. Hofstadter 6 | 7 | A word2vec negative sampling implementation with correct CBOW update. kōan only depends on Eigen. 8 | 9 | _Authors_: Ozan İrsoy, Adrian Benton, Karl Stratos 10 | 11 | Thanks to Cyril Khazan for helping kōan better scale to many threads. 12 | 13 | ## Menu 14 | 15 | - [Rationale](#rationale) 16 | - [Building](#building) 17 | - [Quick start](#quick-start) 18 | - [Installation](#installation) 19 | - [License](#license) 20 | 21 | ## Rationale 22 | 23 | Although continuous bag of word (CBOW) embeddings can be trained more quickly than skipgram (SG) embeddings, it is a common belief that SG embeddings tend to perform better in practice. This was observed by the original authors of Word2Vec [1] and also in subsequent work [2]. However, we found that popular implementations of word2vec with negative sampling such as [word2vec](https://github.com/tmikolov/word2vec/) and [gensim](https://github.com/RaRe-Technologies/gensim/) do not implement the CBOW update correctly, thus potentially leading to misconceptions about the performance of CBOW embeddings when trained correctly. 24 | 25 | We release kōan so that others can efficiently train CBOW embeddings using the corrected weight update. See this [technical report](https://arxiv.org/abs/2012.15332) for benchmarks of kōan vs. gensim word2vec negative sampling implementations. If you use kōan to learn word embeddings for your own work, please cite: 26 | 27 | > Ozan İrsoy, Adrian Benton, and Karl Stratos. "Corrected CBOW Performs as well as Skip-gram." The 2nd Workshop on Insights from Negative Results in NLP (__2021__). 28 | 29 | [1] Tomas Mikolov, Ilya Sutskever, Kai Chen, Greg S Corrado, and Jeff Dean. Distributed representations of words and phrases and their compositionality. In Advances in neural information processing systems, pages 3111–3119, 2013. 30 | 31 | [2] Karl Stratos, Michael Collins, and Daniel Hsu. Model-based word embeddings from decompositions of count matrices. In Proceedings of the 53rd Annual Meeting of the Association for Computational Linguistics and the 7th International Joint Conference on Natural Language Processing 32 | (Volume 1: Long Papers), pages 1282–1291, 2015. 33 | 34 | See [here](https://doi.org/10.5281/zenodo.5542319) for kōan embeddings trained on the English cleaned common crawl corpus (C4). 35 | 36 | ## Building 37 | 38 | You need a C++17 supporting compiler to build koan (tested with g++ 7.5.0, 8.4.0, 9.3.0, and clang 11.0.3). 39 | 40 | To build koan and all tests: 41 | ``` 42 | mkdir build 43 | cd build 44 | cmake .. 45 | cmake --build ./ 46 | ``` 47 | 48 | Run tests with (assuming you are still under `build`): 49 | ``` 50 | ./test_gradcheck 51 | ./test_utils 52 | ``` 53 | 54 | ## Installation 55 | 56 | Installation is as simple as placing the koan binary on your `PATH` 57 | (you might need sudo): 58 | 59 | ``` 60 | cmake --install ./ 61 | ``` 62 | 63 | ## Quick Start 64 | 65 | To train word embeddings on [Wikitext-2](https://blog.einstein.ai/the-wikitext-long-term-dependency-language-modeling-dataset/), first clone and build koan: 66 | 67 | ``` 68 | git clone --recursive git@github.com:bloomberg/koan.git 69 | cd koan 70 | mkdir build 71 | cd build 72 | cmake .. && cmake --build ./ 73 | cd .. 74 | ``` 75 | 76 | Download and unzip the Wikitext-2 corpus: 77 | 78 | ``` 79 | curl https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip --output wikitext-2-v1.zip 80 | unzip wikitext-2-v1.zip 81 | head -n 5 ./wikitext-2/wiki.train.tokens 82 | ``` 83 | 84 | And learn CBOW embeddings on the training fold with: 85 | 86 | ``` 87 | ./build/koan -V 2000000 \ 88 | --epochs 10 \ 89 | --dim 300 \ 90 | --negatives 5 \ 91 | --context-size 5 \ 92 | -l 0.075 \ 93 | --threads 16 \ 94 | --cbow true \ 95 | --min-count 2 \ 96 | --file ./wikitext-2/wiki.train.tokens 97 | ``` 98 | 99 | or skipgram embeddings by running with `--cbow false`. `./build/koan --help` for a full list of command-line arguments and descriptions. Learned embeddings will be saved to `embeddings_${CURRENT_TIMESTAMP}.txt` in the present working directory. 100 | 101 | ## License 102 | 103 | Please read the [LICENSE](LICENSE) file. 104 | 105 | ## Benchmarks 106 | 107 |

108 | 109 | See the [report](https://arxiv.org/abs/2012.15332) for more details. 110 | -------------------------------------------------------------------------------- /extern/mew.h: -------------------------------------------------------------------------------- 1 | /* 2 | ** Copyright 2020 Bloomberg Finance L.P. 3 | ** 4 | ** Licensed under the Apache License, Version 2.0 (the "License"); 5 | ** you may not use this file except in compliance with the License. 6 | ** You may obtain a copy of the License at 7 | ** 8 | ** http://www.apache.org/licenses/LICENSE-2.0 9 | ** 10 | ** Unless required by applicable law or agreed to in writing, software 11 | ** distributed under the License is distributed on an "AS IS" BASIS, 12 | ** WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | ** See the License for the specific language governing permissions and 14 | ** limitations under the License. 15 | */ 16 | 17 | #ifndef MEW_H 18 | #define MEW_H 19 | 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include 25 | #include 26 | #include 27 | #include 28 | #include 29 | 30 | namespace mew { 31 | 32 | using Strings = std::vector; 33 | using StringsList = std::vector; 34 | 35 | // double precision seconds 36 | using Duration = std::chrono::duration>; 37 | 38 | enum AnimationStyle : unsigned short { 39 | Ellipsis, 40 | Clock, 41 | Moon, 42 | Earth, 43 | Bar, 44 | Square, 45 | }; 46 | 47 | const static StringsList animation_stills_{ 48 | {". ", ".. ", "..."}, 49 | {"🕐", "🕜", "🕑", "🕝", "🕒", "🕞", "🕓", "🕟", "🕔", "🕠", "🕕", "🕡", 50 | "🕖", "🕢", "🕗", "🕣", "🕘", "🕤", "🕙", "🕥", "🕚", "🕦", "🕛", "🕧"}, 51 | {"🌕", "🌖", "🌗", "🌘", "🌑", "🌒", "🌓", "🌔"}, 52 | {"🌎", "🌍", "🌏"}, 53 | {"-", "/", "|", "\\"}, 54 | {"▖", "▘", "▝", "▗"}, 55 | }; 56 | 57 | enum ProgressBarStyle : unsigned short { Bars, Blocks, Arrow }; 58 | 59 | const static StringsList progress_partials_{ 60 | {"|"}, 61 | {"▏", "▎", "▍", "▌", "▋", "▊", "▉", "█"}, 62 | {">", "="}, 63 | }; 64 | 65 | enum class Speed : unsigned short { None, Last, Overall, Both }; 66 | 67 | template 68 | class AsyncDisplay { 69 | private: 70 | Duration period_; 71 | std::unique_ptr displayer_; 72 | std::condition_variable completion_; 73 | std::mutex completion_m_; 74 | bool complete_ = false; 75 | 76 | std::string message_; 77 | std::ostream& out_; 78 | 79 | // Render animation, progress bar, etc. Needs to be specialized. 80 | void render_(std::ostream& out) { static_cast(*this).render_(out); } 81 | 82 | // Display message (maybe with animation, progress bar, etc) 83 | void display_() { 84 | out_ << "\r"; 85 | render_(out_); 86 | out_ << std::flush; 87 | } 88 | 89 | protected: 90 | void render_message_(std::ostream& out) { 91 | if (not message_.empty()) { out << message_ << " "; } 92 | } 93 | 94 | public: 95 | AsyncDisplay(std::string message = "", 96 | double period = 1, 97 | std::ostream& out = std::cout) 98 | : period_(period), message_(std::move(message)), out_(out) {} 99 | 100 | AsyncDisplay(std::string message, 101 | Duration period, 102 | std::ostream& out = std::cout) 103 | : period_(period), message_(std::move(message)), out_(out) {} 104 | 105 | AsyncDisplay(const AsyncDisplay& other) 106 | : period_(other.period_), 107 | complete_(other.complete_), 108 | message_(other.message_), 109 | out_(other.out_) {} 110 | 111 | AsyncDisplay(AsyncDisplay&& other) 112 | : period_(std::move(other.period_)), 113 | complete_(std::move(other.complete_)), 114 | message_(std::move(other.message_)), 115 | out_(other.out_) {} 116 | 117 | void start() { 118 | displayer_ = std::make_unique([this]() { 119 | display_(); 120 | while (true) { 121 | std::unique_lock lock(completion_m_); 122 | completion_.wait_for(lock, period_); 123 | display_(); 124 | if (complete_) { break; } 125 | } 126 | }); 127 | } 128 | 129 | void done() { 130 | if (not displayer_) { return; } // already done() before; noop 131 | { 132 | std::lock_guard lock(completion_m_); 133 | complete_ = true; 134 | } 135 | completion_.notify_all(); 136 | displayer_->join(); 137 | displayer_.reset(); 138 | out_ << std::endl; 139 | } 140 | 141 | template 142 | friend class Composite; 143 | }; 144 | 145 | class Animation : public AsyncDisplay { 146 | private: 147 | unsigned short frame_ = 0; 148 | const Strings& stills_; 149 | 150 | void render_(std::ostream& out) { 151 | this->render_message_(out); 152 | out << stills_[frame_] << " "; 153 | frame_ = (frame_ + 1) % stills_.size(); 154 | } 155 | 156 | public: 157 | using Style = AnimationStyle; 158 | 159 | Animation(std::string message = "", 160 | Style style = Ellipsis, 161 | double period = 1, 162 | std::ostream& out = std::cout) 163 | : AsyncDisplay(message, period, out), 164 | stills_(animation_stills_[static_cast(style)]) {} 165 | 166 | Animation(const Animation&) = default; 167 | Animation(Animation&&) = default; 168 | 169 | friend class AsyncDisplay; 170 | template 171 | friend class Composite; 172 | }; 173 | 174 | template 175 | class Composite : public AsyncDisplay> { 176 | private: 177 | LeftDisplay left_; 178 | RightDisplay right_; 179 | 180 | void render_(std::ostream& out) { 181 | left_.render_(out); 182 | out << " "; 183 | right_.render_(out); 184 | } 185 | 186 | public: 187 | Composite(LeftDisplay left, RightDisplay right) 188 | : AsyncDisplay>(left.message_, 189 | left.period_, 190 | left.out_), 191 | left_(std::move(left)), 192 | right_(std::move(right)) {} 193 | 194 | friend class AsyncDisplay>; 195 | template 196 | friend class Composite; 197 | }; 198 | 199 | template 200 | auto operator|(LeftDisplay left, RightDisplay right) { 201 | return Composite(std::move(left), 202 | std::move(right)); 203 | } 204 | 205 | template 206 | struct ProgressTraits { 207 | using value_type = Progress; 208 | }; 209 | 210 | template 211 | struct ProgressTraits> { 212 | using value_type = Progress; 213 | }; 214 | 215 | template 216 | class Speedometer { 217 | private: 218 | Progress& progress_; // Current amount of work done 219 | Speed speed_; // Time period to compute speed over 220 | std::string unit_of_speed_; // unit (message) to display alongside speed 221 | 222 | using Clock = std::chrono::system_clock; 223 | using Time = std::chrono::time_point; 224 | 225 | Time start_time_, last_start_time_; 226 | typename ProgressTraits::value_type last_progress_{0}; 227 | 228 | public: 229 | void render_speed(std::ostream& out) { 230 | if (speed_ != Speed::None) { 231 | std::stringstream ss; // use local stream to avoid disturbing `out` with 232 | // std::fixed and std::setprecision 233 | Duration dur = (Clock::now() - start_time_); 234 | Duration dur2 = (Clock::now() - last_start_time_); 235 | 236 | auto speed = double(progress_) / dur.count(); 237 | auto speed2 = double(progress_ - last_progress_) / dur2.count(); 238 | 239 | ss << std::fixed << std::setprecision(2) << "("; 240 | if (speed_ == Speed::Overall or speed_ == Speed::Both) { ss << speed; } 241 | if (speed_ == Speed::Both) { ss << " | "; } 242 | if (speed_ == Speed::Last or speed_ == Speed::Both) { ss << speed2; } 243 | ss << " " << unit_of_speed_ << ") "; 244 | 245 | out << ss.str(); 246 | 247 | last_progress_ = progress_; 248 | last_start_time_ = Clock::now(); 249 | } 250 | } 251 | 252 | void start() { start_time_ = Clock::now(); } 253 | 254 | Speedometer(Progress& progress, Speed speed, std::string unit_of_speed) 255 | : progress_(progress), 256 | speed_(speed), 257 | unit_of_speed_(std::move(unit_of_speed)) {} 258 | }; 259 | 260 | template 261 | class CounterDisplay : public AsyncDisplay> { 262 | private: 263 | Progress& progress_; // current amount of work done 264 | Speedometer speedom_; 265 | 266 | void render_counts_(std::ostream& out) { 267 | std::stringstream ss; 268 | if (std::is_floating_point::value) { 269 | ss << std::fixed << std::setprecision(2); 270 | } 271 | ss << progress_ << " "; 272 | out << ss.str(); 273 | } 274 | 275 | private: 276 | void render_(std::ostream& out) { 277 | this->render_message_(out); 278 | render_counts_(out); 279 | speedom_.render_speed(out); 280 | } 281 | 282 | public: 283 | CounterDisplay(Progress& progress, 284 | std::string message = "", 285 | std::string unit_of_speed = "", 286 | Speed speed = Speed::None, 287 | double period = 0.1, 288 | std::ostream& out = std::cout) 289 | : AsyncDisplay>(std::move(message), period, out), 290 | progress_(progress), 291 | speedom_(progress, speed, std::move(unit_of_speed)) {} 292 | 293 | void start() { 294 | static_cast>&>(*this).start(); 295 | speedom_.start(); 296 | } 297 | 298 | friend class AsyncDisplay>; 299 | template 300 | friend class Composite; 301 | }; 302 | 303 | template 304 | auto Counter(Progress& progress, Args&&... args) { 305 | return CounterDisplay(progress, std::forward(args)...); 306 | } 307 | 308 | template 309 | class ProgressBarDisplay : public AsyncDisplay> { 310 | private: 311 | Speedometer speedom_; 312 | Progress& progress_; // work done so far 313 | const static size_t width_ = 30; // width of progress bar 314 | size_t total_; // total work to be done 315 | bool counts_; // whether to display counts 316 | 317 | const Strings& partials_; // progress bar display strings 318 | 319 | void render_progress_bar_(std::ostream& out) { 320 | size_t on = width_ * progress_ / total_; 321 | size_t partial = 322 | partials_.size() * width_ * progress_ / total_ - partials_.size() * on; 323 | if (on >= width_) { 324 | on = width_; 325 | partial = 0; 326 | } 327 | assert(partial != partials_.size()); 328 | size_t off = width_ - on - size_t(partial > 0); 329 | 330 | // draw progress bar 331 | out << "|"; 332 | for (size_t i = 0; i < on; i++) { out << partials_.back(); } 333 | if (partial > 0) { out << partials_.at(partial - 1); } 334 | out << std::string(off, ' ') << "| "; 335 | } 336 | 337 | void render_counts_(std::ostream& out) { 338 | if (counts_) { out << progress_ << "/" << total_ << " "; } 339 | } 340 | 341 | void render_percentage_(std::ostream& out) { 342 | std::stringstream ss; 343 | ss << std::fixed << std::setprecision(2); 344 | ss.width(6); 345 | ss << std::right << progress_ * 100. / total_ << "% "; 346 | out << ss.str(); 347 | } 348 | 349 | void render_(std::ostream& out) { 350 | this->render_message_(out); 351 | render_percentage_(out); 352 | render_progress_bar_(out); 353 | render_counts_(out); 354 | speedom_.render_speed(out); 355 | } 356 | 357 | public: 358 | using Style = ProgressBarStyle; 359 | 360 | ProgressBarDisplay(Progress& progress, 361 | size_t total, 362 | std::string message = "", 363 | std::string unit_of_speed = "", 364 | Speed speed = Speed::None, 365 | bool counts = true, 366 | Style style = Blocks, 367 | double period = 0.1, 368 | std::ostream& out = std::cout) 369 | : AsyncDisplay>(std::move(message), 370 | period, 371 | out), 372 | speedom_(progress, speed, std::move(unit_of_speed)), 373 | progress_(progress), 374 | total_(total), 375 | counts_(counts), 376 | partials_(progress_partials_[static_cast(style)]) {} 377 | 378 | void start() { 379 | static_cast>&>(*this).start(); 380 | speedom_.start(); 381 | } 382 | 383 | friend class AsyncDisplay>; 384 | template 385 | friend class Composite; 386 | }; 387 | 388 | template 389 | auto ProgressBar(Progress& progress, Args&&... args) { 390 | return ProgressBarDisplay(progress, std::forward(args)...); 391 | } 392 | 393 | } // namespace mew 394 | 395 | #endif 396 | -------------------------------------------------------------------------------- /extern/tblr.h: -------------------------------------------------------------------------------- 1 | /* 2 | ** Copyright 2020 Bloomberg Finance L.P. 3 | ** 4 | ** Licensed under the Apache License, Version 2.0 (the "License"); 5 | ** you may not use this file except in compliance with the License. 6 | ** You may obtain a copy of the License at 7 | ** 8 | ** http://www.apache.org/licenses/LICENSE-2.0 9 | ** 10 | ** Unless required by applicable law or agreed to in writing, software 11 | ** distributed under the License is distributed on an "AS IS" BASIS, 12 | ** WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | ** See the License for the specific language governing permissions and 14 | ** limitations under the License. 15 | */ 16 | 17 | #ifndef TBLR_H 18 | #define TBLR_H 19 | 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include 25 | #include 26 | #include 27 | #include 28 | #include 29 | #include 30 | 31 | namespace tblr { 32 | 33 | enum Align : char { Left = 'l', Center = 'c', Right = 'r' }; 34 | enum LineSplitter { SingleLine, Naive, Space }; 35 | 36 | // typedefs 37 | using Row = std::vector; 38 | using Widths = std::vector; 39 | using Aligns = std::vector; 40 | 41 | // Class to end a row 42 | class Endr {}; 43 | const static Endr endr, endl; 44 | 45 | // Is first byte of a UTF8 character 46 | inline bool is_first_byte(const char& c) { 47 | // https://stackoverflow.com/a/4063229 48 | return (c & 0xc0) != 0x80; 49 | } 50 | 51 | // UTF8 length of a string 52 | inline size_t len(const std::string& s) { 53 | return std::count_if(s.begin(), s.end(), is_first_byte); 54 | } 55 | 56 | // UTF8 aware substring 57 | inline std::string 58 | substr(const std::string& s, size_t left = 0, size_t size = -1) { 59 | auto i = s.begin(); 60 | for (left++; i != s.end() and (left -= is_first_byte(*i)); i++) {} 61 | auto pos = i; 62 | for (size++; i != s.end() and (size -= is_first_byte(*i)); i++) {} 63 | return s.substr(pos - s.begin(), i - pos); 64 | } 65 | 66 | // Helper class to repeatedly use << to 67 | // construct a single table cell 68 | class Cell { 69 | private: 70 | std::stringstream ss_; 71 | 72 | public: 73 | template 74 | friend Cell& operator<<(Cell& c, const T& x) { 75 | c.ss_ << x; 76 | return c; 77 | } 78 | template 79 | friend Cell&& operator<<(Cell&& c, const T& x) { 80 | c.ss_ << x; 81 | return std::move(c); 82 | } 83 | std::string str() const { return ss_.str(); } 84 | }; 85 | 86 | // Delimiters for table layout 87 | struct ColSeparators { 88 | std::string left = ""; 89 | std::string mid = " "; 90 | std::string right = ""; 91 | }; 92 | 93 | class RowSeparator { 94 | public: 95 | virtual void print(std::ostream& out, 96 | const Widths& spec_widths, 97 | const Widths& widths, 98 | const Aligns& aligns) const = 0; 99 | virtual ~RowSeparator() {} 100 | }; 101 | 102 | // A row separator that does not align to columns (e.g. Latex's \hline) 103 | class RowSeparatorFlat : public RowSeparator { 104 | private: 105 | std::string sepr_; 106 | 107 | public: 108 | RowSeparatorFlat(std::string sepr = "") : sepr_(std::move(sepr)) {} 109 | 110 | void print(std::ostream& out, 111 | const Widths& /*spec_widths*/, 112 | const Widths& /*widths*/, 113 | const Aligns& /*aligns*/) const override { 114 | out << sepr_ << std::endl; 115 | } 116 | }; 117 | 118 | // Empty row separator 119 | class RowSeparatorEmpty : public RowSeparator { 120 | public: 121 | void print(std::ostream& /*out*/, 122 | const Widths& /*spec_widths*/, 123 | const Widths& /*widths*/, 124 | const Aligns& /*aligns*/) const override {} 125 | }; 126 | 127 | // A row separator that aligns to each cell/column (e.g. Markdown) 128 | class RowSeparatorColwise : public RowSeparator { 129 | private: 130 | ColSeparators col_sepr_; 131 | std::string filler_; 132 | 133 | public: 134 | RowSeparatorColwise(ColSeparators csep = {}, std::string fill = " ") 135 | : col_sepr_(std::move(csep)), filler_(std::move(fill)) { 136 | assert(not filler_.empty()); 137 | } 138 | 139 | void print(std::ostream& out, 140 | const Widths& spec_widths, 141 | const Widths& widths, 142 | const Aligns& /*aligns*/) const override { 143 | static auto extend = [](const std::string& s, const size_t width) { 144 | std::string rval; 145 | size_t lens = len(s); 146 | for (size_t _ = 0; _ < width / lens; _++) { rval += s; } 147 | rval += substr(s, 0, width % lens); 148 | return rval; 149 | }; 150 | 151 | out << col_sepr_.left; 152 | for (size_t i = 0; i < widths.size(); i++) { 153 | if (i > 0) { out << col_sepr_.mid; } 154 | size_t width = (i < spec_widths.size() and spec_widths[i] > 0) 155 | ? spec_widths[i] 156 | : widths[i]; 157 | out << extend(filler_, width); 158 | } 159 | out << col_sepr_.right << std::endl; 160 | } 161 | }; 162 | 163 | struct RowSeparators { 164 | std::shared_ptr top = std::make_shared(); 165 | std::shared_ptr header_mid = 166 | std::make_shared(); 167 | std::shared_ptr mid = std::make_shared(); 168 | std::shared_ptr bottom = std::make_shared(); 169 | }; 170 | 171 | struct Layout { 172 | ColSeparators col_sepr; 173 | RowSeparators row_sepr; 174 | }; 175 | 176 | // Main Table class 177 | class Table { 178 | public: 179 | using Row = std::vector; 180 | using Grid = std::vector; 181 | 182 | private: 183 | Grid data_; 184 | Row cur_row_; 185 | 186 | // Layout parameters and specs 187 | Widths spec_widths_; 188 | Aligns spec_aligns_; 189 | LineSplitter split_ = Naive; 190 | Layout layout_; 191 | 192 | // State 193 | Widths widths_; 194 | // bool printed_any_row_ = false; //TODO: to be used in online mode 195 | 196 | std::stringstream ss_; 197 | 198 | // Helpers 199 | static void aligned_print_(std::ostream& out, 200 | const std::string& s, 201 | size_t width, 202 | Align align); 203 | static std::string print_(std::ostream& out, 204 | const std::string& s, 205 | size_t width, 206 | Align align, 207 | LineSplitter ls); 208 | void print_row_(std::ostream& out, const Row& row) const; 209 | Row print_row_line_(std::ostream& out, const Row& row) const; 210 | 211 | public: 212 | Table& widths(Widths widths_) { 213 | spec_widths_ = std::move(widths_); 214 | return *this; 215 | } 216 | Table& aligns(Aligns aligns_) { 217 | spec_aligns_ = std::move(aligns_); 218 | return *this; 219 | } 220 | Table& multiline(LineSplitter mline) { 221 | split_ = std::move(mline); 222 | return *this; 223 | } 224 | Table& layout(Layout layout) { 225 | layout_ = std::move(layout); 226 | return *this; 227 | } 228 | Table& precision(const int n) { 229 | ss_ << std::setprecision(n); 230 | return *this; 231 | } 232 | Table& fixed() { 233 | ss_ << std::fixed; 234 | return *this; 235 | } 236 | 237 | template 238 | Table& operator<<(const T& x); 239 | void print(std::ostream& out = std::cout) const; 240 | }; 241 | 242 | template 243 | Table& Table::operator<<(const T& x) { 244 | // insert the value into the table as a string 245 | ss_ << x; 246 | cur_row_.push_back(ss_.str()); 247 | 248 | widths_.resize(std::max(widths_.size(), cur_row_.size()), 0); 249 | size_t& width = widths_[cur_row_.size() - 1]; 250 | for (std::string s; std::getline(ss_, s); width = std::max(width, len(s))) {} 251 | 252 | ss_.str(""); 253 | ss_.clear(); 254 | 255 | return *this; 256 | } 257 | 258 | template <> 259 | inline Table& Table::operator<<(const Endr&) { 260 | data_.push_back(std::move(cur_row_)); 261 | return *this; 262 | } 263 | 264 | template <> 265 | inline Table& Table::operator<<(const Cell& c) { 266 | return *this << c.str(); 267 | } 268 | 269 | // Preconditions: 270 | // - Single line (does not have \n in it) 271 | // - len(s) <= width 272 | inline void Table::aligned_print_(std::ostream& out, 273 | const std::string& s, 274 | size_t width, 275 | Align align) { 276 | size_t lens = len(s); 277 | assert(lens <= width and 278 | s.find('\n') == std::string::npos); // paranoid ¯\_(ツ)_/¯ 279 | 280 | if (align == Left) { 281 | out << s << std::string(width - lens, ' '); 282 | } else if (align == Center) { 283 | out << std::string((width - lens) / 2, ' ') << s 284 | << std::string((width - lens + 1) / 2, ' '); 285 | } else if (align == Right) { 286 | out << std::string(width - lens, ' ') << s; 287 | } 288 | } 289 | 290 | // print a string s in the given width and alignment, 291 | // return the remaining suffix string that did not fit 292 | inline std::string Table::print_(std::ostream& out, 293 | const std::string& s, 294 | size_t width, 295 | Align align, 296 | LineSplitter ls) { 297 | std::string head = s; 298 | std::string tail = ""; 299 | 300 | // split by '\n' 301 | size_t pos = s.find('\n'); 302 | if (pos != std::string::npos) { 303 | head = s.substr(0, pos); 304 | tail = s.substr(pos + 1); 305 | } 306 | 307 | // split by width 308 | if (len(head) > width) { 309 | head = substr(s, 0, width); 310 | tail = substr(s, width); 311 | if (ls == Space) { 312 | // split by space 313 | pos = head.rfind(' '); 314 | if (pos != std::string::npos) { 315 | head = s.substr(0, pos); 316 | tail = s.substr(pos + 1); 317 | } 318 | } 319 | } 320 | 321 | aligned_print_(out, head, width, align); 322 | return (ls == SingleLine) ? "" : tail; 323 | } 324 | 325 | inline Table::Row Table::print_row_line_(std::ostream& out, 326 | const Row& row) const { 327 | Row rval; 328 | out << layout_.col_sepr.left; 329 | for (size_t i = 0; i < row.size(); i++) { 330 | if (i > 0) { out << layout_.col_sepr.mid; } 331 | size_t width = (i < spec_widths_.size() and spec_widths_[i] > 0) 332 | ? spec_widths_[i] 333 | : widths_[i]; 334 | Align align = (i < spec_aligns_.size()) ? spec_aligns_[i] : Left; 335 | rval.push_back(print_(out, row[i], width, align, split_)); 336 | } 337 | out << layout_.col_sepr.right << std::endl; 338 | return rval; 339 | } 340 | 341 | inline void Table::print_row_(std::ostream& out, const Row& row) const { 342 | static auto empty = [](const Row& row) { 343 | return std::all_of( 344 | row.begin(), row.end(), std::mem_fn(&std::string::empty)); 345 | }; 346 | 347 | Row rval = row; 348 | while (not empty(rval = print_row_line_(out, rval))) {} 349 | } 350 | 351 | inline void Table::print(std::ostream& out) const { 352 | auto& row_sepr = layout_.row_sepr; 353 | row_sepr.top->print(out, spec_widths_, widths_, spec_aligns_); 354 | for (size_t i = 0; i < data_.size(); i++) { 355 | if (i == 1) { 356 | row_sepr.header_mid->print(out, spec_widths_, widths_, spec_aligns_); 357 | } else if (i > 1) { 358 | row_sepr.mid->print(out, spec_widths_, widths_, spec_aligns_); 359 | } 360 | print_row_(out, data_[i]); 361 | } 362 | row_sepr.bottom->print(out, spec_widths_, widths_, spec_aligns_); 363 | } 364 | 365 | inline std::ostream& operator<<(std::ostream& os, const Table& t) { 366 | t.print(os); 367 | return os; 368 | } 369 | 370 | // Predefined Layouts 371 | 372 | inline Layout simple_border(std::string left, 373 | std::string center, 374 | std::string right, 375 | std::string top, 376 | std::string header_mid, 377 | std::string mid, 378 | std::string bottom) { 379 | ColSeparators cs{std::move(left), std::move(center), std::move(right)}; 380 | RowSeparators rs{ 381 | std::make_shared(cs, std::move(top)), 382 | std::make_shared(cs, std::move(header_mid)), 383 | std::make_shared(cs, std::move(mid)), 384 | std::make_shared(cs, std::move(bottom))}; 385 | return {std::move(cs), std::move(rs)}; 386 | } 387 | 388 | inline Layout simple_border(std::string left, 389 | std::string center, 390 | std::string right, 391 | std::string header_mid) { 392 | ColSeparators cs{std::move(left), std::move(center), std::move(right)}; 393 | RowSeparators rs{ 394 | std::make_shared(), 395 | std::make_shared(cs, std::move(header_mid)), 396 | std::make_shared(), 397 | std::make_shared()}; 398 | return {std::move(cs), std::move(rs)}; 399 | } 400 | 401 | inline Layout 402 | simple_border(std::string left, std::string center, std::string right) { 403 | ColSeparators cs{std::move(left), std::move(center), std::move(right)}; 404 | RowSeparators rs{std::make_shared(), 405 | std::make_shared(), 406 | std::make_shared(), 407 | std::make_shared()}; 408 | return {std::move(cs), std::move(rs)}; 409 | } 410 | 411 | inline Layout markdown() { 412 | return simple_border("| ", " | ", " |", "-"); 413 | } 414 | 415 | inline Layout indented_list() { 416 | return simple_border(" ", " ", ""); 417 | } 418 | 419 | class LatexHeader : public RowSeparator { 420 | public: 421 | void print(std::ostream& out, 422 | const Widths& /*spec_widths*/, 423 | const Widths& /*widths*/, 424 | const Aligns& aligns) const override { 425 | out << R"(\begin{tabular}{)"; 426 | for (auto& a : aligns) { out << (char)a; } 427 | out << "}" << std::endl << R"(\hline)" << std::endl; 428 | } 429 | }; 430 | 431 | inline Layout latex() { 432 | ColSeparators cs{"", " & ", " \\\\"}; 433 | RowSeparators rs{ 434 | std::make_shared(), 435 | std::make_shared("\\hline"), 436 | std::make_shared(), 437 | std::make_shared("\\hline\n\\end{tabular}")}; 438 | return {std::move(cs), std::move(rs)}; 439 | } 440 | 441 | } // namespace tblr 442 | 443 | #endif 444 | -------------------------------------------------------------------------------- /koan.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | ** Copyright 2020 Bloomberg Finance L.P. 3 | ** 4 | ** Licensed under the Apache License, Version 2.0 (the "License"); 5 | ** you may not use this file except in compliance with the License. 6 | ** You may obtain a copy of the License at 7 | ** 8 | ** http://www.apache.org/licenses/LICENSE-2.0 9 | ** 10 | ** Unless required by applicable law or agreed to in writing, software 11 | ** distributed under the License is distributed on an "AS IS" BASIS, 12 | ** WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | ** See the License for the specific language governing permissions and 14 | ** limitations under the License. 15 | */ 16 | 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include 25 | #include 26 | #include 27 | #include 28 | #include 29 | 30 | #include 31 | 32 | #include "extern/mew.h" 33 | 34 | #include 35 | #include 36 | #include 37 | #include 38 | #include 39 | #include 40 | #include 41 | 42 | using namespace koan; 43 | 44 | auto build_vocab(const std::vector& fnames, 45 | const std::string& read_mode, 46 | bool enforce_max_line_length, 47 | bool no_progress) { 48 | std::unordered_map freqs; 49 | freqs.reserve(INITIAL_INDEX_SIZE); 50 | 51 | unsigned long long lines = 0; 52 | auto counter = 53 | mew::Counter(lines, "Building vocab", "lines/s", mew::Speed::Last, 1.); 54 | if (no_progress) { 55 | std::cout << "Building vocab..." << std::endl; 56 | } else { 57 | counter.start(); 58 | } 59 | 60 | Timer t; 61 | std::vector s; 62 | s.reserve(100); 63 | 64 | readlines( 65 | fnames, 66 | [&](const std::string_view& line) { 67 | s.clear(); 68 | split(s, line, ' '); 69 | for (auto& w : s) { freqs[w]++; } 70 | lines++; 71 | }, 72 | read_mode, 73 | enforce_max_line_length); 74 | 75 | if (not no_progress) { counter.done(); } 76 | std::cout << "Done in " << unsigned(t.s()) << "s." << std::endl; 77 | 78 | return std::make_tuple(freqs, lines); 79 | } 80 | 81 | void save_vocab_file( 82 | const std::string& vocab_load_path, 83 | const std::vector& ordered_vocab, 84 | const std::unordered_map& freqs) { 85 | std::cout << "Saving vocab file..." << std::endl; 86 | 87 | FILE* out = fopen(vocab_load_path.c_str(), "w"); 88 | KOAN_ASSERT(out); 89 | std::string buf; 90 | buf.reserve(MAX_LINE_LEN); 91 | for (auto& w : ordered_vocab) { 92 | buf.clear(); 93 | buf += w; 94 | buf += " "; 95 | buf += std::to_string(freqs.at(w)); 96 | buf += "\n"; 97 | fputs(buf.data(), out); 98 | } 99 | fclose(out); 100 | std::cout << "Done." << std::endl; 101 | } 102 | 103 | auto load_vocab_file(const std::string& vocab_load_path) { 104 | std::vector ordered_vocab; 105 | std::unordered_map freqs; 106 | 107 | std::vector s; 108 | s.reserve(2); 109 | unsigned long long last = std::numeric_limits::max(); 110 | 111 | std::cout << "Loading vocab file " + vocab_load_path + " ..." << std::endl; 112 | readlines( 113 | vocab_load_path, 114 | [&](const std::string_view& line) { 115 | s.clear(); 116 | split(s, line, ' '); 117 | KOAN_ASSERT(s.size() == 2, 118 | "Unexpected number of columns in vocab file!"); 119 | auto& word = s[0]; 120 | auto freq = std::stoull(s[1]); 121 | if (word == UNKSTR) { 122 | KOAN_ASSERT(ordered_vocab.empty(), 123 | "Only the first line of vocab file can be UNKSTR!"); 124 | } else { 125 | KOAN_ASSERT(freq <= last, 126 | "Vocab file should be in descending frequency order (" 127 | "except for UNKSTR, which should be at the top if it " 128 | "exists)!"); 129 | last = freq; 130 | } 131 | ordered_vocab.push_back(word); 132 | freqs[word] = freq; 133 | }, 134 | "text", 135 | true); 136 | std::cout << "Done." << std::endl; 137 | 138 | return std::make_tuple(ordered_vocab, freqs); 139 | } 140 | 141 | auto load_pretrained_embeddings(const std::string& pretrained_path, 142 | const std::string& read_mode, 143 | unsigned dim, 144 | bool enforce_max_line_length, 145 | bool no_progress) { 146 | std::unordered_map pretrained_table; 147 | long unsigned lines = 0; 148 | 149 | auto counter = mew::Counter( 150 | lines, "Reading pretrained embeddings", "lines/s", mew::Speed::Last, 1.); 151 | if (no_progress) { 152 | std::cout << "Reading pretrained embeddings..." << std::endl; 153 | } else { 154 | counter.start(); 155 | } 156 | 157 | std::vector s; 158 | s.reserve(100); 159 | 160 | readlines( 161 | pretrained_path, 162 | [&](const std::string_view& line) { 163 | s.clear(); 164 | split(s, line, ' '); 165 | KOAN_ASSERT(dim == (s.size() - 1), 166 | "Specified dimension doesn't match pretrained table!"); 167 | auto& word = s[0]; 168 | KOAN_ASSERT(pretrained_table.find(word) == pretrained_table.end(), 169 | "Pretrained table has duplicate entries!"); 170 | Vector v(dim); 171 | for (Vector::Index i = 0; i < v.size(); i++) { 172 | v[i] = std::stof(s[i + 1]); 173 | } 174 | pretrained_table.emplace(word, std::move(v)); 175 | lines++; 176 | }, 177 | read_mode, 178 | enforce_max_line_length); 179 | 180 | counter.done(); 181 | return pretrained_table; 182 | } 183 | 184 | int main(int argc, char** argv) { 185 | srand(123457); 186 | std::vector fnames; 187 | unsigned dim = 200; 188 | unsigned ctxs = 5; 189 | unsigned negatives = 5; 190 | unsigned num_threads = 1; 191 | unsigned epochs = 1; 192 | unsigned min_count = 1; 193 | bool discard = true; 194 | bool cbow = false; 195 | bool use_bad_update = false; 196 | Real downsample_th = 1e-3; 197 | Real init_lr = 0.025; // If cbow, initial learning rate 0.075 recommended. 198 | Real min_lr = 1e-4; 199 | Real ns_exponent = 0.75; 200 | size_t vocab_size = std::numeric_limits::max(); 201 | std::string vocab_load_path = ""; 202 | unsigned long long total_sentences = 0; 203 | size_t buffer_size = 500'000; 204 | std::string embedding_path = ""; 205 | bool shuffle = false; 206 | bool no_progress = false; 207 | bool partitioned = false; 208 | bool enforce_max_line_length = false; 209 | 210 | std::string pretrained_path; 211 | std::string continue_vocab = "union"; 212 | std::string read_mode = "auto"; 213 | 214 | unsigned start_lr_schedule_epoch = 0; 215 | unsigned max_lr_schedule_epochs = 0; 216 | 217 | Args args; 218 | args.add(fnames, "f,files", "paths", "Paths to training files", Required); 219 | args.add(dim, "d,dim", "n", "Word vector dimension"); 220 | args.add(ctxs, 221 | "c,context-size", 222 | "n", 223 | "One sided context size, excluding the center word"); 224 | args.add(negatives, 225 | "n,negatives", 226 | "n", 227 | "Number of negative samples for each positive"); 228 | args.add(init_lr, 229 | "l,learning-rate", 230 | "x", 231 | "(Starting) learning rate. 0.025 for skipgram and 0.075 " 232 | "for cbow is recommended.", 233 | SuggestRange(1e-3, 1e-1)); 234 | args.add(min_lr, 235 | "m,min-learning-rate", 236 | "x", 237 | "Minimum (ending) learning rate when linearly scheduling " 238 | "learning rate", 239 | SuggestRange(0., 1e-4)); 240 | args.add(min_count, 241 | "k,min-count", 242 | "n", 243 | "Do not use word identities if raw frequency count is less " 244 | "than n. See --discard"); 245 | args.add(discard, 246 | "i,discard", 247 | "true|false", 248 | "If true, discard rare words (see --min-count) else, " 249 | "convert them to UNK"); 250 | args.add(cbow, 251 | "b,cbow", 252 | "true|false", 253 | "If true, use cbow loss instead of skipgram"); 254 | args.add(use_bad_update, 255 | "u,use-bad-update", 256 | "true|false", 257 | "If true, use faulty CBOW update"); 258 | args.add( 259 | downsample_th, "o,downsample-threshold", "x", "Downsample threshold"); 260 | args.add(ns_exponent, 261 | "x,ns-exponent", 262 | "x", 263 | "Exponent for negative sampling distribution", 264 | RequireRange(0., 1.)); 265 | args.add(epochs, "e,epochs", "n", "Training epochs"); 266 | args.add(vocab_size, 267 | "V,vocab-size", 268 | "n", 269 | "Vocabulary size to pick top n words instead of all"); 270 | args.add(vocab_load_path, 271 | "a,vocab-load-path", 272 | "path", 273 | "If passed, load vocabulary from file and skip vocab build. " 274 | "If passed, continue_vocab option is ignored."); 275 | args.add(total_sentences, 276 | "I,total-sentences", 277 | "n", 278 | "If loading vocab from file (see vocab-path option), use this value " 279 | "as total number of sentences to measure percent completion."); 280 | args.add(num_threads, "t,threads", "n", "Number of worker threads"); 281 | args.add(buffer_size, 282 | "B,buffer-size", 283 | "n", 284 | "Buffer size in number of sentences. Memory footprint is in the " 285 | "order of buffer-size × avg. length of sentence. Larger buffer-size " 286 | "is bigger memory footprint but better shuffling."); 287 | args.add(embedding_path, 288 | "p,embedding-path", 289 | "path", 290 | "Path embeddings should be saved to. Defaults to saving to a file " 291 | "named 'embeddings_${CURRENT_DATETIME}.txt'. A vocab file is stored " 292 | "using the same path with additonal '.vocab' suffix."); 293 | args.add(pretrained_path, 294 | "r,pretrained-path", 295 | "path", 296 | "If passed (nonempty), continue training from an existing " 297 | "embedding table (also see continue-vocab)"); 298 | args.add(continue_vocab, 299 | "v,continue-vocab", 300 | "old|new|union", 301 | "Which vocab to use when continuing training (see " 302 | "pretrained-path), old: from pretrained table, new: " 303 | "from data, union: combined", 304 | RequireFromSet({"old", "new", "union"})); 305 | args.add(read_mode, 306 | "read-mode", 307 | #ifdef KOAN_ENABLE_ZIP 308 | "text|gzip|auto", 309 | "Force reading training files as text/gzip.", 310 | RequireFromSet({"text", "gzip", "auto"})); 311 | #else 312 | "text|auto", 313 | "Reading from gzipped files is not supported. " 314 | "Build koan with KOAN_ENABLE_ZIP.", 315 | RequireFromSet({"text", "auto"})); 316 | #endif 317 | args.add(shuffle, 318 | "s,shuffle-sentences", 319 | "true|false", 320 | "If true, will shuffle sentences in a batch before allocating " 321 | "to worker threads rather than assigning them consecutively " 322 | "to threads"); 323 | args.add(partitioned, 324 | "L,partitioned", 325 | "true|false", 326 | "If true, use the partitioned version of main parallel for loop. " 327 | "Can be faster due to a lack of std::atomic use, but also slower " 328 | "due to workers with less work waiting for others. Changes " 329 | "sentence processing order."); 330 | args.add(start_lr_schedule_epoch, 331 | "S,start-lr-schedule-epoch", 332 | "n", 333 | "Schedule learning rate as if training starts from n-th epoch " 334 | "instead of 0th."); 335 | args.add(max_lr_schedule_epochs, 336 | "E,max-lr-schedule-epochs", 337 | "n", 338 | "Schedule learning rate as if training will last for n epochs " 339 | "instead of what is specified by \"epochs\" option. Zero default " 340 | "makes it the same as \"start-lr-schedule-epoch + epochs\"."); 341 | args.add_flag(no_progress, 342 | "P,no-progress", 343 | "If passed, do not display counters and progress bars."); 344 | args.add_flag(enforce_max_line_length, 345 | "!,enforce-max-line-length", 346 | "If passed, will throw an error if any line in training file " 347 | "is longer than " + 348 | std::to_string(MAX_LINE_LEN) + 349 | " characters. Otherwise, will silently " 350 | "truncate any lines to this value."); 351 | 352 | args.add_help(); 353 | args.parse(argc, argv); 354 | 355 | // Validate arguments 356 | KOAN_ASSERT(epochs > 0); 357 | KOAN_ASSERT(max_lr_schedule_epochs == 0 or max_lr_schedule_epochs >= epochs); 358 | if (max_lr_schedule_epochs == 0) { 359 | max_lr_schedule_epochs = start_lr_schedule_epoch + epochs; 360 | } 361 | KOAN_ASSERT(start_lr_schedule_epoch < max_lr_schedule_epochs); 362 | 363 | if (not vocab_load_path.empty()) { 364 | KOAN_ASSERT(min_count == 1, 365 | "\"-k,--min-count\" should not be passed in " 366 | "when preloading vocabulary!"); 367 | KOAN_ASSERT(vocab_size == std::numeric_limits::max(), 368 | "\"-V,--vocab-size\" should not be passed in when preloading " 369 | "vocabulary!"); 370 | } 371 | if (total_sentences > 0) { 372 | KOAN_ASSERT(not vocab_load_path.empty(), 373 | "\"-I,--total-sentences\" should not be passed when not " 374 | "preloading a vocabulary file!"); 375 | } 376 | 377 | if (embedding_path.empty()) { 378 | embedding_path = "embeddings_" + date_time("%F_%T") + ".txt"; 379 | } 380 | 381 | Table table, ctx, local(num_threads, Vector::Zero(dim)); 382 | std::vector ordered_vocab; 383 | IndexMap word_map; // ordered_vocab will own the 384 | // actual strings. 385 | 386 | std::unordered_map pretrained_table; 387 | 388 | if (not pretrained_path.empty()) { 389 | pretrained_table = load_pretrained_embeddings( 390 | pretrained_path, read_mode, dim, enforce_max_line_length, no_progress); 391 | } 392 | 393 | bool read_whole_data = false; 394 | 395 | std::unordered_map freqs; 396 | 397 | if (vocab_load_path.empty()) { // build vocab from corpus 398 | std::tie(freqs, total_sentences) = 399 | build_vocab(fnames, read_mode, enforce_max_line_length, no_progress); 400 | 401 | if (not discard) { 402 | ordered_vocab.push_back(UNKSTR); 403 | freqs[UNKSTR] = 0; 404 | } 405 | 406 | // if a word in old vocab did not appear in corpus, assume a frequency count 407 | // of min_count 408 | if (continue_vocab == "old" or continue_vocab == "union") { 409 | for (auto& p : pretrained_table) { 410 | if (freqs.find(p.first) == freqs.end()) { freqs[p.first] = min_count; } 411 | } 412 | } 413 | 414 | if (continue_vocab == "old") { 415 | for (auto& p : pretrained_table) { 416 | if (freqs[p.first] >= min_count) { ordered_vocab.push_back(p.first); } 417 | } 418 | } else { // continue_vocab == "new" or "union" 419 | for (auto& [word, count] : freqs) { 420 | if (count >= min_count) { ordered_vocab.push_back(word); } 421 | } 422 | } 423 | 424 | size_t begin_offset = discard ? 0 : 1; // keep UNK at 0 if exists 425 | std::sort(ordered_vocab.begin() + begin_offset, 426 | ordered_vocab.end(), 427 | [&](auto& a, auto& b) { return freqs[a] > freqs[b]; }); 428 | 429 | // Resize if vocab is bigger than specified size 430 | if (vocab_size < ordered_vocab.size()) { ordered_vocab.resize(vocab_size); } 431 | 432 | KOAN_ASSERT(ordered_vocab.size() < std::numeric_limits::max(), 433 | "Vocab is too big for Word type! Either shrink vocab, or use " 434 | "bigger Word type."); 435 | 436 | save_vocab_file(embedding_path + ".vocab", ordered_vocab, freqs); 437 | } else { 438 | std::tie(ordered_vocab, freqs) = load_vocab_file(vocab_load_path); 439 | if (ordered_vocab.front() == UNKSTR) { 440 | discard = false; 441 | } else { 442 | discard = true; 443 | } 444 | } 445 | 446 | for (const auto& w : ordered_vocab) { 447 | word_map.insert(std::string_view(w)); 448 | assert(word_map.lookup(w) == table.size()); 449 | assert(word_map.lookup(w) == ctx.size()); 450 | table.push_back(Vector::Zero(dim)); 451 | ctx.push_back(Vector::Zero(dim)); 452 | } 453 | 454 | if (total_sentences > 0) { 455 | std::cout << "Total training sentences: " << total_sentences << std::endl; 456 | } 457 | 458 | if (total_sentences > 0 and buffer_size > total_sentences) { 459 | std::cerr << "WARNING: Buffer size is larger than the total number" 460 | " of sentences in the corpus -- will load entire dataset" 461 | " into memory once instead of streaming.\n"; 462 | read_whole_data = true; 463 | } 464 | 465 | unsigned long long tot = 0; // total count of all words 466 | std::vector prob(ordered_vocab.size()); // filter probs 467 | std::vector neg_prob(ordered_vocab.size()); // neg sampling probs 468 | 469 | if (not discard) { freqs[UNKSTR] = 0; } 470 | for (Word w = 0; w < prob.size(); w++) { 471 | auto count = freqs.at(std::string(word_map.reverse_lookup(w))); 472 | prob[w] = neg_prob[w] = count; 473 | tot += count; 474 | } 475 | 476 | // Maybe filter words by frequency 477 | // - 478 | // https://github.com/svn2github/word2vec/blob/99e546e27cae10aa20209dae1ed98716ac9022e9/word2vec.c#L396 479 | // - 480 | // https://github.com/RaRe-Technologies/gensim/blob/e859c11f6f57bf3c883a718a9ab7067ac0c2d4cf/gensim/models/word2vec.py#L1536 481 | for (auto& p : prob) { 482 | p = p / tot; 483 | p = 1. - sqrt(downsample_th / p) - 484 | downsample_th / p; // probability of discarding 485 | } 486 | 487 | // Compute negative sampling probs 488 | // https://github.com/RaRe-Technologies/gensim/blob/e859c11f6f57bf3c883a718a9ab7067ac0c2d4cf/gensim/models/word2vec.py#L1608 489 | { 490 | std::transform(neg_prob.begin(), 491 | neg_prob.end(), 492 | neg_prob.begin(), 493 | [ns_exponent](auto& x) { return std::pow(x, ns_exponent); }); 494 | Real total = std::accumulate(neg_prob.begin(), neg_prob.end(), 0.); 495 | std::transform(neg_prob.begin(), 496 | neg_prob.end(), 497 | neg_prob.begin(), 498 | [total](auto& x) { return x / total; }); 499 | } 500 | 501 | // Randomly initialize embeddings for words not present in pretrained_table 502 | for (size_t w = 0; w < table.size(); w++) { 503 | std::string word(word_map.reverse_lookup(w)); 504 | if (pretrained_table.find(word) != pretrained_table.end()) { 505 | table[w] = std::move(pretrained_table[word]); 506 | } else { 507 | table[w].setRandom(); 508 | table[w] *= (0.5 / dim); 509 | } 510 | ctx[w].setZero(); 511 | } 512 | // pretrained_table not needed after here, save memory 513 | pretrained_table.clear(); 514 | 515 | Trainer::Params params{ 516 | .dim = dim, 517 | .ctxs = ctxs, 518 | .negatives = negatives, 519 | .threads = num_threads, 520 | .use_bad_update = use_bad_update, 521 | }; 522 | 523 | Trainer trainer(params, table, ctx, prob, neg_prob); 524 | std::mt19937 g(12345); 525 | 526 | std::atomic tokens{0}, sents{0}, total_tokens{0}; 527 | std::atomic curr_lr{0}; 528 | 529 | Sentences sentences; 530 | 531 | Timer t; 532 | std::unique_ptr reader; 533 | if (read_whole_data) { 534 | reader = std::make_unique( 535 | word_map, fnames, discard, read_mode, enforce_max_line_length); 536 | } else { 537 | reader = std::make_unique(word_map, 538 | fnames, 539 | buffer_size, 540 | discard, 541 | read_mode, 542 | enforce_max_line_length); 543 | } 544 | 545 | if (total_sentences == 0) { 546 | std::cerr << "WARN: Total number of sentences is unknown, therefore " 547 | "learning rate scheduling and progress bar display are " 548 | "disabled. If you want to enable, feed it in via " 549 | "\"-I,--total-sentences\" option." 550 | << std::endl; 551 | } 552 | 553 | for (size_t e = 0; e < epochs; e++) { 554 | std::atomic filtered_tokens_in_epoch{0}, total_tokens_in_epoch{0}; 555 | 556 | tokens = 0; 557 | sents = 0; 558 | size_t global_i = 0; 559 | 560 | std::cout << "Epoch " << e << std::endl; 561 | 562 | auto bar = mew::ProgressBar(sents, total_sentences, "Sents:") | 563 | mew::Counter(tokens, "Toks:", "tok/s", mew::Speed::Last) | 564 | mew::Counter(curr_lr, "LR:", "", mew::Speed::None); 565 | auto ctr = mew::Counter(sents, "Sents:", "lin/s", mew::Speed::Last) | 566 | mew::Counter(tokens, "Toks:", "tok/s", mew::Speed::Last) | 567 | mew::Counter(curr_lr, "LR:", "", mew::Speed::None); 568 | if (not no_progress) { 569 | if (total_sentences > 0) { 570 | bar.start(); 571 | } else { // We don't know what the total is, so start a counter instead 572 | ctr.start(); 573 | } 574 | } 575 | 576 | while (reader->get_next(sentences)) { 577 | std::vector perm(sentences.size()); 578 | std::iota(perm.begin(), perm.end(), 0); 579 | 580 | if (shuffle) { std::shuffle(perm.begin(), perm.end(), g); } 581 | 582 | auto work = [&](size_t i, size_t tid) { 583 | auto& s = sentences[perm[i]]; 584 | 585 | // linear learning rate scheduling 586 | // https://github.com/RaRe-Technologies/gensim/blob/374de281b27f21fac4df20c315ee07caafb279c0/gensim/models/base_any2vec.py#L1083 587 | Real lr = init_lr; 588 | if (total_sentences > 0) { 589 | Real lr_sched = 590 | Real(e + start_lr_schedule_epoch) / max_lr_schedule_epochs + 591 | (Real(i + global_i) / total_sentences) / max_lr_schedule_epochs; 592 | lr = init_lr - (init_lr - min_lr) * lr_sched; 593 | } 594 | curr_lr = lr; 595 | 596 | size_t remaining_toks = trainer.train(s, tid, lr, cbow); 597 | sents++; 598 | tokens += remaining_toks; 599 | total_tokens += remaining_toks; 600 | filtered_tokens_in_epoch += remaining_toks; 601 | total_tokens_in_epoch += s.size(); 602 | }; 603 | 604 | if (partitioned) { 605 | parallel_for_partitioned(0, sentences.size(), work, num_threads); 606 | } else { 607 | parallel_for(0, sentences.size(), work, num_threads); 608 | } 609 | 610 | global_i += sentences.size(); 611 | } 612 | 613 | bar.done(); 614 | ctr.done(); 615 | 616 | std::cout << std::fixed << std::setprecision(2) 617 | << 100. * filtered_tokens_in_epoch / total_tokens_in_epoch 618 | << "% of tokens were retained while filtering." << std::endl; 619 | } 620 | auto total_secs = t.s(); 621 | std::cout << "Took " << unsigned(total_secs) << "s. (excluding vocab build)" 622 | << std::endl 623 | << "Overall speed was " << total_tokens / total_secs << " toks/s" 624 | << std::endl; 625 | 626 | { 627 | std::cout << "Saving to " << embedding_path << std::endl; 628 | FILE* out = fopen(embedding_path.c_str(), "w"); 629 | KOAN_ASSERT(out); 630 | std::string buf; 631 | buf.reserve(MAX_LINE_LEN); 632 | for (auto& w : word_map.keys()) { 633 | buf.clear(); 634 | buf += w; 635 | auto v = table[word_map.lookup(w)]; 636 | for (int j = 0; j < v.size(); j++) { 637 | buf += " "; 638 | buf += std::to_string(v(j)); 639 | } 640 | buf += "\n"; 641 | fputs(buf.data(), out); 642 | } 643 | fclose(out); 644 | } 645 | } 646 | -------------------------------------------------------------------------------- /koan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bloomberg/koan/c22fccdd6f359b5e7f4889b9969486ed27f76894/koan.png -------------------------------------------------------------------------------- /koan/cli.h: -------------------------------------------------------------------------------- 1 | /* 2 | ** Copyright 2020 Bloomberg Finance L.P. 3 | ** 4 | ** Licensed under the Apache License, Version 2.0 (the "License"); 5 | ** you may not use this file except in compliance with the License. 6 | ** You may obtain a copy of the License at 7 | ** 8 | ** http://www.apache.org/licenses/LICENSE-2.0 9 | ** 10 | ** Unless required by applicable law or agreed to in writing, software 11 | ** distributed under the License is distributed on an "AS IS" BASIS, 12 | ** WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | ** See the License for the specific language governing permissions and 14 | ** limitations under the License. 15 | */ 16 | 17 | // Tiny command line parsing utilities 18 | 19 | #ifndef KOAN_CLI_H 20 | #define KOAN_CLI_H 21 | 22 | #include 23 | #include 24 | #include 25 | #include 26 | #include 27 | #include 28 | #include 29 | 30 | #include "extern/tblr.h" 31 | 32 | #include "util.h" 33 | 34 | namespace koan { 35 | namespace internal { 36 | 37 | namespace fromstr { // Convert from string to things 38 | 39 | template 40 | T to(const std::string& s); 41 | 42 | template <> 43 | inline std::string to(const std::string& s) { 44 | return s; 45 | } 46 | 47 | template <> 48 | inline float to(const std::string& s) { 49 | return std::stof(s); 50 | } 51 | 52 | template <> 53 | inline double to(const std::string& s) { 54 | return std::stod(s); 55 | } 56 | 57 | template <> 58 | inline unsigned to(const std::string& s) { 59 | return std::stoul(s); 60 | } 61 | 62 | template <> 63 | inline int to(const std::string& s) { 64 | return std::stoi(s); 65 | } 66 | 67 | template <> 68 | inline long to(const std::string& s) { 69 | return std::stol(s); 70 | } 71 | 72 | template <> 73 | inline unsigned long to(const std::string& s) { 74 | return std::stoul(s); 75 | } 76 | 77 | template <> 78 | inline unsigned long long to(const std::string& s) { 79 | return std::stoull(s); 80 | } 81 | 82 | template <> 83 | inline bool to(const std::string& s) { 84 | if (s == "true" or s == "True" or s == "1") { return true; } 85 | if (s == "false" or s == "False" or s == "0") { return false; } 86 | throw std::runtime_error("Unexpected boolean string: " + s); 87 | return false; 88 | } 89 | 90 | } // namespace fromstr 91 | 92 | template 93 | std::ostream& operator<<(std::ostream& out, const std::vector& v) { 94 | if (!v.empty()) { 95 | out << '['; 96 | std::copy(v.begin(), v.end(), std::ostream_iterator(out, ", ")); 97 | out << "\b\b]"; 98 | } 99 | return out; 100 | } 101 | 102 | template 103 | std::string tostr(const T& x) { 104 | std::ostringstream ss; 105 | ss << x; 106 | return ss.str(); 107 | } 108 | 109 | } // namespace internal 110 | 111 | enum Require { Optional, Required }; 112 | 113 | class Args; 114 | 115 | namespace internal { 116 | 117 | class ValidityBase { 118 | protected: 119 | const bool throw_; 120 | const std::string namestr_; 121 | 122 | public: 123 | ValidityBase(bool throws, std::string namestr) 124 | : throw_(throws), namestr_(std::move(namestr)) {} 125 | virtual ~ValidityBase() = default; 126 | 127 | virtual std::string str() const = 0; 128 | 129 | friend class koan::Args; 130 | }; 131 | 132 | class ArgBase { 133 | public: 134 | virtual ~ArgBase() {} 135 | ArgBase(std::string descr, 136 | std::vector names, 137 | std::string value, 138 | Require required, 139 | std::unique_ptr validity = nullptr, 140 | bool nargs = false) 141 | : descr_(std::move(descr)), 142 | names_(std::move(names)), 143 | value_(std::move(value)), 144 | required_(required), 145 | nargs_(nargs), 146 | validity_(std::move(validity)) {} 147 | virtual std::string value_str() const = 0; 148 | virtual bool is_flag() const = 0; 149 | bool is_nary() { return nargs_; } 150 | 151 | protected: 152 | std::string descr_; // description of the option in the helpstr 153 | std::vector names_; // names of the option 154 | std::string value_; // value name in the helpstr 155 | Require required_ = Optional; // Required vs Optional 156 | bool parsed_ = false; // is this already parsed 157 | bool nargs_ = false; // Whether this accepts more than 1 argument 158 | std::unique_ptr validity_; // checks if parsed value is valid 159 | 160 | void parse(const std::string& value) { 161 | parse_(value); 162 | parsed_ = true; 163 | } 164 | 165 | private: 166 | virtual void parse_(const std::string& value) = 0; 167 | 168 | friend class koan::Args; 169 | }; 170 | 171 | template 172 | class Validity : public ValidityBase { 173 | public: 174 | Validity(bool throws, std::string namestr) 175 | : ValidityBase(throws, std::move(namestr)) {} 176 | 177 | virtual void check(const T&) const { 178 | throw std::runtime_error("A validity checker is not implemented for " 179 | "this type!"); 180 | } 181 | 182 | std::string str() const override { return ""; } 183 | 184 | void warn_or_throw(const T& val, bool is_range) const { 185 | std::string adj = throw_ ? "required" : "suggested"; 186 | std::string noun = is_range ? "range" : "set"; 187 | std::string msg = "Value " + tostr(val) + " for " + namestr_ + 188 | " is not in " + adj + " " + noun + ": " + str(); 189 | if (throw_) { 190 | throw std::runtime_error(msg); 191 | } else { 192 | std::cerr << "Warning: " + msg << std::endl; 193 | } 194 | } 195 | }; 196 | 197 | // Check if parsed argument value is in the set 198 | template 199 | class InSet : public Validity { 200 | private: 201 | const std::vector candidates_; 202 | 203 | public: 204 | InSet(bool throws, std::vector candidates, std::string namestr) 205 | : Validity(throws, std::move(namestr)), 206 | candidates_(std::move(candidates)) {} 207 | void check(const T& val) const override { 208 | if (std::find(candidates_.begin(), candidates_.end(), val) == 209 | candidates_.end()) { 210 | this->warn_or_throw(val, false); 211 | } 212 | } 213 | std::string str() const override { 214 | std::string s = "{"; 215 | size_t i = 0; 216 | while (i < candidates_.size()) { 217 | std::string value_str = tostr(candidates_[i]); 218 | i++; 219 | if (value_str.size() + s.size() > 20) { break; } 220 | if (i > 1) { s += ", "; } 221 | s += value_str; 222 | } 223 | if (i == candidates_.size()) { 224 | s += "}"; 225 | } else { 226 | s += ", ...}"; 227 | } 228 | return s; 229 | } 230 | }; 231 | 232 | // Check if parsed argument value is in the range (inclusive) 233 | template 234 | class InRange : public Validity { 235 | private: 236 | const T left_, right_; 237 | 238 | public: 239 | InRange(bool throws, const T& left, const T& right, std::string namestr) 240 | : Validity(throws, std::move(namestr)), left_(left), right_(right) {} 241 | void check(const T& val) const override { 242 | if (not(left_ <= val and val <= right_)) { this->warn_or_throw(val, true); } 243 | } 244 | std::string str() const override { 245 | return "[" + tostr(left_) + ", " + tostr(right_) + "]"; 246 | } 247 | }; 248 | 249 | template 250 | class Arg : public ArgBase { 251 | protected: 252 | T& dest_; 253 | 254 | public: 255 | Arg(T& dest, 256 | std::string descr, 257 | std::vector names, 258 | std::string value, 259 | Require required = Optional, 260 | std::unique_ptr validity = nullptr) 261 | : ArgBase(std::move(descr), 262 | std::move(names), 263 | std::move(value), 264 | required, 265 | std::move(validity)), 266 | dest_(dest) {} 267 | 268 | bool is_flag() const override { return false; } 269 | 270 | std::string value_str() const override { return tostr(dest_); } 271 | 272 | private: 273 | void parse_(const std::string& value) override { 274 | dest_ = internal::fromstr::to(value); 275 | if (validity_) { dynamic_cast&>(*validity_).check(dest_); } 276 | } 277 | }; 278 | 279 | template 280 | class Arg> : public ArgBase { 281 | protected: 282 | std::vector& dest_; 283 | 284 | public: 285 | Arg>(std::vector& dest, 286 | std::string descr, 287 | std::vector names, 288 | std::string value, 289 | Require required = Optional, 290 | std::unique_ptr validity = nullptr) 291 | : ArgBase(std::move(descr), 292 | std::move(names), 293 | std::move(value), 294 | required, 295 | std::move(validity), 296 | true), 297 | dest_(dest) {} 298 | 299 | bool is_flag() const override { return false; } 300 | 301 | std::string value_str() const override { return tostr(dest_); } 302 | 303 | private: 304 | void parse_(const std::string& value) override { 305 | T arg = internal::fromstr::to(value); 306 | if (validity_) { dynamic_cast&>(*validity_).check(arg); } 307 | 308 | dest_.push_back(arg); 309 | } 310 | }; 311 | 312 | template <> 313 | class Arg : public ArgBase { 314 | protected: 315 | bool& dest_; 316 | bool is_flag_ = false; 317 | 318 | public: 319 | Arg(bool& dest, 320 | std::string descr, 321 | std::vector names, 322 | std::string value, 323 | Require required = Optional, 324 | std::unique_ptr> /*validity*/ = nullptr, 325 | bool is_flag = false) 326 | : ArgBase(std::move(descr), std::move(names), std::move(value), required), 327 | dest_(dest), 328 | is_flag_(is_flag) {} 329 | 330 | bool is_flag() const override { return is_flag_; } 331 | 332 | std::string value_str() const override { return (dest_ ? "true" : "false"); } 333 | 334 | private: 335 | void parse_(const std::string& value) override { 336 | if (is_flag_) { 337 | dest_ = true; 338 | } else { 339 | dest_ = internal::fromstr::to(value); 340 | } 341 | } 342 | }; 343 | 344 | template 345 | class Arg> : public ArgBase { 346 | // Functional args are flags that are used to perform actions 347 | protected: 348 | std::function f_; 349 | 350 | public: 351 | Arg>(std::function f, 352 | std::string descr, 353 | std::vector names, 354 | std::string value) 355 | : ArgBase(std::move(descr), std::move(names), std::move(value), Optional), 356 | f_(f) {} 357 | bool is_flag() const override { return true; } 358 | std::string value_str() const override { return ""; } 359 | 360 | private: 361 | void parse_(const std::string&) override { f_(); } 362 | }; 363 | 364 | template 365 | struct Range { 366 | bool throws; 367 | T left, right; 368 | }; 369 | 370 | template 371 | struct Set { 372 | bool throws; 373 | std::vector candidates; 374 | }; 375 | 376 | } // namespace internal 377 | 378 | template 379 | auto RequireRange(const T& left, const T& right) { 380 | return internal::Range{/*throws*/ true, left, right}; 381 | } 382 | 383 | template 384 | auto SuggestRange(const T& left, const T& right) { 385 | return internal::Range{/*throws*/ false, left, right}; 386 | } 387 | 388 | template 389 | auto RequireFromSet(const std::vector& candidates) { 390 | return internal::Set{/*throws*/ true, candidates}; 391 | } 392 | 393 | template 394 | auto SuggestFromSet(const std::vector& candidates) { 395 | return internal::Set{/*throws*/ false, candidates}; 396 | } 397 | 398 | template 399 | auto RequireFromSet(std::initializer_list candidates) { 400 | return internal::Set{/*throws*/ true, candidates}; 401 | } 402 | 403 | template 404 | auto SuggestFromSet(std::initializer_list candidates) { 405 | return internal::Set{/*throws*/ false, candidates}; 406 | } 407 | 408 | class Args { 409 | public: 410 | struct ParseError : public std::runtime_error { 411 | using std::runtime_error::runtime_error; 412 | }; 413 | 414 | private: 415 | using ArgBase = internal::ArgBase; 416 | template 417 | using Arg = internal::Arg; 418 | template 419 | using Validity = internal::Validity; 420 | 421 | std::vector> positional_args_; 422 | std::vector> named_args_; 423 | std::map name2i_; 424 | bool has_help_ = false; 425 | std::string program_name_ = "program"; 426 | 427 | static bool is_name(const std::string& value, std::string& name); 428 | static std::vector validate_names(const std::string& namestr); 429 | static void ensure(bool predicate, const std::string& msg); 430 | std::string helpstr() const; 431 | 432 | public: 433 | void parse(int argc, char** argv); 434 | void parse(const std::vector& argv); 435 | 436 | // add option (named arg) 437 | template 438 | void add(T& dest, 439 | const std::string& namestr, 440 | const std::string& value, 441 | const std::string& descr, 442 | Require require = Optional); 443 | 444 | template 445 | void add(T& dest, 446 | const std::string& namestr, 447 | const std::string& value, 448 | const std::string& descr, 449 | const internal::Range& range, 450 | Require require = Optional); 451 | 452 | template 453 | void add(T& dest, 454 | const std::string& namestr, 455 | const std::string& value, 456 | const std::string& descr, 457 | const internal::Set& range, 458 | Require require = Optional); 459 | 460 | // add positional arg 461 | template 462 | void add(T& dest, const std::string& value, const std::string& descr); 463 | 464 | // add option as a flag 465 | void add_flag(bool& dest, 466 | const std::string& namestr, 467 | const std::string& descr, 468 | Require require = Optional); 469 | 470 | // add helpstring flag (-?, -h, --help) 471 | void add_help(); 472 | }; 473 | 474 | inline void Args::ensure(bool predicate, const std::string& msg) { 475 | if (not predicate) { throw ParseError(msg.c_str()); } 476 | } 477 | 478 | inline bool Args::is_name(const std::string& value, std::string& name) { 479 | if (value.size() >= 2 and value[0] == '-' and value[1] == '-') { 480 | name = value.substr(2, value.size()); 481 | ensure(not name.empty(), "Prefix `--` not followed by an option!"); 482 | return true; 483 | } 484 | if (value.size() >= 1 and value[0] == '-') { 485 | name = value.substr(1, value.size()); 486 | ensure(not name.empty(), "Prefix `-` not followed by an option!"); 487 | ensure(name.size() == 1, 488 | "Options prefixed by `-` have to be short names! " 489 | "Did you mean `--" + 490 | name + "`?"); 491 | return true; 492 | } 493 | return false; 494 | } 495 | 496 | inline std::string join(const std::vector& strings, 497 | const std::string& delim) { 498 | std::string s; 499 | for (size_t i = 0; i < strings.size(); i++) { 500 | if (i > 0) { s += delim; } 501 | s += strings[i]; 502 | } 503 | return s; 504 | } 505 | 506 | inline void Args::parse(const std::vector& argv) { 507 | size_t i = 0; 508 | size_t positional_i = 0; 509 | 510 | while (i < argv.size()) { 511 | std::string name; 512 | if (is_name(argv[i], name)) { // a named argument (option) 513 | ensure(name2i_.find(name) != name2i_.end(), "Unexpected option: " + name); 514 | auto& opt = *named_args_.at(name2i_.at(name)); 515 | ensure(not opt.parsed_, "Option `" + name + "` is multiply given!"); 516 | if (opt.is_flag()) { 517 | opt.parse(""); 518 | i++; 519 | } else if (opt.is_nary()) { 520 | size_t j = i + 1; 521 | std::string tmpname; 522 | while (j < argv.size() && !is_name(argv[j], tmpname)) { 523 | opt.parse(argv.at(j)); 524 | j++; 525 | } 526 | 527 | i = j; 528 | } else { 529 | ensure((i + 1) < argv.size(), 530 | "Option `" + name + "` is missing value!"); 531 | opt.parse(argv.at(i + 1)); 532 | i += 2; 533 | } 534 | } else { // a positional argument 535 | ensure(positional_i < positional_args_.size(), 536 | "Unexpected positional argument: " + argv[i]); 537 | positional_args_.at(positional_i)->parse(argv[i]); 538 | i++; 539 | positional_i++; 540 | } 541 | } 542 | 543 | // check if all required args are parsed 544 | for (auto& arg : positional_args_) { 545 | ensure(arg->parsed_, 546 | "Required positional argument <" + arg->value_ + 547 | "> is not provided!"); 548 | } 549 | for (auto& arg : named_args_) { 550 | if (arg->required_) { 551 | ensure(arg->parsed_, 552 | "Required option `" + join(arg->names_, ", ") + " <" + 553 | arg->value_ + ">` is not provided!"); 554 | } 555 | } 556 | } 557 | 558 | inline void Args::parse(int argc, char** argv) { 559 | program_name_ = argv[0]; 560 | std::vector argv_; 561 | for (int i = 1; i < argc; i++) { argv_.push_back(argv[i]); } 562 | parse(argv_); 563 | } 564 | 565 | inline void Args::add_help() { 566 | const static std::vector names({"?", "h", "help"}); 567 | for (auto& name : names) { 568 | ensure(name2i_.find(name) == name2i_.end(), 569 | "Option `" + name + "` is multiply defined!"); 570 | name2i_[name] = named_args_.size(); 571 | } 572 | named_args_.push_back(std::make_unique>>( 573 | [this]() { 574 | std::cout << helpstr() << std::flush; 575 | exit(0); 576 | }, 577 | "print this help message and quit", 578 | names, 579 | "")); 580 | has_help_ = true; 581 | } 582 | 583 | inline std::vector 584 | Args::validate_names(const std::string& namestr) { 585 | auto names = split(namestr, ','); 586 | ensure(not names.empty(), "Option name cannot be empty!"); 587 | ensure(names.size() <= 2, 588 | "Option names can be one short and one long at most! " 589 | "E.g. \"o,option\" or \"o\" or \"option\"."); 590 | if (names.size() == 2) { // ether specify "o,option", 591 | auto& short_name = names[0]; 592 | auto& long_name = names[1]; 593 | ensure(short_name.size() == 1 and long_name.size() > 1, 594 | "Multiple form option names should be first short then long! " 595 | "E.g. \"o,option\"."); 596 | } else { // or "o" only or "option" only. 597 | ; 598 | } 599 | 600 | return names; 601 | } 602 | 603 | template 604 | inline void Args::add(T& dest, 605 | const std::string& namesstr, 606 | const std::string& value, 607 | const std::string& descr, 608 | Require required) { 609 | auto names = validate_names(namesstr); 610 | for (auto& name : names) { 611 | ensure(name2i_.find(name) == name2i_.end(), 612 | "Option `" + name + "` is multiply defined!"); 613 | name2i_[name] = named_args_.size(); 614 | } 615 | named_args_.push_back( 616 | std::make_unique>(dest, descr, names, value, required)); 617 | } 618 | 619 | template 620 | inline void Args::add(T& dest, 621 | const std::string& namesstr, 622 | const std::string& value, 623 | const std::string& descr, 624 | const internal::Range& range, 625 | Require required) { 626 | auto names = validate_names(namesstr); 627 | for (auto& name : names) { 628 | ensure(name2i_.find(name) == name2i_.end(), 629 | "Option `" + name + "` is multiply defined!"); 630 | name2i_[name] = named_args_.size(); 631 | } 632 | named_args_.push_back(std::make_unique>( 633 | dest, 634 | descr, 635 | names, 636 | value, 637 | required, 638 | std::make_unique>( 639 | range.throws, range.left, range.right, namesstr))); 640 | } 641 | 642 | template 643 | inline void Args::add(T& dest, 644 | const std::string& namesstr, 645 | const std::string& value, 646 | const std::string& descr, 647 | const internal::Set& set, 648 | Require required) { 649 | auto names = validate_names(namesstr); 650 | for (auto& name : names) { 651 | ensure(name2i_.find(name) == name2i_.end(), 652 | "Option `" + name + "` is multiply defined!"); 653 | name2i_[name] = named_args_.size(); 654 | } 655 | std::vector candidate_set; 656 | for (auto& val : set.candidates) { candidate_set.push_back(val); } 657 | named_args_.push_back( 658 | std::make_unique>(dest, 659 | descr, 660 | names, 661 | value, 662 | required, 663 | std::make_unique>( 664 | set.throws, candidate_set, namesstr))); 665 | } 666 | 667 | template 668 | inline void 669 | Args::add(T& dest, const std::string& value, const std::string& descr) { 670 | positional_args_.emplace_back(new Arg(dest, descr, {}, value, Required)); 671 | // I cannot use make_unique because I made the ctor protected. Is this OK? 672 | } 673 | 674 | inline void Args::add_flag(bool& dest, 675 | const std::string& namestr, 676 | const std::string& descr, 677 | Require require) { 678 | ensure(not dest, 679 | "Optional boolean flags need to default to false, " 680 | "since the action is `store true`!"); 681 | auto names = validate_names(namestr); 682 | for (auto& name : names) { 683 | ensure(name2i_.find(name) == name2i_.end(), 684 | "Option `" + name + "` is multiply defined!"); 685 | name2i_[name] = named_args_.size(); 686 | } 687 | named_args_.emplace_back( 688 | new Arg(dest, descr, names, "", require, nullptr, true)); 689 | } 690 | 691 | inline std::string Args::helpstr() const { 692 | auto table = []() { 693 | tblr::Table t; 694 | t.widths({0, 50}).multiline(tblr::Space).layout(tblr::indented_list()); 695 | return t; 696 | }; 697 | 698 | std::stringstream ss; 699 | 700 | ss << "Usage:\n " << program_name_; 701 | for (auto& arg_ : positional_args_) { ss << " <" << arg_->value_ << ">"; } 702 | ss << " options\n"; 703 | 704 | if (not positional_args_.empty() or not named_args_.empty()) { 705 | ss << "\nwhere "; 706 | } 707 | 708 | if (not positional_args_.empty()) { 709 | auto t = table(); 710 | for (auto& arg_ : positional_args_) { 711 | t << (tblr::Cell() << "<" << arg_->value_ << ">") << arg_->descr_ 712 | << tblr::endr; 713 | } 714 | ss << "positional arguments are:\n" << t << "\n"; 715 | } 716 | 717 | std::vector required_opts, optional_opts; 718 | for (size_t i = 0; i < named_args_.size(); i++) { 719 | if (named_args_[i]->required_) { 720 | required_opts.push_back(i); 721 | } else { 722 | optional_opts.push_back(i); 723 | } 724 | } 725 | 726 | auto make_option_str = [&](auto& arg) { 727 | auto names = arg->names_; 728 | for (auto& name : names) { 729 | if (name.size() == 1) { 730 | name = "-" + name; 731 | } else { 732 | name = "--" + name; 733 | } 734 | } 735 | return join(names, ", "); 736 | }; 737 | 738 | if (not required_opts.empty()) { 739 | auto t = table(); 740 | for (size_t i = 0; i < required_opts.size(); i++) { 741 | auto& arg_ = named_args_.at(required_opts[i]); 742 | 743 | auto option = (tblr::Cell() << make_option_str(arg_)); 744 | if (not arg_->is_flag()) { option << " <" << arg_->value_ << ">"; } 745 | auto descr = (tblr::Cell() << arg_->descr_); 746 | 747 | if (arg_->validity_) { 748 | if (arg_->validity_->throw_) { 749 | descr << " (required in "; 750 | } else { 751 | descr << " (suggested in "; 752 | } 753 | descr << arg_->validity_->str(); 754 | descr << ")"; 755 | } 756 | 757 | t << option << descr << tblr::endr; 758 | } 759 | 760 | ss << "required options are:\n" << t << "\n"; 761 | } 762 | 763 | if (not optional_opts.empty() or has_help_) { 764 | auto t = table(); 765 | if (has_help_) { 766 | t << "-?, -h, --help" << named_args_[name2i_.at("h")]->descr_ 767 | << tblr::endr; 768 | } 769 | 770 | for (size_t i = 0; i < optional_opts.size(); i++) { 771 | auto& arg_ = named_args_.at(optional_opts[i]); 772 | 773 | if (has_help_ and &named_args_[name2i_.at("h")] == &arg_) { continue; } 774 | 775 | auto option = (tblr::Cell() << make_option_str(arg_)); 776 | if (not arg_->is_flag()) { option << " <" << arg_->value_ << ">"; } 777 | auto descr = (tblr::Cell() << arg_->descr_); 778 | if (arg_->is_flag()) { 779 | descr << " (flag)"; 780 | } else { 781 | descr << " (default: " << arg_->value_str(); 782 | if (arg_->validity_) { 783 | if (arg_->validity_->throw_) { 784 | descr << ", required in "; 785 | } else { 786 | descr << ", suggested in "; 787 | } 788 | descr << arg_->validity_->str(); 789 | } 790 | descr << ")"; 791 | } 792 | 793 | t << option << descr << tblr::endr; 794 | } 795 | ss << "optional options are:\n" << t << "\n"; 796 | } 797 | return ss.str(); 798 | } 799 | 800 | } // namespace koan 801 | 802 | #endif 803 | -------------------------------------------------------------------------------- /koan/def.h: -------------------------------------------------------------------------------- 1 | /* 2 | ** Copyright 2020 Bloomberg Finance L.P. 3 | ** 4 | ** Licensed under the Apache License, Version 2.0 (the "License"); 5 | ** you may not use this file except in compliance with the License. 6 | ** You may obtain a copy of the License at 7 | ** 8 | ** http://www.apache.org/licenses/LICENSE-2.0 9 | ** 10 | ** Unless required by applicable law or agreed to in writing, software 11 | ** distributed under the License is distributed on an "AS IS" BASIS, 12 | ** WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | ** See the License for the specific language governing permissions and 14 | ** limitations under the License. 15 | */ 16 | 17 | #ifndef KOAN_DEF_H 18 | #define KOAN_DEF_H 19 | 20 | #include 21 | #include 22 | #include 23 | 24 | #include 25 | 26 | namespace koan { 27 | 28 | #ifdef KOAN_GRAD_CHECKING 29 | using Real = double; 30 | #else 31 | using Real = float; 32 | #endif 33 | constexpr Real operator"" _R(long double d) { 34 | return d; 35 | } 36 | constexpr Real operator"" _R(unsigned long long d) { 37 | return d; 38 | } 39 | 40 | using Vector = Eigen::Matrix; 41 | using Table = std::vector; 42 | 43 | using Word = unsigned; 44 | using Sentence = std::vector; 45 | using Sentences = std::vector; 46 | 47 | const static std::string UNKSTR = "___UNK___"; 48 | const static std::string_view UNK(UNKSTR); 49 | 50 | const static size_t INITIAL_INDEX_SIZE = 30000000; 51 | const static size_t INITIAL_SENTENCE_LEN = 1000; 52 | const static int MAX_LINE_LEN = 1000000; 53 | 54 | // based on the first nonzero entry in the sigmoid approx. table 55 | const static Real MIN_SIGMOID_IN_LOSS = 0.000340641; 56 | 57 | } // namespace koan 58 | 59 | #endif 60 | -------------------------------------------------------------------------------- /koan/indexmap.h: -------------------------------------------------------------------------------- 1 | /* 2 | ** Copyright 2020 Bloomberg Finance L.P. 3 | ** 4 | ** Licensed under the Apache License, Version 2.0 (the "License"); 5 | ** you may not use this file except in compliance with the License. 6 | ** You may obtain a copy of the License at 7 | ** 8 | ** http://www.apache.org/licenses/LICENSE-2.0 9 | ** 10 | ** Unless required by applicable law or agreed to in writing, software 11 | ** distributed under the License is distributed on an "AS IS" BASIS, 12 | ** WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | ** See the License for the specific language governing permissions and 14 | ** limitations under the License. 15 | */ 16 | 17 | #ifndef KOAN_INDEXMAP_H 18 | #define KOAN_INDEXMAP_H 19 | 20 | #include 21 | #include 22 | #include 23 | 24 | #include "def.h" 25 | 26 | namespace koan { 27 | 28 | /// Used to store vocabulary map from words to index, and the reverse. 29 | template 30 | class IndexMap { 31 | private: 32 | std::unordered_map k2i_; 33 | std::vector i2k_; 34 | 35 | public: 36 | IndexMap() { 37 | k2i_.reserve(INITIAL_INDEX_SIZE); 38 | i2k_.reserve(INITIAL_INDEX_SIZE); 39 | } 40 | IndexMap(const std::unordered_set& keys) { 41 | k2i_.reserve(INITIAL_INDEX_SIZE); 42 | i2k_.reserve(INITIAL_INDEX_SIZE); 43 | for (const auto& key : keys) { i2k_.push_back(key); } 44 | for (size_t i = 0; i < i2k_.size(); i++) { k2i_[i2k_[i]] = i; } 45 | } 46 | 47 | void insert(const Key& key) { 48 | auto elt = k2i_.emplace(key, i2k_.size()); 49 | if (elt.second) { i2k_.push_back(key); } 50 | } 51 | 52 | const std::vector& keys() const { return i2k_; } 53 | 54 | bool has(const Key& key) const { return k2i_.find(key) != k2i_.end(); } 55 | 56 | size_t size() const { return i2k_.size(); } 57 | 58 | void clear() { 59 | k2i_.clear(); 60 | i2k_.clear(); 61 | } 62 | 63 | auto find(const Key& key) const { return k2i_.find(key); } 64 | auto end() const { return k2i_.end(); } 65 | size_t lookup(const Key& key) const { return k2i_.at(key); } 66 | size_t operator[](const Key& key) const { return lookup(key); } 67 | 68 | const Key& reverse_lookup(size_t i) const { return i2k_.at(i); } 69 | const Key& operator()(size_t i) const { return reverse_lookup(i); } 70 | }; 71 | 72 | } // namespace koan 73 | 74 | #endif 75 | -------------------------------------------------------------------------------- /koan/reader.h: -------------------------------------------------------------------------------- 1 | /* 2 | ** Copyright 2020 Bloomberg Finance L.P. 3 | ** 4 | ** Licensed under the Apache License, Version 2.0 (the "License"); 5 | ** you may not use this file except in compliance with the License. 6 | ** You may obtain a copy of the License at 7 | ** 8 | ** http://www.apache.org/licenses/LICENSE-2.0 9 | ** 10 | ** Unless required by applicable law or agreed to in writing, software 11 | ** distributed under the License is distributed on an "AS IS" BASIS, 12 | ** WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | ** See the License for the specific language governing permissions and 14 | ** limitations under the License. 15 | */ 16 | 17 | #ifndef KOAN_READER_H 18 | #define KOAN_READER_H 19 | 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include 25 | #include 26 | #include 27 | 28 | #include "def.h" 29 | #include "indexmap.h" 30 | #include "util.h" 31 | 32 | #ifdef KOAN_ENABLE_ZIP 33 | #include "zlib.h" 34 | #endif 35 | 36 | namespace koan { 37 | 38 | /// Abstraction over type of file to train on. 39 | class TrainFileHandler { 40 | protected: 41 | const std::string& fname_; 42 | 43 | public: 44 | TrainFileHandler(const std::string& fname) : fname_(fname) {} 45 | 46 | virtual char* gets(char* buf, int len) = 0; 47 | virtual void close() = 0; 48 | 49 | virtual ~TrainFileHandler() = default; 50 | }; 51 | 52 | /// Reads plain text files 53 | class TextFileHandler : public TrainFileHandler { 54 | private: 55 | FILE* f; 56 | 57 | public: 58 | TextFileHandler(const std::string& fname) : TrainFileHandler(fname) { 59 | f = fopen(fname.c_str(), "r"); 60 | KOAN_ASSERT(f != nullptr, 61 | "Could not open input file '" + fname + 62 | "' -- make sure it exists."); 63 | } 64 | 65 | char* gets(char* buf, int len) override { return fgets(buf, len, f); } 66 | 67 | void close() override { fclose(f); } 68 | }; 69 | 70 | #ifdef KOAN_ENABLE_ZIP 71 | /// Reads gzipped files 72 | class GzipFileHandler : public TrainFileHandler { 73 | private: 74 | gzFile f; 75 | 76 | public: 77 | GzipFileHandler(const std::string& fname) : TrainFileHandler(fname) { 78 | f = gzopen(fname.c_str(), "r"); 79 | 80 | KOAN_ASSERT(f != nullptr, 81 | "Could not open input file '" + fname + 82 | "' -- make sure it exists."); 83 | } 84 | 85 | char* gets(char* buf, int len) override { return gzgets(f, buf, len); } 86 | 87 | void close() override { gzclose(f); } 88 | }; 89 | #endif 90 | 91 | std::unique_ptr getfilehandler(const std::string& fname, 92 | const std::string& read_mode) { 93 | 94 | #ifdef KOAN_ENABLE_ZIP 95 | bool is_ext_gzip = 96 | fname.size() >= 3 and fname.compare(fname.size() - 3, 3, ".gz") == 0; 97 | 98 | if (read_mode == "gzip" or (is_ext_gzip && read_mode == "auto")) { 99 | return std::make_unique(fname); 100 | } 101 | #endif 102 | 103 | return std::make_unique(fname); 104 | } 105 | 106 | /// Read lines from a training file and process each using function f. Each 107 | /// separate sequence (e.g., sentence/paragraph) should be separated by a 108 | /// newline. 109 | /// 110 | /// @param[in] fname path to dataset to read 111 | /// @param[in] f function to process each line of input file 112 | /// @param[in] read_mode how to read from each file. Respected if compiled with 113 | /// KOAN_ENABLE_ZIP, otherwise assumes all are plain text files. 114 | /// @tparam: F is a callable on const(std::string_view&). 115 | template 116 | void readlines(const std::vector& fnames, 117 | F f, 118 | std::string read_mode, 119 | bool assert_no_long_lines = false) { 120 | for (const std::string& fname : fnames) { 121 | auto fhandler = getfilehandler(fname, read_mode); 122 | 123 | std::unique_ptr line_c_str(new char[MAX_LINE_LEN]()); 124 | while (fhandler->gets(line_c_str.get(), MAX_LINE_LEN) != nullptr) { 125 | auto line = std::string_view(line_c_str.get()); 126 | 127 | if (assert_no_long_lines) { 128 | KOAN_ASSERT(line.back() == '\n', 129 | "No end-of-line char! A line in input " 130 | "data might be too long in file '" + 131 | fname + "'"); 132 | } 133 | 134 | line.remove_suffix(1); // remove \n 135 | f(line); 136 | } 137 | 138 | fhandler->close(); 139 | } 140 | } 141 | 142 | template 143 | void readlines(const std::string& fname, 144 | F f, 145 | std::string read_mode, 146 | bool assert_no_long_lines) { 147 | const std::vector fname_vec{fname}; 148 | readlines(fname_vec, f, read_mode, assert_no_long_lines); 149 | } 150 | 151 | /// Abstract class for reading from a pre-tokenized file. 152 | class Reader { 153 | protected: 154 | bool discard_; // discard OOV words instead of replacing with UNK 155 | bool assert_no_long_lines_; // whether to throw on lines > MAX_LINE_LEN chars 156 | 157 | std::vector fnames_; 158 | std::string read_mode_; 159 | 160 | // buffers reused to avoid wasteful allocs 161 | std::vector words_; 162 | 163 | IndexMap& word_map_; 164 | 165 | /// Split a sequence into tokens by space. Handle out-of-vocabulary words 166 | /// based on the discard flag. 167 | /// 168 | /// @param[in] line string_view of a line in the input file. Corresponds to a 169 | /// single sequence. 170 | /// @returns a vector of token indices for this line 171 | Sentence parseline(const std::string_view& line) { 172 | Sentence s; 173 | 174 | words_.clear(); 175 | split(words_, line, ' '); 176 | 177 | s.reserve(words_.size()); 178 | for (size_t t = 0; t < words_.size(); t++) { 179 | const auto index = word_map_.find(words_[t]); 180 | 181 | if (index == word_map_.end()) { 182 | if (not discard_) { s.push_back(word_map_.lookup(UNK)); } 183 | } else { 184 | s.push_back(index->second); 185 | } 186 | } 187 | return s; 188 | } 189 | 190 | public: 191 | /// 192 | /// @param[in] word_map vocabulary 193 | /// @param[in] fname input file path 194 | /// @param[in] discard flag to toggle between discarding OOV words or 195 | /// replacing them with UNK 196 | /// @param[in] read_mode define behavior for reading from files. "text": 197 | /// treat all files as plain text; "gzip": treat all files as gzipped; "auto": 198 | /// treat *.gz as gzipped, otherwise plain text 199 | Reader(IndexMap& word_map, 200 | std::vector& fnames, 201 | bool discard, 202 | std::string read_mode, 203 | bool assert_no_long_lines = false) 204 | : discard_(discard), 205 | assert_no_long_lines_(assert_no_long_lines), 206 | fnames_(fnames), 207 | read_mode_(read_mode), 208 | word_map_(word_map) { 209 | words_.reserve(100); 210 | } 211 | virtual ~Reader() = default; 212 | 213 | virtual bool get_next(Sentences&) = 0; 214 | }; 215 | 216 | /// Reader used when one can store the entire training set in memory. 217 | class OnceReader : public Reader { 218 | private: 219 | bool read_ = false; 220 | bool fake_reached_eof_ = false; 221 | 222 | public: 223 | using Reader::Reader; 224 | 225 | /// Read everything once at the first call, otherwise do nothing as sentences 226 | /// are already populated. 227 | /// 228 | /// @param[in] s list of sentences to be populated 229 | /// @returns whether we actually read from the file (the first call) 230 | bool get_next(Sentences& s) override { 231 | if (not read_) { 232 | readlines( 233 | fnames_, 234 | [&](const std::string_view& line) { s.push_back(parseline(line)); }, 235 | read_mode_, 236 | assert_no_long_lines_); 237 | 238 | read_ = true; 239 | } 240 | fake_reached_eof_ = not fake_reached_eof_; 241 | return fake_reached_eof_; 242 | } 243 | }; 244 | 245 | /// A reader to be used when you cannot store the entire training set in memory. 246 | class AsyncReader : public Reader { 247 | private: 248 | size_t buffer_size_; 249 | 250 | std::unique_ptr in_; // handler of current file, track where 251 | // we left off 252 | size_t path_idx_ = 0; // index into which file we are reading from 253 | std::unique_ptr line_c_str_ = nullptr; 254 | Sentences read_buffer_; 255 | 256 | std::unique_ptr reader_; 257 | bool reached_eof_ = false; // reached EOF in current call to get_next(). 258 | bool reached_eofs_ = false; // reached EOF for the last file in current call 259 | // to get_next(). 260 | bool reached_eofs_prev_ = false; // reached EOF in previous call to get_next() 261 | // it needs to return false to reset the 262 | // loop, similar to 263 | // std::getline(ifstream, line). 264 | 265 | public: 266 | /// 267 | /// @param[in] word_map vocabulary 268 | /// @param[in] fname input file path 269 | /// @param[in] buffer_size number of lines to read into memory at once 270 | /// @param[in] discard flag to toggle between discarding OOV words or 271 | /// replacing them with UNK 272 | AsyncReader(IndexMap& word_map, 273 | std::vector& fnames, 274 | size_t buffer_size, 275 | bool discard, 276 | const std::string& read_mode, 277 | bool assert_no_long_lines) 278 | : Reader(word_map, fnames, discard, read_mode, assert_no_long_lines), 279 | buffer_size_(buffer_size), 280 | path_idx_(0) { 281 | 282 | in_ = getfilehandler(fnames_[path_idx_], read_mode_); 283 | line_c_str_ = std::unique_ptr(new char[MAX_LINE_LEN]()); 284 | start_reader(); 285 | } 286 | 287 | ~AsyncReader() { 288 | join_reader(); 289 | in_->close(); 290 | } 291 | 292 | /// Initialize reader by populating the line buffer. 293 | void start_reader() { 294 | read_buffer_.clear(); 295 | read_buffer_.reserve(buffer_size_); 296 | reached_eofs_ = false; 297 | 298 | reader_ = std::make_unique([this]() { 299 | while (read_buffer_.size() < buffer_size_) { 300 | reached_eof_ = in_->gets(line_c_str_.get(), MAX_LINE_LEN) == nullptr; 301 | if (reached_eof_) { 302 | // Reset file ptr to beginning of next file 303 | in_->close(); 304 | path_idx_ = (path_idx_ + 1) % fnames_.size(); 305 | 306 | if (path_idx_ == 0) { reached_eofs_ = true; } 307 | 308 | in_ = getfilehandler(fnames_[path_idx_], read_mode_); 309 | break; 310 | } 311 | 312 | Sentence s = parseline(line_c_str_.get()); 313 | read_buffer_.push_back(std::move(s)); 314 | } 315 | }); 316 | } 317 | 318 | void join_reader() { reader_->join(); } 319 | 320 | bool get_next(Sentences& s) override { 321 | // We want to return false when we cannot read at *current* invocation, 322 | // which means we reached EOF in previous invocation. reached_eof_prev_ 323 | // keeps track of that. 324 | if (reached_eofs_prev_) { 325 | reached_eofs_prev_ = false; 326 | return false; 327 | } 328 | 329 | join_reader(); 330 | 331 | reached_eofs_prev_ = reached_eofs_; 332 | s = std::move(read_buffer_); 333 | read_buffer_ = Sentences(); 334 | 335 | // While returning the batch of sentences, also immediately start reading 336 | // the next batch (read_buffer_) in the background 337 | start_reader(); 338 | 339 | return true; 340 | } 341 | }; 342 | 343 | } // namespace koan 344 | 345 | #endif 346 | -------------------------------------------------------------------------------- /koan/sample.h: -------------------------------------------------------------------------------- 1 | /* 2 | ** Copyright 2020 Bloomberg Finance L.P. 3 | ** 4 | ** Licensed under the Apache License, Version 2.0 (the "License"); 5 | ** you may not use this file except in compliance with the License. 6 | ** You may obtain a copy of the License at 7 | ** 8 | ** http://www.apache.org/licenses/LICENSE-2.0 9 | ** 10 | ** Unless required by applicable law or agreed to in writing, software 11 | ** distributed under the License is distributed on an "AS IS" BASIS, 12 | ** WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | ** See the License for the specific language governing permissions and 14 | ** limitations under the License. 15 | */ 16 | 17 | #ifndef KOAN_SAMPLE_H 18 | #define KOAN_SAMPLE_H 19 | 20 | #include 21 | #include 22 | #include 23 | 24 | #include "def.h" 25 | #include "util.h" 26 | 27 | namespace koan { 28 | 29 | /// Algorithm to sample from a fixed categorical distribution in constant time. 30 | /// Implements Vose's Alias Method as described in: 31 | /// https://www.keithschwarz.com/darts-dice-coins/ 32 | class AliasSampler { 33 | public: 34 | using Index = size_t; 35 | 36 | private: 37 | std::vector alias_; // alias class for each bucket 38 | std::vector prob_; // threshold for selecting the alias class 39 | std::uniform_int_distribution macro_dist_; 40 | std::uniform_real_distribution micro_dist_; 41 | std::minstd_rand rng_; 42 | size_t n_; 43 | 44 | /// Initialize alias table. Steps correspond to those listed in 45 | /// "Algorithm: Vose's Alias Method" of 46 | /// https://www.keithschwarz.com/darts-dice-coins/ 47 | /// 48 | /// @param[in] probs multinomial distribution to represent 49 | void init_alias_table(const std::vector& probs) { 50 | // Ensure this is a valid probability distribution 51 | KOAN_ASSERT(std::all_of( 52 | probs.begin(), probs.end(), [](Real p) { return p >= 0.0; })); 53 | Real probSum = std::accumulate(probs.begin(), probs.end(), 0.0); 54 | KOAN_ASSERT((0.9999 <= probSum) and (probSum <= 1.0001)); 55 | 56 | // Step 2 57 | std::vector small, large; 58 | 59 | // Step 3 60 | std::vector scaledProbs = probs; 61 | for (size_t i = 0; i < scaledProbs.size(); ++i) { scaledProbs[i] *= n_; } 62 | 63 | // Step 4 64 | for (size_t i = 0; i < scaledProbs.size(); ++i) { 65 | Real p_i = scaledProbs[i]; 66 | 67 | if (p_i < 1.0) { 68 | small.push_back(i); 69 | } else { 70 | large.push_back(i); 71 | } 72 | } 73 | 74 | // Step 5 75 | Index l, g; 76 | 77 | while (not(small.empty() or large.empty())) { 78 | l = small.back(); 79 | g = large.back(); 80 | small.pop_back(); 81 | large.pop_back(); 82 | 83 | prob_[l] = scaledProbs[l]; 84 | alias_[l] = g; 85 | scaledProbs[g] = (scaledProbs[g] + scaledProbs[l]) - 1; 86 | if (scaledProbs[g] < 1.0) { 87 | small.push_back(g); 88 | } else { 89 | large.push_back(g); 90 | } 91 | } 92 | 93 | // Step 6 94 | while (not large.empty()) { 95 | g = large.front(); 96 | large.erase(large.begin()); 97 | prob_[g] = 1.0; 98 | } 99 | 100 | // Step 7 101 | while (not small.empty()) { 102 | l = small.front(); 103 | small.erase(small.begin()); 104 | prob_[l] = 1.0; 105 | } 106 | } 107 | 108 | public: 109 | AliasSampler(const std::vector& probs) 110 | : alias_(probs.size(), 0), 111 | prob_(probs.size(), 0.0), 112 | macro_dist_(1, probs.size()), 113 | micro_dist_(0.0, 1.0), 114 | rng_(), 115 | n_(probs.size()) { 116 | init_alias_table(probs); 117 | } 118 | 119 | void set_seed(unsigned seed) { rng_.seed(seed); } 120 | 121 | Index sample() { 122 | Index bucket = macro_dist_(rng_) - 1; 123 | Real r = micro_dist_(rng_); 124 | if (r <= prob_[bucket]) { 125 | return bucket; 126 | } else { 127 | return alias_[bucket]; 128 | } 129 | } 130 | 131 | size_t num_classes() { return n_; } 132 | }; 133 | 134 | } // namespace koan 135 | 136 | #endif 137 | -------------------------------------------------------------------------------- /koan/sigmoid.h: -------------------------------------------------------------------------------- 1 | /* 2 | ** Copyright 2020 Bloomberg Finance L.P. 3 | ** 4 | ** Licensed under the Apache License, Version 2.0 (the "License"); 5 | ** you may not use this file except in compliance with the License. 6 | ** You may obtain a copy of the License at 7 | ** 8 | ** http://www.apache.org/licenses/LICENSE-2.0 9 | ** 10 | ** Unless required by applicable law or agreed to in writing, software 11 | ** distributed under the License is distributed on an "AS IS" BASIS, 12 | ** WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | ** See the License for the specific language governing permissions and 14 | ** limitations under the License. 15 | */ 16 | 17 | #ifndef KOAN_SIGMOID_H 18 | #define KOAN_SIGMOID_H 19 | 20 | #include 21 | #include 22 | #include "def.h" 23 | 24 | namespace koan { 25 | 26 | /// Sigmoid. Defaults to table lookup implementation unless checking gradient 27 | /// numerically. 28 | /// 29 | /// @param[in] x logit 30 | /// @returns $\sigma(x)$ 31 | Real sigmoid(Real x) { 32 | // Based on sigmoid(x) == tanh(x/2)/2 + 1/2 33 | // std::tanh can handle extremes correctly out-of-the-box, i.e. 34 | // tanh(-Inf) = -1 and tanh(Inf) = 1 instead of Inf or NaN. 35 | #ifdef KOAN_GRAD_CHECKING 36 | return std::fma(std::tanh(x * .5_R), .5_R, .5_R); 37 | #else 38 | static constexpr Real factor = 64_R, window = 8_R; 39 | static const auto table = [&]() { 40 | std::array ret; 41 | std::generate(ret.begin(), ret.end(), [i = -factor * window]() mutable { 42 | return std::fma(std::tanh(i++ / factor * .5_R), .5_R, .5_R); 43 | }); 44 | ret.front() = 0_R; 45 | ret.back() = 1_R; 46 | return ret; 47 | }(); 48 | static constexpr Real lo = -window, hi = window; 49 | static constexpr Real m = factor, a = factor * window; 50 | return table[size_t(std::fma(std::clamp(x, lo, hi), m, a))]; 51 | #endif 52 | }; 53 | 54 | } // namespace koan 55 | 56 | #endif 57 | -------------------------------------------------------------------------------- /koan/timer.h: -------------------------------------------------------------------------------- 1 | /* 2 | ** Copyright 2020 Bloomberg Finance L.P. 3 | ** 4 | ** Licensed under the Apache License, Version 2.0 (the "License"); 5 | ** you may not use this file except in compliance with the License. 6 | ** You may obtain a copy of the License at 7 | ** 8 | ** http://www.apache.org/licenses/LICENSE-2.0 9 | ** 10 | ** Unless required by applicable law or agreed to in writing, software 11 | ** distributed under the License is distributed on an "AS IS" BASIS, 12 | ** WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | ** See the License for the specific language governing permissions and 14 | ** limitations under the License. 15 | */ 16 | 17 | #ifndef KOAN_TIMER_H 18 | #define KOAN_TIMER_H 19 | 20 | #include 21 | 22 | namespace koan { 23 | 24 | class Timer { 25 | private: 26 | using Clock = std::chrono::steady_clock; 27 | Clock::time_point start_ = Clock::now(); 28 | 29 | public: 30 | long double s() const { 31 | using namespace std::chrono; 32 | return duration_cast>(Clock::now() - start_).count(); 33 | } 34 | }; 35 | 36 | } // namespace koan 37 | 38 | #endif 39 | -------------------------------------------------------------------------------- /koan/trainer.h: -------------------------------------------------------------------------------- 1 | /* 2 | ** Copyright 2020 Bloomberg Finance L.P. 3 | ** 4 | ** Licensed under the Apache License, Version 2.0 (the "License"); 5 | ** you may not use this file except in compliance with the License. 6 | ** You may obtain a copy of the License at 7 | ** 8 | ** http://www.apache.org/licenses/LICENSE-2.0 9 | ** 10 | ** Unless required by applicable law or agreed to in writing, software 11 | ** distributed under the License is distributed on an "AS IS" BASIS, 12 | ** WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | ** See the License for the specific language governing permissions and 14 | ** limitations under the License. 15 | */ 16 | 17 | #ifndef KOAN_TRAINER_H 18 | #define KOAN_TRAINER_H 19 | 20 | #include 21 | #include 22 | 23 | #include "def.h" 24 | #include "sample.h" 25 | #include "sigmoid.h" 26 | 27 | namespace koan { 28 | 29 | /// Main class to train CBOW and SG word embeddings by negative sampling. 30 | class Trainer { 31 | public: 32 | /// Salient parameters of Word2Vec training. 33 | struct Params { 34 | unsigned dim = 200; 35 | 36 | // One-sided context extension. E.g. 4 means 4 additional words 37 | // on each side: [ . . . . x . . . . ] 38 | unsigned ctxs = 5; 39 | 40 | // Number of negative targets for each positive target for negative 41 | // sampling 42 | unsigned negatives = 5; 43 | 44 | // Number of worker threads. This is mainly used to initialize necessary 45 | // data structures to avoid race conditions. Multithreading itself is 46 | // done outside of the Trainer class. 47 | unsigned threads = 8; 48 | 49 | bool use_bad_update = false; 50 | }; 51 | 52 | private: 53 | // Members 54 | Params params_; 55 | // Defines the probability of skipping each word, to downsample highly 56 | // frequent words 57 | std::vector filter_probs_; 58 | std::vector scratch_; // one per thread 59 | std::vector scratch2_; // one per thread 60 | std::vector gens_; // one per thread 61 | std::vector> dists_; // one per thread 62 | std::vector neg_samplers_; // one per thread 63 | 64 | Table& table_; // Input word embeddings (syn1) 65 | Table& ctx_; // Output word embeddings (syn0) 66 | 67 | public: 68 | /// Create trainer 69 | /// 70 | /// @param[in] params training parameters 71 | /// @param[in] table initial input word embeddings (syn1) 72 | /// @param[in] ctx initial output word embeddings (syn0) 73 | /// @param[in] filter_probs probability of skipping each word, downsampling 74 | /// frequent words 75 | /// @param[in] neg_probs negative sampling probability over vocabulary 76 | Trainer(Params params, 77 | Table& table, 78 | Table& ctx, 79 | std::vector filter_probs, 80 | const std::vector& neg_probs) 81 | : params_(params), 82 | filter_probs_(std::move(filter_probs)), 83 | scratch_(params_.threads), 84 | scratch2_(params_.threads), 85 | neg_samplers_(params_.threads, neg_probs), 86 | table_(table), 87 | ctx_(ctx) { 88 | for (unsigned i = 0; i < params_.threads; i++) { 89 | gens_.emplace_back(123457 + i); 90 | dists_.emplace_back(0., 1.); 91 | } 92 | } 93 | 94 | // Operations 95 | 96 | /// Update embeddings for a single input sentence, center word, and context 97 | /// window according to Continuous bag of words (CBOW) objective by negative 98 | /// sampling. 99 | /// 100 | /// @param[in] sent input sentence 101 | /// @param[in] center_idx index of the center word 102 | /// @param[in] left index of the leftmost context word (inclusive) 103 | /// @param[in] right index of the rightmost context word (exclusive) 104 | /// @param[in] tid thread index 105 | /// @param[in] lr current learning rate 106 | /// @param[in] compute_loss whether to also compute and return the CBOW loss. 107 | /// Used for numerically checking gradient. If false, will return 0.0 108 | Real cbow_update(const Sentence& sent, 109 | size_t center_idx, 110 | size_t left, 111 | size_t right, 112 | size_t tid, 113 | Real lr, 114 | bool compute_loss = false) { 115 | // ISSUE: Neither Mikolov's word2vec nor gensim seems to use the correct 116 | // gradient which requires normalization by the number of contexts (see 117 | // below). 118 | // 119 | // https://github.com/tmikolov/word2vec/blob/20c129af10659f7c50e86e3be406df663beff438/word2vec.c#L460 120 | // https://github.com/RaRe-Technologies/gensim/issues/697 121 | Real loss = 0; 122 | auto& center_word = ctx_[sent[center_idx]]; 123 | Vector& avg = scratch_[tid]; 124 | Vector& source_idx_grad = scratch2_[tid]; 125 | avg = Vector::Zero(center_word.size()); 126 | source_idx_grad = Vector::Zero(center_word.size()); 127 | 128 | // collect embeddings for context words 129 | static thread_local std::vector sources; 130 | sources.clear(); 131 | sources.reserve(right - left - 1); 132 | 133 | for (size_t source_idx = left; source_idx < right; source_idx++) { 134 | if (source_idx != center_idx) { 135 | auto& v = table_[sent[source_idx]]; 136 | avg += v; 137 | sources.push_back(&v); 138 | } 139 | } 140 | 141 | Real num_source_ids = static_cast(sources.size()); 142 | if (num_source_ids > 0.) { 143 | avg /= num_source_ids; 144 | 145 | // Update for positive sample 146 | // forward pass 147 | Real sig_pos = sigmoid(avg.dot(center_word)); 148 | if (compute_loss) { 149 | loss -= std::log(std::max(sig_pos, MIN_SIGMOID_IN_LOSS)); 150 | } 151 | // backward pass 152 | if (sig_pos < 1.) { 153 | if (params_.use_bad_update) { 154 | // ISSUE above, typical, wrong update! 155 | source_idx_grad += center_word * ((sig_pos - 1.) * lr); 156 | } else { 157 | // ISSUE above, must normalize by number of 158 | // context words when updating context embeddings 159 | source_idx_grad += 160 | center_word * ((sig_pos - 1.) * lr) / num_source_ids; 161 | } 162 | center_word -= avg * ((sig_pos - 1.) * lr); 163 | } 164 | 165 | // Updates for negative samples 166 | for (unsigned i = 0; i < params_.negatives; i++) { 167 | Word random_idx = neg_samplers_[tid].sample(); 168 | if (random_idx == center_idx) { continue; } 169 | auto& rw = ctx_[random_idx]; // random word 170 | // forward 171 | Real sig_neg = sigmoid(avg.dot(rw)); 172 | if (compute_loss) { 173 | loss -= std::log(std::max(1._R - sig_neg, MIN_SIGMOID_IN_LOSS)); 174 | } 175 | // backward 176 | if (sig_neg > 0.) { 177 | if (params_.use_bad_update) { 178 | // ISSUE above, typical, wrong update! 179 | source_idx_grad += rw * (sig_neg * lr); 180 | } else { 181 | // ISSUE above 182 | source_idx_grad += rw * (sig_neg * lr) / num_source_ids; 183 | } 184 | rw -= avg * (sig_neg * lr); 185 | } 186 | } 187 | for (auto source : sources) { // update each source (context) 188 | *source -= source_idx_grad; 189 | } 190 | } 191 | 192 | return loss; 193 | } 194 | 195 | /// Update embeddings for a single input sentence, center word, and context 196 | /// window according to Skipgram (SG) objective by negative sampling. 197 | /// 198 | /// @param[in] sent input sentence 199 | /// @param[in] center_idx index of the source center word 200 | /// @param[in] left index of the leftmost context word to predict (inclusive) 201 | /// @param[in] right index of the rightmost context word to predict 202 | /// (exclusive) 203 | /// @param[in] tid thread index 204 | /// @param[in] lr current learning rate 205 | /// @param[in] compute_loss whether to also compute and return the SG loss. 206 | /// Used for numerically checking the gradient. If false, will return 0.0 207 | Real sg_update(const Sentence& sent, 208 | size_t center_idx, 209 | size_t left, 210 | size_t right, 211 | size_t tid, 212 | Real lr, 213 | bool compute_loss = false) { 214 | Real loss = 0; 215 | auto& center_word = table_.at(sent[center_idx]); 216 | auto& cw_local = scratch_[tid]; 217 | cw_local = Vector::Zero(center_word.size()); 218 | 219 | // Predict each context word given the center 220 | for (size_t target_idx = left; target_idx < right; target_idx++) { 221 | if (target_idx != center_idx) { 222 | auto& target_word = ctx_.at(sent[target_idx]); 223 | // Update for positive sample 224 | // forward pass 225 | Real sig_pos = sigmoid(center_word.dot(target_word)); 226 | if (compute_loss) { 227 | loss -= std::log(std::max(sig_pos, MIN_SIGMOID_IN_LOSS)); 228 | } 229 | // backward pass 230 | if (sig_pos < 1.) { 231 | cw_local -= target_word * ((sig_pos - 1.) * lr); 232 | target_word -= center_word * ((sig_pos - 1.) * lr); 233 | } 234 | 235 | // Update for negative samples 236 | for (unsigned i = 0; i < params_.negatives; i++) { 237 | Word random_i = neg_samplers_[tid].sample(); 238 | auto& random_word = ctx_.at(random_i); // random word 239 | // forward 240 | Real sig_neg = sigmoid(center_word.dot(random_word)); 241 | if (compute_loss) { 242 | loss -= std::log(std::max(1 - sig_neg, MIN_SIGMOID_IN_LOSS)); 243 | } 244 | // backward 245 | if (sig_neg > 0.) { 246 | cw_local -= random_word * (sig_neg * lr); 247 | random_word -= center_word * (sig_neg * lr); 248 | } 249 | } 250 | } 251 | } 252 | // cw_local itself is a descent direction, so sign is += 253 | center_word += cw_local; 254 | return loss; 255 | } 256 | 257 | /// Update embeddings for an entire sentence: treat each word as the center in 258 | /// turn (modulo downsampled tokens), and sample variable context width. 259 | /// 260 | /// @param[in] sent_raw input sentence 261 | /// @param[in] tid thread index 262 | /// @param[in] lr learning rate for this instance 263 | /// @param[in] cbow true if using CBOW loss, else SG 264 | /// @returns number of tokens in the sentence after downsampling 265 | size_t train(const Sentence& sent_raw, size_t tid, Real lr, bool cbow) { 266 | static thread_local Sentence sent(INITIAL_SENTENCE_LEN); 267 | sent.clear(); 268 | sent.reserve(sent_raw.size()); 269 | for (auto& w : sent_raw) { // prob.at(w) is prob. to discard w 270 | if (dists_[tid](gens_[tid]) >= filter_probs_.at(w)) { sent.push_back(w); } 271 | } 272 | 273 | for (size_t center_idx = 0; center_idx < sent.size(); center_idx++) { 274 | // Sample a contexts width from 1 to maximum context width 275 | size_t ctxs = 1 + (gens_[tid]() % params_.ctxs); 276 | size_t left = center_idx > ctxs ? center_idx - ctxs : 0; 277 | size_t right = std::min(center_idx + ctxs + 1, sent.size()); 278 | 279 | if (cbow) { // cbow loss 280 | cbow_update(sent, center_idx, left, right, tid, lr); 281 | } else { // skipgram loss 282 | sg_update(sent, center_idx, left, right, tid, lr); 283 | } 284 | } 285 | 286 | return sent.size(); 287 | } 288 | }; 289 | 290 | } // namespace koan 291 | 292 | #endif 293 | -------------------------------------------------------------------------------- /koan/util.h: -------------------------------------------------------------------------------- 1 | /* 2 | ** Copyright 2020 Bloomberg Finance L.P. 3 | ** 4 | ** Licensed under the Apache License, Version 2.0 (the "License"); 5 | ** you may not use this file except in compliance with the License. 6 | ** You may obtain a copy of the License at 7 | ** 8 | ** http://www.apache.org/licenses/LICENSE-2.0 9 | ** 10 | ** Unless required by applicable law or agreed to in writing, software 11 | ** distributed under the License is distributed on an "AS IS" BASIS, 12 | ** WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | ** See the License for the specific language governing permissions and 14 | ** limitations under the License. 15 | */ 16 | 17 | #ifndef KOAN_UTIL_H 18 | #define KOAN_UTIL_H 19 | 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include 25 | #include 26 | #include 27 | #include 28 | #include 29 | #include 30 | 31 | namespace koan { 32 | 33 | std::string date_time(const std::string& format) { 34 | std::string ret(50, char()); 35 | std::time_t tt = std::time(nullptr); 36 | ret.resize(std::strftime( 37 | ret.data(), ret.size(), format.c_str(), std::localtime(&tt))); 38 | return ret; 39 | } 40 | 41 | template 42 | void split(std::vector& ret, const IN& s, char delim = ' ') { 43 | auto beg = s.begin(); 44 | while (beg < s.end()) { 45 | auto end = std::find(beg, s.end(), delim); 46 | if (beg != end) { ret.emplace_back(&*beg, end - beg); } 47 | beg = ++end; 48 | } 49 | } 50 | 51 | template 52 | auto split(const IN& s, char delim = ' ') { 53 | std::vector ret; 54 | split(ret, s, delim); 55 | return ret; 56 | } 57 | 58 | /// Parallel for implementation without any explicit allocation of elements per 59 | /// thread. 60 | /// 61 | /// @param[in] begin start index 62 | /// @param[in] end end index 63 | /// @param[in] f function to process each element 64 | /// @param[in] num_threads number of threads to run 65 | /// @tparam F callable that takes size_t elt_idx, size_t thread_idx as arguments 66 | template 67 | void parallel_for(size_t begin, size_t end, F f, size_t num_threads = 8) { 68 | std::vector threads(num_threads); 69 | std::atomic i = begin; 70 | for (size_t ti = 0; ti < num_threads; ti++) { 71 | auto& t = threads[ti]; 72 | t = std::thread([ti, &i, &f, &end]() { 73 | while (true) { 74 | size_t i_ = i++; 75 | if (i_ >= end) { break; } 76 | f(i_, ti); 77 | } 78 | }); 79 | } 80 | 81 | for (auto& t : threads) t.join(); 82 | } 83 | 84 | /// Parallel for implementation where each thread is allotted its own batch of 85 | /// elements to process up front. 86 | /// 87 | /// @param[in] begin start index 88 | /// @param[in] end end index 89 | /// @param[in] f function to process each element 90 | /// @param[in] num_threads number of threads to run 91 | /// @param[in] consecutive_alloc if true, allocate a contiguous block of 92 | /// elements to each thread 93 | /// @tparam F callable that takes size_t elt_idx, size_t thread_idx as arguments 94 | template 95 | void parallel_for_partitioned(size_t begin, 96 | size_t end, 97 | F f, 98 | size_t num_threads = 8, 99 | bool consecutive_alloc = true) { 100 | size_t total_size = end - begin; 101 | size_t batch_size = total_size / num_threads; 102 | std::vector threads(num_threads); 103 | for (size_t ti = 0; ti < num_threads; ti++) { 104 | auto& t = threads[ti]; 105 | if (consecutive_alloc) { 106 | t = std::thread([ti, &f, begin, end, batch_size, num_threads]() { 107 | size_t batch_start = begin + ti * batch_size; 108 | size_t batch_end = 109 | ti < (num_threads - 1) ? begin + (ti + 1) * batch_size : end; 110 | for (size_t i = batch_start; i < batch_end; ++i) { f(i, ti); } 111 | }); 112 | } else { 113 | t = std::thread([ti, &f, begin, end, num_threads]() { 114 | for (size_t i = begin + ti; i < end; i += num_threads) { f(i, ti); } 115 | }); 116 | } 117 | } 118 | 119 | for (auto& t : threads) { t.join(); } 120 | } 121 | 122 | class RuntimeError : public std::runtime_error { 123 | public: 124 | using runtime_error::runtime_error; 125 | }; 126 | 127 | }; // namespace koan 128 | 129 | #define KOAN_OVERLOAD(_1, _2, MACRO, ...) MACRO 130 | 131 | #define KOAN_ASSERT(...) \ 132 | KOAN_OVERLOAD(__VA_ARGS__, KOAN_ASSERT2, KOAN_ASSERT1)(__VA_ARGS__) 133 | 134 | #define KOAN_ASSERT2(statement, message) \ 135 | if (!(statement)) { throw koan::RuntimeError(message); } 136 | 137 | #define KOAN_ASSERT1(statement) \ 138 | KOAN_ASSERT2(statement, "Assertion " #statement " failed!") 139 | 140 | #endif 141 | -------------------------------------------------------------------------------- /tests/test_gradcheck.cpp: -------------------------------------------------------------------------------- 1 | #define CATCH_CONFIG_MAIN // so that Catch is responsible for main() 2 | #define KOAN_GRAD_CHECKING // Real == double and no lookup approx for sigmoid 3 | 4 | #include 5 | 6 | #include 7 | #include 8 | 9 | using namespace koan; 10 | 11 | TEST_CASE("Cbow", "[grad]") { 12 | static_assert(std::is_same::value); 13 | 14 | Table table, ctx; 15 | unsigned dim = 5; 16 | 17 | IndexMap word_map; 18 | word_map.insert("hello"); 19 | word_map.insert("world"); 20 | word_map.insert("!"); 21 | word_map.insert("."); 22 | 23 | // Prevent trainer from randomly dropping any word. 24 | std::vector filter_probs{0, 0, 0, 0}; 25 | 26 | // Force trainer to sample "." as the negative word. 27 | std::vector neg_probs{0, 0, 0, 1}; 28 | 29 | Sentence sent{0, 1, 2}; // hello world ! 30 | 31 | // Randomly initialize center and context word embedding tables. 32 | for (size_t i = 0; i < word_map.size(); i++) { 33 | table.push_back(Vector::Random(dim)); 34 | ctx.push_back(Vector::Random(dim)); 35 | } 36 | 37 | Trainer t( 38 | Trainer::Params{.dim = dim, .ctxs = 5, .negatives = 1, .threads = 1}, 39 | table, 40 | ctx, 41 | filter_probs, 42 | neg_probs); 43 | 44 | // Keep a copy of original weights 45 | Table table_orig(table), ctx_orig(ctx); 46 | 47 | t.cbow_update(sent, 48 | /*center*/ 1, 49 | /*left*/ 0, 50 | /*right*/ 3, 51 | /*tid*/ 0, 52 | /*lr*/ 1, 53 | /*compute_loss*/ true); 54 | 55 | // analytic gradients 56 | Table table_agrad(table), ctx_agrad(ctx); 57 | for (size_t i = 0; i < word_map.size(); i++) { 58 | table_agrad[i] = table_orig[i] - table[i]; 59 | ctx_agrad[i] = ctx_orig[i] - ctx[i]; 60 | } 61 | 62 | // Compute numeric gradients for every parameter 63 | Table table_ngrad(table_orig), ctx_ngrad(ctx_orig); 64 | 65 | table = table_orig; 66 | ctx = ctx_orig; 67 | 68 | // Two-sided numerical gradient: 69 | // http://deeplearning.stanford.edu/tutorial/supervised/DebuggingGradientChecking/ 70 | for (auto tab : {&table, &ctx}) { 71 | for (size_t i = 0; i < word_map.size(); i++) { 72 | for (unsigned j = 0; j < dim; j++) { 73 | const static Real eps = 1e-4; 74 | Real tmp = tab->at(i)[j]; 75 | tab->at(i)[j] += eps; 76 | Real loss_up = t.cbow_update(sent, 77 | /*center*/ 1, 78 | /*left*/ 0, 79 | /*right*/ 3, 80 | /*tid*/ 0, 81 | /*lr*/ 1, 82 | /*compute_loss*/ true); 83 | table = table_orig; 84 | ctx = ctx_orig; 85 | 86 | tab->at(i)[j] = tmp - eps; 87 | Real loss_down = t.cbow_update(sent, 88 | /*center*/ 1, 89 | /*left*/ 0, 90 | /*right*/ 3, 91 | /*tid*/ 0, 92 | /*lr*/ 1, 93 | /*compute_loss*/ true); 94 | table = table_orig; 95 | ctx = ctx_orig; 96 | 97 | Real num_grad = (loss_up - loss_down) / (2 * eps); 98 | if (tab == &table) { 99 | table_ngrad[i][j] = num_grad; 100 | } else { 101 | ctx_ngrad[i][j] = num_grad; 102 | } 103 | } 104 | } 105 | } 106 | 107 | // compare numeric and analytical gradients 108 | for (size_t i = 0; i < word_map.size(); i++) { 109 | for (unsigned j = 0; j < dim; j++) { 110 | CHECK(table_agrad[i][j] == Approx(table_ngrad[i][j])); 111 | CHECK(ctx_agrad[i][j] == Approx(ctx_ngrad[i][j])); 112 | } 113 | } 114 | } 115 | 116 | TEST_CASE("Skipgram", "[grad]") { 117 | static_assert(std::is_same::value); 118 | 119 | Table table, ctx; 120 | unsigned dim = 5; 121 | 122 | IndexMap word_map; 123 | word_map.insert("hello"); 124 | word_map.insert("world"); 125 | word_map.insert("!"); 126 | word_map.insert("."); 127 | 128 | // Prevent trainer from randomly dropping any word. 129 | std::vector filter_probs{0, 0, 0, 0}; 130 | 131 | // Force trainer to sample "." as the negative word. 132 | std::vector neg_probs{0, 0, 0, 1}; 133 | 134 | Sentence sent{0, 1}; // hello world 135 | 136 | // Randomly initialize center and context word embeddings. 137 | for (size_t i = 0; i < word_map.size(); i++) { 138 | table.push_back(Vector::Random(dim)); 139 | ctx.push_back(Vector::Random(dim)); 140 | } 141 | 142 | Trainer t( 143 | Trainer::Params{.dim = dim, .ctxs = 5, .negatives = 1, .threads = 1}, 144 | table, 145 | ctx, 146 | filter_probs, 147 | neg_probs); 148 | 149 | // Keep a copy of original weights 150 | Table table_orig(table), ctx_orig(ctx); 151 | 152 | t.sg_update(sent, 153 | /*center*/ 1, 154 | /*left*/ 0, 155 | /*right*/ 2, 156 | /*tid*/ 0, 157 | /*lr*/ 1, 158 | /*compute_loss*/ true); 159 | 160 | // analytic gradients 161 | Table table_agrad(table), ctx_agrad(ctx); 162 | for (size_t i = 0; i < word_map.size(); i++) { 163 | table_agrad[i] = table_orig[i] - table[i]; 164 | ctx_agrad[i] = ctx_orig[i] - ctx[i]; 165 | } 166 | 167 | // Compute numeric gradients for every parameter 168 | Table table_ngrad(table_orig), ctx_ngrad(ctx_orig); 169 | 170 | table = table_orig; 171 | ctx = ctx_orig; 172 | 173 | // Two-sided numerical gradient: 174 | // http://deeplearning.stanford.edu/tutorial/supervised/DebuggingGradientChecking/ 175 | for (auto tab : {&table, &ctx}) { 176 | for (size_t i = 0; i < word_map.size(); i++) { 177 | for (unsigned j = 0; j < dim; j++) { 178 | const static Real eps = 1e-4; 179 | Real tmp = tab->at(i)[j]; 180 | tab->at(i)[j] += eps; 181 | Real loss_up = t.sg_update(sent, 182 | /*center*/ 1, 183 | /*left*/ 0, 184 | /*right*/ 2, 185 | /*tid*/ 0, 186 | /*lr*/ 1, 187 | /*compute_loss*/ true); 188 | table = table_orig; 189 | ctx = ctx_orig; 190 | 191 | tab->at(i)[j] = tmp - eps; 192 | Real loss_down = t.sg_update(sent, 193 | /*center*/ 1, 194 | /*left*/ 0, 195 | /*right*/ 2, 196 | /*tid*/ 0, 197 | /*lr*/ 1, 198 | /*compute_loss*/ true); 199 | table = table_orig; 200 | ctx = ctx_orig; 201 | 202 | Real num_grad = (loss_up - loss_down) / (2 * eps); 203 | if (tab == &table) { 204 | table_ngrad[i][j] = num_grad; 205 | } else { 206 | ctx_ngrad[i][j] = num_grad; 207 | } 208 | } 209 | } 210 | } 211 | 212 | // compare numeric and analytical gradients 213 | for (size_t i = 0; i < word_map.size(); i++) { 214 | for (unsigned j = 0; j < dim; j++) { 215 | CHECK(table_agrad[i][j] == Approx(table_ngrad[i][j])); 216 | CHECK(ctx_agrad[i][j] == Approx(ctx_ngrad[i][j])); 217 | } 218 | } 219 | } 220 | -------------------------------------------------------------------------------- /tests/test_utils.cpp: -------------------------------------------------------------------------------- 1 | #define CATCH_CONFIG_MAIN // so that Catch is responsible for main() 2 | 3 | #include 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | using namespace koan; 14 | 15 | /// Draw n samples from sampler and record empirical distribution over classes. 16 | /// 17 | /// @param[in] sampler Already-initialized alias sampler 18 | /// @param[in] n Number of samples to draw 19 | /// @return empirical distribution for sampler 20 | std::vector sample_dist(AliasSampler sampler, size_t n = 10000000) { 21 | std::vector dist(sampler.num_classes(), 0.0); 22 | 23 | for (size_t i = 0; i < n; ++i) { ++dist[sampler.sample()]; } 24 | 25 | for (size_t i = 0; i < dist.size(); ++i) { dist[i] /= n; } 26 | 27 | return dist; 28 | } 29 | 30 | /// Test whether the probability of selecting any class from multinomial d2 is 31 | /// within 1% of the probability of selecting it under d1. 32 | /// 33 | /// @param[in] d1 First distribution 34 | /// @param[in] d2 Other distribution 35 | /// @return whether distributions are sufficiently close 36 | bool dists_are_close(const std::vector& d1, const std::vector& d2) { 37 | REQUIRE(d1.size() == d2.size()); 38 | 39 | for (size_t i = 0; i < d1.size(); ++i) { 40 | if (std::abs(d1[i] - d2[i]) >= (d1[i] * 0.01)) { return false; } 41 | } 42 | 43 | return true; 44 | } 45 | 46 | TEST_CASE("AliasSampler", "[sample]") { 47 | std::vector probs1(2, 0.5); 48 | AliasSampler sampler1(probs1); 49 | 50 | /// Make sure alias sampler faithfully represents multinomial distributions 51 | /// where all classes are equally probable. 52 | SECTION("Balanced binary distribution") { 53 | CHECK(dists_are_close(probs1, sample_dist(sampler1))); 54 | } 55 | 56 | std::vector probs2(10, 0.1); 57 | AliasSampler sampler2(probs2); 58 | 59 | SECTION("Balanced 10-class") { 60 | CHECK(dists_are_close(probs2, sample_dist(sampler2))); 61 | } 62 | 63 | std::vector probs3(50, 0.02); 64 | AliasSampler sampler3(probs3); 65 | 66 | SECTION("Balanced 50-class") { 67 | CHECK(dists_are_close(probs3, sample_dist(sampler3))); 68 | } 69 | 70 | /// Make sure alias sampler faithfully represents multinomial distributions 71 | /// where some classes are much more probable than others. 72 | std::vector probs4{0.1, 0.9}; 73 | AliasSampler sampler4(probs4); 74 | 75 | SECTION("Unbalanced binary") { 76 | CHECK(dists_are_close(probs4, sample_dist(sampler4))); 77 | } 78 | 79 | std::vector probs5{ 80 | 0.02, 0.02, 0.02, 0.02, 0.02, 0.1, 0.2, 0.2, 0.2, 0.2}; 81 | AliasSampler sampler5(probs5); 82 | 83 | SECTION("Unbalanced 10-class") { 84 | CHECK(dists_are_close(probs5, sample_dist(sampler5))); 85 | } 86 | } 87 | 88 | TEST_CASE("IndexMap", "[indexmap]") { 89 | IndexMap imap; 90 | 91 | imap.insert("hello"); 92 | imap.insert("world"); 93 | 94 | CHECK(imap.size() == 2); 95 | CHECK(imap.has("hello")); 96 | CHECK(imap.has("world")); 97 | CHECK(not imap.has("!")); 98 | 99 | CHECK(imap.lookup("hello") == 0); 100 | CHECK(imap.lookup("world") == 1); 101 | CHECK(imap.reverse_lookup(0) == "hello"); 102 | CHECK(imap.reverse_lookup(1) == "world"); 103 | 104 | CHECK_THROWS(imap.lookup("!")); 105 | CHECK_THROWS(imap.reverse_lookup(2)); 106 | 107 | SECTION("Insert new") { 108 | imap.insert("!"); 109 | 110 | CHECK(imap.size() == 3); 111 | CHECK(imap.has("!")); 112 | CHECK(imap.lookup("!") == 2); 113 | CHECK(imap.reverse_lookup(2) == "!"); 114 | } 115 | 116 | SECTION("Insert dupe") { 117 | imap.insert("hello"); 118 | 119 | CHECK(imap.size() == 2); 120 | CHECK(imap.has("hello")); 121 | CHECK(imap.has("world")); 122 | CHECK(imap.lookup("hello") == 0); 123 | CHECK(imap.lookup("world") == 1); 124 | CHECK(imap.reverse_lookup(0) == "hello"); 125 | CHECK(imap.reverse_lookup(1) == "world"); 126 | } 127 | 128 | SECTION("Clear") { 129 | imap.clear(); 130 | 131 | CHECK(imap.size() == 0); 132 | CHECK(not imap.has("hello")); 133 | CHECK(not imap.has("world")); 134 | CHECK_THROWS(imap.lookup("hello")); 135 | CHECK_THROWS(imap.lookup("world")); 136 | CHECK_THROWS(imap.reverse_lookup(0)); 137 | CHECK_THROWS(imap.reverse_lookup(1)); 138 | } 139 | } 140 | -------------------------------------------------------------------------------- /word2vec_train_times_cbow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bloomberg/koan/c22fccdd6f359b5e7f4889b9969486ed27f76894/word2vec_train_times_cbow.png -------------------------------------------------------------------------------- /word2vec_train_times_sg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bloomberg/koan/c22fccdd6f359b5e7f4889b9969486ed27f76894/word2vec_train_times_sg.png --------------------------------------------------------------------------------