├── .dockerignore
├── data
└── .gitignore
├── rust-toolchain.toml
├── model.zstd_dict
├── public
└── favicon.ico
├── .idea
├── .gitignore
├── misc.xml
├── vcs.xml
├── modules.xml
├── inspectionProfiles
│ └── Project_Default.xml
└── compchat.iml
├── src
├── component
│ ├── mod.rs
│ ├── chat_message.rs
│ ├── chat.rs
│ ├── navbar.rs
│ ├── dropdown.rs
│ ├── prompt_input.rs
│ └── prompt_section.rs
├── lib.rs
├── chat.rs
├── backend
│ ├── trainer.rs
│ ├── training_options.rs
│ ├── tokenizer.rs
│ ├── mod.rs
│ ├── ensemble_model.rs
│ ├── evaluation.rs
│ ├── clm_model.rs
│ └── dataset.rs
├── main.rs
├── fileserv.rs
├── error_template.rs
├── app.rs
├── model.rs
└── tuning.rs
├── .gitignore
├── README.md
├── docker-compose.yml
├── LICENSE
├── Dockerfile
├── train_tokenizer.py
├── tune.py
├── Cargo.toml
├── style
└── main.scss
├── tokenizer.json
└── Cargo.lock
/.dockerignore:
--------------------------------------------------------------------------------
1 | target
--------------------------------------------------------------------------------
/data/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/rust-toolchain.toml:
--------------------------------------------------------------------------------
1 | [toolchain]
2 | channel = "nightly"
3 |
--------------------------------------------------------------------------------
/model.zstd_dict:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SamuelLess/ChatCLM/HEAD/model.zstd_dict
--------------------------------------------------------------------------------
/public/favicon.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SamuelLess/ChatCLM/HEAD/public/favicon.ico
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Editor-based HTTP Client requests
5 | /httpRequests/
6 |
--------------------------------------------------------------------------------
/src/component/mod.rs:
--------------------------------------------------------------------------------
1 | pub mod chat;
2 | pub mod chat_message;
3 | pub mod dropdown;
4 | pub mod navbar;
5 | pub mod prompt_input;
6 | pub mod prompt_section;
7 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Generated by Cargo
2 | # will have compiled files and executables
3 | /target/
4 | pkg
5 |
6 | # These are backup files generated by rustfmt
7 | **/*.rs.bk
8 |
9 | # node e2e test tools and outputs
10 | node_modules/
11 | test-results/
12 | end2end/playwright-report/
13 | playwright/.cache/
14 |
15 | wandb/
16 | *.checkpoint
17 |
--------------------------------------------------------------------------------
/src/lib.rs:
--------------------------------------------------------------------------------
1 | #![feature(lazy_cell)]
2 |
3 | pub mod app;
4 | pub mod chat;
5 | pub mod component;
6 | pub mod error_template;
7 | #[cfg(feature = "ssr")]
8 | pub mod fileserv;
9 | pub mod model;
10 | #[cfg(feature = "ssr")]
11 | pub mod backend;
12 |
13 |
14 | #[cfg(feature = "hydrate")]
15 | #[wasm_bindgen::prelude::wasm_bindgen]
16 | pub fn hydrate() {
17 | use crate::app::*;
18 | console_error_panic_hook::set_once();
19 | leptos::mount_to_body(App);
20 | }
21 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # chatCLM
2 |
3 | [chatCLM.xyz](https://chatclm.xyz)
4 |
5 | The CLM describes itself as follows:
6 | > You are chatCLM. A zstd compression based LM. You are chatCLM. A zstd compression based LM. You are chatCLM. A zstd compression based LM.
7 |
8 |
9 | ## Overview
10 | Welcome to _chatCLM_, a Rust-based project utilizing the [leptos](https://leptos.dev) framework to build a compression-based Large Language Model (LLM) using zstd and the OpenAI tokenizer.
11 | This project aims to create an efficient, high-performance LLM by leveraging the power of compression algorithms.
12 |
--------------------------------------------------------------------------------
/src/component/chat_message.rs:
--------------------------------------------------------------------------------
1 | use crate::chat::Message;
2 | use leptos::{component, view, IntoView, Show};
3 |
4 | #[component]
5 | pub fn ChatMessage(msg: Message) -> impl IntoView {
6 | let is_user_msg = msg.is_user_msg();
7 |
8 | view! {
9 |
10 |
11 |
14 |
15 |
16 |
{msg.message}
17 |
18 | }
19 | }
20 |
--------------------------------------------------------------------------------
/src/component/chat.rs:
--------------------------------------------------------------------------------
1 | use crate::chat::ChatHistory;
2 | use crate::component::chat_message::ChatMessage;
3 | use leptos::{component, view, For, IntoView, ReadSignal};
4 |
5 | #[component]
6 | pub fn Chat(chat: ReadSignal) -> impl IntoView {
7 | view! {
8 |
9 | }
14 | }
15 | />
16 |
17 |
18 | }
19 | }
20 |
--------------------------------------------------------------------------------
/docker-compose.yml:
--------------------------------------------------------------------------------
1 | services:
2 | web:
3 | build: .
4 | expose:
5 | - "8080"
6 | labels:
7 | - "traefik.enable=true"
8 | - "traefik.http.routers.chatclm.rule=Host(`chatclm.nglodny.de`) || Host(`chatclm.xyz`)"
9 | - "traefik.http.routers.chatclm.tls=true"
10 | - "traefik.http.routers.chatclm.tls.certresolver=letsencrypt"
11 | - "traefik.http.middlewares.chatclm.compress=true"
12 | - "treafik.http.routers.chatclm.middlewares=chatclm@docker"
13 | - "traefik.http.routers.chatclm.entrypoints=web,websecure"
14 | networks:
15 | - traefik
16 | networks:
17 | traefik:
18 | external: true
19 |
20 |
--------------------------------------------------------------------------------
/.idea/compchat.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
--------------------------------------------------------------------------------
/src/component/navbar.rs:
--------------------------------------------------------------------------------
1 | use crate::component::dropdown::Dropdown;
2 | use crate::model::FrontendModel;
3 | use leptos::{component, view, IntoView, ReadSignal, WriteSignal};
4 |
5 | #[component]
6 | pub fn NavBar(
7 | selected_model_index: ReadSignal,
8 | set_selected_model_index: WriteSignal,
9 | ) -> impl IntoView {
10 | view! {
11 |
24 | }
25 | }
26 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | This is free and unencumbered software released into the public domain.
2 |
3 | Anyone is free to copy, modify, publish, use, compile, sell, or
4 | distribute this software, either in source code form or as a compiled
5 | binary, for any purpose, commercial or non-commercial, and by any
6 | means.
7 |
8 | In jurisdictions that recognize copyright laws, the author or authors
9 | of this software dedicate any and all copyright interest in the
10 | software to the public domain. We make this dedication for the benefit
11 | of the public at large and to the detriment of our heirs and
12 | successors. We intend this dedication to be an overt act of
13 | relinquishment in perpetuity of all present and future rights to this
14 | software under copyright law.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
17 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
18 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
19 | IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR
20 | OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
21 | ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
22 | OTHER DEALINGS IN THE SOFTWARE.
23 |
24 | For more information, please refer to
25 |
--------------------------------------------------------------------------------
/src/chat.rs:
--------------------------------------------------------------------------------
1 | #[derive(Debug, Clone, PartialEq, Eq)]
2 | pub enum Sender {
3 | User,
4 | ChatCLM,
5 | }
6 |
7 | #[derive(Debug, Clone, PartialEq, Eq)]
8 | pub struct Message {
9 | pub message: String,
10 | pub time_iso: String,
11 | pub sender: Sender,
12 | }
13 |
14 | impl Message {
15 | pub fn is_user_msg(&self) -> bool {
16 | self.sender == Sender::User
17 | }
18 |
19 | pub fn new(message: String, sender: Sender) -> Self {
20 | Self {
21 | message,
22 | time_iso: chrono::Local::now().format("%Y-%m-%d %H:%M:%S").to_string(),
23 | sender,
24 | }
25 | }
26 | }
27 |
28 | #[derive(Debug, Clone, PartialEq, Eq, Default)]
29 | pub struct ChatHistory {
30 | pub messages: Vec,
31 | }
32 |
33 | impl ChatHistory {
34 | fn add_message(&mut self, message: Message) {
35 | self.messages.push(message);
36 | }
37 |
38 | pub fn new_server_message(&mut self, message: String) {
39 | self.add_message(Message::new(message, Sender::ChatCLM));
40 | }
41 |
42 | pub fn new_user_message(&mut self, message: String) {
43 | self.add_message(Message::new(message, Sender::User));
44 | }
45 |
46 | pub fn replace_last_server_message(&mut self, message: String) {
47 | if let Some(last_message) = self.messages.last_mut() {
48 | if last_message.sender == Sender::ChatCLM {
49 | last_message.message = message;
50 | }
51 | }
52 | }
53 | }
54 |
--------------------------------------------------------------------------------
/src/backend/trainer.rs:
--------------------------------------------------------------------------------
1 | use std::ffi::{c_uint, c_void};
2 | use itertools::Itertools;
3 | use zstd_sys::{ZDICT_isError, ZDICT_optimizeTrainFromBuffer_fastCover};
4 | use crate::backend::{BYTES_PER_TOKEN, Token, tokens_to_bytes};
5 | use crate::backend::clm_model::ClmModel;
6 | use crate::backend::training_options::TrainingOptions;
7 |
8 | pub fn train_model<'a>(input_tokens: &Vec>, training_options: &TrainingOptions) -> ClmModel<'a> {
9 |
10 | if input_tokens.is_empty() {
11 | return ClmModel::from_buffer(vec![]);
12 | }
13 |
14 | let raw_data = input_tokens.iter().flat_map(tokens_to_bytes).collect_vec();
15 | let sizes = input_tokens.iter().map(|x| x.len() * BYTES_PER_TOKEN).collect_vec();
16 | let buffer_size = (raw_data.len() as f64 * training_options.dictionary_size_percentage) as usize;
17 | assert_eq!(sizes.iter().sum::(), raw_data.len(), "Sizes sum doesn't match raw data size");
18 | let mut buffer = vec![0u8; buffer_size];
19 | let mut parameters = training_options.to_zdict_params();
20 | let size;
21 | unsafe {
22 | size = ZDICT_optimizeTrainFromBuffer_fastCover(
23 | buffer.as_mut_ptr() as *mut c_void,
24 | buffer_size,
25 | raw_data.as_ptr() as *mut c_void,
26 | sizes.as_ptr(),
27 | sizes.len() as c_uint,
28 | &mut parameters,
29 | );
30 |
31 | if ZDICT_isError(size) != 0 {
32 | panic!("Failed to train dictionary");
33 | }
34 | }
35 | buffer.resize(size, 0);
36 | ClmModel::from_buffer(buffer)
37 | }
--------------------------------------------------------------------------------
/src/main.rs:
--------------------------------------------------------------------------------
1 | use chatclm::app::App;
2 |
3 | #[cfg(feature = "ssr")]
4 | use chatclm::fileserv::file_and_error_handler;
5 |
6 |
7 | #[cfg(feature = "ssr")]
8 | #[tokio::main]
9 | async fn leptos_main() {
10 | use axum::Router;
11 | use leptos::*;
12 | use leptos_axum::{generate_route_list, LeptosRoutes};
13 |
14 | // Setting get_configuration(None) means we'll be using cargo-leptos's env values
15 | // For deployment these variables are:
16 | //
17 | // Alternately a file can be specified such as Some("Cargo.toml")
18 | // The file would need to be included with the executable when moved to deployment
19 | let conf = get_configuration(None).await.unwrap();
20 | let leptos_options = conf.leptos_options;
21 | let addr = leptos_options.site_addr;
22 | let routes = generate_route_list(App);
23 |
24 | // build our application with a route
25 | let app = Router::new()
26 | .leptos_routes(&leptos_options, routes, App)
27 | .fallback(file_and_error_handler)
28 | .with_state(leptos_options);
29 |
30 | let listener = tokio::net::TcpListener::bind(&addr).await.unwrap();
31 | logging::log!("listening on http://{}", &addr);
32 | axum::serve(listener, app.into_make_service())
33 | .await
34 | .unwrap();
35 | }
36 | #[cfg(feature = "ssr")]
37 | fn main() {
38 | leptos_main();
39 | }
40 |
41 | #[cfg(not(feature = "ssr"))]
42 | pub fn main() {
43 | // no client-side main function
44 | // unless we want this to work with e.g., Trunk for a purely client-side app
45 | // see lib.rs for hydration function instead
46 | }
47 |
--------------------------------------------------------------------------------
/src/component/dropdown.rs:
--------------------------------------------------------------------------------
1 | use leptos::{
2 | component, create_signal, view, For, IntoView, ReadSignal, SignalSet, SignalUpdate, WriteSignal,
3 | };
4 |
5 | #[component]
6 | pub fn Dropdown(
7 | options: [&'static str; N],
8 | selected_option_index: ReadSignal,
9 | set_selected_option_index: WriteSignal,
10 | ) -> impl IntoView {
11 | let (is_open, set_open) = create_signal(false);
12 |
13 | view! {
14 |
15 |
21 |
22 |
{move || options[selected_option_index()]}
23 |
>
24 |
25 |
26 |
38 |
39 | {option}
40 |
41 | }
42 | }
43 | />
44 |
45 |
46 |
47 | }
48 | }
49 |
--------------------------------------------------------------------------------
/src/backend/training_options.rs:
--------------------------------------------------------------------------------
1 | use std::ffi::c_int;
2 |
3 | pub struct TrainingOptions {
4 | pub d: u32,
5 | pub f: u32,
6 | pub k: u32,
7 | pub steps: u32,
8 | pub nb_threads: u32,
9 | pub split_point: f64,
10 | pub accel: u32,
11 | pub shrink_dict: u32,
12 | pub shrink_dict_max_regression: u32,
13 | pub compression_level: u32,
14 | pub dictionary_size_percentage: f64 /* 0.0 to 1.0, how big the dictionary should be compared to the input data */,
15 | pub ensemble_size: usize, /* number of models to train */
16 | }
17 |
18 | impl TrainingOptions {
19 | pub fn new() -> Self {
20 | TrainingOptions {
21 | d: 8,
22 | f: 25,
23 | k: 50,
24 | steps: 4,
25 | nb_threads: 8,
26 | split_point: 0.0,
27 | accel: 1,
28 | shrink_dict: 0,
29 | shrink_dict_max_regression: 0,
30 | compression_level: 3,
31 | dictionary_size_percentage: 1.0,
32 | ensemble_size: 1,
33 | }
34 | }
35 |
36 | pub fn to_zdict_params(&self) -> zstd_sys::ZDICT_fastCover_params_t {
37 | zstd_sys::ZDICT_fastCover_params_t {
38 | k: self.k,
39 | d: self.d,
40 | f: self.f,
41 | steps: self.steps,
42 | nbThreads: self.nb_threads,
43 | splitPoint: self.split_point,
44 | accel: self.accel,
45 | shrinkDict: self.shrink_dict,
46 | shrinkDictMaxRegression: self.shrink_dict_max_regression,
47 | zParams: zstd_sys::ZDICT_params_t {
48 | compressionLevel: self.compression_level as c_int,
49 | notificationLevel: 2,
50 | dictID: 0,
51 | },
52 | }
53 | }
54 |
55 | pub fn default() -> Self {
56 | TrainingOptions::new()
57 | }
58 | }
59 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | # Get started with a build env with Rust nightly
2 | FROM rustlang/rust:nightly-bookworm as builder
3 |
4 | # If you’re using stable, use this instead
5 | # FROM rust:1.74-bullseye as builder
6 |
7 | # Install cargo-binstall, which makes it easier to install other
8 | # cargo extensions like cargo-leptos
9 | RUN wget https://github.com/cargo-bins/cargo-binstall/releases/latest/download/cargo-binstall-x86_64-unknown-linux-musl.tgz
10 | RUN tar -xvf cargo-binstall-x86_64-unknown-linux-musl.tgz
11 | RUN cp cargo-binstall /usr/local/cargo/bin
12 |
13 | # Install cargo-leptos
14 | RUN cargo binstall cargo-leptos@0.2.17 -y
15 |
16 | # Add the WASM target
17 | RUN rustup target add wasm32-unknown-unknown
18 |
19 | # Make an /app dir, which everything will eventually live in
20 | RUN mkdir -p /app
21 | WORKDIR /app
22 | # Ignore the ./target dir
23 | COPY . .
24 |
25 | # Build the app
26 | RUN cargo leptos build --release -vv
27 |
28 | FROM debian:bookworm-slim as runtime
29 | WORKDIR /app
30 | RUN apt-get update -y \
31 | && apt-get install -y --no-install-recommends openssl ca-certificates \
32 | && apt-get autoremove -y \
33 | && apt-get clean -y \
34 | && rm -rf /var/lib/apt/lists/*
35 |
36 | # Copy the server binary to the /app directory
37 | COPY --from=builder /app/target/release/chatclm /app/
38 |
39 | # /target/site contains our JS/WASM/CSS, etc.
40 | COPY --from=builder /app/target/site /app/site
41 |
42 | # Copy Cargo.toml if it’s needed at runtime
43 | COPY --from=builder /app/Cargo.toml /app/
44 |
45 | # Copy the zstd dictionary as it's needed at runtime
46 | COPY --from=builder /app/model.zstd_dict /app/
47 |
48 | # Set any required env variables and
49 | ENV RUST_LOG="info"
50 | ENV LEPTOS_SITE_ADDR="0.0.0.0:8080"
51 | ENV LEPTOS_SITE_ROOT="site"
52 | EXPOSE 8080
53 |
54 | # -- NB: update binary name from "chatclm" to match your app name in Cargo.toml --
55 | # Run the server
56 | CMD ["/app/chatclm"]
57 |
--------------------------------------------------------------------------------
/src/fileserv.rs:
--------------------------------------------------------------------------------
1 | use crate::app::App;
2 | use axum::response::Response as AxumResponse;
3 | use axum::{
4 | body::Body,
5 | extract::State,
6 | http::{Request, Response, StatusCode},
7 | response::IntoResponse,
8 | };
9 | use leptos::*;
10 | use tower::ServiceExt;
11 | use tower_http::services::ServeDir;
12 |
13 | pub async fn file_and_error_handler(
14 | State(options): State,
15 | req: Request,
16 | ) -> AxumResponse {
17 | let root = options.site_root.clone();
18 | let (parts, body) = req.into_parts();
19 |
20 | let mut static_parts = parts.clone();
21 | static_parts.headers.clear();
22 | if let Some(encodings) = parts.headers.get("accept-encoding") {
23 | static_parts
24 | .headers
25 | .insert("accept-encoding", encodings.clone());
26 | }
27 |
28 | let res = get_static_file(Request::from_parts(static_parts, Body::empty()), &root)
29 | .await
30 | .unwrap();
31 |
32 | if res.status() == StatusCode::OK {
33 | res.into_response()
34 | } else {
35 | let handler = leptos_axum::render_app_to_stream(options.to_owned(), App);
36 | handler(Request::from_parts(parts, body))
37 | .await
38 | .into_response()
39 | }
40 | }
41 |
42 | async fn get_static_file(
43 | request: Request,
44 | root: &str,
45 | ) -> Result, (StatusCode, String)> {
46 | // `ServeDir` implements `tower::Service` so we can call it with `tower::ServiceExt::oneshot`
47 | // This path is relative to the cargo root
48 | match ServeDir::new(root)
49 | .precompressed_gzip()
50 | .precompressed_br()
51 | .oneshot(request)
52 | .await
53 | {
54 | Ok(res) => Ok(res.into_response()),
55 | Err(err) => Err((
56 | StatusCode::INTERNAL_SERVER_ERROR,
57 | format!("Error serving files: {err}"),
58 | )),
59 | }
60 | }
61 |
--------------------------------------------------------------------------------
/src/component/prompt_input.rs:
--------------------------------------------------------------------------------
1 | use leptos::{
2 | component, create_node_ref, create_signal, html, view, Callback, IntoView, NodeRef, Show,
3 | };
4 |
5 | #[component]
6 | pub fn PromptInput(on_submit: Callback) -> impl IntoView {
7 | let (show_placeholder, set_show_placeholder) = create_signal(true);
8 | let textarea_element: NodeRef = create_node_ref();
9 |
10 | let submit = move || {
11 | let input = textarea_element().unwrap().inner_text().trim().to_string();
12 | if input.len() == 0 {
13 | return;
14 | }
15 |
16 | on_submit(input);
17 | textarea_element().unwrap().set_inner_html("");
18 | };
19 |
20 | view! {
21 |
53 | }
54 | }
55 |
--------------------------------------------------------------------------------
/train_tokenizer.py:
--------------------------------------------------------------------------------
1 | import tqdm
2 | from tokenizers.normalizers import *
3 | from tokenizers import Tokenizer, decoders, pre_tokenizers
4 | from tokenizers.models import BPE
5 | from tokenizers.pre_tokenizers import ByteLevel
6 | from tokenizers.trainers import BpeTrainer
7 |
8 | files = [f"data/tokenizer_data.txt"]
9 |
10 | def data_iterator():
11 | # Removes all characters, that are only used once in the entire dataset
12 | chars_used_once = set()
13 | char_counter = {}
14 |
15 | for file in files:
16 | with open(file, "r") as f:
17 | for line in tqdm.tqdm(f.readlines(), desc=f"Counting characters {file}"):
18 | for char in line:
19 | if char in char_counter:
20 | char_counter[char] += 1
21 | else:
22 | char_counter[char] = 1
23 |
24 | for char, count in char_counter.items():
25 | if count <= 2:
26 | chars_used_once.add(char)
27 |
28 | print(f"Removing characters:")
29 | print(chars_used_once)
30 |
31 | for file in files:
32 | with open(file, "r") as f:
33 | for line in tqdm.tqdm(f.readlines(), desc=f"Processing {file}"):
34 | line_without_chars = "".join([char for char in line if char not in chars_used_once])
35 | yield line_without_chars
36 |
37 | tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
38 |
39 | trainer = BpeTrainer(special_tokens=["[UNK]"], vocab_size = 255, show_progress=True, end_of_word_suffix="")
40 | tokenizer.pre_tokenizer = pre_tokenizers.Sequence( [pre_tokenizers.WhitespaceSplit() ])
41 | tokenizer.normalizer = Sequence([Lowercase(), NFD(), StripAccents()])
42 | tokenizer.train_from_iterator(data_iterator(), trainer=trainer)
43 | tokenizer.decoder = decoders.BPEDecoder()
44 |
45 | tokenizer.save("tokenizer.json")
46 |
47 | encoded = tokenizer.encode("This is a test!")
48 | print(encoded.ids)
49 |
50 | decoded = tokenizer.decode(encoded.ids)
51 | print(decoded)
52 |
53 | # Try to load the tokenizer
54 | tokenizer = Tokenizer.from_file("tokenizer.json")
55 | encoded = tokenizer.encode("This is á test!")
56 | print(encoded.ids)
57 | decoded = tokenizer.decode(encoded.ids)
58 | print(decoded)
59 |
--------------------------------------------------------------------------------
/tune.py:
--------------------------------------------------------------------------------
1 | # Used as an adapter between wandb and rust code
2 | import argparse
3 | import json
4 | import subprocess
5 |
6 | import wandb
7 |
8 | sweep_config = {
9 | "method": "bayes",
10 | "metric": {
11 | "goal": "maximize",
12 | "name": "train_inf_gain"
13 | },
14 | "parameters": {
15 | "datasetSize": {
16 | "distribution": "int_uniform",
17 | "max": 180_000_000,
18 | "min": 1_000_000,
19 | },
20 | "dictionarySizePercentage" : {
21 | "distribution": "uniform",
22 | "max": 1,
23 | "min": 0
24 | },
25 |
26 | "compressionLevel": {
27 | "distribution": "int_uniform",
28 | "max": 8,
29 | "min": 1
30 | },
31 | "d": {
32 | "values": [6, 8]
33 | },
34 | "f": {
35 | "distribution": "int_uniform",
36 | "max": 26,
37 | "min": 5
38 | },
39 | "k": {
40 | "distribution": "int_uniform",
41 | "max": 2048,
42 | "min": 16
43 | },
44 | }
45 |
46 | }
47 |
48 | def evaluate():
49 | run = wandb.init()
50 |
51 | conf_dict = {}
52 | for param in sweep_config["parameters"].keys():
53 | conf_dict[param] = wandb.config[param]
54 |
55 | # Create a json object with the arguments
56 | arguments = json.dumps(conf_dict)
57 | print("Starting with arguments:")
58 | print(arguments)
59 | # Start the binary, write the arguments to its stdin and read the output
60 | command = ["target/release/tuning"]
61 |
62 | process = subprocess.Popen(command, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
63 | output, stderr = process.communicate(arguments)
64 | print(output)
65 | print(stderr)
66 |
67 |
68 |
69 | # Parse the last line as json
70 | try:
71 | output = json.loads(output.splitlines()[-1])
72 | wandb.log(output)
73 | except json.JSONDecodeError:
74 | print("Error decoding json")
75 | wandb.log({})
76 | except IndexError:
77 | print("Error parsing output")
78 | wandb.log({})
79 |
80 |
81 |
82 |
83 | # Initialize a new sweep
84 | sweep_id = wandb.sweep(sweep_config)
85 | wandb.agent(sweep_id=sweep_id, function=evaluate)
86 |
87 |
88 |
89 |
90 |
91 |
--------------------------------------------------------------------------------
/src/component/prompt_section.rs:
--------------------------------------------------------------------------------
1 | use crate::chat::ChatHistory;
2 | use crate::component::prompt_input::PromptInput;
3 | use crate::model::{cut_prompt, get_next_token};
4 | use leptos::{
5 | component, spawn_local, view, Callback, IntoView, ReadSignal, SignalUpdate, WriteSignal,
6 | };
7 |
8 | #[component]
9 | pub fn PromptSection(
10 | set_chat: WriteSignal,
11 | selected_model_index: ReadSignal,
12 | ) -> impl IntoView {
13 | view! {
14 |
15 |
16 |
50 |
51 |
ChatCLM can make mistakes. Check important info.
52 |
53 |
54 | }
55 | }
56 |
--------------------------------------------------------------------------------
/src/error_template.rs:
--------------------------------------------------------------------------------
1 | use http::status::StatusCode;
2 | use leptos::*;
3 | use thiserror::Error;
4 |
5 | #[derive(Clone, Debug, Error)]
6 | pub enum AppError {
7 | #[error("Not Found")]
8 | NotFound,
9 | }
10 |
11 | impl AppError {
12 | pub fn status_code(&self) -> StatusCode {
13 | match self {
14 | AppError::NotFound => StatusCode::NOT_FOUND,
15 | }
16 | }
17 | }
18 |
19 | // A basic function to display errors served by the error boundaries.
20 | // Feel free to do more complicated things here than just displaying the error.
21 | #[component]
22 | pub fn ErrorTemplate(
23 | #[prop(optional)] outside_errors: Option,
24 | #[prop(optional)] errors: Option>,
25 | ) -> impl IntoView {
26 | let errors = match outside_errors {
27 | Some(e) => create_rw_signal(e),
28 | None => match errors {
29 | Some(e) => e,
30 | None => panic!("No Errors found and we expected errors!"),
31 | },
32 | };
33 | // Get Errors from Signal
34 | let errors = errors.get_untracked();
35 |
36 | // Downcast lets us take a type that implements `std::error::Error`
37 | let errors: Vec = errors
38 | .into_iter()
39 | .filter_map(|(_k, v)| v.downcast_ref::().cloned())
40 | .collect();
41 | println!("Errors: {errors:#?}");
42 |
43 | // Only the response code for the first error is actually sent from the server
44 | // this may be customized by the specific application
45 | #[cfg(feature = "ssr")]
46 | {
47 | use leptos_axum::ResponseOptions;
48 | let response = use_context::();
49 | if let Some(response) = response {
50 | response.set_status(errors[0].status_code());
51 | }
52 | }
53 |
54 | view! {
55 | {if errors.len() > 1 { "Errors" } else { "Error" }}
56 | {error_code.to_string()}
67 | "Error: " {error_string}
68 | }
69 | }
70 | />
71 | }
72 | }
73 |
--------------------------------------------------------------------------------
/src/app.rs:
--------------------------------------------------------------------------------
1 | use crate::chat::ChatHistory;
2 | use crate::component::chat::Chat;
3 | use crate::component::navbar::NavBar;
4 | use crate::component::prompt_section::PromptSection;
5 | use crate::error_template::{AppError, ErrorTemplate};
6 | use leptos::*;
7 | use leptos_meta::*;
8 | use leptos_router::*;
9 |
10 | #[component]
11 | pub fn App() -> impl IntoView {
12 | // Provides context that manages stylesheets, titles, meta tags, etc.
13 | provide_meta_context();
14 |
15 | // create chat as reactive signal object
16 | let (chat, set_chat) = create_signal(ChatHistory::default());
17 | let (selected_model_index, set_selected_model_index) = create_signal(0usize);
18 | // fill with dummy data
19 | set_chat.update(|chat| {
20 | chat.new_server_message("Welcome to ChatCLM!".to_string());
21 | chat.new_user_message("Hello!".to_string());
22 | chat.new_server_message("Type a message and press Enter to chat.".to_string());
23 | });
24 |
25 | view! {
26 |
27 |
28 | // sets the document title
29 |
30 |
31 | // content for this welcome page
32 | }.into_view()
36 | }>
37 |
38 |
39 |
49 | }
50 | }
51 | />
52 |
53 |
54 |
55 | }
56 | }
57 |
58 | #[component]
59 | fn HomePage(
60 | chat: ReadSignal,
61 | set_chat: WriteSignal,
62 | selected_model_index: ReadSignal,
63 | set_selected_model_index: WriteSignal,
64 | ) -> impl IntoView {
65 | view! {
66 |
70 |
71 |
72 |
73 |
74 | }
75 | }
76 |
--------------------------------------------------------------------------------
/src/backend/tokenizer.rs:
--------------------------------------------------------------------------------
1 | use itertools::Itertools;
2 | use tiktoken_rs::{CoreBPE, p50k_base};
3 | use tokenizers::tokenizer::Tokenizer;
4 |
5 | use crate::backend::{MAX_TOKEN, Token};
6 |
7 | static TOKENIZER_PATH: &str = "tokenizer.json";
8 |
9 | #[derive(Clone)]
10 | pub enum ClmTokenizer {
11 | GPT2(CoreBPE),
12 | Custom(Tokenizer),
13 | }
14 |
15 | impl ClmTokenizer {
16 | #[allow(dead_code)]
17 | fn new_gpt2() -> Self {
18 | let tokenizer = p50k_base().unwrap();
19 | ClmTokenizer::GPT2(tokenizer)
20 | }
21 |
22 | pub fn get_max_token(&self) -> Token {
23 | match self {
24 | ClmTokenizer::GPT2(_) => MAX_TOKEN as Token,
25 | ClmTokenizer::Custom(tokenizer) => (tokenizer.get_vocab_size(true) - 1) as Token,
26 | }
27 | }
28 |
29 | pub(crate) fn new_custom() -> Self {
30 | let tokenizer = Tokenizer::from_file(TOKENIZER_PATH).unwrap();
31 | ClmTokenizer::Custom(tokenizer)
32 | }
33 |
34 | pub(crate) fn encode(&self, text: &str) -> Vec {
35 | match self {
36 | ClmTokenizer::GPT2(tokenizer) => tokenizer.encode_ordinary(text).iter().map(|&x| x as Token).collect(),
37 | ClmTokenizer::Custom(tokenizer) => {
38 | let encoding = tokenizer.encode(text, false).unwrap();
39 | let ids = encoding.get_ids();
40 | ids.iter().map(|&x| x as Token).collect()
41 | }
42 | }
43 | }
44 |
45 | pub(crate) fn decode(&self, tokens: Vec) -> String {
46 | match self {
47 | ClmTokenizer::GPT2(tokenizer) => tokenizer.decode(tokens.iter().map(|&x| x as usize).collect_vec()).unwrap(),
48 | ClmTokenizer::Custom(tokenizer) => {
49 | tokenizer.decode(tokens.iter().map(|&x| x as u32).collect_vec().as_ref(), false).unwrap()
50 | }
51 | }
52 | }
53 | }
54 |
55 | #[cfg(test)]
56 | mod tests {
57 | use crate::backend::tokenizer::ClmTokenizer;
58 |
59 | #[test]
60 | fn custom_tokenizer_works() {
61 | let tokenizer = ClmTokenizer::new_custom();
62 | let encoding = tokenizer.encode("Hello, world!");
63 | let decoded = tokenizer.decode(encoding);
64 | assert_eq!(decoded, "hello, world!");
65 | }
66 |
67 | #[test]
68 | fn gpt2_tokenizer_works() {
69 | let tokenizer = ClmTokenizer::new_gpt2();
70 | let encoding = tokenizer.encode("Hello, world!");
71 | let decoded = tokenizer.decode(encoding);
72 | assert_eq!(decoded, "Hello, world!");
73 | }
74 |
75 | #[test]
76 | fn custom_tokenizer_special_characters() {
77 | let tokenizer = ClmTokenizer::new_custom();
78 | let encoding = tokenizer.encode("Hello, world! 🌍");
79 | let decoded = tokenizer.decode(encoding);
80 | assert_eq!(decoded, "hello, world! [UNK]");
81 | }
82 | }
83 |
84 |
--------------------------------------------------------------------------------
/src/model.rs:
--------------------------------------------------------------------------------
1 | use std::sync::LazyLock;
2 | use leptos::{server, ServerFnError};
3 | #[cfg(feature = "ssr")]
4 | use crate::backend::clm_model::ClmModel;
5 | #[derive(Copy, Clone)]
6 | pub enum FrontendModel {
7 | ChatCLM1_0,
8 | ChatGPT3_5,
9 | ChatGPT4o,
10 | ChatRandom,
11 | }
12 |
13 | impl FrontendModel {
14 | pub fn name(&self) -> &'static str {
15 | match self {
16 | FrontendModel::ChatCLM1_0 => "ChatCLM 0.1-pre-alpha",
17 | FrontendModel::ChatGPT3_5 => "ChatGPT 3.5",
18 | FrontendModel::ChatGPT4o => "ChatGPT 4o",
19 | FrontendModel::ChatRandom => "ChatRandom",
20 | }
21 | }
22 |
23 | pub fn from_index(index: usize) -> Self {
24 | match index {
25 | 0 => FrontendModel::ChatCLM1_0,
26 | 1 => FrontendModel::ChatGPT3_5,
27 | 2 => FrontendModel::ChatGPT4o,
28 | 3 => FrontendModel::ChatRandom,
29 | _ => panic!("Invalid model index"),
30 | }
31 | }
32 |
33 |
34 | #[cfg(feature = "ssr")]
35 | pub async fn predict_next_token(model_idx: usize, prompt: String) -> Option {
36 | match Self::from_index(model_idx) {
37 | FrontendModel::ChatCLM1_0 => chat_clm_next_token(prompt).await,
38 | FrontendModel::ChatGPT4o => gpt4o_next_token(prompt).await,
39 | FrontendModel::ChatRandom => random_next_token(prompt).await,
40 | _ => random_next_token(prompt).await,
41 | }
42 | }
43 | }
44 |
45 | pub async fn random_next_token(prompt: String) -> Option {
46 | // sleep 200 ms
47 | let random_number = rand::random::() % 7 + 1;
48 | if random_number == 1 {
49 | None
50 | } else {
51 | Some(format!("{} next", prompt))
52 | }
53 | }
54 |
55 | #[cfg(feature = "ssr")]
56 | static CLM: LazyLock = LazyLock::new(|| ClmModel::from_checkpoint("clm_model.bin"));
57 |
58 | #[cfg(feature = "ssr")]
59 | pub async fn chat_clm_next_token(prompt: String ) -> Option {
60 |
61 | if prompt.len() > 250 {
62 | return None;
63 | }
64 |
65 | Some(CLM.predict_next(prompt, 1, 3))
66 |
67 | }
68 |
69 | pub async fn gpt4o_next_token(_prompt: String) -> Option {
70 | /*let client = Client::new();
71 |
72 | let request = CreateCompletionRequestArgs::default()
73 | .model("gpt-3.5-turbo-instruct")
74 | .prompt(prompt)
75 | .max_tokens(40_u16)
76 | .build().unwrap();
77 |
78 | let result = client.completions().create(request).await;
79 | if let Ok(response) = result {
80 | return Some(response.choices.clone()[0].clone().text);
81 | }*/
82 | None
83 | }
84 |
85 | #[server(GetNextToken, "/api")]
86 | pub async fn get_next_token(
87 | model_idx: usize,
88 | prompt: String,
89 | ) -> Result