├── .gitignore ├── assets └── images │ └── Chat_Your_Data.gif ├── Makefile ├── requirements.txt ├── README.md ├── schemas.py ├── callback.py ├── query_data.py ├── main.py └── templates └── index.html /.gitignore: -------------------------------------------------------------------------------- 1 | config/* -------------------------------------------------------------------------------- /assets/images/Chat_Your_Data.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sangyh/MockGPT4/HEAD/assets/images/Chat_Your_Data.gif -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: start 2 | start: 3 | uvicorn main:app --reload --port 9000 4 | 5 | .PHONY: format 6 | format: 7 | black . 8 | isort . -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | openai 2 | fastapi 3 | black 4 | isort 5 | websockets 6 | pydantic 7 | langchain 8 | uvicorn 9 | jinja2 10 | faiss-cpu 11 | bs4 12 | unstructured 13 | libmagic 14 | python-dotenv -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🦜️🔗 MockGPT4 2 | 3 | This repo is an implementation of a locally hosted GPT4 or GPT3.5 chatbot. 4 | Built with [LangChain](https://github.com/hwchase17/langchain/) and [FastAPI](https://fastapi.tiangolo.com/). 5 | 6 | The app leverages LangChain's streaming support and async API to update the page in real time for multiple users. 7 | 8 | ## ✅ Running locally 9 | Donwload and run this colab notebook-https://colab.research.google.com/drive/1tFQWh7CUryLX6PthFuZyNIpz3_nALXAm?authuser=1#scrollTo=FmvlU7Af2XhZ 10 | 11 | 12 | -------------------------------------------------------------------------------- /schemas.py: -------------------------------------------------------------------------------- 1 | """Schemas for the chat app.""" 2 | from pydantic import BaseModel, validator 3 | 4 | 5 | class ChatResponse(BaseModel): 6 | """Chat response schema.""" 7 | 8 | sender: str 9 | message: str 10 | type: str 11 | 12 | @validator("sender") 13 | def sender_must_be_bot_or_you(cls, v): 14 | if v not in ["bot", "you"]: 15 | raise ValueError("sender must be bot or you") 16 | return v 17 | 18 | @validator("type") 19 | def validate_message_type(cls, v): 20 | if v not in ["start", "stream", "end", "error", "info"]: 21 | raise ValueError("type must be start, stream or end") 22 | return v 23 | -------------------------------------------------------------------------------- /callback.py: -------------------------------------------------------------------------------- 1 | """Callback handlers used in the app.""" 2 | from typing import Any, Dict, List 3 | 4 | from langchain.callbacks.base import AsyncCallbackHandler 5 | 6 | from schemas import ChatResponse 7 | 8 | 9 | class StreamingLLMCallbackHandler(AsyncCallbackHandler): 10 | """Callback handler for streaming LLM responses.""" 11 | 12 | def __init__(self, websocket): 13 | self.websocket = websocket 14 | 15 | async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: 16 | resp = ChatResponse(sender="bot", message=token, type="stream") 17 | await self.websocket.send_json(resp.dict()) 18 | 19 | 20 | class QuestionGenCallbackHandler(AsyncCallbackHandler): 21 | """Callback handler for question generation.""" 22 | 23 | def __init__(self, websocket): 24 | self.websocket = websocket 25 | 26 | async def on_llm_start( 27 | self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any 28 | ) -> None: 29 | """Run when LLM starts running.""" 30 | resp = ChatResponse( 31 | sender="bot", message="Synthesizing question...", type="info" 32 | ) 33 | await self.websocket.send_json(resp.dict()) 34 | -------------------------------------------------------------------------------- /query_data.py: -------------------------------------------------------------------------------- 1 | """Create a ChatVectorDBChain for question/answering.""" 2 | from langchain.callbacks.base import AsyncCallbackManager 3 | from langchain.callbacks.tracers import LangChainTracer 4 | from langchain.chains.llm import LLMChain 5 | from langchain.llms import OpenAI 6 | from langchain.chat_models import ChatOpenAI 7 | from langchain.chains import ConversationChain 8 | from langchain.memory import ConversationBufferMemory 9 | from langchain.prompts.chat import ( 10 | ChatPromptTemplate, 11 | MessagesPlaceholder, 12 | SystemMessagePromptTemplate, 13 | HumanMessagePromptTemplate, 14 | ) 15 | 16 | SystemMessagePrompt = "You are ChatGPT, a large language model trained by OpenAI to have friendly conversations with humans." 17 | prompt = ChatPromptTemplate.from_messages([ 18 | SystemMessagePromptTemplate.from_template(SystemMessagePrompt), 19 | MessagesPlaceholder(variable_name="history"), 20 | HumanMessagePromptTemplate.from_template("{input}") 21 | ]) 22 | 23 | 24 | memory = ConversationBufferMemory(return_messages=True) 25 | 26 | def get_chain( 27 | question_handler, stream_handler, tracing: bool = False 28 | ) -> ConversationChain: 29 | """Create a ConversationChain for question/answering.""" 30 | 31 | manager = AsyncCallbackManager([]) 32 | stream_manager = AsyncCallbackManager([stream_handler]) 33 | 34 | streaming_llm = ChatOpenAI( 35 | streaming=True, 36 | callback_manager=stream_manager, 37 | verbose=True, 38 | temperature=0, 39 | model='gpt-4', #'gpt-3.5-turbo' 40 | ) 41 | 42 | memory = ConversationBufferMemory(return_messages=True) 43 | conversation = ConversationChain( 44 | memory=memory, 45 | prompt=prompt, 46 | llm=streaming_llm, 47 | callback_manager=manager 48 | ) 49 | return conversation 50 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """Main entrypoint for the app.""" 2 | import logging 3 | from pathlib import Path 4 | from typing import Optional 5 | import openai 6 | from pathlib import Path 7 | from dotenv import load_dotenv 8 | load_dotenv(Path("config/.env")) 9 | import os 10 | openai.api_key = os.getenv("OPENAI_API_KEY") 11 | 12 | from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect 13 | from fastapi.templating import Jinja2Templates 14 | 15 | from callback import QuestionGenCallbackHandler, StreamingLLMCallbackHandler 16 | from query_data import get_chain 17 | from schemas import ChatResponse 18 | 19 | app = FastAPI() 20 | templates = Jinja2Templates(directory="templates") 21 | 22 | @app.on_event("startup") 23 | async def startup_event(): 24 | logging.info("loading") 25 | 26 | @app.get("/") 27 | async def get(request: Request): 28 | return templates.TemplateResponse("index.html", {"request": request}) 29 | 30 | @app.websocket("/chat") 31 | async def websocket_endpoint(websocket: WebSocket): 32 | await websocket.accept() 33 | question_handler = QuestionGenCallbackHandler(websocket) 34 | stream_handler = StreamingLLMCallbackHandler(websocket) 35 | chat_history = [] 36 | qa_chain = get_chain(question_handler, stream_handler) 37 | 38 | 39 | while True: 40 | try: 41 | # Receive and send back the client message 42 | question = await websocket.receive_text() 43 | resp = ChatResponse(sender="you", message=question, type="stream") 44 | await websocket.send_json(resp.dict()) 45 | 46 | # Construct a response 47 | start_resp = ChatResponse(sender="bot", message="", type="start") 48 | await websocket.send_json(start_resp.dict()) 49 | 50 | result = await qa_chain.acall( 51 | {"input": question} 52 | ) 53 | 54 | end_resp = ChatResponse(sender="bot", message="", type="end") 55 | await websocket.send_json(end_resp.dict()) 56 | except WebSocketDisconnect: 57 | logging.info("websocket disconnect") 58 | break 59 | except Exception as e: 60 | logging.error(e) 61 | resp = ChatResponse( 62 | sender="bot", 63 | message="Sorry, something went wrong. Try again.", 64 | type="error", 65 | ) 66 | await websocket.send_json(resp.dict()) 67 | 68 | 69 | if __name__ == "__main__": 70 | import uvicorn 71 | 72 | uvicorn.run(app, host="0.0.0.0", port=9000) 73 | -------------------------------------------------------------------------------- /templates/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | MockGPT4 6 | 7 | 53 | 125 | 126 | 127 |
128 |
129 |

MockGPT4

130 | 131 |
132 |
133 |
134 |
135 | 136 | 137 |
138 |
139 |
140 | 141 | --------------------------------------------------------------------------------