├── bindings └── dart │ ├── ios │ ├── Assets │ │ └── .gitkeep │ ├── Classes │ │ ├── XaynAiFfiDartPlugin.h │ │ ├── XaynAiFfiDartPlugin.m │ │ └── SwiftXaynAiFfiDartPlugin.swift │ ├── .gitignore │ └── xayn_ai_ffi_dart.podspec │ ├── lib │ ├── src │ │ ├── common │ │ │ ├── ffi │ │ │ │ └── .gitkeep │ │ │ ├── utils.dart │ │ │ ├── reranker │ │ │ │ ├── utils.dart │ │ │ │ ├── analytics.dart │ │ │ │ ├── mode.dart │ │ │ │ ├── ai.dart │ │ │ │ └── debug.dart │ │ │ ├── data │ │ │ │ └── document.dart │ │ │ └── result │ │ │ │ └── outcomes.dart │ │ ├── web │ │ │ ├── result │ │ │ │ ├── fault.dart │ │ │ │ ├── outcomes.dart │ │ │ │ └── error.dart │ │ │ ├── worker │ │ │ │ ├── message │ │ │ │ │ ├── utils.dart │ │ │ │ │ └── response.dart │ │ │ │ └── worker.dart │ │ │ ├── data │ │ │ │ ├── document.dart │ │ │ │ └── history.dart │ │ │ ├── reranker │ │ │ │ ├── analytics.dart │ │ │ │ └── data_provider.dart │ │ │ └── ffi │ │ │ │ └── library.dart │ │ └── mobile │ │ │ ├── ffi │ │ │ └── library.dart │ │ │ ├── result │ │ │ ├── slice.dart │ │ │ ├── fault.dart │ │ │ ├── outcomes.dart │ │ │ └── error.dart │ │ │ ├── reranker │ │ │ ├── data_provider.dart │ │ │ ├── analytics.dart │ │ │ └── bytes.dart │ │ │ └── data │ │ │ ├── document.dart │ │ │ └── history.dart │ └── package.dart │ ├── example │ ├── assets │ │ ├── ltr_v0000 │ │ │ └── .gitkeep │ │ ├── qambert_v0001 │ │ │ └── .gitkeep │ │ ├── smbert_v0001 │ │ │ └── .gitkeep │ │ ├── wasm_bindings │ │ │ └── .gitkeep │ │ └── call_data │ │ │ └── example2.json │ ├── ios │ │ ├── Runner │ │ │ ├── Runner-Bridging-Header.h │ │ │ ├── Assets.xcassets │ │ │ │ ├── LaunchImage.imageset │ │ │ │ │ ├── LaunchImage.png │ │ │ │ │ └── Contents.json │ │ │ │ └── AppIcon.appiconset │ │ │ │ │ ├── Icon-App-20x20@2x.png │ │ │ │ │ └── Contents.json │ │ │ ├── AppDelegate.swift │ │ │ ├── Base.lproj │ │ │ │ ├── Main.storyboard │ │ │ │ └── LaunchScreen.storyboard │ │ │ └── Info.plist │ │ ├── Flutter │ │ │ ├── Debug.xcconfig │ │ │ ├── Release.xcconfig │ │ │ └── AppFrameworkInfo.plist │ │ ├── Runner.xcodeproj │ │ │ └── project.xcworkspace │ │ │ │ ├── contents.xcworkspacedata │ │ │ │ └── xcshareddata │ │ │ │ ├── WorkspaceSettings.xcsettings │ │ │ │ └── IDEWorkspaceChecks.plist │ │ ├── Runner.xcworkspace │ │ │ ├── contents.xcworkspacedata │ │ │ └── xcshareddata │ │ │ │ ├── WorkspaceSettings.xcsettings │ │ │ │ └── IDEWorkspaceChecks.plist │ │ ├── .gitignore │ │ ├── Podfile.lock │ │ └── Podfile │ ├── android │ │ ├── gradle.properties │ │ ├── app │ │ │ ├── src │ │ │ │ ├── main │ │ │ │ │ ├── res │ │ │ │ │ │ ├── mipmap-mdpi │ │ │ │ │ │ │ └── ic_launcher.png │ │ │ │ │ │ ├── drawable │ │ │ │ │ │ │ └── launch_background.xml │ │ │ │ │ │ └── values │ │ │ │ │ │ │ └── styles.xml │ │ │ │ │ ├── kotlin │ │ │ │ │ │ └── com │ │ │ │ │ │ │ └── example │ │ │ │ │ │ │ └── xayn_ai_ffi_dart_example │ │ │ │ │ │ │ └── MainActivity.kt │ │ │ │ │ └── AndroidManifest.xml │ │ │ │ ├── debug │ │ │ │ │ └── AndroidManifest.xml │ │ │ │ └── profile │ │ │ │ │ └── AndroidManifest.xml │ │ │ └── build.gradle │ │ ├── gradle │ │ │ └── wrapper │ │ │ │ └── gradle-wrapper.properties │ │ ├── .gitignore │ │ ├── settings.gradle │ │ └── build.gradle │ ├── lib │ │ ├── debug │ │ │ ├── print.dart │ │ │ └── mobile │ │ │ │ └── print.dart │ │ └── data_provider │ │ │ ├── data_provider.dart │ │ │ └── web.dart │ ├── web │ │ └── index.html │ ├── .metadata │ ├── server.py │ ├── .gitignore │ ├── pubspec.yaml │ └── flutter_run_web.sh │ ├── android │ ├── settings.gradle │ ├── gradle.properties │ ├── .gitignore │ ├── src │ │ └── main │ │ │ ├── AndroidManifest.xml │ │ │ └── kotlin │ │ │ └── com │ │ │ └── xayn │ │ │ └── xayn_ai_ffi_dart │ │ │ └── XaynAiFfiDartPlugin.kt │ ├── gradle │ │ └── wrapper │ │ │ └── gradle-wrapper.properties │ └── build.gradle │ ├── build.yaml │ ├── ffigen_mobile.yaml │ ├── ffigen_common.yaml │ ├── .gitignore │ ├── analysis_options.yaml │ ├── xayn_ai_ffi_dart.iml │ ├── pubspec.yaml │ └── test │ └── mobile │ ├── reranker │ └── bytes_test.dart │ ├── result │ └── fault_test.dart │ └── utils.dart ├── test-utils ├── src │ ├── test │ │ ├── mod.rs │ │ └── ltr.rs │ ├── bench │ │ ├── mod.rs │ │ └── matmul.rs │ ├── example │ │ ├── mod.rs │ │ └── validate.rs │ ├── ltr.rs │ ├── lib.rs │ ├── smbert.rs │ ├── qambert.rs │ ├── kpe.rs │ └── asset.rs └── Cargo.toml ├── data ├── bundler_config │ ├── .gitignore │ ├── package.json │ └── webpack.config.js └── asset_templates │ ├── base_assets.dart.tmpl │ └── web_assets.dart.tmpl ├── Cross.toml ├── xayn-ai ├── src │ ├── data │ │ └── mod.rs │ ├── error.rs │ ├── embedding │ │ ├── mod.rs │ │ └── smbert.rs │ ├── ranker │ │ ├── mod.rs │ │ └── document.rs │ ├── ltr │ │ └── list_net │ │ │ ├── mod.rs │ │ │ ├── optimizer.rs │ │ │ └── tests │ │ │ └── mod.rs │ ├── tests │ │ ├── mem_db.rs │ │ └── mod.rs │ ├── lib.rs │ └── coi │ │ └── mod.rs └── Cargo.toml ├── dev-tool ├── src │ ├── exit_code.rs │ ├── list_net.rs │ ├── utils.rs │ ├── main.rs │ └── call_data │ │ └── mod.rs └── Cargo.toml ├── layer ├── src │ └── lib.rs └── Cargo.toml ├── rustfmt.toml ├── bors.toml ├── .gitignore ├── rubert-tokenizer ├── README.md ├── Cargo.toml └── src │ ├── pre_tokenizer │ ├── string.rs │ └── mod.rs │ ├── tokenizer.rs │ ├── lib.rs │ └── model │ └── string.rs ├── xayn-ai-ffi ├── src │ └── lib.rs ├── Cargo.toml ├── cbindgen.toml └── build.rs ├── xayn-ai-ffi-c ├── src │ ├── lib.rs │ ├── data │ │ └── mod.rs │ ├── reranker │ │ ├── mod.rs │ │ └── analytics.rs │ └── result │ │ ├── mod.rs │ │ └── fault.rs ├── cbindgen.toml ├── Cargo.toml └── build.rs ├── Cargo.toml ├── kpe ├── examples │ └── kpe.rs ├── src │ ├── lib.rs │ ├── model │ │ └── mod.rs │ └── tokenizer │ │ └── mod.rs ├── Cargo.toml ├── benches │ └── kpe.rs └── extract_params.py ├── xayn-ai-ffi-wasm ├── src │ ├── error.rs │ └── lib.rs └── Cargo.toml ├── .ci ├── generate-flutter-ffi │ └── action.yml ├── copy-headers │ └── action.yml ├── install-wasm-pack │ └── action.yml ├── install-wasm-opt │ └── action.yml ├── install-cargo-sort │ └── action.yml ├── install-gomplate │ └── action.yml └── build-asset-artifacts │ └── action.yml ├── assets_manifest.json ├── .github ├── dependabot.yml └── workflows │ ├── labeler.yml │ ├── ci.yml │ └── audit.yml ├── generate_assets_metadata.sh ├── download_data.sh ├── scripts └── upload_assets.sh ├── rubert ├── examples │ └── mbert.rs ├── src │ ├── lib.rs │ └── config.rs └── Cargo.toml └── prepare_data.sh /bindings/dart/ios/Assets/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /bindings/dart/lib/src/common/ffi/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test-utils/src/test/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod ltr; 2 | -------------------------------------------------------------------------------- /bindings/dart/example/assets/ltr_v0000/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/bundler_config/.gitignore: -------------------------------------------------------------------------------- 1 | node_modules/ 2 | -------------------------------------------------------------------------------- /test-utils/src/bench/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod matmul; 2 | -------------------------------------------------------------------------------- /bindings/dart/example/assets/qambert_v0001/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /bindings/dart/example/assets/smbert_v0001/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /bindings/dart/example/assets/wasm_bindings/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test-utils/src/example/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod validate; 2 | -------------------------------------------------------------------------------- /Cross.toml: -------------------------------------------------------------------------------- 1 | [build.env] 2 | passthrough = ["CARGO_INCREMENTAL", "RUSTFLAGS"] 3 | -------------------------------------------------------------------------------- /bindings/dart/android/settings.gradle: -------------------------------------------------------------------------------- 1 | rootProject.name = 'xayn_ai_ffi_dart' 2 | -------------------------------------------------------------------------------- /xayn-ai/src/data/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod document; 2 | pub(crate) mod document_data; 3 | -------------------------------------------------------------------------------- /xayn-ai/src/error.rs: -------------------------------------------------------------------------------- 1 | // temporary dummy error 2 | pub type Error = anyhow::Error; 3 | -------------------------------------------------------------------------------- /bindings/dart/example/ios/Runner/Runner-Bridging-Header.h: -------------------------------------------------------------------------------- 1 | #import "GeneratedPluginRegistrant.h" 2 | -------------------------------------------------------------------------------- /xayn-ai/src/embedding/mod.rs: -------------------------------------------------------------------------------- 1 | pub(crate) mod qambert; 2 | pub(crate) mod smbert; 3 | pub(crate) mod utils; 4 | -------------------------------------------------------------------------------- /bindings/dart/android/gradle.properties: -------------------------------------------------------------------------------- 1 | org.gradle.jvmargs=-Xmx1536M 2 | android.useAndroidX=true 3 | android.enableJetifier=true 4 | -------------------------------------------------------------------------------- /dev-tool/src/exit_code.rs: -------------------------------------------------------------------------------- 1 | pub const NO_ERROR: i32 = 0; 2 | pub const FATAL_ERROR: i32 = 1; 3 | pub const NON_FATAL_ERROR: i32 = 2; 4 | -------------------------------------------------------------------------------- /bindings/dart/example/android/gradle.properties: -------------------------------------------------------------------------------- 1 | org.gradle.jvmargs=-Xmx1536M 2 | android.useAndroidX=true 3 | android.enableJetifier=true 4 | -------------------------------------------------------------------------------- /data/bundler_config/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "devDependencies": { 3 | "webpack": "5.70.0", 4 | "webpack-cli": "4.9.2" 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /bindings/dart/android/.gitignore: -------------------------------------------------------------------------------- 1 | *.iml 2 | .gradle 3 | /local.properties 4 | /.idea/workspace.xml 5 | /.idea/libraries 6 | .DS_Store 7 | /build 8 | /captures 9 | -------------------------------------------------------------------------------- /bindings/dart/example/ios/Flutter/Debug.xcconfig: -------------------------------------------------------------------------------- 1 | #include? "Pods/Target Support Files/Pods-Runner/Pods-Runner.debug.xcconfig" 2 | #include "Generated.xcconfig" 3 | -------------------------------------------------------------------------------- /bindings/dart/example/ios/Flutter/Release.xcconfig: -------------------------------------------------------------------------------- 1 | #include? "Pods/Target Support Files/Pods-Runner/Pods-Runner.release.xcconfig" 2 | #include "Generated.xcconfig" 3 | -------------------------------------------------------------------------------- /bindings/dart/android/src/main/AndroidManifest.xml: -------------------------------------------------------------------------------- 1 | 3 | 4 | -------------------------------------------------------------------------------- /bindings/dart/ios/Classes/XaynAiFfiDartPlugin.h: -------------------------------------------------------------------------------- 1 | #import 2 | 3 | @interface XaynAiFfiDartPlugin : NSObject 4 | @end 5 | 6 | #include "XaynAiFfiDart.h" 7 | -------------------------------------------------------------------------------- /layer/src/lib.rs: -------------------------------------------------------------------------------- 1 | //! AI model building blocks. 2 | #![forbid(unsafe_op_in_unsafe_fn)] 3 | 4 | pub mod activation; 5 | pub mod conv; 6 | pub mod dense; 7 | pub mod io; 8 | pub mod utils; 9 | -------------------------------------------------------------------------------- /rustfmt.toml: -------------------------------------------------------------------------------- 1 | # requires nightly rustfmt until the options are stabilized 2 | format_code_in_doc_comments = true 3 | imports_granularity = "Crate" 4 | imports_layout = "HorizontalVertical" 5 | -------------------------------------------------------------------------------- /bors.toml: -------------------------------------------------------------------------------- 1 | status = [ "ci-ok"] 2 | timeout_sec = 18000 # five hours 3 | required_approvals = 1 4 | use_squash_merge = true 5 | delete_merged_branches = true 6 | update_base_for_deletes = true 7 | -------------------------------------------------------------------------------- /bindings/dart/example/android/app/src/main/res/mipmap-mdpi/ic_launcher.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xaynetwork/xayn_ai/HEAD/bindings/dart/example/android/app/src/main/res/mipmap-mdpi/ic_launcher.png -------------------------------------------------------------------------------- /bindings/dart/build.yaml: -------------------------------------------------------------------------------- 1 | targets: 2 | $default: 3 | builders: 4 | json_serializable: 5 | options: 6 | field_rename: snake 7 | explicit_to_json: true 8 | any_map: true 9 | -------------------------------------------------------------------------------- /bindings/dart/example/ios/Runner/Assets.xcassets/LaunchImage.imageset/LaunchImage.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xaynetwork/xayn_ai/HEAD/bindings/dart/example/ios/Runner/Assets.xcassets/LaunchImage.imageset/LaunchImage.png -------------------------------------------------------------------------------- /bindings/dart/example/ios/Runner/Assets.xcassets/AppIcon.appiconset/Icon-App-20x20@2x.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xaynetwork/xayn_ai/HEAD/bindings/dart/example/ios/Runner/Assets.xcassets/AppIcon.appiconset/Icon-App-20x20@2x.png -------------------------------------------------------------------------------- /bindings/dart/example/lib/debug/print.dart: -------------------------------------------------------------------------------- 1 | import 'package:flutter/material.dart' show debugPrint; 2 | 3 | /// `debugPrint` for long text. 4 | void debugPrintLongText(String text) { 5 | debugPrint(text, wrapWidth: 80); 6 | } 7 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /target/ 2 | /out/* 3 | /.vscode/ 4 | __pycache__ 5 | 6 | # Data 7 | /data/* 8 | !/data/asset_templates 9 | !/data/bundler_config 10 | !/data/snippets_for_example_data_generation.json 11 | 12 | # MacOS files 13 | .DS_Store 14 | -------------------------------------------------------------------------------- /bindings/dart/example/ios/Runner.xcodeproj/project.xcworkspace/contents.xcworkspacedata: -------------------------------------------------------------------------------- 1 | 2 | 4 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /bindings/dart/example/web/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | xayn_search_web 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | -------------------------------------------------------------------------------- /rubert-tokenizer/README.md: -------------------------------------------------------------------------------- 1 | This crate is a fork of: https://github.com/huggingface/tokenizers. 2 | 3 | We removed all the cli parts, the multithreaded features and every non-BERT-related part to make it more compact and 4 | to allow it to work on wasm. 5 | -------------------------------------------------------------------------------- /bindings/dart/example/android/app/src/main/kotlin/com/example/xayn_ai_ffi_dart_example/MainActivity.kt: -------------------------------------------------------------------------------- 1 | package com.xayn.xayn_ai_ffi_dart_example 2 | 3 | import io.flutter.embedding.android.FlutterActivity 4 | 5 | class MainActivity: FlutterActivity() { 6 | } 7 | -------------------------------------------------------------------------------- /bindings/dart/android/gradle/wrapper/gradle-wrapper.properties: -------------------------------------------------------------------------------- 1 | distributionBase=GRADLE_USER_HOME 2 | distributionPath=wrapper/dists 3 | zipStoreBase=GRADLE_USER_HOME 4 | zipStorePath=wrapper/dists 5 | distributionUrl=https\://services.gradle.org/distributions/gradle-6.7-all.zip 6 | -------------------------------------------------------------------------------- /bindings/dart/ffigen_mobile.yaml: -------------------------------------------------------------------------------- 1 | name: 'XaynAiFfi' 2 | description: 'Bindings to the xayn-ai-ffi-c library.' 3 | output: 'lib/src/mobile/ffi/genesis.dart' 4 | headers: 5 | entry-points: 6 | - 'ios/Classes/XaynAiFfiDart.h' 7 | include-directives: 8 | - 'ios/Classes/XaynAiFfiDart.h' 9 | -------------------------------------------------------------------------------- /bindings/dart/ffigen_common.yaml: -------------------------------------------------------------------------------- 1 | name: 'XaynAiFfi' 2 | description: 'Bindings to the xayn-ai-ffi library.' 3 | output: 'lib/src/common/ffi/genesis.dart' 4 | headers: 5 | entry-points: 6 | - 'ios/Classes/XaynAiFfiCommon.h' 7 | include-directives: 8 | - 'ios/Classes/XaynAiFfiCommon.h' 9 | -------------------------------------------------------------------------------- /xayn-ai-ffi/src/lib.rs: -------------------------------------------------------------------------------- 1 | //! Common FFI components for the Xayn AI. 2 | #![cfg_attr( 3 | doc, 4 | forbid(rustdoc::broken_intra_doc_links, rustdoc::private_intra_doc_links) 5 | )] 6 | #![forbid(unsafe_op_in_unsafe_fn)] 7 | 8 | mod error; 9 | 10 | pub use crate::error::{CCode, Error}; 11 | -------------------------------------------------------------------------------- /bindings/dart/example/android/gradle/wrapper/gradle-wrapper.properties: -------------------------------------------------------------------------------- 1 | #Fri Jun 23 08:50:38 CEST 2017 2 | distributionBase=GRADLE_USER_HOME 3 | distributionPath=wrapper/dists 4 | zipStoreBase=GRADLE_USER_HOME 5 | zipStorePath=wrapper/dists 6 | distributionUrl=https\://services.gradle.org/distributions/gradle-6.7-all.zip 7 | -------------------------------------------------------------------------------- /xayn-ai-ffi-c/src/lib.rs: -------------------------------------------------------------------------------- 1 | //! C FFI for the Xayn AI. 2 | #![cfg_attr( 3 | doc, 4 | forbid(rustdoc::broken_intra_doc_links, rustdoc::private_intra_doc_links) 5 | )] 6 | #![forbid(unsafe_op_in_unsafe_fn)] 7 | 8 | pub mod data; 9 | pub mod reranker; 10 | pub mod result; 11 | mod slice; 12 | pub mod utils; 13 | -------------------------------------------------------------------------------- /bindings/dart/example/ios/Runner/Assets.xcassets/LaunchImage.imageset/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "images" : [ 3 | { 4 | "idiom" : "universal", 5 | "filename" : "LaunchImage.png", 6 | "scale" : "1x" 7 | } 8 | ], 9 | "info" : { 10 | "version" : 1, 11 | "author" : "xcode" 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /bindings/dart/example/ios/Runner.xcworkspace/contents.xcworkspacedata: -------------------------------------------------------------------------------- 1 | 2 | 4 | 6 | 7 | 9 | 10 | 11 | -------------------------------------------------------------------------------- /bindings/dart/example/ios/Runner.xcworkspace/xcshareddata/WorkspaceSettings.xcsettings: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | PreviewsEnabled 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /bindings/dart/example/ios/Runner.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | IDEDidComputeMac32BitWarning 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /bindings/dart/example/android/.gitignore: -------------------------------------------------------------------------------- 1 | gradle-wrapper.jar 2 | /.gradle 3 | /captures/ 4 | /gradlew 5 | /gradlew.bat 6 | /local.properties 7 | GeneratedPluginRegistrant.java 8 | 9 | # Remember to never publicly share your keystore. 10 | # See https://flutter.dev/docs/deployment/android#reference-the-keystore-from-the-app 11 | key.properties 12 | -------------------------------------------------------------------------------- /test-utils/src/example/validate.rs: -------------------------------------------------------------------------------- 1 | use std::{io::Result, path::PathBuf}; 2 | 3 | use crate::asset::{resolve_path, DATA_DIR}; 4 | 5 | const ASSET: &str = "ted_talk_transcripts.csv"; 6 | 7 | /// Resolves the path to the MBert validation transcripts. 8 | pub fn transcripts() -> Result { 9 | resolve_path(&[DATA_DIR, ASSET]) 10 | } 11 | -------------------------------------------------------------------------------- /bindings/dart/example/ios/Runner.xcodeproj/project.xcworkspace/xcshareddata/WorkspaceSettings.xcsettings: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | PreviewsEnabled 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /bindings/dart/example/ios/Runner/Assets.xcassets/AppIcon.appiconset/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "images" : [ 3 | { 4 | "size" : "20x20", 5 | "idiom" : "iphone", 6 | "filename" : "Icon-App-20x20@2x.png", 7 | "scale" : "2x" 8 | } 9 | ], 10 | "info" : { 11 | "version" : 1, 12 | "author" : "xcode" 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /xayn-ai-ffi-c/src/data/mod.rs: -------------------------------------------------------------------------------- 1 | //! I/O types for reranking. 2 | 3 | pub(crate) mod document; 4 | pub(crate) mod history; 5 | pub(crate) mod outcomes; 6 | 7 | #[cfg(doc)] 8 | pub use self::{ 9 | document::{CDocument, CDocuments}, 10 | history::{CHistories, CHistory}, 11 | outcomes::{reranking_outcomes_drop, CRerankingOutcomes}, 12 | }; 13 | -------------------------------------------------------------------------------- /bindings/dart/example/ios/Runner.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | IDEDidComputeMac32BitWarning 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /xayn-ai-ffi-c/cbindgen.toml: -------------------------------------------------------------------------------- 1 | # cbindgen config options: https://github.com/eqrion/cbindgen/blob/master/docs.md#cbindgentoml 2 | 3 | language = "C" 4 | autogen_warning = "/* Warning, this file is autogenerated by cbindgen. Don't modify this manually. */" 5 | include_version = true 6 | sys_includes = ["stdint.h"] 7 | includes = ["XaynAiFfiCommon.h"] 8 | no_includes = true 9 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [workspace] 2 | members = [ 3 | "dev-tool", 4 | "kpe", 5 | "layer", 6 | "rubert", 7 | "rubert-tokenizer", 8 | "test-utils", 9 | "xayn-ai", 10 | "xayn-ai-ffi", 11 | "xayn-ai-ffi-c", 12 | "xayn-ai-ffi-wasm", 13 | ] 14 | resolver = "2" 15 | 16 | [workspace.metadata] 17 | # minimum supported rust version 18 | msrv = "1.55.0" 19 | -------------------------------------------------------------------------------- /bindings/dart/example/.metadata: -------------------------------------------------------------------------------- 1 | # This file tracks properties of this Flutter project. 2 | # Used by Flutter tool to assess capabilities and perform upgrades etc. 3 | # 4 | # This file should be version controlled and should not be manually edited. 5 | 6 | version: 7 | revision: 0941968447ea8058e56e1479f7e53147149b739e 8 | channel: beta 9 | 10 | project_type: app 11 | -------------------------------------------------------------------------------- /test-utils/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "test-utils" 3 | version = "0.1.0" 4 | authors = ["Xayn Engineering "] 5 | edition = "2018" 6 | publish = false 7 | 8 | [dependencies] 9 | float-cmp = "0.9.0" 10 | # to be kept in sync with rubert 11 | ndarray = "=0.15.3" 12 | serde = { version = "1.0.136", features = ["derive"] } 13 | serde_json = "1.0.79" 14 | -------------------------------------------------------------------------------- /test-utils/src/ltr.rs: -------------------------------------------------------------------------------- 1 | use std::{io::Result, path::PathBuf}; 2 | 3 | use crate::asset::resolve_asset; 4 | 5 | /// Resolves the path to the LTR model. 6 | pub fn model() -> Result { 7 | resolve_asset("ltrModel") 8 | } 9 | 10 | #[cfg(test)] 11 | mod tests { 12 | use super::*; 13 | 14 | #[test] 15 | fn test_model() { 16 | assert!(model().is_ok()); 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /bindings/dart/example/lib/data_provider/data_provider.dart: -------------------------------------------------------------------------------- 1 | import 'package:xayn_ai_ffi_dart/package.dart' show SetupData; 2 | 3 | /// Prepares and returns the data that is needed to init [`XaynAi`]. 4 | Future getInputData() async => 5 | throw UnsupportedError('Unsupported platform.'); 6 | 7 | String joinPaths(List paths) { 8 | return paths.where((e) => e.isNotEmpty).join('/'); 9 | } 10 | -------------------------------------------------------------------------------- /bindings/dart/example/android/app/src/debug/AndroidManifest.xml: -------------------------------------------------------------------------------- 1 | 3 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /bindings/dart/example/android/app/src/profile/AndroidManifest.xml: -------------------------------------------------------------------------------- 1 | 3 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /rubert-tokenizer/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "rubert-tokenizer" 3 | version = "0.1.0" 4 | authors = ["Xayn Engineering "] 5 | edition = "2018" 6 | license = "Apache-2.0" 7 | 8 | [dependencies] 9 | displaydoc = "0.2.3" 10 | num-traits = "0.2.14" 11 | regex = "1.5.4" 12 | regex-syntax = "0.6.25" 13 | smallstr = "0.3.0" 14 | thiserror = "1.0.30" 15 | unicode-normalization-alignments = "0.1.12" 16 | unicode_categories = "0.1.1" 17 | -------------------------------------------------------------------------------- /xayn-ai-ffi/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "xayn-ai-ffi" 3 | version = "0.1.0" 4 | authors = ["Xayn Engineering "] 5 | edition = "2018" 6 | 7 | [dependencies] 8 | derive_more = { version = "0.99.17", default-features = false, features = ["display"] } 9 | serde = { version = "1.0.136", features = ["derive"] } 10 | serde_repr = "0.1.7" 11 | xayn-ai = { path = "../xayn-ai" } 12 | 13 | [build-dependencies] 14 | cbindgen = "=0.20.0" 15 | -------------------------------------------------------------------------------- /test-utils/src/lib.rs: -------------------------------------------------------------------------------- 1 | //! The single source of truth for all data paths and other test utilities. 2 | #![forbid(unsafe_op_in_unsafe_fn)] 3 | 4 | mod approx_eq; 5 | mod asset; 6 | pub mod bench; 7 | pub mod example; 8 | pub mod kpe; 9 | pub mod ltr; 10 | pub mod qambert; 11 | pub mod smbert; 12 | pub mod test; 13 | 14 | pub use crate::approx_eq::ApproxEqIter; 15 | #[doc(hidden)] // required for standalone export of assert_approx_eq! 16 | pub use float_cmp::approx_eq; 17 | -------------------------------------------------------------------------------- /layer/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "layer" 3 | version = "0.1.0" 4 | authors = ["Xayn Engineering "] 5 | edition = "2018" 6 | 7 | [dependencies] 8 | bincode = "1.3.3" 9 | displaydoc = "0.2.3" 10 | # to be kept in sync with tract-core 11 | ndarray = "=0.15.3" 12 | rand = "0.8.5" 13 | rand_distr = "0.4.3" 14 | serde = { version = "1.0.136", features = ["derive"] } 15 | thiserror = "1.0.30" 16 | 17 | [dev-dependencies] 18 | test-utils = { path = "../test-utils" } 19 | -------------------------------------------------------------------------------- /bindings/dart/example/ios/Runner/AppDelegate.swift: -------------------------------------------------------------------------------- 1 | import UIKit 2 | import Flutter 3 | 4 | @UIApplicationMain 5 | @objc class AppDelegate: FlutterAppDelegate { 6 | override func application( 7 | _ application: UIApplication, 8 | didFinishLaunchingWithOptions launchOptions: [UIApplication.LaunchOptionsKey: Any]? 9 | ) -> Bool { 10 | GeneratedPluginRegistrant.register(with: self) 11 | return super.application(application, didFinishLaunchingWithOptions: launchOptions) 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /xayn-ai-ffi-c/src/reranker/mod.rs: -------------------------------------------------------------------------------- 1 | //! The AI and its I/O types. 2 | 3 | pub(crate) mod ai; 4 | pub(crate) mod analytics; 5 | pub(crate) mod bytes; 6 | 7 | #[cfg(doc)] 8 | pub use self::{ 9 | ai::{ 10 | xaynai_analytics, 11 | xaynai_drop, 12 | xaynai_faults, 13 | xaynai_new, 14 | xaynai_rerank, 15 | xaynai_serialize, 16 | CXaynAi, 17 | }, 18 | analytics::{analytics_drop, CAnalytics}, 19 | bytes::{bytes_drop, bytes_new, CBytes}, 20 | }; 21 | -------------------------------------------------------------------------------- /data/bundler_config/webpack.config.js: -------------------------------------------------------------------------------- 1 | module.exports = { 2 | mode: "production", 3 | target: "webworker", 4 | output: { 5 | filename: "genesis.js", 6 | library: { 7 | name: "xayn_ai_ffi_wasm", 8 | type: "self", 9 | }, 10 | clean: true, 11 | }, 12 | module: { 13 | rules: [ 14 | { 15 | test: /\.wasm/, 16 | type: "asset/resource", 17 | generator: { 18 | filename: "[name][ext]", 19 | }, 20 | }, 21 | ], 22 | }, 23 | }; 24 | -------------------------------------------------------------------------------- /kpe/examples/kpe.rs: -------------------------------------------------------------------------------- 1 | use kpe::{Config, Pipeline}; 2 | 3 | use test_utils::kpe::*; 4 | 5 | fn main() -> Result<(), Box> { 6 | let config = 7 | Config::from_files(vocab()?, bert()?, cnn()?, classifier()?)?.with_token_size(128)?; 8 | 9 | let kpe = Pipeline::from(config)?; 10 | 11 | let key_phrases = kpe.run("This sequence will be split into key phrases.")?; 12 | println!("{:?}", key_phrases); 13 | assert_eq!(key_phrases.len(), 30); 14 | 15 | Ok(()) 16 | } 17 | -------------------------------------------------------------------------------- /test-utils/src/bench/matmul.rs: -------------------------------------------------------------------------------- 1 | use std::{io::Result, path::PathBuf}; 2 | 3 | use crate::asset::{resolve_path, DATA_DIR}; 4 | 5 | const ASSET: &str = "bench_matmul_v0000"; 6 | 7 | /// Resolves the path to the matrix multiplication benchmark data. 8 | pub fn data_dir() -> Result { 9 | resolve_path(&[DATA_DIR, ASSET]) 10 | } 11 | 12 | #[cfg(test)] 13 | mod tests { 14 | use super::*; 15 | 16 | #[test] 17 | fn test_data_dir() { 18 | assert!(data_dir().is_ok()); 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /xayn-ai-ffi/cbindgen.toml: -------------------------------------------------------------------------------- 1 | # cbindgen config options: https://github.com/eqrion/cbindgen/blob/master/docs.md#cbindgentoml 2 | 3 | language = "C" 4 | autogen_warning = "/* Warning, this file is autogenerated by cbindgen. Don't modify this manually. */" 5 | include_version = true 6 | sys_includes = ["stdint.h"] 7 | no_includes = true 8 | 9 | [parse] 10 | parse_deps = true 11 | include = ["xayn-ai"] 12 | 13 | [export] 14 | include = ["CCode", "DayOfWeek", "Relevance", "RerankMode", "UserAction", "UserFeedback",] 15 | -------------------------------------------------------------------------------- /xayn-ai/src/ranker/mod.rs: -------------------------------------------------------------------------------- 1 | mod context; 2 | mod document; 3 | mod public; 4 | mod system; 5 | 6 | pub use self::{ 7 | document::Document, 8 | public::{Builder, Ranker}, 9 | }; 10 | pub use crate::{ 11 | coi::{ 12 | config::{Config as CoiSystemConfig, Error as CoiSystemConfigError}, 13 | key_phrase::KeyPhrase, 14 | }, 15 | embedding::utils::{cosine_similarity, pairwise_cosine_similarity, ArcEmbedding, Embedding}, 16 | DocumentId, 17 | }; 18 | pub use rubert::AveragePooler; 19 | -------------------------------------------------------------------------------- /xayn-ai/src/ltr/list_net/mod.rs: -------------------------------------------------------------------------------- 1 | //! ListNet implementation using the NdArray crate. 2 | 3 | mod data; 4 | mod model; 5 | mod optimizer; 6 | #[cfg(test)] 7 | mod tests; 8 | mod trainer; 9 | 10 | pub use self::{ 11 | data::{ 12 | prepare_inputs, 13 | prepare_target_prob_dist, 14 | DataSource, 15 | GradientSet, 16 | SampleOwned, 17 | SampleView, 18 | }, 19 | model::ListNet, 20 | optimizer::MiniBatchSgd, 21 | trainer::{ListNetTrainer, TrainingController}, 22 | }; 23 | -------------------------------------------------------------------------------- /bindings/dart/.gitignore: -------------------------------------------------------------------------------- 1 | .dart_tool/ 2 | build/ 3 | .packages 4 | pubspec.lock 5 | 6 | # When we push to the release repository 7 | # we need to upload also the following files. 8 | # DELETE_AFTER_THIS_IN_RELEASE 9 | android/src/main/jniLibs/ 10 | ios/Classes/XaynAiFfiCommon.h 11 | ios/Classes/XaynAiFfiDart.h 12 | ios/libxayn_ai_ffi_c*.a 13 | lib/src/common/ffi/genesis.dart 14 | lib/src/common/reranker/assets.dart 15 | lib/src/web/reranker/assets.dart 16 | lib/src/mobile/ffi/genesis.dart 17 | # Ignore generated dart files. 18 | *.g.dart 19 | -------------------------------------------------------------------------------- /bindings/dart/example/android/settings.gradle: -------------------------------------------------------------------------------- 1 | include ':app' 2 | 3 | def localPropertiesFile = new File(rootProject.projectDir, "local.properties") 4 | def properties = new Properties() 5 | 6 | assert localPropertiesFile.exists() 7 | localPropertiesFile.withReader("UTF-8") { reader -> properties.load(reader) } 8 | 9 | def flutterSdkPath = properties.getProperty("flutter.sdk") 10 | assert flutterSdkPath != null, "flutter.sdk not set in local.properties" 11 | apply from: "$flutterSdkPath/packages/flutter_tools/gradle/app_plugin_loader.gradle" 12 | -------------------------------------------------------------------------------- /bindings/dart/example/android/app/src/main/res/drawable/launch_background.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 12 | 13 | -------------------------------------------------------------------------------- /xayn-ai-ffi-wasm/src/error.rs: -------------------------------------------------------------------------------- 1 | use wasm_bindgen::JsValue; 2 | use xayn_ai_ffi::Error; 3 | 4 | /// An interface to convert results into JS compatible results. 5 | pub trait IntoJsResult { 6 | /// Converts an error into a JS compatible value, while any value stays untouched. 7 | fn into_js_result(self) -> Result; 8 | } 9 | 10 | impl IntoJsResult for Result { 11 | fn into_js_result(self) -> Result { 12 | self.map_err(|e| JsValue::from_serde(&e).expect("Failed to serialize the error")) 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /bindings/dart/ios/.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | .vagrant/ 3 | .sconsign.dblite 4 | .svn/ 5 | 6 | .DS_Store 7 | *.swp 8 | profile 9 | 10 | DerivedData/ 11 | build/ 12 | GeneratedPluginRegistrant.h 13 | GeneratedPluginRegistrant.m 14 | 15 | .generated/ 16 | 17 | *.pbxuser 18 | *.mode1v3 19 | *.mode2v3 20 | *.perspectivev3 21 | 22 | !default.pbxuser 23 | !default.mode1v3 24 | !default.mode2v3 25 | !default.perspectivev3 26 | 27 | xcuserdata 28 | 29 | *.moved-aside 30 | 31 | *.pyc 32 | *sync/ 33 | Icon? 34 | .tags* 35 | 36 | /Flutter/Generated.xcconfig 37 | /Flutter/flutter_export_environment.sh -------------------------------------------------------------------------------- /bindings/dart/lib/src/common/utils.dart: -------------------------------------------------------------------------------- 1 | abstract class ToJson { 2 | /// Serializes a dart object into a JSON object. 3 | Map toJson(); 4 | } 5 | 6 | /// Throws an assertion error in debug mode if the left-hand side is not equal to the right-hand 7 | /// side. 8 | void assertEq(T lhs, T rhs) { 9 | assert(lhs == rhs, 'equality assertion failed: $lhs != $rhs'); 10 | } 11 | 12 | /// Throws an assertion error in debug mode if the left-hand side is equal to the right-hand side. 13 | void assertNeq(T lhs, T rhs) { 14 | assert(lhs != rhs, 'inequality assertion failed: $lhs == $rhs'); 15 | } 16 | -------------------------------------------------------------------------------- /kpe/src/lib.rs: -------------------------------------------------------------------------------- 1 | //! The KPE pipeline extracts key phrases from a sequence. 2 | //! 3 | //! See `examples/` for a usage example. 4 | #![cfg_attr( 5 | doc, 6 | forbid(rustdoc::broken_intra_doc_links, rustdoc::private_intra_doc_links) 7 | )] 8 | #![forbid(unsafe_op_in_unsafe_fn)] 9 | 10 | mod config; 11 | mod model; 12 | mod pipeline; 13 | mod tokenizer; 14 | 15 | pub use crate::{ 16 | config::{Config, ConfigError}, 17 | pipeline::{Pipeline, PipelineError}, 18 | tokenizer::key_phrase::RankedKeyPhrases, 19 | }; 20 | 21 | #[cfg(doc)] 22 | pub use crate::{model::ModelError, tokenizer::TokenizerError}; 23 | -------------------------------------------------------------------------------- /.ci/generate-flutter-ffi/action.yml: -------------------------------------------------------------------------------- 1 | name: 'generate flutter ffi' 2 | description: 'Generates flutter ffi' 3 | inputs: 4 | dart-ws: 5 | description: 'The Dart workspace' 6 | required: true 7 | runs: 8 | using: "composite" 9 | steps: 10 | - shell: bash 11 | working-directory: ${{ inputs.dart-ws }} 12 | run: | 13 | flutter pub run ffigen --config ffigen_common.yaml 14 | flutter pub run ffigen --config ffigen_mobile.yaml 15 | grep --fixed-strings --invert-match "import 'dart:ffi' as ffi;" lib/src/common/ffi/genesis.dart > genesis && mv genesis lib/src/common/ffi/genesis.dart 16 | -------------------------------------------------------------------------------- /bindings/dart/lib/src/web/result/fault.dart: -------------------------------------------------------------------------------- 1 | @JS() 2 | library fault; 3 | 4 | import 'package:js/js.dart' show anonymous, JS; 5 | 6 | @JS() 7 | @anonymous 8 | class JsFault { 9 | external String get message; 10 | 11 | external factory JsFault({ 12 | // ignore: unused_element 13 | int code, 14 | // ignore: unused_element 15 | String message, 16 | }); 17 | } 18 | 19 | extension ToStrings on List { 20 | /// Gets the messages of the faults. 21 | List toStrings() => List.generate( 22 | length, 23 | (i) => this[i].message, 24 | growable: false, 25 | ); 26 | } 27 | -------------------------------------------------------------------------------- /test-utils/src/smbert.rs: -------------------------------------------------------------------------------- 1 | use std::{io::Result, path::PathBuf}; 2 | 3 | use crate::asset::resolve_asset; 4 | 5 | /// Resolves the path to the SMBert vocabulary. 6 | pub fn vocab() -> Result { 7 | resolve_asset("smbertVocab") 8 | } 9 | 10 | /// Resolves the path to the SMBert model. 11 | pub fn model() -> Result { 12 | resolve_asset("smbertModel") 13 | } 14 | 15 | #[cfg(test)] 16 | mod tests { 17 | use super::*; 18 | 19 | #[test] 20 | fn test_vocab() { 21 | assert!(vocab().is_ok()); 22 | } 23 | 24 | #[test] 25 | fn test_model() { 26 | assert!(model().is_ok()); 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /assets_manifest.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_assets": [ 3 | { 4 | "id": "smbertVocab", 5 | "url_suffix": "smbert_v0001/vocab.txt" 6 | }, 7 | { 8 | "id": "smbertModel", 9 | "url_suffix": "smbert_v0001/smbert-quantized.onnx", 10 | "chunk_size": "11MB" 11 | }, 12 | { 13 | "id": "qambertVocab", 14 | "url_suffix": "qambert_v0001/vocab.txt" 15 | }, 16 | { 17 | "id": "qambertModel", 18 | "url_suffix": "qambert_v0001/qambert.onnx", 19 | "chunk_size": "11MB" 20 | }, 21 | { 22 | "id": "ltrModel", 23 | "url_suffix": "ltr_v0000/ltr.binparams" 24 | } 25 | ] 26 | } 27 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "cargo" 4 | directory: "/" 5 | schedule: 6 | interval: "weekly" 7 | day: "monday" 8 | time: "09:00" 9 | timezone: "Europe/Berlin" 10 | 11 | - package-ecosystem: "github-actions" 12 | directory: "/" 13 | schedule: 14 | interval: "weekly" 15 | day: "monday" 16 | time: "09:00" 17 | timezone: "Europe/Berlin" 18 | 19 | - package-ecosystem: "npm" 20 | directory: "/data/bundler_config" 21 | schedule: 22 | interval: "weekly" 23 | day: "monday" 24 | time: "09:00" 25 | timezone: "Europe/Berlin" 26 | -------------------------------------------------------------------------------- /bindings/dart/lib/src/common/reranker/utils.dart: -------------------------------------------------------------------------------- 1 | import 'dart:math' show max, min; 2 | 3 | /// Maximum number of threads to be used for multithreaded features. 4 | const int maxNumberOfThreads = 16; 5 | 6 | /// Selects the number of threads used by the [`XaynAi`] thread pool. 7 | /// 8 | /// On a single core system the thread pool consists of only one thread. 9 | /// On a multicore system the thread pool consists of 10 | /// (the number of logical cores - 1) threads, but at most [`maxNumberOfThreads`] 11 | /// threads and at least one thread. 12 | int selectThreadPoolSize(int numberOfProcessors) => 13 | min(max(numberOfProcessors - 1, 1), maxNumberOfThreads); 14 | -------------------------------------------------------------------------------- /bindings/dart/analysis_options.yaml: -------------------------------------------------------------------------------- 1 | include: package:pedantic/analysis_options.1.11.0.yaml 2 | 3 | analyzer: 4 | strong-mode: 5 | implicit-casts: false 6 | implicit-dynamic: false 7 | exclude: 8 | - lib/src/common/ffi/genesis.dart 9 | - lib/src/mobile/ffi/genesis.dart 10 | - '**/*.g.dart' 11 | 12 | linter: 13 | rules: 14 | - camel_case_types 15 | - camel_case_extensions 16 | - library_names 17 | - file_names 18 | - library_prefixes 19 | - non_constant_identifier_names 20 | - constant_identifier_names 21 | - directives_ordering 22 | - curly_braces_in_flow_control_structures 23 | - prefer_single_quotes 24 | -------------------------------------------------------------------------------- /bindings/dart/ios/Classes/XaynAiFfiDartPlugin.m: -------------------------------------------------------------------------------- 1 | #import "XaynAiFfiDartPlugin.h" 2 | #if __has_include() 3 | #import 4 | #else 5 | // Support project import fallback if the generated compatibility header 6 | // is not copied when this plugin is created as a library. 7 | // https://forums.swift.org/t/swift-static-libraries-dont-copy-generated-objective-c-header/19816 8 | #import "xayn_ai_ffi_dart-Swift.h" 9 | #endif 10 | 11 | @implementation XaynAiFfiDartPlugin 12 | + (void)registerWithRegistrar:(NSObject*)registrar { 13 | [SwiftXaynAiFfiDartPlugin registerWithRegistrar:registrar]; 14 | } 15 | @end 16 | -------------------------------------------------------------------------------- /kpe/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "kpe" 3 | version = "0.1.0" 4 | authors = ["Xayn Engineering "] 5 | edition = "2018" 6 | 7 | [dependencies] 8 | derive_more = { version = "0.99.17", default-features = false, features = ["deref", "from"] } 9 | displaydoc = "0.2.3" 10 | layer = { path = "../layer" } 11 | # to be kept in sync with tract-core 12 | ndarray = "=0.15.3" 13 | rubert-tokenizer = { path = "../rubert-tokenizer" } 14 | thiserror = "1.0.29" 15 | tract-onnx = "0.16.1" 16 | 17 | [dev-dependencies] 18 | criterion = { version = "0.3.5", features = ["html_reports"] } 19 | test-utils = { path = "../test-utils" } 20 | 21 | [[bench]] 22 | name = "kpe" 23 | harness = false 24 | bench = false 25 | -------------------------------------------------------------------------------- /xayn-ai-ffi-c/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "xayn-ai-ffi-c" 3 | version = "0.1.0" 4 | authors = ["Xayn Engineering "] 5 | edition = "2018" 6 | 7 | [dependencies] 8 | uuid = "0.8.2" 9 | xayn-ai = { path = "../xayn-ai" } 10 | xayn-ai-ffi = { path = "../xayn-ai-ffi" } 11 | 12 | # multithreaded feature 13 | rayon = { version = "1.5.1", optional = true } 14 | 15 | [dev-dependencies] 16 | itertools = "0.10.3" 17 | tempfile = "3.3.0" 18 | test-utils = { path = "../test-utils" } 19 | 20 | [build-dependencies] 21 | cbindgen = "=0.20.0" 22 | 23 | [lib] 24 | crate-type = ["cdylib", "staticlib"] 25 | 26 | [features] 27 | default = ["multithreaded"] 28 | multithreaded = ["rayon", "xayn-ai/multithreaded"] 29 | -------------------------------------------------------------------------------- /bindings/dart/example/android/build.gradle: -------------------------------------------------------------------------------- 1 | buildscript { 2 | ext.kotlin_version = '1.3.50' 3 | repositories { 4 | google() 5 | jcenter() 6 | } 7 | 8 | dependencies { 9 | classpath 'com.android.tools.build:gradle:4.1.0' 10 | classpath "org.jetbrains.kotlin:kotlin-gradle-plugin:$kotlin_version" 11 | } 12 | } 13 | 14 | allprojects { 15 | repositories { 16 | google() 17 | jcenter() 18 | } 19 | } 20 | 21 | rootProject.buildDir = '../build' 22 | subprojects { 23 | project.buildDir = "${rootProject.buildDir}/${project.name}" 24 | project.evaluationDependsOn(':app') 25 | } 26 | 27 | task clean(type: Delete) { 28 | delete rootProject.buildDir 29 | } 30 | -------------------------------------------------------------------------------- /bindings/dart/example/ios/.gitignore: -------------------------------------------------------------------------------- 1 | *.mode1v3 2 | *.mode2v3 3 | *.moved-aside 4 | *.pbxuser 5 | *.perspectivev3 6 | **/*sync/ 7 | .sconsign.dblite 8 | .tags* 9 | **/.vagrant/ 10 | **/DerivedData/ 11 | Icon? 12 | **/Pods/ 13 | **/.symlinks/ 14 | profile 15 | xcuserdata 16 | **/.generated/ 17 | Flutter/App.framework 18 | Flutter/Flutter.framework 19 | Flutter/Flutter.podspec 20 | Flutter/Generated.xcconfig 21 | Flutter/ephemeral/ 22 | Flutter/app.flx 23 | Flutter/app.zip 24 | Flutter/flutter_assets/ 25 | Flutter/flutter_export_environment.sh 26 | ServiceDefinitions.json 27 | Runner/GeneratedPluginRegistrant.* 28 | 29 | # Exceptions to above rules. 30 | !default.mode1v3 31 | !default.mode2v3 32 | !default.pbxuser 33 | !default.perspectivev3 34 | -------------------------------------------------------------------------------- /dev-tool/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "dev-tool" 3 | version = "0.1.0" 4 | authors = ["Xayn Engineering "] 5 | edition = "2018" 6 | 7 | [dependencies] 8 | anyhow = "1.0.56" 9 | base64 = "0.13.0" 10 | bincode = "1.3.3" 11 | csv = "1.1.6" 12 | displaydoc = "0.2.3" 13 | env_logger = "0.9.0" 14 | indicatif = "0.16.2" 15 | itertools = "0.10.3" 16 | layer = { path = "../layer" } 17 | log = "0.4.14" 18 | ndarray = "0.15.3" 19 | rand = "0.8.5" 20 | rayon = "1.5.1" 21 | serde = { version = "1.0.136", features = ["derive"] } 22 | serde_json = "1.0.79" 23 | structopt = "0.3.26" 24 | thiserror = "1.0.30" 25 | uuid = "0.8.2" 26 | xayn-ai = { path = "../xayn-ai" } 27 | 28 | [dev-dependencies] 29 | test-utils = { path = "../test-utils" } 30 | -------------------------------------------------------------------------------- /generate_assets_metadata.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Generates the metadata of the assets (ai and wasm). If an asset of the `data_assets` array 4 | # contains a `chunk_size` key, the script splits the asset into chunks where each chunk has 5 | # a maximum size of `chunk_size`. The format of the `chunk_size` value is equivalent to the 6 | # `SIZE` argument in `split` or `gsplit` on macOS. See `split`/`gsplit` man page for more details. 7 | # 8 | # The script needs to be executed in the root of the repository. 9 | # 10 | # Usage: 11 | # ./generate_assets_metadata [] 12 | set -e 13 | 14 | source $(dirname "$0")/scripts/assets_generation_lib.sh 15 | 16 | gen_data_and_wasm_assets_metadata "$1" "$2" 17 | generate_dart_assets_manifest 18 | -------------------------------------------------------------------------------- /dev-tool/src/list_net.rs: -------------------------------------------------------------------------------- 1 | #![cfg(not(tarpaulin))] 2 | use anyhow::Error; 3 | use structopt::StructOpt; 4 | 5 | use self::{convert::ConvertCmd, evaluate::EvaluateCmd, train::TrainCmd}; 6 | 7 | mod cli_callbacks; 8 | mod convert; 9 | mod data_source; 10 | mod evaluate; 11 | mod train; 12 | 13 | /// Commands related to training ListNet (train, convert, evaluate). 14 | #[derive(StructOpt, Debug)] 15 | pub enum ListNetCmd { 16 | Convert(ConvertCmd), 17 | Train(TrainCmd), 18 | Evaluate(EvaluateCmd), 19 | } 20 | 21 | impl ListNetCmd { 22 | pub fn run(self) -> Result { 23 | use ListNetCmd::*; 24 | match self { 25 | Convert(cmd) => cmd.run(), 26 | Train(cmd) => cmd.run(), 27 | Evaluate(cmd) => cmd.run(), 28 | } 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /data/asset_templates/base_assets.dart.tmpl: -------------------------------------------------------------------------------- 1 | /// Warning, this file is autogenerated. Don't modify this manually. 2 | 3 | part of 'data_provider.dart'; 4 | 5 | /// Base assets that are required for both mobile and web. 6 | /// 7 | /// The checksum is the sha256 hash of the asset. 8 | /// To calculate the checksum run 'shasum -a 256 vocab.txt'. 9 | final baseAssets = { 10 | {{- range (ds "assets_manifest").assets }} 11 | AssetType.{{.id}}: Asset('{{.url_suffix}}', Checksum('{{.checksum}}'), [ 12 | {{ range .fragments }} 13 | Fragment('{{.url_suffix}}', Checksum('{{.checksum}}')), 14 | {{ end }} 15 | ]), 16 | {{- end}} 17 | }; 18 | 19 | enum AssetType { 20 | {{- range (ds "assets_manifest").assets }} 21 | {{.id}}, 22 | {{- end}} 23 | wasmModule, 24 | wasmScript, 25 | webWorkerScript, 26 | } 27 | -------------------------------------------------------------------------------- /test-utils/src/qambert.rs: -------------------------------------------------------------------------------- 1 | use std::{io::Result, path::PathBuf}; 2 | 3 | use crate::asset::resolve_asset; 4 | 5 | /// Resolves the path to the QAMBert vocabulary. 6 | pub fn vocab() -> Result { 7 | resolve_asset("qambertVocab") 8 | } 9 | 10 | /// Resolves the path to the QAMBert model. 11 | pub fn model() -> Result { 12 | resolve_asset("qambertModel") 13 | } 14 | 15 | /// Resolves the path to the quantized QAMBert model. 16 | pub fn model_quant() -> Result { 17 | Ok(resolve_asset("qambertModel")?.with_file_name("qambert-quant.onnx")) 18 | } 19 | 20 | #[cfg(test)] 21 | mod tests { 22 | use super::*; 23 | 24 | #[test] 25 | fn test_vocab() { 26 | assert!(vocab().is_ok()); 27 | } 28 | 29 | #[test] 30 | fn test_model() { 31 | assert!(model().is_ok()); 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /bindings/dart/lib/src/mobile/ffi/library.dart: -------------------------------------------------------------------------------- 1 | import 'dart:ffi' show DynamicLibrary; 2 | import 'dart:io' show Platform; 3 | 4 | import 'package:xayn_ai_ffi_dart/src/mobile/ffi/genesis.dart' show XaynAiFfi; 5 | 6 | /// Opens the platform dependent Rust library. 7 | DynamicLibrary _open() { 8 | if (Platform.isAndroid) { 9 | return DynamicLibrary.open('libxayn_ai_ffi_c.so'); 10 | } 11 | if (Platform.isIOS) { 12 | return DynamicLibrary.process(); 13 | } 14 | if (Platform.isLinux) { 15 | return DynamicLibrary.open('../../target/debug/libxayn_ai_ffi_c.so'); 16 | } 17 | if (Platform.isMacOS) { 18 | return DynamicLibrary.open('../../target/debug/libxayn_ai_ffi_c.dylib'); 19 | } 20 | throw UnsupportedError('Unsupported platform.'); 21 | } 22 | 23 | /// The handle to the C-FFI of the Rust library. 24 | final ffi = XaynAiFfi(_open()); 25 | -------------------------------------------------------------------------------- /.ci/copy-headers/action.yml: -------------------------------------------------------------------------------- 1 | name: 'copy ffi header files' 2 | description: 'Copies ffi header files' 3 | inputs: 4 | working-directory: 5 | description: 'The working directory' 6 | required: true 7 | dart-ws: 8 | description: 'The Dart workspace' 9 | required: true 10 | runs: 11 | using: "composite" 12 | steps: 13 | - shell: bash 14 | working-directory: ${{ inputs.working-directory }} 15 | run: | 16 | cp ios/Classes/XaynAiFfiCommon.h ${{ inputs.dart-ws }}/ios/Classes 17 | cp ios/Classes/XaynAiFfiDart.h ${{ inputs.dart-ws }}/ios/Classes 18 | cp lib/src/common/ffi/genesis.dart ${{ inputs.dart-ws }}/lib/src/common/ffi 19 | cp lib/src/mobile/ffi/genesis.dart ${{ inputs.dart-ws }}/lib/src/mobile/ffi 20 | find lib/ -type f -regex ".*\.g\.dart" -exec cp --parents '{}' ${{ inputs.dart-ws }}/ \; 21 | -------------------------------------------------------------------------------- /bindings/dart/example/ios/Podfile.lock: -------------------------------------------------------------------------------- 1 | PODS: 2 | - Flutter (1.0.0) 3 | - path_provider (0.0.1): 4 | - Flutter 5 | - xayn_ai_ffi_dart (0.0.1): 6 | - Flutter 7 | 8 | DEPENDENCIES: 9 | - Flutter (from `Flutter`) 10 | - path_provider (from `.symlinks/plugins/path_provider/ios`) 11 | - xayn_ai_ffi_dart (from `.symlinks/plugins/xayn_ai_ffi_dart/ios`) 12 | 13 | EXTERNAL SOURCES: 14 | Flutter: 15 | :path: Flutter 16 | path_provider: 17 | :path: ".symlinks/plugins/path_provider/ios" 18 | xayn_ai_ffi_dart: 19 | :path: ".symlinks/plugins/xayn_ai_ffi_dart/ios" 20 | 21 | SPEC CHECKSUMS: 22 | Flutter: 50d75fe2f02b26cc09d224853bb45737f8b3214a 23 | path_provider: abfe2b5c733d04e238b0d8691db0cfd63a27a93c 24 | xayn_ai_ffi_dart: 8b41d449eea984489dad516f22ae1423f3795aa1 25 | 26 | PODFILE CHECKSUM: 35bffc6892c3d164e5534753ef0ec939638cfaf8 27 | 28 | COCOAPODS: 1.11.2 29 | -------------------------------------------------------------------------------- /xayn-ai/src/tests/mem_db.rs: -------------------------------------------------------------------------------- 1 | use std::cell::RefCell; 2 | 3 | use crate::{ 4 | reranker::database::{Database, RerankerData}, 5 | Error, 6 | }; 7 | 8 | pub(crate) struct MemDb { 9 | data: RefCell>, 10 | } 11 | 12 | impl MemDb { 13 | pub(crate) fn new() -> Self { 14 | Self { 15 | data: RefCell::new(None), 16 | } 17 | } 18 | 19 | pub(crate) fn from_data(data: RerankerData) -> Self { 20 | Self { 21 | data: RefCell::new(data.into()), 22 | } 23 | } 24 | } 25 | 26 | impl Database for MemDb { 27 | fn serialize(&self, _data: &RerankerData) -> Result, Error> { 28 | unimplemented!("mocked database does not have a serialized representation") 29 | } 30 | 31 | fn load_data(&self) -> Result, Error> { 32 | Ok(self.data.borrow().clone()) 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /bindings/dart/lib/src/web/result/outcomes.dart: -------------------------------------------------------------------------------- 1 | @JS() 2 | library outcomes; 3 | 4 | import 'package:js/js.dart' show anonymous, JS; 5 | 6 | import 'package:xayn_ai_ffi_dart/src/common/result/outcomes.dart' 7 | show RerankingOutcomes; 8 | 9 | @JS() 10 | @anonymous 11 | class JsRerankingOutcomes { 12 | // ignore: non_constant_identifier_names 13 | external List final_ranking; 14 | 15 | // ignore: non_constant_identifier_names 16 | external List? qambert_similarities; 17 | 18 | // ignore: non_constant_identifier_names 19 | external List? context_scores; 20 | } 21 | 22 | extension ToRerankingOutcomes on JsRerankingOutcomes { 23 | /// Creates reranking outcomes from the current JS representation. 24 | RerankingOutcomes toRerankingOutcomes() => RerankingOutcomes.fromParts( 25 | final_ranking, 26 | qambert_similarities, 27 | context_scores, 28 | ); 29 | } 30 | -------------------------------------------------------------------------------- /bindings/dart/example/ios/Flutter/AppFrameworkInfo.plist: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | CFBundleDevelopmentRegion 6 | en 7 | CFBundleExecutable 8 | App 9 | CFBundleIdentifier 10 | io.flutter.flutter.app 11 | CFBundleInfoDictionaryVersion 12 | 6.0 13 | CFBundleName 14 | App 15 | CFBundlePackageType 16 | FMWK 17 | CFBundleShortVersionString 18 | 1.0 19 | CFBundleSignature 20 | ???? 21 | CFBundleVersion 22 | 1.0 23 | MinimumOSVersion 24 | 12.1 25 | 26 | 27 | -------------------------------------------------------------------------------- /test-utils/src/test/ltr.rs: -------------------------------------------------------------------------------- 1 | use std::{io::Result, path::PathBuf}; 2 | 3 | use crate::asset::{resolve_path, DATA_DIR}; 4 | 5 | const ASSET: &str = "ltr_test_data_v0001"; 6 | 7 | /// Resolves the path to the LTR feature extraction test data. 8 | pub fn feature_extraction_test_cases() -> Result { 9 | resolve_path(&[DATA_DIR, ASSET, "feature_extraction"]) 10 | } 11 | 12 | /// Resolves the path to the intermediate LTR test model. 13 | pub fn training_intermediates() -> Result { 14 | resolve_path(&[DATA_DIR, ASSET, "check_training_intermediates.binparams"]) 15 | } 16 | 17 | #[cfg(test)] 18 | mod tests { 19 | use super::*; 20 | 21 | #[test] 22 | fn test_feature_extraction_test_cases() { 23 | assert!(feature_extraction_test_cases().is_ok()); 24 | } 25 | 26 | #[test] 27 | fn test_training_intermediates() { 28 | assert!(training_intermediates().is_ok()); 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /.ci/install-wasm-pack/action.yml: -------------------------------------------------------------------------------- 1 | name: 'wasm-pack' 2 | description: 'Installs wasm-pack' 3 | runs: 4 | using: "composite" 5 | steps: 6 | - shell: bash 7 | working-directory: ${{ runner.temp }} 8 | env: 9 | LINUX_URL: https://github.com/rustwasm/wasm-pack/releases/download/v0.10.1/wasm-pack-v0.10.1-x86_64-unknown-linux-musl.tar.gz 10 | # `shasum -a 256 wasm-pack` 11 | LINUX_CHECKSUM: f6eddf40f7fae0676c8cec4bff0b9f2315cf082ae5e24fab869377c2ee3a601c 12 | run: | 13 | if [ ${{ runner.os }} == "Linux" ]; then 14 | URL=${{ env.LINUX_URL }} 15 | CHECKSUM=${{ env.LINUX_CHECKSUM }} 16 | else 17 | echo "::error wasm-pack for ${{ runner.os }} is not supported" 18 | exit 1 19 | fi 20 | 21 | wget -q -O - $URL | tar xvzf - --strip-components 1 22 | echo "$CHECKSUM *wasm-pack" | shasum -c - 23 | mv wasm-pack $HOME/.cargo/bin/ 24 | -------------------------------------------------------------------------------- /bindings/dart/android/build.gradle: -------------------------------------------------------------------------------- 1 | group 'com.xayn.xayn_ai_ffi_dart' 2 | version '1.0-SNAPSHOT' 3 | 4 | buildscript { 5 | ext.kotlin_version = '1.3.50' 6 | repositories { 7 | google() 8 | jcenter() 9 | } 10 | 11 | dependencies { 12 | classpath 'com.android.tools.build:gradle:4.1.0' 13 | classpath "org.jetbrains.kotlin:kotlin-gradle-plugin:$kotlin_version" 14 | } 15 | } 16 | 17 | rootProject.allprojects { 18 | repositories { 19 | google() 20 | jcenter() 21 | } 22 | } 23 | 24 | apply plugin: 'com.android.library' 25 | apply plugin: 'kotlin-android' 26 | 27 | android { 28 | compileSdkVersion 30 29 | 30 | sourceSets { 31 | main.java.srcDirs += 'src/main/kotlin' 32 | } 33 | defaultConfig { 34 | minSdkVersion 21 35 | } 36 | } 37 | 38 | dependencies { 39 | implementation "org.jetbrains.kotlin:kotlin-stdlib-jdk7:$kotlin_version" 40 | } 41 | -------------------------------------------------------------------------------- /bindings/dart/lib/package.dart: -------------------------------------------------------------------------------- 1 | export 'src/common/data/document.dart' show Document; 2 | export 'src/common/data/history.dart' 3 | show UserFeedback, History, Relevance, DayOfWeek, UserAction; 4 | export 'src/common/reranker/ai.dart' 5 | if (dart.library.io) 'src/mobile/reranker/ai.dart' 6 | if (dart.library.js) 'src/web/reranker/ai.dart' show XaynAi; 7 | export 'src/common/reranker/analytics.dart' show Analytics; 8 | export 'src/common/reranker/data_provider.dart' 9 | show Asset, AssetType, Feature, Fragment, WebFeature; 10 | export 'src/common/reranker/data_provider.dart' 11 | if (dart.library.io) 'src/mobile/reranker/data_provider.dart' 12 | if (dart.library.js) 'src/web/reranker/data_provider.dart' 13 | show getAssets, SetupData; 14 | export 'src/common/reranker/debug.dart' show RerankDebugCallData; 15 | export 'src/common/reranker/mode.dart' show RerankMode; 16 | export 'src/common/result/outcomes.dart' show RerankingOutcomes; 17 | -------------------------------------------------------------------------------- /bindings/dart/lib/src/mobile/result/slice.dart: -------------------------------------------------------------------------------- 1 | import 'dart:ffi' show nullptr, Uint16Pointer, FloatPointer; 2 | 3 | import 'package:xayn_ai_ffi_dart/src/mobile/ffi/genesis.dart' 4 | show CBoxedSlice_f32, CBoxedSlice_u16; 5 | 6 | extension BoxedSliceU16List on CBoxedSlice_u16 { 7 | /// Converts a `CBoxedSlice` to an `Uint16List`. 8 | List? toList() { 9 | if (data == nullptr) { 10 | return null; 11 | } else if (len == 0) { 12 | return List.empty(); 13 | } else { 14 | return data.asTypedList(len).toList(growable: false); 15 | } 16 | } 17 | } 18 | 19 | extension BoxedSliceF32List on CBoxedSlice_f32 { 20 | /// Converts a `CBoxedSlice` to a `Float32List`. 21 | List? toList() { 22 | if (data == nullptr) { 23 | return null; 24 | } else if (len == 0) { 25 | return List.empty(); 26 | } else { 27 | return data.asTypedList(len).toList(growable: false); 28 | } 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /bindings/dart/example/server.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | from http import server 3 | 4 | class MyHTTPRequestHandler(server.SimpleHTTPRequestHandler): 5 | def end_headers(self): 6 | self.send_header("Cross-Origin-Embedder-Policy", "require-corp") 7 | self.send_header("Cross-Origin-Opener-Policy", "same-origin") 8 | 9 | server.SimpleHTTPRequestHandler.end_headers(self) 10 | 11 | def main(data_dir): 12 | import os 13 | os.chdir(data_dir) 14 | print("Now hosting web example at http://localhost:8000") 15 | print("DO NOT USE AN IP ADDRESS TO OPEN THE WEB SITE, IT WILL NOT WORK") 16 | httpd = server.HTTPServer(('localhost', 8000), MyHTTPRequestHandler) 17 | httpd.serve_forever() 18 | 19 | if __name__ == '__main__': 20 | import sys 21 | if len(sys.argv) == 2 and sys.argv[1] not in ["-h", "--help"]: 22 | main(data_dir=sys.argv[1]) 23 | else: 24 | print(f"Usage {sys.argv[0]} ", file=sys.stderr) 25 | -------------------------------------------------------------------------------- /kpe/benches/kpe.rs: -------------------------------------------------------------------------------- 1 | //! Run as `cargo bench --bench kpe`. 2 | 3 | use criterion::{black_box, criterion_group, criterion_main, Criterion}; 4 | 5 | use kpe::{Config, Pipeline}; 6 | use test_utils::kpe::{bert, classifier, cnn, vocab}; 7 | 8 | fn bench_kpe(manager: &mut Criterion) { 9 | let config = Config::from_files( 10 | vocab().unwrap(), 11 | bert().unwrap(), 12 | cnn().unwrap(), 13 | classifier().unwrap(), 14 | ) 15 | .unwrap() 16 | .with_token_size(128) 17 | .unwrap(); 18 | let pipeline = Pipeline::from(config).unwrap(); 19 | 20 | let sequence = "This sequence will be split into key phrases."; 21 | manager.bench_function("KPE", |bencher| { 22 | bencher.iter(|| pipeline.run(black_box(sequence)).unwrap()) 23 | }); 24 | } 25 | 26 | criterion_group! { 27 | name = bench; 28 | config = Criterion::default(); 29 | targets = 30 | bench_kpe, 31 | } 32 | 33 | criterion_main! { 34 | bench, 35 | } 36 | -------------------------------------------------------------------------------- /xayn-ai/src/ltr/list_net/optimizer.rs: -------------------------------------------------------------------------------- 1 | use super::{data::GradientSet, model::ListNet}; 2 | 3 | /// Optimizer which applies gradients in a specific way to the current list net instance. 4 | pub trait Optimizer { 5 | /// Runs the next optimization step by applying the given gradients on the given ListNet. 6 | fn apply_gradients(&mut self, list_net: &mut ListNet, batch_of_gradient_sets: Vec); 7 | } 8 | 9 | /// Mini-Batch Stochastic Gradient Descent 10 | #[derive(Clone)] 11 | pub struct MiniBatchSgd { 12 | pub learning_rate: f32, 13 | } 14 | 15 | impl Optimizer for MiniBatchSgd { 16 | fn apply_gradients( 17 | &mut self, 18 | list_net: &mut ListNet, 19 | batch_of_gradient_sets: Vec, 20 | ) { 21 | if let Some(mut gradient_set) = GradientSet::mean_of(batch_of_gradient_sets) { 22 | gradient_set *= -self.learning_rate; 23 | list_net.add_gradients(gradient_set); 24 | } 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /bindings/dart/ios/xayn_ai_ffi_dart.podspec: -------------------------------------------------------------------------------- 1 | # 2 | # To learn more about a Podspec see http://guides.cocoapods.org/syntax/podspec.html. 3 | # Run `pod lib lint xayn_ai_ffi_dart.podspec` to validate before publishing. 4 | # 5 | Pod::Spec.new do |s| 6 | s.name = 'xayn_ai_ffi_dart' 7 | s.version = '0.0.1' 8 | s.summary = 'XaynAI flutter plugin project.' 9 | s.description = <<-DESC 10 | XaynAI plugin project. 11 | DESC 12 | s.homepage = 'http://xayn.com' 13 | s.license = { :file => '../../../LICENSE' } 14 | s.author = { 'Xayn' => 'engineering@xaynet.dev' } 15 | s.source = { :path => '.' } 16 | s.source_files = 'Classes/**/*' 17 | s.vendored_libraries = "**/*.a" 18 | s.dependency 'Flutter' 19 | s.platform = :ios, '9.0' 20 | 21 | # Flutter.framework does not contain a i386 slice. 22 | s.pod_target_xcconfig = { 'DEFINES_MODULE' => 'YES', 'EXCLUDED_ARCHS[sdk=iphonesimulator*]' => 'i386' } 23 | s.swift_version = '5.0' 24 | end 25 | -------------------------------------------------------------------------------- /bindings/dart/xayn_ai_ffi_dart.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | -------------------------------------------------------------------------------- /bindings/dart/ios/Classes/SwiftXaynAiFfiDartPlugin.swift: -------------------------------------------------------------------------------- 1 | import Flutter 2 | import UIKit 3 | 4 | public class SwiftXaynAiFfiDartPlugin: NSObject, FlutterPlugin { 5 | public static func register(with registrar: FlutterPluginRegistrar) { 6 | let channel = FlutterMethodChannel(name: "xayn_ai_ffi_dart", binaryMessenger: registrar.messenger()) 7 | let instance = SwiftXaynAiFfiDartPlugin() 8 | registrar.addMethodCallDelegate(instance, channel: channel) 9 | } 10 | 11 | public func handle(_ call: FlutterMethodCall, result: @escaping FlutterResult) { 12 | result("iOS " + UIDevice.current.systemVersion) 13 | } 14 | 15 | // The Xcode toolchain won't include the shared library in the build 16 | // process unless a method from the library is invoked. So, this 17 | // call to a component directly related to the ai is just done to ensure 18 | // that the library is included, independent of the actual return value 19 | // or failure. 20 | public func enforceBinding(){ 21 | xaynai_new(nil, nil, nil, nil, nil, nil, nil) 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /bindings/dart/lib/src/mobile/reranker/data_provider.dart: -------------------------------------------------------------------------------- 1 | import 'package:xayn_ai_ffi_dart/src/common/reranker/data_provider.dart' 2 | as common show Asset, AssetType, baseAssets, MobileFeature, SetupData; 3 | 4 | /// Returns a map of all assets required for initializing [`XaynAi`]. 5 | Map getAssets( 6 | {Set features = const {}}) => 7 | common.baseAssets; 8 | 9 | /// Data that is required to initialize [`XaynAi`]. 10 | class SetupData implements common.SetupData { 11 | late String smbertVocab; 12 | late String smbertModel; 13 | late String qambertVocab; 14 | late String qambertModel; 15 | late String ltrModel; 16 | 17 | SetupData(Map assets) { 18 | smbertVocab = assets[common.AssetType.smbertVocab]!; 19 | smbertModel = assets[common.AssetType.smbertModel]!; 20 | qambertVocab = assets[common.AssetType.qambertVocab]!; 21 | qambertModel = assets[common.AssetType.qambertModel]!; 22 | ltrModel = assets[common.AssetType.ltrModel]!; 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /kpe/src/model/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod bert; 2 | pub mod classifier; 3 | pub mod cnn; 4 | 5 | use std::io::Error as IoError; 6 | 7 | use displaydoc::Display; 8 | use ndarray::ShapeError; 9 | use thiserror::Error; 10 | use tract_onnx::prelude::TractError; 11 | 12 | use layer::{conv::ConvError, io::LoadingLayerFailed}; 13 | 14 | /// The potential errors of the models. 15 | #[derive(Debug, Display, Error)] 16 | pub enum ModelError { 17 | /// Failed to read the onnx model: {0} 18 | Read(#[from] IoError), 19 | 20 | /// Failed to run a tract operation: {0} 21 | Tract(#[from] TractError), 22 | 23 | /// Invalid array shapes: {0} 24 | Shape(#[from] ShapeError), 25 | 26 | /// Failed to read or run the CNN model: {0} 27 | Cnn(#[from] ConvError), 28 | 29 | /// Failed to read the Classifier model: {0} 30 | Classifier(#[from] LoadingLayerFailed), 31 | 32 | /// Remaining parameters must be used: {0:?} 33 | UnusedParams(Vec), 34 | 35 | /// The sequence must contain at least `KEY_PHRASE_SIZE` valid words 36 | NotEnoughWords, 37 | } 38 | -------------------------------------------------------------------------------- /download_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | set -eu 4 | 5 | # We can't use `pushd` or `readlink -f` so we 6 | # fall back to this. 7 | CALLING_BASE_DIR="$(pwd -L)" 8 | 9 | # path to the directory where this file is 10 | SELF_DIR_PATH="$(dirname "$0")" 11 | 12 | # in this way we can call the script from different directory 13 | # but the data should go in the correct destination 14 | DATA_DIR="$SELF_DIR_PATH/data" 15 | 16 | CHECKSUM_FILE="sha256sums" 17 | 18 | download() 19 | { 20 | cd "$CALLING_BASE_DIR" 21 | NAME="$1" 22 | VERSION="$2" 23 | ARCHIVE_BASENAME="${NAME}_$VERSION" 24 | ARCHIVE_NAME="$ARCHIVE_BASENAME.tgz" 25 | URL="http://s3-de-central.profitbricks.com/xayn-yellow-bert/$NAME/$ARCHIVE_NAME" 26 | 27 | curl "$URL" -o "$DATA_DIR/$ARCHIVE_NAME" 28 | 29 | cd "$DATA_DIR" 30 | tar -zxf "$ARCHIVE_NAME" 31 | 32 | # check content 33 | cd "$ARCHIVE_BASENAME" 34 | shasum -c "$CHECKSUM_FILE" 35 | } 36 | 37 | download smbert v0001 38 | download qambert v0001 39 | download ltr v0000 40 | download bench_matmul v0000 41 | download ltr_test_data v0001 42 | download kpe v0001 43 | -------------------------------------------------------------------------------- /bindings/dart/pubspec.yaml: -------------------------------------------------------------------------------- 1 | name: 'xayn_ai_ffi_dart' 2 | version: '4.0.0' 3 | 4 | environment: 5 | sdk: '>=2.14.4 <3.0.0' 6 | flutter: '>=2.5.3 <3.0.0' 7 | 8 | dependencies: 9 | async: '^2.8.1' 10 | ffi: '^1.1.2' 11 | flutter: 12 | sdk: 'flutter' 13 | js: '^0.6.3' 14 | json_annotation: '^4.1.0' 15 | meta: '^1.7.0' 16 | hex: '^0.2.0' 17 | 18 | dev_dependencies: 19 | build_runner: '^2.1.2' 20 | json_serializable: '^5.0.2' 21 | ffigen: '^4.0.0' 22 | flutter_test: 23 | sdk: 'flutter' 24 | pedantic: '^1.11.1' 25 | 26 | # The following section is specific to Flutter. 27 | flutter: 28 | # This section identifies this Flutter project as a plugin project. 29 | # The 'pluginClass' and Android 'package' identifiers should not ordinarily 30 | # be modified. They are used by the tooling to maintain consistency when 31 | # adding or updating assets for this project. 32 | plugin: 33 | platforms: 34 | android: 35 | package: 'com.xayn.xayn_ai_ffi_dart' 36 | pluginClass: 'XaynAiFfiDartPlugin' 37 | ios: 38 | pluginClass: 'XaynAiFfiDartPlugin' 39 | -------------------------------------------------------------------------------- /bindings/dart/example/android/app/src/main/res/values/styles.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 9 | 15 | 18 | 19 | -------------------------------------------------------------------------------- /.ci/install-wasm-opt/action.yml: -------------------------------------------------------------------------------- 1 | name: 'wasm-opt' 2 | description: 'Installs wasm-opt' 3 | runs: 4 | using: "composite" 5 | steps: 6 | - shell: bash 7 | working-directory: ${{ runner.temp }} 8 | env: 9 | LINUX_URL: https://github.com/WebAssembly/binaryen/releases/download/version_101/binaryen-version_101-x86_64-linux.tar.gz 10 | # https://github.com/WebAssembly/binaryen/releases/download/version_101/binaryen-version_101-x86_64-linux.tar.gz.sha256 11 | LINUX_CHECKSUM: 20d0b19ca716c51d927f181802125f04d5685250c8a22ec3022ac28bf4f20c57 12 | run: | 13 | if [ ${{ runner.os }} == "Linux" ]; then 14 | URL=${{ env.LINUX_URL }} 15 | CHECKSUM=${{ env.LINUX_CHECKSUM }} 16 | INSTALL_PATH=$HOME/.local/bin 17 | else 18 | echo "::error wasm-opt for ${{ runner.os }} is not supported" 19 | exit 1 20 | fi 21 | 22 | wget -q -O binaryen.tar.gz $URL 23 | echo "$CHECKSUM *binaryen.tar.gz" | shasum -c - 24 | tar xvzf binaryen.tar.gz --strip-components 1 25 | mkdir -p $INSTALL_PATH 26 | mv bin/wasm-opt $INSTALL_PATH 27 | -------------------------------------------------------------------------------- /xayn-ai/src/tests/mod.rs: -------------------------------------------------------------------------------- 1 | mod mem_db; 2 | mod systems; 3 | mod utils; 4 | 5 | pub(crate) use self::{ 6 | mem_db::MemDb, 7 | systems::{mocked_smbert_system, MockCommonSystems}, 8 | utils::{ 9 | data_with_rank, 10 | document_history, 11 | documents_from_ids, 12 | documents_from_words, 13 | documents_with_embeddings_from_ids, 14 | documents_with_embeddings_from_snippet_and_query, 15 | expected_rerank_unchanged, 16 | from_ids, 17 | history_for_prev_docs, 18 | mock_uuid, 19 | neg_cois_from_words, 20 | neg_cois_from_words_with_ids, 21 | pos_cois_from_words, 22 | pos_cois_from_words_v0, 23 | pos_cois_from_words_v1, 24 | pos_cois_from_words_v2, 25 | pos_cois_from_words_with_ids, 26 | }, 27 | }; 28 | 29 | #[cfg(test)] 30 | pub(crate) use crate::reranker::{ 31 | database::MockDatabase, 32 | systems::{ 33 | MockAnalyticsSystem, 34 | MockCoiSystem, 35 | MockContextSystem, 36 | MockLtrSystem, 37 | MockQAMBertSystem, 38 | MockSMBertSystem, 39 | }, 40 | }; 41 | -------------------------------------------------------------------------------- /bindings/dart/example/assets/call_data/example2.json: -------------------------------------------------------------------------------- 1 | { 2 | "rerank_mode": 3, 3 | "histories": [ 4 | { 5 | "id": "fcb6a685-eb92-4d36-8686-000000000000", 6 | "relevance": 2, 7 | "user_feedback": 1, 8 | "rank": 0, 9 | "user_action": 2, 10 | "session": "fcb6a685-eb92-4d36-8686-a00000000010", 11 | "query_count": 1, 12 | "query_id": "fcb6a685-eb92-4d36-8686-b00000000011", 13 | "query_words": "transport", 14 | "day": 0, 15 | "url": "url", 16 | "domain": "dom" 17 | } 18 | ], 19 | "documents": [ 20 | { 21 | "id": "fcb6a685-eb92-4d36-8686-000000000000", 22 | "title": "apple", 23 | "snippet": "snippet of apple", 24 | "rank": 0, 25 | "session": "fcb6a685-eb92-4d36-8686-a00000000010", 26 | "query_count": 2, 27 | "query_id": "fcb6a685-eb92-4d36-8686-b00000000012", 28 | "query_words": "bunny", 29 | "url": "url", 30 | "domain": "dom" 31 | } 32 | ], 33 | "serialized_state": null 34 | } 35 | -------------------------------------------------------------------------------- /bindings/dart/test/mobile/reranker/bytes_test.dart: -------------------------------------------------------------------------------- 1 | import 'dart:ffi' show nullptr; 2 | import 'dart:typed_data' show Uint8List; 3 | 4 | import 'package:flutter_test/flutter_test.dart' 5 | show equals, expect, group, isEmpty, isNot, test; 6 | 7 | import 'package:xayn_ai_ffi_dart/src/mobile/reranker/bytes.dart' show Bytes; 8 | 9 | void main() { 10 | group('Bytes', () { 11 | test('list', () { 12 | final list = Uint8List.fromList(List.generate(10, (i) => i)); 13 | final bytes = Bytes.fromList(list); 14 | expect(bytes.toList(), equals(list)); 15 | bytes.free(); 16 | }); 17 | 18 | test('null', () { 19 | final bytes = Bytes(nullptr); 20 | expect(bytes.toList(), isEmpty); 21 | }); 22 | 23 | test('empty', () { 24 | final bytes = Bytes.fromList(Uint8List(0)); 25 | expect(bytes.toList(), isEmpty); 26 | bytes.free(); 27 | }); 28 | 29 | test('free', () { 30 | final bytes = 31 | Bytes.fromList(Uint8List.fromList(List.generate(10, (i) => i))); 32 | expect(bytes.ptr, isNot(equals(nullptr))); 33 | bytes.free(); 34 | expect(bytes.ptr, equals(nullptr)); 35 | }); 36 | }); 37 | } 38 | -------------------------------------------------------------------------------- /xayn-ai-ffi-wasm/src/lib.rs: -------------------------------------------------------------------------------- 1 | //! WASM FFI for the Xayn AI. 2 | #![cfg_attr( 3 | doc, 4 | forbid(rustdoc::broken_intra_doc_links, rustdoc::private_intra_doc_links) 5 | )] 6 | #![forbid(unsafe_op_in_unsafe_fn)] 7 | // TODO: remove clippy lint once wasm-bindgen fixes the regression from 0.2.79 8 | #![allow(clippy::unused_unit)] 9 | 10 | #[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] 11 | use ::{ 12 | js_sys::Promise, 13 | wasm_bindgen::{prelude::wasm_bindgen, JsValue}, 14 | }; 15 | 16 | #[cfg(not(tarpaulin))] 17 | mod ai; 18 | #[cfg(not(tarpaulin))] 19 | mod error; 20 | 21 | #[cfg(all(not(tarpaulin), doc))] 22 | pub use crate::ai::WXaynAi; 23 | 24 | /// Reexport to allow initialization of a WebWorker based on the rayon thread pool. 25 | #[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] 26 | pub use wasm_bindgen_rayon::init_thread_pool; 27 | 28 | /// Stub which is used when the wasm blob was compiled without the `multithreaded` feature. 29 | #[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] 30 | #[wasm_bindgen(js_name = initThreadPool)] 31 | pub fn init_thread_pool(_num_threads: usize) -> Promise { 32 | Promise::resolve(&JsValue::UNDEFINED) 33 | } 34 | -------------------------------------------------------------------------------- /bindings/dart/example/.gitignore: -------------------------------------------------------------------------------- 1 | # Miscellaneous 2 | *.class 3 | *.log 4 | *.pyc 5 | *.swp 6 | .DS_Store 7 | .atom/ 8 | .buildlog/ 9 | .history 10 | .svn/ 11 | 12 | # IntelliJ related 13 | *.iml 14 | *.ipr 15 | *.iws 16 | .idea/ 17 | 18 | # The .vscode folder contains launch configuration and tasks you configure in 19 | # VS Code which you may wish to be included in version control, so this line 20 | # is commented out by default. 21 | #.vscode/ 22 | 23 | # Flutter/Dart/Pub related 24 | **/doc/api/ 25 | **/ios/Flutter/.last_build_id 26 | .dart_tool/ 27 | .flutter-plugins 28 | .flutter-plugins-dependencies 29 | .packages 30 | .pub-cache/ 31 | .pub/ 32 | /build/ 33 | 34 | # Web related 35 | lib/generated_plugin_registrant.dart 36 | 37 | # Symbolication related 38 | app.*.symbols 39 | 40 | # Obfuscation related 41 | app.*.map.json 42 | 43 | # Android Studio will place build artifacts here 44 | /android/app/debug 45 | /android/app/profile 46 | /android/app/release 47 | 48 | # assets 49 | assets/ltr_v0000/* 50 | !assets/ltr_v0000/.gitkeep 51 | assets/qambert_v0001/* 52 | !assets/qambert_v0001/.gitkeep 53 | assets/smbert_v0001/* 54 | !assets/smbert_v0001/.gitkeep 55 | assets/wasm_bindings/* 56 | !assets/wasm_bindings/.gitkeep 57 | -------------------------------------------------------------------------------- /test-utils/src/kpe.rs: -------------------------------------------------------------------------------- 1 | use std::{io::Result, path::PathBuf}; 2 | 3 | use crate::asset::{resolve_path, DATA_DIR}; 4 | 5 | const ASSET: &str = "kpe_v0001"; 6 | 7 | /// Resolves the path to the Bert vocabulary. 8 | pub fn vocab() -> Result { 9 | resolve_path(&[DATA_DIR, ASSET, "vocab.txt"]) 10 | } 11 | 12 | /// Resolves the path to the Bert model. 13 | pub fn bert() -> Result { 14 | resolve_path(&[DATA_DIR, ASSET, "bert-quantized.onnx"]) 15 | } 16 | 17 | /// Resolves the path to the CNN model. 18 | pub fn cnn() -> Result { 19 | resolve_path(&[DATA_DIR, ASSET, "cnn.binparams"]) 20 | } 21 | 22 | /// Resolves the path to the Classifier model. 23 | pub fn classifier() -> Result { 24 | resolve_path(&[DATA_DIR, ASSET, "classifier.binparams"]) 25 | } 26 | 27 | #[cfg(test)] 28 | mod tests { 29 | use super::*; 30 | 31 | #[test] 32 | fn test_vocab() { 33 | assert!(vocab().is_ok()); 34 | } 35 | 36 | #[test] 37 | fn test_bert() { 38 | assert!(bert().is_ok()); 39 | } 40 | 41 | #[test] 42 | fn test_cnn() { 43 | assert!(cnn().is_ok()); 44 | } 45 | 46 | #[test] 47 | fn test_classifier() { 48 | assert!(classifier().is_ok()); 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /bindings/dart/lib/src/common/reranker/analytics.dart: -------------------------------------------------------------------------------- 1 | import 'package:json_annotation/json_annotation.dart' show JsonSerializable; 2 | import 'package:meta/meta.dart' show immutable; 3 | import 'package:xayn_ai_ffi_dart/src/common/utils.dart' show ToJson; 4 | 5 | part 'analytics.g.dart'; 6 | 7 | /// The analytics of the Xayn AI. 8 | @immutable 9 | @JsonSerializable() 10 | class Analytics implements ToJson { 11 | /// The nDCG@k score between the LTR ranking and the relevance based ranking. 12 | final double ndcgLtr; 13 | 14 | /// The nDCG@k score between the Context ranking and the relevance based ranking. 15 | final double ndcgContext; 16 | 17 | /// The nDCG@k score between the initial ranking and the relevance based ranking. 18 | final double ndcgInitialRanking; 19 | 20 | /// The nDCG@k score between the final ranking and the relevance based ranking. 21 | final double ndcgFinalRanking; 22 | 23 | /// Creates the analytics from the individual values. 24 | Analytics( 25 | this.ndcgLtr, 26 | this.ndcgContext, 27 | this.ndcgInitialRanking, 28 | this.ndcgFinalRanking, 29 | ); 30 | 31 | factory Analytics.fromJson(Map json) => _$AnalyticsFromJson(json); 32 | 33 | @override 34 | Map toJson() => _$AnalyticsToJson(this); 35 | } 36 | -------------------------------------------------------------------------------- /data/asset_templates/web_assets.dart.tmpl: -------------------------------------------------------------------------------- 1 | /// Warning, this file is autogenerated. Don't modify this manually. 2 | 3 | part of 'data_provider.dart'; 4 | 5 | {{ range $key, $value := (ds "assets_manifest").wasm_assets }} 6 | final _{{ print "wasm " $key | strings.CamelCase}} = { 7 | common.AssetType.wasmModule: common.Asset('{{$value.module.url_suffix}}', common.Checksum('{{$value.module.checksum}}'), []), 8 | common.AssetType.wasmScript: common.Asset('{{$value.script.url_suffix}}', common.Checksum('{{$value.script.checksum}}'), []), 9 | common.AssetType.webWorkerScript: common.Asset('{{$value.web_worker.url_suffix}}', common.Checksum('{{$value.web_worker.checksum}}'), []), 10 | }; 11 | {{- end }} 12 | 13 | /// Returns the most suitable wasm assets for the given features. 14 | Map getWasmAssets( 15 | Set features) { 16 | {{- if has (ds "assets_manifest").wasm_assets "multithreaded" }} 17 | if (features.containsAll(_multithreaded)) { 18 | return _wasmMultithreaded; 19 | } 20 | {{- end }} 21 | 22 | {{- if has (ds "assets_manifest").wasm_assets "sequential" }} 23 | return _wasmSequential; 24 | {{- else }} 25 | throw UnsupportedError('No suitable WASM assets are available'); 26 | {{- end }} 27 | } 28 | -------------------------------------------------------------------------------- /scripts/upload_assets.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Generates a manifest file for the data assets in `assets_manifest.json` and uploads them to the 4 | # given bucket URL. If an asset of the `data_assets` array contains a `chunk_size` key, the script 5 | # splits the asset into chunks where each chunk has a maximum size of `chunk_size`. The format of 6 | # the `chunk_size` value is equivalent to the `SIZE` argument in `split` or `gsplit` on macOS. 7 | # See `split`/`gsplit` man page for more details. The manifest file is written to the `out` 8 | # directory. 9 | # 10 | # Usage: 11 | # ./upload_assets 12 | set -e 13 | 14 | source $(dirname "$0")/assets_generation_lib.sh 15 | 16 | ASSET_METADATA=$(gen_data_assets_metadata_only "$1") 17 | BUCKET_URL="$2" 18 | 19 | for ASSET in $(cat "$ASSET_METADATA" | jq -c '.upload[]'); do 20 | ASSET_URL_SUFFIX=$(echo $ASSET | jq -r '.url_suffix') 21 | ASSET_PATH=$(echo $ASSET | jq -r '.path') 22 | s3cmd sync -v --acl-public --guess-mime-type --no-mime-magic --skip-existing $ASSET_PATH ${BUCKET_URL}/$ASSET_URL_SUFFIX 23 | done 24 | 25 | ASSET_MANIFEST="$(dirname "$ASSET_METADATA")/asset_manifest.json" 26 | jq 'del(.upload)' "$ASSET_METADATA" > $ASSET_MANIFEST 27 | 28 | echo "assets manifest path: $ASSET_MANIFEST" 29 | -------------------------------------------------------------------------------- /bindings/dart/lib/src/web/worker/message/utils.dart: -------------------------------------------------------------------------------- 1 | import 'dart:html' show MessagePort; 2 | import 'dart:typed_data' show Uint8List; 3 | 4 | import 'package:json_annotation/json_annotation.dart' show JsonConverter; 5 | 6 | /// A [Uint8List] from/to JSON converter. 7 | class Uint8ListConverter implements JsonConverter { 8 | const Uint8ListConverter(); 9 | 10 | @override 11 | Uint8List fromJson(Uint8List json) { 12 | return json; 13 | } 14 | 15 | @override 16 | Uint8List toJson(Uint8List object) { 17 | return object; 18 | } 19 | } 20 | 21 | /// A `Uint8List?` from/to JSON converter. 22 | class Uint8ListMaybeNullConverter 23 | implements JsonConverter { 24 | const Uint8ListMaybeNullConverter(); 25 | 26 | @override 27 | Uint8List? fromJson(Uint8List? json) { 28 | return json; 29 | } 30 | 31 | @override 32 | Uint8List? toJson(Uint8List? object) { 33 | return object; 34 | } 35 | } 36 | 37 | /// A `MessagePort?` from/to JSON converter. 38 | class MessagePortConverter 39 | implements JsonConverter { 40 | const MessagePortConverter(); 41 | 42 | @override 43 | MessagePort? fromJson(MessagePort? json) { 44 | return json; 45 | } 46 | 47 | @override 48 | MessagePort? toJson(MessagePort? object) { 49 | return object; 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /bindings/dart/lib/src/web/data/document.dart: -------------------------------------------------------------------------------- 1 | @JS() 2 | library document; 3 | 4 | import 'package:js/js.dart' show anonymous, JS; 5 | 6 | import 'package:xayn_ai_ffi_dart/src/common/data/document.dart' show Document; 7 | 8 | @JS() 9 | @anonymous 10 | class JsDocument { 11 | external factory JsDocument({ 12 | String id, 13 | int rank, 14 | String title, 15 | String snippet, 16 | String session, 17 | // ignore: non_constant_identifier_names 18 | int query_count, 19 | // ignore: non_constant_identifier_names 20 | String query_id, 21 | // ignore: non_constant_identifier_names 22 | String query_words, 23 | String url, 24 | String domain, 25 | }); 26 | } 27 | 28 | extension ToJsDocuments on List { 29 | /// Creates JS documents from the current documents. 30 | List toJsDocuments() => List.generate( 31 | length, 32 | (i) => JsDocument( 33 | id: this[i].id, 34 | rank: this[i].rank, 35 | title: this[i].title, 36 | snippet: this[i].snippet, 37 | session: this[i].session, 38 | query_count: this[i].queryCount, 39 | query_id: this[i].queryId, 40 | query_words: this[i].queryWords, 41 | url: this[i].url, 42 | domain: this[i].domain, 43 | ), 44 | growable: false, 45 | ); 46 | } 47 | -------------------------------------------------------------------------------- /xayn-ai/src/ranker/document.rs: -------------------------------------------------------------------------------- 1 | use chrono::NaiveDateTime; 2 | 3 | use crate::{embedding::utils::Embedding, DocumentId}; 4 | 5 | pub trait Document { 6 | /// Gets the document id. 7 | fn id(&self) -> DocumentId; 8 | 9 | /// Gets the SMBert embedding of the document. 10 | fn smbert_embedding(&self) -> &Embedding; 11 | 12 | /// Gets the publishing date. 13 | fn date_published(&self) -> NaiveDateTime; 14 | } 15 | 16 | #[cfg(test)] 17 | pub(super) struct TestDocument { 18 | pub(super) id: DocumentId, 19 | pub(super) smbert_embedding: Embedding, 20 | pub(super) date_published: NaiveDateTime, 21 | } 22 | 23 | #[cfg(test)] 24 | impl TestDocument { 25 | pub(super) fn new(id: u128, embedding: impl Into, date_published: &str) -> Self { 26 | Self { 27 | id: DocumentId::from_u128(id), 28 | smbert_embedding: embedding.into(), 29 | date_published: NaiveDateTime::parse_from_str(date_published, "%Y-%m-%d %H:%M:%S") 30 | .unwrap(), 31 | } 32 | } 33 | } 34 | 35 | #[cfg(test)] 36 | impl Document for TestDocument { 37 | fn id(&self) -> DocumentId { 38 | self.id 39 | } 40 | 41 | fn smbert_embedding(&self) -> &Embedding { 42 | &self.smbert_embedding 43 | } 44 | 45 | fn date_published(&self) -> NaiveDateTime { 46 | self.date_published 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /bindings/dart/lib/src/web/reranker/analytics.dart: -------------------------------------------------------------------------------- 1 | @JS() 2 | library analytics; 3 | 4 | import 'package:js/js.dart' show anonymous, JS; 5 | 6 | import 'package:xayn_ai_ffi_dart/src/common/reranker/analytics.dart' 7 | show Analytics; 8 | 9 | @JS() 10 | @anonymous 11 | class JsAnalytics { 12 | // ignore: non_constant_identifier_names 13 | external double get ndcg_ltr; 14 | 15 | // ignore: non_constant_identifier_names 16 | external double get ndcg_context; 17 | 18 | // ignore: non_constant_identifier_names 19 | external double get ndcg_initial_ranking; 20 | 21 | // ignore: non_constant_identifier_names 22 | external double get ndcg_final_ranking; 23 | 24 | external factory JsAnalytics({ 25 | // ignore: non_constant_identifier_names, unused_element 26 | double ndcg_ltr, 27 | // ignore: non_constant_identifier_names, unused_element 28 | double ndcg_context, 29 | // ignore: non_constant_identifier_names, unused_element 30 | double ndcg_initial_ranking, 31 | // ignore: non_constant_identifier_names, unused_element 32 | double ndcg_final_ranking, 33 | }); 34 | } 35 | 36 | extension ToAnalytics on JsAnalytics { 37 | /// Creates analytics from the current JS analytics. 38 | Analytics toAnalytics() => Analytics( 39 | ndcg_ltr, 40 | ndcg_context, 41 | ndcg_initial_ranking, 42 | ndcg_final_ranking, 43 | ); 44 | } 45 | -------------------------------------------------------------------------------- /bindings/dart/lib/src/web/ffi/library.dart: -------------------------------------------------------------------------------- 1 | @JS() 2 | library library; 3 | 4 | import 'dart:html' show WorkerGlobalScope; 5 | 6 | import 'package:js/js.dart' show JS; 7 | import 'package:js/js_util.dart' show promiseToFuture; 8 | 9 | import 'package:xayn_ai_ffi_dart/src/common/reranker/utils.dart' 10 | show selectThreadPoolSize; 11 | 12 | @JS('Promise') 13 | class _Promise {} 14 | 15 | @JS('WebAssembly.Exports') 16 | class Wasm {} 17 | 18 | @JS('xayn_ai_ffi_wasm.default') 19 | external _Promise _init([ 20 | // ignore: non_constant_identifier_names 21 | dynamic module_or_path, 22 | ]); 23 | 24 | @JS('xayn_ai_ffi_wasm.initThreadPool') 25 | external _Promise _initThreadPool(int numberOfThreads); 26 | 27 | /// Initializes the wasm module. 28 | /// 29 | /// If `moduleOrPath` is a `RequestInfo` or `URL`, makes a request and 30 | /// for everything else, calls `WebAssembly.instantiate` directly. 31 | Future init([dynamic moduleOrPath]) async { 32 | final wasm = await promiseToFuture(_init(moduleOrPath)); 33 | 34 | // Most devices have 4+ hardware threads, but if the browser doesn't support 35 | // the property it's probably old so we default to 2. 36 | var hardwareThreads = selectThreadPoolSize( 37 | WorkerGlobalScope.instance.navigator.hardwareConcurrency ?? 2); 38 | 39 | await promiseToFuture(_initThreadPool(hardwareThreads)); 40 | return wasm; 41 | } 42 | -------------------------------------------------------------------------------- /.ci/install-cargo-sort/action.yml: -------------------------------------------------------------------------------- 1 | name: 'cargo-sort' 2 | description: 'Installs cargo-sort' 3 | runs: 4 | using: "composite" 5 | steps: 6 | - shell: bash 7 | working-directory: ${{ runner.temp }} 8 | env: 9 | LINUX_URL: https://github.com/DevinR528/cargo-sort/releases/download/v1.0.5/cargo-sort-x86_64-unknown-linux-gnu.tar.gz 10 | LINUX_CHECKSUM: ad909aed897f0eb4cda43f3884a5f7a4d403b7b8f1645e8f16ead33a7bbbf79cdf0ed85382777c70b7823d0d9f291dfd63dbbd63a8b8f5853acc00bbb3e8aa61 11 | MACOS_URL: https://github.com/DevinR528/cargo-sort/releases/download/v1.0.5/cargo-sort-x86_64-apple-darwin.tar.gz 12 | MACOS_CHECKSUM: b838f6333a47a649b2ea17e50803fe8dc89885570c3c2a97de9e92679f517b052de561db1a302aa265ad93c8db4731de4c032a0a84f5ec62ae5ff5f09693de4d 13 | run: | 14 | if [ ${{ runner.os }} == "Linux" ]; then 15 | URL=${{ env.LINUX_URL }} 16 | CHECKSUM=${{ env.LINUX_CHECKSUM }} 17 | elif [ ${{ runner.os }} == "macOS" ]; then 18 | URL=${{ env.MACOS_URL }} 19 | CHECKSUM=${{ env.MACOS_CHECKSUM }} 20 | else 21 | echo "::error cargo-sort for ${{ runner.os }} is not supported" 22 | exit 1 23 | fi 24 | 25 | wget -q -O cargo-sort $URL 26 | echo "$CHECKSUM *cargo-sort" | shasum -c - 27 | tar -xf cargo-sort 28 | chmod u+x cargo-sort 29 | mv cargo-sort $HOME/.cargo/bin 30 | -------------------------------------------------------------------------------- /bindings/dart/lib/src/common/reranker/mode.dart: -------------------------------------------------------------------------------- 1 | import 'package:json_annotation/json_annotation.dart' show JsonValue; 2 | 3 | import 'package:xayn_ai_ffi_dart/src/common/ffi/genesis.dart' as ffi 4 | show RerankMode; 5 | 6 | /// Rerank mode 7 | enum RerankMode { 8 | @JsonValue(ffi.RerankMode.StandardNews) 9 | standardNews, 10 | @JsonValue(ffi.RerankMode.PersonalizedNews) 11 | personalizedNews, 12 | @JsonValue(ffi.RerankMode.StandardSearch) 13 | standardSearch, 14 | @JsonValue(ffi.RerankMode.PersonalizedSearch) 15 | personalizedSearch, 16 | } 17 | 18 | extension RerankModeToInt on RerankMode { 19 | /// Gets the discriminant. 20 | int toInt() { 21 | // We can't use `_$RerankModeEnumMap` as it only gets generated for 22 | // files which have a `@JsonSerializable` type containing the enum. 23 | // You can't make enums `@JsonSerializable`. Given that `RerankMode` 24 | // has only few variants and rarely changes we just write this switch 25 | // statement by hand. 26 | switch (this) { 27 | case RerankMode.standardNews: 28 | return ffi.RerankMode.StandardNews; 29 | case RerankMode.personalizedNews: 30 | return ffi.RerankMode.PersonalizedNews; 31 | case RerankMode.standardSearch: 32 | return ffi.RerankMode.StandardSearch; 33 | case RerankMode.personalizedSearch: 34 | return ffi.RerankMode.PersonalizedSearch; 35 | } 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /xayn-ai/src/lib.rs: -------------------------------------------------------------------------------- 1 | #![forbid(unsafe_op_in_unsafe_fn)] 2 | 3 | mod analytics; 4 | mod coi; 5 | mod context; 6 | mod data; 7 | mod embedding; 8 | mod error; 9 | mod ltr; 10 | pub mod ranker; 11 | mod reranker; 12 | mod utils; 13 | 14 | pub use crate::{ 15 | analytics::Analytics, 16 | coi::CoiId, 17 | data::document::{ 18 | DayOfWeek, 19 | Document, 20 | DocumentHistory, 21 | DocumentId, 22 | QueryId, 23 | Relevance, 24 | RerankingOutcomes, 25 | SessionId, 26 | UserAction, 27 | UserFeedback, 28 | }, 29 | error::Error, 30 | reranker::{ 31 | public::{Builder, Reranker}, 32 | RerankMode, 33 | }, 34 | }; 35 | 36 | // We need to re-export these, since they encapsulate the arguments 37 | // required for pipeline construction, and are passed to builders. 38 | pub use kpe::Config as KpeConfig; 39 | pub use rubert::{QAMBertConfig, SMBertConfig}; 40 | 41 | #[cfg(test)] 42 | mod tests; 43 | 44 | // we need to export rstest_reuse from the root for it to work. 45 | // `use rstest_reuse` will trigger `clippy::single_component_path_imports` 46 | // which is not possible to silence. 47 | #[cfg(test)] 48 | #[allow(unused_imports)] 49 | #[rustfmt::skip] 50 | pub(crate) use rstest_reuse as rstest_reuse; 51 | 52 | // Reexport for the dev-tool 53 | #[doc(hidden)] 54 | pub use crate::ltr::{list_net, list_net_training_data_from_history}; 55 | -------------------------------------------------------------------------------- /.ci/install-gomplate/action.yml: -------------------------------------------------------------------------------- 1 | name: 'gomplate' 2 | description: 'Installs Gomplate' 3 | runs: 4 | using: "composite" 5 | steps: 6 | - shell: bash 7 | working-directory: ${{ runner.temp }} 8 | env: 9 | # https://github.com/hairyhenderson/gomplate/releases/download/v3.10.0/checksums-v3.10.0_sha256.txt 10 | LINUX_URL: https://github.com/hairyhenderson/gomplate/releases/download/v3.10.0/gomplate_linux-amd64 11 | LINUX_CHECKSUM: eec0f85433c9c8aad93e8cd84c79d238f436b3e62f35b15471f5929bc741763a 12 | MACOS_URL: https://github.com/hairyhenderson/gomplate/releases/download/v3.10.0/gomplate_darwin-amd64 13 | MACOS_CHECKSUM: 9eb031e2c32226708a7a67cd8e5139fea9c9dbe0fed0c2a5959d224e8a6353e0 14 | run: | 15 | if [ ${{ runner.os }} == "Linux" ]; then 16 | URL=${{ env.LINUX_URL }} 17 | CHECKSUM=${{ env.LINUX_CHECKSUM }} 18 | INSTALL_PATH=$HOME/.local/bin 19 | elif [ ${{ runner.os }} == "macOS" ]; then 20 | URL=${{ env.MACOS_URL }} 21 | CHECKSUM=${{ env.MACOS_CHECKSUM }} 22 | INSTALL_PATH=$HOME/bin 23 | else 24 | echo "::error gomplate for ${{ runner.os }} is not supported" 25 | exit 1 26 | fi 27 | 28 | wget -q -O gomplate $URL 29 | echo "$CHECKSUM *gomplate" | shasum -c - 30 | chmod u+x gomplate 31 | mkdir -p $INSTALL_PATH 32 | mv gomplate $INSTALL_PATH 33 | -------------------------------------------------------------------------------- /bindings/dart/lib/src/mobile/reranker/analytics.dart: -------------------------------------------------------------------------------- 1 | import 'dart:ffi' show nullptr, Pointer, StructPointer; 2 | 3 | import 'package:xayn_ai_ffi_dart/src/common/reranker/analytics.dart' 4 | show Analytics; 5 | import 'package:xayn_ai_ffi_dart/src/mobile/ffi/genesis.dart' show CAnalytics; 6 | import 'package:xayn_ai_ffi_dart/src/mobile/ffi/library.dart' show ffi; 7 | 8 | /// The raw analytics. 9 | class AnalyticsBuilder { 10 | final Pointer _cAnalytics; 11 | bool _freed = false; 12 | 13 | /// Creates the analytics from a pointer. 14 | /// 15 | /// This constructor never throws an exception. 16 | AnalyticsBuilder(this._cAnalytics); 17 | 18 | /// Builds the analytics from raw. 19 | Analytics? build() { 20 | if (_freed) { 21 | throw StateError('CAnalytics already freed'); 22 | } else if (_cAnalytics == nullptr) { 23 | // Analytics might not be provided 24 | return null; 25 | } else { 26 | final cval = _cAnalytics.ref; 27 | return Analytics( 28 | cval.ndcg_ltr, 29 | cval.ndcg_context, 30 | cval.ndcg_initial_ranking, 31 | cval.ndcg_final_ranking, 32 | ); 33 | } 34 | } 35 | 36 | /// Frees the memory. 37 | void free() { 38 | if (!_freed) { 39 | _freed = true; 40 | // drop impl's are nullptr safe, but we don't want to call into ffi in tests 41 | if (_cAnalytics != nullptr) ffi.analytics_drop(_cAnalytics); 42 | } 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /dev-tool/src/utils.rs: -------------------------------------------------------------------------------- 1 | use indicatif::{ProgressBar, ProgressStyle}; 2 | 3 | pub(crate) fn progress_spin_until_done(msg: &'static str, func: impl FnOnce() -> R) -> R { 4 | let progress_bar = ProgressBar::new_spinner() 5 | .with_style(ProgressStyle::default_bar().template("{msg}: {elapsed:>10} {spinner:.green}")); 6 | progress_bar.set_message(msg); 7 | progress_bar.enable_steady_tick(100); 8 | let res = func(); 9 | progress_bar.finish(); 10 | res 11 | } 12 | 13 | /// Provides functionality to (de-)serialize optional bytes as base64 string. 14 | /// 15 | /// Use it with the `#[serde(with=serde_opt_bytes_as_base64)]` annotation. 16 | pub mod serde_opt_bytes_as_base64 { 17 | use serde::{Deserialize, Deserializer, Serialize, Serializer}; 18 | 19 | pub fn serialize(bytes: &Option>, serializer: S) -> Result 20 | where 21 | S: Serializer, 22 | { 23 | let encoded = bytes.as_ref().map(base64::encode); 24 | encoded.serialize(serializer) 25 | } 26 | 27 | pub fn deserialize<'de, D>(deserializer: D) -> Result>, D::Error> 28 | where 29 | D: Deserializer<'de>, 30 | { 31 | if let Some(encoded) = Option::::deserialize(deserializer)? { 32 | base64::decode(encoded) 33 | .map(Some) 34 | .map_err(::custom) 35 | } else { 36 | Ok(None) 37 | } 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /dev-tool/src/main.rs: -------------------------------------------------------------------------------- 1 | #![forbid(unsafe_op_in_unsafe_fn)] 2 | #![cfg(not(tarpaulin))] 3 | #![cfg(not(any(target_os = "android", target_os = "ios")))] 4 | use std::{process::exit, time::Instant}; 5 | 6 | use anyhow::Error; 7 | use indicatif::HumanDuration; 8 | use log::error; 9 | use structopt::StructOpt; 10 | 11 | use crate::exit_code::{FATAL_ERROR, NO_ERROR}; 12 | 13 | mod call_data; 14 | mod exit_code; 15 | mod list_net; 16 | mod utils; 17 | 18 | /// Tooling for the developers of XaynAi. 19 | #[derive(StructOpt, Debug)] 20 | enum CommandArgs { 21 | CallData(call_data::CallDataCmd), 22 | ListNet(list_net::ListNetCmd), 23 | } 24 | 25 | impl CommandArgs { 26 | fn run(self) -> Result { 27 | use CommandArgs::*; 28 | 29 | match self { 30 | CallData(cmd) => cmd.run(), 31 | ListNet(cmd) => cmd.run(), 32 | } 33 | } 34 | } 35 | 36 | fn main() { 37 | env_logger::init(); 38 | 39 | let start_time = Instant::now(); 40 | 41 | let exit_code = match CommandArgs::from_args().run() { 42 | Ok(exit_code) => exit_code, 43 | Err(error) => { 44 | error!("FATAL: {}\n{:?}", error, error); 45 | FATAL_ERROR 46 | } 47 | }; 48 | 49 | let duration = HumanDuration(start_time.elapsed()); 50 | if exit_code == NO_ERROR { 51 | eprintln!("DONE ({})", duration); 52 | } else { 53 | eprintln!("EXIT WITH ERRORS ({})", duration); 54 | } 55 | exit(exit_code); 56 | } 57 | -------------------------------------------------------------------------------- /rubert-tokenizer/src/pre_tokenizer/string.rs: -------------------------------------------------------------------------------- 1 | use crate::normalizer::string::NormalizedString; 2 | 3 | /// A pre-tokenized sequence. 4 | pub struct PreTokenizedString { 5 | pub original: String, 6 | pub splits: Vec, 7 | } 8 | 9 | impl From for PreTokenizedString { 10 | fn from(sequence: NormalizedString) -> Self { 11 | Self { 12 | original: sequence.original.clone(), 13 | splits: vec![sequence], 14 | } 15 | } 16 | } 17 | 18 | impl PreTokenizedString { 19 | /// Splits wrt the function. 20 | /// 21 | /// The function takes a normalized sequence and returns an iterator over normalized 22 | /// subsequences. The combined normalized subsequences must have the same original sequence as 23 | /// the normalized sequence. 24 | pub fn split(mut self, f: F) -> Self 25 | where 26 | F: Fn(usize, NormalizedString) -> S, 27 | S: IntoIterator, 28 | { 29 | // new_splits is at least as big as self.splits 30 | let mut new_splits = Vec::with_capacity(self.splits.len()); 31 | for (i, original_split) in self.splits.drain(..).enumerate() { 32 | new_splits.extend( 33 | f(i, original_split) 34 | .into_iter() 35 | .filter(|split| !split.normalized.is_empty()), 36 | ); 37 | } 38 | self.splits = new_splits; 39 | 40 | self 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /rubert/examples/mbert.rs: -------------------------------------------------------------------------------- 1 | //! Run as `cargo run --example mbert ` with ``: 2 | //! - `s` for SMBert 3 | //! - `qa` for QAMBert 4 | 5 | use rubert::{Config, FirstPooler, Pipeline, QAMBertConfig, SMBertConfig}; 6 | use test_utils::{qambert, smbert}; 7 | 8 | fn main() { 9 | let (embedding, size) = match std::env::args().nth(1).unwrap().as_str() { 10 | "s" => { 11 | let config: SMBertConfig<_> = 12 | Config::from_files(smbert::vocab().unwrap(), smbert::model().unwrap()) 13 | .unwrap() 14 | .with_pooling(FirstPooler) 15 | .with_token_size(64) 16 | .unwrap(); 17 | 18 | let mbert = Pipeline::from(config).unwrap(); 19 | ( 20 | mbert.run("This is a sequence.").unwrap(), 21 | mbert.embedding_size(), 22 | ) 23 | } 24 | "qa" => { 25 | let config: QAMBertConfig<_> = 26 | Config::from_files(qambert::vocab().unwrap(), qambert::model().unwrap()) 27 | .unwrap() 28 | .with_pooling(FirstPooler); 29 | 30 | let mbert = Pipeline::from(config).unwrap(); 31 | ( 32 | mbert.run("This is a sequence.").unwrap(), 33 | mbert.embedding_size(), 34 | ) 35 | } 36 | _ => panic!("unknown MBert kind"), 37 | }; 38 | println!("{}", *embedding); 39 | assert_eq!(embedding.shape(), [size]); 40 | } 41 | -------------------------------------------------------------------------------- /xayn-ai-ffi-c/src/result/mod.rs: -------------------------------------------------------------------------------- 1 | //! Error handling types. 2 | 3 | pub(crate) mod error; 4 | pub(crate) mod fault; 5 | 6 | use std::panic::{catch_unwind, UnwindSafe}; 7 | 8 | use xayn_ai_ffi::Error; 9 | 10 | #[cfg(not(doc))] 11 | use crate::result::error::CError; 12 | use crate::utils::IntoRaw; 13 | 14 | #[cfg(doc)] 15 | pub use self::{ 16 | error::{error_message_drop, CError}, 17 | fault::{faults_drop, CFaults}, 18 | }; 19 | 20 | /// Calls a callback which returns a result. 21 | /// 22 | /// Catches an unwinding panic with optional error handling: 23 | /// - Ok: returns `T`'s FFI value. 24 | /// - Error/Panic: returns `T`'s default FFI value and optionally reports an error. 25 | pub(crate) fn call_with_result(call: F, error: Option<&mut CError>) -> T::Value 26 | where 27 | F: UnwindSafe + FnOnce() -> Result, 28 | T: IntoRaw, 29 | E: Into, 30 | { 31 | match catch_unwind(call) { 32 | Ok(Ok(value)) => { 33 | if let Some(error) = error { 34 | *error = Error::none().into_raw(); 35 | } 36 | value.into_raw() 37 | } 38 | Ok(Err(cause)) => { 39 | if let Some(error) = error { 40 | *error = cause.into().into_raw(); 41 | } 42 | T::Value::default() 43 | } 44 | Err(cause) => { 45 | if let Some(error) = error { 46 | *error = Error::panic(cause).into_raw(); 47 | } 48 | T::Value::default() 49 | } 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /bindings/dart/example/lib/debug/mobile/print.dart: -------------------------------------------------------------------------------- 1 | import 'package:flutter/material.dart' show debugPrint; 2 | 3 | final _printChunkedRegex = RegExp(r'(.{0,80})'); 4 | 5 | /// `debugPrint` for long text. 6 | /// 7 | /// Workaround for problems with string concatenation. 8 | /// 9 | /// When using `print` or `debugPrint`, longer messages might 10 | /// get truncated, at least when running it on an android 11 | /// phone over adb using `flutter run`. 12 | /// 13 | /// Using `debugPrint(test, wrapWidth: 1024)` is a workaround 14 | /// which often works, but not always because: 15 | /// 16 | /// - It uses full word line wrapping, and as such won't work 17 | /// if there are any longer "words" (like base64 blobs). 18 | /// 19 | /// - It seems (unclear) that the limit of `1024` might not 20 | /// always be small enough. 21 | /// 22 | /// So this is an "ad-hoc" workaround: 23 | /// 24 | /// - We split the string into chunks of 80 characters, the 25 | /// simplest way to do so is using the given regex. 26 | /// 27 | /// - If this wasn't just some ad-hoc debug helper we probably 28 | /// would implement splitting by at most 80 bytes at character 29 | /// boundary. 30 | /// 31 | /// - 80 is an arbitrarily chosen value which is not too large even 32 | /// with unicode and works well with "small" terminals. 33 | void debugPrintLongText(String text) { 34 | debugPrint('--- START Chunks ----'); 35 | _printChunkedRegex 36 | .allMatches(text) 37 | .forEach((match) => debugPrint(match.group(0)!)); 38 | debugPrint('--- END Chunks ----'); 39 | } 40 | -------------------------------------------------------------------------------- /rubert/src/lib.rs: -------------------------------------------------------------------------------- 1 | #![cfg_attr( 2 | doc, 3 | forbid(rustdoc::broken_intra_doc_links, rustdoc::private_intra_doc_links) 4 | )] 5 | #![forbid(unsafe_op_in_unsafe_fn)] 6 | //! The RuBert pipeline computes embeddings of sequences. 7 | //! 8 | //! Sequences are anything string-like and can also be single words or snippets. The embeddings are 9 | //! f32-arrays and their shape depends on the pooling strategy. 10 | //! 11 | //! See the example in this crate for usage details. 12 | 13 | mod config; 14 | mod model; 15 | mod pipeline; 16 | mod pooler; 17 | mod tokenizer; 18 | 19 | pub use crate::{ 20 | config::{Config, ConfigError}, 21 | model::kinds, 22 | pipeline::{Pipeline, PipelineError}, 23 | pooler::{ 24 | ArcEmbedding1, 25 | ArcEmbedding2, 26 | AveragePooler, 27 | Embedding1, 28 | Embedding2, 29 | FirstPooler, 30 | NonePooler, 31 | }, 32 | }; 33 | 34 | /// A sentence (embedding) multilingual Bert pipeline. 35 | #[allow(clippy::upper_case_acronyms)] 36 | pub type SMBert = Pipeline; 37 | pub type SMBertConfig<'a, P> = Config<'a, kinds::SMBert, P>; 38 | 39 | /// A question answering (embedding) multilingual Bert pipeline. 40 | #[allow(clippy::upper_case_acronyms)] 41 | pub type QAMBert = Pipeline; 42 | pub type QAMBertConfig<'a, P> = Config<'a, kinds::QAMBert, P>; 43 | 44 | #[cfg(doc)] 45 | pub use crate::{ 46 | model::{BertModel, ModelError}, 47 | pooler::{Embedding, PoolerError}, 48 | tokenizer::TokenizerError, 49 | }; 50 | -------------------------------------------------------------------------------- /rubert/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "rubert" 3 | version = "0.1.0" 4 | authors = ["Xayn Engineering "] 5 | edition = "2018" 6 | 7 | [dependencies] 8 | derive_more = { version = "0.99.17", default-features = false, features = ["deref", "from"] } 9 | displaydoc = "0.2.3" 10 | float-cmp = "0.9.0" 11 | # to be kept in sync with tract-core 12 | ndarray = { version = "=0.15.3", features = ["serde"] } 13 | rubert-tokenizer = { path = "../rubert-tokenizer" } 14 | serde = { version = "1.0.136", features = ["derive"] } 15 | thiserror = "1.0.30" 16 | tract-onnx = "0.16.1" 17 | 18 | # features 19 | criterion = { version = "0.3.5", features = ["html_reports"], optional = true } 20 | csv = { version = "1.1.6", optional = true } 21 | indicatif = { version = "0.16.2", optional = true } 22 | onnxruntime = { version = "0.0.13", optional = true } 23 | rayon = { version = "1.5.1", optional = true } 24 | 25 | [dev-dependencies] 26 | test-utils = { path = "../test-utils" } 27 | 28 | [features] 29 | bench = ["criterion", "onnxruntime", "rayon"] 30 | validate = ["csv", "indicatif", "onnxruntime"] 31 | 32 | [[example]] 33 | name = "mbert" 34 | 35 | [[example]] 36 | name = "validate" 37 | required-features = ["validate"] 38 | 39 | [[bench]] 40 | name = "matmul" 41 | harness = false 42 | bench = false 43 | required-features = ["bench"] 44 | 45 | [[bench]] 46 | name = "mbert" 47 | harness = false 48 | bench = false 49 | required-features = ["bench"] 50 | 51 | [[bench]] 52 | name = "multithreaded" 53 | harness = false 54 | bench = false 55 | required-features = ["bench"] 56 | -------------------------------------------------------------------------------- /xayn-ai/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "xayn-ai" 3 | version = "0.1.0" 4 | authors = ["Xayn Engineering "] 5 | edition = "2018" 6 | 7 | [dependencies] 8 | anyhow = "1.0.56" 9 | bincode = "1.3.3" 10 | chrono = { version = "0.4.19", default-features = false } 11 | derivative = "2.2.0" 12 | derive_more = { version = "0.99.17", default-features = false, features = ["deref", "display", "from", "into"] } 13 | displaydoc = "0.2.3" 14 | itertools = "0.10.3" 15 | kpe = { path = "../kpe" } 16 | layer = { path = "../layer" } 17 | lazy_static = "1.4.0" 18 | # to be kept in sync with rubert 19 | ndarray = "=0.15.3" 20 | # TODO: use version 1.0.5 once it is released 21 | obake = { git = "https://github.com/doctorn/obake", rev = "d6bea07e355ca4adf353c1e627f13c8c3286361b" } 22 | rand = "0.8.5" 23 | regex = { version = "1.5.5", features = ["unicode-gencat"] } 24 | rubert = { path = "../rubert" } 25 | serde = { version = "1.0.136", features = ["derive"] } 26 | serde_repr = "0.1.7" 27 | smallvec = "1.8.0" 28 | thiserror = "1.0.30" 29 | uuid = { version = "0.8.2", features = ["serde", "wasm-bindgen", "v4"] } 30 | 31 | # multithreaded feature 32 | rayon = { version = "1.5.1", optional = true } 33 | 34 | [target.'cfg(target_arch = "wasm32")'.dependencies] 35 | js-sys = "0.3.56" 36 | 37 | [dev-dependencies] 38 | csv = "1.1.6" 39 | mockall = "0.11.0" 40 | once_cell = "1.10.0" 41 | paste = "1.0.7" 42 | rstest = "0.12.0" 43 | rstest_reuse = "0.3.0" 44 | serde_json = "1.0.79" 45 | test-utils = { path = "../test-utils" } 46 | 47 | [features] 48 | multithreaded = ["rayon"] 49 | -------------------------------------------------------------------------------- /xayn-ai-ffi-wasm/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "xayn-ai-ffi-wasm" 3 | version = "0.1.0" 4 | authors = ["Xayn Engineering "] 5 | edition = "2018" 6 | 7 | [dependencies] 8 | console_error_panic_hook = "0.1.7" 9 | getrandom = { version = "0.2.5", features = ["js"] } 10 | js-sys = "0.3.56" 11 | serde = { version = "1.0.136", features = ["derive"] } 12 | wasm-bindgen = { version = "=0.2.79", features = ["serde-serialize"] } 13 | xayn-ai-ffi = { path = "../xayn-ai-ffi" } 14 | 15 | # We use the "atomics" `target_feature` to enable parallelism instead of a 16 | # crate feature. This is necessary, as using a "normal" feature will break 17 | # "cargo clippy --all-targets --all-features" and similar. Furthermore we 18 | # always want to use parallelism if our target supports it, which this 19 | # setups represents fairly well. 20 | [target.'cfg(all(target_arch = "wasm32", target_feature = "atomics"))'.dependencies] 21 | wasm-bindgen-rayon = "1.0.3" 22 | xayn-ai = { path = "../xayn-ai", features = ["multithreaded"] } 23 | 24 | [target.'cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))'.dependencies] 25 | xayn-ai = { path = "../xayn-ai" } 26 | 27 | [dev-dependencies] 28 | itertools = "0.10.3" 29 | wasm-bindgen-test = "0.3.29" 30 | 31 | [lib] 32 | crate-type = ["cdylib"] 33 | 34 | [package.metadata.wasm-pack.profile.profiling] 35 | # -g is required to keep the original function names 36 | wasm-opt = ['-O', '-g'] 37 | 38 | [package.metadata.wasm-pack.profile.release] 39 | wasm-opt = ['-Oz'] 40 | 41 | [features] 42 | default = ["browser"] 43 | browser = [] 44 | node = [] 45 | -------------------------------------------------------------------------------- /xayn-ai/src/coi/mod.rs: -------------------------------------------------------------------------------- 1 | pub(crate) mod config; 2 | pub(crate) mod key_phrase; 3 | mod merge; 4 | pub(crate) mod point; 5 | mod relevance; 6 | mod stats; 7 | mod system; 8 | mod utils; 9 | 10 | #[cfg(test)] 11 | pub(crate) use self::{ 12 | system::{compute_coi, update_user_interests, CoiSystemError}, 13 | utils::tests::{create_neg_cois, create_pos_cois}, 14 | }; 15 | pub(crate) use merge::reduce_cois; 16 | pub(crate) use point::find_closest_coi; 17 | pub(crate) use relevance::RelevanceMap; 18 | pub(crate) use stats::compute_coi_decay_factor; 19 | pub(crate) use system::{CoiSystem, NeutralCoiSystem}; 20 | 21 | use derive_more::From; 22 | use displaydoc::Display; 23 | use serde::{Deserialize, Serialize}; 24 | use thiserror::Error; 25 | use uuid::Uuid; 26 | 27 | use crate::embedding::utils::ArcEmbedding; 28 | #[cfg(test)] 29 | use crate::tests::mock_uuid; 30 | 31 | #[repr(transparent)] // needed for FFI 32 | #[derive( 33 | Debug, PartialEq, Eq, Hash, Clone, Copy, PartialOrd, Ord, Serialize, Deserialize, From, 34 | )] 35 | pub struct CoiId(Uuid); 36 | 37 | #[cfg(test)] 38 | impl CoiId { 39 | /// Creates a mocked CoI id from a mocked UUID, cf. [`mock_uuid()`]. 40 | pub(crate) const fn mocked(sub_id: usize) -> Self { 41 | Self(mock_uuid(sub_id)) 42 | } 43 | } 44 | 45 | #[derive(Debug, Display, Error)] 46 | pub(crate) enum CoiError { 47 | /// A key phrase is empty 48 | EmptyKeyPhrase, 49 | /// A key phrase has non-finite embedding values: {0:#?} 50 | NonFiniteKeyPhrase(ArcEmbedding), 51 | /// A computed relevance score isn't finite. 52 | NonFiniteRelevance, 53 | } 54 | -------------------------------------------------------------------------------- /xayn-ai-ffi-c/build.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | env, 3 | fs::read_dir, 4 | path::{Path, PathBuf}, 5 | }; 6 | 7 | use cbindgen::{generate_with_config, Config}; 8 | 9 | // cargo doesn't check directories recursively so we have to do it by hand, also emitting a 10 | // rerun-if line cancels the default rerun for changes in the crate directory 11 | fn cargo_rerun_if_changed(entry: impl AsRef) { 12 | let entry = entry.as_ref(); 13 | if entry.is_dir() { 14 | for entry in read_dir(entry).expect("Failed to read dir.") { 15 | cargo_rerun_if_changed(entry.expect("Failed to read entry.").path()); 16 | } 17 | } else { 18 | println!("cargo:rerun-if-changed={}", entry.display()); 19 | } 20 | } 21 | 22 | fn main() { 23 | let crate_dir = PathBuf::from( 24 | env::var("CARGO_MANIFEST_DIR").expect("Failed to read CARGO_MANIFEST_DIR env."), 25 | ); 26 | 27 | let config_file = crate_dir.join("cbindgen.toml"); 28 | let header_file = crate_dir 29 | .parent() 30 | .unwrap() 31 | .join("bindings") 32 | .join("dart") 33 | .join("ios") 34 | .join("Classes") 35 | .join("XaynAiFfiDart.h"); 36 | 37 | cargo_rerun_if_changed(crate_dir.join("src")); 38 | cargo_rerun_if_changed(crate_dir.join("Cargo.toml")); 39 | cargo_rerun_if_changed(config_file.as_path()); 40 | 41 | let config = Config::from_file(config_file).expect("Failed to read config."); 42 | generate_with_config(crate_dir, config) 43 | .expect("Failed to generate bindings.") 44 | .write_to_file(header_file); 45 | } 46 | -------------------------------------------------------------------------------- /bindings/dart/android/src/main/kotlin/com/xayn/xayn_ai_ffi_dart/XaynAiFfiDartPlugin.kt: -------------------------------------------------------------------------------- 1 | package com.xayn.xayn_ai_ffi_dart 2 | 3 | import androidx.annotation.NonNull 4 | 5 | import io.flutter.embedding.engine.plugins.FlutterPlugin 6 | import io.flutter.plugin.common.MethodCall 7 | import io.flutter.plugin.common.MethodChannel 8 | import io.flutter.plugin.common.MethodChannel.MethodCallHandler 9 | import io.flutter.plugin.common.MethodChannel.Result 10 | import io.flutter.plugin.common.PluginRegistry.Registrar 11 | 12 | /** XaynAiFfiDartPlugin */ 13 | class XaynAiFfiDartPlugin: FlutterPlugin, MethodCallHandler { 14 | /// The MethodChannel that will the communication between Flutter and native Android 15 | /// 16 | /// This local reference serves to register the plugin with the Flutter Engine and unregister it 17 | /// when the Flutter Engine is detached from the Activity 18 | private lateinit var channel : MethodChannel 19 | 20 | override fun onAttachedToEngine(@NonNull flutterPluginBinding: FlutterPlugin.FlutterPluginBinding) { 21 | channel = MethodChannel(flutterPluginBinding.binaryMessenger, "xayn_ai_ffi_dart") 22 | channel.setMethodCallHandler(this) 23 | } 24 | 25 | override fun onMethodCall(@NonNull call: MethodCall, @NonNull result: Result) { 26 | if (call.method == "getPlatformVersion") { 27 | result.success("Android ${android.os.Build.VERSION.RELEASE}") 28 | } else { 29 | result.notImplemented() 30 | } 31 | } 32 | 33 | override fun onDetachedFromEngine(@NonNull binding: FlutterPlugin.FlutterPluginBinding) { 34 | channel.setMethodCallHandler(null) 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /.github/workflows/labeler.yml: -------------------------------------------------------------------------------- 1 | name: PR WIP Label Assigner 2 | 3 | on: 4 | pull_request: 5 | types: [opened, converted_to_draft, ready_for_review] 6 | workflow_dispatch: 7 | 8 | permissions: 9 | pull-requests: write 10 | 11 | jobs: 12 | draft_PR: 13 | if: (github.event.pull_request.draft == true) 14 | runs-on: ubuntu-20.04 15 | name: Add WIP label 16 | steps: 17 | - name: Add WIP Label 18 | uses: buildsville/add-remove-label@6008d7bd99d3baeb7c04033584e68f8ec80b198b # v1 19 | with: 20 | token: ${{secrets.GITHUB_TOKEN}} 21 | label: "WIP ⏳" 22 | type: add 23 | 24 | - name: Remove Ready for Review Label 25 | uses: buildsville/add-remove-label@6008d7bd99d3baeb7c04033584e68f8ec80b198b # v1 26 | with: 27 | token: ${{secrets.GITHUB_TOKEN}} 28 | label: "Ready for review ✅" 29 | type: remove 30 | 31 | ready_for_review_PR: 32 | if: (github.event.pull_request.draft == false) 33 | runs-on: ubuntu-20.04 34 | name: Remove label 35 | steps: 36 | - name: Remove WIP Label 37 | uses: buildsville/add-remove-label@6008d7bd99d3baeb7c04033584e68f8ec80b198b # v1 38 | if: contains(github.event.pull_request.labels.*.name, 'WIP ⏳') 39 | with: 40 | token: ${{secrets.GITHUB_TOKEN}} 41 | label: "WIP ⏳" 42 | type: remove 43 | 44 | - name: Add Ready for Review Label 45 | uses: buildsville/add-remove-label@6008d7bd99d3baeb7c04033584e68f8ec80b198b # v1 46 | with: 47 | token: ${{secrets.GITHUB_TOKEN}} 48 | label: "Ready for review ✅" 49 | type: add 50 | -------------------------------------------------------------------------------- /dev-tool/src/call_data/mod.rs: -------------------------------------------------------------------------------- 1 | #![cfg(not(tarpaulin))] 2 | use std::{ 3 | fs::File, 4 | io::{self, BufReader, BufWriter}, 5 | path::Path, 6 | }; 7 | 8 | use anyhow::Error; 9 | use serde::{Deserialize, Serialize}; 10 | use structopt::StructOpt; 11 | 12 | use xayn_ai::{Document, DocumentHistory, RerankMode}; 13 | 14 | use self::{generate::GenerateCallDataCmd, run::RunCallDataCmd}; 15 | use crate::utils::serde_opt_bytes_as_base64; 16 | 17 | mod generate; 18 | mod run; 19 | 20 | /// Commands related to training ListNet (train, convert, evaluate). 21 | #[derive(StructOpt, Debug)] 22 | pub enum CallDataCmd { 23 | Generate(GenerateCallDataCmd), 24 | Run(RunCallDataCmd), 25 | } 26 | 27 | impl CallDataCmd { 28 | pub fn run(self) -> Result { 29 | use CallDataCmd::*; 30 | match self { 31 | Generate(cmd) => cmd.run(), 32 | Run(cmd) => cmd.run(), 33 | } 34 | } 35 | } 36 | 37 | #[derive(Deserialize, Serialize)] 38 | struct CallData { 39 | rerank_mode: RerankMode, 40 | histories: Vec, 41 | documents: Vec, 42 | #[serde(with = "serde_opt_bytes_as_base64")] 43 | serialized_state: Option>, 44 | } 45 | 46 | impl CallData { 47 | fn load_from_file(path: &Path) -> Result { 48 | let file = File::open(path)?; 49 | let reader = BufReader::new(file); 50 | serde_json::from_reader(reader).map_err(Into::into) 51 | } 52 | 53 | fn save_to_file(&self, path: &Path) -> Result<(), io::Error> { 54 | let file = File::create(path)?; 55 | let writer = BufWriter::new(file); 56 | serde_json::to_writer_pretty(writer, self).map_err(Into::into) 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /bindings/dart/example/ios/Runner/Base.lproj/Main.storyboard: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | -------------------------------------------------------------------------------- /bindings/dart/example/ios/Runner/Info.plist: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | CFBundleDevelopmentRegion 6 | $(DEVELOPMENT_LANGUAGE) 7 | CFBundleExecutable 8 | $(EXECUTABLE_NAME) 9 | CFBundleIdentifier 10 | $(PRODUCT_BUNDLE_IDENTIFIER) 11 | CFBundleInfoDictionaryVersion 12 | 6.0 13 | CFBundleName 14 | xayn_ai_ffi_dart_example 15 | CFBundlePackageType 16 | APPL 17 | CFBundleShortVersionString 18 | $(FLUTTER_BUILD_NAME) 19 | CFBundleSignature 20 | ???? 21 | CFBundleVersion 22 | $(FLUTTER_BUILD_NUMBER) 23 | LSRequiresIPhoneOS 24 | 25 | UILaunchStoryboardName 26 | LaunchScreen 27 | UIMainStoryboardFile 28 | Main 29 | UISupportedInterfaceOrientations 30 | 31 | UIInterfaceOrientationPortrait 32 | UIInterfaceOrientationLandscapeLeft 33 | UIInterfaceOrientationLandscapeRight 34 | 35 | UISupportedInterfaceOrientations~ipad 36 | 37 | UIInterfaceOrientationPortrait 38 | UIInterfaceOrientationPortraitUpsideDown 39 | UIInterfaceOrientationLandscapeLeft 40 | UIInterfaceOrientationLandscapeRight 41 | 42 | UIViewControllerBasedStatusBarAppearance 43 | 44 | 45 | 46 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: Dev CI 2 | 3 | on: 4 | push: 5 | branches-ignore: 6 | # we push on release only from staging and release represent 7 | # a snapshot of staging in a given point in time 8 | - 'master' 9 | - 'release' 10 | - '_bors*' 11 | 12 | concurrency: 13 | group: ${{ github.ref }} 14 | cancel-in-progress: true 15 | 16 | permissions: 17 | contents: read 18 | 19 | jobs: 20 | dev-ci: 21 | uses: ./.github/workflows/ci_reusable_wf.yml 22 | 23 | # this is an helper that needs all the real leafs of the workflow. 24 | # It makes easier notify_staging_failure because we only need to check 25 | # for this job 26 | ci-ok: 27 | name: ci-ok 28 | needs: dev-ci 29 | runs-on: ubuntu-20.04 30 | steps: 31 | - run: echo "Helper job" 32 | 33 | notify-staging-failure: 34 | name: notify-staging-failure 35 | needs: ci-ok 36 | # always() allows to run even if ci-ok is not successful 37 | # we only want this to run on the staging branch 38 | if: always() && github.ref == 'refs/heads/staging' 39 | runs-on: ubuntu-20.04 40 | steps: 41 | - name: Notify failure 42 | if: needs.ci-ok.result != 'success' 43 | uses: 8398a7/action-slack@a74b761b4089b5d730d813fbedcd2ec5d394f3af # v3.13.0 44 | with: 45 | status: custom 46 | fields: workflow, repo 47 | custom_payload: | 48 | { 49 | attachments: [{ 50 | title: 'Staging CI failed :warning:', 51 | color: 'danger', 52 | text: `CI: ${process.env.AS_WORKFLOW}\nRepository: ${process.env.AS_REPO}`, 53 | }] 54 | } 55 | env: 56 | SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL }} 57 | -------------------------------------------------------------------------------- /bindings/dart/lib/src/web/data/history.dart: -------------------------------------------------------------------------------- 1 | @JS() 2 | library history; 3 | 4 | import 'package:js/js.dart' show anonymous, JS; 5 | 6 | import 'package:xayn_ai_ffi_dart/src/common/data/history.dart' 7 | show 8 | FeedbackToInt, 9 | History, 10 | RelevanceToInt, 11 | DayOfWeekToInt, 12 | UserActionToInt; 13 | 14 | @JS() 15 | @anonymous 16 | class JsHistory { 17 | external factory JsHistory({ 18 | String id, 19 | int relevance, 20 | // ignore: non_constant_identifier_names 21 | int user_feedback, 22 | String session, 23 | // ignore: non_constant_identifier_names 24 | int query_count, 25 | // ignore: non_constant_identifier_names 26 | String query_id, 27 | // ignore: non_constant_identifier_names 28 | String query_words, 29 | int day, 30 | String url, 31 | String domain, 32 | int rank, 33 | // ignore: non_constant_identifier_names 34 | int user_action, 35 | }); 36 | } 37 | 38 | extension ToJsHistories on List { 39 | /// Creates JS histories from the current histories. 40 | List toJsHistories() => List.generate( 41 | length, 42 | (i) => JsHistory( 43 | id: this[i].id, 44 | relevance: this[i].relevance.toInt(), 45 | user_feedback: this[i].userFeedback.toInt(), 46 | session: this[i].session, 47 | query_count: this[i].queryCount, 48 | query_id: this[i].queryId, 49 | query_words: this[i].queryWords, 50 | day: this[i].day.toInt(), 51 | url: this[i].url, 52 | domain: this[i].domain, 53 | rank: this[i].rank.toInt(), 54 | user_action: this[i].userAction.toInt(), 55 | ), 56 | growable: false, 57 | ); 58 | } 59 | -------------------------------------------------------------------------------- /bindings/dart/lib/src/web/result/error.dart: -------------------------------------------------------------------------------- 1 | @JS() 2 | library error; 3 | 4 | import 'package:js/js.dart' show JS; 5 | import 'package:js/js_util.dart' show getProperty, hasProperty; 6 | 7 | import 'package:xayn_ai_ffi_dart/src/common/result/error.dart' 8 | show Code, IntToCode, XaynAiException; 9 | 10 | class XaynAiError extends Error { 11 | final int code; 12 | final String message; 13 | 14 | static bool isXaynAiError(Object o) { 15 | return hasProperty(o, 'code') && hasProperty(o, 'message'); 16 | } 17 | 18 | XaynAiError(this.code, this.message); 19 | } 20 | 21 | extension ObjectToXaynAiError on Object { 22 | XaynAiError toXaynAiError() => XaynAiError( 23 | getProperty(this, 'code') as int, 24 | getProperty(this, 'message') as String, 25 | ); 26 | } 27 | 28 | extension XaynAiErrorToException on XaynAiError { 29 | /// Creates an exception from the error information. 30 | XaynAiException toException() => XaynAiException(code.toCode(), message); 31 | } 32 | 33 | @JS('WebAssembly.RuntimeError') 34 | // see: https://github.com/lexaknyazev/wasm.dart/blob/a6c93afea4732c140f1f61f144795961c42c8613/wasm_interop/lib/wasm_interop.dart#L718 35 | external Function get runtimeError; 36 | 37 | class RuntimeError extends Error { 38 | final String message; 39 | 40 | RuntimeError(this.message); 41 | } 42 | 43 | extension ObjectToRuntimeError on Object { 44 | RuntimeError toRuntimeError() => 45 | RuntimeError(getProperty(this, 'message') as String); 46 | } 47 | 48 | extension RuntimeErrorToException on RuntimeError { 49 | /// Creates an exception with a [`Code.panic`] from the JS runtime error. 50 | XaynAiException toException() => XaynAiException( 51 | Code.panic, 52 | 'JS WebAssembly RuntimeError: $message', 53 | ); 54 | } 55 | -------------------------------------------------------------------------------- /bindings/dart/example/ios/Podfile: -------------------------------------------------------------------------------- 1 | # Uncomment this line to define a global platform for your project 2 | platform :ios, '12.1' 3 | 4 | # CocoaPods analytics sends network stats synchronously affecting flutter build latency. 5 | ENV['COCOAPODS_DISABLE_STATS'] = 'true' 6 | 7 | project 'Runner', { 8 | 'Debug' => :debug, 9 | 'Profile' => :release, 10 | 'Release' => :release, 11 | } 12 | 13 | def flutter_root 14 | generated_xcode_build_settings_path = File.expand_path(File.join('..', 'Flutter', 'Generated.xcconfig'), __FILE__) 15 | unless File.exist?(generated_xcode_build_settings_path) 16 | raise "#{generated_xcode_build_settings_path} must exist. If you're running pod install manually, make sure flutter pub get is executed first" 17 | end 18 | 19 | File.foreach(generated_xcode_build_settings_path) do |line| 20 | matches = line.match(/FLUTTER_ROOT\=(.*)/) 21 | return matches[1].strip if matches 22 | end 23 | raise "FLUTTER_ROOT not found in #{generated_xcode_build_settings_path}. Try deleting Generated.xcconfig, then run flutter pub get" 24 | end 25 | 26 | require File.expand_path(File.join('packages', 'flutter_tools', 'bin', 'podhelper'), flutter_root) 27 | 28 | flutter_ios_podfile_setup 29 | 30 | target 'Runner' do 31 | use_frameworks! 32 | use_modular_headers! 33 | 34 | flutter_install_all_ios_pods File.dirname(File.realpath(__FILE__)) 35 | end 36 | 37 | post_install do |installer| 38 | installer.pods_project.targets.each do |target| 39 | flutter_additional_ios_build_settings(target) 40 | target.build_configurations.each do |config| 41 | config.build_settings['ENABLE_BITCODE'] = 'NO' 42 | config.build_settings["EXCLUDED_ARCHS"] = "armv7" 43 | config.build_settings['EXCLUDED_ARCHS[sdk=iphonesimulator*]'] = 'i386 arm64' 44 | end 45 | end 46 | end 47 | -------------------------------------------------------------------------------- /xayn-ai-ffi/build.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | env, 3 | fs::read_dir, 4 | path::{Path, PathBuf}, 5 | }; 6 | 7 | use cbindgen::{generate_with_config, Config}; 8 | 9 | // cargo doesn't check directories recursively so we have to do it by hand, also emitting a 10 | // rerun-if line cancels the default rerun for changes in the crate directory 11 | fn cargo_rerun_if_changed(entry: impl AsRef) { 12 | let entry = entry.as_ref(); 13 | if entry.is_dir() { 14 | for entry in read_dir(entry).expect("Failed to read dir.") { 15 | cargo_rerun_if_changed(entry.expect("Failed to read entry.").path()); 16 | } 17 | } else { 18 | println!("cargo:rerun-if-changed={}", entry.display()); 19 | } 20 | } 21 | 22 | fn main() { 23 | let crate_dir = PathBuf::from( 24 | env::var("CARGO_MANIFEST_DIR").expect("Failed to read CARGO_MANIFEST_DIR env."), 25 | ); 26 | 27 | let config_file = crate_dir.join("cbindgen.toml"); 28 | let header_file = crate_dir 29 | .parent() 30 | .unwrap() 31 | .join("bindings") 32 | .join("dart") 33 | .join("ios") 34 | .join("Classes") 35 | .join("XaynAiFfiCommon.h"); 36 | 37 | cargo_rerun_if_changed(crate_dir.join("src")); 38 | cargo_rerun_if_changed(crate_dir.join("Cargo.toml")); 39 | cargo_rerun_if_changed( 40 | crate_dir 41 | .parent() 42 | .unwrap() 43 | .join("xayn-ai") 44 | .join("src") 45 | .join("data"), 46 | ); 47 | cargo_rerun_if_changed(config_file.as_path()); 48 | 49 | let config = Config::from_file(config_file).expect("Failed to read config."); 50 | generate_with_config(crate_dir, config) 51 | .expect("Failed to generate bindings.") 52 | .write_to_file(header_file); 53 | } 54 | -------------------------------------------------------------------------------- /bindings/dart/lib/src/mobile/result/fault.dart: -------------------------------------------------------------------------------- 1 | import 'dart:ffi' show nullptr, Pointer, StructPointer; 2 | 3 | import 'package:ffi/ffi.dart' show Utf8, Utf8Pointer; 4 | 5 | import 'package:xayn_ai_ffi_dart/src/common/utils.dart' 6 | show assertEq, assertNeq; 7 | import 'package:xayn_ai_ffi_dart/src/mobile/ffi/genesis.dart' 8 | show CBoxedSlice_CError; 9 | import 'package:xayn_ai_ffi_dart/src/mobile/ffi/library.dart' show ffi; 10 | 11 | /// The Xayn Ai faults. 12 | class Faults { 13 | Pointer _faults; 14 | 15 | /// Creates the faults. 16 | /// 17 | /// This constructor never throws an exception. 18 | Faults(this._faults); 19 | 20 | /// Converts the faults to a list. 21 | List toList() { 22 | assert( 23 | _faults == nullptr || _faults.ref.data != nullptr, 24 | 'unexpected faults pointer state', 25 | ); 26 | 27 | return _faults == nullptr 28 | ? List.empty() 29 | : List.generate( 30 | _faults.ref.len, 31 | (i) { 32 | if (_faults.ref.data[i].message == nullptr) { 33 | return ''; 34 | } else { 35 | assertNeq(_faults.ref.data[i].message.ref.data, nullptr); 36 | assertEq( 37 | _faults.ref.data[i].message.ref.len, 38 | _faults.ref.data[i].message.ref.data.cast().length + 1, 39 | ); 40 | 41 | return _faults.ref.data[i].message.ref.data 42 | .cast() 43 | .toDartString(); 44 | } 45 | }, 46 | growable: false, 47 | ); 48 | } 49 | 50 | /// Frees the memory. 51 | void free() { 52 | if (_faults != nullptr) { 53 | ffi.faults_drop(_faults); 54 | _faults = nullptr; 55 | } 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /bindings/dart/example/pubspec.yaml: -------------------------------------------------------------------------------- 1 | name: xayn_ai_ffi_dart_example 2 | description: Demonstrates how to use the xayn_ai_ffi_dart plugin. 3 | 4 | # The following line prevents the package from being accidentally published to 5 | # pub.dev using `pub publish`. This is preferred for private packages. 6 | publish_to: 'none' # Remove this line if you wish to publish to pub.dev 7 | 8 | environment: 9 | sdk: '>=2.14.0 <3.0.0' 10 | flutter: '>=2.5.0 <3.0.0' 11 | 12 | dependencies: 13 | flutter: 14 | sdk: 'flutter' 15 | path_provider: '^2.0.3' 16 | stats: '2.0.0' 17 | crypto: '^3.0.1' 18 | 19 | xayn_ai_ffi_dart: 20 | # When depending on this package from a real application you should use: 21 | # xayn_ai_ffi_dart: ^x.y.z 22 | # See https://dart.dev/tools/pub/dependencies#version-constraints 23 | # The example app is bundled with the plugin so we use a path dependency on 24 | # the parent directory to use the current plugin's version. 25 | path: '../' 26 | 27 | # The following adds the Cupertino Icons font to your application. 28 | # Use with the CupertinoIcons class for iOS style icons. 29 | cupertino_icons: '^1.0.3' 30 | 31 | dev_dependencies: 32 | flutter_test: 33 | sdk: 'flutter' 34 | pedantic: '^1.11.1' 35 | 36 | # For information on the generic Dart part of this file, see the 37 | # following page: https://dart.dev/tools/pub/pubspec 38 | 39 | # The following section is specific to Flutter. 40 | flutter: 41 | 42 | # The following line ensures that the Material Icons font is 43 | # included with your application, so that you can use the icons in 44 | # the material Icons class. 45 | uses-material-design: true 46 | 47 | assets: 48 | - 'assets/smbert_v0001/' 49 | - 'assets/qambert_v0001/' 50 | - 'assets/ltr_v0000/' 51 | - 'assets/call_data/' 52 | - 'assets/wasm_bindings/' 53 | -------------------------------------------------------------------------------- /.github/workflows/audit.yml: -------------------------------------------------------------------------------- 1 | name: Audit for Security Vulnerabilities (master) 2 | 3 | on: 4 | schedule: 5 | - cron: '00 08 * * mon-fri' 6 | 7 | permissions: 8 | contents: read 9 | 10 | jobs: 11 | audit: 12 | name: Rust Audit 13 | runs-on: ubuntu-20.04 14 | steps: 15 | - name: Checkout repository 16 | uses: actions/checkout@a12a3943b4bdde767164f792f33f40b04645d846 # v3.0.0 17 | with: 18 | ref: master 19 | 20 | - name: Run cargo-audit 21 | id: cargo-audit 22 | continue-on-error: true 23 | run: | 24 | OUTCOME=0 25 | echo 'CARGO_AUDIT<> $GITHUB_ENV 26 | (((((cargo audit --deny warnings -q 2>&1; echo $? >&3) | sed 's/`/\\`/g' >&4) 3>&1) | (read xs; exit $xs)) 4>&1) >> $GITHUB_ENV || OUTCOME=1 27 | echo 'EOF' >> $GITHUB_ENV 28 | exit $OUTCOME 29 | 30 | - name: Notify on Slack 31 | uses: 8398a7/action-slack@a74b761b4089b5d730d813fbedcd2ec5d394f3af # v3.13.0 32 | if: steps.cargo-audit.outcome != 'success' 33 | with: 34 | status: custom 35 | fields: workflow, repo 36 | custom_payload: | 37 | { 38 | "text": ":package::mag:cargo audit", 39 | "blocks": [ 40 | { 41 | "type": "section", 42 | "text": { 43 | "type": "mrkdwn", 44 | "text": `Workflow: ${process.env.AS_WORKFLOW}\nRepository: ${process.env.AS_REPO}\nRef: \`master\`` 45 | } 46 | }, 47 | ], 48 | "attachments": [{ 49 | "color": "danger", 50 | "text": `\`\`\`${{ env.CARGO_AUDIT }}\`\`\``, 51 | }] 52 | } 53 | env: 54 | SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL }} 55 | -------------------------------------------------------------------------------- /prepare_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | # This script takes as input a directory, the name of the archive and a version. 3 | # It creates an archive in the correct format in the current directory 4 | # and adds the necessary information to verify its content. 5 | # The archive will contain the directory name and the provided version. 6 | # If the option --upload is provided the script will upload the archive to the s3 bucket. 7 | 8 | # ./prepare_data.sh ./bert_models rubert v0 will generate an archive rubert_v0.tgz 9 | # with one directory rubert_v0 that contains the files that are present in ./bert_models. 10 | 11 | # directory to prepare for upload 12 | DIR_PATH=$1 13 | shift 14 | NAME=$1 15 | shift 16 | VERSION=$1 17 | shift 18 | 19 | while [ $# -gt 0 ]; do 20 | opt="$1" 21 | shift 22 | 23 | case $opt in 24 | --upload) 25 | UPLOAD=true 26 | ;; 27 | esac 28 | done 29 | 30 | DIR_PATH=$(pwd)/$DIR_PATH 31 | DIR_NAME=$(basename $DIR_PATH) 32 | ARCHIVE_BASENAME="${NAME}_$VERSION" 33 | ARCHIVE_NAME="$ARCHIVE_BASENAME.tgz" 34 | URL="s3://xayn-yellow-bert/$NAME/$ARCHIVE_NAME" 35 | CHECKSUM_FILE="sha256sums" 36 | 37 | CURRENT_DIR=$(pwd) 38 | cd $DIR_PATH 39 | 40 | # create a directory with the expected name 41 | TMP_DIR=$(mktemp -d) 42 | cd $TMP_DIR 43 | cp -r $DIR_PATH . 44 | if [ $DIR_NAME != $ARCHIVE_BASENAME ]; then 45 | mv $DIR_NAME $ARCHIVE_BASENAME 46 | fi 47 | TO_ARCHIVE="$TMP_DIR/$ARCHIVE_BASENAME" 48 | 49 | # compute checksum file 50 | cd $ARCHIVE_BASENAME 51 | rm -f $CHECKSUM_FILE 52 | find . -type f -not -iname $CHECKSUM_FILE -not -name ".DS_Store" -print0 | xargs -0 shasum -a 256 > $CHECKSUM_FILE 53 | 54 | cd $CURRENT_DIR 55 | 56 | # prepare archive 57 | tar czf $ARCHIVE_NAME --exclude ".DS_Store" -C $TMP_DIR $ARCHIVE_BASENAME 58 | rm -rf $TMP_DIR 59 | 60 | if [ "$UPLOAD" = true ]; then 61 | s3cmd put --acl-public --guess-mime-type $ARCHIVE_NAME $URL 62 | fi 63 | -------------------------------------------------------------------------------- /bindings/dart/test/mobile/result/fault_test.dart: -------------------------------------------------------------------------------- 1 | import 'dart:ffi' show AllocatorAlloc, nullptr, Pointer, StructPointer, Uint8; 2 | 3 | import 'package:ffi/ffi.dart' show malloc, StringUtf8Pointer, Utf8, Utf8Pointer; 4 | import 'package:flutter_test/flutter_test.dart' 5 | show equals, expect, group, isEmpty, test; 6 | 7 | import 'package:xayn_ai_ffi_dart/src/mobile/ffi/genesis.dart' 8 | show CError, CBoxedSlice_CError, CBoxedSlice_u8; 9 | import 'package:xayn_ai_ffi_dart/src/mobile/result/fault.dart' show Faults; 10 | 11 | void main() { 12 | group('Faults', () { 13 | test('to list', () { 14 | final faults = List.generate(10, (i) => 'fault $i', growable: false); 15 | final faultsPtr = malloc.call(); 16 | faultsPtr.ref.data = malloc.call(faults.length); 17 | faultsPtr.ref.len = faults.length; 18 | faults.asMap().forEach((i, fault) { 19 | faultsPtr.ref.data[i].message = malloc.call(); 20 | faultsPtr.ref.data[i].message.ref.data = 21 | fault.toNativeUtf8().cast(); 22 | faultsPtr.ref.data[i].message.ref.len = 23 | faultsPtr.ref.data[i].message.ref.data.cast().length + 1; 24 | }); 25 | expect(Faults(faultsPtr).toList(), equals(faults)); 26 | for (var i = 0; i < faults.length; i++) { 27 | malloc.free(faultsPtr.ref.data[i].message.ref.data); 28 | malloc.free(faultsPtr.ref.data[i].message); 29 | } 30 | malloc.free(faultsPtr.ref.data); 31 | malloc.free(faultsPtr); 32 | }); 33 | 34 | test('null', () { 35 | final faults = Faults(nullptr); 36 | expect(faults.toList(), isEmpty); 37 | }); 38 | 39 | test('empty', () { 40 | final faultsPtr = malloc.call(); 41 | faultsPtr.ref.data = Pointer.fromAddress(16); 42 | faultsPtr.ref.len = 0; 43 | expect(Faults(faultsPtr).toList(), isEmpty); 44 | malloc.free(faultsPtr); 45 | }); 46 | }); 47 | } 48 | -------------------------------------------------------------------------------- /bindings/dart/example/android/app/build.gradle: -------------------------------------------------------------------------------- 1 | def localProperties = new Properties() 2 | def localPropertiesFile = rootProject.file('local.properties') 3 | if (localPropertiesFile.exists()) { 4 | localPropertiesFile.withReader('UTF-8') { reader -> 5 | localProperties.load(reader) 6 | } 7 | } 8 | 9 | def flutterRoot = localProperties.getProperty('flutter.sdk') 10 | if (flutterRoot == null) { 11 | throw new GradleException("Flutter SDK not found. Define location with flutter.sdk in the local.properties file.") 12 | } 13 | 14 | def flutterVersionCode = localProperties.getProperty('flutter.versionCode') 15 | if (flutterVersionCode == null) { 16 | flutterVersionCode = '1' 17 | } 18 | 19 | def flutterVersionName = localProperties.getProperty('flutter.versionName') 20 | if (flutterVersionName == null) { 21 | flutterVersionName = '1.0' 22 | } 23 | 24 | apply plugin: 'com.android.application' 25 | apply plugin: 'kotlin-android' 26 | apply from: "$flutterRoot/packages/flutter_tools/gradle/flutter.gradle" 27 | 28 | android { 29 | compileSdkVersion 30 30 | 31 | sourceSets { 32 | main.java.srcDirs += 'src/main/kotlin' 33 | } 34 | 35 | defaultConfig { 36 | // TODO: Specify your own unique Application ID (https://developer.android.com/studio/build/application-id.html). 37 | applicationId "com.xayn.xayn_ai_ffi_dart_example" 38 | minSdkVersion 21 39 | targetSdkVersion 30 40 | versionCode flutterVersionCode.toInteger() 41 | versionName flutterVersionName 42 | } 43 | 44 | buildTypes { 45 | release { 46 | // TODO: Add your own signing config for the release build. 47 | // Signing with the debug keys for now, so `flutter run --release` works. 48 | signingConfig signingConfigs.debug 49 | } 50 | } 51 | } 52 | 53 | flutter { 54 | source '../..' 55 | } 56 | 57 | dependencies { 58 | implementation "org.jetbrains.kotlin:kotlin-stdlib-jdk7:$kotlin_version" 59 | } 60 | -------------------------------------------------------------------------------- /test-utils/src/asset.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | collections::HashMap, 3 | env::var_os, 4 | fs::File, 5 | io::{BufReader, Error, ErrorKind, Result}, 6 | path::{Path, PathBuf}, 7 | }; 8 | 9 | use serde::Deserialize; 10 | use serde_json::from_reader; 11 | 12 | pub const DATA_DIR: &str = "data"; 13 | 14 | /// Resolves the path to the requested data relative to the workspace directory. 15 | pub fn resolve_path(path: &[impl AsRef]) -> Result { 16 | let manifest = var_os("CARGO_MANIFEST_DIR") 17 | .ok_or_else(|| Error::new(ErrorKind::NotFound, "missing CARGO_MANIFEST_DIR"))?; 18 | let workspace = PathBuf::from(manifest) 19 | .parent() 20 | .ok_or_else(|| Error::new(ErrorKind::NotFound, "missing cargo workspace dir"))? 21 | .to_path_buf(); 22 | 23 | path.iter() 24 | .fold(workspace, |path, component| path.join(component)) 25 | .canonicalize() 26 | } 27 | 28 | #[derive(Deserialize)] 29 | struct Asset { 30 | #[serde(rename(deserialize = "id"))] 31 | name: String, 32 | url_suffix: String, 33 | } 34 | 35 | #[derive(Deserialize)] 36 | struct Assets { 37 | data_assets: Vec, 38 | } 39 | 40 | /// Reads the asset paths from the static assets file. 41 | fn read_assets() -> Result> { 42 | from_reader::<_, Assets>(BufReader::new(File::open(resolve_path(&[ 43 | "assets_manifest.json", 44 | ])?)?)) 45 | .map(|assets| { 46 | assets 47 | .data_assets 48 | .into_iter() 49 | .map(|asset| (asset.name, [DATA_DIR, &asset.url_suffix].iter().collect())) 50 | .collect() 51 | }) 52 | .map_err(|error| Error::new(ErrorKind::InvalidData, error.to_string())) 53 | } 54 | 55 | /// Resolves the path to the requested asset relative to the workspace directory. 56 | pub fn resolve_asset(asset: &str) -> Result { 57 | resolve_path(&[read_assets()? 58 | .get(asset) 59 | .ok_or_else(|| Error::new(ErrorKind::NotFound, format!("missing asset '{}'", asset)))?]) 60 | } 61 | -------------------------------------------------------------------------------- /rubert-tokenizer/src/tokenizer.rs: -------------------------------------------------------------------------------- 1 | use num_traits::{FromPrimitive, Num}; 2 | 3 | use crate::{ 4 | model::Model, 5 | normalizer::Normalizer, 6 | post_tokenizer::{encoding::Encoding, padding::Padding, truncation::Truncation, PostTokenizer}, 7 | pre_tokenizer::PreTokenizer, 8 | }; 9 | 10 | /// A Bert tokenizer. 11 | /// 12 | /// Can be created via the [`Builder`] and consists of a Bert normalizer, a Bert pre-tokenizer, a 13 | /// Bert word piece model and a Bert post-tokenizer including truncation and padding strategies. 14 | /// 15 | /// [`Builder`]: crate::Builder 16 | #[derive(Debug)] 17 | pub struct Tokenizer { 18 | pub(crate) normalizer: Normalizer, 19 | pub(crate) pre_tokenizer: PreTokenizer, 20 | pub(crate) model: Model, 21 | pub(crate) post_tokenizer: PostTokenizer, 22 | pub(crate) truncation: Truncation, 23 | pub(crate) padding: Padding, 24 | } 25 | 26 | impl Tokenizer { 27 | /// Encodes the sequence. 28 | pub fn encode(&self, sequence: impl AsRef) -> Encoding 29 | where 30 | N: Num + FromPrimitive + Copy, 31 | { 32 | let sequence = self.normalizer.normalize(sequence); 33 | let sequence = self.pre_tokenizer.pre_tokenize(sequence); 34 | let sequence = self.model.tokenize(sequence); 35 | 36 | let encoding = self.truncation.truncate(sequence.into()); 37 | let encoding = self.post_tokenizer.post_tokenize(encoding); 38 | self.padding.pad(encoding) 39 | } 40 | 41 | /// Decodes the encoding with optional cleanup. 42 | pub fn decode(&self, encoding: &Encoding, cleanup: bool) -> String { 43 | encoding.decode( 44 | self.post_tokenizer.cls_token.as_str(), 45 | self.post_tokenizer.sep_token.as_str(), 46 | self.padding.pad_token(), 47 | self.model.unk_token.as_str(), 48 | self.model.prefix.as_str(), 49 | cleanup, 50 | ) 51 | } 52 | 53 | /// Gets the number of entries in the vocabulary. 54 | pub fn vocab_size(&self) -> usize { 55 | self.model.vocab.len() 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /bindings/dart/lib/src/web/reranker/data_provider.dart: -------------------------------------------------------------------------------- 1 | import 'dart:typed_data' show Uint8List; 2 | 3 | import 'package:xayn_ai_ffi_dart/src/common/reranker/data_provider.dart' 4 | as common 5 | show 6 | Asset, 7 | AssetType, 8 | baseAssets, 9 | // ignore: unused_shown_name 10 | Checksum, 11 | SetupData, 12 | WebFeature; 13 | 14 | part 'assets.dart'; 15 | 16 | // ignore: unused_element 17 | final _multithreaded = { 18 | common.WebFeature.bulkMemory, 19 | common.WebFeature.mutableGlobals, 20 | common.WebFeature.threads, 21 | }; 22 | 23 | /// Returns a map of all assets required for initializing [`XaynAi`]. 24 | Map getAssets( 25 | {Set features = const {}}) => 26 | {...common.baseAssets, ...getWasmAssets(features)}; 27 | 28 | /// Data that is required to initialize [`XaynAi`]. 29 | class SetupData implements common.SetupData { 30 | late Uint8List smbertVocab; 31 | late Uint8List smbertModel; 32 | late Uint8List qambertVocab; 33 | late Uint8List qambertModel; 34 | late Uint8List ltrModel; 35 | late Uint8List wasmModule; 36 | late String wasmScript; 37 | late String webWorkerScript; 38 | 39 | SetupData(Map assets) { 40 | smbertVocab = assets[common.AssetType.smbertVocab]! as Uint8List; 41 | smbertModel = assets[common.AssetType.smbertModel]! as Uint8List; 42 | qambertVocab = assets[common.AssetType.qambertVocab]! as Uint8List; 43 | qambertModel = assets[common.AssetType.qambertModel]! as Uint8List; 44 | ltrModel = assets[common.AssetType.ltrModel]! as Uint8List; 45 | wasmModule = assets[common.AssetType.wasmModule]! as Uint8List; 46 | 47 | // The wasm script url needs to be relative to the web-worker url. 48 | // https://github.com/xaynetwork/xayn_ai/pull/272 explains it in more detail. 49 | final wasmScriptRelative = assets[common.AssetType.wasmScript]! as String; 50 | wasmScript = wasmScriptRelative.split('/').last; 51 | 52 | webWorkerScript = assets[common.AssetType.webWorkerScript]! as String; 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /kpe/src/tokenizer/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod encoding; 2 | pub mod key_phrase; 3 | 4 | use std::io::BufRead; 5 | 6 | use displaydoc::Display; 7 | use rubert_tokenizer::{Builder, BuilderError, Padding, Tokenizer as BertTokenizer, Truncation}; 8 | use thiserror::Error; 9 | 10 | /// A pre-configured Bert tokenizer for key phrase extraction. 11 | #[derive(Debug)] 12 | pub struct Tokenizer { 13 | tokenizer: BertTokenizer, 14 | key_phrase_max_count: Option, 15 | key_phrase_min_score: Option, 16 | } 17 | 18 | /// The potential errors of the tokenizer. 19 | #[derive(Debug, Display, Error, PartialEq)] 20 | pub enum TokenizerError { 21 | /// Failed to build the tokenizer: {0} 22 | Builder(#[from] BuilderError), 23 | } 24 | 25 | impl Tokenizer { 26 | /// Creates a tokenizer from a vocabulary. 27 | /// 28 | /// Can be set to keep accents and to lowercase the sequences. Requires the maximum number of 29 | /// tokens per tokenized sequence, which applies to padding and truncation and includes special 30 | /// tokens as well. 31 | /// 32 | /// Optionally takes an upper count for the number of returned key phrases as well as a lower 33 | /// threshold for the scores of returned key phrases. 34 | pub fn new( 35 | vocab: impl BufRead, 36 | accents: bool, 37 | lowercase: bool, 38 | token_size: usize, 39 | key_phrase_max_count: Option, 40 | key_phrase_min_score: Option, 41 | ) -> Result { 42 | let tokenizer = Builder::new(vocab)? 43 | .with_normalizer(true, false, accents, lowercase) 44 | .with_model("[UNK]", "##", 100) 45 | .with_post_tokenizer("[CLS]", "[SEP]") 46 | .with_truncation(Truncation::fixed(token_size, 0)) 47 | .with_padding(Padding::fixed(token_size, "[PAD]")) 48 | .build()?; 49 | 50 | Ok(Tokenizer { 51 | tokenizer, 52 | key_phrase_max_count, 53 | key_phrase_min_score, 54 | }) 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /xayn-ai/src/embedding/smbert.rs: -------------------------------------------------------------------------------- 1 | use std::sync::Arc; 2 | 3 | use derive_more::{Deref, From}; 4 | use ndarray::arr1; 5 | #[cfg(feature = "multithreaded")] 6 | use rayon::iter::{IntoParallelIterator, ParallelIterator}; 7 | 8 | use crate::{ 9 | data::document_data::{DocumentDataWithDocument, DocumentDataWithSMBert, SMBertComponent}, 10 | error::Error, 11 | reranker::systems::SMBertSystem, 12 | }; 13 | 14 | #[derive(Clone, Deref, From)] 15 | pub struct SMBert(Arc); 16 | 17 | impl SMBertSystem for SMBert { 18 | fn compute_embedding( 19 | &self, 20 | documents: &[DocumentDataWithDocument], 21 | ) -> Result, Error> { 22 | #[cfg(not(feature = "multithreaded"))] 23 | let documents = documents.iter(); 24 | #[cfg(feature = "multithreaded")] 25 | let documents = documents.into_par_iter(); 26 | 27 | documents 28 | .map(|document| { 29 | let embedding = self.run(document.document_content.title.as_str()); 30 | embedding 31 | .map(|embedding| { 32 | DocumentDataWithSMBert::from_document( 33 | document, 34 | SMBertComponent { embedding }, 35 | ) 36 | }) 37 | .map_err(Into::into) 38 | }) 39 | .collect() 40 | } 41 | } 42 | 43 | /// SMBert system to run when SMBert is disabled 44 | #[allow(clippy::upper_case_acronyms)] 45 | pub struct NeutralSMBert; 46 | 47 | impl SMBertSystem for NeutralSMBert { 48 | fn compute_embedding( 49 | &self, 50 | documents: &[DocumentDataWithDocument], 51 | ) -> Result, Error> { 52 | Ok(documents 53 | .iter() 54 | .map(|document| { 55 | DocumentDataWithSMBert::from_document( 56 | document, 57 | SMBertComponent { 58 | embedding: arr1(&[]).into(), 59 | }, 60 | ) 61 | }) 62 | .collect()) 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /bindings/dart/example/flutter_run_web.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | set -eu 4 | 5 | # This provides a way to run flutter web which is compatible with using 6 | # threads in wasm. In difference to e.g. `flutter run -d Chrome` this doesn't 7 | # support live reloading and similar features. The reason for this is that 8 | # we need to run a custom server to set the required http headers. 9 | 10 | error() { 11 | echo $1 >&2 12 | exit 1 13 | } 14 | 15 | ROOT="$(dirname $0)" 16 | cd "$ROOT" 17 | BUILD_OUT="./build/web" 18 | CANVASKIT_OUT="$BUILD_OUT/canvaskit" 19 | 20 | # The default canvaskit is hosted on `https://unpkg.com/` but that CDN 21 | # doesn't yet set the Cross-Origin-Resource-Policy header. This makes it 22 | # unusable if our site uses `Cross-Origin-Embedder-Policy: require-corp`. 23 | # But we need to set that header to be able to use `SharedArrayBuffer`. 24 | # 25 | # Issue: https://github.com/mjackson/unpkg/issues/290 26 | flutter build web \ 27 | --dart-define=FLUTTER_WEB_CANVASKIT_URL=./canvaskit/ 28 | 29 | # Fetch JS libs from CDN to avoid problems with COOP/COEP 30 | 31 | # Usage: download_unpkg [(|"") []] 32 | # 33 | # Downloads a file from the unpkg CDN. 34 | # 35 | # Output folders will be created like necessary. 36 | download_unpkg() { 37 | if [ -z "$4" ]; then 38 | OUT="$3" 39 | else 40 | OUT="$4/$3" 41 | fi 42 | mkdir -p "$(dirname "$OUT")" 43 | curl "https://unpkg.com/$1@$2${5:-}/$3" > "$OUT" 44 | } 45 | 46 | # Downloads a file from the canvaskit project which was placed in the `bin/` directory. 47 | download_canvaskit_file() { 48 | OUT="$CANVASKIT_OUT/$1" 49 | download_unpkg canvaskit-wasm 0.28.1 $1 $CANVASKIT_OUT /bin 50 | } 51 | if [ ! -e "$CANVASKIT_OUT/canvaskit.js" ]; then 52 | echo "Downloading canvaskit" >&2 53 | mkdir -p "$CANVASKIT_OUT" 54 | download_canvaskit_file "canvaskit.js" 55 | download_canvaskit_file "canvaskit.wasm" 56 | download_canvaskit_file "profiling/canvaskit.js" 57 | download_canvaskit_file "profiling/canvaskit.wasm" 58 | fi 59 | 60 | echo "Running Server" 61 | /usr/bin/env python3 "./server.py" "$BUILD_OUT" 62 | -------------------------------------------------------------------------------- /bindings/dart/lib/src/web/worker/worker.dart: -------------------------------------------------------------------------------- 1 | import 'dart:async' show StreamController; 2 | import 'dart:html' show DedicatedWorkerGlobalScope; 3 | 4 | import 'package:xayn_ai_ffi_dart/src/common/result/error.dart' 5 | show XaynAiException; 6 | import 'package:xayn_ai_ffi_dart/src/web/ffi/ai.dart' as ffi show XaynAi; 7 | import 'package:xayn_ai_ffi_dart/src/web/worker/message/request.dart' 8 | show Request; 9 | import 'package:xayn_ai_ffi_dart/src/web/worker/message/response.dart' 10 | show Response; 11 | import 'package:xayn_ai_ffi_dart/src/web/worker/method_handler.dart' 12 | show MethodHandler, SendResponse; 13 | 14 | void main() async { 15 | try { 16 | await handleRequests(); 17 | } catch (error) { 18 | print('Web worker error while handling a request: $error'); 19 | } 20 | } 21 | 22 | /// A small wrapper around [DedicatedWorkerGlobalScope.onMessage]. 23 | /// [DedicatedWorkerGlobalScope.onMessage] does not seem to behave like a 24 | /// real Dart [Stream]. When used in the `await for` loop in the 25 | /// [handleRequests] function below, it sometimes loses/discards messages. 26 | class MessageStream { 27 | final _workerOnMessage = DedicatedWorkerGlobalScope.instance.onMessage; 28 | final _incoming = StreamController(); 29 | 30 | Stream get incoming => _incoming.stream; 31 | 32 | MessageStream() { 33 | _workerOnMessage.listen((event) => _incoming.add(event.data as T)); 34 | } 35 | } 36 | 37 | /// A function that handles the incoming [Request]s. 38 | /// [Request]s are processed sequentially in the order in which they arrived. 39 | Future handleRequests() async { 40 | final messageStream = MessageStream(); 41 | const methodHandler = MethodHandler(); 42 | ffi.XaynAi? ai; 43 | 44 | await for (final json in messageStream.incoming) { 45 | final request = Request.fromJson(json); 46 | try { 47 | ai = await methodHandler[request.method](ai, request); 48 | } on XaynAiException catch (exception) { 49 | request.sender.sendResponse(Response.fromException(exception)); 50 | } catch (error) { 51 | print( 52 | 'Web worker error while handling the method call `${request.method}`: $error'); 53 | } 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /bindings/dart/lib/src/mobile/result/outcomes.dart: -------------------------------------------------------------------------------- 1 | import 'dart:ffi' show nullptr, Pointer, StructPointer; 2 | 3 | import 'package:xayn_ai_ffi_dart/src/common/result/outcomes.dart' 4 | show RerankingOutcomes; 5 | import 'package:xayn_ai_ffi_dart/src/mobile/ffi/genesis.dart' 6 | show CRerankingOutcomes; 7 | import 'package:xayn_ai_ffi_dart/src/mobile/ffi/library.dart' show ffi; 8 | import 'package:xayn_ai_ffi_dart/src/mobile/result/slice.dart' 9 | show BoxedSliceF32List, BoxedSliceU16List; 10 | 11 | class RerankingOutcomesBuilder { 12 | final Pointer _cOutcomes; 13 | bool _freed = false; 14 | 15 | RerankingOutcomesBuilder(this._cOutcomes); 16 | 17 | /// Build the `RerankingOutcomes` based on the passed in pointer. 18 | /// 19 | /// This should be called in a `try {} finally {}` block where in 20 | /// the `finally` block `builder.free()` is called. 21 | /// 22 | /// If this is called and the pointer with which this instance was 23 | /// created was a `nullptr` an exception is thrown. As you *should* only 24 | /// call this after checking the error codes this should not happen in 25 | /// practice. 26 | RerankingOutcomes build() { 27 | if (_freed) { 28 | throw StateError('CRerankingOutcomes have already been freed'); 29 | } else if (_cOutcomes == nullptr) { 30 | throw StateError( 31 | 'Error codes should be checked befor building outcomes from C.'); 32 | } else { 33 | final finalRanking = _cOutcomes.ref.final_ranking.toList(); 34 | if (finalRanking == null) { 35 | throw ArgumentError('Final rankings outcome was null.'); 36 | } 37 | final contextScores = _cOutcomes.ref.context_scores.toList(); 38 | final qaMBertSimilarities = _cOutcomes.ref.qambert_similarities.toList(); 39 | 40 | return RerankingOutcomes.fromParts( 41 | finalRanking, 42 | contextScores, 43 | qaMBertSimilarities, 44 | ); 45 | } 46 | } 47 | 48 | /// Free the wrapped C struct by calling the rust FFI. 49 | void free() { 50 | if (!_freed) { 51 | _freed = true; 52 | // can always be called with a nullptr, but in test we don't want to call ffi 53 | if (_cOutcomes != nullptr) ffi.reranking_outcomes_drop(_cOutcomes); 54 | } 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /bindings/dart/lib/src/common/data/document.dart: -------------------------------------------------------------------------------- 1 | import 'package:json_annotation/json_annotation.dart' show JsonSerializable; 2 | import 'package:meta/meta.dart' show immutable; 3 | 4 | part 'document.g.dart'; 5 | 6 | /// The document. 7 | @JsonSerializable() 8 | @immutable 9 | class Document { 10 | /// Unique identifier of the document. 11 | final String id; 12 | 13 | /// Text title of the document. 14 | final String title; 15 | 16 | /// Text snippet of the document. 17 | final String snippet; 18 | 19 | /// Position of the document from the source. 20 | final int rank; 21 | 22 | /// Session of the document. 23 | final String session; 24 | 25 | /// Query count within session. 26 | final int queryCount; 27 | 28 | /// Query identifier of the document. 29 | final String queryId; 30 | 31 | /// Query of the document. 32 | final String queryWords; 33 | 34 | /// URL of the document. 35 | final String url; 36 | 37 | /// Domain of the document 38 | final String domain; 39 | 40 | /// Creates the document. 41 | Document({ 42 | required this.id, 43 | required this.title, 44 | required this.snippet, 45 | required this.rank, 46 | required this.session, 47 | required this.queryCount, 48 | required this.queryId, 49 | required this.queryWords, 50 | required this.url, 51 | required this.domain, 52 | }) { 53 | if (id.isEmpty) { 54 | throw ArgumentError('empty document id'); 55 | } 56 | if (rank.isNegative) { 57 | throw ArgumentError('negative document rank'); 58 | } 59 | if (session.isEmpty) { 60 | throw ArgumentError('empty session id'); 61 | } 62 | if (queryCount < 0) { 63 | throw ArgumentError('negative query count'); 64 | } 65 | if (queryId.isEmpty) { 66 | throw ArgumentError('empty query id'); 67 | } 68 | if (queryWords.isEmpty) { 69 | throw ArgumentError('empty query words'); 70 | } 71 | if (url.isEmpty) { 72 | throw ArgumentError('empty document url'); 73 | } 74 | if (domain.isEmpty) { 75 | throw ArgumentError('empty document domain'); 76 | } 77 | } 78 | 79 | factory Document.fromJson(Map json) => 80 | _$DocumentFromJson(json); 81 | 82 | Map toJson() => _$DocumentToJson(this); 83 | } 84 | -------------------------------------------------------------------------------- /.ci/build-asset-artifacts/action.yml: -------------------------------------------------------------------------------- 1 | name: 'build asset artifacts' 2 | description: 'Builds asset artifacts' 3 | inputs: 4 | dart-ws: 5 | description: 'The Dart workspace' 6 | required: true 7 | wasm-out-dir-path: 8 | description: 'The relative path (wrt the repository root) of the WASM artifacts.' 9 | required: false 10 | outputs: 11 | dart-base-assets: 12 | description: "The relative path (wrt the repository root) of the Dart base assets manifest." 13 | value: ${{ steps.artifact-paths.outputs.dart-base-assets }} 14 | dart-web-assets: 15 | description: "The relative path (wrt the repository root) of the Dart web assets manifest." 16 | value: ${{ steps.artifact-paths.outputs.dart-web-assets }} 17 | json-metadata: 18 | description: "The relative path (wrt the repository root) of the JSON manifest." 19 | value: ${{ steps.artifact-paths.outputs.json-metadata }} 20 | chunks-dir: 21 | description: "The relative path (wrt the repository root) of the directory where the chunks of the data assets are stored." 22 | value: ${{ steps.artifact-paths.outputs.chunks-dir }} 23 | runs: 24 | using: "composite" 25 | steps: 26 | - id: artifact-paths 27 | shell: bash 28 | run: | 29 | echo "::set-output name=dart-base-assets::$(echo ${{ inputs.dart-ws }}/lib/src/common/reranker/assets.dart)" 30 | echo "::set-output name=dart-web-assets::$(echo ${{ inputs.dart-ws }}/lib/src/web/reranker/assets.dart)" 31 | echo "::set-output name=json-metadata::out/assets_metadata.json" 32 | echo "::set-output name=chunks-dir::data/chunks" 33 | 34 | - shell: bash 35 | run: | 36 | if [ ${{ runner.os }} == "macOS" ]; then 37 | # installs gsplit 38 | brew install coreutils 39 | fi 40 | 41 | bash generate_assets_metadata.sh assets_manifest.json ${{ inputs.wasm-out-dir-path }} 42 | 43 | echo "::group::Dart base assets manifest" 44 | cat ${{ steps.artifact-paths.outputs.dart-base-assets }} 45 | echo "::endgroup::" 46 | echo "::group::Dart web assets manifest" 47 | cat ${{ steps.artifact-paths.outputs.dart-web-assets }} 48 | echo "::endgroup::" 49 | echo "::group::JSON metadata" 50 | cat ${{ steps.artifact-paths.outputs.json-metadata }} 51 | echo "::endgroup::" 52 | -------------------------------------------------------------------------------- /bindings/dart/example/android/app/src/main/AndroidManifest.xml: -------------------------------------------------------------------------------- 1 | 3 | 6 | 13 | 17 | 21 | 26 | 30 | 31 | 32 | 33 | 34 | 35 | 37 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /bindings/dart/lib/src/mobile/reranker/bytes.dart: -------------------------------------------------------------------------------- 1 | import 'dart:ffi' show nullptr, Pointer, StructPointer, Uint8Pointer; 2 | import 'dart:typed_data' show Uint8List; 3 | 4 | import 'package:flutter/foundation.dart' show listEquals; 5 | 6 | import 'package:xayn_ai_ffi_dart/src/common/utils.dart' show assertNeq; 7 | import 'package:xayn_ai_ffi_dart/src/mobile/ffi/genesis.dart' 8 | show CBoxedSlice_u8; 9 | import 'package:xayn_ai_ffi_dart/src/mobile/ffi/library.dart' show ffi; 10 | import 'package:xayn_ai_ffi_dart/src/mobile/result/error.dart' show XaynAiError; 11 | 12 | /// A bytes buffer. 13 | class Bytes { 14 | late Pointer _bytes; 15 | 16 | /// Creates the bytes buffer from a pointer. 17 | /// 18 | /// This constructor never throws an exception. 19 | Bytes(this._bytes); 20 | 21 | /// Creates the bytes buffer from a list. 22 | /// 23 | /// This constructor can throw an exception. 24 | Bytes.fromList(Uint8List bytes) { 25 | final error = XaynAiError(); 26 | 27 | _bytes = ffi.bytes_new(bytes.length, error.ptr); 28 | try { 29 | if (error.isError()) { 30 | throw error.toException(); 31 | } 32 | assertNeq(_bytes, nullptr); 33 | assert(listEquals( 34 | _bytes.ref.data.asTypedList(_bytes.ref.len), 35 | Uint8List(bytes.length), 36 | )); 37 | } finally { 38 | error.free(); 39 | } 40 | 41 | bytes.asMap().forEach((i, byte) { 42 | _bytes.ref.data[i] = byte; 43 | }); 44 | } 45 | 46 | /// Gets the pointer. 47 | Pointer get ptr => _bytes; 48 | 49 | /// Converts the buffer to a list. 50 | Uint8List toList() { 51 | assert( 52 | _bytes == nullptr || _bytes.ref.data != nullptr, 53 | 'unexpected bytes pointer state', 54 | ); 55 | 56 | if (_bytes == nullptr) { 57 | return Uint8List(0); 58 | } else { 59 | final bytes = Uint8List(_bytes.ref.len); 60 | _bytes.ref.data.asTypedList(_bytes.ref.len).asMap().forEach((i, byte) { 61 | bytes[i] = byte; 62 | }); 63 | return bytes; 64 | } 65 | } 66 | 67 | /// Frees the memory. 68 | void free() { 69 | assert( 70 | _bytes == nullptr || _bytes.ref.data != nullptr, 71 | 'unexpected bytes pointer state', 72 | ); 73 | 74 | if (_bytes != nullptr) { 75 | ffi.bytes_drop(_bytes); 76 | _bytes = nullptr; 77 | } 78 | } 79 | } 80 | -------------------------------------------------------------------------------- /bindings/dart/lib/src/common/result/outcomes.dart: -------------------------------------------------------------------------------- 1 | import 'package:json_annotation/json_annotation.dart' show JsonSerializable; 2 | import 'package:xayn_ai_ffi_dart/src/common/utils.dart' show ToJson; 3 | 4 | part 'outcomes.g.dart'; 5 | 6 | /// Type containing all reranking outcomes. 7 | /// 8 | /// Some of the outcomes can be empty if they 9 | /// had not been calculated. This can happen due to 10 | /// configurations when running rerank or if an 11 | /// non-panic error happened during execution and 12 | /// as such only partial results are available. 13 | /// 14 | /// Note that `finalRanks` is empty if and only if there 15 | /// had been no input documents. 16 | @JsonSerializable() 17 | class RerankingOutcomes implements ToJson { 18 | /// The final ranking in order of the input documents. 19 | /// 20 | /// Should only be empty if there where no input documents. 21 | final List finalRanks; 22 | 23 | /// The QA-mBERT similarities in order of the input documents. 24 | /// 25 | /// Can be empty if not calculated. 26 | final List? qaMBertSimilarities; 27 | 28 | /// The context scores for all documents in order of the input documents. 29 | /// 30 | /// Can be empty if not calculated. 31 | final List? contextScores; 32 | 33 | RerankingOutcomes( 34 | this.finalRanks, this.qaMBertSimilarities, this.contextScores); 35 | 36 | /// Create a new instance from its parts. 37 | /// 38 | /// Besides for testing this should ONLY be used by the `mobile/` and `web/` 39 | /// FFI binding. 40 | /// 41 | RerankingOutcomes.fromParts( 42 | this.finalRanks, 43 | this.contextScores, 44 | this.qaMBertSimilarities, 45 | ) { 46 | checkOutcomeLength(contextScores, finalRanks, 'contextScores'); 47 | checkOutcomeLength(qaMBertSimilarities, finalRanks, 'qaMBertSimilarities'); 48 | } 49 | 50 | factory RerankingOutcomes.fromJson(Map json) => 51 | _$RerankingOutcomesFromJson(json); 52 | 53 | @override 54 | Map toJson() => _$RerankingOutcomesToJson(this); 55 | } 56 | 57 | void checkOutcomeLength(List? outcome, List base, String name) { 58 | if (outcome == null) { 59 | return; 60 | } 61 | final outLen = outcome.length; 62 | final baseLen = base.length; 63 | if (outLen != baseLen) { 64 | throw ArgumentError( 65 | 'Invalid Outcome length for $name: len=$outLen expected 0 or $baseLen'); 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /bindings/dart/lib/src/mobile/data/document.dart: -------------------------------------------------------------------------------- 1 | import 'dart:ffi' show AllocatorAlloc, nullptr, Pointer, StructPointer, Uint8; 2 | 3 | import 'package:ffi/ffi.dart' show malloc, StringUtf8Pointer; 4 | 5 | import 'package:xayn_ai_ffi_dart/src/common/data/document.dart' show Document; 6 | import 'package:xayn_ai_ffi_dart/src/mobile/ffi/genesis.dart' 7 | show CDocument, CDocuments; 8 | 9 | /// The raw documents. 10 | class Documents { 11 | late Pointer _docs; 12 | 13 | /// Creates the documents. 14 | /// 15 | /// This constructor never throws an exception. 16 | Documents(List documents) { 17 | _docs = malloc.call(); 18 | _docs.ref.len = documents.length; 19 | if (documents.isEmpty) { 20 | _docs.ref.data = nullptr; 21 | } else { 22 | _docs.ref.data = malloc.call(_docs.ref.len); 23 | documents.asMap().forEach((i, document) { 24 | var cdoc = _docs.ref.data[i]; 25 | cdoc.id = document.id.toNativeUtf8().cast(); 26 | cdoc.title = document.title.toNativeUtf8().cast(); 27 | cdoc.snippet = document.snippet.toNativeUtf8().cast(); 28 | cdoc.rank = document.rank; 29 | cdoc.session = document.session.toNativeUtf8().cast(); 30 | cdoc.query_count = document.queryCount; 31 | cdoc.query_id = document.queryId.toNativeUtf8().cast(); 32 | cdoc.query_words = document.queryWords.toNativeUtf8().cast(); 33 | cdoc.url = document.url.toNativeUtf8().cast(); 34 | cdoc.domain = document.domain.toNativeUtf8().cast(); 35 | }); 36 | } 37 | } 38 | 39 | /// Gets the pointer. 40 | Pointer get ptr => _docs; 41 | 42 | /// Frees the memory. 43 | void free() { 44 | if (_docs != nullptr) { 45 | if (_docs.ref.data != nullptr) { 46 | for (var i = 0; i < _docs.ref.len; i++) { 47 | var cdoc = _docs.ref.data[i]; 48 | malloc.free(cdoc.id); 49 | malloc.free(cdoc.title); 50 | malloc.free(cdoc.snippet); 51 | malloc.free(cdoc.session); 52 | malloc.free(cdoc.query_id); 53 | malloc.free(cdoc.query_words); 54 | malloc.free(cdoc.url); 55 | malloc.free(cdoc.domain); 56 | } 57 | malloc.free(_docs.ref.data); 58 | } 59 | malloc.free(_docs); 60 | _docs = nullptr; 61 | } 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /bindings/dart/example/lib/data_provider/web.dart: -------------------------------------------------------------------------------- 1 | import 'dart:html' show window; 2 | import 'dart:typed_data' show ByteBuffer, BytesBuilder, Uint8List; 3 | 4 | import 'package:xayn_ai_ffi_dart/package.dart' 5 | show AssetType, getAssets, SetupData, WebFeature; 6 | 7 | import 'package:xayn_ai_ffi_dart_example/data_provider/data_provider.dart' 8 | show joinPaths; 9 | 10 | const _baseAssetUrl = './assets/assets'; 11 | 12 | /// Prepares and returns the data that is needed to init [`XaynAi`]. 13 | Future getInputData() async { 14 | final fetched = {}; 15 | final features = {}; 16 | 17 | // uncomment the following section to load the multithreaded version 18 | // final features = { 19 | // WebFeature.bulkMemory, 20 | // WebFeature.mutableGlobals, 21 | // WebFeature.threads, 22 | // }; 23 | 24 | for (var asset in getAssets(features: features).entries) { 25 | if (asset.value.fragments.isEmpty) { 26 | final path = joinPaths([_baseAssetUrl, asset.value.urlSuffix]); 27 | // We also load the wasm/worker script here in order to check its integrity/checksum. 28 | // The browser keeps it in cache so `injectWasmScript` does not download it again. 29 | final data = await _fetchAsset(path, asset.value.checksum.checksumSri); 30 | 31 | if (asset.key == AssetType.webWorkerScript || 32 | asset.key == AssetType.wasmScript) { 33 | fetched.putIfAbsent(asset.key, () => path); 34 | } else { 35 | fetched.putIfAbsent(asset.key, () => data); 36 | } 37 | } else { 38 | final builder = BytesBuilder(copy: false); 39 | for (var fragment in asset.value.fragments) { 40 | final path = joinPaths([_baseAssetUrl, fragment.urlSuffix]); 41 | final part = await _fetchAsset(path, fragment.checksum.checksumSri); 42 | builder.add(part); 43 | } 44 | fetched.putIfAbsent(asset.key, () => builder.takeBytes()); 45 | } 46 | } 47 | 48 | return SetupData(fetched); 49 | } 50 | 51 | Future _fetchAsset(String url, String checksum) async { 52 | try { 53 | final dynamic response = 54 | await window.fetch(url, {'integrity': checksum}); 55 | final arrayBuffer = await response.arrayBuffer() as ByteBuffer; 56 | return Uint8List.view(arrayBuffer); 57 | } catch (e) { 58 | return Future.error('error loading asset: $url, error: $e'); 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /bindings/dart/test/mobile/utils.dart: -------------------------------------------------------------------------------- 1 | import 'package:flutter_test/flutter_test.dart' 2 | show Matcher, predicate, throwsA; 3 | 4 | import 'package:xayn_ai_ffi_dart/src/common/data/document.dart' show Document; 5 | import 'package:xayn_ai_ffi_dart/src/common/data/history.dart' 6 | show UserFeedback, History, Relevance, DayOfWeek, UserAction; 7 | import 'package:xayn_ai_ffi_dart/src/common/result/error.dart' 8 | show Code, XaynAiException; 9 | import 'package:xayn_ai_ffi_dart/src/mobile/reranker/data_provider.dart' 10 | show getAssets, SetupData; 11 | 12 | SetupData mkSetupData() { 13 | return SetupData({ 14 | for (final asset in getAssets().entries) 15 | asset.key: '../../data/' + asset.value.urlSuffix 16 | }); 17 | } 18 | 19 | Document mkTestDoc(String id, String title, int rank) => Document( 20 | id: id, 21 | title: title, 22 | snippet: 'snippet of $title', 23 | rank: rank, 24 | session: 'fcb6a685-eb92-4d36-8686-000000000000', 25 | queryCount: 1, 26 | queryId: 'fcb6a685-eb92-4d36-8686-000000000000', 27 | queryWords: 'query words', 28 | url: 'url', 29 | domain: 'domain', 30 | ); 31 | 32 | History mkTestHist(String id, Relevance relevance, UserFeedback feedback) => 33 | History( 34 | id: id, 35 | relevance: relevance, 36 | userFeedback: feedback, 37 | session: 'fcb6a685-eb92-4d36-8686-000000000000', 38 | queryCount: 1, 39 | queryId: 'fcb6a685-eb92-4d36-8686-000000000000', 40 | queryWords: 'query words', 41 | day: DayOfWeek.mon, 42 | url: 'url', 43 | domain: 'domain', 44 | rank: 0, 45 | userAction: UserAction.miss, 46 | ); 47 | 48 | final histories = [ 49 | mkTestHist('fcb6a685-eb92-4d36-8686-000000000000', Relevance.low, 50 | UserFeedback.irrelevant), 51 | mkTestHist('fcb6a685-eb92-4d36-8686-000000000001', Relevance.high, 52 | UserFeedback.relevant), 53 | ]; 54 | 55 | final documents = [ 56 | mkTestDoc('fcb6a685-eb92-4d36-8686-000000000000', 'abc', 0), 57 | mkTestDoc('fcb6a685-eb92-4d36-8686-000000000001', 'def', 1), 58 | mkTestDoc('fcb6a685-eb92-4d36-8686-000000000002', 'ghi', 2), 59 | ]; 60 | 61 | Matcher throwsXaynAiException(Code code) => throwsA( 62 | predicate( 63 | (exception) => 64 | exception is XaynAiException && 65 | exception.code == code && 66 | exception.toString().isNotEmpty, 67 | ), 68 | ); 69 | -------------------------------------------------------------------------------- /bindings/dart/example/ios/Runner/Base.lproj/LaunchScreen.storyboard: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | -------------------------------------------------------------------------------- /xayn-ai/src/ltr/list_net/tests/mod.rs: -------------------------------------------------------------------------------- 1 | use std::f32::consts::SQRT_2; 2 | 3 | use once_cell::sync::Lazy; 4 | 5 | use layer::{activation::ActivationFunction, dense::Dense}; 6 | use test_utils::{assert_approx_eq, ltr::model}; 7 | 8 | use super::{model::ListNet, *}; 9 | 10 | mod inference; 11 | mod training; 12 | 13 | static LIST_NET: Lazy = 14 | Lazy::new(|| ListNet::deserialize_from_file(model().unwrap()).unwrap()); 15 | 16 | #[test] 17 | fn test_chunk_size_is_valid() { 18 | assert_eq!(ListNet::CHUNK_SIZE * 2, ListNet::INPUT_NR_DOCUMENTS); 19 | } 20 | 21 | #[test] 22 | fn test_random_weights_initialization() { 23 | let ListNet { 24 | dense1, 25 | dense2, 26 | scores, 27 | prob_dist, 28 | } = ListNet::new_with_random_weights(); 29 | 30 | test_layer(&dense1); 31 | test_layer(&dense2); 32 | test_layer(&scores); 33 | test_layer(&prob_dist); 34 | 35 | fn test_layer(layer: &Dense>) { 36 | for b in layer.bias().iter() { 37 | assert_approx_eq!(f32, b, 0.0, ulps = 9) 38 | } 39 | let weights = layer.weights(); 40 | let std = SQRT_2 / (weights.shape()[0] as f32).sqrt(); 41 | let limit = 2. * std; 42 | for &w in weights.iter() { 43 | assert!( 44 | -limit <= w && w <= limit, 45 | "out of bound weight: {} <= {} <= {}", 46 | -limit, 47 | w, 48 | limit 49 | ); 50 | } 51 | } 52 | } 53 | 54 | #[test] 55 | fn test_serialize_deserialize_list_net() { 56 | let list_net = ListNet::new_with_random_weights(); 57 | let mut buffer = Vec::new(); 58 | list_net.clone().serialize_into(&mut buffer).unwrap(); 59 | let list_net2 = ListNet::deserialize_from(&*buffer).unwrap(); 60 | assert_approx_eq!(f32, list_net.dense1.weights(), list_net2.dense1.weights()); 61 | assert_approx_eq!(f32, list_net.dense1.bias(), list_net2.dense1.bias()); 62 | assert_approx_eq!(f32, list_net.dense2.weights(), list_net2.dense2.weights()); 63 | assert_approx_eq!(f32, list_net.dense2.bias(), list_net2.dense2.bias()); 64 | assert_approx_eq!(f32, list_net.scores.weights(), list_net2.scores.weights()); 65 | assert_approx_eq!(f32, list_net.scores.bias(), list_net2.scores.bias()); 66 | assert_approx_eq!( 67 | f32, 68 | list_net.prob_dist.weights(), 69 | list_net2.prob_dist.weights() 70 | ); 71 | assert_approx_eq!(f32, list_net.prob_dist.bias(), list_net2.prob_dist.bias()); 72 | } 73 | -------------------------------------------------------------------------------- /bindings/dart/lib/src/mobile/data/history.dart: -------------------------------------------------------------------------------- 1 | import 'dart:ffi' show AllocatorAlloc, nullptr, Pointer, StructPointer, Uint8; 2 | 3 | import 'package:ffi/ffi.dart' show malloc, StringUtf8Pointer; 4 | 5 | import 'package:xayn_ai_ffi_dart/src/common/data/history.dart' 6 | show 7 | FeedbackToInt, 8 | History, 9 | RelevanceToInt, 10 | DayOfWeekToInt, 11 | UserActionToInt; 12 | import 'package:xayn_ai_ffi_dart/src/mobile/ffi/genesis.dart' 13 | show CHistories, CHistory; 14 | 15 | /// The raw document histories. 16 | class Histories { 17 | late Pointer _hists; 18 | 19 | /// Creates the document histories. 20 | /// 21 | /// This constructor never throws an exception. 22 | Histories(List histories) { 23 | _hists = malloc.call(); 24 | _hists.ref.len = histories.length; 25 | if (histories.isEmpty) { 26 | _hists.ref.data = nullptr; 27 | } else { 28 | _hists.ref.data = malloc.call(_hists.ref.len); 29 | histories.asMap().forEach((i, history) { 30 | var chist = _hists.ref.data[i]; 31 | chist.id = history.id.toNativeUtf8().cast(); 32 | chist.relevance = history.relevance.toInt(); 33 | chist.user_feedback = history.userFeedback.toInt(); 34 | chist.session = history.session.toNativeUtf8().cast(); 35 | chist.query_count = history.queryCount; 36 | chist.query_id = history.queryId.toNativeUtf8().cast(); 37 | chist.query_words = history.queryWords.toNativeUtf8().cast(); 38 | chist.day = history.day.toInt(); 39 | chist.url = history.url.toNativeUtf8().cast(); 40 | chist.domain = history.domain.toNativeUtf8().cast(); 41 | chist.rank = history.rank; 42 | chist.user_action = history.userAction.toInt(); 43 | }); 44 | } 45 | } 46 | 47 | /// Gets the pointer. 48 | Pointer get ptr => _hists; 49 | 50 | /// Frees the memory. 51 | void free() { 52 | if (_hists != nullptr) { 53 | if (_hists.ref.data != nullptr) { 54 | for (var i = 0; i < _hists.ref.len; i++) { 55 | var chist = _hists.ref.data[i]; 56 | malloc.free(chist.id); 57 | malloc.free(chist.session); 58 | malloc.free(chist.query_id); 59 | malloc.free(chist.query_words); 60 | malloc.free(chist.url); 61 | malloc.free(chist.domain); 62 | } 63 | malloc.free(_hists.ref.data); 64 | } 65 | malloc.free(_hists); 66 | _hists = nullptr; 67 | } 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /kpe/extract_params.py: -------------------------------------------------------------------------------- 1 | """ 2 | Call like `python3 extract_params.py model_in_path [params_out_dir]`. 3 | 4 | See `xayn_ai/extract_listnet_parameters.py` for constraints. 5 | """ 6 | 7 | from pathlib import Path 8 | import sys 9 | from torch import load, Tensor 10 | from typing import BinaryIO, Dict 11 | 12 | CNN: Dict[str, str] = { 13 | "cnn2gram.cnn_list.0.weight": "conv_1/weights", 14 | "cnn2gram.cnn_list.0.bias": "conv_1/bias", 15 | "cnn2gram.cnn_list.1.weight": "conv_2/weights", 16 | "cnn2gram.cnn_list.1.bias": "conv_2/bias", 17 | "cnn2gram.cnn_list.2.weight": "conv_3/weights", 18 | "cnn2gram.cnn_list.2.bias": "conv_3/bias", 19 | "cnn2gram.cnn_list.3.weight": "conv_4/weights", 20 | "cnn2gram.cnn_list.3.bias": "conv_4/bias", 21 | "cnn2gram.cnn_list.4.weight": "conv_5/weights", 22 | "cnn2gram.cnn_list.4.bias": "conv_5/bias", 23 | } 24 | 25 | CLASSIFIER: Dict[str, str] = { 26 | "classifier.weight": "dense/weights", 27 | "classifier.bias": "dense/bias", 28 | } 29 | 30 | BYTE_ORDER: str = "little" 31 | 32 | def write_integer(file: BinaryIO, integer: int): 33 | file.write(integer.to_bytes(8, BYTE_ORDER)) 34 | 35 | def write_string(file: BinaryIO, string: str): 36 | bytes = string.encode("utf-8") 37 | write_integer(file, len(bytes)) 38 | file.write(bytes) 39 | 40 | def write_tensor(file: BinaryIO, tensor: Tensor, transpose: bool): 41 | array = tensor.numpy() 42 | if transpose: 43 | array = array.transpose() 44 | array = array.astype( 45 | array.dtype.newbyteorder(BYTE_ORDER), 46 | order="C", 47 | casting="equiv", 48 | copy=False, 49 | ) 50 | 51 | write_integer(file, len(array.shape)) 52 | for value in array.shape: 53 | write_integer(file, value) 54 | 55 | write_integer(file, array.size) 56 | file.write(array.tobytes(order="C")) 57 | 58 | def write_layers(file: BinaryIO, layers: Dict[str, str], state: Dict[str, Tensor], transpose: bool): 59 | write_integer(file, len(layers)) 60 | for name, tensor in state.items(): 61 | if name in layers: 62 | write_string(file, layers[name]) 63 | write_tensor(file, tensor, transpose) 64 | 65 | if __name__ == "__main__": 66 | model = Path(sys.argv[1]).resolve() 67 | if len(sys.argv) > 2: 68 | params = Path(sys.argv[2]).resolve() 69 | else: 70 | params = model.parent 71 | 72 | state = load(model)["state_dict"] 73 | with open(params.joinpath("cnn.binparams"), "wb") as cnn: 74 | write_layers(cnn, CNN, state, False) 75 | with open(params.joinpath("classifier.binparams"), "wb") as classifier: 76 | write_layers(classifier, CLASSIFIER, state, True) 77 | -------------------------------------------------------------------------------- /bindings/dart/lib/src/common/reranker/ai.dart: -------------------------------------------------------------------------------- 1 | import 'dart:typed_data' show Uint8List; 2 | 3 | import 'package:xayn_ai_ffi_dart/src/common/data/document.dart' show Document; 4 | import 'package:xayn_ai_ffi_dart/src/common/data/history.dart' show History; 5 | import 'package:xayn_ai_ffi_dart/src/common/reranker/analytics.dart' 6 | show Analytics; 7 | import 'package:xayn_ai_ffi_dart/src/common/reranker/data_provider.dart' 8 | show SetupData; 9 | import 'package:xayn_ai_ffi_dart/src/common/reranker/mode.dart' show RerankMode; 10 | import 'package:xayn_ai_ffi_dart/src/common/result/outcomes.dart' 11 | show RerankingOutcomes; 12 | 13 | /// The Xayn AI. 14 | class XaynAi { 15 | /// Creates and initializes the Xayn AI from a given state. 16 | /// 17 | /// Requires the necessary [SetupData] and the state. 18 | /// It will throw an error if the provided state is empty. 19 | static Future restore(SetupData data, Uint8List serialized) async { 20 | throw UnsupportedError('Unsupported platform.'); 21 | } 22 | 23 | /// Creates and initializes the Xayn AI. 24 | /// 25 | /// Requires the necessary [SetupData] for the AI. 26 | static Future create(SetupData data) async { 27 | throw UnsupportedError('Unsupported platform.'); 28 | } 29 | 30 | /// Reranks the documents. 31 | /// 32 | /// The list of ranks is in the same order as the documents. 33 | Future rerank( 34 | RerankMode mode, 35 | List histories, 36 | List documents, 37 | ) async => 38 | throw UnsupportedError('Unsupported platform.'); 39 | 40 | /// Serializes the current state of the reranker. 41 | Future serialize() async => 42 | throw UnsupportedError('Unsupported platform.'); 43 | 44 | /// Retrieves faults which might occur during reranking. 45 | /// 46 | /// Faults can range from warnings to errors which are handled in some default way internally. 47 | Future> faults() async => 48 | throw UnsupportedError('Unsupported platform.'); 49 | 50 | /// Retrieves the analytics which were collected in the penultimate reranking. 51 | Future analytics() async => 52 | throw UnsupportedError('Unsupported platform.'); 53 | 54 | /// Serializes the synchronizable data of the reranker. 55 | Future syncdataBytes() async => 56 | throw UnsupportedError('Unsupported platform.'); 57 | 58 | /// Synchronizes the internal data of the reranker with another. 59 | Future synchronize(Uint8List serialized) async => 60 | throw UnsupportedError('Unsupported platform'); 61 | 62 | /// Frees the memory. 63 | Future free() async => throw UnsupportedError('Unsupported platform.'); 64 | } 65 | -------------------------------------------------------------------------------- /bindings/dart/lib/src/web/worker/message/response.dart: -------------------------------------------------------------------------------- 1 | import 'dart:typed_data' show Uint8List; 2 | 3 | import 'package:json_annotation/json_annotation.dart' show JsonSerializable; 4 | 5 | import 'package:xayn_ai_ffi_dart/src/common/reranker/analytics.dart' 6 | show Analytics; 7 | import 'package:xayn_ai_ffi_dart/src/common/result/error.dart' 8 | show XaynAiException; 9 | import 'package:xayn_ai_ffi_dart/src/common/utils.dart' show ToJson; 10 | import 'package:xayn_ai_ffi_dart/src/web/worker/message/utils.dart' 11 | show Uint8ListConverter; 12 | 13 | part 'response.g.dart'; 14 | 15 | /// The kind of the [Response]. 16 | enum Result { 17 | ok, 18 | exception, 19 | } 20 | 21 | /// A Response object that holds the result of the method invocation. 22 | @JsonSerializable() 23 | class Response implements ToJson { 24 | final Result kind; 25 | final Map? result; 26 | 27 | static Response fromResult(R result) => 28 | Response(Result.ok, result.toJson()); 29 | static Response fromException(XaynAiException exception) => 30 | Response(Result.exception, exception.toJson()); 31 | 32 | static const ok = Response(Result.ok, null); 33 | 34 | const Response(this.kind, this.result); 35 | 36 | bool isException() => kind == Result.exception ? true : false; 37 | 38 | factory Response.fromJson(Map json) => _$ResponseFromJson(json); 39 | 40 | @override 41 | Map toJson() => _$ResponseToJson(this); 42 | } 43 | 44 | /// A response that contains a [Uint8List] result. 45 | @JsonSerializable() 46 | class Uint8ListResponse implements ToJson { 47 | @Uint8ListConverter() 48 | final Uint8List data; 49 | 50 | Uint8ListResponse(this.data); 51 | 52 | factory Uint8ListResponse.fromJson(Map json) => 53 | _$Uint8ListResponseFromJson(json); 54 | 55 | @override 56 | Map toJson() => _$Uint8ListResponseToJson(this); 57 | } 58 | 59 | /// A response that holds the result of the `Method.faults` invocation. 60 | @JsonSerializable() 61 | class FaultsResponse implements ToJson { 62 | final List faults; 63 | 64 | FaultsResponse(this.faults); 65 | 66 | factory FaultsResponse.fromJson(Map json) => _$FaultsResponseFromJson(json); 67 | 68 | @override 69 | Map toJson() => _$FaultsResponseToJson(this); 70 | } 71 | 72 | /// A response that holds the result of the `Method.analytics` invocation. 73 | @JsonSerializable() 74 | class AnalyticsResponse implements ToJson { 75 | Analytics? analytics; 76 | 77 | AnalyticsResponse(this.analytics); 78 | 79 | factory AnalyticsResponse.fromJson(Map json) => 80 | _$AnalyticsResponseFromJson(json); 81 | 82 | @override 83 | Map toJson() => _$AnalyticsResponseToJson(this); 84 | } 85 | -------------------------------------------------------------------------------- /bindings/dart/lib/src/mobile/result/error.dart: -------------------------------------------------------------------------------- 1 | import 'dart:ffi' show AllocatorAlloc, nullptr, Pointer, StructPointer; 2 | 3 | import 'package:ffi/ffi.dart' show malloc, Utf8, Utf8Pointer; 4 | 5 | import 'package:xayn_ai_ffi_dart/src/common/ffi/genesis.dart' show CCode; 6 | import 'package:xayn_ai_ffi_dart/src/common/result/error.dart' 7 | show IntToCode, XaynAiException; 8 | import 'package:xayn_ai_ffi_dart/src/common/utils.dart' show assertNeq; 9 | import 'package:xayn_ai_ffi_dart/src/mobile/ffi/genesis.dart' show CError; 10 | import 'package:xayn_ai_ffi_dart/src/mobile/ffi/library.dart' show ffi; 11 | 12 | /// The Xayn AI error information. 13 | class XaynAiError { 14 | late Pointer _error; 15 | 16 | /// Creates the error information initialized to success. 17 | /// 18 | /// This constructor never throws an exception. 19 | XaynAiError() { 20 | _error = malloc.call(); 21 | _error.ref.code = CCode.None; 22 | _error.ref.message = nullptr; 23 | } 24 | 25 | /// Gets the pointer. 26 | Pointer get ptr => _error; 27 | 28 | /// Checks for a fault code. 29 | bool isFault() { 30 | assertNeq(_error, nullptr); 31 | return _error.ref.code == CCode.Fault; 32 | } 33 | 34 | /// Checks for an irrecoverable error code. 35 | bool isPanic() { 36 | assertNeq(_error, nullptr); 37 | return _error.ref.code == CCode.Panic; 38 | } 39 | 40 | /// Checks for a no error code. 41 | bool isNone() { 42 | assertNeq(_error, nullptr); 43 | return _error.ref.code == CCode.None; 44 | } 45 | 46 | /// Checks for an error code (both recoverable and irrecoverable). 47 | bool isError() => !isNone() && !isFault(); 48 | 49 | /// Creates an exception from the error information. 50 | XaynAiException toException() { 51 | assertNeq(_error, nullptr); 52 | assert( 53 | _error.ref.message == nullptr || 54 | (_error.ref.message.ref.data != nullptr && 55 | _error.ref.message.ref.len == 56 | _error.ref.message.ref.data.cast().length + 1), 57 | 'unexpected error pointer state', 58 | ); 59 | 60 | final code = _error.ref.code.toCode(); 61 | final message = _error.ref.message == nullptr 62 | ? '' 63 | : _error.ref.message.ref.data.cast().toDartString(); 64 | 65 | return XaynAiException(code, message); 66 | } 67 | 68 | /// Frees the memory. 69 | void free() { 70 | assert( 71 | _error == nullptr || 72 | _error.ref.message == nullptr || 73 | (_error.ref.message.ref.data != nullptr && 74 | _error.ref.message.ref.len == 75 | _error.ref.message.ref.data.cast().length + 1), 76 | 'unexpected error pointer state', 77 | ); 78 | 79 | if (_error != nullptr) { 80 | ffi.error_message_drop(_error); 81 | malloc.free(_error); 82 | _error = nullptr; 83 | } 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /rubert-tokenizer/src/lib.rs: -------------------------------------------------------------------------------- 1 | //! A Bert tokenizer which converts sequences into encodings. 2 | //! 3 | //! This is a very condensed and heavily refactored version of [huggingface's `tokenizers`] crate. 4 | //! 5 | //! The tokenizer is based on a word piece vocabulary and consists of a Bert normalizer, a Bert 6 | //! pre-tokenizer, a Bert word piece model and a Bert post-tokenizer including truncation and 7 | //! padding strategies. The encodings can be of any numerical data type which implements 8 | //! [`Num`]` + `[`FromPrimitive`]` + `[`Copy`]. 9 | //! 10 | //! The normalizer is configurable by: 11 | //! - Cleans any control characters and replaces all sorts of whitespace by ` `. 12 | //! - Separates Chinese characters by whitespace so they get split. 13 | //! - Keeps accents of characters. 14 | //! - Lowercases characters. 15 | //! 16 | //! The pre-tokenizer is not configurable. 17 | //! 18 | //! The word piece model is configurable by: 19 | //! - The unknown token. 20 | //! - The continuing subword prefix. 21 | //! - The maximum number of characters per word. 22 | //! 23 | //! The post-tokenizer is configurable by: 24 | //! - The class token. 25 | //! - The separation token. 26 | //! - A truncation strategy. 27 | //! - A padding strategy. 28 | //! 29 | //! ```no_run 30 | //! use rubert_tokenizer::{Builder, Padding, Truncation}; 31 | //! 32 | //! fn main() -> Result<(), Box> { 33 | //! let tokenizer = Builder::::from_file("vocab.txt")? 34 | //! .with_normalizer(true, true, false, true) 35 | //! .with_model("[UNK]", "##", 100) 36 | //! .with_post_tokenizer("[CLS]", "[SEP]") 37 | //! .with_truncation(Truncation::fixed(128, 0)) 38 | //! .with_padding(Padding::fixed(128, "[PAD]")) 39 | //! .build()?; 40 | //! 41 | //! let encoding = tokenizer.encode("This îs ã séquènce."); 42 | //! assert_eq!(tokenizer.decode(&encoding, true), "this is a sequence."); 43 | //! 44 | //! Ok(()) 45 | //! } 46 | //! ``` 47 | //! 48 | //! [huggingface's `tokenizers`]: https://crates.io/crates/tokenizers 49 | //! [`Num`]: num_traits::Num 50 | //! [`FromPrimitive`]: num_traits::FromPrimitive 51 | #![cfg_attr( 52 | doc, 53 | forbid(rustdoc::broken_intra_doc_links, rustdoc::private_intra_doc_links) 54 | )] 55 | #![forbid(unsafe_op_in_unsafe_fn)] 56 | 57 | mod builder; 58 | mod model; 59 | mod normalizer; 60 | mod post_tokenizer; 61 | mod pre_tokenizer; 62 | mod tokenizer; 63 | 64 | pub use crate::{ 65 | builder::{Builder, BuilderError}, 66 | model::ModelError, 67 | normalizer::string::Offsets, 68 | post_tokenizer::{ 69 | encoding::Encoding, 70 | padding::{Padding, PaddingError}, 71 | truncation::{Truncation, TruncationError}, 72 | PostTokenizerError, 73 | }, 74 | tokenizer::Tokenizer, 75 | }; 76 | 77 | /// A stack allocated string with a maximum length of eight bytes. 78 | type SmallString = smallstr::SmallString<[u8; 8]>; 79 | -------------------------------------------------------------------------------- /rubert-tokenizer/src/pre_tokenizer/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod string; 2 | 3 | use unicode_categories::UnicodeCategories; 4 | 5 | use crate::{ 6 | normalizer::string::{NormalizedString, SplitDelimiter}, 7 | pre_tokenizer::string::PreTokenizedString, 8 | }; 9 | 10 | /// A Bert pre-tokenizer. 11 | #[derive(Debug)] 12 | pub struct PreTokenizer; 13 | 14 | impl PreTokenizer { 15 | /// Pre-tokenizes the sequence. 16 | pub(crate) fn pre_tokenize(&self, sequence: NormalizedString) -> PreTokenizedString { 17 | PreTokenizedString::from(sequence) 18 | .split(|_, sequence| sequence.split(char::is_whitespace, SplitDelimiter::Remove)) 19 | .split(|_, sequence| { 20 | sequence.split( 21 | |c: char| c.is_ascii_punctuation() || c.is_punctuation(), 22 | SplitDelimiter::Isolate, 23 | ) 24 | }) 25 | } 26 | } 27 | 28 | #[cfg(test)] 29 | mod tests { 30 | use super::*; 31 | use crate::normalizer::string::Offsets; 32 | 33 | fn assert_eq(actual: PreTokenizedString, expected: Vec<(&str, Offsets)>) { 34 | assert_eq!(actual.splits.len(), expected.len()); 35 | for (split, (word, offset)) in actual.splits.iter().zip(expected) { 36 | assert_eq!(split.normalized, word); 37 | assert_eq!(split.offset + split.alignments.first().unwrap().0, offset.0); 38 | assert_eq!(split.offset + split.alignments.last().unwrap().1, offset.1); 39 | } 40 | } 41 | 42 | #[test] 43 | fn test_basic() { 44 | let pre_tokenized = PreTokenizer.pre_tokenize("Hey friend! How are you?!?".into()); 45 | let expected = vec![ 46 | ("Hey", Offsets(0, 3)), 47 | ("friend", Offsets(4, 10)), 48 | ("!", Offsets(10, 11)), 49 | ("How", Offsets(16, 19)), 50 | ("are", Offsets(20, 23)), 51 | ("you", Offsets(24, 27)), 52 | ("?", Offsets(27, 28)), 53 | ("!", Offsets(28, 29)), 54 | ("?", Offsets(29, 30)), 55 | ]; 56 | assert_eq(pre_tokenized, expected); 57 | } 58 | 59 | #[test] 60 | fn test_chinese() { 61 | let sequence = "野口里佳 Noguchi Rika"; 62 | let normalized = NormalizedString::from(sequence).transform( 63 | sequence.chars().flat_map(|c| { 64 | if (c as usize) > 0x4E00 { 65 | vec![(' ', 0), (c, 1), (' ', 1)] 66 | } else { 67 | vec![(c, 0)] 68 | } 69 | }), 70 | 0, 71 | ); 72 | let pre_tokenized = PreTokenizer.pre_tokenize(normalized); 73 | let expected = vec![ 74 | ("野", Offsets(0, 3)), 75 | ("口", Offsets(3, 6)), 76 | ("里", Offsets(6, 9)), 77 | ("佳", Offsets(9, 12)), 78 | ("Noguchi", Offsets(13, 20)), 79 | ("Rika", Offsets(21, 25)), 80 | ]; 81 | assert_eq(pre_tokenized, expected); 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /bindings/dart/lib/src/common/reranker/debug.dart: -------------------------------------------------------------------------------- 1 | import 'dart:convert' show base64Decode, base64Encode, jsonDecode, jsonEncode; 2 | import 'dart:typed_data' show Uint8List; 3 | 4 | import 'package:json_annotation/json_annotation.dart' 5 | show JsonKey, JsonSerializable; 6 | 7 | import 'package:xayn_ai_ffi_dart/src/common/data/document.dart' show Document; 8 | import 'package:xayn_ai_ffi_dart/src/common/data/history.dart' show History; 9 | import 'package:xayn_ai_ffi_dart/src/common/reranker/mode.dart' show RerankMode; 10 | 11 | part 'debug.g.dart'; 12 | 13 | /// Bundle of all reranking call data used for debugging purpose. 14 | /// 15 | /// This combines the data used for a reranking call including 16 | /// the documents, history and (optionally) the serialized state. 17 | /// 18 | /// It provides a to/from json serialization and is meant for Team Blue 19 | /// so that they can store JSON blobs of the Reranking call data for 20 | /// debugging purpose. 21 | /// 22 | /// Furthermore both the dart example/benchmark app and the dev-tool 23 | /// provide ways to run reranking based on the serialized call data 24 | /// and can be used for debugging. 25 | /// 26 | /// To make this possible it was made sure that the (JSON) serialization 27 | /// format between dart (native,js) and rust is the same. This also 28 | /// means that all fields are renamed to snake-case. E.g. `serializedState` 29 | /// gets encoded as `serialized_state`. 30 | /// 31 | @JsonSerializable() 32 | class RerankDebugCallData { 33 | /// The mode which was used to run the reranking. 34 | RerankMode rerankMode; 35 | 36 | /// History used for a reranking call. 37 | final List histories; 38 | 39 | /// Documents used for a reranking call. 40 | final List documents; 41 | 42 | /// Serialized state which should be used to run a reranking call. 43 | /// 44 | /// This is normally the state *before* histories/documents 45 | /// were used for a reranking call. 46 | /// 47 | /// This field is JSON encoded as base64 encoded string. 48 | @JsonKey(toJson: _optBytesToBase64, fromJson: _optBase64ToBytes) 49 | final Uint8List? serializedState; 50 | 51 | /// Creates a new instance. 52 | RerankDebugCallData({ 53 | required this.rerankMode, 54 | required this.histories, 55 | required this.documents, 56 | this.serializedState, 57 | }); 58 | 59 | /// Creates an instance from a JSON map. 60 | factory RerankDebugCallData.fromJson(Map json) => 61 | _$RerankDebugCallDataFromJson(json); 62 | 63 | /// Creates an instance from a JSON String. 64 | factory RerankDebugCallData.fromJsonString(String json) => 65 | RerankDebugCallData.fromJson(jsonDecode(json) as Map); 66 | 67 | /// Creates a JSON map based on this instance. 68 | /// 69 | /// Serialized state is included as a base64 encoded string. 70 | Map toJson() => _$RerankDebugCallDataToJson(this); 71 | 72 | /// Creates a JSON string based on this instance. 73 | String toJsonString() => jsonEncode(toJson()); 74 | } 75 | 76 | String? _optBytesToBase64(Uint8List? bytes) => 77 | bytes == null ? null : base64Encode(bytes); 78 | 79 | Uint8List? _optBase64ToBytes(String? base64) => 80 | base64 == null ? null : base64Decode(base64); 81 | -------------------------------------------------------------------------------- /xayn-ai-ffi-c/src/result/fault.rs: -------------------------------------------------------------------------------- 1 | use xayn_ai::Error; 2 | use xayn_ai_ffi::CCode; 3 | 4 | use crate::{result::error::CError, slice::CBoxedSlice, utils::IntoRaw}; 5 | 6 | /// The Xayn Ai faults. 7 | pub struct Faults(Vec); 8 | 9 | /// A raw slice of faults. 10 | pub type CFaults = CBoxedSlice; 11 | 12 | impl From<&[Error]> for Faults { 13 | fn from(faults: &[Error]) -> Self { 14 | Self(faults.iter().map(ToString::to_string).collect()) 15 | } 16 | } 17 | 18 | unsafe impl IntoRaw for Faults 19 | where 20 | CFaults: Sized, 21 | { 22 | // Safety: 23 | // CFaults is sized, hence Box is representable as a *mut CFaults and 24 | // Option> is applicable for the nullable pointer optimization. 25 | type Value = Option>; 26 | 27 | #[inline] 28 | fn into_raw(self) -> Self::Value { 29 | let faults = self 30 | .0 31 | .into_iter() 32 | .map(|message| CCode::Fault.with_context(message).into_raw()) 33 | .collect::>(); 34 | Some(Box::new(faults.into_boxed_slice().into())) 35 | } 36 | } 37 | 38 | /// Frees the memory of the faults. 39 | /// 40 | /// # Safety 41 | /// The behavior is undefined if: 42 | /// - A non-null `faults` doesn't point to memory allocated by [`xaynai_faults()`]. 43 | /// - A non-null `faults` is freed more than once. 44 | /// - A non-null `faults` is accessed after being freed. 45 | /// 46 | /// [`xaynai_faults()`]: crate::reranker::ai::xaynai_faults 47 | #[no_mangle] 48 | pub unsafe extern "C" fn faults_drop(_faults: Option>) {} 49 | 50 | #[cfg(test)] 51 | mod tests { 52 | use itertools::izip; 53 | 54 | use super::*; 55 | 56 | struct TestFaults(Vec); 57 | 58 | impl Default for TestFaults { 59 | fn default() -> Self { 60 | Self( 61 | (0..10) 62 | .map(|idx| Error::msg(format!("fault {}", idx))) 63 | .collect(), 64 | ) 65 | } 66 | } 67 | 68 | #[test] 69 | fn test_from_faults() { 70 | let buffer = TestFaults::default().0; 71 | let faults = Faults::from(buffer.as_slice()); 72 | assert_eq!(faults.0.len(), buffer.len()); 73 | for (fault, error) in izip!(faults.0, buffer) { 74 | assert_eq!(fault, error.to_string()); 75 | } 76 | } 77 | 78 | #[test] 79 | fn test_from_empty() { 80 | let faults = Faults::from(Vec::new().as_slice()); 81 | assert!(faults.0.is_empty()); 82 | } 83 | 84 | #[test] 85 | fn test_into_raw() { 86 | let buffer = TestFaults::default(); 87 | let faults = Faults::from(buffer.0.as_slice()).into_raw().unwrap(); 88 | 89 | for (fault, error) in izip!(faults.as_slice(), buffer.0) { 90 | assert_eq!(fault.code, CCode::Fault); 91 | assert_eq!(fault.message.as_ref().unwrap().as_str(), error.to_string(),); 92 | } 93 | } 94 | 95 | #[test] 96 | fn test_into_empty() { 97 | let faults = Faults(Vec::new()).into_raw().unwrap(); 98 | assert!(faults.as_slice().is_empty()); 99 | } 100 | } 101 | -------------------------------------------------------------------------------- /xayn-ai-ffi-c/src/reranker/analytics.rs: -------------------------------------------------------------------------------- 1 | use crate::utils::IntoRaw; 2 | 3 | /// The analytics of the reranker. 4 | pub(super) struct Analytics(pub(crate) Option); 5 | 6 | /// The raw analytics of the reranker. 7 | #[repr(C)] 8 | pub struct CAnalytics { 9 | /// The nDCG@k score between the LTR ranking and the relevance based ranking. 10 | pub ndcg_ltr: f32, 11 | /// The nDCG@k score between the Context ranking and the relevance based ranking. 12 | pub ndcg_context: f32, 13 | /// The nDCG@k score between the initial ranking and the relevance based ranking. 14 | pub ndcg_initial_ranking: f32, 15 | /// The nDCG@k score between the final ranking and the relevance based ranking. 16 | pub ndcg_final_ranking: f32, 17 | } 18 | 19 | unsafe impl IntoRaw for Analytics 20 | where 21 | CAnalytics: Sized, 22 | { 23 | // Safety: 24 | // CAnalytics is sized, hence Box is representable as a *mut CAnalytics and 25 | // Option> is eligible for the nullable pointer optimization. 26 | type Value = Option>; 27 | 28 | #[inline] 29 | fn into_raw(self) -> Self::Value { 30 | self.0.map(|analytics| { 31 | Box::new(CAnalytics { 32 | ndcg_ltr: analytics.ndcg_ltr, 33 | ndcg_context: analytics.ndcg_context, 34 | ndcg_initial_ranking: analytics.ndcg_initial_ranking, 35 | ndcg_final_ranking: analytics.ndcg_final_ranking, 36 | }) 37 | }) 38 | } 39 | } 40 | 41 | /// Frees the memory of the analytics. 42 | /// 43 | /// # Safety 44 | /// The behavior is undefined if: 45 | /// - A non-null `analytics` doesn't point to memory allocated by [`xaynai_analytics()`]. 46 | /// - A non-null `analytics` is freed more than once. 47 | /// - A non-null `analytics` is accessed after being freed. 48 | /// 49 | /// [`xaynai_analytics()`]: crate::reranker::ai::xaynai_analytics 50 | #[no_mangle] 51 | pub unsafe extern "C" fn analytics_drop(_analytics: Option>) {} 52 | 53 | #[cfg(test)] 54 | mod tests { 55 | use test_utils::assert_approx_eq; 56 | 57 | use super::*; 58 | 59 | #[test] 60 | fn test_convert_some_analytics_to_c_analytics() { 61 | let analytics = Analytics(Some(xayn_ai::Analytics { 62 | ndcg_ltr: 0.25, 63 | ndcg_context: 0.75, 64 | ndcg_initial_ranking: 1.125, 65 | ndcg_final_ranking: 2.825, 66 | })); 67 | 68 | let c_analytics = analytics.into_raw().unwrap(); 69 | 70 | assert_approx_eq!(f32, c_analytics.ndcg_ltr, 0.25, ulps = 0); 71 | assert_approx_eq!(f32, c_analytics.ndcg_context, 0.75, ulps = 0); 72 | assert_approx_eq!(f32, c_analytics.ndcg_initial_ranking, 1.125, ulps = 0); 73 | assert_approx_eq!(f32, c_analytics.ndcg_final_ranking, 2.825, ulps = 0); 74 | 75 | unsafe { 76 | analytics_drop(Some(c_analytics)); 77 | } 78 | } 79 | 80 | #[test] 81 | fn test_convert_none_analytics_to_c_analytics() { 82 | let analytics = Analytics(None); 83 | let c_analytics = analytics.into_raw(); 84 | 85 | assert!(c_analytics.is_none()); 86 | 87 | unsafe { 88 | analytics_drop(c_analytics); 89 | } 90 | } 91 | } 92 | -------------------------------------------------------------------------------- /rubert/src/config.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | fs::File, 3 | io::{BufRead, BufReader, Read}, 4 | marker::PhantomData, 5 | path::Path, 6 | }; 7 | 8 | use displaydoc::Display; 9 | use thiserror::Error; 10 | 11 | use crate::{model::BertModel, NonePooler}; 12 | 13 | #[derive(Debug, Display, Error)] 14 | pub enum ConfigError { 15 | /// The token size must be greater than two to allow for special tokens 16 | TokenSize, 17 | /// Failed to load a data file: {0} 18 | DataFile(#[from] std::io::Error), 19 | } 20 | 21 | pub struct Config<'a, K, P> { 22 | pub(crate) model_kind: PhantomData, 23 | pub(crate) vocab: Box, 24 | pub(crate) model: Box, 25 | pub(crate) accents: bool, 26 | pub(crate) lowercase: bool, 27 | pub(crate) token_size: usize, 28 | pub(crate) pooler: P, 29 | } 30 | 31 | impl<'a, K: BertModel> Config<'a, K, NonePooler> { 32 | pub fn from_readers( 33 | vocab: Box, 34 | model: Box, 35 | ) -> Self { 36 | Config { 37 | model_kind: Default::default(), 38 | vocab, 39 | model, 40 | accents: false, 41 | lowercase: true, 42 | token_size: 128, 43 | pooler: NonePooler, 44 | } 45 | } 46 | 47 | pub fn from_files( 48 | vocab: impl AsRef, 49 | model: impl AsRef, 50 | ) -> Result { 51 | let vocab = Box::new(BufReader::new(File::open(vocab)?)); 52 | let model = Box::new(BufReader::new(File::open(model)?)); 53 | Ok(Self::from_readers(vocab, model)) 54 | } 55 | } 56 | 57 | impl<'a, K: BertModel, P> Config<'a, K, P> { 58 | /// Whether the tokenizer keeps accents. 59 | /// 60 | /// Defaults to `false`. 61 | pub fn with_accents(mut self, accents: bool) -> Self { 62 | self.accents = accents; 63 | self 64 | } 65 | 66 | /// Whether the tokenizer lowercases. 67 | /// 68 | /// Defaults to `true`. 69 | pub fn with_lowercase(mut self, lowercase: bool) -> Self { 70 | self.lowercase = lowercase; 71 | self 72 | } 73 | 74 | /// Sets the token size for the tokenizer and the model. 75 | /// 76 | /// Defaults to [`K::TOKEN_RANGE`]. 77 | /// 78 | /// # Errors 79 | /// Fails if `size` is less than two or greater than 512. 80 | pub fn with_token_size(mut self, size: usize) -> Result { 81 | if K::TOKEN_RANGE.contains(&size) { 82 | self.token_size = size; 83 | Ok(self) 84 | } else { 85 | Err(ConfigError::TokenSize) 86 | } 87 | } 88 | 89 | /// Sets pooling for the model. 90 | /// 91 | /// Defaults to `NonePooler`. 92 | pub fn with_pooling(self, pooler: NP) -> Config<'a, K, NP> { 93 | Config { 94 | vocab: self.vocab, 95 | model: self.model, 96 | model_kind: self.model_kind, 97 | accents: self.accents, 98 | lowercase: self.lowercase, 99 | token_size: self.token_size, 100 | pooler, 101 | } 102 | } 103 | } 104 | -------------------------------------------------------------------------------- /rubert-tokenizer/src/model/string.rs: -------------------------------------------------------------------------------- 1 | use std::borrow::Cow; 2 | 3 | use crate::{ 4 | model::Model, 5 | normalizer::string::{NormalizedString, Offsets}, 6 | pre_tokenizer::string::PreTokenizedString, 7 | }; 8 | 9 | /// A token relative to a sequence. 10 | pub struct Token { 11 | pub id: N, 12 | pub value: String, 13 | pub offsets: Offsets, 14 | } 15 | 16 | /// A subpart of a normalized string. 17 | pub struct Split { 18 | pub normalized: NormalizedString, 19 | pub tokens: Vec>, 20 | } 21 | 22 | impl From for Split { 23 | fn from(string: NormalizedString) -> Self { 24 | Self { 25 | normalized: string, 26 | tokens: Vec::new(), 27 | } 28 | } 29 | } 30 | 31 | /// A tokenized sequence. 32 | pub struct TokenizedString { 33 | pub splits: Vec>, 34 | } 35 | 36 | impl From for TokenizedString { 37 | fn from(string: PreTokenizedString) -> Self { 38 | Self { 39 | splits: string.splits.into_iter().map(Into::into).collect(), 40 | } 41 | } 42 | } 43 | 44 | impl TokenizedString 45 | where 46 | N: Copy, 47 | { 48 | /// Tokenizes wrt the model parameters. 49 | pub fn tokenize(mut self, model: &Model) -> Self { 50 | self.splits.iter_mut().for_each(|split| { 51 | let string = split.normalized.normalized.as_str(); 52 | let len = string.len(); 53 | if string.chars().count() > model.max_chars { 54 | split.tokens = vec![Token { 55 | id: model.unk_id, 56 | value: model.unk_token.to_string(), 57 | offsets: Offsets(0, len), 58 | }] 59 | } else { 60 | let mut start = 0; 61 | while start < len { 62 | let mut end = len; 63 | start = loop { 64 | if start >= end { 65 | split.tokens = vec![Token { 66 | id: model.unk_id, 67 | value: model.unk_token.to_string(), 68 | offsets: Offsets(0, len), 69 | }]; 70 | return; 71 | } 72 | 73 | let sub_str = if start > 0 { 74 | Cow::Owned([model.prefix.as_str(), &string[start..end]].join("")) 75 | } else { 76 | Cow::Borrowed(&string[start..end]) 77 | }; 78 | 79 | if let Some(id) = model.vocab.get(sub_str.as_ref()) { 80 | split.tokens.push(Token { 81 | id: *id, 82 | value: sub_str.into_owned(), 83 | offsets: Offsets(start, end), 84 | }); 85 | break end; 86 | } else { 87 | end -= sub_str.chars().last().map_or(1, |c| c.len_utf8()); 88 | } 89 | } 90 | } 91 | }; 92 | }); 93 | 94 | self 95 | } 96 | } 97 | --------------------------------------------------------------------------------