├── .env.example ├── .flake8 ├── .github └── workflows │ └── lint.yml ├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── nagato ├── __init__.py ├── service │ ├── __init__.py │ ├── embedding.py │ ├── finetune.py │ ├── prompts.py │ ├── query.py │ └── vectordb.py └── utils │ ├── __init__.py │ ├── lazy_model_loader.py │ └── logger.py ├── poetry.lock ├── pyproject.toml ├── setup.py └── tests ├── __init__.py ├── embedding.py ├── finetune.py ├── predict.py ├── predict_with_embedding.py └── query_embedding.py /.env.example: -------------------------------------------------------------------------------- 1 | OPENAI_API_KEY= 2 | HF_API_KEY= 3 | PINECONE_API_KEY= 4 | REPLICATE_API_TOKEN= 5 | COHERE_API_KEY= -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | exclude = 3 | venv 4 | .venv 5 | __pycache__ 6 | notebooks 7 | # Recommend matching the black line length (default 88), 8 | # rather than using the flake8 default of 79: 9 | max-line-length = 88 10 | extend-ignore = 11 | # See https://github.com/PyCQA/pycodestyle/issues/373 12 | E203, 13 | E501, -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: lint 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | 8 | env: 9 | POETRY_VERSION: "1.4.2" 10 | 11 | jobs: 12 | build: 13 | runs-on: ubuntu-latest 14 | strategy: 15 | matrix: 16 | python-version: 17 | - "3.8" 18 | - "3.9" 19 | - "3.10" 20 | - "3.11" 21 | steps: 22 | - uses: actions/checkout@v3 23 | - name: Install poetry 24 | run: | 25 | pipx install poetry==$POETRY_VERSION 26 | - name: Set up Python ${{ matrix.python-version }} 27 | uses: actions/setup-python@v4 28 | with: 29 | python-version: ${{ matrix.python-version }} 30 | cache: poetry 31 | - name: Install dependencies 32 | run: | 33 | poetry install 34 | - name: Analysing the code with our lint 35 | run: | 36 | make lint -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.pyc 3 | .env 4 | .env*.local 5 | .venv 6 | superenv/ 7 | .DS_Store 8 | venv/ 9 | /.vscode 10 | /.codesandbox 11 | .pypirc 12 | dist/ 13 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Ismail Pelaseyed 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | format: 2 | poetry run black . 3 | poetry run ruff --select I --fix . 4 | poetry run vulture . 5 | 6 | PYTHON_FILES=. 7 | lint: PYTHON_FILES=. 8 | lint_diff: PYTHON_FILES=$(shell git diff --name-only --diff-filter=d master | grep -E '\.py$$') 9 | 10 | lint lint_diff: 11 | poetry run black $(PYTHON_FILES) --check 12 | poetry run ruff . 13 | poetry run vulture . -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
21 | Quick Start Guide • 22 | Features • 23 | Key benefits • 24 | How it works 25 |
26 | 27 | ----- 28 | 29 | ## Quick Start Quide 30 | Full documentation of all methods in the `nagato` library will be posted soon. 31 | 32 | 1. Change the name of `.env-example` to `.env` and populate the environment variables 33 | 34 | 2. Install the `nagato-ai` package using either PIP or Poetry: 35 | 36 | - For PIP: 37 | ```sh 38 | pip install nagato-ai 39 | ``` 40 | - For Poetry: 41 | ```sh 42 | poetry add nagato-ai 43 | ``` 44 | 45 | 3. Create and store embeddings 46 | ```sh 47 | from nagato import create_vector_embeddings 48 | 49 | results = create_vector_embeddings( 50 | type: "PDF", 51 | filter_id: "MY_DOCUMENT_ID", 52 | url: "https://digitalassets.tesla.com/tesla-contents/image/upload/IR/TSLA-Q2-2023-Update.pdf", 53 | ) 54 | 55 | ``` 56 | 57 | 4. Create fine-tuned model 58 | ```sh 59 | from nagato import create_finetuned_model 60 | 61 | results = create_finetuned_model( 62 | url="https://digitalassets.tesla.com/tesla-contents/image/upload/IR/TSLA-Q2-2023-Update.pdf", 63 | type="PDF", 64 | base_model="LLAMA2_7B_CHAT", 65 | provider="REPLICATE", 66 | webhook_url="https://webhook.site/ebe803b9-1e34-4b20-a6ca-d06356961cd1", 67 | ) 68 | ``` 69 | 70 | ## Features 71 | 72 | - Data ingestion from various formats such as JSON, CSV, TXT, PDF, etc. 73 | - Data embedding using pre-trained or finetuned models. 74 | - Storage of embedded vectors 75 | - Automatic generation of question/answer pairs for model finetuning 76 | - Built in code interpreter 77 | - API concurrency for scalalbility and performance 78 | - Workflow management for ingestion pipelines 79 | 80 | ## Key benefits 81 | 82 | - **Faster inference**: Generic models often bring overhead in terms of computational time due to their broad-based training. In contrast, our fine-tuned models are optimized for specific domains, enabling faster inference and more timely results. 83 | 84 | - **Lower costs**: Utilizing fine-tuned models tailored for a specific corpus minimizes the number of tokens needed for accurate understanding and response generation. This reduction in token count translates to decreased computational costs and thus lower operational expenses. 85 | 86 | - **Better results**: Fine-tuned models offer superior performance on specialized tasks when compared to generic, all-purpose models. Whether you're generating embeddings or answering complex queries, you can expect more accurate and contextually relevant outcomes. 87 | 88 | ## How it works 89 | 90 | Nagato utilizes distinct strategies to process structured and unstructured data, aiming to produce fine-tuned models for both types. Below is a breakdown of how this is accomplished: 91 | 92 |  93 | 94 | ### Unstructured data: 95 | 96 | 1. **Selection of Embedding Model**: The first step involves a careful analysis of the textual content to select an appropriate text-based embedding model. Based on various characteristics of the corpus such as vocabulary, context, and domain-specific jargon, Nagato picks the most suitable pre-trained text-based model for embedding. 97 | 98 | 2. **Fine-Tuning the Embedding Model**: Once the initial text-based model is selected, it is then fine-tuned to align more closely with the specific domain or subject matter of the corpus. This ensures that the embeddings generated are as accurate and relevant as possible. 99 | 100 | 3. **Fine-Tuning the Language Model**: After generating and storing embeddings, Nagato creates question-answer pairs for the purpose of fine-tuning a GPT-based language model. This yields a language model that is highly specialized in understanding and generating text within the domain of the corpus. 101 | 102 | ### Structured data: 103 | 104 | 1. **Sandboxed REPL**: Nagato features a secure, sandboxed Read-Eval-Print Loop (REPL) environment to execute code snippets against the structured text data. This facilitates flexible and dynamic processing of structured data formats like JSON, CSV or XML. 105 | 106 | 2. **Evaluation/Prediction Using a Code Interpreter**: Post-initial processing, a code interpreter evaluates various code snippets within the sandboxed environment to produce predictions or analyses based on the structured text data. This capability allows the extraction of highly specialized insights tailored to the domain or subject matter. 107 | 108 | 109 | ## Citation 110 | 111 | If you use Nagato in your research, please cite it as follows: 112 | 113 | ``` 114 | @misc{nagato, 115 | author = {Ismail Pelaseyed}, 116 | title = {Nagato: The open framework for Q&A finetuning LLMs on private data}, 117 | year = {2023}, 118 | publisher = {GitHub}, 119 | journal = {GitHub repository}, 120 | howpublished = {\url{https://github.com/homanp/nagato}}, 121 | } 122 | ``` 123 | -------------------------------------------------------------------------------- /nagato/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | 3 | from .service import create_finetuned_model, create_vector_embeddings 4 | -------------------------------------------------------------------------------- /nagato/service/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, List 2 | 3 | import requests 4 | 5 | from nagato.service.embedding import ( 6 | MODEL_TO_INDEX, 7 | EmbeddingService, 8 | get_vector_service, 9 | ) 10 | from nagato.service.finetune import get_finetuning_service 11 | from nagato.service.query import QueryService 12 | from nagato.utils.lazy_model_loader import LazyModelLoader 13 | 14 | 15 | def create_vector_embeddings( 16 | type: str, model: str, filter_id: str, url: str = None, content: str = None 17 | ) -> List: 18 | embedding_service = EmbeddingService(type=type, content=content, url=url) 19 | documents = embedding_service.generate_documents() 20 | nodes = embedding_service.generate_chunks(documents=documents) 21 | embedding_service.generate_embeddings(nodes=nodes, filter_id=filter_id, model=model) 22 | return nodes 23 | 24 | 25 | def create_finetuned_model( 26 | provider: str, 27 | base_model: str, 28 | type: str, 29 | url: str = None, 30 | content: str = None, 31 | webhook_url: str = None, 32 | num_questions_per_chunk: int = 10, 33 | ) -> dict: 34 | embedding_service = EmbeddingService(type=type, url=url, content=content) 35 | documents = embedding_service.generate_documents() 36 | nodes = embedding_service.generate_chunks(documents=documents) 37 | finetunning_service = get_finetuning_service( 38 | nodes=nodes, 39 | provider=provider, 40 | batch_size=5, 41 | base_model=base_model, 42 | num_questions_per_chunk=num_questions_per_chunk, 43 | ) 44 | training_file = finetunning_service.generate_dataset() 45 | formatted_training_file = finetunning_service.validate_dataset( 46 | training_file=training_file 47 | ) 48 | finetune = finetunning_service.finetune( 49 | training_file=formatted_training_file, webhook_url=webhook_url 50 | ) 51 | if provider == "OPENAI": 52 | requests.post(webhook_url, json=finetune) 53 | finetunning_service.cleanup(training_file=finetune.get("training_file")) 54 | return finetune 55 | 56 | 57 | def predict( 58 | input: str, 59 | provider: str, 60 | model: str, 61 | system_prompt: str = None, 62 | callback: Callable = None, 63 | enable_streaming: bool = False, 64 | ) -> dict: 65 | query_service = QueryService(provider=provider, model=model) 66 | output = query_service.predict( 67 | input=input, 68 | callback=callback, 69 | enable_streaming=enable_streaming, 70 | system_prompt=system_prompt, 71 | ) 72 | return output 73 | 74 | 75 | def predict_with_embedding( 76 | input: str, 77 | provider: str, 78 | model: str, 79 | vector_db: str, 80 | embedding_model: str, 81 | embedding_filter_id: str, 82 | callback: Callable = None, 83 | system_prompt: str = "You are a helpful assistant", 84 | enable_streaming: bool = False, 85 | ) -> dict: 86 | context = query_embedding( 87 | query=input, 88 | model=embedding_model, 89 | filter_id=embedding_filter_id, 90 | vector_db=vector_db, 91 | ) 92 | query_service = QueryService(provider=provider, model=model) 93 | output = query_service.predict_with_embedding( 94 | input=input, 95 | callback=callback, 96 | enable_streaming=enable_streaming, 97 | context=context, 98 | system_prompt=system_prompt, 99 | ) 100 | return output 101 | 102 | 103 | def query_embedding( 104 | query: str, 105 | model: str = "thenlper/gte-small", 106 | vector_db: str = "PINECONE", 107 | filter_id: str = None, 108 | top_k: int = 5, 109 | re_rank: bool = True, 110 | ) -> dict: 111 | model_name = MODEL_TO_INDEX[model.split("/")[-1]].get("index_name") 112 | model_dimensions = MODEL_TO_INDEX[model.split("/")[-1]].get("dimensions") 113 | embedding_model = LazyModelLoader(model_name=model_name) 114 | vectordb = get_vector_service( 115 | provider=vector_db, 116 | index_name=model_name, 117 | filter_id=filter_id, 118 | dimension=model_dimensions, 119 | ) 120 | embedding = embedding_model.model.encode([query]).tolist() 121 | docs = vectordb.query(queries=embedding, top_k=top_k, include_metadata=True) 122 | if re_rank: 123 | docs = vectordb.rerank(query=query, documents=docs, top_n=top_k) 124 | return docs[0] 125 | -------------------------------------------------------------------------------- /nagato/service/embedding.py: -------------------------------------------------------------------------------- 1 | from tempfile import NamedTemporaryFile 2 | from typing import List, Union 3 | 4 | import requests 5 | from llama_index import Document, SimpleDirectoryReader 6 | from llama_index.node_parser import SimpleNodeParser 7 | from numpy import ndarray 8 | from tqdm import tqdm 9 | 10 | from nagato.service.vectordb import get_vector_service 11 | from nagato.utils.lazy_model_loader import LazyModelLoader 12 | 13 | MODEL_TO_INDEX = { 14 | "all-MiniLM-L6-v2": {"index_name": "all-minilm-l6-v2", "dimensions": 384}, 15 | "thenlper/gte-base": {"index_name": "gte-base", "dimensions": 768}, 16 | "thenlper/gte-small": {"index_name": "gte-small", "dimensions": 384}, 17 | "thenlper/gte-large": {"index_name": "gte-large", "dimensions": 1024}, 18 | "infgrad/stella-base-en-v2": {"index_name": "stella-base", "dimensions": 768}, 19 | "BAAI/bge-large-en-v1.5": {"index_name": "bge-large", "dimensions": 1024}, 20 | "jinaai/jina-embeddings-v2-base-en": { 21 | "index_name": "jina-embeddings-v2", 22 | "dimensions": 768, 23 | } 24 | # Add more mappings here as needed 25 | } 26 | 27 | 28 | class EmbeddingService: 29 | def __init__(self, type: str, url: str = None, content: str = None): 30 | self.type = type 31 | self.url = url 32 | self.content = content 33 | 34 | def get_datasource_suffix(self) -> str: 35 | suffixes = {"TXT": ".txt", "PDF": ".pdf", "MARKDOWN": ".md"} 36 | try: 37 | return suffixes[self.type] 38 | except KeyError: 39 | raise ValueError("Unsupported datasource type") 40 | 41 | def generate_documents(self) -> List[Document]: 42 | with NamedTemporaryFile( 43 | suffix=self.get_datasource_suffix(), delete=True 44 | ) as temp_file: 45 | if self.url: 46 | response = requests.get(self.url, stream=True) 47 | total_size_in_bytes = int(response.headers.get("content-length", 0)) 48 | block_size = 1024 49 | content = b"" 50 | with tqdm( 51 | total=total_size_in_bytes, 52 | desc="🟠 Downloading file", 53 | unit="iB", 54 | unit_scale=True, 55 | ) as progress_bar: 56 | for data in response.iter_content(block_size): 57 | progress_bar.update(len(data)) 58 | content += data 59 | if ( 60 | total_size_in_bytes != 0 61 | and progress_bar.n != total_size_in_bytes 62 | ): 63 | print("ERROR, something went wrong") 64 | else: 65 | progress_bar.set_description("🟢 Downloading file") 66 | else: 67 | content = self.content 68 | temp_file.write(content) 69 | temp_file.flush() 70 | 71 | with tqdm(total=1, desc="🟠 Processing data") as pbar: 72 | reader = SimpleDirectoryReader(input_files=[temp_file.name]) 73 | docs = reader.load_data() 74 | pbar.update() 75 | pbar.set_description("🟢 Processing data") 76 | 77 | return docs 78 | 79 | def generate_chunks(self, documents: List[Document]) -> List[Union[Document, None]]: 80 | parser = SimpleNodeParser.from_defaults(chunk_size=350, chunk_overlap=20) 81 | with tqdm(total=1, desc="🟠 Generating chunks") as pbar: 82 | nodes = parser.get_nodes_from_documents(documents, show_progress=False) 83 | pbar.update() 84 | pbar.set_description("🟢 Generating chunks") 85 | return nodes 86 | 87 | def generate_embeddings( 88 | self, 89 | nodes: List[Union[Document, None]], 90 | filter_id: str, 91 | model: str = "all-MiniLM-L6-v2", 92 | embedding_provider: str = "PINECONE", 93 | ) -> List[ndarray]: 94 | vectordb = get_vector_service( 95 | provider=embedding_provider, 96 | index_name=MODEL_TO_INDEX[model].get("index_name"), 97 | filter_id=filter_id, 98 | dimension=MODEL_TO_INDEX[model].get("dimensions"), 99 | ) 100 | embedding_model = LazyModelLoader(model_name=model) 101 | embeddings = [] 102 | with tqdm(total=len(nodes), desc="🟠 Generating embeddings") as pbar: 103 | for node in nodes: 104 | if node is not None: 105 | embedding = ( 106 | node.id_, 107 | embedding_model.model.encode(node.text).tolist(), 108 | {**node.metadata, "content": node.text}, 109 | ) 110 | embeddings.append(embedding) 111 | pbar.update() 112 | vectordb.upsert(vectors=embeddings) 113 | pbar.set_description("🟢 Generating embeddings") 114 | 115 | return embeddings 116 | -------------------------------------------------------------------------------- /nagato/service/finetune.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | 3 | import sys 4 | import requests 5 | import json 6 | import os 7 | import uuid 8 | import openai 9 | import replicate 10 | import concurrent.futures 11 | 12 | from tqdm import tqdm 13 | from decouple import config 14 | from abc import ABC, abstractmethod 15 | from typing import Dict, List, Union 16 | from concurrent.futures import ThreadPoolExecutor 17 | 18 | from nagato.utils.logger import logger 19 | from decouple import config 20 | from llama_index import Document 21 | 22 | from nagato.service.prompts import ( 23 | GPT_DATA_FORMAT, 24 | REPLICATE_FORMAT, 25 | generate_qa_pair_prompt, 26 | ) 27 | 28 | openai.api_key = config("OPENAI_API_KEY") 29 | 30 | 31 | REPLICATE_MODELS = { 32 | "LLAMA2_7B_CHAT": "meta/llama-2-7b-chat:8e6975e5ed6174911a6ff3d60540dfd4844201974602551e10e9e87ab143d81e", 33 | "LLAMA2_7B": "meta/llama-2-7b:527827021d8756c7ab79fde0abbfaac885c37a3ed5fe23c7465093f0878d55ef", 34 | "LLAMA2_13B_CHAT": "meta/llama-2-13b-chat:f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d", 35 | "LLAMA2_13B": "meta/llama-2-13b:078d7a002387bd96d93b0302a4c03b3f15824b63104034bfa943c63a8f208c38", 36 | "LLAMA2_70B_CHAT": "meta/llama-2-70b-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3", 37 | "LLAMA2_70B": "meta/llama-2-70b:a52e56fee2269a78c9279800ec88898cecb6c8f1df22a6483132bea266648f00", 38 | "GPT_J_6B": "replicate/gpt-j-6b:b3546aeec6c9891f0dd9929c2d3bedbf013c12e02e7dd0346af09c37e008c827", 39 | "DOLLY_V2_12B": "replicate/dolly-v2-12b:ef0e1aefc61f8e096ebe4db6b2bacc297daf2ef6899f0f7e001ec445893500e5", 40 | } 41 | 42 | OPENAI_MODELS = {"GPT_35_TURBO": "gpt-3.5-turbo"} 43 | 44 | 45 | class FinetuningService(ABC): 46 | def __init__( 47 | self, 48 | nodes: List[Union[Document, None]], 49 | num_questions_per_chunk: int, 50 | batch_size: int, 51 | ): 52 | self.nodes = nodes 53 | self.num_questions_per_chunk = num_questions_per_chunk 54 | self.batch_size = batch_size 55 | 56 | @abstractmethod 57 | def generate_prompt_and_completion(self, node): 58 | pass 59 | 60 | @abstractmethod 61 | def validate_dataset(self, training_file: str) -> str: 62 | pass 63 | 64 | @abstractmethod 65 | def finetune(self, training_file: str, base_model: str) -> Dict: 66 | pass 67 | 68 | def generate_dataset(self) -> str: 69 | training_file = f"{uuid.uuid4()}.jsonl" 70 | total_pairs = len(self.nodes) * self.num_questions_per_chunk 71 | with open(training_file, "w") as f: 72 | with ThreadPoolExecutor() as executor: 73 | progress_bar = tqdm( 74 | total=total_pairs, 75 | desc="🟠 Generating synthetic Q&A pairs", 76 | file=sys.stdout, 77 | ) 78 | for i in range( 79 | 0, len(self.nodes), self.batch_size 80 | ): # Process nodes in chunks of batch_size 81 | batch_nodes = self.nodes[i : i + self.batch_size] 82 | tasks = [ 83 | executor.submit(self.generate_prompt_and_completion, node) 84 | for node in batch_nodes 85 | ] 86 | for future in concurrent.futures.as_completed(tasks): 87 | qa_pair = future.result() 88 | json_objects = qa_pair.split("\n\n") 89 | for json_obj in json_objects: 90 | f.write(json_obj + "\n") 91 | progress_bar.update(1) 92 | progress_bar.set_description("🟢 Generating synthetic Q&A pairs") 93 | progress_bar.close() 94 | return training_file 95 | 96 | def cleanup(self, training_file: str) -> None: 97 | os.remove(training_file) 98 | 99 | 100 | class OpenAIFinetuningService(FinetuningService): 101 | def __init__( 102 | self, 103 | nodes: List[Union[Document, None]], 104 | num_questions_per_chunk: int, 105 | batch_size: int, 106 | base_model: str = "GPT_35_TURBO", 107 | ): 108 | super().__init__( 109 | nodes=nodes, 110 | num_questions_per_chunk=num_questions_per_chunk, 111 | batch_size=batch_size, 112 | ) 113 | self.base_model = base_model 114 | 115 | def generate_prompt_and_completion(self, node): 116 | prompt = generate_qa_pair_prompt( 117 | context=node.text, 118 | num_of_qa_pairs=self.num_questions_per_chunk, 119 | format=GPT_DATA_FORMAT, 120 | ) 121 | completion = openai.ChatCompletion.create( 122 | model="gpt-3.5-turbo", 123 | messages=[{"role": "user", "content": prompt}], 124 | temperature=0, 125 | ) 126 | return completion.choices[0].message.content 127 | 128 | def validate_dataset(self, training_file: str) -> str: 129 | valid_lines = [] 130 | with open(training_file, "r") as file: 131 | lines = file.readlines() 132 | total_lines = len(lines) 133 | progress_bar = tqdm( 134 | total=total_lines, desc="🟠 Validating dataset", file=sys.stdout 135 | ) 136 | for line in lines: 137 | try: 138 | data = json.loads(line) 139 | if "messages" not in data: 140 | continue 141 | messages = data["messages"] 142 | if len(messages) != 3: 143 | continue 144 | if not ( 145 | messages[0]["role"] == "system" 146 | and messages[1]["role"] == "user" 147 | and messages[2]["role"] == "assistant" 148 | ): 149 | continue 150 | valid_lines.append(line) 151 | except json.JSONDecodeError: 152 | continue 153 | finally: 154 | progress_bar.update(1) 155 | progress_bar.set_description("🟢 Validating dataset") 156 | progress_bar.close() 157 | 158 | with open(training_file, "w") as file: 159 | file.writelines(valid_lines) 160 | return training_file 161 | 162 | def finetune(self, training_file: str, webhook_url: str = None) -> Dict: 163 | file = openai.File.create(file=open(training_file, "rb"), purpose="fine-tune") 164 | finetune = openai.FineTuningJob.create( 165 | training_file=file.get("id"), model=OPENAI_MODELS[self.base_model] 166 | ) 167 | return {**finetune, "training_file": training_file} 168 | 169 | 170 | class ReplicateFinetuningService(FinetuningService): 171 | def __init__( 172 | self, 173 | nodes: List[Union[Document, None]], 174 | num_questions_per_chunk: int, 175 | batch_size: int, 176 | base_model: str = "LLAMA2_7B_CHAT", 177 | ): 178 | super().__init__( 179 | nodes=nodes, 180 | num_questions_per_chunk=num_questions_per_chunk, 181 | batch_size=batch_size, 182 | ) 183 | self.base_model = base_model 184 | 185 | def generate_prompt_and_completion(self, node): 186 | prompt = generate_qa_pair_prompt( 187 | context=node.text, 188 | num_of_qa_pairs=self.num_questions_per_chunk, 189 | format=REPLICATE_FORMAT, 190 | ) 191 | completion = openai.ChatCompletion.create( 192 | model="gpt-3.5-turbo", 193 | messages=[{"role": "user", "content": prompt}], 194 | temperature=0, 195 | ) 196 | return completion.choices[0].message.content 197 | 198 | def validate_dataset(self, training_file: str) -> str: 199 | valid_data = [] 200 | with open(training_file, "r") as f: 201 | lines = f.readlines() 202 | total_lines = len(lines) 203 | progress_bar = tqdm( 204 | total=total_lines, 205 | desc="🟠 Validating training data", 206 | file=sys.stdout, 207 | ) 208 | for i, line in enumerate(lines, start=1): 209 | try: 210 | data = json.loads(line) 211 | if "prompt" in data and "completion" in data: 212 | valid_data.append(data) 213 | except json.JSONDecodeError: 214 | pass 215 | progress_bar.update(1) 216 | progress_bar.set_description("🟢 Validating training data") 217 | progress_bar.close() 218 | 219 | with open(training_file, "w") as f: 220 | for data in valid_data: 221 | f.write(json.dumps(data) + "\n") 222 | 223 | return training_file 224 | 225 | def finetune(self, training_file: str, webhook_url: str = None) -> Dict: 226 | training_file_url = upload_replicate_dataset(training_file=training_file) 227 | training = replicate.Client( 228 | api_token=config("REPLICATE_API_KEY") 229 | ).trainings.create( 230 | version=REPLICATE_MODELS[self.base_model], 231 | input={ 232 | "train_data": training_file_url, 233 | "num_train_epochs": 3, 234 | }, 235 | destination="homanp/test", 236 | webhook=webhook_url, 237 | ) 238 | progress_bar = tqdm( 239 | total=1, 240 | desc="🟢 Started model training", 241 | file=sys.stdout, 242 | ) 243 | progress_bar.update(1) 244 | progress_bar.close() 245 | return {"id": training.id, "training_file": training_file} 246 | 247 | 248 | def get_finetuning_service( 249 | nodes: List[Union[Document, None]], 250 | provider: str = "openai", 251 | base_model: str = "GPT_35_TURBO", 252 | num_questions_per_chunk: int = 10, 253 | batch_size: int = 10, 254 | ): 255 | services = { 256 | "OPENAI": OpenAIFinetuningService, 257 | "REPLICATE": ReplicateFinetuningService, 258 | # Add other providers here 259 | } 260 | service = services.get(provider) 261 | if service is None: 262 | raise ValueError(f"Unsupported provider: {provider}") 263 | return service( 264 | nodes=nodes, 265 | num_questions_per_chunk=num_questions_per_chunk, 266 | batch_size=batch_size, 267 | base_model=base_model, 268 | ) 269 | 270 | 271 | def upload_replicate_dataset(training_file: str) -> str: 272 | headers = {"Authorization": f"Token {config('REPLICATE_API_KEY')}"} 273 | upload_response = requests.post( 274 | "https://dreambooth-api-experimental.replicate.com/v1/upload/data.jsonl", 275 | headers=headers, 276 | ) 277 | upload_response_data = upload_response.json() 278 | upload_url = upload_response_data["upload_url"] 279 | 280 | with open(training_file, "rb") as f: 281 | requests.put( 282 | upload_url, 283 | headers={"Content-Type": "application/jsonl"}, 284 | data=f.read(), 285 | ) 286 | 287 | serving_url = upload_response_data["serving_url"] 288 | return serving_url 289 | -------------------------------------------------------------------------------- /nagato/service/prompts.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | 3 | GPT_DATA_FORMAT = ( 4 | "{" 5 | '"messages": [' 6 | '{"role": "system", "content": "You are an AI agent that\'s an expert at answering questions."}, ' 7 | '{"role": "user", "content": "What\'s the capital of France?"}, ' 8 | '{"role": "assistant", "content": "Paris, is the capital of France."}' 9 | "]" 10 | "}" 11 | ) 12 | 13 | REPLICATE_FORMAT = ( 14 | "{" 15 | '"prompt": "What\'s the capital of France?",' 16 | '"completion": "Paris, is the capital of France"' 17 | "}" 18 | ) 19 | 20 | 21 | def generate_rag_prompt(context: str, input: str) -> str: 22 | prompt = ( 23 | "You are an assistant for question-answering tasks. Use the following pieces" 24 | "of retrieved context to answer the question. If you don't know the answer, " 25 | "just say that you don't know. Use three sentences maximum and keep the answer concise.\n\n" 26 | f"Question: {input}\n" 27 | f"Context: {context}" 28 | f"Answer:" 29 | ) 30 | return prompt 31 | 32 | 33 | def generate_qa_pair_prompt( 34 | format: str, context: str, num_of_qa_pairs: int = 10 35 | ) -> str: 36 | prompt = ( 37 | "You are an AI assistant tasked with generating question and answer pairs" 38 | "for the given context using the given format. Only answer in the format with" 39 | f"no other text. You should create the following number of question/answer pairs: {num_of_qa_pairs}" 40 | "Return the question/answer pairs as a JSONL." 41 | "Each dict in the list should have the full context provided," 42 | "a relevant question to the context and an answer to the question.\n\n" 43 | f"Format:\n {format}\n\n" 44 | f"Context:\n {context}" 45 | ) 46 | return prompt 47 | -------------------------------------------------------------------------------- /nagato/service/query.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | from typing import Callable 3 | 4 | import litellm 5 | from decouple import config 6 | 7 | from nagato.service.prompts import ( 8 | generate_rag_prompt, 9 | ) 10 | 11 | 12 | class QueryService(ABC): 13 | def __init__( 14 | self, 15 | provider: str, 16 | model: str, 17 | ): 18 | self.provider = provider 19 | self.model = model 20 | if self.provider == "REPLICATE": 21 | self.api_key = config("REPLICATE_API_KEY") 22 | elif self.provider == "OPENAI": 23 | self.api_key = config("OPENAI_API_KEY") 24 | else: 25 | self.api_key = None 26 | 27 | def predict_with_embedding( 28 | self, 29 | input: str, 30 | context: str, 31 | system_prompt: str, 32 | enable_streaming: bool = False, 33 | callback: Callable = None, 34 | ): 35 | litellm.api_key = self.api_key 36 | prompt = generate_rag_prompt(context=context, input=input) 37 | output = litellm.completion( 38 | model=self.model, 39 | messages=[ 40 | { 41 | "content": system_prompt, 42 | "role": "system", 43 | }, 44 | { 45 | "content": prompt, 46 | "role": "user", 47 | }, 48 | ], 49 | max_tokens=2000, 50 | temperature=0, 51 | stream=enable_streaming, 52 | ) 53 | if enable_streaming: 54 | for chunk in output: 55 | callback(chunk["choices"][0]["delta"]["content"]) 56 | return output 57 | 58 | def predict( 59 | self, 60 | input: str, 61 | enable_streaming: bool = False, 62 | system_prompt: str = None, 63 | callback: Callable = None, 64 | ): 65 | litellm.api_key = self.api_key 66 | 67 | output = litellm.completion( 68 | model=self.model, 69 | messages=[ 70 | {"content": system_prompt, "role": "system"}, 71 | {"content": input, "role": "user"}, 72 | ], 73 | max_tokens=450, 74 | temperature=0, 75 | stream=enable_streaming, 76 | ) 77 | if enable_streaming: 78 | for chunk in output: 79 | callback(chunk["choices"][0]["delta"]["content"]) 80 | return output 81 | -------------------------------------------------------------------------------- /nagato/service/vectordb.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Any, List 3 | 4 | import pinecone 5 | from decouple import config 6 | from numpy import ndarray 7 | 8 | 9 | class VectorDBService(ABC): 10 | def __init__(self, index_name: str, dimension: int, filter_id: str = None): 11 | self.index_name = index_name 12 | self.filter_id = filter_id 13 | self.dimension = dimension 14 | 15 | @abstractmethod 16 | def upsert(): 17 | pass 18 | 19 | @abstractmethod 20 | def query(): 21 | pass 22 | 23 | @abstractmethod 24 | def rerank(self, query: str, documents: list, top_n: int = 3): 25 | pass 26 | 27 | 28 | class PineconeVectorService(VectorDBService): 29 | def __init__(self, index_name: str, dimension: int, filter_id: str = None): 30 | super().__init__( 31 | index_name=index_name, dimension=dimension, filter_id=filter_id 32 | ) 33 | pinecone.init( 34 | api_key=config("PINECONE_API_KEY"), 35 | environment=config("PINECONE_ENVIRONMENT"), 36 | ) 37 | # Create a new vector index if it doesn't 38 | # exist dimensions should be passed in the arguments 39 | if index_name not in pinecone.list_indexes(): 40 | pinecone.create_index( 41 | name=index_name, metric="cosine", shards=1, dimension=dimension 42 | ) 43 | self.index = pinecone.Index(index_name=self.index_name) 44 | 45 | def upsert(self, vectors: ndarray): 46 | self.index.upsert(vectors=vectors, namespace=self.filter_id) 47 | 48 | def query(self, queries: List[ndarray], top_k: int, include_metadata: bool = True): 49 | results = self.index.query( 50 | queries=queries, 51 | top_k=top_k, 52 | include_metadata=include_metadata, 53 | namespace=self.filter_id, 54 | ) 55 | return results["results"][0]["matches"] 56 | 57 | def rerank(self, query: str, documents: Any, top_n: int = 3): 58 | from cohere import Client 59 | 60 | api_key = config("COHERE_API_KEY") 61 | if not api_key: 62 | raise ValueError("API key for Cohere is not present.") 63 | cohere_client = Client(api_key=api_key) 64 | docs = [ 65 | ( 66 | f"{doc['metadata']['content']}\n\n" 67 | f"page number: {doc['metadata']['page_label']}" 68 | ) 69 | for doc in documents 70 | ] 71 | re_ranked = cohere_client.rerank( 72 | model="rerank-multilingual-v2.0", 73 | query=query, 74 | documents=docs, 75 | top_n=top_n, 76 | ).results 77 | results = [] 78 | for obj in re_ranked: 79 | results.append(obj.document["text"]) 80 | return results 81 | 82 | 83 | def get_vector_service( 84 | provider: str, index_name: str, filter_id: str = None, dimension: int = 384 85 | ): 86 | services = { 87 | "PINECONE": PineconeVectorService, 88 | # Add other providers here 89 | # e.g "weaviate": WeaviateVectorService, 90 | } 91 | service = services.get(provider) 92 | if service is None: 93 | raise ValueError(f"Unsupported provider: {provider}") 94 | return service(index_name=index_name, filter_id=filter_id, dimension=dimension) 95 | -------------------------------------------------------------------------------- /nagato/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/homanp/nagato/202d752bf03cbb471fbd574a70274696c9c41c30/nagato/utils/__init__.py -------------------------------------------------------------------------------- /nagato/utils/lazy_model_loader.py: -------------------------------------------------------------------------------- 1 | from decouple import config 2 | 3 | 4 | class LazyModelLoader: 5 | def __init__(self, model_name: str = None): 6 | self._model = None 7 | self._model_name = model_name 8 | 9 | @property 10 | def model(self): 11 | if self._model is None and self._model_name is not None: 12 | from sentence_transformers import SentenceTransformer 13 | 14 | self._model = SentenceTransformer( 15 | model_name_or_path=self._model_name, use_auth_token=config("HF_API_KEY") 16 | ) 17 | return self._model 18 | -------------------------------------------------------------------------------- /nagato/utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import colorlog 4 | 5 | 6 | def setup_logger(): 7 | """Return a logger with a default ColoredFormatter.""" 8 | formatter = colorlog.ColoredFormatter( 9 | "%(log_color)s%(levelname)-8s%(reset)s %(blue)s%(message)s", 10 | datefmt=None, 11 | reset=True, 12 | log_colors={ 13 | "DEBUG": "cyan", 14 | "INFO": "green", 15 | "WARNING": "yellow", 16 | "ERROR": "red", 17 | "CRITICAL": "red,bg_white", 18 | }, 19 | secondary_log_colors={}, 20 | style="%", 21 | ) 22 | 23 | logger = colorlog.getLogger(__name__) 24 | handler = logging.StreamHandler() 25 | handler.setFormatter(formatter) 26 | logger.addHandler(handler) 27 | logger.setLevel(logging.DEBUG) 28 | 29 | return logger 30 | 31 | 32 | # Set up logger 33 | logger = setup_logger() 34 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "nagato-ai" 3 | version = "0.0.16" 4 | description = "" 5 | authors = ["Ismail Pelaseyed"] 6 | readme = "./README.md" 7 | packages = [{include = "nagato"}] 8 | 9 | [tool.poetry.dependencies] 10 | python = "^3.9" 11 | python-decouple = "^3.8" 12 | pydantic = "^1.10.7" 13 | flake8 = "^6.0.0" 14 | ruff = "^0.0.265" 15 | black = "^23.3.0" 16 | pinecone-client = "^2.2.2" 17 | unstructured = "^0.10.16" 18 | requests = "^2.31.0" 19 | colorlog = "^6.7.0" 20 | vulture = "^2.7" 21 | llama-index = "^0.8.37" 22 | pypdf = "^3.16.2" 23 | tiktoken = "^0.5.1" 24 | sentence-transformers = "^2.2.2" 25 | replicate = "^0.15.4" 26 | wheel = "^0.41.0" 27 | python-dotenv = "^1.0.0" 28 | tqdm = "^4.66.1" 29 | setuptools = "^68.2.2" 30 | cohere = "^4.32" 31 | litellm = "^0.12.9" 32 | 33 | [build-system] 34 | requires = ["poetry-core"] 35 | build-backend = "poetry.core.masonry.api" 36 | 37 | [tool.vulture] 38 | exclude = [ 39 | "*settings.py", 40 | "*/docs/*.py", 41 | "*/test_*.py", 42 | "*/.venv/*.py", 43 | ] 44 | ignore_decorators = ["@app.route", "@require_*"] 45 | ignore_names = ["visit_*", "do_*"] 46 | make_whitelist = true 47 | min_confidence = 100 48 | paths = ["."] 49 | sort_by_size = true 50 | verbose = false 51 | 52 | [tool.ruff] 53 | exclude = [ 54 | "*settings.py", 55 | "*/docs/*.py", 56 | "*/test_*.py", 57 | "*/.venv/*.py", 58 | "whitelist.py" 59 | ] -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name="nagato-ai", 5 | version="0.0.16", 6 | packages=["nagato"], 7 | description="The open framework for Q&A finetuning LLMs on private data", 8 | long_description=open("README.md").read(), 9 | long_description_content_type="text/markdown", 10 | author="Ismail Pelaseyed", 11 | author_email="ismail@superagent.sh", 12 | url="https://github.com/homanp/nagato", 13 | classifiers=[ 14 | "Programming Language :: Python :: 3", 15 | "License :: OSI Approved :: MIT License", 16 | ], 17 | install_requires=[ 18 | "python-decouple>=3.8", 19 | "pydantic>=1.10.7", 20 | "flake8>=6.0.0", 21 | "ruff>=0.0.265", 22 | "black>=23.3.0", 23 | "pinecone-client>=2.2.2", 24 | "unstructured>=0.10.16", 25 | "requests>=2.31.0", 26 | "colorlog>=6.7.0", 27 | "vulture>=2.7", 28 | "llama-index>=0.8.37", 29 | "pypdf>=3.16.2", 30 | "tiktoken>=0.5.1", 31 | "sentence-transformers>=2.2.2", 32 | "replicate>=0.15.4", 33 | "wheel>=0.41.0", 34 | "python-dotenv>=1.0.0", 35 | "tqdm>=4.66.1", 36 | "setuptools>=68.2.2", 37 | ], 38 | ) 39 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/homanp/nagato/202d752bf03cbb471fbd574a70274696c9c41c30/tests/__init__.py -------------------------------------------------------------------------------- /tests/embedding.py: -------------------------------------------------------------------------------- 1 | from nagato.service import create_vector_embeddings 2 | 3 | 4 | def main(): 5 | result = create_vector_embeddings( 6 | type="PDF", 7 | url="https://digitalassets.tesla.com/tesla-contents/image/upload/IR/TSLA-Q2-2023-Update.pdf", 8 | filter_id="011", 9 | model="all-MiniLM-L6-v2", 10 | ) 11 | print(result) 12 | 13 | 14 | main() 15 | -------------------------------------------------------------------------------- /tests/finetune.py: -------------------------------------------------------------------------------- 1 | from nagato.service import create_finetuned_model 2 | 3 | 4 | def main(): 5 | result = create_finetuned_model( 6 | url="https://digitalassets.tesla.com/tesla-contents/image/upload/IR/TSLA-Q2-2023-Update.pdf", 7 | type="PDF", 8 | base_model="LLAMA2_7B_CHAT", 9 | provider="REPLICATE", 10 | webhook_url="https://webhook.site/ebe803b9-1e34-4b20-a6ca-d06356961cd1", 11 | num_questions_per_chunk=40, 12 | ) 13 | print(f"🤖 MODEL: {result}") 14 | 15 | 16 | main() 17 | -------------------------------------------------------------------------------- /tests/predict.py: -------------------------------------------------------------------------------- 1 | from nagato.service import predict 2 | 3 | 4 | def callback_method(item): 5 | print(item) 6 | 7 | 8 | def main(): 9 | result = predict( 10 | input="What was Teslas YoY revenue increase in Q2 2023?", 11 | provider="REPLICATE", 12 | model="homanp/test:bc8afbabceaec8abb9b15fade05ff42db371b01fa251541b49c8ba9a9d44bc1f", 13 | system_prompt="You are an helpful assistant", 14 | enable_streaming=True, 15 | callback=callback_method, 16 | ) 17 | print(result) 18 | 19 | 20 | main() 21 | -------------------------------------------------------------------------------- /tests/predict_with_embedding.py: -------------------------------------------------------------------------------- 1 | from nagato.service import predict_with_embedding 2 | 3 | 4 | def callback_method(item): 5 | print(item) 6 | 7 | 8 | def main(): 9 | result = predict_with_embedding( 10 | input="What was the EBITDA for Q2 2023?", 11 | provider="REPLICATE", 12 | model="meta/llama-2-13b-chat:f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d", 13 | vector_db="PINECONE", 14 | embedding_model="huggingface/sentence-transformers/all-MiniLM-L6-v2", 15 | embedding_filter_id="011", 16 | enable_streaming=False, 17 | callback=callback_method, 18 | ) 19 | print(result) 20 | 21 | 22 | main() 23 | -------------------------------------------------------------------------------- /tests/query_embedding.py: -------------------------------------------------------------------------------- 1 | from nagato.service import query_embedding 2 | 3 | 4 | def callback_method(item): 5 | print(item) 6 | 7 | 8 | def main(): 9 | result = query_embedding( 10 | query="What was total revenues in Q2 2023?", 11 | filter_id="011", 12 | model="huggingface/sentence-transformers/all-MiniLM-L6-v2", 13 | ) 14 | print(result) 15 | 16 | 17 | main() 18 | --------------------------------------------------------------------------------