├── .clang-format ├── .github └── workflows │ └── tests.yml ├── .gitignore ├── CMakeLists.txt ├── DEVELOPERS.md ├── LICENSE ├── README.md ├── asset ├── demo_chat.png └── demo_chat_website.png ├── examples ├── example.py ├── simple-webchat │ ├── requirements.txt │ └── streamlit_demo.py └── webchat │ ├── README.md │ ├── __init__.py │ ├── process_web.py │ ├── requirements.txt │ └── streamlit_webchat.py ├── model └── .gitignore ├── pyproject.toml ├── requirements.txt ├── scripts └── release.sh ├── setup.py ├── src ├── _pygemma │ ├── __init__.pyi │ ├── gemma_binding.cpp │ └── gemma_binding.h └── pygemma │ ├── __init__.py │ └── gemma.py └── tests └── test_gemma.py /.clang-format: -------------------------------------------------------------------------------- 1 | Language: Cpp 2 | BasedOnStyle: Google 3 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | pull_request: 8 | branches: 9 | - master 10 | 11 | jobs: 12 | download-model: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/checkout@v3 16 | - name: Download model from Kaggle 17 | run: | 18 | mkdir -p model 19 | curl -L -u ${{ secrets.KAGGLE_USERNAME }}:${{ secrets.KAGGLE_KEY }} \ 20 | -o model/model.tar.gz \ 21 | "https://www.kaggle.com/api/v1/models/google/gemma/gemmaCpp/2b-it-mqa/1/download" 22 | shell: bash 23 | - name: Upload model as artifact 24 | uses: actions/upload-artifact@v4 25 | with: 26 | name: model 27 | path: model/ 28 | 29 | tests: 30 | needs: download-model 31 | runs-on: ${{ matrix.os }} 32 | strategy: 33 | matrix: 34 | os: [ubuntu-latest, windows-latest, macos-latest] 35 | python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] 36 | 37 | steps: 38 | - uses: actions/checkout@v3 39 | - name: Set up Python ${{ matrix.python-version }} 40 | uses: actions/setup-python@v4 41 | with: 42 | python-version: ${{ matrix.python-version }} 43 | - name: Download model artifact 44 | uses: actions/download-artifact@v4 45 | with: 46 | name: model 47 | path: model/ 48 | - name: Uncompress model files 49 | run: | 50 | tar -xzf model/model.tar.gz -C model 51 | rm model/model.tar.gz 52 | shell: bash 53 | - name: Install dependencies 54 | run: | 55 | python -m pip install --upgrade pip 56 | pip install .[test] -v 57 | - name: Test with pytest 58 | run: pytest tests/ 59 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | # Vscode 163 | .vscode/ 164 | 165 | # Project 166 | .pre-commit-config.yaml 167 | models/ 168 | fixed_wheels 169 | playground 170 | db.json 171 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.11) # or whatever minimum version you require 2 | 3 | include(FetchContent) 4 | 5 | project(gemma_cpp_python) 6 | 7 | set(CMAKE_CXX_STANDARD 17) 8 | set(CMAKE_CXX_STANDARD_REQUIRED ON) 9 | set(CMAKE_POSITION_INDEPENDENT_CODE ON) 10 | 11 | FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp GIT_TAG 7122afed5a89c082fac028ab152cc50af3e57386) 12 | FetchContent_MakeAvailable(gemma) 13 | 14 | FetchContent_Declare(pybind11 GIT_REPOSITORY https://github.com/pybind/pybind11.git GIT_TAG v2.10.4) 15 | FetchContent_MakeAvailable(pybind11) 16 | 17 | # Create the Python module 18 | pybind11_add_module(_pygemma src/_pygemma/gemma_binding.cpp) 19 | 20 | target_link_libraries(_pygemma PRIVATE libgemma) 21 | 22 | FetchContent_GetProperties(gemma) 23 | target_include_directories(_pygemma PRIVATE ${gemma_SOURCE_DIR}) 24 | -------------------------------------------------------------------------------- /DEVELOPERS.md: -------------------------------------------------------------------------------- 1 | ## 🤝 Contributing 2 | Contributions are welcome. Please clone the repository, push your changes to a new branch, and submit a pull request. 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Nam D. Tran 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # gemma-cpp-python: Python Bindings for [gemma.cpp](https://github.com/google/gemma.cpp) 2 | 3 | **Latest Version: v0.1.3.post3** 4 | - Fixed absolute path for libsentencepiece.0.0.0.dylib 5 | - Interface changes due to updates in gemma.cpp. 6 | - Enhanced user experience for ease of use 🙏. Give it a try! 7 | 8 | [![License: MIT](https://img.shields.io/badge/license-MIT-blue.svg)](LICENSE) 9 | 10 | `gemma-cpp-python` provides Python bindings for `gemma`, a high-performance C++ library, bridging advanced computational functionalities to Python. 11 | 12 | ## 🙏 Acknowledgments 13 | Special thanks to the creators and contributors of [gemma.cpp](https://github.com/google/gemma.cpp) for their foundational work. 14 | 15 | ## 💬 Demo Chat and Chat with Website! 16 | Check out the new chat demo included in the examples directory! This interactive interface showcases how you can engage in real-time conversations with the Gemma model. 17 | 18 | For the Chat with Website, please visit the [tutorial](examples/webchat/README.md) for more detail/ 19 | 20 | 21 | ### Using Gemma to chat with website 22 | ![Gemma Cpp Python Chat with Website Demo](asset/demo_chat_website.png) 23 | 24 | ### Chat with Gemma 25 | ![Gemma Cpp Python Chat Demo](asset/demo_chat.png) 26 | 27 | 28 | ## 🛠 Installation 29 | `Prerequisites`: Ensure Python 3.8+ and pip are installed. 30 | 31 | `System requirements`: For now, I only tested it on the Unix-like Platforms and the MacOS. Please visit the [gemma.cpp installation](https://github.com/google/gemma.cpp?tab=readme-ov-file#system-requirements) for more details. 32 | 33 | `Models`: pygemma supported 2b-it-sfp model for now, to install model, [please visit here](https://github.com/google/gemma.cpp?tab=readme-ov-file#step-1-obtain-model-weights-and-tokenizer-from-kaggle-or-hugging-face-hub) 34 | 35 | ### Install from PyPI 36 | For a quick setup, install directly from PyPI: 37 | ```bash 38 | pip install pygemma==0.1.3 39 | ``` 40 | 41 | ### For Developers: Install from Source 42 | To install the latest version or for development purposes: 43 | 44 | 1. Clone the repo and enter the directory: 45 | ```bash 46 | git clone https://github.com/namtranase/gemma-cpp-python.git 47 | cd gemma-cpp-python 48 | ``` 49 | 50 | 2. Install Python dependencies and pygemma: 51 | ```bash 52 | pip install . 53 | ``` 54 | 55 | ## 🖥 Usage 56 | 57 | To acctually run the model, you need to install the model followed on the [gemma.cpp](https://github.com/google/gemma.cpp?tab=readme-ov-file#step-1-obtain-model-weights-and-tokenizer-from-kaggle) repo 58 | 59 | For usage examples, refer to tests/test_chat.py. Here's a quick start: 60 | ```bash 61 | from pygemma import Gemma 62 | gemma = Gemma() 63 | gemma.show_help() 64 | gemma.show_config() 65 | gemma.load_model("/path/to/tokenizer", "/path/to/compressed_weight/", "model_type") 66 | gemma.completion("Write a poem") 67 | ``` 68 | 69 | To run the demo on your local machine: 70 | ```bash 71 | cd gemma-cpp-python/examples 72 | pip install -r requirements.txt 73 | streamlit run streamlit_demo.py 74 | ``` 75 | 76 | ## 🤝 Contributing 77 | Contributions are welcome. Please clone the repository, push your changes to a new branch, and submit a pull request. 78 | 79 | ## License 80 | gemma-cpp-python is MIT licensed. See the LICENSE file for details. 81 | -------------------------------------------------------------------------------- /asset/demo_chat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/namtranase/gemma-cpp-python/9164d4e4a57d34e41c7ad2296175e00e5fee01ec/asset/demo_chat.png -------------------------------------------------------------------------------- /asset/demo_chat_website.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/namtranase/gemma-cpp-python/9164d4e4a57d34e41c7ad2296175e00e5fee01ec/asset/demo_chat_website.png -------------------------------------------------------------------------------- /examples/example.py: -------------------------------------------------------------------------------- 1 | from pygemma import Gemma, ModelType, ModelTraining 2 | from time import time 3 | 4 | TOKENIZER_PATH = "../model/tokenizer.spm" 5 | COMPRESSED_WEIGHTS_PATH = "../model/2b-it-mqa.sbs" 6 | MODEL_TYPE = ModelType.Gemma2B 7 | MODEL_TRAINING = ModelTraining.GEMMA_IT 8 | 9 | 10 | def main(): 11 | gemma = Gemma( 12 | tokenizer_path=TOKENIZER_PATH, 13 | compressed_weights_path=COMPRESSED_WEIGHTS_PATH, 14 | model_type=MODEL_TYPE, 15 | model_training=MODEL_TRAINING, 16 | ) 17 | 18 | start_time = time() 19 | res = gemma("Hello world!") 20 | print(f"Generated: {res}") 21 | 22 | print(f"Elapsed time: {time() - start_time:.2f}s") 23 | 24 | 25 | if __name__ == "__main__": 26 | main() 27 | -------------------------------------------------------------------------------- /examples/simple-webchat/requirements.txt: -------------------------------------------------------------------------------- 1 | streamlit 2 | -------------------------------------------------------------------------------- /examples/simple-webchat/streamlit_demo.py: -------------------------------------------------------------------------------- 1 | import time 2 | import streamlit as st 3 | from pygemma import Gemma 4 | 5 | st.set_page_config(page_title="Gemma 💬") 6 | st.title("Gemma Cpp Python Chat Demo 🎈") 7 | 8 | # Initialize session state for the model and its load status 9 | if "model_loaded" not in st.session_state: 10 | st.session_state["model_loaded"] = False 11 | st.session_state["gemma"] = None 12 | st.session_state["messages"] = [ 13 | {"role": "assistant", "content": "How may I help you?"} 14 | ] 15 | 16 | 17 | @st.cache_resource 18 | def load_gemma_model(tokenizer_path, weights_path, model_type): 19 | gemma = Gemma() 20 | gemma.load_model(tokenizer_path, weights_path, model_type) 21 | return gemma 22 | 23 | 24 | # Sidebar for model configuration 25 | with st.sidebar: 26 | st.title("Gemma Config") 27 | tokenizer_path = st.text_input( 28 | "Tokenizer path", value="", placeholder="tokenizer.spm" 29 | ) 30 | weights_path = st.text_input( 31 | "Compressed weights path", value="", placeholder="2b-it-sfp.sbs" 32 | ) 33 | model_type = st.text_input("Model type", value="2b-it", placeholder="2b-it") 34 | 35 | # Load model button in the sidebar 36 | if st.button("Load Model"): 37 | st.session_state["gemma"] = load_gemma_model( 38 | tokenizer_path, weights_path, model_type 39 | ) 40 | st.session_state["model_loaded"] = True 41 | 42 | # Indicate whether the model is loaded 43 | if st.session_state["model_loaded"]: 44 | st.sidebar.success("Model Loaded Successfully!") 45 | else: 46 | st.sidebar.warning('Model Not Loaded. Click "Load Model" to load the model.') 47 | 48 | st.markdown( 49 | "📖 Check the detail at [gemma-cpp-python](https://github.com/namtranase/gemma-cpp-python)!" 50 | ) 51 | 52 | # Store LLM generated responses 53 | if "messages" not in st.session_state: 54 | st.session_state.messages = [ 55 | {"role": "assistant", "content": "How may I help you?"} 56 | ] 57 | 58 | # Store LLM generated responses 59 | if "messages" not in st.session_state.keys(): 60 | st.session_state.messages = [ 61 | {"role": "assistant", "content": "How may I help you?"} 62 | ] 63 | 64 | # Display chat messages 65 | for message in st.session_state.messages: 66 | with st.chat_message(message["role"]): 67 | st.write(message["content"]) 68 | 69 | # Function for generating LLM response 70 | def generate_response(prompt_input): 71 | # Hugging Face Login 72 | if st.session_state.model_loaded: 73 | return st.session_state.gemma.completion(prompt_input) 74 | else: 75 | return "Please load the model first." 76 | 77 | 78 | # User-provided prompt 79 | if prompt := st.chat_input(disabled=not (st.session_state["model_loaded"])): 80 | st.session_state.messages.append({"role": "user", "content": prompt}) 81 | with st.chat_message("user"): 82 | st.write(prompt) 83 | 84 | # Generate a new response if last message is not from assistant 85 | if st.session_state.messages[-1]["role"] != "assistant": 86 | with st.chat_message("assistant"): 87 | with st.spinner("Thinking..."): 88 | response = generate_response(prompt) 89 | st.write(response) 90 | message = {"role": "assistant", "content": response} 91 | st.session_state.messages.append(message) 92 | -------------------------------------------------------------------------------- /examples/webchat/README.md: -------------------------------------------------------------------------------- 1 | ## 🌐 Chat with Website Feature 2 | 3 | Gemma branches out into the web with our `Chat with Website` feature. Strike up a conversation with any website and let Gemma extract the essence for a delightful chat experience. 4 | 5 | Special thank to the authors of repo: [scrapeGPT](https://github.com/LexiestLeszek/scrapeGPT), we based on scrapeGPT to build our demo! 6 | 7 | ### Quick Start 8 | 9 | 1. Launch the Gemma Chat Demo. 10 | 2. Plug in the tokenizer and weights paths, set the model type in the sidebar, and hit 'Load Model'. 11 | 3. Navigate to 'Website Processing', input a website URL, and press 'Process Website'. 12 | 4. Enjoy the interactive session as Gemma digests web content for a smart chat. 13 | 14 | Dive into a seamless blend of AI interaction and web content with just a few clicks! 15 | 16 | ```bash 17 | # To get started right away: 18 | cd examples/webchat 19 | pip install -r requirements.txt 20 | streamlit run streamlit_webchat.py 21 | ``` 22 | 23 | 24 | ![Chat with Website Demo](../../asset/demo_chat_website.png) 25 | -------------------------------------------------------------------------------- /examples/webchat/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/namtranase/gemma-cpp-python/9164d4e4a57d34e41c7ad2296175e00e5fee01ec/examples/webchat/__init__.py -------------------------------------------------------------------------------- /examples/webchat/process_web.py: -------------------------------------------------------------------------------- 1 | import json 2 | import time 3 | from datetime import datetime 4 | from io import BytesIO 5 | from urllib.parse import urljoin, urlparse 6 | 7 | import requests 8 | from bs4 import BeautifulSoup 9 | from fp.fp import FreeProxy 10 | from langchain.text_splitter import RecursiveCharacterTextSplitter 11 | from langchain_community.embeddings import HuggingFaceEmbeddings 12 | from langchain_community.vectorstores import Chroma 13 | from PyPDF2 import PdfReader 14 | 15 | # Proxy init 16 | def get_proxy(): 17 | print("Starting proxy ...") 18 | proxy_url = FreeProxy( 19 | country_id=[ 20 | "US", 21 | "CA", 22 | "FR", 23 | "NZ", 24 | "SE", 25 | "PT", 26 | "CZ", 27 | "NL", 28 | "ES", 29 | "SK", 30 | "UK", 31 | "PL", 32 | "IT", 33 | "DE", 34 | "AT", 35 | "JP", 36 | ], 37 | https=True, 38 | rand=True, 39 | timeout=3, 40 | ).get() 41 | proxy_obj = {"server": proxy_url, "username": "", "password": ""} 42 | 43 | print(f"Proxy generated: {proxy_url}") 44 | 45 | return proxy_obj 46 | 47 | 48 | def save_to_db(text, url): 49 | timestamp = datetime.now().isoformat() 50 | # Load existing data from db.json 51 | try: 52 | with open("db.json", "r") as f: 53 | data = json.load(f) 54 | except FileNotFoundError: 55 | data = [] 56 | 57 | # Create a new entry with the domain name as key 58 | website = {"date": timestamp, "text": text} 59 | new_entry = {"start_url": url, "data": website} 60 | 61 | # Append new entry to the data list 62 | data.append(new_entry) 63 | 64 | # Write data back to db.json 65 | with open("db.json", "w") as f: 66 | json.dump(data, f, indent=4) 67 | 68 | 69 | def scrape_webpages(urls, proxy): 70 | print("Scraping text from webpages from each of the links ...") 71 | scraped_texts = [] 72 | for url in urls: 73 | try: 74 | if url.endswith(".pdf"): 75 | response = requests.get(url, proxies=proxy) 76 | reader = PdfReader(BytesIO(response.content)) 77 | number_of_pages = len(reader.pages) 78 | 79 | for p in range(number_of_pages): 80 | 81 | page = reader.pages[p] 82 | text = page.extract_text() 83 | scraped_texts.append(text) 84 | else: 85 | page = requests.get(url, proxies=proxy) 86 | soup = BeautifulSoup(page.content, "html.parser") 87 | text = " ".join([p.get_text() for p in soup.find_all("p")]) 88 | scraped_texts.append(text) 89 | 90 | except Exception as e: 91 | print(f"Failed to scrape {url}: {e}") 92 | 93 | all_scraped_text = "\n".join(scraped_texts) 94 | print("Finished scraping the text from webpages!") 95 | return all_scraped_text 96 | 97 | 98 | def get_domain(url): 99 | return urlparse(url).netloc 100 | 101 | 102 | def get_robots_file(url, proxy): 103 | robots_url = urljoin(url, "/robots.txt") 104 | try: 105 | response = requests.get(robots_url, proxies=proxy) 106 | return response.text 107 | except Exception as e: 108 | print(f"Error fetching robots.txt: {e}") 109 | return None 110 | 111 | 112 | def parse_robots(content): 113 | # This function assumes simple rules without wildcards, comments, etc. 114 | # For a full parser, consider using a library like robotparser. 115 | disallowed = [] 116 | for line in content.splitlines(): 117 | if line.startswith("Disallow:"): 118 | path = line[len("Disallow:") :].strip() 119 | disallowed.append(path) 120 | return disallowed 121 | 122 | 123 | def is_allowed(url, disallowed_paths, base_domain): 124 | parsed_url = urlparse(url) 125 | if parsed_url.netloc != base_domain: 126 | return False 127 | for path in disallowed_paths: 128 | if parsed_url.path.startswith(path): 129 | return False 130 | return True 131 | 132 | 133 | def scrape_site_links(url, proxy): 134 | visited_links = set() 135 | not_visited_links = set() 136 | to_visit = [url] 137 | base_domain = get_domain(url) 138 | disallowed_paths = parse_robots(get_robots_file(url, proxy)) 139 | last_found_time = time.time() # Track the last time a link was found 140 | 141 | while to_visit: 142 | # Break the loop if 30 seconds have passed without finding a new link 143 | if time.time() - last_found_time > 15: 144 | print("FINISHED scraping the links") 145 | break 146 | 147 | current_url = to_visit.pop(0) 148 | if current_url not in visited_links and is_allowed( 149 | current_url, disallowed_paths, base_domain 150 | ): 151 | visited_links.add(current_url) 152 | try: 153 | print(f"{current_url}") 154 | response = requests.get(current_url, proxies=proxy) 155 | soup = BeautifulSoup(response.text, "html.parser") 156 | for link in soup.find_all("a", href=True): 157 | new_url = urljoin(current_url, link["href"]) 158 | if new_url not in visited_links: 159 | to_visit.append(new_url) 160 | last_found_time = time.time() # Update the last found time 161 | except Exception as e: 162 | print(f" !!! COULD NOT VISIT: {current_url}") 163 | not_visited_links.add(current_url) 164 | 165 | return visited_links 166 | 167 | 168 | class WebProcesser: 169 | def __init__(self) -> None: 170 | self.chunk_size = (500,) 171 | self.chunk_overlap = (100,) 172 | self.text_splitter = RecursiveCharacterTextSplitter( 173 | chunk_size=500, chunk_overlap=100 174 | ) 175 | self.embedding = HuggingFaceEmbeddings( 176 | model_name="sentence-transformers/all-MiniLM-L6-v2" 177 | ) 178 | self.db = None 179 | self.retriever = None 180 | self.db_path = "db.json" 181 | db_file = json.dumps([]) 182 | with open(self.db_path, "w") as outfile: 183 | outfile.write(db_file) 184 | 185 | def init_db_website(self, url): 186 | web_text = "" 187 | try: 188 | with open(self.db_path, "r") as f: 189 | data = json.load(f) 190 | for entry in data: 191 | if ( 192 | url in entry["start_url"] 193 | ): # ADD check for today's scraped website data, not longer 194 | print("Website is already scraped today!") 195 | web_text = entry["data"]["text"] 196 | except FileNotFoundError: 197 | data = [] 198 | # Check if website already in the db 199 | if not web_text: 200 | proxy = get_proxy() 201 | # Scrape all the links from the given start URL using the proxy 202 | all_links = scrape_site_links(url, proxy) 203 | 204 | # Scrape the content from all the links obtained, using the proxy 205 | web_text = scrape_webpages(all_links, proxy) 206 | save_to_db(web_text, url) 207 | 208 | documents = self.text_splitter.split_text(str(web_text)) 209 | self.db = Chroma.from_texts(documents, embedding=self.embedding) 210 | self.retriever = self.db.as_retriever(search_kwargs={"k": 3}) 211 | return True 212 | 213 | def get_context(self, question, chunk_size=500, chunk_overlap=100): 214 | """Get context from question and txt file""" 215 | print("Embedding model started ...") 216 | context = self.retriever.get_relevant_documents(question) 217 | print(f"Emdeggind Model returned: {context}") 218 | 219 | return context 220 | -------------------------------------------------------------------------------- /examples/webchat/requirements.txt: -------------------------------------------------------------------------------- 1 | aiogram==2.22.1 2 | beautifulsoup4==4.11.1 3 | free_proxy==1.1.1 4 | langchain==0.1.6 5 | langchain_community==0.0.19 6 | PyPDF2==3.0.1 7 | sentence-transformers==2.6.0 8 | chromadb==0.4.24 9 | -------------------------------------------------------------------------------- /examples/webchat/streamlit_webchat.py: -------------------------------------------------------------------------------- 1 | import time 2 | import streamlit as st 3 | from pygemma import Gemma 4 | from process_web import WebProcesser 5 | 6 | 7 | st.set_page_config(page_title="Chat with Website 💬") 8 | st.title("Gemma Cpp Python Chat with Website Demo 🎈") 9 | 10 | # Initialize session state for the model and its load status 11 | if "model_loaded" not in st.session_state: 12 | st.session_state["model_loaded"] = False 13 | st.session_state["website_loaded"] = False 14 | st.session_state["gemma"] = None 15 | st.session_state["web_processer"] = None 16 | st.session_state["messages"] = [ 17 | {"role": "assistant", "content": "How may I help you?"} 18 | ] 19 | 20 | 21 | @st.cache_resource 22 | def load_gemma_model(tokenizer_path, weights_path, model_type): 23 | gemma = Gemma() 24 | gemma.load_model(tokenizer_path, weights_path, model_type) 25 | return gemma 26 | 27 | 28 | @st.cache_resource 29 | def load_web_processer(): 30 | web_processor = WebProcesser() 31 | return web_processor 32 | 33 | 34 | # Sidebar for model configuration 35 | with st.sidebar: 36 | st.title("Gemma Config") 37 | tokenizer_path = st.text_input( 38 | "Tokenizer path", value="tokenizer.spm", placeholder="tokenizer.spm" 39 | ) 40 | weights_path = st.text_input( 41 | "Compressed weights path", value="2b-it-sfp.sbs", placeholder="2b-it-sfp.sbs" 42 | ) 43 | model_type = st.text_input("Model type", value="2b-it", placeholder="2b-it") 44 | 45 | # Load model button in the sidebar 46 | if st.button("Load Model"): 47 | st.session_state["gemma"] = load_gemma_model( 48 | tokenizer_path, weights_path, model_type 49 | ) 50 | st.session_state["model_loaded"] = True 51 | 52 | # Indicate whether the model is loaded 53 | if st.session_state["model_loaded"]: 54 | st.sidebar.success("Model Loaded Successfully!") 55 | else: 56 | st.sidebar.warning('Model Not Loaded. Click "Load Model" to load the model.') 57 | 58 | st.markdown("## Website Processing") 59 | website_url = st.text_input( 60 | "Website URL", 61 | value="https://namtranase.github.io/terminalmind/", 62 | placeholder="Enter website URL", 63 | ) 64 | 65 | if st.button("Process Website"): 66 | st.session_state["web_processor"] = WebProcesser() 67 | if website_url: 68 | # Placeholder for the function to process website data 69 | st.session_state["web_processor"].init_db_website(website_url) 70 | st.session_state["website_loaded"] = True 71 | st.sidebar.success("Website processed successfully, Now you can ask!") 72 | else: 73 | st.sidebar.error("Please enter a valid website URL.") 74 | 75 | st.markdown( 76 | "📖 Check the detail at [gemma-cpp-python](https://github.com/namtranase/gemma-cpp-python)!" 77 | ) 78 | 79 | # Store LLM generated responses 80 | if "messages" not in st.session_state: 81 | st.session_state.messages = [ 82 | {"role": "assistant", "content": "How may I help you?"} 83 | ] 84 | 85 | # Store LLM generated responses 86 | if "messages" not in st.session_state.keys(): 87 | st.session_state.messages = [ 88 | {"role": "assistant", "content": "How may I help you?"} 89 | ] 90 | 91 | # Display chat messages 92 | for message in st.session_state.messages: 93 | with st.chat_message(message["role"]): 94 | st.write(message["content"]) 95 | 96 | # Function for generating LLM response 97 | def generate_response(prompt_input): 98 | if st.session_state.model_loaded and st.session_state.website_loaded: 99 | context = st.session_state["web_processor"].get_context(prompt_input) 100 | prompt = f"""Use the following pieces of context to answer the question at the end. 101 | Context: {context}.\n 102 | Question: {prompt_input} 103 | Helpful Answer:""" 104 | return st.session_state.gemma.completion(prompt) 105 | else: 106 | return "Please load the model first." 107 | 108 | 109 | # User-provided prompt 110 | if prompt := st.chat_input(disabled=not (st.session_state["model_loaded"])): 111 | st.session_state.messages.append({"role": "user", "content": prompt}) 112 | with st.chat_message("user"): 113 | st.write(prompt) 114 | 115 | # Generate a new response if last message is not from assistant 116 | if st.session_state.messages[-1]["role"] != "assistant": 117 | with st.chat_message("assistant"): 118 | with st.spinner("Thinking..."): 119 | response = generate_response(prompt) 120 | st.write(response) 121 | message = {"role": "assistant", "content": response} 122 | st.session_state.messages.append(message) 123 | -------------------------------------------------------------------------------- /model/.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/namtranase/gemma-cpp-python/9164d4e4a57d34e41c7ad2296175e00e5fee01ec/model/.gitignore -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "wheel", "cmake"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "pygemma" 7 | version = "0.1.3.post3" 8 | authors = [ 9 | {name = "Nam Tran", email = "namtran.ase@gmail.com"}, 10 | ] 11 | description = "Python bindings for the gemma.cpp library" 12 | readme = "README.md" 13 | license = { text = "MIT" } 14 | classifiers = [ 15 | "Programming Language :: Python :: 3", 16 | "License :: OSI Approved :: MIT License", 17 | "Operating System :: OS Independent", 18 | ] 19 | requires-python = ">=3.8" 20 | 21 | [project.optional-dependencies] 22 | test = [ 23 | "pytest>=8.1.1" 24 | ] -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/namtranase/gemma-cpp-python/9164d4e4a57d34e41c7ad2296175e00e5fee01ec/requirements.txt -------------------------------------------------------------------------------- /scripts/release.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Remove previous distribution files 4 | rm -rf dist/* 5 | 6 | # Build the distribution 7 | python3 setup.py sdist bdist_wheel 8 | 9 | # Check the operating system 10 | OS="$(uname)" 11 | if [ "$OS" = "Darwin" ]; then 12 | # macOS specific commands 13 | delocate-wheel -w fixed_wheels -v dist/*.whl 14 | mv fixed_wheels/*.whl dist/ 15 | elif [ "$OS" = "Linux" ]; then 16 | # Linux specific commands (if any can be added here) 17 | echo "Linux OS detected. No additional steps required for Linux." 18 | fi 19 | 20 | # Upload the distribution to PyPI 21 | twine upload dist/* -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import platform 3 | import subprocess 4 | import sys 5 | 6 | from setuptools import Extension, setup 7 | from setuptools.command.build_ext import build_ext 8 | 9 | 10 | class CMakeExtension(Extension): 11 | def __init__(self, name, sourcedir=""): 12 | Extension.__init__(self, name, sources=[]) 13 | self.sourcedir = os.path.abspath(sourcedir) 14 | 15 | 16 | class CMakeBuild(build_ext): 17 | def run(self): 18 | try: 19 | out = subprocess.check_output(["cmake", "--version"]) 20 | except OSError: 21 | raise RuntimeError( 22 | "CMake must be installed to build the following extensions: " 23 | + ", ".join(e.name for e in self.extensions) 24 | ) 25 | 26 | for ext in self.extensions: 27 | self.build_extension(ext) 28 | 29 | def build_extension(self, ext): 30 | extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name))) 31 | os.makedirs(self.build_temp, exist_ok=True) 32 | 33 | cmake_args = [ 34 | f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}", 35 | f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE={extdir}", 36 | f"-DPYTHON_EXECUTABLE={sys.executable}", 37 | ] 38 | 39 | # Allow gemma.cpp to be built on Windows with ClangCL 40 | # Refer to https://github.com/google/gemma.cpp/pull/6 41 | if platform.system() == "Windows": 42 | cmake_args += ["-T", "ClangCL"] 43 | 44 | cfg = "Debug" if self.debug else "Release" 45 | build_args = ["--config", cfg] 46 | 47 | if platform.system() == "Windows": 48 | build_args += ["--", "/m:12"] 49 | else: 50 | build_args += ["--", "-j12"] 51 | 52 | subprocess.check_call( 53 | ["cmake", ext.sourcedir] + cmake_args, cwd=self.build_temp 54 | ) 55 | subprocess.check_call( 56 | ["cmake", "--build", ".", "--target", ext.name] + build_args, 57 | cwd=self.build_temp, 58 | ) 59 | 60 | 61 | setup( 62 | ext_modules=[CMakeExtension("_pygemma")], 63 | cmdclass=dict(build_ext=CMakeBuild), 64 | ) 65 | -------------------------------------------------------------------------------- /src/_pygemma/__init__.pyi: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | class GemmaModel: 4 | def __init__( 5 | self, 6 | tokenizer_path: str, 7 | compressed_weights_path: str, 8 | model_type: int, 9 | model_training: int, 10 | n_threads: int, 11 | ) -> None: 12 | pass 13 | 14 | @property 15 | def bos_token(self) -> int: ... 16 | @property 17 | def eos_token(self) -> int: ... 18 | def generate( 19 | self, 20 | prompt: str, 21 | max_tokens: int, 22 | max_generated_tokens: int, 23 | temperature: float, 24 | seed: int, 25 | verbosity: int, 26 | ) -> str: ... 27 | def tokenize( 28 | self, 29 | text: str, 30 | add_bos: bool = True, 31 | ) -> List[int]: ... 32 | def detokenize( 33 | self, 34 | tokens: List[int], 35 | ) -> str: ... 36 | -------------------------------------------------------------------------------- /src/_pygemma/gemma_binding.cpp: -------------------------------------------------------------------------------- 1 | #include "gemma_binding.h" 2 | 3 | #include 4 | #include 5 | 6 | #include 7 | 8 | #include "gemma.h" // Gemma 9 | #include "util/app.h" 10 | #include "util/args.h" 11 | 12 | namespace py = pybind11; 13 | 14 | GemmaModel::GemmaModel(const char *tokenizer_path_str, 15 | const char *compressed_weights_path_str, 16 | int model_type_id, int training_id, int num_threads) { 17 | const gcpp::Path tokenizer_path = gcpp::Path{tokenizer_path_str}; 18 | const gcpp::Path compressed_weights_path = 19 | gcpp::Path{compressed_weights_path_str}; 20 | const gcpp::Path weights_path = gcpp::Path{""}; 21 | 22 | this->model_type = static_cast(model_type_id); 23 | this->model_training = static_cast(training_id); 24 | this->num_threads = static_cast(num_threads); 25 | 26 | pool = new hwy::ThreadPool(num_threads); 27 | 28 | kv_cache = CreateKVCache(model_type); 29 | 30 | // For many-core, pinning threads to cores helps. 31 | if (this->num_threads > 10) { 32 | gcpp::PinThreadToCore(this->num_threads - 1); // Main thread 33 | 34 | pool->Run(0, pool->NumThreads(), [](uint64_t /*task*/, size_t thread) { 35 | gcpp::PinThreadToCore(thread); 36 | }); 37 | } 38 | 39 | model = new gcpp::Gemma(tokenizer_path, compressed_weights_path, weights_path, 40 | model_type, model_training, *pool); 41 | } 42 | 43 | GemmaModel::~GemmaModel() { delete model; } 44 | 45 | int GemmaModel::get_bos_token() const { return bos_token; } 46 | 47 | int GemmaModel::get_eos_token() const { return eos_token; } 48 | 49 | std::vector GemmaModel::tokenize(const std::string &text, 50 | const bool add_bos) { 51 | std::vector tokens; 52 | 53 | if (!model->Tokenizer()->Encode(text, &tokens).ok()) { 54 | throw std::runtime_error("Tokenization failed"); 55 | } 56 | 57 | if (add_bos) { 58 | tokens.insert(tokens.begin(), bos_token); 59 | } 60 | 61 | return tokens; 62 | } 63 | 64 | std::string GemmaModel::detokenize(const std::vector &tokens) { 65 | std::string text; 66 | if (!model->Tokenizer()->Decode(tokens, &text).ok()) { 67 | throw std::runtime_error("Detokenization failed"); 68 | } 69 | return text; 70 | } 71 | 72 | std::string GemmaModel::generate(const std::string &prompt_string, 73 | size_t max_tokens, size_t max_generated_tokens, 74 | float temperature, uint_fast32_t seed, 75 | int verbosity) { 76 | size_t pos = 0; // KV Cache position 77 | 78 | // Initialize random number generator 79 | std::mt19937 gen; 80 | gen.seed(seed); 81 | 82 | const std::string formatted = [&]() { 83 | if (model_training == gcpp::ModelTraining::GEMMA_IT) { 84 | return "user\n" + prompt_string + 85 | "\nmodel\n"; 86 | } 87 | return prompt_string; 88 | }(); 89 | 90 | std::vector tokens = tokenize(formatted, true); 91 | size_t ntokens = tokens.size(); 92 | 93 | std::string completion; 94 | 95 | // This callback function gets invoked everytime a token is generated 96 | const gcpp::StreamFunc stream_token = [this, &pos, &gen, &ntokens, 97 | tokenizer = model->Tokenizer(), 98 | &completion](int token, float) { 99 | ++pos; 100 | if (pos < ntokens) { 101 | // print feedback 102 | } else if (token != this->eos_token) { 103 | std::string token_text; 104 | HWY_ASSERT(tokenizer->Decode(std::vector{token}, &token_text).ok()); 105 | completion += token_text; 106 | } 107 | return true; 108 | }; 109 | 110 | gcpp::GenerateGemma(*model, 111 | {.max_tokens = max_tokens, 112 | .max_generated_tokens = max_generated_tokens, 113 | .temperature = temperature, 114 | .verbosity = verbosity}, 115 | tokens, /*KV cache position = */ 0, kv_cache, *pool, 116 | stream_token, gen); 117 | 118 | return completion; 119 | } 120 | 121 | PYBIND11_MODULE(_pygemma, m) { 122 | m.doc() = "Python binding for gemma.cpp"; 123 | 124 | py::class_(m, "GemmaModel") 125 | .def_property_readonly("bos_token", &GemmaModel::get_bos_token, 126 | "Get the BOS token") 127 | .def_property_readonly("eos_token", &GemmaModel::get_eos_token, 128 | "Get the EOS token") 129 | .def(py::init(), 130 | py::arg("tokenizer_path"), py::arg("compressed_weights_path"), 131 | py::arg("model_type"), py::arg("model_training"), 132 | py::arg("num_threads"), "Initialize the Gemma model") 133 | .def("tokenize", &GemmaModel::tokenize, py::arg("text"), 134 | py::arg("add_bos"), 135 | "Tokenize the input text and return the tokenized text") 136 | .def("detokenize", &GemmaModel::detokenize, py::arg("tokens"), 137 | "Detokenize the input tokens and return the detokenized text") 138 | .def("generate", &GemmaModel::generate, py::arg("prompt"), 139 | py::arg("max_tokens"), py::arg("max_generated_tokens"), 140 | py::arg("temperature"), py::arg("seed"), py::arg("verbosity"), 141 | "Generate text based on the input prompt"); 142 | } -------------------------------------------------------------------------------- /src/_pygemma/gemma_binding.h: -------------------------------------------------------------------------------- 1 | #include "gemma.h" // Gemma 2 | 3 | #pragma once 4 | class GemmaModel { 5 | private: 6 | gcpp::Gemma *model; 7 | gcpp::Model model_type; 8 | gcpp::ModelTraining model_training; 9 | 10 | size_t num_threads; 11 | hwy::ThreadPool *pool; 12 | gcpp::KVCache kv_cache; 13 | 14 | const int eos_token = 1; 15 | const int bos_token = 2; 16 | 17 | public: 18 | GemmaModel(const char *tokenizer_path_str, 19 | const char *compressed_weights_path_str, int model_type_id, 20 | int training_id, int num_threads); 21 | 22 | ~GemmaModel(); 23 | 24 | int get_bos_token() const; 25 | 26 | int get_eos_token() const; 27 | 28 | std::vector tokenize(const std::string &text, const bool add_bos = true); 29 | 30 | std::string detokenize(const std::vector &tokens); 31 | 32 | std::string generate(const std::string &prompt, size_t max_tokens, 33 | size_t max_generated_tokens, float temperature, 34 | uint_fast32_t seed, int verbosity); 35 | }; -------------------------------------------------------------------------------- /src/pygemma/__init__.py: -------------------------------------------------------------------------------- 1 | from .gemma import * 2 | -------------------------------------------------------------------------------- /src/pygemma/gemma.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import os 3 | from enum import Enum 4 | import random 5 | from typing import List, Optional 6 | 7 | import _pygemma 8 | 9 | 10 | class ModelTraining(Enum): 11 | GEMMA_IT = 0 12 | GEMMA_PT = 1 13 | 14 | 15 | class ModelType(Enum): 16 | Gemma2B = 0 17 | Gemma7B = 1 18 | 19 | 20 | class Gemma: 21 | def __init__( 22 | self, 23 | *, 24 | tokenizer_path: str, 25 | compressed_weights_path: str, 26 | model_type: ModelType, 27 | model_training: ModelTraining, 28 | n_threads: Optional[int] = None, 29 | ): 30 | self.tokenizer_path = tokenizer_path 31 | self.compressed_weights_path = compressed_weights_path 32 | self.model_type = model_type 33 | self.model_training = model_training 34 | 35 | self.n_threads = n_threads or max(multiprocessing.cpu_count() - 2, 1) 36 | 37 | if not os.path.exists(self.tokenizer_path): 38 | raise FileNotFoundError(f"Tokenizer not found: {self.tokenizer_path}") 39 | 40 | if not os.path.exists(self.compressed_weights_path): 41 | raise FileNotFoundError( 42 | f"Compressed weights not found: {self.compressed_weights_path}" 43 | ) 44 | 45 | self.model = _pygemma.GemmaModel( 46 | self.tokenizer_path, 47 | self.compressed_weights_path, 48 | self.model_type.value, 49 | self.model_training.value, 50 | self.n_threads, 51 | ) 52 | 53 | assert self.model 54 | 55 | @property 56 | def bos_token(self) -> int: 57 | assert self.model 58 | return self.model.bos_token 59 | 60 | @property 61 | def eos_token(self) -> int: 62 | assert self.model 63 | return self.model.eos_token 64 | 65 | def __call__( 66 | self, 67 | prompt: str, 68 | *, 69 | max_tokens: int = 2048, 70 | max_generated_tokens: int = 1024, 71 | temperature: float = 1.0, 72 | seed: Optional[int] = None, 73 | verbosity: int = 0, 74 | ) -> str: 75 | assert self.model 76 | 77 | seed = seed or random.randint(0, 2**32 - 1) 78 | 79 | return self.model.generate( 80 | prompt, 81 | max_tokens, 82 | max_generated_tokens, 83 | temperature, 84 | seed, 85 | verbosity, 86 | ) 87 | 88 | def tokenize( 89 | self, 90 | text: str, 91 | add_bos: bool = True, 92 | ) -> List[int]: 93 | assert self.model 94 | return self.model.tokenize(text, add_bos) 95 | 96 | def detokenize( 97 | self, 98 | tokens: List[int], 99 | ) -> str: 100 | assert self.model 101 | return self.model.detokenize(tokens) 102 | -------------------------------------------------------------------------------- /tests/test_gemma.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pygemma import Gemma, ModelType, ModelTraining 3 | 4 | # Get the directory that this file is in 5 | dir_path = os.path.dirname(os.path.realpath(__file__)) 6 | 7 | TOKENIZER_PATH = os.path.join(dir_path, "../model/tokenizer.spm") 8 | COMPRESSED_WEIGHTS_PATH = os.path.join(dir_path, "../model/2b-it-mqa.sbs") 9 | MODEL_TYPE = ModelType.Gemma2B 10 | MODEL_TRAINING = ModelTraining.GEMMA_IT 11 | 12 | 13 | def test_gemma(): 14 | gemma = Gemma( 15 | tokenizer_path=TOKENIZER_PATH, 16 | compressed_weights_path=COMPRESSED_WEIGHTS_PATH, 17 | model_type=MODEL_TYPE, 18 | model_training=MODEL_TRAINING, 19 | ) 20 | 21 | assert gemma 22 | assert gemma.model 23 | 24 | text = "Hello world!" 25 | 26 | tokens = gemma.tokenize(text) 27 | assert tokens[0] == gemma.bos_token 28 | assert tokens == [2, 4521, 2134, 235341] 29 | detokenized = gemma.detokenize(tokens) 30 | assert detokenized == text 31 | 32 | # without BOS 33 | tokens_without_bos = gemma.tokenize(text, add_bos=False) 34 | assert tokens_without_bos[0] != gemma.bos_token 35 | assert tokens_without_bos == [4521, 2134, 235341] 36 | --------------------------------------------------------------------------------