├── .github ├── FUNDING.yml └── workflows │ └── python-publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── poetry.lock ├── pyproject.toml ├── requirements.txt └── talk_codebase ├── __init__.py ├── cli.py ├── config.py ├── consts.py ├── llm.py └── utils.py /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: [] -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | workflow_dispatch: 13 | 14 | permissions: 15 | contents: read 16 | 17 | jobs: 18 | deploy: 19 | 20 | runs-on: ubuntu-latest 21 | 22 | steps: 23 | - uses: actions/checkout@v3 24 | - name: Set up Python 25 | uses: actions/setup-python@v3 26 | with: 27 | python-version: '3.x' 28 | - name: Install dependencies 29 | run: | 30 | python -m pip install --upgrade pip 31 | pip install build 32 | - name: Build package 33 | run: python -m build 34 | - name: Publish package 35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 36 | with: 37 | user: __token__ 38 | password: ${{ secrets.PYPI_API_TOKEN }} 39 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /.env 2 | /.idea/ 3 | /.vscode/ 4 | /.venv/ 5 | /talk_codebase/__pycache__/ 6 | .DS_Store 7 | /vector_store/ 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Saryev Rustam 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # talk-codebase 2 | 3 | [![Node.js Package](https://github.com/rsaryev/talk-codebase/actions/workflows/python-publish.yml/badge.svg)](https://github.com/rsaryev/talk-codebase/actions/workflows/python-publish.yml) 4 | 5 | Talk-codebase is a tool that allows you to converse with your codebase using Large Language Models (LLMs) to answer your 6 | queries. It supports offline code processing using LlamaCpp and [GPT4All](https://github.com/nomic-ai/gpt4all) without 7 | sharing your code with third parties, or you can use OpenAI if privacy is not a concern for you. Please note that 8 | talk-codebase is still under development and is recommended for educational purposes, not for production use. 9 | 10 |

11 | chat 12 |

