├── .github ├── dependabot.yml └── workflows │ └── ci.yml ├── .gitignore ├── Cargo.toml ├── LICENSE ├── LICENSE-3rdparty.csv ├── README.md ├── benches └── bench.rs ├── src ├── bin │ ├── assess.rs │ └── create_train_dataset.rs ├── lib.rs └── weights.rs └── train.ipynb /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | # Enable version updates for Cargo 4 | - package-ecosystem: "cargo" 5 | # Look for `Cargo.toml` and `Cargo.lock` files in the `root` directory 6 | directory: "/" 7 | # Check for updates once a week 8 | schedule: 9 | interval: "weekly" 10 | # Specify labels for pull requests 11 | labels: 12 | - "dependencies" 13 | - "rust" 14 | # Allow up to 10 open pull requests 15 | open-pull-requests-limit: 10 16 | # Create pull requests for patch and minor updates 17 | versioning-strategy: auto 18 | # Create a group of dependencies to be updated together 19 | groups: 20 | dependencies: 21 | patterns: 22 | - "*" 23 | 24 | # Enable version updates for GitHub Actions 25 | - package-ecosystem: "github-actions" 26 | directory: "/" 27 | schedule: 28 | interval: "weekly" 29 | labels: 30 | - "dependencies" 31 | - "github-actions" 32 | # Allow up to 5 open pull requests for GitHub Actions 33 | open-pull-requests-limit: 5 -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: [ "main" ] 6 | pull_request: 7 | branches: [ "main" ] 8 | 9 | env: 10 | CARGO_TERM_COLOR: always 11 | 12 | jobs: 13 | test: 14 | name: Run tests 15 | runs-on: ubuntu-latest 16 | steps: 17 | - uses: actions/checkout@v4 18 | 19 | - name: Install Rust toolchain 20 | uses: dtolnay/rust-toolchain@stable 21 | 22 | - name: Cache dependencies 23 | uses: actions/cache@v4 24 | with: 25 | path: | 26 | ~/.cargo/registry 27 | ~/.cargo/git 28 | target 29 | key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} 30 | 31 | - name: Run tests 32 | run: cargo test --verbose 33 | 34 | formatting: 35 | name: Check formatting 36 | runs-on: ubuntu-latest 37 | steps: 38 | - uses: actions/checkout@v4 39 | 40 | - name: Install Rust toolchain 41 | uses: dtolnay/rust-toolchain@stable 42 | with: 43 | components: rustfmt 44 | 45 | - name: Check formatting 46 | run: cargo fmt --all -- --check 47 | 48 | clippy: 49 | name: Lint with Clippy 50 | runs-on: ubuntu-latest 51 | steps: 52 | - uses: actions/checkout@v4 53 | 54 | - name: Install Rust toolchain 55 | uses: dtolnay/rust-toolchain@stable 56 | with: 57 | components: clippy 58 | 59 | - name: Cache dependencies 60 | uses: actions/cache@v4 61 | with: 62 | path: | 63 | ~/.cargo/registry 64 | ~/.cargo/git 65 | target 66 | key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} 67 | 68 | - name: Run Clippy 69 | run: cargo clippy -- -D warnings 70 | 71 | thirdparty-license: 72 | name: Check Datadog third-party license file 73 | runs-on: ubuntu-latest 74 | steps: 75 | - uses: actions/checkout@v4 76 | 77 | - name: Install Rust toolchain 78 | uses: dtolnay/rust-toolchain@stable 79 | 80 | - name: Cache cargo tools 81 | uses: actions/cache@v4 82 | with: 83 | path: ~/.cargo/bin 84 | key: ${{ runner.os }}-cargo-tools-${{ hashFiles('**/Cargo.lock') }} 85 | 86 | - name: Install dd-rust-license-tool 87 | run: dd-rust-license-tool --help || cargo install dd-rust-license-tool 88 | 89 | - name: Check Datadog third-party license file 90 | run: dd-rust-license-tool check -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | /Cargo.lock 3 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "whichlang" 3 | version = "0.1.1" 4 | edition = "2021" 5 | authors = ["Quickwit, Inc. "] 6 | description = "A blazingly fast and lightweight language detection library for Rust." 7 | keywords = ["nlp", "lang", "whichlang", "language", "text-processing"] 8 | categories = ["text-processing", "algorithms"] 9 | repository = "https://github.com/quickwit-oss/whichlang" 10 | homepage = "https://github.com/quickwit-oss/whichlang" 11 | documentation = "https://docs.rs/whichlang" 12 | readme = "README.md" 13 | license = "MIT" 14 | 15 | [dev-dependencies] 16 | criterion = "0.5" 17 | 18 | [[bench]] 19 | name = "bench" 20 | harness = false 21 | 22 | [[bin]] 23 | name = "create_train_dataset" 24 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2023 by Quickwit Inc. 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | -------------------------------------------------------------------------------- /LICENSE-3rdparty.csv: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quickwit-oss/whichlang/93e37265c7d2026658f0a70453fe8f7dd67d3de8/LICENSE-3rdparty.csv -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Whichlang 2 | 3 | This is a language detection library, aiming for both precision and performance. 4 | 5 | # Why build this? 6 | While building [Quickwit](https://github.com/quickwit-oss/quickwit), a search engine tailored for log and tracing data, we found ourselves needing a light, fast, and precise language detection library in Rust that works well with our high throughput requirement. The full story and how it works are detailed in this [blog post](https://quickwit.io/blog/whichlang-language-detection-library). 7 | 8 | # Features 9 | 10 | - No dependency 11 | - Throughput above 100 MB/s for short and long strings. 12 | - Good accuracy (99.5% on my validation dataset, but it really depends on the size of your input.) 13 | - Supported languages: Arabic, Dutch, English, French, German, Hindi, Italian, Japanese, Korean, Mandarin, Portuguese, Russian, Spanish, Swedish, Turkish, and Vietnamese. 14 | 15 | # How does it work? 16 | 17 | It uses a multiclass logistic regression model over: 18 | - 2, 3, 4-grams of letters on ASCII 19 | - codepoint / 128 20 | - a slightly smarter projection of codepoints over a given class. 21 | 22 | We use the hashing trick and project these features over a space of size `4_096`. 23 | 24 | The logistic regression is trained in the python notebook attached, 25 | and used to generate `weight.rs`. 26 | 27 | # Comparison with [Whatlang](https://github.com/greyblake/whatlang-rs) 28 | 29 | The following compares the throughput using the simple benchmark found in this repository and the accuracy using [whatlang-accuracy-benchmark](https://github.com/evanxg852000/whatlang-accuracy-benchmark) benchmark. Overall, Whichlang is about 10x faster and slightly more accurate than Whatlang. 30 | 31 | ### Throughput 32 | 33 | To generate the throughput benchmark, we ported the benchmark available in [this repository](https://github.com/quickwit-oss/whichlang/blob/main/benches/bench.rs). Please, check this [repository](https://github.com/evanxg852000/whatlang-accuracy-benchmark) to see our changes. 34 | 35 | | | Processing Time (µs) | Throughput (MiB/s) | 36 | | ------------------------- | -------------------- | ------------------ | 37 | | whatlang/short | 16.62 | 1.66 | 38 | | whatlang/long | 62.00 | 9.42 | 39 | | whichlang/short | 0.26 | 105.69 | 40 | | whichlang/long | 5.21 | 112.31 | 41 | 42 | ### Accuracy 43 | 44 | 45 | To generate the accuracy benchmark, we have changed the [whatlang-accuracy-benchmark](https://github.com/whatlang/whatlang-accuracy-benchmark) to add support for Whichlang. Given that Whatlang supports more languages, we have used its FilterList feature to restrict its analysis to only languages that are supported in Whichlang. We also use the `trigram` method in Whatlang. Please, check this [repository](https://github.com/evanxg852000/whatlang-accuracy-benchmark) to see our changes. 46 | 47 | ``` 48 | Crate: Whatlang 49 | AVG: 91.69% 50 | 51 | | LANG | AVG | <= 20 | 21-50 | 51-100 | > 100 | 52 | |------------|--------|---------|--------|--------|---------| 53 | | Arabic | 99.68% | 99.51% | 99.64% | 99.83% | 99.76% | 54 | | Mandarin | 96.09% | 97.54% | 96.92% | 95.45% | 94.43% | 55 | | German | 88.57% | 70.00% | 88.53% | 96.61% | 99.16% | 56 | | English | 85.99% | 57.82% | 88.37% | 97.97% | 99.78% | 57 | | French | 90.88% | 72.84% | 92.51% | 98.54% | 99.65% | 58 | | Hindi | 99.80% | 100.00% | 99.83% | 99.78% | 99.61% | 59 | | Italian | 87.75% | 66.67% | 87.74% | 97.04% | 99.54% | 60 | | Japanese | 94.37% | 93.97% | 96.04% | 94.30% | 93.18% | 61 | | Korean | 99.17% | 98.88% | 99.69% | 99.44% | 98.66% | 62 | | Dutch | 89.68% | 72.13% | 89.78% | 97.40% | 99.40% | 63 | | Portuguese | 88.08% | 72.90% | 85.76% | 95.22% | 98.44% | 64 | | Russian | 99.98% | 100.00% | 99.96% | 99.98% | 100.00% | 65 | | Spanish | 82.91% | 55.45% | 82.24% | 94.85% | 99.10% | 66 | | Swedish | 84.16% | 58.33% | 83.78% | 96.35% | 98.18% | 67 | | Turkish | 86.73% | 61.01% | 88.94% | 97.32% | 99.63% | 68 | | Vietnamese | 93.23% | 82.84% | 92.96% | 97.88% | 99.24% | 69 | | AVG | 91.69% | 78.74% | 92.04% | 97.37% | 98.61% | 70 | ``` 71 | 72 | ``` 73 | Crate: Whichlang 74 | AVG: 97.03% 75 | 76 | | LANG | AVG | <= 20 | 21-50 | 51-100 | > 100 | 77 | |------------|---------|---------|---------|---------|---------| 78 | | Arabic | 100.00% | 100.00% | 100.00% | 100.00% | 100.00% | 79 | | Mandarin | 98.65% | 98.69% | 98.48% | 98.55% | 98.87% | 80 | | German | 94.20% | 80.00% | 97.47% | 99.49% | 99.84% | 81 | | English | 97.15% | 91.84% | 97.25% | 99.57% | 99.93% | 82 | | French | 97.59% | 93.83% | 97.61% | 99.20% | 99.71% | 83 | | Hindi | 100.00% | 100.00% | 100.00% | 100.00% | 100.00% | 84 | | Italian | 97.20% | 93.06% | 97.33% | 98.85% | 99.57% | 85 | | Japanese | 94.92% | 88.95% | 95.14% | 97.74% | 97.85% | 86 | | Korean | 99.83% | 99.44% | 99.98% | 99.97% | 99.94% | 87 | | Dutch | 97.08% | 92.84% | 96.98% | 98.91% | 99.60% | 88 | | Portuguese | 94.07% | 83.87% | 94.89% | 98.18% | 99.36% | 89 | | Russian | 99.92% | 99.69% | 99.99% | 100.00% | 100.00% | 90 | | Spanish | 92.12% | 76.36% | 93.78% | 98.65% | 99.70% | 91 | | Swedish | 95.37% | 90.28% | 94.94% | 97.76% | 98.51% | 92 | | Turkish | 95.51% | 88.24% | 98.11% | 98.38% | 97.33% | 93 | | Vietnamese | 98.79% | 96.57% | 98.87% | 99.77% | 99.96% | 94 | | AVG | 97.03% | 92.10% | 97.55% | 99.06% | 99.39% | 95 | ``` 96 | -------------------------------------------------------------------------------- /benches/bench.rs: -------------------------------------------------------------------------------- 1 | use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput}; 2 | 3 | // A random ascii string of length 100 chars. 4 | const ASCII_SHORT: &str = "It is a long established fact"; 5 | const ASCII_MEDIUM: &str = "It is a long established fact that a reader will be distracted by the readable content of a page when looking at its layout. The point of using Lorem Ipsum is that it has a more-or-less normal distribution of letters, as opposed to using 'Content here, content here', making it look like readable English. Many desktop publishing packages and web page editors now use Lorem Ipsum as their default model text, and a search for 'lorem ipsum' will uncover many web sites still in their infancy. Various versions have evolved over the years, sometimes by accident, sometimes on purpose (injected humour and the like)."; 6 | const JP_SHORT: &str = "日本ごです。 とても素敵な言葉ですね"; 7 | const JP_MEDIUM: &str = "日本ごです。 和名の由来は、太陽の動きにつれてその方向を追うように花が回るといわれたことから。ただしこの動きは生長に伴うものであるため、実際に太陽を追って動くのは生長が盛んな若い時期だけである。若いヒマワリの茎の上部の葉は太陽に正対になるように動き、朝には東を向いていたのが夕方には西を向く。日没後はまもなく起きあがり、夜明け前にはふたたび東に向く。この運動はつぼみを付ける頃まで続くが、つぼみが大きくなり花が開く素敵な言葉ですね."; 8 | 9 | pub fn criterion_benchmark(c: &mut Criterion) { 10 | let mut group = c.benchmark_group("whichlang"); 11 | group 12 | .throughput(Throughput::Bytes(ASCII_SHORT.len() as u64)) 13 | .bench_with_input("inference_short", ASCII_SHORT, |b, text| { 14 | b.iter(|| whichlang::detect_language(black_box(text))); 15 | }); 16 | group 17 | .throughput(Throughput::Bytes(ASCII_MEDIUM.len() as u64)) 18 | .bench_with_input("inference_long", ASCII_MEDIUM, |b, text| { 19 | b.iter(|| whichlang::detect_language(black_box(text))); 20 | }); 21 | group 22 | .throughput(Throughput::Bytes(JP_SHORT.len() as u64)) 23 | .bench_with_input("inference_jp_short", JP_SHORT, |b, text| { 24 | b.iter(|| whichlang::detect_language(black_box(text))); 25 | }); 26 | group 27 | .throughput(Throughput::Bytes(JP_MEDIUM.len() as u64)) 28 | .bench_with_input("inference_jp_medium", JP_MEDIUM, |b, text| { 29 | b.iter(|| whichlang::detect_language(black_box(text))); 30 | }); 31 | } 32 | 33 | criterion_group!(benches, criterion_benchmark); 34 | criterion_main!(benches); 35 | -------------------------------------------------------------------------------- /src/bin/assess.rs: -------------------------------------------------------------------------------- 1 | use std::io; 2 | use std::io::BufRead; 3 | use std::io::BufReader; 4 | 5 | use whichlang::Lang; 6 | 7 | fn main() -> io::Result<()> { 8 | let label_codes: Vec<&'static str> = whichlang::LANGUAGES 9 | .iter() 10 | .map(|lang| lang.three_letter_code()) 11 | .collect(); 12 | let mut line = String::new(); 13 | let stdin = io::stdin(); 14 | let mut stdinlocked = BufReader::new(stdin.lock()); 15 | let mut total = 0; 16 | let mut error = 0; 17 | loop { 18 | line.clear(); 19 | 20 | if stdinlocked.read_line(&mut line)? == 0 { 21 | break; 22 | } 23 | let trimmed_line = line.trim_end().trim_matches('\\'); 24 | let id_label_sentence: Vec<&str> = trimmed_line.splitn(3, "\t").collect(); 25 | if !label_codes.contains(&id_label_sentence[1]) { 26 | continue; 27 | } 28 | let detected: Lang = whichlang::detect_language(id_label_sentence[2]); 29 | total += 1; 30 | if detected.three_letter_code() != id_label_sentence[1] { 31 | error += 1; 32 | println!( 33 | "{} {} {} : {}", 34 | id_label_sentence[0], 35 | id_label_sentence[1], 36 | id_label_sentence[2], 37 | detected.three_letter_code() 38 | ); 39 | println!("precision: {}", 1.0 - (error as f64 / total as f64)); 40 | } 41 | } 42 | Ok(()) 43 | } 44 | -------------------------------------------------------------------------------- /src/bin/create_train_dataset.rs: -------------------------------------------------------------------------------- 1 | use std::io; 2 | use std::io::BufRead; 3 | use std::io::BufReader; 4 | use std::io::BufWriter; 5 | use std::io::Write; 6 | 7 | use whichlang::emit_tokens; 8 | use whichlang::DIMENSION; 9 | 10 | fn main() -> io::Result<()> { 11 | let language_codes = whichlang::LANGUAGES 12 | .iter() 13 | .map(|lang| lang.three_letter_code()) 14 | .collect::>(); 15 | let mut line = String::new(); 16 | let mut features: Vec = vec![0; whichlang::DIMENSION]; 17 | let stdin = io::stdin(); 18 | let mut stdinlocked = BufReader::new(stdin.lock()); 19 | let stdout = io::stdout(); 20 | let mut stdoutlock = BufWriter::new(stdout.lock()); 21 | write!(&mut stdoutlock, "id,label")?; 22 | for i in 0..whichlang::DIMENSION { 23 | write!(&mut stdoutlock, ",feature{i}")?; 24 | } 25 | writeln!(&mut stdoutlock)?; 26 | loop { 27 | line.clear(); 28 | if stdinlocked.read_line(&mut line)? == 0 { 29 | break; 30 | } 31 | let trimmed_line = line.trim_end().trim_matches('\\'); 32 | features.fill(0); 33 | let id_label_sentence: Vec<&str> = trimmed_line.splitn(3, "\t").collect(); 34 | if !language_codes.contains(&id_label_sentence[1]) { 35 | continue; 36 | } 37 | let sentence: &str = id_label_sentence[2]; 38 | emit_tokens(sentence, |token| { 39 | features[token.to_hash() as usize % DIMENSION] += 1; 40 | }); 41 | write!( 42 | &mut stdoutlock, 43 | "{},{}", 44 | id_label_sentence[0], id_label_sentence[1] 45 | )?; 46 | for &feature in &features { 47 | write!(&mut stdoutlock, ",{feature}")?; 48 | } 49 | writeln!(&mut stdoutlock)?; 50 | line.clear(); 51 | } 52 | Ok(()) 53 | } 54 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | pub use crate::weights::{Lang, LANGUAGES}; 2 | 3 | #[allow(clippy::all)] 4 | mod weights; 5 | 6 | const NUM_LANGUAGES: usize = LANGUAGES.len(); 7 | 8 | #[doc(hidden)] 9 | pub const DIMENSION: usize = 1 << 12; 10 | const BIGRAM_MASK: u32 = (1 << 16) - 1; 11 | const TRIGRAM_MASK: u32 = (1 << 24) - 1; 12 | 13 | #[derive(Copy, Clone, Debug, Eq, PartialEq)] 14 | #[doc(hidden)] 15 | pub enum Feature { 16 | AsciiNGram(u32), 17 | Unicode(char), 18 | UnicodeClass(char), 19 | } 20 | 21 | const SEED: u32 = 3_242_157_231u32; 22 | 23 | #[inline(always)] 24 | fn murmurhash2(mut k: u32, seed: u32) -> u32 { 25 | const M: u32 = 0x5bd1_e995; 26 | let mut h: u32 = seed; 27 | k = k.wrapping_mul(M); 28 | k ^= k >> 24; 29 | k = k.wrapping_mul(M); 30 | h = h.wrapping_mul(M); 31 | h ^= k; 32 | h ^= h >> 13; 33 | h = h.wrapping_mul(M); 34 | h ^ (h >> 15) 35 | } 36 | 37 | impl Feature { 38 | #[inline(always)] 39 | pub fn to_hash(&self) -> u32 { 40 | match self { 41 | Feature::AsciiNGram(ngram) => murmurhash2(*ngram, SEED), 42 | Feature::Unicode(chr) => murmurhash2(*chr as u32 / 128, SEED ^ 2), 43 | Feature::UnicodeClass(chr) => murmurhash2(classify_codepoint(*chr), SEED ^ 4), 44 | } 45 | } 46 | } 47 | 48 | pub fn detect_language(text: &str) -> Lang { 49 | let mut scores: [f32; NUM_LANGUAGES] = Default::default(); 50 | let mut num_features: u32 = 0; 51 | emit_tokens( 52 | text, 53 | #[inline(always)] 54 | |token| { 55 | num_features += 1u32; 56 | let bucket = token.to_hash() % DIMENSION as u32; 57 | let idx = bucket as usize * NUM_LANGUAGES; 58 | let per_language_scores = &weights::WEIGHTS[idx..idx + NUM_LANGUAGES]; 59 | for i in 0..NUM_LANGUAGES { 60 | scores[i] += per_language_scores[i]; 61 | } 62 | }, 63 | ); 64 | if num_features == 0 { 65 | // By default, we return English 66 | return Lang::Eng; 67 | } 68 | 69 | let sqrt_inv_num_features = 1.0f32 / (num_features as f32).sqrt(); 70 | #[allow(clippy::needless_range_loop)] 71 | for i in 0..NUM_LANGUAGES { 72 | // Ok so the sqrt(num_features) is not really the norm, but whatever. 73 | scores[i] = scores[i] * sqrt_inv_num_features + weights::INTERCEPTS[i]; 74 | } 75 | 76 | let lang_id = scores 77 | .iter() 78 | .enumerate() 79 | .max_by(|(_, &score_left), (_, &score_right)| score_left.partial_cmp(&score_right).unwrap()) 80 | .map(|(pos, _val)| pos) 81 | .unwrap(); 82 | weights::LANGUAGES[lang_id] 83 | } 84 | 85 | #[doc(hidden)] 86 | pub fn emit_tokens(text: &str, mut listener: impl FnMut(Feature)) { 87 | let mut prev = ' ' as u32; 88 | let mut num_previous_ascii_chr = 1; 89 | for chr in text.chars() { 90 | let code = chr.to_ascii_lowercase() as u32; 91 | if !chr.is_ascii() { 92 | listener(Feature::Unicode(chr)); 93 | listener(Feature::UnicodeClass(chr)); 94 | num_previous_ascii_chr = 0; 95 | continue; 96 | } 97 | prev = prev << 8 | code; 98 | match num_previous_ascii_chr { 99 | 0 => { 100 | num_previous_ascii_chr = 1; 101 | } 102 | 1 => { 103 | listener(Feature::AsciiNGram(prev & BIGRAM_MASK)); 104 | num_previous_ascii_chr = 2; 105 | } 106 | 2 => { 107 | listener(Feature::AsciiNGram(prev & BIGRAM_MASK)); 108 | listener(Feature::AsciiNGram(prev & TRIGRAM_MASK)); 109 | num_previous_ascii_chr = 3; 110 | } 111 | 3 => { 112 | listener(Feature::AsciiNGram(prev & BIGRAM_MASK)); 113 | listener(Feature::AsciiNGram(prev & TRIGRAM_MASK)); 114 | listener(Feature::AsciiNGram(prev)); 115 | } 116 | _ => { 117 | unreachable!(); 118 | } 119 | } 120 | if !chr.is_alphanumeric() { 121 | prev = ' ' as u32; 122 | } 123 | } 124 | } 125 | 126 | const JP_PUNCT_START: u32 = 0x3000; 127 | const JP_PUNCT_END: u32 = 0x303f; 128 | const JP_HIRAGANA_START: u32 = 0x3040; 129 | const JP_HIRAGANA_END: u32 = 0x309f; 130 | const JP_KATAKANA_START: u32 = 0x30a0; 131 | const JP_KATAKANA_END: u32 = 0x30ff; 132 | const CJK_KANJI_START: u32 = 0x4e00; 133 | const CJK_KANJI_END: u32 = 0x9faf; 134 | const JP_HALFWIDTH_KATAKANA_START: u32 = 0xff61; 135 | const JP_HALFWIDTH_KATAKANA_END: u32 = 0xff90; 136 | 137 | fn classify_codepoint(chr: char) -> u32 { 138 | [ 139 | 160, 140 | 161, 141 | 171, 142 | 172, 143 | 173, 144 | 174, 145 | 187, 146 | 192, 147 | 196, 148 | 199, 149 | 200, 150 | 201, 151 | 202, 152 | 205, 153 | 214, 154 | 220, 155 | 223, 156 | 224, 157 | 225, 158 | 226, 159 | 227, 160 | 228, 161 | 231, 162 | 232, 163 | 233, 164 | 234, 165 | 235, 166 | 236, 167 | 237, 168 | 238, 169 | 239, 170 | 242, 171 | 243, 172 | 244, 173 | 245, 174 | 246, 175 | 249, 176 | 250, 177 | 251, 178 | 252, 179 | 333, 180 | 339, 181 | JP_PUNCT_START, 182 | JP_PUNCT_END, 183 | JP_HIRAGANA_START, 184 | JP_HIRAGANA_END, 185 | JP_KATAKANA_START, 186 | JP_KATAKANA_END, 187 | CJK_KANJI_START, 188 | CJK_KANJI_END, 189 | JP_HALFWIDTH_KATAKANA_START, 190 | JP_HALFWIDTH_KATAKANA_END, 191 | ] 192 | .binary_search(&(chr as u32)) 193 | .unwrap_or_else(|pos| pos) as u32 194 | } 195 | 196 | #[cfg(test)] 197 | mod tests { 198 | use crate::detect_language; 199 | use crate::emit_tokens; 200 | use crate::Feature; 201 | use crate::Lang; 202 | 203 | fn ascii_ngram_feature(text: &str) -> Feature { 204 | assert!(text.is_ascii()); 205 | let mut bytes: [u8; 4] = [0u8; 4]; 206 | assert!(text.len() <= 4); 207 | bytes[4 - text.len()..].copy_from_slice(text.as_bytes()); 208 | Feature::AsciiNGram(u32::from_be_bytes(bytes)) 209 | } 210 | 211 | #[test] 212 | fn test_emit_tokens() { 213 | let mut tokens = Vec::new(); 214 | emit_tokens("hello こん!", |token| tokens.push(token)); 215 | assert_eq!( 216 | &tokens, 217 | &[ 218 | ascii_ngram_feature(" h"), 219 | ascii_ngram_feature("he"), 220 | ascii_ngram_feature(" he"), 221 | ascii_ngram_feature("el"), 222 | ascii_ngram_feature("hel"), 223 | ascii_ngram_feature(" hel"), 224 | ascii_ngram_feature("ll"), 225 | ascii_ngram_feature("ell"), 226 | ascii_ngram_feature("hell"), 227 | ascii_ngram_feature("lo"), 228 | ascii_ngram_feature("llo"), 229 | ascii_ngram_feature("ello"), 230 | Feature::Unicode(' '), 231 | Feature::UnicodeClass(' '), 232 | Feature::Unicode('こ'), 233 | Feature::UnicodeClass('こ'), 234 | Feature::Unicode('ん'), 235 | Feature::UnicodeClass('ん'), 236 | Feature::Unicode('!'), 237 | Feature::UnicodeClass('!'), 238 | ] 239 | ); 240 | } 241 | 242 | #[test] 243 | fn test_empty_str() { 244 | assert_eq!(detect_language(""), Lang::Eng); 245 | } 246 | 247 | #[test] 248 | fn test_detect_language() { 249 | // English 250 | assert_eq!(detect_language("Hello, happy tax payer"), Lang::Eng); 251 | // French 252 | assert_eq!(detect_language("Bonjour joyeux contribuable"), Lang::Fra); 253 | // German 254 | assert_eq!(detect_language("Hallo glücklicher Steuerzahler"), Lang::Deu); 255 | // Japanese 256 | assert_eq!(detect_language("こんにちは幸せな税金納め"), Lang::Jpn); 257 | // Mandarin chinese 258 | assert_eq!(detect_language("你好幸福的纳税人"), Lang::Cmn); 259 | // Turkish 260 | assert_eq!(detect_language("Merhaba, mutlu vergi mükellefi"), Lang::Tur); 261 | // Dutch 262 | assert_eq!(detect_language("Hallo, blije belastingbetaler"), Lang::Nld); 263 | // Korean 264 | assert_eq!(detect_language("안녕하세요 행복한 납세자입니다"), Lang::Kor); 265 | // Italian 266 | assert_eq!(detect_language("Ciao, felice contribuente!"), Lang::Ita); 267 | // Spanish 268 | assert_eq!(detect_language("Hola feliz contribuyente"), Lang::Spa); 269 | assert_eq!(detect_language("¡Hola!"), Lang::Spa); 270 | // Portuguese 271 | assert_eq!(detect_language("Olá feliz contribuinte"), Lang::Por); 272 | } 273 | } 274 | -------------------------------------------------------------------------------- /train.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "e34afb69-8f87-4073-9560-cb80c67e6ae5", 7 | "metadata": { 8 | "tags": [] 9 | }, 10 | "outputs": [], 11 | "source": [ 12 | "import numpy as np\n", 13 | "import sklearn\n", 14 | "from sklearn import linear_model\n", 15 | "import polars\n", 16 | "from sklearn.model_selection import train_test_split\n", 17 | "from scipy import sparse" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": 24, 23 | "id": "b02b8789-e927-4df5-96e1-d26706ff5fa2", 24 | "metadata": { 25 | "tags": [] 26 | }, 27 | "outputs": [ 28 | { 29 | "name": "stdout", 30 | "output_type": "stream", 31 | "text": [ 32 | "(49998, 4098)\n", 33 | "(49998, 4098)\n", 34 | "(49998, 4098)\n", 35 | "(49998, 4098)\n", 36 | "(49998, 4098)\n", 37 | "(49998, 4098)\n", 38 | "(49998, 4098)\n", 39 | "(49998, 4098)\n", 40 | "(49998, 4098)\n", 41 | "(49998, 4098)\n", 42 | "(49998, 4098)\n", 43 | "(49998, 4098)\n", 44 | "(599976, 4096)\n", 45 | "(599976,)\n", 46 | "(449982, 4096) (149994, 4096)\n" 47 | ] 48 | } 49 | ], 50 | "source": [ 51 | "\n", 52 | "df_it = polars.read_csv_batched(\"train.csv\", has_header=True)\n", 53 | "ys = []\n", 54 | "Xs = []\n", 55 | "idxs = []\n", 56 | "for i in range(12):\n", 57 | " df = df_it.next_batches(1)[0]\n", 58 | " df[0, 0:10]\n", 59 | " print(df.shape)\n", 60 | " y = df[:, 1].to_numpy()\n", 61 | " ys.append(y)\n", 62 | " X = sparse.csr_matrix(np.float32(df[:, 2:].to_numpy()))\n", 63 | " X = sklearn.preprocessing.normalize(X)\n", 64 | " Xs.append(X)\n", 65 | " idx = df[:, 0]\n", 66 | " idxs.append(idx)\n", 67 | "\n", 68 | "X = sparse.vstack(Xs)\n", 69 | "y = np.hstack(ys)\n", 70 | "\n", 71 | "del Xs\n", 72 | "del ys\n", 73 | "\n", 74 | "print(X.shape)\n", 75 | "print(y.shape)\n", 76 | "(X_train, X_test, y_train, y_test) = train_test_split(X, y, shuffle=False)\n", 77 | "del X\n", 78 | "del y\n", 79 | "n_train = X_train.shape[0]\n", 80 | "print(X_train.shape, X_test.shape)\n", 81 | "\n", 82 | "\n", 83 | "# print(df.shape)\n", 84 | "# print(df[:, 2:].mean())\n" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": 26, 90 | "id": "4e6d2504-69d1-4135-844f-1d80a1d96e04", 91 | "metadata": { 92 | "tags": [] 93 | }, 94 | "outputs": [ 95 | { 96 | "name": "stdout", 97 | "output_type": "stream", 98 | "text": [ 99 | "160, 171, 173, 187, 192, 196, 199, 200, 201, 202, 205, 214, 220, 223, 224, 225, 226, 227, 228, 231, 232, 233, 234, 235, 236, 237, 238, 239, 242, 243, 244, 245, 246, 249, 250, 251, 252, 333, 339, 8201, 8211, 8212, 8217, 8220, 8221, 8222, 8239\n" 100 | ] 101 | } 102 | ], 103 | "source": [ 104 | "from collections import Counter\n", 105 | "counter = Counter()\n", 106 | "\n", 107 | "c = 0\n", 108 | "for line in open(\"dataset/archive/sentences.prepared.csv\"):\n", 109 | " (rid, lang, sentence) = line.strip().split(\"\\t\", 2)\n", 110 | " if lang not in {\"fra\",\"eng\", \"ita\", \"deu\", \"esp\", \"por\"}:\n", 111 | " continue\n", 112 | " c += 1\n", 113 | " if c > 100_000:\n", 114 | " break\n", 115 | " for chr in sentence:\n", 116 | " if ord(chr) > 128:\n", 117 | " counter[chr] += 1\n", 118 | "letters = sorted(ord(letter) for (letter, count) in counter.most_common(100) if count >= 10)\n", 119 | "print(\", \".join(map(str, letters)))\n", 120 | "#print(letters)" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": 27, 126 | "id": "5f7408bc-da6b-4608-84e0-af67b203a748", 127 | "metadata": { 128 | "tags": [] 129 | }, 130 | "outputs": [ 131 | { 132 | "name": "stderr", 133 | "output_type": "stream", 134 | "text": [ 135 | "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", 136 | " This problem is unconstrained.\n" 137 | ] 138 | }, 139 | { 140 | "name": "stdout", 141 | "output_type": "stream", 142 | "text": [ 143 | "RUNNING THE L-BFGS-B CODE\n", 144 | "\n", 145 | " * * *\n", 146 | "\n", 147 | "Machine precision = 2.220D-16\n", 148 | " N = 65552 M = 10\n", 149 | "\n", 150 | "At X0 0 variables are exactly at the bounds\n", 151 | "\n", 152 | "At iterate 0 f= 1.24762D+06 |proj g|= 2.45139D+04\n", 153 | "\n", 154 | "At iterate 50 f= 1.21037D+04 |proj g|= 3.55325D+02\n", 155 | "\n", 156 | "At iterate 100 f= 4.94206D+03 |proj g|= 1.13277D+02\n", 157 | "\n", 158 | "At iterate 150 f= 3.82850D+03 |proj g|= 1.43956D+01\n", 159 | "\n", 160 | "At iterate 200 f= 3.43115D+03 |proj g|= 1.99731D+01\n", 161 | "\n", 162 | "At iterate 250 f= 3.29184D+03 |proj g|= 1.39921D+01\n", 163 | "\n", 164 | "At iterate 300 f= 3.23884D+03 |proj g|= 4.21846D+00\n", 165 | "\n", 166 | "At iterate 350 f= 3.21968D+03 |proj g|= 1.86934D+00\n", 167 | "\n", 168 | "At iterate 400 f= 3.21289D+03 |proj g|= 7.79969D-01\n", 169 | "\n", 170 | "At iterate 450 f= 3.21017D+03 |proj g|= 1.19910D+00\n", 171 | "\n", 172 | "At iterate 500 f= 3.20925D+03 |proj g|= 7.50076D-01\n", 173 | "\n", 174 | " * * *\n", 175 | "\n", 176 | "Tit = total number of iterations\n", 177 | "Tnf = total number of function evaluations\n", 178 | "Tnint = total number of segments explored during Cauchy searches\n", 179 | "Skip = number of BFGS updates skipped\n", 180 | "Nact = number of active bounds at final generalized Cauchy point\n", 181 | "Projg = norm of the final projected gradient\n", 182 | "F = final function value\n", 183 | "\n", 184 | " * * *\n", 185 | "\n", 186 | " N Tit Tnf Tnint Skip Nact Projg F\n", 187 | "65552 500 549 1 0 0 7.501D-01 3.209D+03\n", 188 | " F = 3209.2533390579683 \n", 189 | "\n", 190 | "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT \n", 191 | "\n", 192 | "\n", 193 | "\n", 194 | "\n", 195 | "\n", 196 | "----------\n", 197 | "64\n" 198 | ] 199 | }, 200 | { 201 | "name": "stderr", 202 | "output_type": "stream", 203 | "text": [ 204 | "/home/fulmicoton/miniconda3/lib/python3.9/site-packages/sklearn/linear_model/_logistic.py:458: ConvergenceWarning: lbfgs failed to converge (status=1):\n", 205 | "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n", 206 | "\n", 207 | "Increase the number of iterations (max_iter) or scale the data as shown in:\n", 208 | " https://scikit-learn.org/stable/modules/preprocessing.html\n", 209 | "Please also refer to the documentation for alternative solver options:\n", 210 | " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", 211 | " n_iter_i = _check_optimize_result(\n", 212 | "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 8.0min finished\n" 213 | ] 214 | }, 215 | { 216 | "name": "stdout", 217 | "output_type": "stream", 218 | "text": [ 219 | "0.9997377672884693\n", 220 | "0.9961465125271678\n" 221 | ] 222 | } 223 | ], 224 | "source": [ 225 | "for C in [64]:\n", 226 | " model = sklearn.linear_model.LogisticRegression(max_iter=500, penalty='l2', multi_class='multinomial', C=C, verbose=1, class_weight='balanced') #, l1_ratio=0.1,) # penalty='elasticnet', solver='saga') \n", 227 | " model.fit(X_train, y_train)\n", 228 | " print(\"\\n\\n\\n\\n\\n----------\")\n", 229 | " print(C)\n", 230 | " print((model.predict(X_train) == y_train).mean())\n", 231 | " print((model.predict(X_test) == y_test).mean())" 232 | ] 233 | }, 234 | { 235 | "cell_type": "code", 236 | "execution_count": 21, 237 | "id": "d47b1c8a-2d26-41aa-a537-91c86ecf77c6", 238 | "metadata": { 239 | "tags": [] 240 | }, 241 | "outputs": [ 242 | { 243 | "name": "stdout", 244 | "output_type": "stream", 245 | "text": [ 246 | "['ara' 'cmn' 'deu' 'eng' 'fra' 'hin' 'ita' 'jpn' 'kor' 'nld' 'por' 'rus'\n", 247 | " 'spa' 'swe' 'tur' 'vie']\n", 248 | "[[13558 13 4 2 26 1 0 3]\n", 249 | " [ 20 36656 23 10 27 10 0 20]\n", 250 | " [ 2 9 11541 9 2 5 0 11]\n", 251 | " [ 1 10 25 18634 4 28 0 41]\n", 252 | " [ 14 9 2 4 3663 1 0 0]\n", 253 | " [ 3 3 16 28 0 9036 0 104]\n", 254 | " [ 0 0 0 0 0 0 21220 0]\n", 255 | " [ 1 4 10 29 1 68 0 8532]]\n" 256 | ] 257 | }, 258 | { 259 | "data": { 260 | "text/plain": [ 261 | "array([[ 209, 0, 0],\n", 262 | " [ 0, 5127, 3],\n", 263 | " [ 0, 1, 1704]])" 264 | ] 265 | }, 266 | "execution_count": 21, 267 | "metadata": {}, 268 | "output_type": "execute_result" 269 | } 270 | ], 271 | "source": [ 272 | "import time\n", 273 | "time.sleep(100)\n", 274 | "print(\"start\")from sklearn import metrics\n", 275 | "print(model.classes_)\n", 276 | "print(sklearn.metrics.confusion_matrix(y_test, model.predict(X_test), labels=['deu', 'eng', 'fra', 'ita','nld', 'por', 'rus', 'spa']))\n", 277 | "\n", 278 | "sklearn.metrics.confusion_matrix(y_test, model.predict(X_test), labels=['kor', 'jpn', 'cmn'])" 279 | ] 280 | }, 281 | { 282 | "cell_type": "code", 283 | "execution_count": 30, 284 | "id": "6ba8d76f-4dde-45d9-b07d-b12b4355b0d6", 285 | "metadata": { 286 | "tags": [] 287 | }, 288 | "outputs": [ 289 | { 290 | "name": "stdout", 291 | "output_type": "stream", 292 | "text": [ 293 | "['ara' 'cmn' 'deu' 'eng' 'fra' 'hin' 'ita' 'jpn' 'kor' 'nld' 'por' 'rus'\n", 294 | " 'spa' 'swe' 'tur' 'vie']\n", 295 | "[[13565 7 2 2 22 3 0 3]\n", 296 | " [ 13 36698 8 12 20 7 0 9]\n", 297 | " [ 1 6 11555 7 1 4 0 5]\n", 298 | " [ 3 6 20 18666 4 22 0 29]\n", 299 | " [ 11 12 2 4 3665 5 0 1]\n", 300 | " [ 4 3 6 22 1 9064 0 89]\n", 301 | " [ 0 0 0 0 0 0 21220 0]\n", 302 | " [ 2 3 9 34 0 58 0 8539]]\n" 303 | ] 304 | }, 305 | { 306 | "data": { 307 | "text/plain": [ 308 | "array([[ 209, 0, 0],\n", 309 | " [ 0, 5128, 3],\n", 310 | " [ 0, 2, 1703]])" 311 | ] 312 | }, 313 | "execution_count": 30, 314 | "metadata": {}, 315 | "output_type": "execute_result" 316 | } 317 | ], 318 | "source": [ 319 | "\n", 320 | "from sklearn import metrics\n", 321 | "print(model.classes_)\n", 322 | "print(sklearn.metrics.confusion_matrix(y_test, model.predict(X_test), labels=['deu', 'eng', 'fra', 'ita','nld', 'por', 'rus', 'spa']))\n", 323 | "\n", 324 | "sklearn.metrics.confusion_matrix(y_test, model.predict(X_test), labels=['kor', 'jpn', 'cmn'])" 325 | ] 326 | }, 327 | { 328 | "cell_type": "code", 329 | "execution_count": 26, 330 | "id": "0efb14c6-50db-4bd6-bb33-cd2dcee69493", 331 | "metadata": { 332 | "tags": [] 333 | }, 334 | "outputs": [ 335 | { 336 | "name": "stdout", 337 | "output_type": "stream", 338 | "text": [ 339 | "['ara' 'cmn' 'deu' 'eng' 'fra' 'hin' 'ita' 'jpn' 'kor' 'nld' 'por' 'rus'\n", 340 | " 'spa' 'swe' 'tur' 'vie']\n", 341 | "[[13562 11 2 3 21 2 0 3]\n", 342 | " [ 14 36684 11 13 21 5 0 15]\n", 343 | " [ 1 6 11551 8 2 4 0 7]\n", 344 | " [ 3 6 21 18662 5 22 0 30]\n", 345 | " [ 11 12 2 4 3666 5 0 0]\n", 346 | " [ 4 3 8 23 1 9060 0 90]\n", 347 | " [ 0 0 0 0 0 0 21220 0]\n", 348 | " [ 2 3 8 35 0 57 0 8539]]\n" 349 | ] 350 | }, 351 | { 352 | "data": { 353 | "text/plain": [ 354 | "array([[ 209, 0, 0],\n", 355 | " [ 0, 5128, 3],\n", 356 | " [ 0, 1, 1704]])" 357 | ] 358 | }, 359 | "execution_count": 26, 360 | "metadata": {}, 361 | "output_type": "execute_result" 362 | } 363 | ], 364 | "source": [ 365 | "from sklearn import metrics\n", 366 | "print(model.classes_)\n", 367 | "print(sklearn.metrics.confusion_matrix(y_test, model.predict(X_test), labels=['deu', 'eng', 'fra', 'ita','nld', 'por', 'rus', 'spa']))\n", 368 | "\n", 369 | "sklearn.metrics.confusion_matrix(y_test, model.predict(X_test), labels=['kor', 'jpn', 'cmn'])" 370 | ] 371 | }, 372 | { 373 | "cell_type": "code", 374 | "execution_count": 16, 375 | "id": "c4252d1b-1ed4-4ffb-8f6a-0196b32622f9", 376 | "metadata": { 377 | "tags": [] 378 | }, 379 | "outputs": [ 380 | { 381 | "name": "stdout", 382 | "output_type": "stream", 383 | "text": [ 384 | "['ara' 'cmn' 'deu' 'eng' 'fra' 'hin' 'ita' 'jpn' 'kor' 'nld' 'por' 'rus'\n", 385 | " 'spa' 'swe' 'tur' 'vie']\n", 386 | "[[ 6866 8 2 6 10 2 0 7]\n", 387 | " [ 4 18245 22 11 30 3 0 22]\n", 388 | " [ 0 2 5838 9 3 7 0 12]\n", 389 | " [ 1 6 14 9446 0 27 0 40]\n", 390 | " [ 10 7 1 3 1770 0 0 4]\n", 391 | " [ 1 2 10 31 4 4517 0 58]\n", 392 | " [ 0 0 0 0 0 0 10485 0]\n", 393 | " [ 2 4 10 34 1 43 0 4133]]\n" 394 | ] 395 | }, 396 | { 397 | "data": { 398 | "text/plain": [ 399 | "array([[ 101, 0, 0],\n", 400 | " [ 0, 2561, 5],\n", 401 | " [ 0, 0, 812]])" 402 | ] 403 | }, 404 | "execution_count": 16, 405 | "metadata": {}, 406 | "output_type": "execute_result" 407 | } 408 | ], 409 | "source": [ 410 | "from sklearn import metrics\n", 411 | "print(model.classes_)\n", 412 | "print(sklearn.metrics.confusion_matrix(y_test, model.predict(X_test), labels=['deu', 'eng', 'fra', 'ita','nld', 'por', 'rus', 'spa']))\n", 413 | "\n", 414 | "sklearn.metrics.confusion_matrix(y_test, model.predict(X_test), labels=['kor', 'jpn', 'cmn'])" 415 | ] 416 | }, 417 | { 418 | "cell_type": "code", 419 | "execution_count": 27, 420 | "id": "476077a6-2110-447a-8197-f0848777a508", 421 | "metadata": { 422 | "tags": [] 423 | }, 424 | "outputs": [ 425 | { 426 | "name": "stdout", 427 | "output_type": "stream", 428 | "text": [ 429 | "(565,)\n", 430 | "(75000,)\n", 431 | "ita por 8484775\n", 432 | "ita spa 3553712\n", 433 | "ita por 2614727\n", 434 | "ita spa 3213159\n", 435 | "tur nld 4129221\n", 436 | "por spa 6992976\n", 437 | "por spa 7888818\n", 438 | "deu fra 843713\n", 439 | "eng swe 1164430\n", 440 | "por spa 972364\n" 441 | ] 442 | } 443 | ], 444 | "source": [ 445 | "y_predict = model.predict(X_test)\n", 446 | "print(np.where((y_predict == y_test) == False)[0].shape)\n", 447 | "print(y_predict.shape)\n", 448 | "i = 0 \n", 449 | "for row in list(np.where((y_predict == y_test) == False))[0]:\n", 450 | " i += 1\n", 451 | " print(y_test[row], y_predict[row], idx[int(n_train + row)])\n", 452 | " if i == 10:\n", 453 | " break" 454 | ] 455 | }, 456 | { 457 | "cell_type": "code", 458 | "execution_count": 31, 459 | "id": "ac15d1d4-c839-4e96-89cc-724af639d231", 460 | "metadata": { 461 | "tags": [] 462 | }, 463 | "outputs": [ 464 | { 465 | "name": "stdout", 466 | "output_type": "stream", 467 | "text": [ 468 | "(16, 4096)\n", 469 | "-0.0047384733\n", 470 | "[-4.73847311e-03 -2.28266937e-03 -5.36394365e-01 5.88218350e-01\n", 471 | " -1.03133002e-01 -5.35878639e-04 -2.13580000e+00 -2.86555783e-03\n", 472 | " -1.19892844e-03 -3.64568806e-01 -3.39155160e-01 -5.44463552e-04\n", 473 | " 8.38150022e-01 1.53047103e+00 5.64755284e-01 -3.03773764e-02]\n" 474 | ] 475 | } 476 | ], 477 | "source": [ 478 | "(LANG, DIM) = model.coef_.shape\n", 479 | "print(model.coef_.shape)\n", 480 | "coef = np.float32(model.coef_)\n", 481 | "\n", 482 | "print(coef[0,0])\n", 483 | "print(model.coef_[:,0])\n", 484 | "\n", 485 | "f = open(\"src/weights.rs\", \"w\")\n", 486 | "\n", 487 | "f.write(\"#[derive(Clone, Copy, Debug, Eq, PartialEq)]\\n\")\n", 488 | "f.write(\"pub enum Lang {\\n\")\n", 489 | "for lang in model.classes_:\n", 490 | " f.write(\"\\t%s,\\n\" % lang.capitalize(),)\n", 491 | "f.write(\"}\\n\\n\")\n", 492 | "\n", 493 | "f.write(\"\"\"\n", 494 | "impl Lang {\n", 495 | " pub fn three_letter_code(self)-> &'static str {\n", 496 | " match self {\n", 497 | "\"\"\")\n", 498 | "for lang in model.classes_:\n", 499 | " f.write(\"\\t\\t\\tLang::%s => \\\"%s\\\",\\n\" % (lang.capitalize(), lang))\n", 500 | "f.write(\"\\t\\t}\\t}\\n}\\n\\n\\n\")\n", 501 | "\n", 502 | "\n", 503 | "f.write(\"pub const LANGUAGES: [Lang; %d] = [\\n\\t\" % LANG)\n", 504 | "for lang in model.classes_:\n", 505 | " f.write(\"Lang::%s, \" % lang.capitalize())\n", 506 | "f.write(\"];\\n\\n\")\n", 507 | "\n", 508 | "f.write(\"pub const WEIGHTS: [f32; %d] = [\\n\" % (LANG * DIM))\n", 509 | "for i in range(DIM):\n", 510 | " f.write(\"\\t\")\n", 511 | " for val in coef[:, i]:\n", 512 | " f.write(\"%f, \" % val)\n", 513 | " f.write(\"\\n\")\n", 514 | "f.write(\"];\\n\\n\")\n", 515 | "\n", 516 | "\n", 517 | "f.write(\"pub const INTERCEPTS: [f32; %d] = [\\n\\t\" % LANG)\n", 518 | "for val in model.intercept_:\n", 519 | " f.write(\"%f, \" % val)\n", 520 | "f.write(\"];\\n\\n\")\n", 521 | "\n", 522 | "\n", 523 | "f.flush()\n", 524 | "f.close()\n" 525 | ] 526 | } 527 | ], 528 | "metadata": { 529 | "kernelspec": { 530 | "display_name": "Python 3 (ipykernel)", 531 | "language": "python", 532 | "name": "python3" 533 | }, 534 | "language_info": { 535 | "codemirror_mode": { 536 | "name": "ipython", 537 | "version": 3 538 | }, 539 | "file_extension": ".py", 540 | "mimetype": "text/x-python", 541 | "name": "python", 542 | "nbconvert_exporter": "python", 543 | "pygments_lexer": "ipython3", 544 | "version": "3.9.12" 545 | } 546 | }, 547 | "nbformat": 4, 548 | "nbformat_minor": 5 549 | } 550 | --------------------------------------------------------------------------------