├── .env.example ├── .gitignore ├── README.md ├── requirements.txt └── scripts ├── agent.py ├── embedding_search.py ├── eval.py └── ingest.py /.env.example: -------------------------------------------------------------------------------- 1 | OPENAI_API_KEY= 2 | DB_URL= 3 | HF_TOKEN= -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .env 2 | venv/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RAG is more than vector search 2 | 3 | # Introduction 4 | 5 | This is a repository that contains the code for the article `RAG is more than embeddings`. Head over to the [Timescale blog](https://www.timescale.com/blog/rag-is-more-than-just-vector-search/) to read the article if you haven't already. The code is compatible for python >= 3.9. 6 | 7 | ## Instructions 8 | 9 | 1. First install all the required dependencies in the `requirements.txt` file 10 | 11 | ``` 12 | pip install -r requirements.txt 13 | ``` 14 | 15 | 2. Make sure to create a `.env` file that has the same environment variables as our `.env.example ` file. You can get your DB_URL after creating a Timescale instance by following the instructions [here](https://docs.timescale.com/getting-started/latest/services/#create-your-timescale-account). 16 | 17 | 3. Next, ingest in some Github Issues from the `bigcode/the-stack-github-issues` dataset by running the `scripts/ingest.py` file. This will crawl the first 100 issues that match the list of whitelisted repos in our file. We can do so by running the command below. 18 | 19 | ```bash 20 | python3 ./scripts/ingest.py 21 | ``` 22 | 23 | 3. We can then test the function calling ability of our model by running the `scripts/eval.py` file to verify that our model is choosing the right tool with respect to a user query. We can do so by running the command below. 24 | 25 | ```bash 26 | python3 ./scripts/eval.py 27 | ``` 28 | 29 | 4. In order to perform embedding search, we can define a new `.execute` function inside our tools themselves. This allows us to call a `.execute()` function when the tool is selected to immediately return a list of relevant results. To see this in action, run the command below and we'll fetch the top 10 relevant summaries from our database related to the `kubernetes/kubernetes` repository using embedding search. 30 | 31 | ```bash 32 | python3 ./scripts/embedding_search.py 33 | ``` 34 | 35 | 5. Lastly, we'll put it all together in the `agent.py` file where we'll create a one-step agent that'll be able to answer questions about specific repositories in our database. We can run this agent by executing the command below. 36 | 37 | ```bash 38 | python3 ./scripts/agent.py 39 | ``` 40 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | instructor==1.3.7 2 | openai==1.41.0 3 | tqdm==4.66.5 4 | pydantic==2.8.2 5 | datasets==2.21.0 6 | pgvector==0.3.2 7 | asyncpg==0.29.0 8 | jinja2==3.1.4 9 | python-dotenv==1.0.1 10 | fuzzywuzzy==0.18.0 11 | python-Levenshtein==0.25.1 12 | -------------------------------------------------------------------------------- /scripts/agent.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, Field, field_validator, ValidationInfo 2 | from typing import Optional 3 | from openai import OpenAI 4 | from jinja2 import Template 5 | from asyncpg import Connection 6 | from fuzzywuzzy import process 7 | import os 8 | from asyncpg import Record 9 | import asyncpg 10 | from dotenv import load_dotenv 11 | from pgvector.asyncpg import register_vector 12 | from typing import Iterable, Union 13 | from asyncio import run 14 | import instructor 15 | 16 | 17 | def find_closest_repo(query: str, repos: list[str]) -> str | None: 18 | if not query: 19 | return None 20 | 21 | best_match = process.extractOne(query, repos) 22 | return best_match[0] if best_match[1] >= 80 else None 23 | 24 | 25 | class SearchIssues(BaseModel): 26 | """ 27 | Use this when the user wants to get original issue information from the database 28 | """ 29 | 30 | query: Optional[str] 31 | repo: str = Field( 32 | description="the repo to search for issues in, should be in the format of 'owner/repo'" 33 | ) 34 | 35 | @field_validator("repo") 36 | def validate_repo(cls, v: str, info: ValidationInfo): 37 | matched_repo = find_closest_repo(v, info.context["repos"]) 38 | if matched_repo is None: 39 | raise ValueError( 40 | f"Unable to match repo {v} to a list of known repos of {info.context['repos']}" 41 | ) 42 | return matched_repo 43 | 44 | async def execute(self, conn: Connection, limit: int): 45 | if self.query: 46 | embedding = ( 47 | OpenAI() 48 | .embeddings.create(input=self.query, model="text-embedding-3-small") 49 | .data[0] 50 | .embedding 51 | ) 52 | args = [self.repo, limit, embedding] 53 | else: 54 | args = [self.repo, limit] 55 | embedding = None 56 | 57 | sql_query = Template( 58 | """ 59 | SELECT * 60 | FROM {{ table_name }} 61 | WHERE repo_name = $1 62 | {%- if embedding is not none %} 63 | ORDER BY embedding <=> $3 64 | {%- endif %} 65 | LIMIT $2 66 | """ 67 | ).render(table_name="github_issues", embedding=embedding) 68 | 69 | return await conn.fetch(sql_query, *args) 70 | 71 | 72 | class RunSQLReturnPandas(BaseModel): 73 | """ 74 | Use this function when the user wants to do time series analysis, data analysis or compute some specific statistics and we don't have a tool that can supply the necessary information 75 | """ 76 | 77 | user_query_summary: str = Field(description="Description of user's query") 78 | repos: list[str] = Field( 79 | description="the repos to run the query on, should be in the format of 'owner/repo'" 80 | ) 81 | 82 | async def execute(self, conn: Connection, limit: int): 83 | prompt = f""" 84 | ```markdown 85 | You are a SQL expert tasked with writing queries including a time attribute for the relevant table. The user wants to execute a query for the following repos: {self.repos} to answer the query of {self.user_query_summary}. 86 | 87 | - If you need to filter by repository, use the `repo_name` column. 88 | - When partitioning items by a specific time period, always use the `time_bucket` function provided by TimescaleDB. For example: 89 | ```sql 90 | SELECT time_bucket('2 month', start_ts) AS month, 91 | COUNT(*) AS issue_count 92 | FROM github_issues 93 | GROUP BY month 94 | ORDER BY month; 95 | ``` 96 | - This groups data into a 2 month buckets and the individual rows into groups of two week intervals. Adjust the interval (e.g., '2 weeks', '1 day') as needed for the specific query. 97 | - The `time_bucket` function can take any arbitrary interval such as week, month, or year. 98 | - When looking at comments, note that `order_in_issue` begins with 1 and increments thereafter, so make sure to account for that. 99 | - The `metadata` field is currently empty, so do not use it. 100 | - To determine if an issue is closed or not, use the `issue_label` column. 101 | - To detect involvement or participation in an issue, check for comments in the `github_issue_comments` table. 102 | - Only use the tables and the fields provided in the database schema below. 103 | 104 | **Database Schema:** 105 | 106 | ```sql 107 | CREATE TABLE IF NOT EXISTS github_issues ( 108 | issue_id INTEGER, 109 | metadata JSONB, 110 | text TEXT, 111 | repo_name TEXT, 112 | start_ts TIMESTAMPTZ NOT NULL, 113 | end_ts TIMESTAMPTZ, 114 | embedding VECTOR(1536) NOT NULL 115 | ); 116 | 117 | CREATE INDEX github_issue_embedding_idx 118 | ON github_issues 119 | USING diskann (embedding); 120 | 121 | -- Create a Hypertable that breaks it down by 1 month intervals 122 | SELECT create_hypertable('github_issues', 'start_ts', chunk_time_interval => INTERVAL '1 month'); 123 | 124 | CREATE TABLE github_issue_summaries ( 125 | issue_id INTEGER, 126 | text TEXT, 127 | label issue_label NOT NULL, 128 | repo_name TEXT, 129 | embedding VECTOR(1536) NOT NULL 130 | ); 131 | 132 | CREATE INDEX github_issue_summaries_embedding_idx 133 | ON github_issue_summaries 134 | USING diskann (embedding); 135 | ``` 136 | 137 | Examples: 138 | Query: How many issues were created in the apache/spark repository every two months? 139 | SQL: 140 | tsdb=> SELECT time_bucket('2 months', start_ts) AS two_month_period, COUNT(*) AS issue_count 141 | FROM github_issues 142 | WHERE repo_name = 'rust-lang/rust' 143 | GROUP BY two_month_period 144 | ORDER BY two_month_period; 145 | 146 | 147 | two_month_period | issue_count 148 | ------------------------+------------- 149 | 2013-09-01 00:00:00+00 | 1 150 | 2015-01-01 00:00:00+00 | 3 151 | 2015-03-01 00:00:00+00 | 2 152 | 2015-05-01 00:00:00+00 | 1 153 | 154 | Example: 155 | Query: What is the average time to first response for issues in the MicrosoftDocs/azure-docs repository? 156 | SQL: 157 | ```sql 158 | SELECT AVG(EXTRACT(EPOCH FROM (first_comment.created_at - issues.start_ts))) / 3600 AS average_time_to_first_response_hours 159 | FROM github_issues AS issues 160 | LEFT JOIN ( 161 | SELECT issue_id, MIN(created_at) AS created_at 162 | FROM github_issue_comments 163 | WHERE order_in_issue = 1 164 | GROUP BY issue_id 165 | ) AS first_comment ON issues.issue_id = first_comment.issue_id 166 | WHERE issues.repo_name = 'MicrosoftDocs/azure-docs' 167 | AND first_comment.created_at IS NOT NULL; 168 | 169 | 170 | average_time_to_first_response_hours 171 | ------------------------------------- 172 | 12.5 173 | 174 | Example: 175 | Query: How many unique issues has alextp commented on in the tensorflow library? 176 | SQL: 177 | ```sql 178 | SELECT COUNT(DISTINCT issue_id) AS unique_issues_count 179 | FROM github_issue_comments 180 | WHERE author = 'alextp' 181 | AND repo_name = 'tensorflow/tensorflow'; 182 | ``` 183 | 184 | unique_issues_count 185 | -------------------- 186 | 4 187 | """ 188 | 189 | class GeneratedSQL(BaseModel): 190 | chain_of_thought: str 191 | sql: str = Field(description="Generated SQL Query") 192 | 193 | client = instructor.from_openai(OpenAI()) 194 | sql = client.chat.completions.create( 195 | messages=[ 196 | {"role": "system", "content": prompt}, 197 | ], 198 | response_model=GeneratedSQL, 199 | model="gpt-4o-mini", 200 | ) 201 | return await conn.fetch(sql.sql) 202 | 203 | 204 | class SearchSummaries(BaseModel): 205 | """ 206 | This function retrieves summarized information about GitHub issues that match/are similar to a specific query, It's particularly useful for obtaining a quick snapshot of issue trends or patterns within a project. 207 | """ 208 | 209 | query: Optional[str] = Field(description="Relevant user query if any") 210 | repo: str = Field( 211 | description="the repo to search for issues in, should be in the format of 'owner/repo'" 212 | ) 213 | 214 | @field_validator("repo") 215 | def validate_repo(cls, v: str, info: ValidationInfo): 216 | matched_repo = find_closest_repo(v, info.context["repos"]) 217 | if matched_repo is None: 218 | raise ValueError( 219 | f"Unable to match repo {v} to a list of known repos of {info.context['repos']}" 220 | ) 221 | return matched_repo 222 | 223 | async def execute(self, conn: Connection, limit: int): 224 | if self.query: 225 | embedding = ( 226 | OpenAI() 227 | .embeddings.create(input=self.query, model="text-embedding-3-small") 228 | .data[0] 229 | .embedding 230 | ) 231 | args = [self.repo, limit, embedding] 232 | else: 233 | args = [self.repo, limit] 234 | embedding = None 235 | 236 | sql_query = Template( 237 | """ 238 | SELECT * 239 | FROM {{ table_name }} 240 | WHERE repo_name = $1 241 | {%- if embedding is not none %} 242 | ORDER BY embedding <=> $3 243 | {%- endif %} 244 | LIMIT $2 245 | """ 246 | ).render(table_name="github_issue_summaries", embedding=embedding) 247 | 248 | return await conn.fetch(sql_query, *args) 249 | 250 | 251 | class Summary(BaseModel): 252 | chain_of_thought: str 253 | summary: str 254 | 255 | 256 | def summarize_content(results: list[Record], query: Optional[str]): 257 | client = instructor.from_openai(OpenAI()) 258 | return client.chat.completions.create( 259 | messages=[ 260 | { 261 | "role": "system", 262 | "content": "You're a helpful assistant that summarizes information about issues from a github repository. Be sure to output your response in a single paragraph that is concise and to the point.", 263 | }, 264 | { 265 | "role": "user", 266 | "content": Template( 267 | """ 268 | Here are the relevant issues: 269 | {% for result in results %} 270 | 271 | - {{ result['tool'] }} 272 | {% for row in result['result'] %} 273 | {{ dict(row) }} 274 | {% endfor %} 275 | 276 | {% endfor %} 277 | {% if query %} 278 | My specific query is: {{ query }} 279 | {% else %} 280 | Please provide a broad summary and key insights from the issues above. 281 | {% endif %} 282 | """ 283 | ).render(results=results, query=query), 284 | }, 285 | ], 286 | response_model=Summary, 287 | model="gpt-4o-mini", 288 | ) 289 | 290 | 291 | async def get_conn(): 292 | conn = await asyncpg.connect(os.getenv("DB_URL")) 293 | await register_vector(conn) 294 | return conn 295 | 296 | 297 | def one_step_agent(question: str, repos: list[str]): 298 | client = instructor.from_openai(OpenAI(), mode=instructor.Mode.PARALLEL_TOOLS) 299 | 300 | return client.chat.completions.create( 301 | model="gpt-4o-mini", 302 | messages=[ 303 | { 304 | "role": "system", 305 | "content": "You are an AI assistant that helps users query and analyze GitHub issues stored in a PostgreSQL database. Search for summaries when the user wants to understand the high level trends or patterns within a project. Otherwise just get the issues and return them. Only resort to SQL queries if the other tools are not able to answer the user's query.", 306 | }, 307 | { 308 | "role": "user", 309 | "content": Template( 310 | """ 311 | Here is the user's question: {{ question }} 312 | Here is a list of repos that we have stored in our database. Choose the one that is most relevant to the user's query: 313 | {% for repo in repos %} 314 | - {{ repo }} 315 | {% endfor %} 316 | """ 317 | ).render(question=question, repos=repos), 318 | }, 319 | ], 320 | validation_context={"repos": repos}, 321 | response_model=Iterable[ 322 | Union[ 323 | RunSQLReturnPandas, 324 | SearchIssues, 325 | SearchSummaries, 326 | ] 327 | ], 328 | ) 329 | 330 | 331 | async def main(): 332 | query = "What was the differnce in the number of yearly issues for the rust repository? Generate a markdown table that shows the year, number of issues and the change from the previous quarter with a short description at the end summarising the overall trends" 333 | 334 | print(f"Query: {query}") 335 | 336 | repos = [ 337 | "rust-lang/rust", 338 | "kubernetes/kubernetes", 339 | "apache/spark", 340 | "golang/go", 341 | "tensorflow/tensorflow", 342 | "MicrosoftDocs/azure-docs", 343 | "pytorch/pytorch", 344 | "Microsoft/TypeScript", 345 | "python/cpython", 346 | "facebook/react", 347 | "django/django", 348 | "rails/rails", 349 | "bitcoin/bitcoin", 350 | "nodejs/node", 351 | "ocaml/opam-repository", 352 | "apache/airflow", 353 | "scipy/scipy", 354 | "vercel/next.js", 355 | ] 356 | 357 | resp = one_step_agent(query, repos) 358 | 359 | conn = await get_conn() 360 | limit = 10 361 | 362 | tools = [tool for tool in resp] 363 | print(f"Tools: {tools}") 364 | 365 | result = [] 366 | for tool in tools: 367 | tool_call = await tool.execute(conn, limit) 368 | result.append({"tool": tool, "result": tool_call}) 369 | 370 | summary = summarize_content(result, query) 371 | print(summary.summary) 372 | 373 | 374 | if __name__ == "__main__": 375 | load_dotenv(dotenv_path=".env", override=True) 376 | run(main()) 377 | -------------------------------------------------------------------------------- /scripts/embedding_search.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, Field 2 | from typing import Optional 3 | from openai import OpenAI 4 | from jinja2 import Template 5 | from asyncpg import Connection 6 | import asyncpg 7 | from pgvector.asyncpg import register_vector 8 | from dotenv import load_dotenv 9 | from asyncio import run 10 | import os 11 | 12 | 13 | class SearchIssues(BaseModel): 14 | """ 15 | Use this when the user wants to get original issue information from the database 16 | """ 17 | 18 | query: Optional[str] 19 | repo: str = Field( 20 | description="the repo to search for issues in, should be in the format of 'owner/repo'" 21 | ) 22 | 23 | async def execute(self, conn: Connection, limit: int): 24 | if self.query: 25 | embedding = ( 26 | OpenAI() 27 | .embeddings.create(input=self.query, model="text-embedding-3-small") 28 | .data[0] 29 | .embedding 30 | ) 31 | args = [self.repo, limit, embedding] 32 | else: 33 | args = [self.repo, limit] 34 | embedding = None 35 | 36 | sql_query = Template( 37 | """ 38 | SELECT * 39 | FROM {{ table_name }} 40 | WHERE repo_name = $1 41 | {%- if embedding is not none %} 42 | ORDER BY embedding <=> $3 43 | {%- endif %} 44 | LIMIT $2 45 | """ 46 | ).render(table_name="github_issues", embedding=embedding) 47 | 48 | return await conn.fetch(sql_query, *args) 49 | 50 | 51 | class RunSQLReturnPandas(BaseModel): 52 | """ 53 | Use this function when the user wants to do time series analysis or data analysis and we don't have a tool that can supply the necessary information 54 | """ 55 | 56 | query: str = Field(description="Description of user's query") 57 | repos: list[str] = Field( 58 | description="the repos to run the query on, should be in the format of 'owner/repo'" 59 | ) 60 | 61 | async def execute(self, conn: Connection, limit: int): 62 | pass 63 | 64 | 65 | class SearchSummaries(BaseModel): 66 | """ 67 | This function retrieves summarized information about GitHub issues that match/are similar to a specific query, It's particularly useful for obtaining a quick snapshot of issue trends or patterns within a project. 68 | """ 69 | 70 | query: Optional[str] = Field(description="Relevant user query if any") 71 | repo: str = Field( 72 | description="the repo to search for issues in, should be in the format of 'owner/repo'" 73 | ) 74 | 75 | async def execute(self, conn: Connection, limit: int): 76 | if self.query: 77 | embedding = ( 78 | OpenAI() 79 | .embeddings.create(input=self.query, model="text-embedding-3-small") 80 | .data[0] 81 | .embedding 82 | ) 83 | args = [self.repo, limit, embedding] 84 | else: 85 | args = [self.repo, limit] 86 | embedding = None 87 | 88 | sql_query = Template( 89 | """ 90 | SELECT * 91 | FROM {{ table_name }} 92 | WHERE repo_name = $1 93 | {%- if embedding is not none %} 94 | ORDER BY embedding <=> $3 95 | {%- endif %} 96 | LIMIT $2 97 | """ 98 | ).render(table_name="github_issue_summaries", embedding=embedding) 99 | 100 | return await conn.fetch(sql_query, *args) 101 | 102 | 103 | async def get_conn(): 104 | conn = await asyncpg.connect(os.getenv("DB_URL")) 105 | await register_vector(conn) 106 | return conn 107 | 108 | 109 | async def main(): 110 | query = ( 111 | "What are the main problems people are facing with installation with Kubernetes" 112 | ) 113 | 114 | conn = await get_conn() 115 | limit = 10 116 | resp = await SearchSummaries(query=query, repo="kubernetes/kubernetes").execute( 117 | conn, limit 118 | ) 119 | 120 | for row in resp[:3]: 121 | print(row["text"]) 122 | 123 | 124 | if __name__ == "__main__": 125 | load_dotenv(dotenv_path=".env", override=True) 126 | run(main()) 127 | -------------------------------------------------------------------------------- /scripts/eval.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, Field 2 | from typing import Optional 3 | from typing import Iterable, Union 4 | import instructor 5 | import openai 6 | 7 | 8 | class SearchIssues(BaseModel): 9 | """ 10 | Use this when the user wants to get original issue information from the database 11 | """ 12 | 13 | query: Optional[str] 14 | repo: str = Field( 15 | description="the repo to search for issues in, should be in the format of 'owner/repo'" 16 | ) 17 | 18 | 19 | class RunSQLReturnPandas(BaseModel): 20 | """ 21 | Use this function when the user wants to do time series analysis or data analysis and we don't have a tool that can supply the necessary information 22 | """ 23 | 24 | query: str = Field(description="Description of user's query") 25 | repos: list[str] = Field( 26 | description="the repos to run the query on, should be in the format of 'owner/repo'" 27 | ) 28 | 29 | 30 | class SearchSummaries(BaseModel): 31 | """ 32 | This function retrieves summarized information about GitHub issues that match/are similar to a specific query, It's particularly useful for obtaining a quick snapshot of issue trends or patterns within a project. 33 | """ 34 | 35 | query: Optional[str] = Field(description="Relevant user query if any") 36 | repo: str = Field( 37 | description="the repo to search for issues in, should be in the format of 'owner/repo'" 38 | ) 39 | 40 | 41 | def one_step_agent(question: str): 42 | client = instructor.from_openai( 43 | openai.OpenAI(), mode=instructor.Mode.PARALLEL_TOOLS 44 | ) 45 | 46 | return client.chat.completions.create( 47 | model="gpt-4o-mini", 48 | messages=[ 49 | { 50 | "role": "system", 51 | "content": "You are an AI assistant that helps users query and analyze GitHub issues stored in a PostgreSQL database. Search for summaries when the user wants to understand the high level trends or patterns within a project. Otherwise just get the issues and return them. Only resort to SQL queries if the other tools are not able to answer the user's query.", 52 | }, 53 | {"role": "user", "content": question}, 54 | ], 55 | response_model=Iterable[ 56 | Union[ 57 | RunSQLReturnPandas, 58 | SearchIssues, 59 | SearchSummaries, 60 | ] 61 | ], 62 | ) 63 | 64 | 65 | if __name__ == "__main__": 66 | tests = [ 67 | [ 68 | "What is the average time to first response for issues in the azure repository over the last 6 months? Has this metric improved or worsened?", 69 | [RunSQLReturnPandas], 70 | ], 71 | [ 72 | "How many issues mentioned issues with Cohere in the 'vercel/next.js' repository in the last 6 months?", 73 | [SearchIssues], 74 | ], 75 | [ 76 | "What were some of the big features that were implemented in the last 4 months for the scipy repo that addressed some previously open issues?", 77 | [SearchSummaries], 78 | ], 79 | ] 80 | 81 | for query, expected_result in tests: 82 | response = one_step_agent(query) 83 | for expected_call, agent_call in zip(expected_result, response): 84 | assert isinstance( 85 | agent_call, expected_call 86 | ), f"Expected {expected_call} but got {type(agent_call)}" 87 | 88 | print("All tests passed") 89 | -------------------------------------------------------------------------------- /scripts/ingest.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | from datetime import datetime 3 | from openai import AsyncOpenAI 4 | from asyncio import run, Semaphore 5 | from tqdm.asyncio import tqdm_asyncio as asyncio 6 | from textwrap import dedent 7 | from instructor import from_openai 8 | from jinja2 import Template 9 | import os 10 | from pgvector.asyncpg import register_vector 11 | import asyncpg 12 | import json 13 | from dotenv import load_dotenv 14 | from pydantic import BaseModel 15 | from typing import Literal, Any, Optional 16 | 17 | 18 | class ClassifiedSummary(BaseModel): 19 | chain_of_thought: str 20 | label: Literal["OPEN", "CLOSED"] 21 | summary: str 22 | 23 | 24 | class ProcessedIssue(BaseModel): 25 | issue_id: int 26 | text: str 27 | label: Literal["OPEN", "CLOSED"] 28 | repo_name: str 29 | embedding: Optional[list[float]] 30 | 31 | 32 | class GithubIssue(BaseModel): 33 | issue_id: int 34 | metadata: dict[str, Any] 35 | text: str 36 | repo_name: str 37 | start_ts: datetime 38 | end_ts: Optional[datetime] 39 | embedding: Optional[list[float]] 40 | 41 | 42 | def get_issues(n: int, repos: list[str]): 43 | dataset = ( 44 | load_dataset("bigcode/the-stack-github-issues", split="train", streaming=True) 45 | .filter(lambda x: x["repo"] in repos) 46 | .take(n) 47 | ) 48 | 49 | for row in dataset: 50 | start_time = None 51 | end_time = None 52 | for event in row["events"]: 53 | event_type = event["action"] 54 | timestamp = event["datetime"] 55 | timestamp = timestamp.replace("Z", "+00:00") 56 | 57 | if event_type == "opened": 58 | start_time = datetime.fromisoformat(timestamp) 59 | 60 | elif event_type == "closed": 61 | end_time = datetime.fromisoformat(timestamp) 62 | 63 | # Small Fall Back here - Some issues have no Creation event 64 | elif event_type == "created" and not start_time: 65 | start_time = datetime.fromisoformat(timestamp) 66 | 67 | elif event_type == "reopened" and not start_time: 68 | start_time = datetime.fromisoformat(timestamp) 69 | 70 | yield GithubIssue( 71 | issue_id=row["issue_id"], 72 | metadata={}, 73 | text=row["content"], 74 | repo_name=row["repo"], 75 | start_ts=start_time, 76 | end_ts=end_time, 77 | embedding=None, 78 | ) 79 | 80 | 81 | async def batch_classify_issue( 82 | batch: list[GithubIssue], max_concurrent_requests: int = 20 83 | ) -> list[ProcessedIssue]: 84 | async def classify_issue(issue: GithubIssue, semaphore: Semaphore): 85 | client = from_openai(AsyncOpenAI()) 86 | async with semaphore: 87 | classification = await client.chat.completions.create( 88 | response_model=ClassifiedSummary, 89 | messages=[ 90 | { 91 | "role": "system", 92 | "content": "You are a helpful assistant that classifies and summarizes GitHub issues. When summarizing the issues, make sure to expand on specific accronyms and add additional explanation where necessary.", 93 | }, 94 | { 95 | "role": "user", 96 | "content": Template( 97 | dedent( 98 | """ 99 | Repo Name: {{ repo_name }} 100 | Issue Text: {{ issue_text}} 101 | """ 102 | ) 103 | ).render(repo_name=issue.repo_name, issue_text=issue.text), 104 | }, 105 | ], 106 | model="gpt-4o-mini", 107 | ) 108 | return ProcessedIssue( 109 | issue_id=issue.issue_id, 110 | repo_name=issue.repo_name, 111 | text=classification.summary, 112 | label=classification.label, 113 | embedding=None, 114 | ) 115 | 116 | semaphore = Semaphore(max_concurrent_requests) 117 | coros = [classify_issue(item, semaphore) for item in batch] 118 | results = await asyncio.gather(*coros) 119 | return results 120 | 121 | 122 | async def batch_embeddings( 123 | data: list[ProcessedIssue], 124 | max_concurrent_calls: int = 20, 125 | ) -> list[ProcessedIssue]: 126 | oai = AsyncOpenAI() 127 | 128 | async def embed_row( 129 | item: ProcessedIssue, 130 | semaphore: Semaphore, 131 | ): 132 | async with semaphore: 133 | input_text = item.text if len(item.text) < 8000 else item.text[:6000] 134 | embedding = ( 135 | ( 136 | await oai.embeddings.create( 137 | input=input_text, model="text-embedding-3-small" 138 | ) 139 | ) 140 | .data[0] 141 | .embedding 142 | ) 143 | item.embedding = embedding 144 | return item 145 | 146 | semaphore = Semaphore(max_concurrent_calls) 147 | coros = [embed_row(item, semaphore) for item in data] 148 | results = await asyncio.gather(*coros) 149 | return results 150 | 151 | 152 | async def get_conn(): 153 | conn = await asyncpg.connect(os.getenv("DB_URL")) 154 | await register_vector(conn) 155 | return conn 156 | 157 | 158 | async def insert_github_issue_summaries(conn, issues: list[GithubIssue]): 159 | insert_query = """ 160 | INSERT INTO github_issue_summaries (issue_id, text, label, embedding,repo_name) 161 | VALUES ($1, $2, $3, $4, $5) 162 | """ 163 | summarized_issues = await batch_classify_issue(issues) 164 | embedded_summaries = await batch_embeddings(summarized_issues) 165 | 166 | await conn.executemany( 167 | insert_query, 168 | [ 169 | (item.issue_id, item.text, item.label, item.embedding, item.repo_name) 170 | for item in embedded_summaries 171 | ], 172 | ) 173 | 174 | print("GitHub issue summaries inserted successfully.") 175 | 176 | 177 | async def insert_github_issues(conn, issues: list[GithubIssue]): 178 | insert_query = """ 179 | INSERT INTO github_issues (issue_id, metadata, text, repo_name, start_ts, end_ts, embedding) 180 | VALUES ($1, $2, $3, $4, $5, $6, $7) 181 | """ 182 | embedded_issues = await batch_embeddings(issues) 183 | 184 | await conn.executemany( 185 | insert_query, 186 | [ 187 | ( 188 | item.issue_id, 189 | json.dumps(item.metadata), 190 | item.text, 191 | item.repo_name, 192 | item.start_ts, 193 | item.end_ts, 194 | item.embedding, 195 | ) 196 | for item in embedded_issues 197 | ], 198 | ) 199 | print("GitHub issues inserted successfully.") 200 | 201 | 202 | async def setup_db(conn: asyncpg.Connection): 203 | init_sql = """ 204 | CREATE EXTENSION IF NOT EXISTS vectorscale CASCADE; 205 | 206 | DROP TABLE IF EXISTS github_issue_summaries CASCADE; 207 | DROP TABLE IF EXISTS github_issues CASCADE; 208 | 209 | CREATE TABLE IF NOT EXISTS github_issues ( 210 | issue_id INTEGER, 211 | metadata JSONB, 212 | text TEXT, 213 | repo_name TEXT, 214 | start_ts TIMESTAMPTZ NOT NULL, 215 | end_ts TIMESTAMPTZ, 216 | embedding VECTOR(1536) NOT NULL 217 | ); 218 | 219 | CREATE INDEX github_issue_embedding_idx 220 | ON github_issues 221 | USING diskann (embedding); 222 | 223 | -- Create a Hypertable that breaks it down by 1 month intervals 224 | SELECT create_hypertable('github_issues', 'start_ts', chunk_time_interval => INTERVAL '1 month'); 225 | 226 | CREATE UNIQUE INDEX ON github_issues (issue_id, start_ts); 227 | 228 | CREATE TABLE github_issue_summaries ( 229 | issue_id INTEGER, 230 | text TEXT, 231 | label issue_label NOT NULL, 232 | repo_name TEXT, 233 | embedding VECTOR(1536) NOT NULL 234 | ); 235 | 236 | CREATE INDEX github_issue_summaries_embedding_idx 237 | ON github_issue_summaries 238 | USING diskann (embedding); 239 | """ 240 | 241 | await conn.execute(init_sql) 242 | 243 | 244 | async def process_issues(n_issues: int, repos: list[str], conn: asyncpg.Connection): 245 | issues = list(get_issues(n_issues, repos)) 246 | await insert_github_issues(conn, issues) 247 | await insert_github_issue_summaries(conn, issues) 248 | 249 | 250 | async def main(): 251 | repos = [ 252 | "rust-lang/rust", 253 | "kubernetes/kubernetes", 254 | "apache/spark", 255 | "golang/go", 256 | "tensorflow/tensorflow", 257 | "MicrosoftDocs/azure-docs", 258 | "pytorch/pytorch", 259 | "Microsoft/TypeScript", 260 | "python/cpython", 261 | "facebook/react", 262 | "django/django", 263 | "rails/rails", 264 | "bitcoin/bitcoin", 265 | "nodejs/node", 266 | "ocaml/opam-repository", 267 | "apache/airflow", 268 | "scipy/scipy", 269 | "vercel/next.js", 270 | ] 271 | conn = await get_conn() 272 | try: 273 | n_issues = 400 274 | await setup_db(conn) 275 | await process_issues(n_issues, repos, conn) 276 | finally: 277 | await conn.close() 278 | 279 | 280 | if __name__ == "__main__": 281 | load_dotenv(dotenv_path=".env", override=True) 282 | run(main()) 283 | --------------------------------------------------------------------------------