├── .python-version
├── .github
├── ISSUE_TEMPLATE
│ ├── config.yml
│ ├── 03_others.yml
│ ├── 01_bug_report.yml
│ └── 02_feature_request.yml
├── pull_request_template.md
├── FUNDING.yml
├── branch-convention.md
├── semantic.yml
└── commit-convention.md
├── img
└── angryface.png
├── .gitignore
├── constants.py
├── main.py
├── LICENSE
├── pyproject.toml
├── llama2gptq
├── qa.py
├── ingest.py
├── quantize.py
└── generate.py
├── chat.py
├── README.md
├── requirements.lock
└── requirements-dev.lock
/.python-version:
--------------------------------------------------------------------------------
1 | 3.10.0
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/config.yml:
--------------------------------------------------------------------------------
1 | blank_issues_enabled: false
2 |
--------------------------------------------------------------------------------
/.github/pull_request_template.md:
--------------------------------------------------------------------------------
1 | ## Why did you do this
2 |
3 | ## How did you do that
4 |
--------------------------------------------------------------------------------
/img/angryface.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seonglae/llama2gptq/HEAD/img/angryface.png
--------------------------------------------------------------------------------
/.github/FUNDING.yml:
--------------------------------------------------------------------------------
1 | github:
2 | - seonglae
3 | custom:
4 | - 'https://paypal.me/seonglae'
5 | - 'https://www.buymeacoffee.com/seongland'
6 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | knowledge
2 | .env
3 | __pycache__
4 | # mypy
5 | .mypy_cache/
6 | .dmypy.json
7 | dmypy.json
8 | .venv
9 | .chroma
10 | db*
11 |
--------------------------------------------------------------------------------
/.github/branch-convention.md:
--------------------------------------------------------------------------------
1 | ## Git Branch name Convention
2 |
3 | #### TL;DR:
4 |
5 | Branch name must be matched by the following regex:
6 |
7 | ```re
8 | ^(feature|bug|document|style|refactor|test|deps)\/\#[0-9]{1,5}-[a-z|A-Z|\-|0-9]{1,20}
9 | ```
10 |
--------------------------------------------------------------------------------
/.github/semantic.yml:
--------------------------------------------------------------------------------
1 | titleOnly: true
2 | types:
3 | - feat
4 | - fix
5 | - docs
6 | - style
7 | - refactor
8 | - test
9 | - ci
10 | - cd
11 | - build
12 | - lint
13 | - merge
14 | - typing
15 | - perf
16 | - meta
17 | - deps
18 | - pr
19 | - chore
20 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/03_others.yml:
--------------------------------------------------------------------------------
1 | ---
2 | name: Other Issue
3 | description: Other kind of the issue
4 | body:
5 | - type: textarea
6 | id: summary
7 | attributes:
8 | label: Summary
9 | description: Any precise description of the issue
10 | validations:
11 | required: true
12 |
--------------------------------------------------------------------------------
/constants.py:
--------------------------------------------------------------------------------
1 | from os.path import realpath, join, dirname
2 |
3 | from chromadb.config import Settings
4 | from langchain.document_loaders.base import BaseLoader
5 |
6 |
7 | ROOT_DIRECTORY = dirname(realpath(__file__))
8 |
9 | SOURCE_DIRECTORY = join(ROOT_DIRECTORY, 'knowledge')
10 |
11 | PERSIST_DIRECTORY = join(ROOT_DIRECTORY, 'db')
12 |
13 | CHROMA_SETTINGS = Settings(
14 | chroma_db_impl='duckdb+parquet',
15 | persist_directory=PERSIST_DIRECTORY,
16 | anonymized_telemetry=False
17 | )
18 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/01_bug_report.yml:
--------------------------------------------------------------------------------
1 | ---
2 | name: 🐛 Bug Report
3 | description: Something isn't working as expected
4 | labels:
5 | - bug
6 | body:
7 | - type: input
8 | id: testcase
9 | attributes:
10 | label: Reproducible test case
11 | description:
12 | If possible, please create a minimal test case that reproduces your
13 | problem.
14 | validations:
15 | required: true
16 | - type: textarea
17 | id: summary
18 | attributes:
19 | label: Additional information
20 | description:
21 | Please share any other relevant information not mentioned above. What
22 | did you expect to happen? What do you think the problem might be?
23 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/02_feature_request.yml:
--------------------------------------------------------------------------------
1 | ---
2 | name: 🟢 Feature Request
3 | description: Wouldn’t it be nice if
4 | labels:
5 | - feature
6 | body:
7 | - type: textarea
8 | id: summary
9 | attributes:
10 | label: What?
11 | description: Describe your feature idea
12 | validations:
13 | required: true
14 | - type: textarea
15 | id: why
16 | attributes:
17 | label: Why?
18 | description: Describe the problem you are facing
19 | validations:
20 | required: true
21 | - type: textarea
22 | id: alternatives
23 | attributes:
24 | label: How?
25 | description:
26 | Describe you tried or ideas to implement the feature
27 | validations:
28 | required: false
29 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import fire
2 |
3 | from llama2gptq.ingest import ingest
4 | from llama2gptq.qa import chat_cli
5 | from llama2gptq.quantize import quantization
6 | from constants import (SOURCE_DIRECTORY, PERSIST_DIRECTORY)
7 |
8 |
9 | def chat(device: str = "cuda") -> str:
10 | stats = chat_cli(device)
11 | return stats
12 |
13 |
14 | def process(src_dir: str = SOURCE_DIRECTORY, dst_dir: str = PERSIST_DIRECTORY, device: str = "cuda") -> str:
15 | return ingest(src_dir, dst_dir, device)
16 |
17 |
18 | def quantize(model: str = "meta-llama/Llama-2-13b-chat-hf",
19 | output: str = "llama-2-13b-chat-hf-gptq",
20 | push: bool = False, owner: str = 'seonglae',
21 | safetensor = False, inference_only: bool = False) -> str:
22 | quantization(model, output, push, owner, safetensor, inference_only)
23 | return 'complete'
24 |
25 |
26 | if __name__ == '__main__':
27 | fire.Fire()
28 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Alan Jo
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | authors = [
3 | {name = "seonglae", email = "sungle3737@gmail.com"},
4 | ]
5 | dependencies = [
6 | "langchain~=0.0.225",
7 | "chromadb~=0.3.26",
8 | "transformers~=4.30.2",
9 | "InstructorEmbedding~=1.0.1",
10 | "sentence_transformers~=2.2.2",
11 | "unstructured~=0.7.12",
12 | "torch>=2.0.1",
13 | "auto_gptq~=0.2.2",
14 | "einops~=0.6.1",
15 | "fire~=0.5.0",
16 | "streamlit-chat~=0.1.1",
17 | "protobuf<=3.20.0"
18 | ]
19 | description = "Chat AI which can provide responses with reference documents by Prompt engineering over vector database."
20 | license = {text = "MIT"}
21 | name = "llama2gptq"
22 | readme = "README.md"
23 | requires-python = ">= 3.8"
24 | version = "0.1.0"
25 |
26 | [build-system]
27 | build-backend = "hatchling.build"
28 | requires = ["hatchling"]
29 |
30 | [tool.rye]
31 | dev-dependencies = [
32 | "autopep8~=2.0.2",
33 | "pip~=23.1.2",
34 | "mypy~=1.3.0",
35 | "setuptools~=68.0.0",
36 | ]
37 | managed = true
38 |
39 | [[tool.rye.sources]]
40 | name = "cuda"
41 | url = "https://download.pytorch.org/whl/cu118"
42 | type = "index"
43 |
--------------------------------------------------------------------------------
/llama2gptq/qa.py:
--------------------------------------------------------------------------------
1 | from time import time
2 | from typing import Tuple, List
3 |
4 | import torch
5 | from transformers import Pipeline
6 |
7 | from llama2gptq.ingest import extract_ref
8 | from llama2gptq.generate import load_embeddings, load_db, load_model, TokenStoppingCriteria
9 |
10 |
11 | @torch.no_grad()
12 | def qa(query, device, db, transformer: Pipeline, history: List[List[str]],
13 | user_token="USER: ",
14 | bot_token="ASSISTANT: ",
15 | sys_token="",
16 | system="",
17 | extract_ref=extract_ref) -> Tuple:
18 | start = time()
19 |
20 | if db is None:
21 | embeddings = load_embeddings(device)
22 | db = load_db(device, embeddings)
23 | if transformer is None:
24 | transformer = load_model(device)
25 |
26 | # input similarity
27 | conversation = [f"{user_token}{q}\n{bot_token}{a}\n" for [q, a] in history]
28 | prompt = f"{sys_token}{system}" + \
29 | "".join(conversation) + f'{user_token}{query}\n{bot_token}'
30 | print(prompt)
31 |
32 | # Inference
33 | criteria = TokenStoppingCriteria(
34 | user_token.strip(), prompt, transformer.tokenizer)
35 | response = transformer(prompt, stopping_criteria=criteria)[
36 | 0]["generated_text"]
37 | answer = response.replace(prompt, "").strip()
38 |
39 | # output similarity
40 | refs = db.search(
41 | f'{user_token}{query}\n{bot_token}', search_type="similarity")
42 |
43 | # Print the result
44 | print('\nHelpful links\n')
45 | for ref in refs:
46 | ref_info = extract_ref(ref)
47 | print(f"{ref_info['title']}: {ref_info['link']}")
48 |
49 | print(f"\nTime taken: {time() - start} seconds\n")
50 | print(prompt + answer + '\n')
51 |
52 | return (answer, refs)
53 |
54 |
55 | def qa_cli(device, db, llm, history) -> Tuple:
56 | query = input("\nQuestion: ")
57 | if query == "exit":
58 | return ()
59 | return (query, *qa(query, device, db, llm, history))
60 |
61 |
62 | def chat_cli(device='cuda'):
63 | embeddings = load_embeddings(device)
64 | db = load_db(device, embeddings)
65 | transformer = load_model(device)
66 |
67 | pingongs = []
68 | while True:
69 | history = [[pingpong[0], pingpong[1]] for pingpong in pingongs]
70 | pingpong = qa_cli(device, db, transformer, history)
71 | if len(pingpong) == 0:
72 | break
73 | pingongs.append(pingpong)
74 | return pingongs
75 |
--------------------------------------------------------------------------------
/llama2gptq/ingest.py:
--------------------------------------------------------------------------------
1 | import os
2 | from re import split
3 | from typing import List, Type, Dict
4 | from pathlib import Path
5 |
6 | from langchain.docstore.document import Document
7 | from langchain.embeddings import HuggingFaceInstructEmbeddings
8 | from langchain.text_splitter import RecursiveCharacterTextSplitter
9 | from langchain.document_loaders.base import BaseLoader
10 | from langchain.vectorstores import Chroma
11 | from langchain.document_loaders import (
12 | CSVLoader,
13 | PDFMinerLoader,
14 | TextLoader,
15 | UnstructuredMarkdownLoader,
16 | UnstructuredExcelLoader,
17 | )
18 |
19 | from constants import (CHROMA_SETTINGS)
20 |
21 |
22 | DOCUMENT_MAP = {
23 | ".txt": TextLoader,
24 | ".pdf": PDFMinerLoader,
25 | ".csv": CSVLoader,
26 | ".xls": UnstructuredExcelLoader,
27 | ".xlxs": UnstructuredExcelLoader,
28 | ".md": TextLoader,
29 | }
30 |
31 |
32 | def load_documents(folder_path: str) -> List[Document]:
33 | glob = Path(folder_path).glob
34 | ps = list(glob("**/*.md"))
35 | documents = []
36 | for p in ps:
37 | file_extension = os.path.splitext(p)[1]
38 | loader_class = DOCUMENT_MAP.get(file_extension)
39 | if loader_class:
40 | loader = loader_class(p, encoding="utf-8")
41 | document = loader.load()[0]
42 | document.metadata["source"] = str(p)
43 | documents.append(document)
44 | else:
45 | continue
46 | return documents
47 |
48 |
49 | def extract_ref(ref: Document) -> Dict[str, str]:
50 | source = split(r"\\|/", ref.metadata["source"])[-1]
51 | slug = split(r" |.md", source)[-2]
52 | title = ' '.join(slug.split('-')[:-1])
53 | link = f"https://texonom.com/{slug}"
54 | return {"title": title, "link": link}
55 |
56 |
57 | def ingest(source: str, output: str, device='cuda'):
58 | print(f"Loading documents from {source}")
59 | documents = load_documents(source)
60 | for doc in documents:
61 | doc.metadata["source"] = extract_ref(doc)['link']
62 | text_splitter = RecursiveCharacterTextSplitter(
63 | chunk_size=1000, chunk_overlap=200)
64 | texts = text_splitter.split_documents(documents)
65 | print(f"Loaded {len(documents)} documents from {source}")
66 | print(f"Split into {len(texts)} chunks of text")
67 |
68 | embeddings = HuggingFaceInstructEmbeddings(
69 | model_name="intfloat/multilingual-e5-large",
70 | model_kwargs={"device": device},
71 | )
72 | db = Chroma.from_documents(
73 | texts,
74 | embeddings,
75 | persist_directory=output,
76 | client_settings=CHROMA_SETTINGS,
77 | )
78 | db.persist()
79 |
--------------------------------------------------------------------------------
/llama2gptq/quantize.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from transformers import AutoTokenizer, TextGenerationPipeline, GenerationConfig, LlamaTokenizer, LlamaTokenizerFast
3 | from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
4 |
5 |
6 | def quantization(source_model: str, output: str, push: bool, owner: str,
7 | safetensor=False, inference_only=False):
8 | logging.basicConfig(
9 | format="%(asctime)s %(levelname)s [%(name)s] %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S"
10 | )
11 |
12 | tokenizer = AutoTokenizer.from_pretrained(
13 | source_model, use_fast=True, use_auth_token=True)
14 | examples = [
15 | tokenizer(
16 | "Texonom is an knowledge system that can help you with your daily tasks using AI chatbot."
17 | )
18 | ]
19 |
20 | quantize_config = BaseQuantizeConfig(
21 | bits=4, # quantize model to 4-bit
22 | group_size=128, # it is recommended to set the value to 128
23 | desc_act=False, # None act-order can significantly speed up inference but the perplexity may slightly bad
24 | )
25 |
26 | # quantize model, the examples should be list of dict whose keys can only be "input_ids" and "attention_mask"
27 | if not inference_only:
28 | # load un-quantized model, by default, the model will always be loaded into CPU memory
29 | model = AutoGPTQForCausalLM.from_pretrained(
30 | source_model, quantize_config, use_safetensors=safetensor)
31 | model.quantize(examples)
32 | model.save_quantized(output, use_safetensors=safetensor)
33 |
34 | # load quantized model to the first GPU
35 | quantized = AutoGPTQForCausalLM.from_quantized(
36 | output,
37 | device="cuda:0",
38 | use_safetensors=safetensor
39 | )
40 |
41 | # inference with model.generate
42 | query = "USER: Are you AI? Say yes or no.\n ASSISTANT:"
43 |
44 | # or you can also use pipeline
45 | pipeline = TextGenerationPipeline(model=quantized, tokenizer=tokenizer)
46 | print(pipeline(query)[0]["generated_text"])
47 |
48 | # push quantized model to Hugging Face Hub.
49 | # to use use_auth_token=True, Login first via huggingface-cli login.
50 | if push and not inference_only:
51 | commit_message = f"build: AutoGPTQ for {source_model}" + \
52 | f": {quantize_config.bits}bits, gr{quantize_config.group_size}, desc_act={quantize_config.desc_act}"
53 | generation_config = GenerationConfig.from_pretrained(source_model)
54 | generation_config.push_to_hub(
55 | output, use_auth_token=True, commit_message=commit_message)
56 | tokenizer.push_to_hub(output, use_auth_token=True,
57 | commit_message=commit_message)
58 | repo_id = f"{owner}/{output}"
59 | quantized.push_to_hub(repo_id, use_safetensors=safetensor,
60 | commit_message=commit_message, use_auth_token=True)
61 |
--------------------------------------------------------------------------------
/chat.py:
--------------------------------------------------------------------------------
1 | from re import split
2 |
3 | import torch
4 | import streamlit as st
5 | from streamlit_chat import message
6 |
7 | from llama2gptq.qa import qa, load_model, load_db
8 | from llama2gptq.ingest import extract_ref
9 |
10 | DEVICE = 'cuda'
11 | TITLE = 'LLaMa2 GPTQ'
12 | HUG = 'https://em-content.zobj.net/source/microsoft-teams/363/hugging-face_1f917.png'
13 | ANGRY = 'https://em-content.zobj.net/source/microsoft-teams/363/pouting-face_1f621.png'
14 |
15 |
16 | st.set_page_config(page_title=TITLE)
17 | st.header(TITLE)
18 | st.markdown('''
19 | ### Ask anythig to [Texonom](https://texonom.com).
20 | Question for recently learned
21 | ''', unsafe_allow_html=True)
22 |
23 |
24 | @st.cache_resource
25 | def load_transformer():
26 | return (load_model(DEVICE), load_db(DEVICE))
27 |
28 |
29 | transformer, db = load_transformer()
30 |
31 | styl = """
32 |
48 | """
49 |
50 | BTN_STYLE = """
51 | color: #aaa;
52 | padding-right: 0.5rem;
53 | """
54 |
55 |
56 | st.markdown(styl, unsafe_allow_html=True)
57 |
58 | if 'generated' not in st.session_state:
59 | st.session_state['generated'] = []
60 |
61 | if 'past' not in st.session_state:
62 | st.session_state['past'] = []
63 |
64 | if 'answers' not in st.session_state:
65 | st.session_state['answers'] = []
66 |
67 |
68 | def query(query):
69 | st.session_state.past.append(query)
70 | history = []
71 | for i, _ in enumerate(st.session_state['generated']):
72 | history.append([st.session_state['past'][i],
73 | st.session_state["generated"][i]])
74 |
75 | answer, refs = qa(query, DEVICE, db, transformer, history)
76 |
77 | # Append references
78 | st.session_state.generated.append(answer)
79 |
80 | # Generate HTML
81 | answer += '
References: '
82 | for ref in refs:
83 | ref_info = extract_ref(ref)
84 | btn = f"{ref_info['title']}"
85 | answer += btn
86 |
87 | st.session_state.answers.append(answer)
88 | return answer
89 |
90 |
91 | def get_text():
92 | input_text = st.text_input("You: ", key="input")
93 | return input_text
94 |
95 |
96 | user_input = get_text()
97 |
98 |
99 | if user_input:
100 | query(user_input)
101 |
102 |
103 | if st.session_state['generated']:
104 | for i, _ in enumerate(st.session_state['generated']):
105 | message(st.session_state['past'][i], is_user=True,
106 | key=str(i) + '_user', logo=HUG)
107 | message(st.session_state["answers"][i],
108 | key=str(i), logo=ANGRY, allow_html=True)
109 |
--------------------------------------------------------------------------------
/.github/commit-convention.md:
--------------------------------------------------------------------------------
1 | ## Git Commit Message Convention
2 |
3 | > This is adapted from [Vite's commit convention](https://github.com/vitejs/vite/blob/main/.github/commit-convention.md).
4 |
5 | #### TL;DR:
6 |
7 | Messages must be matched by the following regex:
8 |
9 |
10 | ```re
11 | ^(revert: )?(feat|fix|docs|style|refactor|test|ci|cd|build|meta|pr|lint|typing|perf|deps|merge)(\(.+\))?: .{1,50}
12 | ```
13 |
14 | #### Examples
15 |
16 | Appears under "Features" header, `dev` subheader:
17 |
18 | ```
19 | feat(dev): add 'comments' option
20 | ```
21 |
22 | Appears under "Bug Fixes" header, `dev` subheader, with a link to issue #28:
23 |
24 | ```
25 | fix(dev): fix dev error
26 |
27 | close #28
28 | ```
29 |
30 | Appears under "Performance Improvements" header, and under "Breaking Changes" with the breaking change explanation:
31 |
32 | ```
33 | perf(build): remove 'foo' option
34 |
35 | BREAKING CHANGE: The 'foo' option has been removed.
36 | ```
37 |
38 | The following commit and commit `667ecc1` do not appear in the changelog if they are under the same release. If not, the revert commit appears under the "Reverts" header.
39 |
40 | ```
41 | revert: feat(compiler): add 'comments' option
42 |
43 | This reverts commit 667ecc1654a317a13331b17617d973392f415f02.
44 | ```
45 |
46 | ### Full Message Format
47 |
48 | A commit message consists of a **header**, **body** and **footer**. The header has a **type**, **scope** and **subject**:
49 |
50 | ```
51 | ():
52 |
53 |
54 |
55 |