├── .github └── workflows │ ├── codeql-analysis.yml │ └── maven.yml ├── .gitignore ├── LICENSE ├── README.md ├── pom.xml ├── src └── de │ └── jungblut │ └── glove │ ├── GloveRandomAccessReader.java │ ├── GloveStreamReader.java │ ├── GloveWriter.java │ ├── examples │ ├── NearestNeighbourMain.java │ ├── TextToBinaryConverterMain.java │ └── VectorLookupMain.java │ ├── impl │ ├── CachedGloveBinaryRandomAccessReader.java │ ├── GloveBinaryRandomAccessReader.java │ ├── GloveBinaryReader.java │ ├── GloveBinaryWriter.java │ └── GloveTextReader.java │ └── util │ ├── IOUtils.java │ ├── StringVectorPair.java │ └── WritableUtils.java └── test └── de └── jungblut └── glove ├── GloveTestUtils.java └── impl ├── GloveBinaryRandomAccessReaderTest.java ├── GloveBinaryReaderWriterTest.java └── GloveTextReaderTest.java /.github/workflows/codeql-analysis.yml: -------------------------------------------------------------------------------- 1 | # For most projects, this workflow file will not need changing; you simply need 2 | # to commit it to your repository. 3 | # 4 | # You may wish to alter this file to override the set of languages analyzed, 5 | # or to provide custom queries or build logic. 6 | name: "CodeQL" 7 | 8 | on: 9 | push: 10 | branches: [master] 11 | pull_request: 12 | # The branches below must be a subset of the branches above 13 | branches: [master] 14 | schedule: 15 | - cron: '0 4 * * 1' 16 | 17 | jobs: 18 | analyze: 19 | name: Analyze 20 | runs-on: ubuntu-latest 21 | 22 | strategy: 23 | fail-fast: false 24 | matrix: 25 | # Override automatic language detection by changing the below list 26 | # Supported options are ['csharp', 'cpp', 'go', 'java', 'javascript', 'python'] 27 | language: ['java'] 28 | # Learn more... 29 | # https://docs.github.com/en/github/finding-security-vulnerabilities-and-errors-in-your-code/configuring-code-scanning#overriding-automatic-language-detection 30 | 31 | steps: 32 | - name: Checkout repository 33 | uses: actions/checkout@v2 34 | with: 35 | # We must fetch at least the immediate parents so that if this is 36 | # a pull request then we can checkout the head. 37 | fetch-depth: 2 38 | 39 | # If this run was triggered by a pull request event, then checkout 40 | # the head of the pull request instead of the merge commit. 41 | - run: git checkout HEAD^2 42 | if: ${{ github.event_name == 'pull_request' }} 43 | 44 | # Initializes the CodeQL tools for scanning. 45 | - name: Initialize CodeQL 46 | uses: github/codeql-action/init@v1 47 | with: 48 | languages: ${{ matrix.language }} 49 | # If you wish to specify custom queries, you can do so here or in a config file. 50 | # By default, queries listed here will override any specified in a config file. 51 | # Prefix the list here with "+" to use these queries and those in the config file. 52 | # queries: ./path/to/local/query, your-org/your-repo/queries@main 53 | 54 | # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). 55 | # If this step fails, then you should remove it and run the build manually (see below) 56 | - name: Autobuild 57 | uses: github/codeql-action/autobuild@v1 58 | 59 | # ℹ️ Command-line programs to run using the OS shell. 60 | # 📚 https://git.io/JvXDl 61 | 62 | # ✏️ If the Autobuild fails above, remove it and uncomment the following three lines 63 | # and modify them (or add more) to build your code if your project 64 | # uses a compiled language 65 | 66 | #- run: | 67 | # make bootstrap 68 | # make release 69 | 70 | - name: Perform CodeQL Analysis 71 | uses: github/codeql-action/analyze@v1 72 | -------------------------------------------------------------------------------- /.github/workflows/maven.yml: -------------------------------------------------------------------------------- 1 | # This workflow will build a Java project with Maven 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/building-and-testing-java-with-maven 3 | 4 | name: Java CI with Maven 5 | 6 | on: 7 | push: 8 | branches: [ master ] 9 | pull_request: 10 | branches: [ master ] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | 17 | steps: 18 | - uses: actions/checkout@v2 19 | - name: Set up JDK 1.8 20 | uses: actions/setup-java@v1 21 | with: 22 | java-version: 1.8 23 | - name: Build with Maven 24 | run: mvn clean package install -Dgpg.skip=true 25 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.class 2 | 3 | # Mobile Tools for Java (J2ME) 4 | .mtj.tmp/ 5 | 6 | # Package Files # 7 | *.jar 8 | *.war 9 | *.ear 10 | 11 | # virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml 12 | hs_err_pid* 13 | 14 | /bin 15 | /.settings 16 | /.idea 17 | /.project 18 | /.classpath 19 | /target 20 | /lib 21 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright {yyyy} {name of copyright owner} 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This project is a convenience Java wrapper around GloVe word vectors and converter to more space efficient binary files, which also includes a random access lookup for very large amount of vectors on disk. 2 | 3 | Maven 4 | ----- 5 | 6 | If you use maven, you can get the latest release using the following dependency: 7 | 8 | ``` 9 | 10 | de.jungblut.glove 11 | glove 12 | 0.3 13 | 14 | ``` 15 | 16 | To use the library from the command line (with a fat jar) you can compile it with the instructions at the bottom. 17 | If you'd like to use it from the code directly, below examples give away the classes that you can look into. 18 | 19 | Converting the text files to binary 20 | ----------------------------------- 21 | 22 | To use the power of the library and save some disk space you should rewrite the file to binary first. 23 | This can be done by supplying the text file and an output folder: 24 | 25 | > java -cp glove-0.3-jar-with-dependencies.jar de.jungblut.glove.examples.TextToBinaryConverterMain glove-vectors.txt glove-binary 26 | 27 | Now you should have a "glove-binary" folder with two files in it, a smaller "dict.bin" and a bigger "vectors.bin". 28 | 29 | 30 | Using it for random access reads 31 | ----------------------------------- 32 | 33 | You can use the GloveRandomAccessReader to get the vector for a string fast and without loading all the vectors into memory. 34 | 35 | Using my math library it is also easy to do the typical vector computations. 36 | 37 | ```java 38 | 39 | GloveRandomAccessReader db = new GloveBinaryRandomAccessReader( 40 | Paths.get("glove-binary")); 41 | 42 | DoubleVector king = db.get("king"); 43 | DoubleVector man = db.get("man"); 44 | 45 | DoubleVector queen = db.get("queen"); 46 | DoubleVector woman = db.get("woman"); 47 | 48 | CosineDistance cos = new CosineDistance(); 49 | 50 | DoubleVector diff = king.subtract(man).add(woman); 51 | 52 | double dist = cos.measureDistance(diff, queen); 53 | System.out.println("dist queen = " + dist); 54 | 55 | dist = cos.measureDistance(diff, db.get("royal")); 56 | System.out.println("dist royal = " + dist); 57 | 58 | ``` 59 | 60 | You can execute the above using 61 | 62 | > java -cp glove-0.3-jar-with-dependencies.jar de.jungblut.glove.examples.VectorLookupMain glove-binary 63 | 64 | Output is: 65 | 66 | ``` 67 | dist queen = 0.24690873337939978 68 | dist royal = 0.30120073399624914 69 | 70 | ``` 71 | 72 | Nearest Neighbour Queries 73 | ------------------------- 74 | 75 | You can also do efficient nearest neighbour queries using a KD-Tree. The full code can be found in de.jungblut.glove.examples.NearestNeighbourMain. 76 | 77 | You can also run it with an "interactive" menu like this: 78 | 79 | > java -cp glove-0.3-jar-with-dependencies.jar de.jungblut.glove.examples.NearestNeighbourMain glove-binary 80 | 81 | Keep in mind that this takes up quite some memory since the KD-Tree needs some space, but the queries are fast. 82 | 83 | Some example output on the small twitter file: 84 | 85 | ``` 86 | Reading... 87 | Balancing the KD tree... 88 | Finished, input your word to find its nearest neighbours 89 | rt 90 | Searching....done. Took 850 millis. 91 | 1.3282014981649486 92 | : 2.0140673085358696 93 | ? 2.439117425083601 94 | --- 2.441400270469679 95 | " 2.45020842228818 96 | 97 | yolo 98 | Searching....done. Took 758 millis. 99 | wtf 2.1469371352219953 100 | swag 2.1752410454311986 101 | lolz 2.2263784996001705 102 | loser 2.2666981122295806 103 | ew 2.308925723645761 104 | 105 | ``` 106 | 107 | 108 | Binary File Layout 109 | ================== 110 | 111 | Dictionary 112 | ---------- 113 | 114 | The dictionary writes (per string-vector pair): 115 | - UTF-8 string 116 | - vlong offset of where the start of the vector is in the vector file 117 | 118 | 119 | Vectors 120 | ------- 121 | 122 | The vector file contains (per string-vector pair): 123 | - the vector content encoded as a sequence of 4 byte floats 124 | - each float is encoded using Float.floatToIntBits 125 | 126 | 127 | License 128 | =================== 129 | 130 | Since I am Apache committer, I consider everything inside of this repository 131 | licensed by Apache 2.0 license, although I haven't put the usual header into the source files. 132 | 133 | If something is not licensed via Apache 2.0, there is a reference or an additional licence header included in the specific source file. 134 | 135 | 136 | Build 137 | =================== 138 | 139 | To build locally, you will need at least Java 8 to build this library. 140 | 141 | You can simply build with: 142 | 143 | > mvn clean package install 144 | 145 | The created jars contains debuggable code + sources + javadocs. 146 | 147 | If you want to skip testcases you can use: 148 | 149 | > mvn clean package install -DskipTests 150 | 151 | If you want to skip the signing process you can do: 152 | 153 | > mvn clean package install -Dgpg.skip=true 154 | 155 | 156 | -------------------------------------------------------------------------------- /pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4.0.0 4 | de.jungblut.glove 5 | glove 6 | Glove Utilities 7 | some utils to read and write glove files 8 | https://github.com/thomasjungblut/autosummary 9 | 0.4-SNAPSHOT 10 | jar 11 | 12 | 3.0 13 | 14 | 15 | 16 | UTF-8 17 | 18 | 19 | 20 | 21 | Apache 2 22 | http://www.apache.org/licenses/LICENSE-2.0.txt 23 | 24 | 25 | 26 | 27 | 28 | Thomas Jungblut 29 | thomas.jungblut@gmail.com 30 | tjungblut 31 | http://codingwiththomas.blogspot.com/ 32 | 33 | 34 | 35 | 36 | scm:git:git@github.com:thomasjungblut/glove.git 37 | scm:git:git@github.com:thomasjungblut/glove.git 38 | git@github.com:thomasjungblut/glove.git 39 | HEAD 40 | 41 | 42 | 43 | 44 | ossrh 45 | https://oss.sonatype.org/content/repositories/snapshots 46 | 47 | 48 | 49 | 50 | 51 | com.google.guava 52 | guava 53 | 29.0-jre 54 | 55 | 56 | net.sf.trove4j 57 | trove4j 58 | 3.0.2 59 | 60 | 61 | de.jungblut.common 62 | thomasjungblut-common 63 | 1.0 64 | 65 | 66 | de.jungblut.math 67 | tjungblut-math 68 | 1.2 69 | 70 | 71 | 72 | 73 | src/ 74 | test/ 75 | 76 | 77 | src/ 78 | 79 | 80 | glove-${project.version} 81 | 82 | 83 | org.apache.maven.plugins 84 | maven-compiler-plugin 85 | 2.3.2 86 | 87 | 1.8 88 | 1.8 89 | true 90 | true 91 | 92 | 93 | 94 | org.apache.maven.plugins 95 | maven-source-plugin 96 | 2.2.1 97 | 98 | 99 | attach-sources 100 | 101 | jar-no-fork 102 | 103 | 104 | 105 | 106 | 107 | org.apache.maven.plugins 108 | maven-javadoc-plugin 109 | 2.9.1 110 | 111 | 112 | attach-javadocs 113 | 114 | jar 115 | 116 | 117 | 118 | 119 | -Xdoclint:none 120 | 121 | 122 | 123 | org.apache.maven.plugins 124 | maven-surefire-plugin 125 | 2.11 126 | 127 | 128 | org.apache.maven.surefire 129 | surefire-junit47 130 | 2.11 131 | 132 | 133 | 134 | methods 135 | 8 136 | pertest 137 | 138 | 139 | 140 | test 141 | install 142 | 143 | test 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | org.apache.maven.plugins 152 | maven-shade-plugin 153 | 2.4.1 154 | 155 | true 156 | jar-with-dependencies 157 | 158 | 159 | 160 | package 161 | 162 | shade 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | org.sonatype.plugins 171 | nexus-staging-maven-plugin 172 | 1.6.3 173 | true 174 | 175 | ossrh 176 | https://oss.sonatype.org/ 177 | true 178 | 179 | 180 | 181 | org.apache.maven.plugins 182 | maven-gpg-plugin 183 | 1.5 184 | 185 | 186 | sign-artifacts 187 | verify 188 | 189 | sign 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | -------------------------------------------------------------------------------- /src/de/jungblut/glove/GloveRandomAccessReader.java: -------------------------------------------------------------------------------- 1 | package de.jungblut.glove; 2 | 3 | import java.io.IOException; 4 | 5 | import de.jungblut.math.DoubleVector; 6 | 7 | public interface GloveRandomAccessReader { 8 | 9 | /** 10 | * @return true if the glove reader contains this word. 11 | */ 12 | public boolean contains(String word); 13 | 14 | /** 15 | * @return the word or null if it doesn't exists. 16 | */ 17 | public DoubleVector get(String word) throws IOException; 18 | 19 | } 20 | -------------------------------------------------------------------------------- /src/de/jungblut/glove/GloveStreamReader.java: -------------------------------------------------------------------------------- 1 | package de.jungblut.glove; 2 | 3 | import java.io.IOException; 4 | import java.nio.file.Path; 5 | import java.util.stream.Stream; 6 | 7 | import de.jungblut.glove.util.StringVectorPair; 8 | 9 | public interface GloveStreamReader { 10 | 11 | /** 12 | * Streams over the glove file/directory in the given path. 13 | * 14 | * @param input the path to the glove files or directory (defined by the 15 | * implementation). 16 | * @return a lazy evaluated stream of the glove file. 17 | * @throws IOException file not found, or io errors. 18 | */ 19 | public Stream stream(Path input) throws IOException; 20 | 21 | } 22 | -------------------------------------------------------------------------------- /src/de/jungblut/glove/GloveWriter.java: -------------------------------------------------------------------------------- 1 | package de.jungblut.glove; 2 | 3 | import java.io.IOException; 4 | import java.nio.file.Path; 5 | import java.util.stream.Stream; 6 | 7 | import de.jungblut.glove.util.StringVectorPair; 8 | 9 | public interface GloveWriter { 10 | 11 | /** 12 | * A writer for a stream of StringVectorPairs. This is mainly used to rewrite 13 | * text to binary files or vice versa. 14 | * 15 | * @param stream the stream of elements to write. 16 | * @param output depending on the implementation, either a file or a folder. 17 | * @throws IOException file/directory doesn't exist, isn't writable, or other 18 | * io errors. 19 | */ 20 | public void writeStream(Stream stream, Path output) 21 | throws IOException; 22 | 23 | } 24 | -------------------------------------------------------------------------------- /src/de/jungblut/glove/examples/NearestNeighbourMain.java: -------------------------------------------------------------------------------- 1 | package de.jungblut.glove.examples; 2 | 3 | import java.io.IOException; 4 | import java.nio.file.Path; 5 | import java.nio.file.Paths; 6 | import java.util.Collections; 7 | import java.util.List; 8 | import java.util.Scanner; 9 | import java.util.stream.Stream; 10 | 11 | import de.jungblut.datastructure.KDTree; 12 | import de.jungblut.datastructure.KDTree.VectorDistanceTuple; 13 | import de.jungblut.glove.GloveRandomAccessReader; 14 | import de.jungblut.glove.impl.CachedGloveBinaryRandomAccessReader; 15 | import de.jungblut.glove.impl.GloveBinaryRandomAccessReader; 16 | import de.jungblut.glove.impl.GloveBinaryReader; 17 | import de.jungblut.glove.util.StringVectorPair; 18 | import de.jungblut.math.DoubleVector; 19 | 20 | public class NearestNeighbourMain { 21 | 22 | public static void main(String[] args) throws IOException { 23 | 24 | if (args.length != 1) { 25 | System.err.println("first argument needs to be the binary glove folder"); 26 | System.exit(1); 27 | } 28 | 29 | Path dir = Paths.get(args[0]); 30 | 31 | System.out.println("Reading..."); 32 | GloveRandomAccessReader reader = new CachedGloveBinaryRandomAccessReader( 33 | new GloveBinaryRandomAccessReader(dir), 100l); 34 | final KDTree tree = new KDTree<>(); 35 | 36 | try (Stream stream = new GloveBinaryReader().stream(dir)) { 37 | stream.forEach((pair) -> { 38 | tree.add(pair.vector, pair.value); 39 | }); 40 | 41 | } 42 | 43 | System.out.println("Balancing the KD tree..."); 44 | tree.balanceBySort(); 45 | 46 | System.out 47 | .println("Finished, input your word to find its nearest neighbours"); 48 | 49 | @SuppressWarnings("resource") 50 | Scanner scanner = new Scanner(System.in); 51 | 52 | while (true) { 53 | String nextLine = scanner.nextLine(); 54 | 55 | if (nextLine.equals("q")) { 56 | return; 57 | } 58 | 59 | DoubleVector v = reader.get(nextLine); 60 | if (v == null) { 61 | System.err.println("doesn't exist"); 62 | } else { 63 | System.out.print("Searching...."); 64 | long start = System.currentTimeMillis(); 65 | List> nearestNeighbours = tree 66 | .getNearestNeighbours(v, 6); 67 | 68 | // sort and remove the one that we searched for 69 | Collections.sort(nearestNeighbours, Collections.reverseOrder()); 70 | // the best hit is usually the same item with distance 0 71 | if (nearestNeighbours.get(0).getValue().equals(nextLine)) { 72 | nearestNeighbours.remove(0); 73 | } 74 | 75 | System.out.println("done. Took " + (System.currentTimeMillis() - start) 76 | + " millis."); 77 | for (VectorDistanceTuple tuple : nearestNeighbours) { 78 | System.out.println(tuple.getValue() + "\t" + tuple.getDistance()); 79 | } 80 | System.out.println(); 81 | } 82 | 83 | } 84 | 85 | } 86 | } 87 | -------------------------------------------------------------------------------- /src/de/jungblut/glove/examples/TextToBinaryConverterMain.java: -------------------------------------------------------------------------------- 1 | package de.jungblut.glove.examples; 2 | 3 | import java.io.IOException; 4 | import java.nio.file.Paths; 5 | import java.util.stream.Stream; 6 | 7 | import de.jungblut.glove.impl.GloveBinaryWriter; 8 | import de.jungblut.glove.impl.GloveTextReader; 9 | import de.jungblut.glove.util.StringVectorPair; 10 | 11 | public class TextToBinaryConverterMain { 12 | 13 | public static void main(String[] args) throws IOException { 14 | 15 | if (args.length != 2) { 16 | System.err 17 | .println("first argument needs to be the glove text file, the second needs to be the output folder of the binary files."); 18 | System.exit(1); 19 | } 20 | 21 | GloveTextReader reader = new GloveTextReader(); 22 | Stream stream = reader.stream(Paths.get(args[0])); 23 | GloveBinaryWriter writer = new GloveBinaryWriter(); 24 | writer.writeStream(stream, Paths.get(args[1])); 25 | 26 | } 27 | 28 | } 29 | -------------------------------------------------------------------------------- /src/de/jungblut/glove/examples/VectorLookupMain.java: -------------------------------------------------------------------------------- 1 | package de.jungblut.glove.examples; 2 | 3 | import java.io.IOException; 4 | import java.nio.file.Paths; 5 | 6 | import de.jungblut.distance.CosineDistance; 7 | import de.jungblut.glove.GloveRandomAccessReader; 8 | import de.jungblut.glove.impl.GloveBinaryRandomAccessReader; 9 | import de.jungblut.math.DoubleVector; 10 | 11 | public class VectorLookupMain { 12 | 13 | public static void main(String[] args) throws IOException { 14 | 15 | if (args.length != 1) { 16 | System.err 17 | .println("only argument should be the path to the binary glove folder"); 18 | System.exit(1); 19 | } 20 | 21 | GloveRandomAccessReader db = new GloveBinaryRandomAccessReader( 22 | Paths.get(args[0])); 23 | 24 | DoubleVector king = db.get("king"); 25 | DoubleVector man = db.get("man"); 26 | 27 | DoubleVector queen = db.get("queen"); 28 | DoubleVector woman = db.get("woman"); 29 | 30 | CosineDistance cos = new CosineDistance(); 31 | 32 | DoubleVector diff = king.subtract(man).add(woman); 33 | 34 | double dist = cos.measureDistance(diff, queen); 35 | System.out.println("dist queen = " + dist); 36 | 37 | dist = cos.measureDistance(diff, db.get("royal")); 38 | System.out.println("dist royal = " + dist); 39 | 40 | } 41 | 42 | } 43 | -------------------------------------------------------------------------------- /src/de/jungblut/glove/impl/CachedGloveBinaryRandomAccessReader.java: -------------------------------------------------------------------------------- 1 | package de.jungblut.glove.impl; 2 | 3 | import java.io.IOException; 4 | 5 | import com.google.common.cache.Cache; 6 | import com.google.common.cache.CacheBuilder; 7 | 8 | import de.jungblut.glove.GloveRandomAccessReader; 9 | import de.jungblut.math.DoubleVector; 10 | 11 | public class CachedGloveBinaryRandomAccessReader implements 12 | GloveRandomAccessReader { 13 | 14 | private final GloveRandomAccessReader reader; 15 | private Cache cache; 16 | 17 | public CachedGloveBinaryRandomAccessReader(GloveRandomAccessReader reader, 18 | long maxCacheSize) { 19 | this.reader = reader; 20 | this.cache = CacheBuilder.newBuilder().maximumSize(maxCacheSize).build(); 21 | } 22 | 23 | @Override 24 | public boolean contains(String word) { 25 | return reader.contains(word); 26 | } 27 | 28 | @Override 29 | public DoubleVector get(String word) throws IOException { 30 | 31 | DoubleVector ret = cache.getIfPresent(word); 32 | 33 | if (ret == null) { 34 | if (reader.contains(word)) { 35 | ret = reader.get(word); 36 | cache.put(word, ret); 37 | } 38 | } 39 | 40 | return ret; 41 | } 42 | 43 | } 44 | -------------------------------------------------------------------------------- /src/de/jungblut/glove/impl/GloveBinaryRandomAccessReader.java: -------------------------------------------------------------------------------- 1 | package de.jungblut.glove.impl; 2 | 3 | import gnu.trove.map.hash.TObjectLongHashMap; 4 | 5 | import java.io.BufferedInputStream; 6 | import java.io.DataInputStream; 7 | import java.io.EOFException; 8 | import java.io.FileInputStream; 9 | import java.io.FileNotFoundException; 10 | import java.io.IOException; 11 | import java.io.RandomAccessFile; 12 | import java.nio.MappedByteBuffer; 13 | import java.nio.channels.FileChannel.MapMode; 14 | import java.nio.file.Path; 15 | 16 | import de.jungblut.glove.GloveRandomAccessReader; 17 | import de.jungblut.glove.util.WritableUtils; 18 | import de.jungblut.math.DoubleVector; 19 | import de.jungblut.math.dense.DenseDoubleVector; 20 | 21 | public class GloveBinaryRandomAccessReader implements GloveRandomAccessReader { 22 | 23 | private final TObjectLongHashMap dictMap = new TObjectLongHashMap<>(); 24 | private RandomAccessFile raf; 25 | private long size; 26 | 27 | public GloveBinaryRandomAccessReader(Path gloveBinaryFolder) 28 | throws IOException { 29 | 30 | Path dict = gloveBinaryFolder.resolve(GloveBinaryWriter.DICT_FILE); 31 | Path vectors = gloveBinaryFolder.resolve(GloveBinaryWriter.VECTORS_FILE); 32 | 33 | initLookup(dict); 34 | initBufferedFile(vectors); 35 | } 36 | 37 | private void initBufferedFile(Path vectors) throws FileNotFoundException { 38 | raf = new RandomAccessFile(vectors.toFile(), "r"); 39 | } 40 | 41 | private void initLookup(Path dict) throws IOException { 42 | try (DataInputStream in = new DataInputStream(new BufferedInputStream( 43 | new FileInputStream(dict.toFile())))) { 44 | 45 | long lastBlock = -1; 46 | size = -1; 47 | while (true) { 48 | String s = in.readUTF(); 49 | long off = WritableUtils.readVLong(in); 50 | 51 | if (lastBlock == -1) { 52 | lastBlock = off; 53 | } else { 54 | if (size == -1) { 55 | size = off; 56 | } 57 | if (off - lastBlock != size) { 58 | throw new IOException( 59 | "Dictionary is corrupted, blocking isn't exact. Expected blocks of " 60 | + size); 61 | } 62 | 63 | lastBlock = off; 64 | } 65 | 66 | dictMap.put(s, off); 67 | 68 | } 69 | 70 | } catch (EOFException e) { 71 | // expected 72 | } 73 | } 74 | 75 | @Override 76 | public boolean contains(String word) { 77 | return dictMap.containsKey(word); 78 | } 79 | 80 | @Override 81 | public DoubleVector get(String word) throws IOException { 82 | 83 | if (!contains(word)) { 84 | return null; 85 | } 86 | 87 | long offset = dictMap.get(word); 88 | 89 | // page the block in, read from it and wrap as a vector 90 | MappedByteBuffer buf = raf.getChannel() 91 | .map(MapMode.READ_ONLY, offset, size); 92 | 93 | return parse(buf); 94 | } 95 | 96 | private DoubleVector parse(MappedByteBuffer buf) { 97 | int dim = (int) (size / 4); 98 | DoubleVector v = new DenseDoubleVector(dim); 99 | 100 | for (int i = 0; i < v.getDimension(); i++) { 101 | int n = buf.getInt(); 102 | v.set(i, Float.intBitsToFloat(n)); 103 | } 104 | 105 | return v; 106 | } 107 | } 108 | -------------------------------------------------------------------------------- /src/de/jungblut/glove/impl/GloveBinaryReader.java: -------------------------------------------------------------------------------- 1 | package de.jungblut.glove.impl; 2 | 3 | import java.io.BufferedInputStream; 4 | import java.io.Closeable; 5 | import java.io.DataInputStream; 6 | import java.io.EOFException; 7 | import java.io.FileInputStream; 8 | import java.io.IOException; 9 | import java.nio.ByteBuffer; 10 | import java.nio.file.Path; 11 | import java.util.Spliterator; 12 | import java.util.Spliterators; 13 | import java.util.stream.Stream; 14 | import java.util.stream.StreamSupport; 15 | 16 | import com.google.common.base.Preconditions; 17 | import com.google.common.collect.AbstractIterator; 18 | 19 | import de.jungblut.glove.GloveStreamReader; 20 | import de.jungblut.glove.util.IOUtils; 21 | import de.jungblut.glove.util.StringVectorPair; 22 | import de.jungblut.glove.util.WritableUtils; 23 | import de.jungblut.math.DoubleVector; 24 | import de.jungblut.math.dense.DenseDoubleVector; 25 | 26 | public class GloveBinaryReader implements GloveStreamReader { 27 | 28 | @Override 29 | public Stream stream(Path gloveBinaryFolder) 30 | throws IOException { 31 | 32 | Path dict = gloveBinaryFolder.resolve(GloveBinaryWriter.DICT_FILE); 33 | Path vectors = gloveBinaryFolder.resolve(GloveBinaryWriter.VECTORS_FILE); 34 | 35 | final DataInputStream in = new DataInputStream(new BufferedInputStream( 36 | new FileInputStream(dict.toFile()))); 37 | final BufferedInputStream vec = new BufferedInputStream( 38 | new FileInputStream(vectors.toFile())); 39 | 40 | FilesIterator filesIterator = new FilesIterator(in, vec); 41 | 42 | Stream stream = StreamSupport 43 | .stream(Spliterators.spliteratorUnknownSize(filesIterator, 44 | Spliterator.ORDERED), false); 45 | 46 | stream.onClose(() -> IOUtils.cleanup(filesIterator)); 47 | 48 | return stream; 49 | } 50 | 51 | private class FilesIterator extends AbstractIterator 52 | implements Closeable { 53 | 54 | final DataInputStream dict; 55 | final BufferedInputStream vec; 56 | 57 | long blockSize = -1; 58 | long currentOffset = -1; 59 | String second = null; 60 | 61 | public FilesIterator(DataInputStream dict, BufferedInputStream vec) { 62 | this.dict = dict; 63 | this.vec = vec; 64 | } 65 | 66 | @Override 67 | protected StringVectorPair computeNext() { 68 | try { 69 | 70 | if (second != null) { 71 | String tmp = second; 72 | second = null; 73 | return new StringVectorPair(tmp, readVec()); 74 | } 75 | 76 | String word = dict.readUTF(); 77 | long off = WritableUtils.readVLong(dict); 78 | Preconditions.checkArgument(off >= 0, 79 | "Offset was negative! Dictionary seems corrupted."); 80 | if (blockSize == -1) { 81 | String word2 = dict.readUTF(); 82 | long off2 = WritableUtils.readVLong(dict); 83 | blockSize = off2; 84 | second = word2; 85 | currentOffset = off2; 86 | } else { 87 | // check block size consistency 88 | Preconditions.checkArgument((currentOffset + blockSize) == off, 89 | String.format( 90 | "Can't read different block sizes! Expected %d but was %d.", 91 | blockSize, off - currentOffset)); 92 | currentOffset = off; 93 | } 94 | 95 | return new StringVectorPair(word, readVec()); 96 | 97 | } catch (EOFException e) { 98 | // expected eod from the dictionary 99 | try { 100 | // check if the vector file has some bytes we were missing 101 | Preconditions 102 | .checkArgument( 103 | vec.read() == -1, 104 | "Vector file has more bytes than expected, dictionary seems inconsistent to the vector file"); 105 | } catch (IOException e1) { 106 | // expect errors here for checking stuff 107 | } 108 | return endOfData(); 109 | } catch (IOException e) { 110 | throw new RuntimeException(e); 111 | } 112 | } 113 | 114 | private DoubleVector readVec() throws IOException { 115 | int dim = (int) (blockSize / 4); 116 | DoubleVector v = new DenseDoubleVector(dim); 117 | 118 | byte[] buf = new byte[dim * 4]; 119 | try { 120 | int read = vec.read(buf); 121 | Preconditions.checkArgument(read == buf.length, 122 | "Couldn't read the next " + buf.length 123 | + " bytes from the file, vector file seems truncated"); 124 | } catch (EOFException e) { 125 | throw new IOException( 126 | "Unexpected end of file found while reading a vector of size " 127 | + dim); 128 | } 129 | 130 | ByteBuffer wrap = ByteBuffer.wrap(buf); 131 | for (int i = 0; i < v.getDimension(); i++) { 132 | int n = wrap.getInt(); 133 | v.set(i, Float.intBitsToFloat(n)); 134 | } 135 | 136 | return v; 137 | } 138 | 139 | @Override 140 | public void close() throws IOException { 141 | IOUtils.cleanup(dict, vec); 142 | } 143 | 144 | } 145 | 146 | } 147 | -------------------------------------------------------------------------------- /src/de/jungblut/glove/impl/GloveBinaryWriter.java: -------------------------------------------------------------------------------- 1 | package de.jungblut.glove.impl; 2 | 3 | import java.io.BufferedOutputStream; 4 | import java.io.ByteArrayOutputStream; 5 | import java.io.DataOutput; 6 | import java.io.DataOutputStream; 7 | import java.io.FileOutputStream; 8 | import java.io.IOException; 9 | import java.nio.file.Files; 10 | import java.nio.file.Path; 11 | import java.util.Iterator; 12 | import java.util.stream.Stream; 13 | 14 | import com.google.common.base.Preconditions; 15 | 16 | import de.jungblut.glove.GloveWriter; 17 | import de.jungblut.glove.util.StringVectorPair; 18 | import de.jungblut.glove.util.WritableUtils; 19 | import de.jungblut.math.DoubleVector; 20 | 21 | public class GloveBinaryWriter implements GloveWriter { 22 | 23 | public static final String VECTORS_FILE = "vectors.bin"; 24 | public static final String DICT_FILE = "dict.bin"; 25 | 26 | @Override 27 | public void writeStream(Stream stream, Path outputFolder) 28 | throws IOException { 29 | 30 | Files.createDirectories(outputFolder); 31 | 32 | try (DataOutputStream dict = new DataOutputStream(new BufferedOutputStream( 33 | new FileOutputStream(outputFolder.resolve(DICT_FILE).toFile())))) { 34 | 35 | try (BufferedOutputStream vec = new BufferedOutputStream( 36 | new FileOutputStream(outputFolder.resolve(VECTORS_FILE).toFile()))) { 37 | 38 | Iterator iterator = stream.iterator(); 39 | 40 | long blockSize = -1; 41 | long offset = 0; 42 | ByteArrayOutputStream byteBuffer = new ByteArrayOutputStream(); 43 | while (iterator.hasNext()) { 44 | byteBuffer.reset(); 45 | StringVectorPair pair = iterator.next(); 46 | dict.writeUTF(pair.value); 47 | WritableUtils.writeVLong(dict, offset); 48 | 49 | try (DataOutputStream out = new DataOutputStream(byteBuffer)) { 50 | writeVectorData(pair.vector, out); 51 | } 52 | 53 | byte[] buf = byteBuffer.toByteArray(); 54 | if (blockSize == -1) { 55 | blockSize = buf.length; 56 | } 57 | 58 | if (blockSize != buf.length) { 59 | Preconditions 60 | .checkArgument( 61 | blockSize == buf.length, 62 | String 63 | .format( 64 | "Can't write different block size! Expected %d but was %d. " 65 | + "This happened because the vectors in the stream had different dimensions.", 66 | blockSize, buf.length)); 67 | } 68 | 69 | vec.write(buf); 70 | 71 | offset += buf.length; 72 | } 73 | } 74 | } 75 | } 76 | 77 | private void writeVectorData(DoubleVector v, DataOutput out) 78 | throws IOException { 79 | 80 | for (int i = 0; i < v.getDimension(); i++) { 81 | float f = (float) v.get(i); 82 | int var = Float.floatToIntBits(f); 83 | out.writeInt(var); 84 | } 85 | 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /src/de/jungblut/glove/impl/GloveTextReader.java: -------------------------------------------------------------------------------- 1 | package de.jungblut.glove.impl; 2 | 3 | import java.io.IOException; 4 | import java.nio.file.Files; 5 | import java.nio.file.Path; 6 | import java.util.regex.Pattern; 7 | import java.util.stream.Stream; 8 | 9 | import com.google.common.base.Preconditions; 10 | 11 | import de.jungblut.glove.GloveStreamReader; 12 | import de.jungblut.glove.util.StringVectorPair; 13 | import de.jungblut.math.DoubleVector; 14 | import de.jungblut.math.dense.DenseDoubleVector; 15 | 16 | public class GloveTextReader implements GloveStreamReader { 17 | 18 | private static final Pattern SPLIT_WHITESPACE = Pattern.compile(" "); 19 | 20 | @Override 21 | public Stream stream(Path input) throws IOException { 22 | 23 | final Stream lines = Files.lines(input); 24 | int[] expectedSize = new int[] { -1 }; 25 | Stream pairs = lines.map((line) -> process(line)).map( 26 | (pair) -> { 27 | Preconditions.checkNotNull(pair.value, "word was null"); 28 | if (expectedSize[0] == -1) { 29 | expectedSize[0] = pair.vector.getDimension(); 30 | } else { 31 | Preconditions.checkArgument( 32 | expectedSize[0] == pair.vector.getDimension(), 33 | "found inconsistency. Expected size " + expectedSize[0] 34 | + " but found " + pair.vector.getDimension()); 35 | } 36 | return pair; 37 | }); 38 | 39 | pairs.onClose(() -> lines.close()); 40 | 41 | return pairs; 42 | } 43 | 44 | private StringVectorPair process(String line) { 45 | String[] split = SPLIT_WHITESPACE.split(line); 46 | String name = split[0]; 47 | 48 | DoubleVector vec = new DenseDoubleVector(split.length - 1); 49 | for (int i = 1; i < split.length; i++) { 50 | vec.set(i - 1, Double.parseDouble(split[i])); 51 | } 52 | 53 | return new StringVectorPair(name, vec); 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /src/de/jungblut/glove/util/IOUtils.java: -------------------------------------------------------------------------------- 1 | package de.jungblut.glove.util; 2 | 3 | import java.io.IOException; 4 | 5 | 6 | /** 7 | * An utility class for I/O related functionality. 8 | */ 9 | /* Taken from org.apache.hadoop.commons and modified */ 10 | public class IOUtils { 11 | 12 | 13 | /** 14 | * Close the Closeable objects and ignore any {@link IOException} or null pointers. Must 15 | * only be used for cleanup in exception handlers. 16 | * 17 | * @param closeables the objects to close 18 | */ 19 | public static void cleanup(java.io.Closeable... closeables) { 20 | for (java.io.Closeable c : closeables) { 21 | if (c != null) { 22 | try { 23 | c.close(); 24 | } catch (IOException e) { 25 | System.err.println("Exception in closing " + c); 26 | System.err.println(e.getMessage()); 27 | } 28 | } 29 | } 30 | } 31 | 32 | } 33 | -------------------------------------------------------------------------------- /src/de/jungblut/glove/util/StringVectorPair.java: -------------------------------------------------------------------------------- 1 | package de.jungblut.glove.util; 2 | 3 | import de.jungblut.math.DoubleVector; 4 | 5 | public class StringVectorPair { 6 | 7 | public final String value; 8 | public final DoubleVector vector; 9 | 10 | public StringVectorPair(String value, DoubleVector vector) { 11 | this.value = value; 12 | this.vector = vector; 13 | } 14 | 15 | @Override 16 | public String toString() { 17 | return "StringVectorPair [word=" + value + ", vector=" + vector + "]"; 18 | } 19 | 20 | @Override 21 | public int hashCode() { 22 | final int prime = 31; 23 | int result = 1; 24 | result = prime * result + ((value == null) ? 0 : value.hashCode()); 25 | return result; 26 | } 27 | 28 | @Override 29 | public boolean equals(Object obj) { 30 | if (this == obj) 31 | return true; 32 | if (obj == null) 33 | return false; 34 | if (getClass() != obj.getClass()) 35 | return false; 36 | StringVectorPair other = (StringVectorPair) obj; 37 | if (value == null) { 38 | if (other.value != null) 39 | return false; 40 | } else if (!value.equals(other.value)) 41 | return false; 42 | return true; 43 | } 44 | 45 | } 46 | -------------------------------------------------------------------------------- /src/de/jungblut/glove/util/WritableUtils.java: -------------------------------------------------------------------------------- 1 | package de.jungblut.glove.util; 2 | 3 | import java.io.DataInput; 4 | import java.io.DataOutput; 5 | import java.io.IOException; 6 | 7 | /* Taken from org.apache.hadoop.commons and modified */ 8 | public final class WritableUtils { 9 | 10 | /** 11 | * Serializes a long to a binary stream with zero-compressed encoding. For -112 <= i <= 127, only 12 | * one byte is used with the actual value. For other values of i, the first byte value indicates 13 | * whether the long is positive or negative, and the number of bytes that follow. If the first 14 | * byte value v is between -113 and -120, the following long is positive, with number of bytes 15 | * that follow are -(v+112). If the first byte value v is between -121 and -128, the following 16 | * long is negative, with number of bytes that follow are -(v+120). Bytes are stored in the 17 | * high-non-zero-byte-first order. 18 | * 19 | * @param stream Binary output stream 20 | * @param i Long to be serialized 21 | * @throws java.io.IOException 22 | */ 23 | public static void writeVLong(DataOutput stream, long i) throws IOException { 24 | if (i >= -112 && i <= 127) { 25 | stream.writeByte((byte) i); 26 | return; 27 | } 28 | 29 | int len = -112; 30 | if (i < 0) { 31 | i ^= -1L; // take one's complement' 32 | len = -120; 33 | } 34 | 35 | long tmp = i; 36 | while (tmp != 0) { 37 | tmp = tmp >> 8; 38 | len--; 39 | } 40 | 41 | stream.writeByte((byte) len); 42 | 43 | len = len < -120 ? -(len + 120) : -(len + 112); 44 | 45 | for (int idx = len; idx != 0; idx--) { 46 | int shiftbits = (idx - 1) * 8; 47 | long mask = 0xFFL << shiftbits; 48 | stream.writeByte((byte) ((i & mask) >> shiftbits)); 49 | } 50 | } 51 | 52 | /** 53 | * Reads a zero-compressed encoded long from input stream and returns it. 54 | * 55 | * @param stream Binary input stream 56 | * @throws java.io.IOException 57 | * @return deserialized long from stream. 58 | */ 59 | public static long readVLong(DataInput stream) throws IOException { 60 | byte firstByte = stream.readByte(); 61 | int len = decodeVIntSize(firstByte); 62 | if (len == 1) { 63 | return firstByte; 64 | } 65 | long i = 0; 66 | for (int idx = 0; idx < len - 1; idx++) { 67 | byte b = stream.readByte(); 68 | i = i << 8; 69 | i = i | b & 0xFF; 70 | } 71 | return isNegativeVInt(firstByte) ? i ^ -1L : i; 72 | } 73 | 74 | /** 75 | * Parse the first byte of a vint/vlong to determine the number of bytes 76 | * 77 | * @param value the first byte of the vint/vlong 78 | * @return the total number of bytes (1 to 9) 79 | */ 80 | public static int decodeVIntSize(byte value) { 81 | if (value >= -112) { 82 | return 1; 83 | } else if (value < -120) { 84 | return -119 - value; 85 | } 86 | return -111 - value; 87 | } 88 | 89 | /** 90 | * Given the first byte of a vint/vlong, determine the sign 91 | * 92 | * @param value the first byte 93 | * @return is the value negative 94 | */ 95 | public static boolean isNegativeVInt(byte value) { 96 | return value < -120 || value >= -112 && value < 0; 97 | } 98 | 99 | 100 | } 101 | -------------------------------------------------------------------------------- /test/de/jungblut/glove/GloveTestUtils.java: -------------------------------------------------------------------------------- 1 | package de.jungblut.glove; 2 | 3 | import java.io.IOException; 4 | import java.nio.charset.Charset; 5 | import java.nio.file.Files; 6 | import java.nio.file.Path; 7 | import java.util.List; 8 | import java.util.stream.Collectors; 9 | import java.util.stream.IntStream; 10 | import java.util.stream.Stream; 11 | 12 | import org.junit.Assert; 13 | 14 | import de.jungblut.datastructure.ArrayJoiner; 15 | import de.jungblut.glove.impl.GloveBinaryWriter; 16 | import de.jungblut.glove.util.StringVectorPair; 17 | import de.jungblut.math.dense.DenseDoubleVector; 18 | 19 | public class GloveTestUtils { 20 | 21 | public static Path createTemporaryTestBinaryFile(int dimension, int numVectors) 22 | throws IOException { 23 | 24 | Stream randomStream = GloveTestUtils.getWordVectorStream( 25 | dimension, numVectors); 26 | 27 | return writeStreamToBinaryFile(randomStream); 28 | } 29 | 30 | public static Path writeStreamToBinaryFile( 31 | Stream randomStream) throws IOException { 32 | GloveBinaryWriter writer = new GloveBinaryWriter(); 33 | Path outputFolder = GloveTestUtils.createTemporaryOutputFolder(); 34 | 35 | writer.writeStream(randomStream, outputFolder); 36 | return outputFolder; 37 | } 38 | 39 | public static Path createTemporaryCorruptedTestBinaryFiles() 40 | throws IOException { 41 | Stream input = IntStream.range(0, 5).mapToObj( 42 | (i) -> new StringVectorPair("" + i, new DenseDoubleVector(i))); 43 | return writeStreamToBinaryFile(input); 44 | } 45 | 46 | public static Path createTemporaryOutputFolder() throws IOException { 47 | Path tmp = Files.createTempDirectory("glove-test-dir"); 48 | tmp.toFile().deleteOnExit(); 49 | return tmp; 50 | } 51 | 52 | public static Path createTemporaryTestTextFile(int dimension) 53 | throws IOException { 54 | Stream input = getWordVectorStream(dimension); 55 | return createTemporaryTestTextFileWithContent(input); 56 | } 57 | 58 | public static Path createTemporaryCorruptedTestTextFile() throws IOException { 59 | Stream input = IntStream.range(0, 5).mapToObj( 60 | (i) -> new StringVectorPair("" + i, new DenseDoubleVector(i))); 61 | return createTemporaryTestTextFileWithContent(input); 62 | } 63 | 64 | public static Stream getWordVectorStream(int dimension) { 65 | return getWordVectorStream(dimension, 15); 66 | } 67 | 68 | public static Stream getWordVectorStream(int dimension, 69 | int numElements) { 70 | return IntStream.range(0, numElements).mapToObj((i) -> { 71 | DenseDoubleVector vec = new DenseDoubleVector(dimension); 72 | for (int x = 0; x < dimension; x++) { 73 | vec.set(x, x); 74 | } 75 | return new StringVectorPair("" + i, vec); 76 | }); 77 | } 78 | 79 | public static Path createTemporaryTestTextFileWithContent( 80 | Stream input) throws IOException { 81 | 82 | Path tmp = Files.createTempFile("glove-test", ".txt"); 83 | Files.write( 84 | tmp, 85 | input.map( 86 | (pair) -> pair.value + " " 87 | + ArrayJoiner.on(" ").join(pair.vector.toArray())).collect( 88 | Collectors.toList()), Charset.forName("UTF-8")); 89 | 90 | tmp.toFile().deleteOnExit(); 91 | 92 | return tmp; 93 | } 94 | 95 | public static void checkWordVectorResult(int dim, int numElements, 96 | List list) { 97 | 98 | Assert.assertEquals(numElements, list.size()); 99 | int start = 0; 100 | for (StringVectorPair v : list) { 101 | checkSingleWordVector(dim, start++, v); 102 | } 103 | } 104 | 105 | public static void checkSingleWordVector(int dim, int start, 106 | StringVectorPair v) { 107 | String name = start + ""; 108 | Assert.assertEquals(name, v.value); 109 | Assert.assertEquals(dim, v.vector.getDimension()); 110 | for (int i = 0; i < v.vector.getDimension(); i++) { 111 | Assert.assertEquals(i, (int) v.vector.get(i)); 112 | } 113 | } 114 | 115 | } 116 | -------------------------------------------------------------------------------- /test/de/jungblut/glove/impl/GloveBinaryRandomAccessReaderTest.java: -------------------------------------------------------------------------------- 1 | package de.jungblut.glove.impl; 2 | 3 | import java.io.IOException; 4 | import java.nio.file.Path; 5 | 6 | import org.junit.Assert; 7 | import org.junit.Test; 8 | 9 | import de.jungblut.glove.GloveTestUtils; 10 | import de.jungblut.glove.util.StringVectorPair; 11 | import de.jungblut.math.DoubleVector; 12 | 13 | public class GloveBinaryRandomAccessReaderTest { 14 | 15 | @Test 16 | public void testNormalRandomReading() throws IOException { 17 | int dim = 10; 18 | int numElements = 100; 19 | Path folder = GloveTestUtils 20 | .createTemporaryTestBinaryFile(dim, numElements); 21 | 22 | GloveBinaryRandomAccessReader reader = new GloveBinaryRandomAccessReader( 23 | folder); 24 | 25 | for (int i = 0; i < numElements; i++) { 26 | String s = i + ""; 27 | Assert.assertTrue("didn't contain word=" + s, reader.contains(s)); 28 | DoubleVector vec = reader.get(s); 29 | Assert.assertNotNull("vector null, despite contains returned true! word=" 30 | + s, vec); 31 | 32 | GloveTestUtils 33 | .checkSingleWordVector(dim, i, new StringVectorPair(s, vec)); 34 | 35 | } 36 | 37 | } 38 | 39 | @Test 40 | public void testNotContainedValues() throws IOException { 41 | int dim = 10; 42 | int numElements = 100; 43 | Path folder = GloveTestUtils 44 | .createTemporaryTestBinaryFile(dim, numElements); 45 | 46 | GloveBinaryRandomAccessReader reader = new GloveBinaryRandomAccessReader( 47 | folder); 48 | 49 | Assert.assertNull("contained lolol", reader.get("lolol")); 50 | Assert.assertNull("contained omg", reader.get("omg")); 51 | Assert.assertNull("contained 101", reader.get("101")); 52 | Assert.assertNull("contained -1", reader.get("-1")); 53 | } 54 | 55 | } 56 | -------------------------------------------------------------------------------- /test/de/jungblut/glove/impl/GloveBinaryReaderWriterTest.java: -------------------------------------------------------------------------------- 1 | package de.jungblut.glove.impl; 2 | 3 | import java.io.BufferedOutputStream; 4 | import java.io.DataOutputStream; 5 | import java.io.FileOutputStream; 6 | import java.io.IOException; 7 | import java.nio.file.Path; 8 | import java.util.List; 9 | import java.util.Random; 10 | import java.util.stream.Collectors; 11 | 12 | import org.junit.Test; 13 | 14 | import de.jungblut.glove.GloveTestUtils; 15 | import de.jungblut.glove.util.StringVectorPair; 16 | import de.jungblut.glove.util.WritableUtils; 17 | 18 | public class GloveBinaryReaderWriterTest { 19 | 20 | @Test 21 | public void testNormalFileWritingAndReading() throws IOException { 22 | 23 | int dim = 10; 24 | int numElements = 100; 25 | Path folder = GloveTestUtils 26 | .createTemporaryTestBinaryFile(dim, numElements); 27 | GloveBinaryReader reader = new GloveBinaryReader(); 28 | List list = reader.stream(folder).collect( 29 | Collectors.toList()); 30 | 31 | GloveTestUtils.checkWordVectorResult(dim, numElements, list); 32 | 33 | } 34 | 35 | @Test(expected = IllegalArgumentException.class) 36 | public void testWritingCorruptedStream() throws IOException { 37 | GloveTestUtils.createTemporaryCorruptedTestBinaryFiles(); 38 | } 39 | 40 | @Test(expected = IllegalArgumentException.class) 41 | public void testReadingCorruptedFileTooSmall() throws IOException { 42 | int dim = 10; 43 | int numElements = 100; 44 | Path folder = GloveTestUtils 45 | .createTemporaryTestBinaryFile(dim, numElements); 46 | corruptVectorFile(folder, 2048); 47 | 48 | GloveBinaryReader reader = new GloveBinaryReader(); 49 | reader.stream(folder).forEach((v) -> v.vector.sum()); 50 | } 51 | 52 | @Test(expected = IllegalArgumentException.class) 53 | public void testReadingCorruptedFileTooBig() throws IOException { 54 | int dim = 10; 55 | int numElements = 100; 56 | Path folder = GloveTestUtils 57 | .createTemporaryTestBinaryFile(dim, numElements); 58 | corruptVectorFile(folder, 1024 * 512); 59 | 60 | GloveBinaryReader reader = new GloveBinaryReader(); 61 | reader.stream(folder).forEach((v) -> v.vector.sum()); 62 | } 63 | 64 | @Test(expected = IllegalArgumentException.class) 65 | public void testReadingCorruptedDictionaryDifferentNumElements() 66 | throws IOException { 67 | int dim = 10; 68 | int numElements = 100; 69 | Path folder = GloveTestUtils 70 | .createTemporaryTestBinaryFile(dim, numElements); 71 | corruptDictionaryFile(folder, dim, 50, false, false); 72 | GloveBinaryReader reader = new GloveBinaryReader(); 73 | reader.stream(folder).forEach((v) -> v.vector.sum()); 74 | } 75 | 76 | @Test(expected = IllegalArgumentException.class) 77 | public void testReadingCorruptedDictionaryNegativeOffset() throws IOException { 78 | int dim = 10; 79 | int numElements = 100; 80 | Path folder = GloveTestUtils 81 | .createTemporaryTestBinaryFile(dim, numElements); 82 | corruptDictionaryFile(folder, dim, numElements, true, true); 83 | GloveBinaryReader reader = new GloveBinaryReader(); 84 | reader.stream(folder).forEach((v) -> v.vector.sum()); 85 | } 86 | 87 | @Test(expected = IllegalArgumentException.class) 88 | public void testReadingCorruptedDictionaryCorruptedOffset() 89 | throws IOException { 90 | int dim = 10; 91 | int numElements = 100; 92 | Path folder = GloveTestUtils 93 | .createTemporaryTestBinaryFile(dim, numElements); 94 | corruptDictionaryFile(folder, dim, numElements, true, false); 95 | GloveBinaryReader reader = new GloveBinaryReader(); 96 | reader.stream(folder).forEach((v) -> v.vector.sum()); 97 | } 98 | 99 | private void corruptDictionaryFile(Path folder, int dim, int items, 100 | boolean corruptOffsets, boolean negativeOffset) throws IOException { 101 | try (DataOutputStream dict = new DataOutputStream(new BufferedOutputStream( 102 | new FileOutputStream(folder.resolve(GloveBinaryWriter.DICT_FILE) 103 | .toFile())))) { 104 | for (int i = 0; i < items; i++) { 105 | dict.writeUTF(i + ""); 106 | long off = i * dim * 4; 107 | if (corruptOffsets) { 108 | if (negativeOffset) { 109 | off = -115; 110 | } else { 111 | off = off + 2; 112 | } 113 | } 114 | WritableUtils.writeVLong(dict, off); 115 | } 116 | } 117 | 118 | } 119 | 120 | public void corruptVectorFile(Path folder, int size) throws IOException { 121 | // corrupt the binary file 122 | try (BufferedOutputStream vec = new BufferedOutputStream( 123 | new FileOutputStream(folder.resolve(GloveBinaryWriter.VECTORS_FILE) 124 | .toFile()))) { 125 | 126 | // write some random garbage 127 | byte[] buf = new byte[size]; 128 | Random r = new Random(); 129 | for (int i = 0; i < buf.length; i++) { 130 | buf[i] = (byte) r.nextInt(Byte.MAX_VALUE); 131 | } 132 | vec.write(buf); 133 | 134 | } 135 | } 136 | 137 | } 138 | -------------------------------------------------------------------------------- /test/de/jungblut/glove/impl/GloveTextReaderTest.java: -------------------------------------------------------------------------------- 1 | package de.jungblut.glove.impl; 2 | 3 | import java.io.IOException; 4 | import java.nio.file.Path; 5 | import java.util.List; 6 | import java.util.stream.Collectors; 7 | import java.util.stream.Stream; 8 | 9 | import org.junit.Test; 10 | 11 | import de.jungblut.glove.GloveTestUtils; 12 | import de.jungblut.glove.util.StringVectorPair; 13 | 14 | public class GloveTextReaderTest { 15 | 16 | @Test 17 | public void testNormalFile() throws IOException { 18 | final int dim = 5; 19 | Path in = GloveTestUtils.createTemporaryTestTextFile(5); 20 | 21 | GloveTextReader reader = new GloveTextReader(); 22 | Stream stream = reader.stream(in); 23 | List collected = stream.collect(Collectors.toList()); 24 | 25 | GloveTestUtils.checkWordVectorResult(dim, 15, collected); 26 | } 27 | 28 | @Test(expected = IllegalArgumentException.class) 29 | public void readBlowsUpOnCorruption() throws IOException { 30 | Path in = GloveTestUtils.createTemporaryCorruptedTestTextFile(); 31 | GloveTextReader reader = new GloveTextReader(); 32 | Stream stream = reader.stream(in); 33 | stream.forEach((v) -> System.out.println(v)); 34 | } 35 | } 36 | --------------------------------------------------------------------------------