├── .github
├── FUNDING.yml
└── workflows
│ └── python-publish.yml
├── .gitignore
├── LICENSE
├── README.md
├── poetry.lock
├── pyproject.toml
├── requirements.txt
└── talk_codebase
├── __init__.py
├── cli.py
├── config.py
├── consts.py
├── llm.py
└── utils.py
/.github/FUNDING.yml:
--------------------------------------------------------------------------------
1 | # These are supported funding model platforms
2 |
3 | github: []
--------------------------------------------------------------------------------
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflow will upload a Python Package using Twine when a release is created
2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries
3 |
4 | # This workflow uses actions that are not certified by GitHub.
5 | # They are provided by a third-party and are governed by
6 | # separate terms of service, privacy policy, and support
7 | # documentation.
8 |
9 | name: Upload Python Package
10 |
11 | on:
12 | workflow_dispatch:
13 |
14 | permissions:
15 | contents: read
16 |
17 | jobs:
18 | deploy:
19 |
20 | runs-on: ubuntu-latest
21 |
22 | steps:
23 | - uses: actions/checkout@v3
24 | - name: Set up Python
25 | uses: actions/setup-python@v3
26 | with:
27 | python-version: '3.x'
28 | - name: Install dependencies
29 | run: |
30 | python -m pip install --upgrade pip
31 | pip install build
32 | - name: Build package
33 | run: python -m build
34 | - name: Publish package
35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
36 | with:
37 | user: __token__
38 | password: ${{ secrets.PYPI_API_TOKEN }}
39 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | /.env
2 | /.idea/
3 | /.vscode/
4 | /.venv/
5 | /talk_codebase/__pycache__/
6 | .DS_Store
7 | /vector_store/
8 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Saryev Rustam
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # talk-codebase
2 |
3 | [](https://github.com/rsaryev/talk-codebase/actions/workflows/python-publish.yml)
4 |
5 | Talk-codebase is a tool that allows you to converse with your codebase using Large Language Models (LLMs) to answer your
6 | queries. It supports offline code processing using LlamaCpp and [GPT4All](https://github.com/nomic-ai/gpt4all) without
7 | sharing your code with third parties, or you can use OpenAI if privacy is not a concern for you. Please note that
8 | talk-codebase is still under development and is recommended for educational purposes, not for production use.
9 |
10 |
11 |
12 |
13 |
14 | ## Installation
15 |
16 | Requirement Python 3.8.1 or higher
17 | Your project must be in a git repository
18 |
19 | ```bash
20 | pip install talk-codebase
21 | ```
22 |
23 | After installation, you can use it to chat with your codebase in the current directory by running the following command:
24 |
25 | ```bash
26 | talk-codebase chat
27 | ```
28 |
29 | Select model type: Local or OpenAI
30 |
31 |
32 |
33 | OpenAI
34 |
35 | If you use the OpenAI model, you need an OpenAI API key. You can get it from [here](https://beta.openai.com/). Then you
36 | will be offered a choice of available models.
37 |
38 |
39 |
40 |
41 | Local
42 |
43 |
44 |
45 | If you want some files to be ignored, add them to .gitignore.
46 |
47 | ## Reset configuration
48 |
49 | To reset the configuration, run the following command:
50 |
51 | ```bash
52 | talk-codebase configure
53 | ```
54 |
55 | ## Advanced configuration
56 |
57 | You can manually edit the configuration by editing the `~/.config.yaml` file. If you cannot find the configuration file,
58 | run the tool and it will output the path to the configuration file at the very beginning.
59 |
60 | ## Supported Extensions
61 |
62 | - [x] `.csv`
63 | - [x] `.doc`
64 | - [x] `.docx`
65 | - [x] `.epub`
66 | - [x] `.md`
67 | - [x] `.pdf`
68 | - [x] `.txt`
69 | - [x] `popular programming languages`
70 |
71 | ## Contributing
72 |
73 | * If you find a bug in talk-codebase, please report it on the project's issue tracker. When reporting a bug, please
74 | include as much information as possible, such as the steps to reproduce the bug, the expected behavior, and the actual
75 | behavior.
76 | * If you have an idea for a new feature for Talk-codebase, please open an issue on the project's issue tracker. When
77 | suggesting a feature, please include a brief description of the feature, as well as any rationale for why the feature
78 | would be useful.
79 | * You can contribute to talk-codebase by writing code. The project is always looking for help with improving the
80 | codebase, adding new features, and fixing bugs.
81 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.poetry]
2 | name = "talk-codebase"
3 | version = "0.1.50"
4 | description = "talk-codebase is a powerful tool for querying and analyzing codebases."
5 | authors = ["Saryev Rustam "]
6 | readme = "README.md"
7 | packages = [{ include = "talk_codebase" }]
8 | keywords = ["chatgpt", "openai", "cli"]
9 |
10 | [tool.poetry.dependencies]
11 | python = ">=3.8.1,<4.0"
12 | fire = "^0.5.0"
13 | faiss-cpu = "^1.7.4"
14 | halo = "^0.0.31"
15 | urllib3 = "1.26.18"
16 | gitpython = "^3.1.31"
17 | questionary = "^1.10.0"
18 | sentence-transformers = "^2.2.2"
19 | unstructured = "^0.6.10"
20 | langchain = ">=0.0.223,<0.1.12"
21 | llama-cpp-python = { version = "^0.1.68", optional = true }
22 | gpt4all = { version = "^1.0.1", optional = true }
23 | openai = { version = "^0.27.7", optional = true }
24 | tiktoken = { version = "^0.4.0", optional = true }
25 |
26 | [tool.poetry.extras]
27 | local = ["gpt4all", "llama-cpp-python"]
28 | all = ["gpt4all", "llama-cpp-python", "openai", "tiktoken"]
29 |
30 |
31 | [tool.poetry.group.dev.dependencies]
32 | setuptools = ">=68,<71"
33 |
34 | [build-system]
35 | requires = ["poetry-core"]
36 | build-backend = "poetry.core.masonry.api"
37 |
38 | [project.urls]
39 | "Source" = "https://github.com/rsaryev/talk-codebase"
40 | "Bug Tracker" = "https://github.com/rsaryev/talk-codebase/issues"
41 |
42 | [tool.poetry.scripts]
43 | talk-codebase = "talk_codebase.cli:main"
44 |
45 |
46 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | aiohttp==3.9.2
2 | aiosignal==1.3.1
3 | async-timeout==4.0.2
4 | attrs==23.1.0
5 | certifi==2023.7.22
6 | charset-normalizer==3.1.0
7 | click==8.1.3
8 | colorama==0.4.6
9 | colored==1.4.4
10 | dataclasses-json==0.5.7
11 | faiss-cpu==1.7.4
12 | filelock==3.12.0
13 | fire==0.5.0
14 | frozenlist==1.3.3
15 | fsspec==2023.5.0
16 | gitdb==4.0.10
17 | GitPython==3.1.41
18 | gpt4all==0.2.3
19 | halo==0.0.31
20 | huggingface-hub==0.14.1
21 | idna==3.7
22 | Jinja2==3.1.3
23 | joblib==1.2.0
24 | langchain==0.0.325
25 | log-symbols==0.0.14
26 | MarkupSafe==2.1.2
27 | marshmallow==3.19.0
28 | marshmallow-enum==1.5.1
29 | mpmath==1.3.0
30 | multidict==6.0.4
31 | mypy-extensions==1.0.0
32 | networkx==3.1
33 | nltk==3.8.1
34 | numexpr==2.8.4
35 | numpy==1.24.3
36 | openai==0.27.7
37 | openapi-schema-pydantic==1.2.4
38 | packaging==23.1
39 | Pillow==10.3.0
40 | prompt-toolkit==3.0.38
41 | pydantic==1.10.8
42 | PyYAML==6.0
43 | questionary==1.10.0
44 | regex==2023.5.5
45 | requests==2.31.0
46 | scikit-learn==1.2.2
47 | scipy==1.10.0
48 | sentence-transformers==2.2.2
49 | sentencepiece==0.1.99
50 | six==1.16.0
51 | smmap==5.0.0
52 | spinners==0.0.24
53 | SQLAlchemy==2.0.15
54 | sympy==1.12
55 | tenacity==8.2.2
56 | termcolor==2.3.0
57 | threadpoolctl==3.1.0
58 | tiktoken==0.4.0
59 | tokenizers==0.13.3
60 | torch==2.0.1
61 | torchvision==0.15.2
62 | tqdm==4.65.0
63 | transformers==4.38.0
64 | typing-inspect==0.9.0
65 | typing_extensions==4.6.2
66 | urllib3==1.26.18
67 | wcwidth==0.2.6
68 | yarl==1.9.2
69 |
--------------------------------------------------------------------------------
/talk_codebase/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rsaryev/talk-codebase/86f1771abb19eb00bd88a4ea76abbe3156e98125/talk_codebase/__init__.py
--------------------------------------------------------------------------------
/talk_codebase/cli.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | import fire
4 |
5 | from talk_codebase.config import CONFIGURE_STEPS, save_config, get_config, config_path, remove_api_key, \
6 | remove_model_type, remove_model_name_local
7 | from talk_codebase.consts import DEFAULT_CONFIG
8 | from talk_codebase.llm import factory_llm
9 | from talk_codebase.utils import get_repo
10 |
11 |
12 | def check_python_version():
13 | if sys.version_info < (3, 8, 1):
14 | print("🤖 Please use Python 3.8.1 or higher")
15 | sys.exit(1)
16 |
17 |
18 | def update_config(config):
19 | for key, value in DEFAULT_CONFIG.items():
20 | if key not in config:
21 | config[key] = value
22 | return config
23 |
24 |
25 | def configure(reset=True):
26 | if reset:
27 | remove_api_key()
28 | remove_model_type()
29 | remove_model_name_local()
30 | config = get_config()
31 | config = update_config(config)
32 | for step in CONFIGURE_STEPS:
33 | step(config)
34 | save_config(config)
35 |
36 |
37 | def chat_loop(llm):
38 | while True:
39 | query = input("👉 ").lower().strip()
40 | if not query:
41 | print("🤖 Please enter a query")
42 | continue
43 | if query in ('exit', 'quit'):
44 | break
45 | llm.send_query(query)
46 |
47 |
48 | def chat():
49 | configure(False)
50 | config = get_config()
51 | repo = get_repo()
52 | if not repo:
53 | print("🤖 Git repository not found")
54 | sys.exit(1)
55 | llm = factory_llm(repo.working_dir, config)
56 | chat_loop(llm)
57 |
58 |
59 | def main():
60 | check_python_version()
61 | print(f"🤖 Config path: {config_path}:")
62 | try:
63 | fire.Fire({
64 | "chat": chat,
65 | "configure": lambda: configure(True)
66 | })
67 | except KeyboardInterrupt:
68 | print("\n🤖 Bye!")
69 | except Exception as e:
70 | raise e
71 |
72 |
73 | if __name__ == "__main__":
74 | main()
75 |
--------------------------------------------------------------------------------
/talk_codebase/config.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | openai_flag = True
4 |
5 | try:
6 | import openai
7 | except:
8 | openai_flag = False
9 |
10 | import gpt4all
11 | import questionary
12 | import yaml
13 |
14 | from talk_codebase.consts import MODEL_TYPES
15 |
16 |
17 |
18 |
19 | config_path = os.path.join(os.path.expanduser("~"), ".talk_codebase_config.yaml")
20 |
21 |
22 | def get_config():
23 | if os.path.exists(config_path):
24 | with open(config_path, "r") as f:
25 | config = yaml.safe_load(f)
26 | else:
27 | config = {}
28 | return config
29 |
30 |
31 | def save_config(config):
32 | with open(config_path, "w") as f:
33 | yaml.dump(config, f)
34 |
35 |
36 | def api_key_is_invalid(api_key):
37 | if not api_key:
38 | return True
39 | try:
40 | openai.api_key = api_key
41 | openai.Engine.list()
42 | except Exception:
43 | return True
44 | return False
45 |
46 |
47 | def get_gpt_models(openai):
48 | try:
49 | model_lst = openai.Model.list()
50 | except Exception:
51 | print("✘ Failed to retrieve model list")
52 | return []
53 |
54 | return [i['id'] for i in model_lst['data'] if 'gpt' in i['id']]
55 |
56 |
57 | def configure_model_name_openai(config):
58 | api_key = config.get("api_key")
59 |
60 | if config.get("model_type") != MODEL_TYPES["OPENAI"] or config.get("openai_model_name"):
61 | return
62 |
63 | openai.api_key = api_key
64 | gpt_models = get_gpt_models(openai)
65 | choices = [{"name": model, "value": model} for model in gpt_models]
66 |
67 | if not choices:
68 | print("ℹ No GPT models available")
69 | return
70 |
71 | model_name = questionary.select("🤖 Select model name:", choices).ask()
72 |
73 | if not model_name:
74 | print("✘ No model selected")
75 | return
76 |
77 | config["openai_model_name"] = model_name
78 | save_config(config)
79 | print("🤖 Model name saved!")
80 |
81 |
82 | def remove_model_name_openai():
83 | config = get_config()
84 | config["openai_model_name"] = None
85 | save_config(config)
86 |
87 |
88 | def configure_model_name_local(config):
89 | if config.get("model_type") != MODEL_TYPES["LOCAL"] or config.get("local_model_name"):
90 | return
91 |
92 | list_models = gpt4all.GPT4All.list_models()
93 |
94 | def get_model_info(model):
95 | return (
96 | f"{model['name']} "
97 | f"| {model['filename']} "
98 | f"| {model['filesize']} "
99 | f"| {model['parameters']} "
100 | f"| {model['quant']} "
101 | f"| {model['type']}"
102 | )
103 |
104 | choices = [
105 | {"name": get_model_info(model), "value": model['filename']} for model in list_models
106 | ]
107 |
108 | model_name = questionary.select("🤖 Select model name:", choices).ask()
109 | config["local_model_name"] = model_name
110 | save_config(config)
111 | print("🤖 Model name saved!")
112 |
113 |
114 | def remove_model_name_local():
115 | config = get_config()
116 | config["local_model_name"] = None
117 | save_config(config)
118 |
119 |
120 | def get_and_validate_api_key():
121 | prompt = "🤖 Enter your OpenAI API key: "
122 | api_key = input(prompt)
123 | while api_key_is_invalid(api_key):
124 | print("✘ Invalid API key")
125 | api_key = input(prompt)
126 | return api_key
127 |
128 |
129 | def configure_api_key(config):
130 | if config.get("model_type") != MODEL_TYPES["OPENAI"]:
131 | return
132 |
133 | if api_key_is_invalid(config.get("api_key")):
134 | api_key = get_and_validate_api_key()
135 | config["api_key"] = api_key
136 | save_config(config)
137 |
138 |
139 | def remove_api_key():
140 | config = get_config()
141 | config["api_key"] = None
142 | save_config(config)
143 |
144 |
145 | def remove_model_type():
146 | config = get_config()
147 | config["model_type"] = None
148 | save_config(config)
149 |
150 |
151 | def configure_model_type(config):
152 | if config.get("model_type"):
153 | return
154 |
155 | choices = [{"name": "Local", "value": MODEL_TYPES["LOCAL"]}]
156 |
157 | if openai_flag: choices.append(
158 | {"name": "OpenAI", "value": MODEL_TYPES["OPENAI"]})
159 |
160 |
161 | model_type = questionary.select(
162 | "🤖 Select model type:",
163 | choices=choices
164 | ).ask()
165 | config["model_type"] = model_type
166 | save_config(config)
167 |
168 |
169 | CONFIGURE_STEPS = [
170 | configure_model_type,
171 | configure_api_key,
172 | configure_model_name_openai,
173 | configure_model_name_local,
174 | ]
175 |
--------------------------------------------------------------------------------
/talk_codebase/consts.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 |
4 | from langchain.document_loaders import CSVLoader, UnstructuredWordDocumentLoader, UnstructuredEPubLoader, \
5 | PDFMinerLoader, UnstructuredMarkdownLoader, TextLoader
6 |
7 | EXCLUDE_DIRS = ['__pycache__', '.venv', '.git', '.idea', 'venv', 'env', 'node_modules', 'dist', 'build', '.vscode',
8 | '.github', '.gitlab']
9 | ALLOW_FILES = ['.txt', '.js', '.mjs', '.ts', '.tsx', '.css', '.scss', '.less', '.html', '.htm', '.json', '.py',
10 | '.java', '.c', '.cpp', '.cs', '.go', '.php', '.rb', '.rs', '.swift', '.kt', '.scala', '.m', '.h',
11 | '.sh', '.pl', '.pm', '.lua', '.sql']
12 | EXCLUDE_FILES = ['requirements.txt', 'package.json', 'package-lock.json', 'yarn.lock']
13 | MODEL_TYPES = {
14 | "OPENAI": "openai",
15 | "LOCAL": "local",
16 | }
17 | DEFAULT_MODEL_DIRECTORY = os.path.join(str(Path.home()), ".cache", "gpt4all").replace("\\", "\\\\")
18 |
19 | DEFAULT_CONFIG = {
20 | "max_tokens": "2056",
21 | "chunk_size": "2056",
22 | "chunk_overlap": "256",
23 | "k": "2",
24 | "temperature": "0.7",
25 | "model_path": DEFAULT_MODEL_DIRECTORY,
26 | "n_batch": "8",
27 | }
28 |
29 | LOADER_MAPPING = {
30 | ".csv": {
31 | "loader": CSVLoader,
32 | "args": {}
33 | },
34 | ".doc": {
35 | "loader": UnstructuredWordDocumentLoader,
36 | "args": {}
37 | },
38 | ".docx": {
39 | "loader": UnstructuredWordDocumentLoader,
40 | "args": {}
41 | },
42 | ".epub": {
43 | "loader": UnstructuredEPubLoader,
44 | "args": {}
45 | },
46 | ".md": {
47 | "loader": UnstructuredMarkdownLoader,
48 | "args": {}
49 | },
50 | ".pdf": {
51 | "loader": PDFMinerLoader,
52 | "args": {}
53 | }
54 | }
55 |
56 | for ext in ALLOW_FILES:
57 | if ext not in LOADER_MAPPING:
58 | LOADER_MAPPING[ext] = {
59 | "loader": TextLoader,
60 | "args": {}
61 | }
62 |
--------------------------------------------------------------------------------
/talk_codebase/llm.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | from typing import Optional
4 |
5 | import gpt4all
6 | import questionary
7 | from halo import Halo
8 | from langchain.vectorstores import FAISS
9 | from langchain.callbacks.manager import CallbackManager
10 | from langchain.chains import RetrievalQA
11 | from langchain.chat_models import ChatOpenAI
12 | from langchain.embeddings import HuggingFaceEmbeddings, OpenAIEmbeddings
13 | from langchain.llms import LlamaCpp
14 | from langchain.text_splitter import RecursiveCharacterTextSplitter
15 |
16 | from talk_codebase.consts import MODEL_TYPES
17 | from talk_codebase.utils import load_files, get_local_vector_store, calculate_cost, StreamStdOut
18 |
19 |
20 | class BaseLLM:
21 |
22 | def __init__(self, root_dir, config):
23 | self.config = config
24 | self.llm = self._create_model()
25 | self.root_dir = root_dir
26 | self.vector_store = self._create_store(root_dir)
27 |
28 | def _create_store(self, root_dir):
29 | raise NotImplementedError("Subclasses must implement this method.")
30 |
31 | def _create_model(self):
32 | raise NotImplementedError("Subclasses must implement this method.")
33 |
34 | def embedding_search(self, query, k):
35 | return self.vector_store.search(query, k=k, search_type="similarity")
36 |
37 | def _create_vector_store(self, embeddings, index, root_dir):
38 | k = int(self.config.get("k"))
39 | index_path = os.path.join(root_dir, f"vector_store/{index}")
40 | new_db = get_local_vector_store(embeddings, index_path)
41 | if new_db is not None:
42 | return new_db.as_retriever(search_kwargs={"k": k})
43 |
44 | docs = load_files()
45 | if len(docs) == 0:
46 | print("✘ No documents found")
47 | exit(0)
48 | text_splitter = RecursiveCharacterTextSplitter(chunk_size=int(self.config.get("chunk_size")),
49 | chunk_overlap=int(self.config.get("chunk_overlap")))
50 | texts = text_splitter.split_documents(docs)
51 | if index == MODEL_TYPES["OPENAI"]:
52 | cost = calculate_cost(docs, self.config.get("openai_model_name"))
53 | approve = questionary.select(
54 | f"Creating a vector store will cost ~${cost:.5f}. Do you want to continue?",
55 | choices=[
56 | {"name": "Yes", "value": True},
57 | {"name": "No", "value": False},
58 | ]
59 | ).ask()
60 | if not approve:
61 | exit(0)
62 |
63 | spinners = Halo(text=f"Creating vector store", spinner='dots').start()
64 | db = FAISS.from_documents([texts[0]], embeddings)
65 | for i, text in enumerate(texts[1:]):
66 | spinners.text = f"Creating vector store ({i + 1}/{len(texts)})"
67 | db.add_documents([text])
68 | db.save_local(index_path)
69 | time.sleep(1.5)
70 |
71 | spinners.succeed(f"Created vector store")
72 | return db.as_retriever(search_kwargs={"k": k})
73 |
74 | def send_query(self, query):
75 | retriever = self._create_store(self.root_dir)
76 | qa = RetrievalQA.from_chain_type(
77 | llm=self.llm,
78 | chain_type="stuff",
79 | retriever=retriever,
80 | return_source_documents=True
81 | )
82 | docs = qa(query)
83 | file_paths = [os.path.abspath(s.metadata["source"]) for s in docs['source_documents']]
84 | print('\n'.join([f'📄 {file_path}:' for file_path in file_paths]))
85 |
86 |
87 | class LocalLLM(BaseLLM):
88 |
89 | def _create_store(self, root_dir: str) -> Optional[FAISS]:
90 | embeddings = HuggingFaceEmbeddings(model_name='all-MiniLM-L6-v2')
91 | return self._create_vector_store(embeddings, MODEL_TYPES["LOCAL"], root_dir)
92 |
93 | def _create_model(self):
94 | os.makedirs(self.config.get("model_path"), exist_ok=True)
95 | gpt4all.GPT4All.retrieve_model(model_name=self.config.get("local_model_name"),
96 | model_path=self.config.get("model_path"))
97 | model_path = os.path.join(self.config.get("model_path"), self.config.get("local_model_name"))
98 | model_n_ctx = int(self.config.get("max_tokens"))
99 | model_n_batch = int(self.config.get("n_batch"))
100 | callbacks = CallbackManager([StreamStdOut()])
101 | llm = LlamaCpp(model_path=model_path, n_ctx=model_n_ctx, n_batch=model_n_batch, callbacks=callbacks,
102 | verbose=False)
103 | llm.client.verbose = False
104 | return llm
105 |
106 |
107 | class OpenAILLM(BaseLLM):
108 | def _create_store(self, root_dir: str) -> Optional[FAISS]:
109 | embeddings = OpenAIEmbeddings(openai_api_key=self.config.get("api_key"))
110 | return self._create_vector_store(embeddings, MODEL_TYPES["OPENAI"], root_dir)
111 |
112 | def _create_model(self):
113 | return ChatOpenAI(model_name=self.config.get("openai_model_name"),
114 | openai_api_key=self.config.get("api_key"),
115 | streaming=True,
116 | max_tokens=int(self.config.get("max_tokens")),
117 | callback_manager=CallbackManager([StreamStdOut()]),
118 | temperature=float(self.config.get("temperature")))
119 |
120 |
121 | def factory_llm(root_dir, config):
122 | if config.get("model_type") == "openai":
123 | return OpenAILLM(root_dir, config)
124 | else:
125 | return LocalLLM(root_dir, config)
126 |
--------------------------------------------------------------------------------
/talk_codebase/utils.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | import tiktoken
4 | from git import Repo
5 | from langchain.vectorstores import FAISS
6 | from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
7 |
8 | from talk_codebase.consts import LOADER_MAPPING, EXCLUDE_FILES
9 |
10 |
11 | def get_repo():
12 | try:
13 | return Repo()
14 | except:
15 | return None
16 |
17 |
18 | class StreamStdOut(StreamingStdOutCallbackHandler):
19 | def on_llm_new_token(self, token: str, **kwargs) -> None:
20 | sys.stdout.write(token)
21 | sys.stdout.flush()
22 |
23 | def on_llm_start(self, serialized, prompts, **kwargs):
24 | sys.stdout.write("🤖 ")
25 |
26 | def on_llm_end(self, response, **kwargs):
27 | sys.stdout.write("\n")
28 | sys.stdout.flush()
29 |
30 |
31 | def load_files():
32 | repo = get_repo()
33 | if repo is None:
34 | return []
35 | files = []
36 | tree = repo.tree()
37 | for blob in tree.traverse():
38 | path = blob.path
39 | if any(
40 | path.endswith(exclude_file) for exclude_file in EXCLUDE_FILES):
41 | continue
42 | for ext in LOADER_MAPPING:
43 | if path.endswith(ext):
44 | print('\r' + f'📂 Loading files: {path}')
45 | args = LOADER_MAPPING[ext]['args']
46 | loader = LOADER_MAPPING[ext]['loader'](path, *args)
47 | files.extend(loader.load())
48 | return files
49 |
50 |
51 | def calculate_cost(texts, model_name):
52 | enc = tiktoken.encoding_for_model(model_name)
53 | all_text = ''.join([text.page_content for text in texts])
54 | tokens = enc.encode(all_text)
55 | token_count = len(tokens)
56 | cost = (token_count / 1000) * 0.0004
57 | return cost
58 |
59 |
60 | def get_local_vector_store(embeddings, path):
61 | try:
62 | return FAISS.load_local(path, embeddings)
63 | except:
64 | return None
65 |
--------------------------------------------------------------------------------