├── .python-version ├── .github ├── ISSUE_TEMPLATE │ ├── config.yml │ ├── 03_others.yml │ ├── 01_bug_report.yml │ └── 02_feature_request.yml ├── pull_request_template.md ├── FUNDING.yml ├── branch-convention.md ├── semantic.yml └── commit-convention.md ├── img └── angryface.png ├── .gitignore ├── constants.py ├── main.py ├── LICENSE ├── pyproject.toml ├── llama2gptq ├── qa.py ├── ingest.py ├── quantize.py └── generate.py ├── chat.py ├── README.md ├── requirements.lock └── requirements-dev.lock /.python-version: -------------------------------------------------------------------------------- 1 | 3.10.0 -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: false 2 | -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | ## Why did you do this 2 | 3 | ## How did you do that 4 | -------------------------------------------------------------------------------- /img/angryface.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seonglae/llama2gptq/HEAD/img/angryface.png -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | github: 2 | - seonglae 3 | custom: 4 | - 'https://paypal.me/seonglae' 5 | - 'https://www.buymeacoffee.com/seongland' 6 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | knowledge 2 | .env 3 | __pycache__ 4 | # mypy 5 | .mypy_cache/ 6 | .dmypy.json 7 | dmypy.json 8 | .venv 9 | .chroma 10 | db* 11 | -------------------------------------------------------------------------------- /.github/branch-convention.md: -------------------------------------------------------------------------------- 1 | ## Git Branch name Convention 2 | 3 | #### TL;DR: 4 | 5 | Branch name must be matched by the following regex: 6 | 7 | ```re 8 | ^(feature|bug|document|style|refactor|test|deps)\/\#[0-9]{1,5}-[a-z|A-Z|\-|0-9]{1,20} 9 | ``` 10 | -------------------------------------------------------------------------------- /.github/semantic.yml: -------------------------------------------------------------------------------- 1 | titleOnly: true 2 | types: 3 | - feat 4 | - fix 5 | - docs 6 | - style 7 | - refactor 8 | - test 9 | - ci 10 | - cd 11 | - build 12 | - lint 13 | - merge 14 | - typing 15 | - perf 16 | - meta 17 | - deps 18 | - pr 19 | - chore 20 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/03_others.yml: -------------------------------------------------------------------------------- 1 | --- 2 | name: Other Issue 3 | description: Other kind of the issue 4 | body: 5 | - type: textarea 6 | id: summary 7 | attributes: 8 | label: Summary 9 | description: Any precise description of the issue 10 | validations: 11 | required: true 12 | -------------------------------------------------------------------------------- /constants.py: -------------------------------------------------------------------------------- 1 | from os.path import realpath, join, dirname 2 | 3 | from chromadb.config import Settings 4 | from langchain.document_loaders.base import BaseLoader 5 | 6 | 7 | ROOT_DIRECTORY = dirname(realpath(__file__)) 8 | 9 | SOURCE_DIRECTORY = join(ROOT_DIRECTORY, 'knowledge') 10 | 11 | PERSIST_DIRECTORY = join(ROOT_DIRECTORY, 'db') 12 | 13 | CHROMA_SETTINGS = Settings( 14 | chroma_db_impl='duckdb+parquet', 15 | persist_directory=PERSIST_DIRECTORY, 16 | anonymized_telemetry=False 17 | ) 18 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/01_bug_report.yml: -------------------------------------------------------------------------------- 1 | --- 2 | name: 🐛 Bug Report 3 | description: Something isn't working as expected 4 | labels: 5 | - bug 6 | body: 7 | - type: input 8 | id: testcase 9 | attributes: 10 | label: Reproducible test case 11 | description: 12 | If possible, please create a minimal test case that reproduces your 13 | problem. 14 | validations: 15 | required: true 16 | - type: textarea 17 | id: summary 18 | attributes: 19 | label: Additional information 20 | description: 21 | Please share any other relevant information not mentioned above. What 22 | did you expect to happen? What do you think the problem might be? 23 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/02_feature_request.yml: -------------------------------------------------------------------------------- 1 | --- 2 | name: 🟢 Feature Request 3 | description: Wouldn’t it be nice if 4 | labels: 5 | - feature 6 | body: 7 | - type: textarea 8 | id: summary 9 | attributes: 10 | label: What? 11 | description: Describe your feature idea 12 | validations: 13 | required: true 14 | - type: textarea 15 | id: why 16 | attributes: 17 | label: Why? 18 | description: Describe the problem you are facing 19 | validations: 20 | required: true 21 | - type: textarea 22 | id: alternatives 23 | attributes: 24 | label: How? 25 | description: 26 | Describe you tried or ideas to implement the feature 27 | validations: 28 | required: false 29 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import fire 2 | 3 | from llama2gptq.ingest import ingest 4 | from llama2gptq.qa import chat_cli 5 | from llama2gptq.quantize import quantization 6 | from constants import (SOURCE_DIRECTORY, PERSIST_DIRECTORY) 7 | 8 | 9 | def chat(device: str = "cuda") -> str: 10 | stats = chat_cli(device) 11 | return stats 12 | 13 | 14 | def process(src_dir: str = SOURCE_DIRECTORY, dst_dir: str = PERSIST_DIRECTORY, device: str = "cuda") -> str: 15 | return ingest(src_dir, dst_dir, device) 16 | 17 | 18 | def quantize(model: str = "meta-llama/Llama-2-13b-chat-hf", 19 | output: str = "llama-2-13b-chat-hf-gptq", 20 | push: bool = False, owner: str = 'seonglae', 21 | safetensor = False, inference_only: bool = False) -> str: 22 | quantization(model, output, push, owner, safetensor, inference_only) 23 | return 'complete' 24 | 25 | 26 | if __name__ == '__main__': 27 | fire.Fire() 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Alan Jo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | authors = [ 3 | {name = "seonglae", email = "sungle3737@gmail.com"}, 4 | ] 5 | dependencies = [ 6 | "langchain~=0.0.225", 7 | "chromadb~=0.3.26", 8 | "transformers~=4.30.2", 9 | "InstructorEmbedding~=1.0.1", 10 | "sentence_transformers~=2.2.2", 11 | "unstructured~=0.7.12", 12 | "torch>=2.0.1", 13 | "auto_gptq~=0.2.2", 14 | "einops~=0.6.1", 15 | "fire~=0.5.0", 16 | "streamlit-chat~=0.1.1", 17 | "protobuf<=3.20.0" 18 | ] 19 | description = "Chat AI which can provide responses with reference documents by Prompt engineering over vector database." 20 | license = {text = "MIT"} 21 | name = "llama2gptq" 22 | readme = "README.md" 23 | requires-python = ">= 3.8" 24 | version = "0.1.0" 25 | 26 | [build-system] 27 | build-backend = "hatchling.build" 28 | requires = ["hatchling"] 29 | 30 | [tool.rye] 31 | dev-dependencies = [ 32 | "autopep8~=2.0.2", 33 | "pip~=23.1.2", 34 | "mypy~=1.3.0", 35 | "setuptools~=68.0.0", 36 | ] 37 | managed = true 38 | 39 | [[tool.rye.sources]] 40 | name = "cuda" 41 | url = "https://download.pytorch.org/whl/cu118" 42 | type = "index" 43 | -------------------------------------------------------------------------------- /llama2gptq/qa.py: -------------------------------------------------------------------------------- 1 | from time import time 2 | from typing import Tuple, List 3 | 4 | import torch 5 | from transformers import Pipeline 6 | 7 | from llama2gptq.ingest import extract_ref 8 | from llama2gptq.generate import load_embeddings, load_db, load_model, TokenStoppingCriteria 9 | 10 | 11 | @torch.no_grad() 12 | def qa(query, device, db, transformer: Pipeline, history: List[List[str]], 13 | user_token="USER: ", 14 | bot_token="ASSISTANT: ", 15 | sys_token="", 16 | system="", 17 | extract_ref=extract_ref) -> Tuple: 18 | start = time() 19 | 20 | if db is None: 21 | embeddings = load_embeddings(device) 22 | db = load_db(device, embeddings) 23 | if transformer is None: 24 | transformer = load_model(device) 25 | 26 | # input similarity 27 | conversation = [f"{user_token}{q}\n{bot_token}{a}\n" for [q, a] in history] 28 | prompt = f"{sys_token}{system}" + \ 29 | "".join(conversation) + f'{user_token}{query}\n{bot_token}' 30 | print(prompt) 31 | 32 | # Inference 33 | criteria = TokenStoppingCriteria( 34 | user_token.strip(), prompt, transformer.tokenizer) 35 | response = transformer(prompt, stopping_criteria=criteria)[ 36 | 0]["generated_text"] 37 | answer = response.replace(prompt, "").strip() 38 | 39 | # output similarity 40 | refs = db.search( 41 | f'{user_token}{query}\n{bot_token}', search_type="similarity") 42 | 43 | # Print the result 44 | print('\nHelpful links\n') 45 | for ref in refs: 46 | ref_info = extract_ref(ref) 47 | print(f"{ref_info['title']}: {ref_info['link']}") 48 | 49 | print(f"\nTime taken: {time() - start} seconds\n") 50 | print(prompt + answer + '\n') 51 | 52 | return (answer, refs) 53 | 54 | 55 | def qa_cli(device, db, llm, history) -> Tuple: 56 | query = input("\nQuestion: ") 57 | if query == "exit": 58 | return () 59 | return (query, *qa(query, device, db, llm, history)) 60 | 61 | 62 | def chat_cli(device='cuda'): 63 | embeddings = load_embeddings(device) 64 | db = load_db(device, embeddings) 65 | transformer = load_model(device) 66 | 67 | pingongs = [] 68 | while True: 69 | history = [[pingpong[0], pingpong[1]] for pingpong in pingongs] 70 | pingpong = qa_cli(device, db, transformer, history) 71 | if len(pingpong) == 0: 72 | break 73 | pingongs.append(pingpong) 74 | return pingongs 75 | -------------------------------------------------------------------------------- /llama2gptq/ingest.py: -------------------------------------------------------------------------------- 1 | import os 2 | from re import split 3 | from typing import List, Type, Dict 4 | from pathlib import Path 5 | 6 | from langchain.docstore.document import Document 7 | from langchain.embeddings import HuggingFaceInstructEmbeddings 8 | from langchain.text_splitter import RecursiveCharacterTextSplitter 9 | from langchain.document_loaders.base import BaseLoader 10 | from langchain.vectorstores import Chroma 11 | from langchain.document_loaders import ( 12 | CSVLoader, 13 | PDFMinerLoader, 14 | TextLoader, 15 | UnstructuredMarkdownLoader, 16 | UnstructuredExcelLoader, 17 | ) 18 | 19 | from constants import (CHROMA_SETTINGS) 20 | 21 | 22 | DOCUMENT_MAP = { 23 | ".txt": TextLoader, 24 | ".pdf": PDFMinerLoader, 25 | ".csv": CSVLoader, 26 | ".xls": UnstructuredExcelLoader, 27 | ".xlxs": UnstructuredExcelLoader, 28 | ".md": TextLoader, 29 | } 30 | 31 | 32 | def load_documents(folder_path: str) -> List[Document]: 33 | glob = Path(folder_path).glob 34 | ps = list(glob("**/*.md")) 35 | documents = [] 36 | for p in ps: 37 | file_extension = os.path.splitext(p)[1] 38 | loader_class = DOCUMENT_MAP.get(file_extension) 39 | if loader_class: 40 | loader = loader_class(p, encoding="utf-8") 41 | document = loader.load()[0] 42 | document.metadata["source"] = str(p) 43 | documents.append(document) 44 | else: 45 | continue 46 | return documents 47 | 48 | 49 | def extract_ref(ref: Document) -> Dict[str, str]: 50 | source = split(r"\\|/", ref.metadata["source"])[-1] 51 | slug = split(r" |.md", source)[-2] 52 | title = ' '.join(slug.split('-')[:-1]) 53 | link = f"https://texonom.com/{slug}" 54 | return {"title": title, "link": link} 55 | 56 | 57 | def ingest(source: str, output: str, device='cuda'): 58 | print(f"Loading documents from {source}") 59 | documents = load_documents(source) 60 | for doc in documents: 61 | doc.metadata["source"] = extract_ref(doc)['link'] 62 | text_splitter = RecursiveCharacterTextSplitter( 63 | chunk_size=1000, chunk_overlap=200) 64 | texts = text_splitter.split_documents(documents) 65 | print(f"Loaded {len(documents)} documents from {source}") 66 | print(f"Split into {len(texts)} chunks of text") 67 | 68 | embeddings = HuggingFaceInstructEmbeddings( 69 | model_name="intfloat/multilingual-e5-large", 70 | model_kwargs={"device": device}, 71 | ) 72 | db = Chroma.from_documents( 73 | texts, 74 | embeddings, 75 | persist_directory=output, 76 | client_settings=CHROMA_SETTINGS, 77 | ) 78 | db.persist() 79 | -------------------------------------------------------------------------------- /llama2gptq/quantize.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from transformers import AutoTokenizer, TextGenerationPipeline, GenerationConfig, LlamaTokenizer, LlamaTokenizerFast 3 | from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig 4 | 5 | 6 | def quantization(source_model: str, output: str, push: bool, owner: str, 7 | safetensor=False, inference_only=False): 8 | logging.basicConfig( 9 | format="%(asctime)s %(levelname)s [%(name)s] %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S" 10 | ) 11 | 12 | tokenizer = AutoTokenizer.from_pretrained( 13 | source_model, use_fast=True, use_auth_token=True) 14 | examples = [ 15 | tokenizer( 16 | "Texonom is an knowledge system that can help you with your daily tasks using AI chatbot." 17 | ) 18 | ] 19 | 20 | quantize_config = BaseQuantizeConfig( 21 | bits=4, # quantize model to 4-bit 22 | group_size=128, # it is recommended to set the value to 128 23 | desc_act=False, # None act-order can significantly speed up inference but the perplexity may slightly bad 24 | ) 25 | 26 | # quantize model, the examples should be list of dict whose keys can only be "input_ids" and "attention_mask" 27 | if not inference_only: 28 | # load un-quantized model, by default, the model will always be loaded into CPU memory 29 | model = AutoGPTQForCausalLM.from_pretrained( 30 | source_model, quantize_config, use_safetensors=safetensor) 31 | model.quantize(examples) 32 | model.save_quantized(output, use_safetensors=safetensor) 33 | 34 | # load quantized model to the first GPU 35 | quantized = AutoGPTQForCausalLM.from_quantized( 36 | output, 37 | device="cuda:0", 38 | use_safetensors=safetensor 39 | ) 40 | 41 | # inference with model.generate 42 | query = "USER: Are you AI? Say yes or no.\n ASSISTANT:" 43 | 44 | # or you can also use pipeline 45 | pipeline = TextGenerationPipeline(model=quantized, tokenizer=tokenizer) 46 | print(pipeline(query)[0]["generated_text"]) 47 | 48 | # push quantized model to Hugging Face Hub. 49 | # to use use_auth_token=True, Login first via huggingface-cli login. 50 | if push and not inference_only: 51 | commit_message = f"build: AutoGPTQ for {source_model}" + \ 52 | f": {quantize_config.bits}bits, gr{quantize_config.group_size}, desc_act={quantize_config.desc_act}" 53 | generation_config = GenerationConfig.from_pretrained(source_model) 54 | generation_config.push_to_hub( 55 | output, use_auth_token=True, commit_message=commit_message) 56 | tokenizer.push_to_hub(output, use_auth_token=True, 57 | commit_message=commit_message) 58 | repo_id = f"{owner}/{output}" 59 | quantized.push_to_hub(repo_id, use_safetensors=safetensor, 60 | commit_message=commit_message, use_auth_token=True) 61 | -------------------------------------------------------------------------------- /chat.py: -------------------------------------------------------------------------------- 1 | from re import split 2 | 3 | import torch 4 | import streamlit as st 5 | from streamlit_chat import message 6 | 7 | from llama2gptq.qa import qa, load_model, load_db 8 | from llama2gptq.ingest import extract_ref 9 | 10 | DEVICE = 'cuda' 11 | TITLE = 'LLaMa2 GPTQ' 12 | HUG = 'https://em-content.zobj.net/source/microsoft-teams/363/hugging-face_1f917.png' 13 | ANGRY = 'https://em-content.zobj.net/source/microsoft-teams/363/pouting-face_1f621.png' 14 | 15 | 16 | st.set_page_config(page_title=TITLE) 17 | st.header(TITLE) 18 | st.markdown(''' 19 | ### Ask anythig to [Texonom](https://texonom.com). 20 | Question for recently learned 21 | ''', unsafe_allow_html=True) 22 | 23 | 24 | @st.cache_resource 25 | def load_transformer(): 26 | return (load_model(DEVICE), load_db(DEVICE)) 27 | 28 | 29 | transformer, db = load_transformer() 30 | 31 | styl = """ 32 | 48 | """ 49 | 50 | BTN_STYLE = """ 51 | color: #aaa; 52 | padding-right: 0.5rem; 53 | """ 54 | 55 | 56 | st.markdown(styl, unsafe_allow_html=True) 57 | 58 | if 'generated' not in st.session_state: 59 | st.session_state['generated'] = [] 60 | 61 | if 'past' not in st.session_state: 62 | st.session_state['past'] = [] 63 | 64 | if 'answers' not in st.session_state: 65 | st.session_state['answers'] = [] 66 | 67 | 68 | def query(query): 69 | st.session_state.past.append(query) 70 | history = [] 71 | for i, _ in enumerate(st.session_state['generated']): 72 | history.append([st.session_state['past'][i], 73 | st.session_state["generated"][i]]) 74 | 75 | answer, refs = qa(query, DEVICE, db, transformer, history) 76 | 77 | # Append references 78 | st.session_state.generated.append(answer) 79 | 80 | # Generate HTML 81 | answer += '
References: ' 82 | for ref in refs: 83 | ref_info = extract_ref(ref) 84 | btn = f"{ref_info['title']}" 85 | answer += btn 86 | 87 | st.session_state.answers.append(answer) 88 | return answer 89 | 90 | 91 | def get_text(): 92 | input_text = st.text_input("You: ", key="input") 93 | return input_text 94 | 95 | 96 | user_input = get_text() 97 | 98 | 99 | if user_input: 100 | query(user_input) 101 | 102 | 103 | if st.session_state['generated']: 104 | for i, _ in enumerate(st.session_state['generated']): 105 | message(st.session_state['past'][i], is_user=True, 106 | key=str(i) + '_user', logo=HUG) 107 | message(st.session_state["answers"][i], 108 | key=str(i), logo=ANGRY, allow_html=True) 109 | -------------------------------------------------------------------------------- /.github/commit-convention.md: -------------------------------------------------------------------------------- 1 | ## Git Commit Message Convention 2 | 3 | > This is adapted from [Vite's commit convention](https://github.com/vitejs/vite/blob/main/.github/commit-convention.md). 4 | 5 | #### TL;DR: 6 | 7 | Messages must be matched by the following regex: 8 | 9 | 10 | ```re 11 | ^(revert: )?(feat|fix|docs|style|refactor|test|ci|cd|build|meta|pr|lint|typing|perf|deps|merge)(\(.+\))?: .{1,50} 12 | ``` 13 | 14 | #### Examples 15 | 16 | Appears under "Features" header, `dev` subheader: 17 | 18 | ``` 19 | feat(dev): add 'comments' option 20 | ``` 21 | 22 | Appears under "Bug Fixes" header, `dev` subheader, with a link to issue #28: 23 | 24 | ``` 25 | fix(dev): fix dev error 26 | 27 | close #28 28 | ``` 29 | 30 | Appears under "Performance Improvements" header, and under "Breaking Changes" with the breaking change explanation: 31 | 32 | ``` 33 | perf(build): remove 'foo' option 34 | 35 | BREAKING CHANGE: The 'foo' option has been removed. 36 | ``` 37 | 38 | The following commit and commit `667ecc1` do not appear in the changelog if they are under the same release. If not, the revert commit appears under the "Reverts" header. 39 | 40 | ``` 41 | revert: feat(compiler): add 'comments' option 42 | 43 | This reverts commit 667ecc1654a317a13331b17617d973392f415f02. 44 | ``` 45 | 46 | ### Full Message Format 47 | 48 | A commit message consists of a **header**, **body** and **footer**. The header has a **type**, **scope** and **subject**: 49 | 50 | ``` 51 | (): 52 | 53 | 54 | 55 |