├── .gitignore ├── src ├── lib.rs ├── utils.rs ├── wordclusters.rs ├── errors.rs ├── vectorreader.rs └── wordvectors.rs ├── CHANGELOG.rst ├── Cargo.toml ├── examples └── example.rs ├── .travis.yml ├── LICENSE ├── README.md └── tests └── tests.rs /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | **/*.rs.bk 3 | Cargo.lock 4 | target 5 | example -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | extern crate byteorder; 2 | 3 | pub mod vectorreader; 4 | pub mod wordvectors; 5 | pub mod wordclusters; 6 | mod utils; 7 | pub mod errors; 8 | -------------------------------------------------------------------------------- /CHANGELOG.rst: -------------------------------------------------------------------------------- 1 | Repository Change Log 2 | --------------------- 3 | 4 | All notable changes to this project will be documented in this file. 5 | 6 | [v0.3.3] 7 | ======== 8 | * Fixed deprecation warnings 9 | 10 | [v0.3.2] 11 | ======== 12 | * Optimized iterating over large sets of vectors 13 | 14 | [v0.3.1] 15 | ======== 16 | * Added ability to load models from different sources 17 | 18 | [v0.3.0] 19 | ======== 20 | * Recognize binary vector files without line breaks. 21 | * Improved ``analogy`` perfomance 22 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "word2vec" 3 | version = "0.3.3" 4 | authors = ["Dima Kudosh ", "Sebastian Humenda"] 5 | description = "Rust interface to word2vec." 6 | documentation = "https://github.com/DimaKudosh/word2vec/wiki" 7 | homepage = "https://github.com/DimaKudosh/word2vec" 8 | repository = "https://github.com/DimaKudosh/word2vec" 9 | keywords = ["word2vec"] 10 | license = "MIT" 11 | include = [ 12 | "**/*.rs", 13 | "Cargo.toml", 14 | ] 15 | 16 | [dependencies] 17 | byteorder = "1" 18 | 19 | 20 | [[test]] 21 | name = "tests" 22 | -------------------------------------------------------------------------------- /examples/example.rs: -------------------------------------------------------------------------------- 1 | extern crate word2vec; 2 | 3 | fn main(){ 4 | let model = word2vec::wordvectors::WordVector::load_from_binary( 5 | "vectors.bin").expect("Unable to load word vector model"); 6 | println!("{:?}", model.cosine("snow", 10)); 7 | let positive = vec!["woman", "king"]; 8 | let negative = vec!["man"]; 9 | println!("{:?}", model.analogy(positive, negative, 10)); 10 | 11 | let clusters = word2vec::wordclusters::WordClusters::load_from_file( 12 | "classes.txt").expect("Unable to load word clusters"); 13 | println!("{:?}", clusters.get_cluster("belarus")); 14 | println!("{:?}", clusters.get_words_on_cluster(6)); 15 | } 16 | 17 | -------------------------------------------------------------------------------- /src/utils.rs: -------------------------------------------------------------------------------- 1 | pub fn dot_product(arr1: &Vec, arr2: &Vec) -> f32 { 2 | let mut result: f32 = 0.0; 3 | for (elem1, elem2) in arr1.iter().zip(arr2.iter()) { 4 | result += elem1 * elem2; 5 | } 6 | return result; 7 | } 8 | 9 | pub fn vector_norm(vector: &mut Vec) { 10 | let sum = 1.0 / vector.iter().fold(0f32, |sum, &x| sum + (x * x)).sqrt(); 11 | for x in vector.iter_mut() { 12 | (*x) *= sum; 13 | } 14 | } 15 | 16 | /// Get the mean (average) of the given Iterator of numbers 17 | pub fn mean>(numbers: Iterable) -> f32 { 18 | let (sum, count) = numbers.fold((0f32, 0), |(sum, count), x| (sum + x, count + 1)); 19 | sum / (count as f32) 20 | } 21 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | sudo: false 2 | language: rust 3 | addons: 4 | apt: 5 | packages: 6 | - libcurl4-openssl-dev 7 | - libelf-dev 8 | - libdw-dev 9 | - binutils-dev # optional: only required for the --verify flag of coveralls 10 | - wget 11 | rust: 12 | - stable 13 | - beta 14 | - nightly 15 | before_script: 16 | - pip install 'travis-cargo<0.2' --user 17 | - export PATH=$HOME/.local/bin/:$PATH 18 | - wget -c -O vectors.bin 'https://www.dropbox.com/s/y4ls6yd4k0wbzhp/vectors.bin?dl=0' 19 | script: 20 | - travis-cargo build 21 | - travis-cargo test 22 | - travis-cargo bench 23 | - travis-cargo --only stable doc 24 | after_success: 25 | - travis-cargo --only stable doc-upload 26 | - travis-cargo coveralls --no-sudo --verify 27 | 28 | env: 29 | global: 30 | - TRAVIS_CARGO_NIGHTLY_FEATURE=nightly 31 | 32 | matrix: 33 | allow_failures: 34 | - rust: nightly 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2015-2016 Kevin B. Knapp 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # word2vec [![Build Status](https://travis-ci.org/DimaKudosh/word2vec.svg?branch=master)](https://travis-ci.org/DimaKudosh/word2vec) 2 | 3 | Rust interface to word2vec word vectors. 4 | 5 | This crate provides a way to read a trained word vector file from word2vec. 6 | It doesn't provide model training and hence requires a already trained model. 7 | 8 | 9 | ## Documentation 10 | Documentation is available at https://github.com/DimaKudosh/word2vec/wiki 11 | 12 | ## Example 13 | 14 | Add this to your `cargo.toml`: 15 | 16 | ``` 17 | [dependencies] 18 | # … 19 | word2vec = "0.3.3" 20 | ``` 21 | 22 | Example for word similarity and word clusters: 23 | 24 | ```rust 25 | extern crate word2vec; 26 | 27 | fn main(){ 28 | let model = word2vec::wordvectors::WordVector::load_from_binary( 29 | "vectors.bin").expect("Unable to load word vector model"); 30 | println!("{:?}", model.cosine("snow", 10)); 31 | let positive = vec!["woman", "king"]; 32 | let negative = vec!["man"]; 33 | println!("{:?}", model.analogy(positive, negative, 10)); 34 | 35 | let clusters = word2vec::wordclusters::WordClusters::load_from_file( 36 | "classes.txt").expect("Unable to load word clusters"); 37 | println!("{:?}", clusters.get_cluster("belarus")); 38 | println!("{:?}", clusters.get_words_on_cluster(6)); 39 | } 40 | ``` 41 | -------------------------------------------------------------------------------- /src/wordclusters.rs: -------------------------------------------------------------------------------- 1 | use std::io::prelude::*; 2 | use std::io::BufReader; 3 | use std::fs::File; 4 | use std::collections::HashMap; 5 | use errors::Word2VecError; 6 | 7 | 8 | pub struct WordClusters { 9 | clusters: HashMap>, 10 | } 11 | 12 | 13 | impl WordClusters { 14 | pub fn load_from_file(file_name: &str) -> Result { 15 | let file = File::open(file_name)?; 16 | let reader = BufReader::new(file); 17 | 18 | return WordClusters::load_from_reader(reader) 19 | } 20 | 21 | pub fn load_from_reader(mut reader: R) -> Result { 22 | let mut buffer = String::new(); 23 | let mut clusters: HashMap> = HashMap::new(); 24 | while reader.read_line(&mut buffer)? > 0 { 25 | { 26 | let mut iter = buffer.split_whitespace(); 27 | let word = iter.next().unwrap(); 28 | let cluster_number = iter.next().unwrap().trim().parse::().ok().unwrap(); 29 | let cluster = clusters.entry(cluster_number).or_insert(Vec::new()); 30 | cluster.push(word.to_string()); 31 | } 32 | buffer.clear(); 33 | } 34 | Ok(WordClusters { clusters }) 35 | } 36 | 37 | pub fn get_words_on_cluster(&self, index: i32) -> Option<&Vec> { 38 | self.clusters.get(&index) 39 | } 40 | 41 | pub fn get_cluster(&self, word: &str) -> Option<&i32> { 42 | let word = word.to_string(); 43 | for (key, val) in self.clusters.iter() { 44 | if val.contains(&word) { 45 | return Some(key); 46 | } 47 | } 48 | None 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /src/errors.rs: -------------------------------------------------------------------------------- 1 | use std::error; 2 | use std::fmt; 3 | use std::io; 4 | use std::string::FromUtf8Error; 5 | 6 | 7 | /// Common error type for errors concerning loading and processing binary word vectors 8 | /// 9 | /// This error type mostly wraps I/O and encoding errors, but also adds crate-specific error 10 | /// variants. 11 | #[derive(Debug)] 12 | pub enum Word2VecError { 13 | Io(io::Error), 14 | Decode(FromUtf8Error), 15 | WrongHeader, 16 | } 17 | 18 | impl error::Error for Word2VecError { 19 | fn description(&self) -> &str { 20 | match *self { 21 | Word2VecError::Decode(ref err) => err.description(), 22 | Word2VecError::Io(ref err) => err.description(), 23 | Word2VecError::WrongHeader => "Wrong header format", 24 | } 25 | } 26 | 27 | fn source(&self) -> Option<&(dyn error::Error + 'static)> { 28 | match *self { 29 | Word2VecError::Decode(ref e) => e.source(), 30 | Word2VecError::Io(ref e) => e.source(), 31 | _ => None, 32 | } 33 | } 34 | } 35 | 36 | impl fmt::Display for Word2VecError { 37 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 38 | match *self { 39 | Word2VecError::Io(ref err) => write!(f, "IO error: {}", err), 40 | Word2VecError::Decode(ref err) => write!(f, "Decode error: {}", err), 41 | Word2VecError::WrongHeader => write!(f, "Wrong header length."), 42 | } 43 | } 44 | } 45 | 46 | impl From for Word2VecError { 47 | fn from(err: io::Error) -> Word2VecError { 48 | Word2VecError::Io(err) 49 | } 50 | } 51 | 52 | impl From for Word2VecError { 53 | fn from(err: FromUtf8Error) -> Word2VecError { 54 | Word2VecError::Decode(err) 55 | } 56 | } 57 | 58 | -------------------------------------------------------------------------------- /tests/tests.rs: -------------------------------------------------------------------------------- 1 | extern crate word2vec; 2 | use word2vec::wordvectors::WordVector; 3 | 4 | 5 | const PATH: &'static str = "vectors.bin"; 6 | 7 | 8 | #[test] 9 | fn test_word_cosine() { 10 | let model = WordVector::load_from_binary(PATH).unwrap(); 11 | let res = model.cosine("winter", 10).expect("word not found in vocabulary"); 12 | assert_eq!(res.len(), 10); 13 | let only_words: Vec<&str> = res.iter().map(|x| x.0.as_ref()).collect(); 14 | assert!(!only_words.contains(&"winter")) 15 | } 16 | 17 | 18 | #[test] 19 | fn test_unexisting_word_cosine() { 20 | let model = WordVector::load_from_binary(PATH).unwrap(); 21 | let result = model.cosine("somenotexistingword", 10); 22 | match result { 23 | Some(_) => assert!(false), 24 | None => assert!(true), 25 | } 26 | } 27 | 28 | 29 | #[test] 30 | fn test_word_analogy() { 31 | let model = WordVector::load_from_binary(PATH).unwrap(); 32 | let mut pos = Vec::new(); 33 | pos.push("woman"); 34 | pos.push("king"); 35 | let mut neg = Vec::new(); 36 | neg.push("man"); 37 | let res = model.analogy(pos, neg, 10).expect("couldn't find all of the given words"); 38 | assert_eq!(res.len(), 10); 39 | let only_words: Vec<&str> = res.iter().map(|x| x.0.as_ref()).collect(); 40 | assert!(!only_words.contains(&"woman")); 41 | assert!(!only_words.contains(&"king")); 42 | assert!(!only_words.contains(&"man")); 43 | } 44 | 45 | 46 | #[test] 47 | fn test_word_analogy_with_empty_params() { 48 | let model = WordVector::load_from_binary(PATH).unwrap(); 49 | let result = model.analogy(Vec::new(), Vec::new(), 10); 50 | match result { 51 | Some(_) => assert!(false), 52 | None => assert!(true), 53 | } 54 | } 55 | 56 | #[test] 57 | fn test_word_count_is_correctly_returned() { 58 | let v = WordVector::load_from_binary(PATH).unwrap(); 59 | assert_eq!(v.word_count(), 71291); 60 | } 61 | -------------------------------------------------------------------------------- /src/vectorreader.rs: -------------------------------------------------------------------------------- 1 | use std::io::BufRead; 2 | 3 | use byteorder::{ReadBytesExt, LittleEndian}; 4 | 5 | use errors::Word2VecError; 6 | 7 | pub struct WordVectorReader { 8 | vocabulary_size: usize, 9 | vector_size: usize, 10 | reader: R, 11 | 12 | ended_early: bool, 13 | vectors_read: usize 14 | } 15 | 16 | impl WordVectorReader { 17 | 18 | pub fn vocabulary_size(&self) -> usize { 19 | return self.vocabulary_size; 20 | } 21 | 22 | pub fn vector_size(&self) -> usize { 23 | return self.vector_size; 24 | } 25 | 26 | pub fn new_from_reader(mut reader: R) -> Result, Word2VecError> { 27 | 28 | // Read UTF8 header string from start of file 29 | let mut header = String::new(); 30 | reader.read_line(&mut header)?; 31 | 32 | //Parse 2 integers, separated by whitespace 33 | let header_info = header.split_whitespace() 34 | .filter_map(|x| x.parse::().ok()) 35 | .take(2) 36 | .collect::>(); 37 | if header_info.len() != 2 { 38 | return Err(Word2VecError::WrongHeader); 39 | } 40 | 41 | //We've successfully read the header, ready to read vectors 42 | return Ok(WordVectorReader { 43 | vocabulary_size: header_info[0], 44 | vector_size: header_info[1], 45 | vectors_read: 0, 46 | ended_early: false, 47 | reader, 48 | }); 49 | } 50 | 51 | } 52 | 53 | impl Iterator for WordVectorReader { 54 | type Item = (String, Vec); 55 | 56 | fn next(&mut self) -> Option<(String, Vec)> { 57 | 58 | if self.vectors_read == self.vocabulary_size { 59 | return None; 60 | } 61 | 62 | // Read the bytes of the word string 63 | let mut word_bytes: Vec = Vec::new(); 64 | if let Err(_) = self.reader.read_until(b' ', &mut word_bytes) { 65 | // End the stream if a read error occured 66 | self.ended_early = true; 67 | return None; 68 | } 69 | 70 | // trim newlines, some vector files have newlines in front of a new word, others don't 71 | let word = match String::from_utf8(word_bytes) { 72 | Err(_) => { 73 | self.ended_early = true; 74 | return None 75 | }, 76 | Ok(word) => word.trim().into(), 77 | }; 78 | 79 | // Read floats of the vector 80 | let mut vector: Vec = Vec::with_capacity(self.vector_size); 81 | for _ in 0 .. self.vector_size { 82 | match self.reader.read_f32::() { 83 | Err(_) => { 84 | self.ended_early = true; 85 | return None 86 | }, 87 | Ok(value) => vector.push(value) 88 | } 89 | } 90 | 91 | self.vectors_read += 1; 92 | return Some((word, vector)) 93 | 94 | } 95 | } -------------------------------------------------------------------------------- /src/wordvectors.rs: -------------------------------------------------------------------------------- 1 | use std::io::prelude::*; 2 | use std::io::BufReader; 3 | use std::fs::File; 4 | use std::cmp::Ordering; 5 | use utils; 6 | use errors::Word2VecError; 7 | use vectorreader::WordVectorReader; 8 | 9 | /// Representation of a word vector space 10 | /// 11 | /// Each word of a vocabulary is represented by a vector. All words span a vector space. This data 12 | /// structure manages this vector space of words. 13 | pub struct WordVector { 14 | vocabulary: Vec<(String, Vec)>, 15 | vector_size: usize, 16 | } 17 | 18 | impl WordVector { 19 | 20 | /// Load a word vector space from file 21 | /// 22 | /// Word2vec is able to store the word vectors in a binary file. This function parses the file 23 | /// and loads the vectors into RAM. 24 | pub fn load_from_binary(file_name: &str) -> Result { 25 | let file = File::open(file_name)?; 26 | let reader = BufReader::new(file); 27 | 28 | return WordVector::load_from_reader(reader); 29 | } 30 | 31 | /// Load a word vector space from a reader 32 | /// 33 | /// Word2vec is able to store the word vectors in a binary format. This function parses the bytes in that format 34 | /// and loads the vectors into RAM. 35 | pub fn load_from_reader(reader: R) -> Result { 36 | 37 | let reader = WordVectorReader::new_from_reader(reader)?; 38 | let vector_size = reader.vector_size(); 39 | 40 | let mut vocabulary: Vec<(String, Vec)> = Vec::with_capacity(reader.vocabulary_size()); 41 | for item in reader { 42 | 43 | let (word, mut vector) = item; 44 | utils::vector_norm(&mut vector); 45 | 46 | vocabulary.push((word, vector)); 47 | } 48 | 49 | Ok(WordVector { 50 | vocabulary, 51 | vector_size, 52 | }) 53 | } 54 | 55 | fn get_index(&self, word: &str) -> Option { 56 | self.vocabulary.iter().position(|x| x.0.as_str() == word) 57 | } 58 | 59 | /// Get word vector for the given word. 60 | pub fn get_vector(&self, word: &str) -> Option<&Vec> { 61 | let index = self.get_index(word); 62 | match index { 63 | Some(val) => Some(&self.vocabulary[val].1), 64 | None => None, 65 | } 66 | } 67 | 68 | /// Compute consine distance to similar words. 69 | /// 70 | /// The words in the vector space are characterized through the position and angle to each 71 | /// other. This method calculates the `n` closest words via the cosine of the requested word to 72 | /// all other words. 73 | pub fn cosine(&self, word: &str, n: usize) -> Option> { 74 | let word_vector = self.get_vector(word); 75 | match word_vector { 76 | Some(val) => { // save index and cosine distance to current word 77 | let mut metrics: Vec<(usize, f32)> = Vec::with_capacity(self.vocabulary.len()); 78 | metrics.extend(self.vocabulary.iter().enumerate(). 79 | map(|(i, other_val)| 80 | (i, utils::dot_product(&other_val.1, val)))); 81 | metrics.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal)); 82 | Some(metrics[1..n+1].iter().map(|&(idx, dist)| 83 | (self.vocabulary[idx].clone().0, dist)).collect()) 84 | } 85 | None => None, 86 | } 87 | } 88 | 89 | pub fn analogy(&self, pos: Vec<&str>, neg: Vec<&str>, n: usize) -> Option> { 90 | let mut vectors: Vec> = Vec::new(); 91 | let mut exclude: Vec = Vec::new(); 92 | for word in pos { 93 | exclude.push(word.to_string()); 94 | match self.get_vector(word) { 95 | Some(val) => vectors.push(val.clone()), 96 | None => {} 97 | } 98 | } 99 | for word in neg.iter() { 100 | exclude.push(word.to_string()); 101 | match self.get_vector(word) { 102 | Some(val) => vectors.push(val.iter().map(|x| -x).collect::>()), 103 | None => {} 104 | } 105 | } 106 | if exclude.is_empty() { 107 | return None; 108 | } 109 | let mut mean: Vec = Vec::with_capacity(self.vector_size); 110 | for i in 0..self.vector_size { 111 | mean.push(utils::mean(vectors.iter().map(|v| v[i]))); 112 | } 113 | let mut metrics: Vec<(&String, f32)> = Vec::new(); 114 | for word in self.vocabulary.iter() { 115 | metrics.push((&word.0, utils::dot_product(&word.1, &mean))); 116 | } 117 | metrics.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal)); 118 | metrics.retain(|x| !exclude.contains(&x.0)); 119 | Some(metrics.iter().take(n).map(|&(x,y)| (x.clone(), y)).collect()) 120 | } 121 | 122 | /// Get the number of all known words from the vocabulary. 123 | pub fn word_count(&self) -> usize { 124 | self.vocabulary.len() 125 | } 126 | 127 | /// Return the number of columns of the word vector. 128 | pub fn get_col_count(&self) -> usize { 129 | self.vector_size // size == column count 130 | } 131 | 132 | /// Get all known words from the vocabulary. 133 | pub fn get_words(& self) -> Words { 134 | Words::new(&self.vocabulary) 135 | } 136 | } 137 | 138 | pub struct Words<'parent> { 139 | words: &'parent Vec<(String, Vec)>, 140 | index: usize, 141 | } 142 | 143 | impl<'a> Words<'a> { 144 | fn new(x: &'a Vec<(String, Vec)>) -> Words<'a> { 145 | Words { 146 | words: x, 147 | index: 0, 148 | } 149 | } 150 | } 151 | 152 | impl<'a> Iterator for Words<'a> { 153 | type Item = String; 154 | 155 | fn next(&mut self) -> Option { 156 | if self.index >= self.words.len() { 157 | return None; 158 | } 159 | self.index += 1; 160 | Some(self.words[self.index - 1].0.clone()) 161 | } 162 | } 163 | --------------------------------------------------------------------------------