├── utils ├── build_graph.py └── pdf_splitter.py ├── main.py ├── README.md ├── vector_store └── weaviate_store.py ├── models ├── embedder.py └── llm.py ├── .gitignore ├── requirements.txt └── LICENSE /utils/build_graph.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | import matplotlib.pyplot as plt 3 | 4 | class Knowledge_graph: 5 | def __init__(self, df: str = None): 6 | self.df = df 7 | 8 | def create_graph(self): 9 | self.G = nx.from_pandas_edgelist(self.df, 'node_1', 'node_2', edge_attr='edge', create_using=nx.MultiGraph()) 10 | nx.draw(self.G, with_labels=True) 11 | 12 | def query_sub_graph(self, query_node): 13 | neighbors = list(self.G.neighbors(query_node)) + [query_node] 14 | subgraph = self.G.subgraph(neighbors) 15 | 16 | pos = nx.spring_layout(subgraph) 17 | 18 | plt.figure(figsize=(8, 8)) 19 | 20 | node_size = 2000 21 | node_color = 'lightblue' 22 | font_color = 'black' 23 | font_weight = 'bold' 24 | font_size = 8 25 | edge_color = 'gray' 26 | edge_style = 'dashed' 27 | 28 | # Draw the subgraph 29 | nx.draw(subgraph, pos, with_labels=True, node_size=node_size, node_color=node_color, font_color=font_color, font_size = font_size, 30 | font_weight=font_weight, edge_color=edge_color, style=edge_style) 31 | 32 | # Add additional customizations 33 | plt.title(f"Graph of Node: {query_node}") 34 | 35 | # Save the plot to a file 36 | plt.savefig('subgraph.png') 37 | #plt.show() -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from models.embedder import Embedder 2 | from models.llm import RAG_LLM 3 | from vector_store.weaviate_store import Weaviate_Store 4 | 5 | from utils.pdf_splitter import PDFSplitter 6 | from utils.build_graph import Knowledge_graph 7 | 8 | import torch 9 | 10 | if __name__ == '__main__': 11 | splitter = PDFSplitter() 12 | chunks = splitter.split_document(r"C:\Users\Jino Rohit\Downloads\Jino-Rohit-Resume.pdf", max_size = 500) 13 | 14 | rag_llm = RAG_LLM( 15 | model_directory="C:/Users/Jino Rohit/Downloads/mistral-7b-orca", 16 | temperature=1.0, 17 | top_k=5, 18 | top_p=0.8, 19 | top_a=0.9, 20 | token_repetition_penalty=1.2 21 | ) 22 | rag_llm.setup_model() 23 | entities_df = rag_llm.generate_nodes(chunks, max_new_tokens = 1000) 24 | 25 | kg = Knowledge_graph(entities_df) 26 | kg.create_graph() 27 | 28 | #if you want to do hybrid search, create embeddings 29 | emb_model = Embedder() 30 | nodes_embeds = emb_model.embed(list(entities_df['node_1'] + ' ' + entities_df['node_2'] + ' ' + entities_df['edge'])) 31 | 32 | entities_df['vectors'] = nodes_embeds.tolist() 33 | 34 | vector_store = Weaviate_Store(store_name = 'Demo') 35 | vector_store.store_vectors(entities_df) 36 | 37 | query = "what was the solution for the hackathon" 38 | response = vector_store.keyword_search(query = query, top_k = 5) 39 | 40 | for _r in response['data']['Get']['Demo']: 41 | kg.query_sub_graph(_r['source']) 42 | 43 | rag_llm.generate_answers(chunks, query, 500) 44 | 45 | 46 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RAG with Knowledge Graph 2 | 3 | ## Overview 4 | Retrieval Augmented Generation(RAG) is a way of generating reliable answers from LLM using an external knowledge base. 5 | This project shows how to use RAG with a knowledge graph using Weaviate as the vector database and the exllamav2 implementation of the mistral orca model. 6 | 7 | The following is the pipeline - 8 | 1. Extract text from a PDF 9 | 2. Chunk the data into k size with w overlap. 10 | 2. Extract (source, relation, target) from the chunks and create a knowledge graph 11 | 3. Extract embeddings for the nodes and relationships. 12 | 4. Store the text and vectors in weaviate vector database. 13 | 5. Apply a keyword search on the nodes and retrieve the top k chunks. 14 | 6. Generate the answer from the top k retrieved chunks. 15 | 7. You can also visualize the sub-graph of the nodes used to generate the answer. 16 | 17 | ## Installation 18 | 19 | ```bash 20 | pip install -r requirements.txt 21 | ``` 22 | 23 | ## Features 24 | - Uses the Exllamav2 implementation of the mistral orca model which is extremely fast and memory efficient. 25 | - Construction of knowledge graph from text to understand the concepts better and retrieve more relevant text chunks. 26 | - The vector database used is Weaviate for storing the data and applying the keyword search(can also be done for hybrid search). 27 | 28 | ## Usage 29 | 30 | - First signup to weaviate free sandbox to get your api key and weaviate instance url. 31 | - Create a .env file and store your credentials there. 32 | 33 | Sample .env 34 | 35 | ```env 36 | WEAVIATE_API_KEY = 'xxx' 37 | WEAVIATE_CLIENT_URL = 'xxx' 38 | ``` 39 | 40 | - Run the main file. 41 | 42 | ```bash 43 | python main.py 44 | ``` 45 | -------------------------------------------------------------------------------- /vector_store/weaviate_store.py: -------------------------------------------------------------------------------- 1 | import weaviate 2 | import os 3 | from loguru import logger 4 | from dotenv import load_dotenv 5 | load_dotenv() 6 | 7 | class Weaviate_Store: 8 | def __init__(self, store_name: str) -> None: 9 | self.store_name = store_name 10 | auth_config = weaviate.AuthApiKey(api_key = os.environ['WEAVIATE_API_KEY']) 11 | 12 | self.client = weaviate.Client( 13 | url = os.environ['WEAVIATE_CLIENT_URL'], 14 | auth_client_secret = auth_config 15 | ) 16 | 17 | class_obj = { 18 | "class": self.store_name, 19 | "vectorizer": "none", 20 | } 21 | self.client.schema.create_class(class_obj) 22 | 23 | logger.info('Store has been created') 24 | 25 | def store_vectors(self, data): 26 | self.client.batch.configure(batch_size=100) 27 | with self.client.batch as batch: 28 | for i, d in data.iterrows(): 29 | print(f"importing question: {i+1}") 30 | 31 | properties = { 32 | "source": d["node_1"], 33 | "relation": d["edge"], 34 | "target": d["node_2"], 35 | "chunk": d["chunk"] 36 | } 37 | 38 | batch.add_data_object(properties, self.store_name, vector=d["vectors"]) 39 | 40 | total_docs = self.client.query.aggregate(self.store_name).with_meta_count().do() 41 | logger.info(f'Total items : {total_docs}') 42 | 43 | def keyword_search(self, query: str , top_k: int = 5): 44 | response = ( 45 | self.client.query.get(self.store_name, ["source", "target", "relation", "chunk"]) 46 | .with_bm25(query = query) 47 | .with_limit(top_k) 48 | .do() 49 | ) 50 | return response 51 | 52 | -------------------------------------------------------------------------------- /models/embedder.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer, AutoModel 2 | import torch 3 | 4 | class Embedder: 5 | def __init__(self, model_name: str = 'sentence-transformers/all-MiniLM-L6-v2' ): 6 | self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) 7 | self.model = AutoModel.from_pretrained(model_name) 8 | self.dimension = self.model.embeddings.position_embeddings.embedding_dim 9 | self.max_seq_length = self.model.embeddings.position_embeddings.num_embeddings 10 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 11 | print("Device:", self.device) 12 | 13 | def embed(self, texts, max_seq_length = 500): 14 | self.model.to(self.device) 15 | 16 | encoded_input = self.tokenizer(texts, padding = True, truncation = True, return_tensors='pt', max_length = max_seq_length) 17 | #print("Encoded input done",encoded_input['input_ids'].shape) 18 | 19 | encoded_input = {name: tensor.to(self.device) for name, tensor in encoded_input.items()} 20 | print("Encoded input moved to device") 21 | 22 | with torch.no_grad(): 23 | model_output = self.model(**encoded_input) 24 | 25 | embeddings = self.mean_pooling(model_output, encoded_input['attention_mask']) 26 | tensor_embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) 27 | np_embeddings = tensor_embeddings.cpu().numpy() 28 | return np_embeddings 29 | 30 | def mean_pooling(self, model_output, attention_mask): 31 | token_embeddings = model_output.last_hidden_state.to(self.device) 32 | attention_mask = attention_mask.to(self.device) 33 | input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() 34 | return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) 35 | 36 | def check_sim(self, embed1, embed2): 37 | return embed1 @ embed2 -------------------------------------------------------------------------------- /utils/pdf_splitter.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import List, Union 3 | from tqdm.auto import tqdm 4 | import fitz 5 | from langchain.text_splitter import CharacterTextSplitter 6 | from loguru import logger 7 | 8 | 9 | class PDFSplitter: 10 | def __init__(self, chunk_overlap: int = 200) -> None: 11 | self.chunk_overlap = chunk_overlap 12 | 13 | def split_document( 14 | self, document_path: Union[str, Path], max_size: int, **kwargs 15 | ) -> List[dict]: 16 | 17 | logger.info(f"Partitioning document: {document_path}") 18 | 19 | all_chunks = [] 20 | splitter = CharacterTextSplitter( 21 | separator="\n", 22 | keep_separator=True, 23 | chunk_size=max_size, 24 | chunk_overlap=self.chunk_overlap, 25 | ) 26 | 27 | doc = fitz.open(document_path) 28 | current_text = "" 29 | for page in doc: 30 | text = page.get_text("block") 31 | 32 | if len(text) > max_size: 33 | all_chunks.append( 34 | {"text": current_text, "metadata": {"page": page.number}} 35 | ) 36 | chunks = splitter.split_text(text) 37 | for chunk in chunks: 38 | logger.info( 39 | f"Flushing chunk. Length: {len(chunk)}, page: {page.number}" 40 | ) 41 | all_chunks.append( 42 | {"text": chunk, "metadata": {"page": page.number}} 43 | ) 44 | current_text = "" 45 | 46 | elif len(current_text + text) >= max_size: 47 | if current_text != "": 48 | all_chunks.append( 49 | {"text": current_text, "metadata": {"page": page.number}} 50 | ) 51 | logger.info( 52 | f"Flushing chunk. Length: {len(current_text)}, page: {page.number}" 53 | ) 54 | current_text = text 55 | 56 | else: 57 | current_text += text 58 | 59 | # Filter out empty docs 60 | all_chunks = [ 61 | chunk for chunk in all_chunks if chunk["text"].strip().replace(" ", "") 62 | ] 63 | return all_chunks -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | .idea/ 132 | .vscode/ 133 | lightning_logs/ 134 | *.zip 135 | *.jpg 136 | *.jpeg 137 | *.png 138 | *.gif 139 | 140 | /data 141 | .DS_Store 142 | *.gguf 143 | *.ggml -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.25.0 2 | aiohttp==3.9.1 3 | aiosignal==1.3.1 4 | annotated-types==0.6.0 5 | anyio==4.2.0 6 | asttokens @ file:///home/conda/feedstock_root/build_artifacts/asttokens_1698341106958/work 7 | async-timeout==4.0.3 8 | attrs==23.2.0 9 | Authlib==1.3.0 10 | backcall @ file:///home/conda/feedstock_root/build_artifacts/backcall_1592338393461/work 11 | certifi==2022.12.7 12 | cffi==1.16.0 13 | charset-normalizer==2.1.1 14 | click==8.1.7 15 | colorama @ file:///home/conda/feedstock_root/build_artifacts/colorama_1666700638685/work 16 | comm @ file:///home/conda/feedstock_root/build_artifacts/comm_1704278392174/work 17 | contourpy==1.2.0 18 | cramjam==2.7.0 19 | cryptography==41.0.7 20 | cycler==0.12.1 21 | dataclasses-json==0.6.3 22 | datasets==2.14.7 23 | debugpy @ file:///C:/b/abs_c0y1fjipt2/croot/debugpy_1690906864587/work 24 | decorator @ file:///home/conda/feedstock_root/build_artifacts/decorator_1641555617451/work 25 | dill==0.3.7 26 | exceptiongroup==1.2.0 27 | executing @ file:///home/conda/feedstock_root/build_artifacts/executing_1698579936712/work 28 | exllamav2==0.0.11 29 | fastparquet==2023.10.1 30 | filelock==3.13.1 31 | FlagEmbedding==1.1.8 32 | fonttools==4.47.0 33 | frozenlist==1.4.1 34 | fsspec==2023.10.0 35 | greenlet==3.0.3 36 | huggingface-hub==0.17.3 37 | idna==3.4 38 | importlib-metadata @ file:///home/conda/feedstock_root/build_artifacts/importlib-metadata_1703269254275/work 39 | importlib-resources==6.1.1 40 | ipykernel @ file:///D:/bld/ipykernel_1703631901958/work 41 | ipython @ file:///D:/bld/ipython_1680185618122/work 42 | jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1696326070614/work 43 | Jinja2==3.1.2 44 | joblib==1.3.2 45 | jsonpatch==1.33 46 | jsonpointer==2.4 47 | jupyter_client @ file:///home/conda/feedstock_root/build_artifacts/jupyter_client_1699283905679/work 48 | jupyter_core @ file:///D:/bld/jupyter_core_1669775233774/work 49 | kiwisolver==1.4.5 50 | langchain==0.0.354 51 | langchain-community==0.0.8 52 | langchain-core==0.1.6 53 | langsmith==0.0.77 54 | loguru==0.7.2 55 | MarkupSafe==2.1.3 56 | marshmallow==3.20.1 57 | matplotlib==3.8.2 58 | matplotlib-inline @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-inline_1660814786464/work 59 | mpmath==1.3.0 60 | multidict==6.0.4 61 | multiprocess==0.70.15 62 | mypy-extensions==1.0.0 63 | nest-asyncio @ file:///home/conda/feedstock_root/build_artifacts/nest-asyncio_1697083700168/work 64 | networkx==3.2.1 65 | ninja==1.11.1.1 66 | nltk==3.8.1 67 | numpy==1.26.3 68 | packaging @ file:///home/conda/feedstock_root/build_artifacts/packaging_1696202382185/work 69 | pandas==2.1.4 70 | parso @ file:///home/conda/feedstock_root/build_artifacts/parso_1638334955874/work 71 | pickleshare @ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1602536217715/work 72 | Pillow==9.3.0 73 | prompt-toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1702399386289/work 74 | psutil @ file:///C:/Windows/Temp/abs_b2c2fd7f-9fd5-4756-95ea-8aed74d0039flsd9qufz/croots/recipe/psutil_1656431277748/work 75 | pure-eval @ file:///home/conda/feedstock_root/build_artifacts/pure_eval_1642875951954/work 76 | pyarrow==14.0.2 77 | pyarrow-hotfix==0.6 78 | pycparser==2.21 79 | pydantic==2.5.3 80 | pydantic_core==2.14.6 81 | Pygments @ file:///home/conda/feedstock_root/build_artifacts/pygments_1700607939962/work 82 | PyMuPDF==1.23.8 83 | PyMuPDFb==1.23.7 84 | pyparsing==3.1.1 85 | python-dateutil @ file:///home/conda/feedstock_root/build_artifacts/python-dateutil_1626286286081/work 86 | python-dotenv==1.0.0 87 | pytz==2023.3.post1 88 | pywin32==227 89 | PyYAML==6.0.1 90 | pyzmq @ file:///C:/b/abs_655zk4a3s8/croot/pyzmq_1686601465034/work 91 | regex==2023.12.25 92 | requests==2.31.0 93 | safetensors==0.4.1 94 | scikit-learn==1.3.2 95 | scipy==1.11.4 96 | sentence-transformers==2.2.2 97 | sentencepiece==0.1.99 98 | six @ file:///home/conda/feedstock_root/build_artifacts/six_1620240208055/work 99 | sniffio==1.3.0 100 | SQLAlchemy==2.0.25 101 | stack-data @ file:///home/conda/feedstock_root/build_artifacts/stack_data_1669632077133/work 102 | sympy==1.12 103 | tenacity==8.2.3 104 | threadpoolctl==3.2.0 105 | tokenizers==0.14.1 106 | torch==2.1.2+cu121 107 | torchaudio==2.1.2+cu118 108 | torchvision==0.16.2 109 | tornado @ file:///D:/bld/tornado_1656937934674/work 110 | tqdm==4.66.1 111 | traitlets @ file:///home/conda/feedstock_root/build_artifacts/traitlets_1704212992681/work 112 | transformers==4.34.0 113 | typing-inspect==0.9.0 114 | typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/typing_extensions_1702176139754/work 115 | tzdata==2023.4 116 | urllib3==1.26.13 117 | validators==0.22.0 118 | wcwidth @ file:///home/conda/feedstock_root/build_artifacts/wcwidth_1700607916581/work 119 | weaviate-client==3.26.0 120 | websockets==12.0 121 | win32-setctime==1.1.0 122 | xxhash==3.4.1 123 | yarl==1.9.4 124 | zipp @ file:///home/conda/feedstock_root/build_artifacts/zipp_1695255097490/work 125 | -------------------------------------------------------------------------------- /models/llm.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | from loguru import logger 3 | from utils.pdf_splitter import PDFSplitter 4 | 5 | from exllamav2 import( 6 | ExLlamaV2, 7 | ExLlamaV2Config, 8 | ExLlamaV2Cache, 9 | ExLlamaV2Tokenizer, 10 | ) 11 | 12 | from exllamav2.generator import ( 13 | ExLlamaV2BaseGenerator, 14 | ExLlamaV2Sampler 15 | ) 16 | 17 | import re 18 | import pandas as pd 19 | 20 | def lowercase_dict(d): 21 | return {key: value.lower() for key, value in d.items()} 22 | 23 | 24 | def process_results(object): 25 | filtered_list = [item for item in object if len(item) > 1] 26 | 27 | elements_to_remove = { 28 | 'node_1': 'A concept from extracted ontology', 29 | 'node_2': 'A related concept from extracted ontology', 30 | 'edge': 'relationship between the two concepts, node_1 and node_2 in one or two sentences' 31 | } 32 | 33 | filtered_list = [lowercase_dict(item) for item in filtered_list if item != elements_to_remove] 34 | return filtered_list 35 | 36 | system_prompt = """ 37 | You are a network graph maker who extracts terms and their relations from a given context. 38 | You are provided with a context chunk (delimited by ```) Your task is to extract the ontology 39 | of terms mentioned in the given context. These terms should represent the key concepts as per the context. \n 40 | Thought 1: While traversing through each sentence, Think about the key terms mentioned in it.\n 41 | \tTerms may include object, entity, location, organization, person, \n 42 | \tcondition, acronym, documents, service, concept, etc.\n 43 | \tTerms should be as atomistic as possible\n\n 44 | Thought 2: Think about how these terms can have one on one relation with other terms.\n 45 | \tTerms that are mentioned in the same sentence or the same paragraph are typically related to each other.\n 46 | \tTerms can be related to many other terms\n\n 47 | Thought 3: Find out the relation between each such related pair of terms. \n\n 48 | Format your output as a list of json. Each element of the list contains a pair of terms 49 | and the relation between them, like the following: \n 50 | [\n 51 | {\n 52 | "node_1": "A concept from extracted ontology",\n 53 | "node_2": "A related concept from extracted ontology",\n 54 | "edge": "relationship between the two concepts, node_1 and node_2 in one or two sentences"\n 55 | }, {...}\n" 56 | ]" 57 | DO NOT RETURN ANY EXPLANATION, ONLY RETURN THE LIST OF JSON. 58 | """ 59 | 60 | qna_prompt = """You are a helpful assistant. You do not respond as 'User' or pretend to be 'User'. 61 | You only respond once as Assistant. You are allowed to use only the given context below to answer the user's queries, 62 | and if the answer is not present in the context, say you don't know the answer. 63 | CONTEXT: {context} 64 | """ 65 | 66 | class RAG_LLM: 67 | def __init__(self, model_directory: str, temperature: float, top_k: float, top_p: float, top_a: float, token_repetition_penalty: float): 68 | 69 | self.model_directory = model_directory 70 | self.temperature = temperature 71 | self.top_k = top_k 72 | self.top_p = top_p 73 | self.top_a = top_a 74 | self.token_repetition_penalty = token_repetition_penalty 75 | 76 | #r"C:\Users\Jino Rohit\Downloads\mistral-7b-orca" 77 | 78 | def setup_model(self) -> None: 79 | self.config = ExLlamaV2Config() 80 | self.config.model_dir = self.model_directory 81 | self.config.prepare() 82 | 83 | self.model = ExLlamaV2(self.config) 84 | logger.info("Loading model...") 85 | 86 | self.cache = ExLlamaV2Cache(self.model, lazy = True) 87 | self.model.load_autosplit(self.cache) 88 | 89 | self.tokenizer = ExLlamaV2Tokenizer(self.config) 90 | 91 | self.generator = ExLlamaV2BaseGenerator(self.model, self.cache, self.tokenizer) 92 | 93 | self.settings = ExLlamaV2Sampler.Settings() 94 | self.settings.temperature = self.temperature 95 | self.settings.top_k = self.top_k 96 | self.settings.top_p = self.top_p 97 | self.settings.top_a = self.top_a 98 | self.settings.token_repetition_penalty = self.token_repetition_penalty 99 | self.settings.disallow_tokens(self.tokenizer, [self.tokenizer.eos_token_id]) 100 | 101 | def generate_nodes(self, chunks, max_new_tokens) -> str: 102 | if self.generator is None or self.settings is None: 103 | raise RuntimeError("Model not initialized. Call setup_model() first.") 104 | 105 | all_matches = [] 106 | 107 | self.generator.warmup() 108 | 109 | prompt = """<|im_start|>system 110 | {system_prompt} 111 | <|im_end|> 112 | <|im_start|>user 113 | {text_chunk} 114 | <|im_end|> 115 | <|im_start|>assistant 116 | """ 117 | 118 | for idx, chunk in enumerate(chunks): 119 | logger.info(f"Extracting tuples from chunk: {idx}") 120 | output = self.generator.generate_simple(prompt.format(system_prompt = system_prompt, text_chunk = chunk['text']), self.settings, max_new_tokens, seed = 1234) 121 | 122 | #extracting dict types 123 | pattern = r'\{[^}]+\}' 124 | matches = re.findall(pattern, output) 125 | try: 126 | dictionaries = [eval(match) for match in matches] 127 | dictionaries = process_results(dictionaries) 128 | 129 | for _d in dictionaries: 130 | _d['chunk'] = chunk['text'] 131 | 132 | all_matches.extend(dictionaries) 133 | except: 134 | pass 135 | 136 | df = pd.DataFrame(all_matches) 137 | df = df.drop_duplicates(subset=['node_1', 'node_2', 'edge'], keep=False) 138 | return df 139 | 140 | def generate_answers(self, chunks, query, max_new_tokens) -> str: 141 | if self.generator is None or self.settings is None: 142 | raise RuntimeError("Model not initialized. Call setup_model() first.") 143 | 144 | self.generator.warmup() 145 | 146 | prompt = """<|im_start|>system 147 | {qna_prompt} 148 | <|im_end|> 149 | <|im_start|>user 150 | {query} 151 | <|im_end|> 152 | <|im_start|>assistant 153 | """ 154 | logger.info(f"Asking the assistant : {query}") 155 | output = self.generator.generate_simple(prompt.format(qna_prompt = qna_prompt.format(context = chunks), query = query), self.settings, max_new_tokens, seed = 1234) 156 | 157 | start_tag = "<|im_start|>" 158 | end_tag = "<|im_end|>" 159 | start_index = output.rfind(start_tag) 160 | end_index = output.rfind(end_tag) 161 | logger.info(f"Answer : {output[start_index + len(start_tag): end_index]}") 162 | 163 | 164 | 165 | 166 | 167 | # if __name__ == '__main__': 168 | # rag_llm_instance = RAG_LLM( 169 | # model_directory="C:/Users/Jino Rohit/Downloads/mistral-7b-orca", 170 | # temperature=1.0, 171 | # top_k=5, 172 | # top_p=0.8, 173 | # top_a=0.9, 174 | # token_repetition_penalty=1.2 175 | # ) 176 | 177 | # rag_llm_instance.setup_model() 178 | 179 | # prompt = "write a story on" 180 | # generated_text = rag_llm_instance.generate_text(prompt, 100) 181 | # print(generated_text) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright 2023 Jino Rohit 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | --------------------------------------------------------------------------------