13 | 14 | ## Installation 15 | 16 | Requirement Python 3.8.1 or higher 17 | Your project must be in a git repository 18 | 19 | ```bash 20 | pip install talk-codebase 21 | ``` 22 | 23 | After installation, you can use it to chat with your codebase in the current directory by running the following command: 24 | 25 | ```bash 26 | talk-codebase chat 27 | ``` 28 | 29 | Select model type: Local or OpenAI 30 | 31 | select_type 32 | 33 | OpenAI 34 | 35 | If you use the OpenAI model, you need an OpenAI API key. You can get it from [here](https://beta.openai.com/). Then you 36 | will be offered a choice of available models. 37 | 38 | select 39 | 40 | 41 | Local 42 | 43 | Снимок экрана 2023-07-12 в 03 47 58 44 | 45 | If you want some files to be ignored, add them to .gitignore. 46 | 47 | ## Reset configuration 48 | 49 | To reset the configuration, run the following command: 50 | 51 | ```bash 52 | talk-codebase configure 53 | ``` 54 | 55 | ## Advanced configuration 56 | 57 | You can manually edit the configuration by editing the `~/.config.yaml` file. If you cannot find the configuration file, 58 | run the tool and it will output the path to the configuration file at the very beginning. 59 | 60 | ## Supported Extensions 61 | 62 | - [x] `.csv` 63 | - [x] `.doc` 64 | - [x] `.docx` 65 | - [x] `.epub` 66 | - [x] `.md` 67 | - [x] `.pdf` 68 | - [x] `.txt` 69 | - [x] `popular programming languages` 70 | 71 | ## Contributing 72 | 73 | * If you find a bug in talk-codebase, please report it on the project's issue tracker. When reporting a bug, please 74 | include as much information as possible, such as the steps to reproduce the bug, the expected behavior, and the actual 75 | behavior. 76 | * If you have an idea for a new feature for Talk-codebase, please open an issue on the project's issue tracker. When 77 | suggesting a feature, please include a brief description of the feature, as well as any rationale for why the feature 78 | would be useful. 79 | * You can contribute to talk-codebase by writing code. The project is always looking for help with improving the 80 | codebase, adding new features, and fixing bugs. 81 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "talk-codebase" 3 | version = "0.1.50" 4 | description = "talk-codebase is a powerful tool for querying and analyzing codebases." 5 | authors = ["Saryev Rustam "] 6 | readme = "README.md" 7 | packages = [{ include = "talk_codebase" }] 8 | keywords = ["chatgpt", "openai", "cli"] 9 | 10 | [tool.poetry.dependencies] 11 | python = ">=3.8.1,<4.0" 12 | fire = "^0.5.0" 13 | faiss-cpu = "^1.7.4" 14 | halo = "^0.0.31" 15 | urllib3 = "1.26.18" 16 | gitpython = "^3.1.31" 17 | questionary = "^1.10.0" 18 | sentence-transformers = "^2.2.2" 19 | unstructured = "^0.6.10" 20 | langchain = ">=0.0.223,<0.1.12" 21 | llama-cpp-python = { version = "^0.1.68", optional = true } 22 | gpt4all = { version = "^1.0.1", optional = true } 23 | openai = { version = "^0.27.7", optional = true } 24 | tiktoken = { version = "^0.4.0", optional = true } 25 | 26 | [tool.poetry.extras] 27 | local = ["gpt4all", "llama-cpp-python"] 28 | all = ["gpt4all", "llama-cpp-python", "openai", "tiktoken"] 29 | 30 | 31 | [tool.poetry.group.dev.dependencies] 32 | setuptools = ">=68,<71" 33 | 34 | [build-system] 35 | requires = ["poetry-core"] 36 | build-backend = "poetry.core.masonry.api" 37 | 38 | [project.urls] 39 | "Source" = "https://github.com/rsaryev/talk-codebase" 40 | "Bug Tracker" = "https://github.com/rsaryev/talk-codebase/issues" 41 | 42 | [tool.poetry.scripts] 43 | talk-codebase = "talk_codebase.cli:main" 44 | 45 | 46 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | aiohttp==3.9.2 2 | aiosignal==1.3.1 3 | async-timeout==4.0.2 4 | attrs==23.1.0 5 | certifi==2023.7.22 6 | charset-normalizer==3.1.0 7 | click==8.1.3 8 | colorama==0.4.6 9 | colored==1.4.4 10 | dataclasses-json==0.5.7 11 | faiss-cpu==1.7.4 12 | filelock==3.12.0 13 | fire==0.5.0 14 | frozenlist==1.3.3 15 | fsspec==2023.5.0 16 | gitdb==4.0.10 17 | GitPython==3.1.41 18 | gpt4all==0.2.3 19 | halo==0.0.31 20 | huggingface-hub==0.14.1 21 | idna==3.7 22 | Jinja2==3.1.3 23 | joblib==1.2.0 24 | langchain==0.0.325 25 | log-symbols==0.0.14 26 | MarkupSafe==2.1.2 27 | marshmallow==3.19.0 28 | marshmallow-enum==1.5.1 29 | mpmath==1.3.0 30 | multidict==6.0.4 31 | mypy-extensions==1.0.0 32 | networkx==3.1 33 | nltk==3.8.1 34 | numexpr==2.8.4 35 | numpy==1.24.3 36 | openai==0.27.7 37 | openapi-schema-pydantic==1.2.4 38 | packaging==23.1 39 | Pillow==10.3.0 40 | prompt-toolkit==3.0.38 41 | pydantic==1.10.8 42 | PyYAML==6.0 43 | questionary==1.10.0 44 | regex==2023.5.5 45 | requests==2.31.0 46 | scikit-learn==1.2.2 47 | scipy==1.10.0 48 | sentence-transformers==2.2.2 49 | sentencepiece==0.1.99 50 | six==1.16.0 51 | smmap==5.0.0 52 | spinners==0.0.24 53 | SQLAlchemy==2.0.15 54 | sympy==1.12 55 | tenacity==8.2.2 56 | termcolor==2.3.0 57 | threadpoolctl==3.1.0 58 | tiktoken==0.4.0 59 | tokenizers==0.13.3 60 | torch==2.0.1 61 | torchvision==0.15.2 62 | tqdm==4.65.0 63 | transformers==4.38.0 64 | typing-inspect==0.9.0 65 | typing_extensions==4.6.2 66 | urllib3==1.26.18 67 | wcwidth==0.2.6 68 | yarl==1.9.2 69 | -------------------------------------------------------------------------------- /talk_codebase/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rsaryev/talk-codebase/86f1771abb19eb00bd88a4ea76abbe3156e98125/talk_codebase/__init__.py -------------------------------------------------------------------------------- /talk_codebase/cli.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import fire 4 | 5 | from talk_codebase.config import CONFIGURE_STEPS, save_config, get_config, config_path, remove_api_key, \ 6 | remove_model_type, remove_model_name_local 7 | from talk_codebase.consts import DEFAULT_CONFIG 8 | from talk_codebase.llm import factory_llm 9 | from talk_codebase.utils import get_repo 10 | 11 | 12 | def check_python_version(): 13 | if sys.version_info < (3, 8, 1): 14 | print("🤖 Please use Python 3.8.1 or higher") 15 | sys.exit(1) 16 | 17 | 18 | def update_config(config): 19 | for key, value in DEFAULT_CONFIG.items(): 20 | if key not in config: 21 | config[key] = value 22 | return config 23 | 24 | 25 | def configure(reset=True): 26 | if reset: 27 | remove_api_key() 28 | remove_model_type() 29 | remove_model_name_local() 30 | config = get_config() 31 | config = update_config(config) 32 | for step in CONFIGURE_STEPS: 33 | step(config) 34 | save_config(config) 35 | 36 | 37 | def chat_loop(llm): 38 | while True: 39 | query = input("👉 ").lower().strip() 40 | if not query: 41 | print("🤖 Please enter a query") 42 | continue 43 | if query in ('exit', 'quit'): 44 | break 45 | llm.send_query(query) 46 | 47 | 48 | def chat(): 49 | configure(False) 50 | config = get_config() 51 | repo = get_repo() 52 | if not repo: 53 | print("🤖 Git repository not found") 54 | sys.exit(1) 55 | llm = factory_llm(repo.working_dir, config) 56 | chat_loop(llm) 57 | 58 | 59 | def main(): 60 | check_python_version() 61 | print(f"🤖 Config path: {config_path}:") 62 | try: 63 | fire.Fire({ 64 | "chat": chat, 65 | "configure": lambda: configure(True) 66 | }) 67 | except KeyboardInterrupt: 68 | print("\n🤖 Bye!") 69 | except Exception as e: 70 | raise e 71 | 72 | 73 | if __name__ == "__main__": 74 | main() 75 | -------------------------------------------------------------------------------- /talk_codebase/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | openai_flag = True 4 | 5 | try: 6 | import openai 7 | except: 8 | openai_flag = False 9 | 10 | import gpt4all 11 | import questionary 12 | import yaml 13 | 14 | from talk_codebase.consts import MODEL_TYPES 15 | 16 | 17 | 18 | 19 | config_path = os.path.join(os.path.expanduser("~"), ".talk_codebase_config.yaml") 20 | 21 | 22 | def get_config(): 23 | if os.path.exists(config_path): 24 | with open(config_path, "r") as f: 25 | config = yaml.safe_load(f) 26 | else: 27 | config = {} 28 | return config 29 | 30 | 31 | def save_config(config): 32 | with open(config_path, "w") as f: 33 | yaml.dump(config, f) 34 | 35 | 36 | def api_key_is_invalid(api_key): 37 | if not api_key: 38 | return True 39 | try: 40 | openai.api_key = api_key 41 | openai.Engine.list() 42 | except Exception: 43 | return True 44 | return False 45 | 46 | 47 | def get_gpt_models(openai): 48 | try: 49 | model_lst = openai.Model.list() 50 | except Exception: 51 | print("✘ Failed to retrieve model list") 52 | return [] 53 | 54 | return [i['id'] for i in model_lst['data'] if 'gpt' in i['id']] 55 | 56 | 57 | def configure_model_name_openai(config): 58 | api_key = config.get("api_key") 59 | 60 | if config.get("model_type") != MODEL_TYPES["OPENAI"] or config.get("openai_model_name"): 61 | return 62 | 63 | openai.api_key = api_key 64 | gpt_models = get_gpt_models(openai) 65 | choices = [{"name": model, "value": model} for model in gpt_models] 66 | 67 | if not choices: 68 | print("ℹ No GPT models available") 69 | return 70 | 71 | model_name = questionary.select("🤖 Select model name:", choices).ask() 72 | 73 | if not model_name: 74 | print("✘ No model selected") 75 | return 76 | 77 | config["openai_model_name"] = model_name 78 | save_config(config) 79 | print("🤖 Model name saved!") 80 | 81 | 82 | def remove_model_name_openai(): 83 | config = get_config() 84 | config["openai_model_name"] = None 85 | save_config(config) 86 | 87 | 88 | def configure_model_name_local(config): 89 | if config.get("model_type") != MODEL_TYPES["LOCAL"] or config.get("local_model_name"): 90 | return 91 | 92 | list_models = gpt4all.GPT4All.list_models() 93 | 94 | def get_model_info(model): 95 | return ( 96 | f"{model['name']} " 97 | f"| {model['filename']} " 98 | f"| {model['filesize']} " 99 | f"| {model['parameters']} " 100 | f"| {model['quant']} " 101 | f"| {model['type']}" 102 | ) 103 | 104 | choices = [ 105 | {"name": get_model_info(model), "value": model['filename']} for model in list_models 106 | ] 107 | 108 | model_name = questionary.select("🤖 Select model name:", choices).ask() 109 | config["local_model_name"] = model_name 110 | save_config(config) 111 | print("🤖 Model name saved!") 112 | 113 | 114 | def remove_model_name_local(): 115 | config = get_config() 116 | config["local_model_name"] = None 117 | save_config(config) 118 | 119 | 120 | def get_and_validate_api_key(): 121 | prompt = "🤖 Enter your OpenAI API key: " 122 | api_key = input(prompt) 123 | while api_key_is_invalid(api_key): 124 | print("✘ Invalid API key") 125 | api_key = input(prompt) 126 | return api_key 127 | 128 | 129 | def configure_api_key(config): 130 | if config.get("model_type") != MODEL_TYPES["OPENAI"]: 131 | return 132 | 133 | if api_key_is_invalid(config.get("api_key")): 134 | api_key = get_and_validate_api_key() 135 | config["api_key"] = api_key 136 | save_config(config) 137 | 138 | 139 | def remove_api_key(): 140 | config = get_config() 141 | config["api_key"] = None 142 | save_config(config) 143 | 144 | 145 | def remove_model_type(): 146 | config = get_config() 147 | config["model_type"] = None 148 | save_config(config) 149 | 150 | 151 | def configure_model_type(config): 152 | if config.get("model_type"): 153 | return 154 | 155 | choices = [{"name": "Local", "value": MODEL_TYPES["LOCAL"]}] 156 | 157 | if openai_flag: choices.append( 158 | {"name": "OpenAI", "value": MODEL_TYPES["OPENAI"]}) 159 | 160 | 161 | model_type = questionary.select( 162 | "🤖 Select model type:", 163 | choices=choices 164 | ).ask() 165 | config["model_type"] = model_type 166 | save_config(config) 167 | 168 | 169 | CONFIGURE_STEPS = [ 170 | configure_model_type, 171 | configure_api_key, 172 | configure_model_name_openai, 173 | configure_model_name_local, 174 | ] 175 | -------------------------------------------------------------------------------- /talk_codebase/consts.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | from langchain.document_loaders import CSVLoader, UnstructuredWordDocumentLoader, UnstructuredEPubLoader, \ 5 | PDFMinerLoader, UnstructuredMarkdownLoader, TextLoader 6 | 7 | EXCLUDE_DIRS = ['__pycache__', '.venv', '.git', '.idea', 'venv', 'env', 'node_modules', 'dist', 'build', '.vscode', 8 | '.github', '.gitlab'] 9 | ALLOW_FILES = ['.txt', '.js', '.mjs', '.ts', '.tsx', '.css', '.scss', '.less', '.html', '.htm', '.json', '.py', 10 | '.java', '.c', '.cpp', '.cs', '.go', '.php', '.rb', '.rs', '.swift', '.kt', '.scala', '.m', '.h', 11 | '.sh', '.pl', '.pm', '.lua', '.sql'] 12 | EXCLUDE_FILES = ['requirements.txt', 'package.json', 'package-lock.json', 'yarn.lock'] 13 | MODEL_TYPES = { 14 | "OPENAI": "openai", 15 | "LOCAL": "local", 16 | } 17 | DEFAULT_MODEL_DIRECTORY = os.path.join(str(Path.home()), ".cache", "gpt4all").replace("\\", "\\\\") 18 | 19 | DEFAULT_CONFIG = { 20 | "max_tokens": "2056", 21 | "chunk_size": "2056", 22 | "chunk_overlap": "256", 23 | "k": "2", 24 | "temperature": "0.7", 25 | "model_path": DEFAULT_MODEL_DIRECTORY, 26 | "n_batch": "8", 27 | } 28 | 29 | LOADER_MAPPING = { 30 | ".csv": { 31 | "loader": CSVLoader, 32 | "args": {} 33 | }, 34 | ".doc": { 35 | "loader": UnstructuredWordDocumentLoader, 36 | "args": {} 37 | }, 38 | ".docx": { 39 | "loader": UnstructuredWordDocumentLoader, 40 | "args": {} 41 | }, 42 | ".epub": { 43 | "loader": UnstructuredEPubLoader, 44 | "args": {} 45 | }, 46 | ".md": { 47 | "loader": UnstructuredMarkdownLoader, 48 | "args": {} 49 | }, 50 | ".pdf": { 51 | "loader": PDFMinerLoader, 52 | "args": {} 53 | } 54 | } 55 | 56 | for ext in ALLOW_FILES: 57 | if ext not in LOADER_MAPPING: 58 | LOADER_MAPPING[ext] = { 59 | "loader": TextLoader, 60 | "args": {} 61 | } 62 | -------------------------------------------------------------------------------- /talk_codebase/llm.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from typing import Optional 4 | 5 | import gpt4all 6 | import questionary 7 | from halo import Halo 8 | from langchain.vectorstores import FAISS 9 | from langchain.callbacks.manager import CallbackManager 10 | from langchain.chains import RetrievalQA 11 | from langchain.chat_models import ChatOpenAI 12 | from langchain.embeddings import HuggingFaceEmbeddings, OpenAIEmbeddings 13 | from langchain.llms import LlamaCpp 14 | from langchain.text_splitter import RecursiveCharacterTextSplitter 15 | 16 | from talk_codebase.consts import MODEL_TYPES 17 | from talk_codebase.utils import load_files, get_local_vector_store, calculate_cost, StreamStdOut 18 | 19 | 20 | class BaseLLM: 21 | 22 | def __init__(self, root_dir, config): 23 | self.config = config 24 | self.llm = self._create_model() 25 | self.root_dir = root_dir 26 | self.vector_store = self._create_store(root_dir) 27 | 28 | def _create_store(self, root_dir): 29 | raise NotImplementedError("Subclasses must implement this method.") 30 | 31 | def _create_model(self): 32 | raise NotImplementedError("Subclasses must implement this method.") 33 | 34 | def embedding_search(self, query, k): 35 | return self.vector_store.search(query, k=k, search_type="similarity") 36 | 37 | def _create_vector_store(self, embeddings, index, root_dir): 38 | k = int(self.config.get("k")) 39 | index_path = os.path.join(root_dir, f"vector_store/{index}") 40 | new_db = get_local_vector_store(embeddings, index_path) 41 | if new_db is not None: 42 | return new_db.as_retriever(search_kwargs={"k": k}) 43 | 44 | docs = load_files() 45 | if len(docs) == 0: 46 | print("✘ No documents found") 47 | exit(0) 48 | text_splitter = RecursiveCharacterTextSplitter(chunk_size=int(self.config.get("chunk_size")), 49 | chunk_overlap=int(self.config.get("chunk_overlap"))) 50 | texts = text_splitter.split_documents(docs) 51 | if index == MODEL_TYPES["OPENAI"]: 52 | cost = calculate_cost(docs, self.config.get("openai_model_name")) 53 | approve = questionary.select( 54 | f"Creating a vector store will cost ~${cost:.5f}. Do you want to continue?", 55 | choices=[ 56 | {"name": "Yes", "value": True}, 57 | {"name": "No", "value": False}, 58 | ] 59 | ).ask() 60 | if not approve: 61 | exit(0) 62 | 63 | spinners = Halo(text=f"Creating vector store", spinner='dots').start() 64 | db = FAISS.from_documents([texts[0]], embeddings) 65 | for i, text in enumerate(texts[1:]): 66 | spinners.text = f"Creating vector store ({i + 1}/{len(texts)})" 67 | db.add_documents([text]) 68 | db.save_local(index_path) 69 | time.sleep(1.5) 70 | 71 | spinners.succeed(f"Created vector store") 72 | return db.as_retriever(search_kwargs={"k": k}) 73 | 74 | def send_query(self, query): 75 | retriever = self._create_store(self.root_dir) 76 | qa = RetrievalQA.from_chain_type( 77 | llm=self.llm, 78 | chain_type="stuff", 79 | retriever=retriever, 80 | return_source_documents=True 81 | ) 82 | docs = qa(query) 83 | file_paths = [os.path.abspath(s.metadata["source"]) for s in docs['source_documents']] 84 | print('\n'.join([f'📄 {file_path}:' for file_path in file_paths])) 85 | 86 | 87 | class LocalLLM(BaseLLM): 88 | 89 | def _create_store(self, root_dir: str) -> Optional[FAISS]: 90 | embeddings = HuggingFaceEmbeddings(model_name='all-MiniLM-L6-v2') 91 | return self._create_vector_store(embeddings, MODEL_TYPES["LOCAL"], root_dir) 92 | 93 | def _create_model(self): 94 | os.makedirs(self.config.get("model_path"), exist_ok=True) 95 | gpt4all.GPT4All.retrieve_model(model_name=self.config.get("local_model_name"), 96 | model_path=self.config.get("model_path")) 97 | model_path = os.path.join(self.config.get("model_path"), self.config.get("local_model_name")) 98 | model_n_ctx = int(self.config.get("max_tokens")) 99 | model_n_batch = int(self.config.get("n_batch")) 100 | callbacks = CallbackManager([StreamStdOut()]) 101 | llm = LlamaCpp(model_path=model_path, n_ctx=model_n_ctx, n_batch=model_n_batch, callbacks=callbacks, 102 | verbose=False) 103 | llm.client.verbose = False 104 | return llm 105 | 106 | 107 | class OpenAILLM(BaseLLM): 108 | def _create_store(self, root_dir: str) -> Optional[FAISS]: 109 | embeddings = OpenAIEmbeddings(openai_api_key=self.config.get("api_key")) 110 | return self._create_vector_store(embeddings, MODEL_TYPES["OPENAI"], root_dir) 111 | 112 | def _create_model(self): 113 | return ChatOpenAI(model_name=self.config.get("openai_model_name"), 114 | openai_api_key=self.config.get("api_key"), 115 | streaming=True, 116 | max_tokens=int(self.config.get("max_tokens")), 117 | callback_manager=CallbackManager([StreamStdOut()]), 118 | temperature=float(self.config.get("temperature"))) 119 | 120 | 121 | def factory_llm(root_dir, config): 122 | if config.get("model_type") == "openai": 123 | return OpenAILLM(root_dir, config) 124 | else: 125 | return LocalLLM(root_dir, config) 126 | -------------------------------------------------------------------------------- /talk_codebase/utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import tiktoken 4 | from git import Repo 5 | from langchain.vectorstores import FAISS 6 | from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler 7 | 8 | from talk_codebase.consts import LOADER_MAPPING, EXCLUDE_FILES 9 | 10 | 11 | def get_repo(): 12 | try: 13 | return Repo() 14 | except: 15 | return None 16 | 17 | 18 | class StreamStdOut(StreamingStdOutCallbackHandler): 19 | def on_llm_new_token(self, token: str, **kwargs) -> None: 20 | sys.stdout.write(token) 21 | sys.stdout.flush() 22 | 23 | def on_llm_start(self, serialized, prompts, **kwargs): 24 | sys.stdout.write("🤖 ") 25 | 26 | def on_llm_end(self, response, **kwargs): 27 | sys.stdout.write("\n") 28 | sys.stdout.flush() 29 | 30 | 31 | def load_files(): 32 | repo = get_repo() 33 | if repo is None: 34 | return [] 35 | files = [] 36 | tree = repo.tree() 37 | for blob in tree.traverse(): 38 | path = blob.path 39 | if any( 40 | path.endswith(exclude_file) for exclude_file in EXCLUDE_FILES): 41 | continue 42 | for ext in LOADER_MAPPING: 43 | if path.endswith(ext): 44 | print('\r' + f'📂 Loading files: {path}') 45 | args = LOADER_MAPPING[ext]['args'] 46 | loader = LOADER_MAPPING[ext]['loader'](path, *args) 47 | files.extend(loader.load()) 48 | return files 49 | 50 | 51 | def calculate_cost(texts, model_name): 52 | enc = tiktoken.encoding_for_model(model_name) 53 | all_text = ''.join([text.page_content for text in texts]) 54 | tokens = enc.encode(all_text) 55 | token_count = len(tokens) 56 | cost = (token_count / 1000) * 0.0004 57 | return cost 58 | 59 | 60 | def get_local_vector_store(embeddings, path): 61 | try: 62 | return FAISS.load_local(path, embeddings) 63 | except: 64 | return None 65 | --------------------------------------------------------------------------------