├── .env.local ├── README.md ├── .env.production ├── app ├── config.py ├── test_connection.py ├── query_data.py └── load_pdf.py ├── requirements.txt └── .gitignore /.env.local: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.env.production: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /app/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | DB_HOST = os.getenv('DB_HOST') 4 | DB_USER = os.getenv('DB_USER') 5 | DB_PASS = os.getenv('DB_PASS') 6 | DB_NAME = os.getenv('DB_NAME') 7 | OPENAI_API_KEY = os.getenv('OPENAI_API_KEY') 8 | DB_PORT = 5432 9 | 10 | 11 | 12 | -------------------------------------------------------------------------------- /app/test_connection.py: -------------------------------------------------------------------------------- 1 | # test_connection.py 2 | import psycopg2 3 | from config import DB_HOST, DB_NAME, DB_USER, DB_PASS, DB_PORT 4 | 5 | try: 6 | conn = psycopg2.connect( 7 | host=DB_HOST, 8 | database=DB_NAME, 9 | user=DB_USER, 10 | password=DB_PASS, 11 | port=DB_PORT 12 | ) 13 | print("Database connection successful!") 14 | except Exception as e: 15 | print(f"Database connection failed: {e}") 16 | finally: 17 | if conn: 18 | conn.close() 19 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | aiohappyeyeballs==2.4.3 2 | aiohttp==3.10.10 3 | aiosignal==1.3.1 4 | annotated-types==0.7.0 5 | anyio==4.6.2.post1 6 | async-timeout==4.0.3 7 | asyncpg==0.30.0 8 | attrs==24.2.0 9 | certifi==2024.8.30 10 | charset-normalizer==3.4.0 11 | click==8.1.7 12 | dnspython==2.7.0 13 | email_validator==2.2.0 14 | exceptiongroup==1.2.2 15 | fastapi==0.115.4 16 | fastapi-cli==0.0.5 17 | filelock==3.16.1 18 | frozenlist==1.5.0 19 | fsspec==2024.10.0 20 | greenlet==3.1.1 21 | h11==0.14.0 22 | httpcore==1.0.6 23 | httptools==0.6.4 24 | httpx==0.27.2 25 | huggingface-hub==0.26.2 26 | idna==3.10 27 | Jinja2==3.1.4 28 | jsonpatch==1.33 29 | jsonpointer==3.0.0 30 | langchain==0.3.7 31 | langchain-core==0.3.15 32 | langchain-text-splitters==0.3.2 33 | langsmith==0.1.142 34 | markdown-it-py==3.0.0 35 | MarkupSafe==3.0.2 36 | mdurl==0.1.2 37 | multidict==6.1.0 38 | numpy==1.26.4 39 | orjson==3.10.11 40 | packaging==24.2 41 | propcache==0.2.0 42 | psycopg2==2.9.10 43 | pydantic==2.9.2 44 | pydantic_core==2.23.4 45 | Pygments==2.18.0 46 | python-dotenv==1.0.1 47 | python-multipart==0.0.17 48 | PyYAML==6.0.2 49 | regex==2024.11.6 50 | requests==2.32.3 51 | requests-toolbelt==1.0.0 52 | rich==13.9.4 53 | safetensors==0.4.5 54 | shellingham==1.5.4 55 | sniffio==1.3.1 56 | SQLAlchemy==2.0.36 57 | starlette==0.41.2 58 | tenacity==9.0.0 59 | tokenizers==0.20.3 60 | tqdm==4.67.0 61 | transformers==4.46.2 62 | typer==0.13.0 63 | typing_extensions==4.12.2 -------------------------------------------------------------------------------- /app/query_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import psycopg2 3 | from psycopg2.extras import RealDictCursor 4 | from langchain.prompts import ChatPromptTemplate 5 | from langchain_community.embeddings import OpenAIEmbeddings 6 | #from ollama import generate 7 | from config import DB_HOST, DB_NAME, DB_USER, DB_PASS 8 | 9 | PROMPT_TEMPLATE = """ 10 | Answer the question based only on the following context: 11 | 12 | {context} 13 | 14 | --- 15 | 16 | Answer the question based on the above context: {question} 17 | """ 18 | 19 | def main(): 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument("query_text", type=str, help="The query text.") 22 | args = parser.parse_args() 23 | query_text = args.query_text 24 | query_rag(query_text) 25 | 26 | def query_rag(query_text: str): 27 | embedding_function = OpenAIEmbeddings().embed_query 28 | query_vector = embedding_function(query_text) 29 | 30 | # Query PostgreSQL for relevant chunks 31 | results = query_postgres(query_vector, top_n=5) 32 | context_text = "\n\n---\n\n".join([row['content'] for row in results]) 33 | 34 | # Format the prompt 35 | prompt = PROMPT_TEMPLATE.format(context=context_text, question=query_text) 36 | print(f"Generated prompt:\n{prompt}") 37 | 38 | # Use OpenAI's GPT for the response 39 | response = query_openai(prompt) 40 | print(f"Response:\n{response}") 41 | 42 | def query_postgres(query_vector, top_n=5): 43 | connection = psycopg2.connect( 44 | host=DB_HOST, 45 | database=DB_NAME, 46 | user=DB_USER, 47 | password=DB_PASS 48 | ) 49 | cursor = connection.cursor(cursor_factory=RealDictCursor) 50 | 51 | query = """ 52 | SELECT chunk_id, content, embedding <-> %s::vector AS distance 53 | FROM document_chunks 54 | ORDER BY distance ASC 55 | LIMIT %s; 56 | """ 57 | cursor.execute(query, (query_vector, top_n)) 58 | results = cursor.fetchall() 59 | cursor.close() 60 | connection.close() 61 | return results 62 | 63 | def query_openai(prompt: str) -> str: 64 | openai.api_key = "" 65 | 66 | response = openai.ChatCompletion.create( 67 | model="gpt-4", # You can also use gpt-3.5-turbo for faster and cheaper responses 68 | messages=[ 69 | {"role": "system", "content": "You are an expert assistant."}, 70 | {"role": "user", "content": prompt} 71 | ], 72 | temperature=0.7 73 | ) 74 | return response['choices'][0]['message']['content'] 75 | 76 | if __name__ == "__main__": 77 | main() 78 | -------------------------------------------------------------------------------- /app/load_pdf.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import logging 4 | import psycopg2 5 | from psycopg2.extras import execute_batch 6 | from langchain.text_splitter import RecursiveCharacterTextSplitter 7 | from langchain.schema import Document 8 | from langchain.embeddings import OpenAIEmbeddings # Example, adjust if you use a different embedding function 9 | from langchain.document_loaders import PyPDFDirectoryLoader 10 | from config import DB_HOST, DB_NAME, DB_USER, DB_PASS 11 | 12 | logging.basicConfig(level=logging.INFO) 13 | logger = logging.getLogger(__name__) 14 | 15 | FILE_PATH = "data" # Folder for PDFs 16 | 17 | 18 | def main(): 19 | print("[main]") 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument("--reset", action="store_true", help="Reset the database.") 22 | args = parser.parse_args() 23 | 24 | if args.reset: 25 | print("✨ Clearing Database") 26 | clear_database() 27 | 28 | documents = load_documents(FILE_PATH) 29 | chunks = split_documents(documents) 30 | add_to_postgres(chunks) 31 | 32 | 33 | def load_documents(directory: str): 34 | """Load documents from a directory of PDFs.""" 35 | print("[load_documents] Loading PDFs from:", directory) 36 | loader = PyPDFDirectoryLoader(directory) 37 | return loader.load() 38 | 39 | 40 | def split_documents(documents: list[Document]): 41 | """Split documents into smaller chunks.""" 42 | text_splitter = RecursiveCharacterTextSplitter( 43 | chunk_size=800, 44 | chunk_overlap=80, 45 | length_function=len, 46 | ) 47 | return text_splitter.split_documents(documents) 48 | 49 | 50 | def get_embeddings(): 51 | """Return embedding function (adjust as necessary).""" 52 | return OpenAIEmbeddings().embed_query # Example function 53 | 54 | 55 | def add_to_postgres(chunks: list[Document]): 56 | """Store chunks and their embeddings in PostgreSQL.""" 57 | embeddings = get_embeddings() 58 | 59 | # Establish DB connection 60 | connection = psycopg2.connect( 61 | host=DB_HOST, 62 | database=DB_NAME, 63 | user=DB_USER, 64 | password=DB_PASS 65 | ) 66 | cursor = connection.cursor() 67 | 68 | # Ensure pgvector extension is ready 69 | cursor.execute("CREATE EXTENSION IF NOT EXISTS vector;") 70 | cursor.execute(""" 71 | CREATE TABLE IF NOT EXISTS document_chunks ( 72 | id SERIAL PRIMARY KEY, 73 | chunk_id TEXT, 74 | content TEXT, 75 | embedding VECTOR(1536) -- Adjust dimension to match your embeddings 76 | ); 77 | """) 78 | 79 | # Prepare batch insert 80 | insert_query = """ 81 | INSERT INTO document_chunks (chunk_id, content, embedding) 82 | VALUES (%s, %s, %s) 83 | """ 84 | records = [] 85 | for chunk in chunks: 86 | chunk_id = f"{chunk.metadata.get('source')}:{chunk.metadata.get('page')}" 87 | content = chunk.page_content 88 | embedding = embeddings(content) 89 | records.append((chunk_id, content, embedding)) 90 | 91 | execute_batch(cursor, insert_query, records) 92 | connection.commit() 93 | cursor.close() 94 | connection.close() 95 | 96 | print(f"Inserted {len(records)} chunks into PostgreSQL.") 97 | 98 | 99 | def clear_database(): 100 | """Clear the document_chunks table.""" 101 | connection = psycopg2.connect( 102 | host=DB_HOST, 103 | database=DB_NAME, 104 | user=DB_USER, 105 | password=DB_PASS 106 | ) 107 | cursor = connection.cursor() 108 | cursor.execute("DROP TABLE IF EXISTS document_chunks;") 109 | connection.commit() 110 | cursor.close() 111 | connection.close() 112 | print("Database cleared.") 113 | 114 | 115 | if __name__ == "__main__": 116 | main() 117 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.js 2 | !jest.config.js 3 | !functions/*.js 4 | *.d.ts 5 | node_modules 6 | 7 | # CDK asset staging directory 8 | .cdk.staging 9 | cdk.out 10 | 11 | # Created by https://www.gitignore.io/api/osx,linux,python,windows,pycharm,visualstudiocode,node 12 | # Edit at https://www.gitignore.io/?templates=osx,linux,python,windows,pycharm,visualstudiocode,node 13 | 14 | ### Linux ### 15 | *~ 16 | 17 | # temporary files which can be created if a process still has a handle open of a deleted file 18 | .fuse_hidden* 19 | 20 | # KDE directory preferences 21 | .directory 22 | 23 | # Linux trash folder which might appear on any partition or disk 24 | .Trash-* 25 | 26 | # .nfs files are created when an open file is removed but is still being accessed 27 | .nfs* 28 | 29 | ### Node ### 30 | # Logs 31 | logs 32 | *.log 33 | npm-debug.log* 34 | yarn-debug.log* 35 | yarn-error.log* 36 | lerna-debug.log* 37 | 38 | # Diagnostic reports (https://nodejs.org/api/report.html) 39 | report.[0-9]*.[0-9]*.[0-9]*.[0-9]*.json 40 | 41 | # Runtime data 42 | pids 43 | *.pid 44 | *.seed 45 | *.pid.lock 46 | 47 | # Directory for instrumented libs generated by jscoverage/JSCover 48 | lib-cov 49 | 50 | # Coverage directory used by tools like istanbul 51 | coverage 52 | *.lcov 53 | 54 | # nyc test coverage 55 | .nyc_output 56 | 57 | # Grunt intermediate storage (https://gruntjs.com/creating-plugins#storing-task-files) 58 | .grunt 59 | 60 | # Bower dependency directory (https://bower.io/) 61 | bower_components 62 | 63 | # node-waf configuration 64 | .lock-wscript 65 | 66 | # Compiled binary addons (https://nodejs.org/api/addons.html) 67 | build/Release 68 | 69 | # Dependency directories 70 | node_modules/ 71 | jspm_packages/ 72 | 73 | # TypeScript v1 declaration files 74 | typings/ 75 | 76 | # TypeScript cache 77 | *.tsbuildinfo 78 | 79 | # Optional npm cache directory 80 | .npm 81 | 82 | # Optional eslint cache 83 | .eslintcache 84 | 85 | # Optional REPL history 86 | .node_repl_history 87 | 88 | # Output of 'npm pack' 89 | *.tgz 90 | 91 | # Yarn Integrity file 92 | .yarn-integrity 93 | 94 | # dotenv environment variables file 95 | .env 96 | .env.test 97 | 98 | # parcel-bundler cache (https://parceljs.org/) 99 | .cache 100 | 101 | # next.js build output 102 | .next 103 | 104 | # nuxt.js build output 105 | .nuxt 106 | 107 | # vuepress build output 108 | .vuepress/dist 109 | 110 | # Serverless directories 111 | .serverless/ 112 | 113 | # FuseBox cache 114 | .fusebox/ 115 | 116 | # DynamoDB Local files 117 | .dynamodb/ 118 | 119 | ### OSX ### 120 | # General 121 | .DS_Store 122 | .AppleDouble 123 | .LSOverride 124 | 125 | # Icon must end with two \r 126 | Icon 127 | 128 | # Thumbnails 129 | ._* 130 | 131 | # Files that might appear in the root of a volume 132 | .DocumentRevisions-V100 133 | .fseventsd 134 | .Spotlight-V100 135 | .TemporaryItems 136 | .Trashes 137 | .VolumeIcon.icns 138 | .com.apple.timemachine.donotpresent 139 | 140 | # Directories potentially created on remote AFP share 141 | .AppleDB 142 | .AppleDesktop 143 | Network Trash Folder 144 | Temporary Items 145 | .apdisk 146 | 147 | ### PyCharm ### 148 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm 149 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 150 | 151 | # User-specific stuff 152 | .idea/**/workspace.xml 153 | .idea/**/tasks.xml 154 | .idea/**/usage.statistics.xml 155 | .idea/**/dictionaries 156 | .idea/**/shelf 157 | 158 | # Generated files 159 | .idea/**/contentModel.xml 160 | 161 | # Sensitive or high-churn files 162 | .idea/**/dataSources/ 163 | .idea/**/dataSources.ids 164 | .idea/**/dataSources.local.xml 165 | .idea/**/sqlDataSources.xml 166 | .idea/**/dynamic.xml 167 | .idea/**/uiDesigner.xml 168 | .idea/**/dbnavigator.xml 169 | 170 | # Gradle 171 | .idea/**/gradle.xml 172 | .idea/**/libraries 173 | 174 | # Gradle and Maven with auto-import 175 | # When using Gradle or Maven with auto-import, you should exclude module files, 176 | # since they will be recreated, and may cause churn. Uncomment if using 177 | # auto-import. 178 | .idea/*.xml 179 | .idea/*.iml 180 | .idea 181 | # .idea/modules 182 | # *.iml 183 | # *.ipr 184 | 185 | # CMake 186 | cmake-build-*/ 187 | 188 | # Mongo Explorer plugin 189 | .idea/**/mongoSettings.xml 190 | 191 | # File-based project format 192 | *.iws 193 | 194 | # IntelliJ 195 | out/ 196 | 197 | # mpeltonen/sbt-idea plugin 198 | .idea_modules/ 199 | 200 | # JIRA plugin 201 | atlassian-ide-plugin.xml 202 | 203 | # Cursive Clojure plugin 204 | .idea/replstate.xml 205 | 206 | # Crashlytics plugin (for Android Studio and IntelliJ) 207 | com_crashlytics_export_strings.xml 208 | crashlytics.properties 209 | crashlytics-build.properties 210 | fabric.properties 211 | 212 | # Editor-based Rest Client 213 | .idea/httpRequests 214 | 215 | # Android studio 3.1+ serialized cache file 216 | .idea/caches/build_file_checksums.ser 217 | 218 | ### PyCharm Patch ### 219 | # Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721 220 | 221 | # *.iml 222 | # modules.xml 223 | # .idea/misc.xml 224 | # *.ipr 225 | 226 | # Sonarlint plugin 227 | .idea/sonarlint 228 | 229 | ### Python ### 230 | # Byte-compiled / optimized / DLL files 231 | __pycache__/ 232 | *.py[cod] 233 | *$py.class 234 | 235 | # C extensions 236 | *.so 237 | 238 | # Distribution / packaging 239 | .Python 240 | build/ 241 | develop-eggs/ 242 | dist/ 243 | downloads/ 244 | eggs/ 245 | .eggs/ 246 | lib64/ 247 | parts/ 248 | sdist/ 249 | var/ 250 | wheels/ 251 | pip-wheel-metadata/ 252 | share/python-wheels/ 253 | *.egg-info/ 254 | .installed.cfg 255 | *.egg 256 | MANIFEST 257 | 258 | # PyInstaller 259 | # Usually these files are written by a python script from a template 260 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 261 | *.manifest 262 | *.spec 263 | 264 | # Installer logs 265 | pip-log.txt 266 | pip-delete-this-directory.txt 267 | 268 | # Unit test / coverage reports 269 | htmlcov/ 270 | .tox/ 271 | .nox/ 272 | .coverage 273 | .coverage.* 274 | nosetests.xml 275 | coverage.xml 276 | *.cover 277 | .hypothesis/ 278 | .pytest_cache/ 279 | 280 | # Translations 281 | *.mo 282 | *.pot 283 | 284 | # Django stuff: 285 | local_settings.py 286 | db.sqlite3 287 | db.sqlite3-journal 288 | 289 | # Flask stuff: 290 | instance/ 291 | .webassets-cache 292 | 293 | # Scrapy stuff: 294 | .scrapy 295 | 296 | # Sphinx documentation 297 | docs/_build/ 298 | 299 | # PyBuilder 300 | target/ 301 | 302 | # Jupyter Notebook 303 | .ipynb_checkpoints 304 | 305 | # IPython 306 | profile_default/ 307 | ipython_config.py 308 | 309 | # pyenv 310 | .python-version 311 | 312 | # pipenv 313 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 314 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 315 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 316 | # install all needed dependencies. 317 | #Pipfile.lock 318 | 319 | # celery beat schedule file 320 | celerybeat-schedule 321 | 322 | # SageMath parsed files 323 | *.sage.py 324 | 325 | # Environments 326 | .venv 327 | env/ 328 | venv/ 329 | ENV/ 330 | env.bak/ 331 | venv.bak/ 332 | 333 | # Spyder project settings 334 | .spyderproject 335 | .spyproject 336 | 337 | # Rope project settings 338 | .ropeproject 339 | 340 | # mkdocs documentation 341 | /site 342 | 343 | # mypy 344 | .mypy_cache/ 345 | .dmypy.json 346 | dmypy.json 347 | 348 | # Pyre type checker 349 | .pyre/ 350 | 351 | ### VisualStudioCode ### 352 | .vscode 353 | 354 | ### VisualStudioCode Patch ### 355 | # Ignore all local history of files 356 | .history 357 | 358 | ### Windows ### 359 | # Windows thumbnail cache files 360 | Thumbs.db 361 | Thumbs.db:encryptable 362 | ehthumbs.db 363 | ehthumbs_vista.db 364 | 365 | # Dump file 366 | *.stackdump 367 | 368 | # Folder config file 369 | [Dd]esktop.ini 370 | 371 | # Recycle Bin used on file shares 372 | $RECYCLE.BIN/ 373 | 374 | # Windows Installer files 375 | *.cab 376 | *.msi 377 | *.msix 378 | *.msm 379 | *.msp 380 | 381 | # Windows shortcuts 382 | *.lnk 383 | 384 | # End of https://www.gitignore.io/api/osx,linux,python,windows,pycharm,visualstudiocode,node 385 | 386 | ### CDK-specific ignores ### 387 | *.swp 388 | cdk.context.json 389 | package-lock.json 390 | yarn.lock 391 | .cdk.staging 392 | cdk.out 393 | 394 | #ek custom 395 | .idea 396 | idea/ 397 | data/ 398 | --------------------------------------------------------------------------------