├── .gitignore ├── README.md ├── app.py ├── data └── .gitkeep ├── helpers ├── __init__.py ├── inference.py ├── inference_chatgpt.py └── trainer.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | ## Developer defined 2 | 3 | migrations/db_dumps 4 | client_secret.json 5 | *.log 6 | .idea 7 | .vscode 8 | *.env 9 | .DS_Store 10 | # Byte-compiled / optimized / DLL files 11 | __pycache__/ 12 | *.py[cod] 13 | *$py.class 14 | 15 | # C extensions 16 | *.so 17 | 18 | # Distribution / packaging 19 | .Python 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib/ 27 | lib64/ 28 | parts/ 29 | sdist/ 30 | var/ 31 | wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | MANIFEST 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # celery beat schedule file 88 | celerybeat-schedule 89 | 90 | # SageMath parsed files 91 | *.sage.py 92 | 93 | # Environments 94 | .env 95 | .venv 96 | env/ 97 | venv/ 98 | ENV/ 99 | env.bak/ 100 | venv.bak/ 101 | 102 | # Spyder project settings 103 | .spyderproject 104 | .spyproject 105 | 106 | # Rope project settings 107 | .ropeproject 108 | 109 | # mkdocs documentation 110 | /site 111 | 112 | # mypy 113 | .mypy_cache/.dmypy.json 114 | .dmypy.json 115 | tmp 116 | 117 | # file types 118 | *.csv 119 | *.wav 120 | *.mp3 121 | *.mp4 122 | *.m4a 123 | *.png 124 | *.jpg 125 | *.tiff 126 | *.jsonl 127 | 128 | helpers/davinci_text_model.ipynb -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Interview-AI-GPT3 2 | 3 | A simple chatbot built using OpenAI GPT-3. We used a dataset from another repo as explained in this [Article](https://medium.com/@olahsymbo/fine-tuning-openai-gpt-3-to-build-custom-chatbot-fe2dea524561) 4 | 5 | ### Update: 6 | 7 | Added ChatGPT and an approach to handle conversation drift 8 | 9 | ## Getting Started 10 | Create a virtual environment and install the dependencies 11 | 12 | ``` 13 | virtualenv .virtualenv 14 | source .virtualenv/bin/activate 15 | 16 | pip install -r requirements.txt 17 | ``` 18 | 19 | Create a `.env` file and place OpenAI API Key in the file, e.g 20 | 21 | ``` 22 | OPENAI_KEY=REPLACE-WITH-KEY-FROM-OPENAI 23 | ``` 24 | 25 | ### Finetuning the model 26 | 27 | Simply run: 28 | 29 | ``` 30 | python helpers/trainer.py 31 | ``` 32 | 33 | ### Test the model 34 | 35 | Run: 36 | 37 | ``` 38 | python helpers/inference.py 39 | ``` 40 | You can always change the input text 41 | 42 | ### Test the model with Flask 43 | 44 | Run: 45 | 46 | ``` 47 | python app.py 48 | ``` 49 | 50 | Example request: 51 | 52 | ``` 53 | curl --location 'http://127.0.0.1:5000/chat_openai_no_drift' \ 54 | --header 'Content-Type: application/json' \ 55 | --data '{ 56 | "text": "what is linked list", 57 | "session_id": "KJN=87aGNw" 58 | }' 59 | ``` 60 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import traceback 4 | 5 | import openai 6 | from flask import Flask, request, jsonify 7 | from dotenv import load_dotenv 8 | 9 | from helpers.inference_chatgpt import generate_response_chatgpt, generate_response_chatgpt_no_drift 10 | 11 | load_dotenv() 12 | 13 | app = Flask(__name__) 14 | 15 | openai.api_key = os.environ.get('OPENAI_KEY') 16 | 17 | app.debug = True 18 | 19 | 20 | @app.route("/chat_openai", methods=['POST']) 21 | def get_bot_response_chatgpt(): 22 | if request.method == 'POST': 23 | output_text = [] 24 | try: 25 | user_text = request.get_json()["text"] 26 | output_text = generate_response_chatgpt(user_text) 27 | return jsonify({"error": None, 28 | "message": str(output_text), 29 | "status": 200}), 200 30 | 31 | except Exception as error: 32 | return jsonify({"error": True, 33 | "message": "Can't to process the input questions", 34 | "status": 400}), 400 35 | else: 36 | return jsonify(success=False), 405 37 | 38 | 39 | session = {"session_id": ""} 40 | interview_history = [] 41 | 42 | 43 | @app.route("/chat_openai_no_drift", methods=['POST']) 44 | def get_bot_response_chatgpt_no_drift(): 45 | if request.method == 'POST': 46 | output_text = [] 47 | try: 48 | user_text = request.get_json()["text"] 49 | session_id = request.get_json()["session_id"] 50 | print(session_id) 51 | print(session["session_id"]) 52 | if session_id == session["session_id"]: 53 | output_text = generate_response_chatgpt_no_drift(user_text, interview_history) 54 | interview = { 55 | "question": user_text, 56 | "answer": output_text 57 | } 58 | interview_history.append(interview) 59 | else: 60 | session["session_id"] = session_id 61 | interview_history2 = [] 62 | output_text = generate_response_chatgpt_no_drift(user_text, interview_history2) 63 | interview = { 64 | "question": user_text, 65 | "answer": output_text 66 | } 67 | interview_history2.append(interview) 68 | 69 | return jsonify({"message": str(output_text), 70 | "error": None, 71 | "status": 200}), 200 72 | 73 | except Exception as error: 74 | logging.error([error]) 75 | return jsonify({"error": True, 76 | "message": "Can't to process the input questions", 77 | "status": 400}), 400 78 | else: 79 | return jsonify({"error": "Not Allowed", 80 | "status": 405}), 405 81 | 82 | 83 | if __name__ == "__main__": 84 | app.run(host="0.0.0.0") 85 | -------------------------------------------------------------------------------- /data/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/olahsymbo/interview-ai-gpt3/316384f2ba7c80696743c4c0d3407f6eb16ae076/data/.gitkeep -------------------------------------------------------------------------------- /helpers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/olahsymbo/interview-ai-gpt3/316384f2ba7c80696743c4c0d3407f6eb16ae076/helpers/__init__.py -------------------------------------------------------------------------------- /helpers/inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import openai 3 | 4 | from dotenv import load_dotenv 5 | 6 | load_dotenv() 7 | 8 | openai.api_key = os.getenv('OPENAI_KEY') 9 | 10 | 11 | def generate_response(input_text): 12 | response = openai.Completion.create( 13 | engine="davinci:ft-personal-2023-01-25-19-20-17", 14 | prompt="The following is a conversation with DSA an AI assistant. " 15 | "DSA is an interview bot who is very helpful and knowledgeable in data structure and algorithms.\n\n" 16 | "Human: Hello, who are you?\n" 17 | "DSA: I am DSA, an interview digital assistant. How can I help you today?\n" 18 | "Human: {}\nDSA:".format(input_text), 19 | temperature=0.9, 20 | max_tokens=150, 21 | top_p=1, 22 | frequency_penalty=0.0, 23 | presence_penalty=0.6, 24 | stop=["\n", " Human:", " DSA:"] 25 | ) 26 | return response.choices[0].text.strip() 27 | 28 | 29 | if __name__ == "__main__": 30 | input_text = "what is breadth first search algorithm" 31 | output = generate_response(input_text) 32 | print(output) -------------------------------------------------------------------------------- /helpers/inference_chatgpt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import openai 3 | 4 | from dotenv import load_dotenv 5 | 6 | load_dotenv() 7 | 8 | openai.api_key = os.getenv('OPENAI_KEY') 9 | 10 | 11 | def generate_response_chatgpt(input_text): 12 | prompt = [{"role": "system", 13 | "content": "You are DSA, a large language model for answering data structure and algorithm questions. " 14 | "Answer as concisely as possible. Limit your response to 60 words. \nKnowledge cutoff: " 15 | "2023-03-01\nCurrent date: 2023-03-02"}, 16 | {"role": "user", "content": "who are you"}, 17 | {"role": "assistant", "content": "I am DSA, my purpose is to answer your questions on data structure " 18 | "and algorithms"}, 19 | {"role": "user", "content": "{}".format(input_text)}] 20 | completion = openai.ChatCompletion.create( 21 | model="gpt-3.5-turbo", 22 | messages=prompt 23 | ) 24 | return completion.choices[0].message.content.strip() 25 | 26 | 27 | def generate_response_chatgpt_no_drift(input_text, interview_history): 28 | prompt = [{"role": "system", 29 | "content": "You are DSA, a large language model for answering data structure and algorithm questions. " 30 | "Answer as concisely as possible. Limit your response to 60 words. \nKnowledge cutoff: " 31 | "2023-03-01\nCurrent date: 2023-03-02"}, 32 | {"role": "user", "content": "who are you"}, 33 | {"role": "assistant", "content": "I am DSA, my purpose is to answer your questions on data structure " 34 | "and algorithms"}, 35 | {"role": "user", "content": "{}".format(input_text)}] 36 | if interview_history is not []: 37 | for interview in interview_history: 38 | prompt.append({"role": "assistant", "content": "{}".format(interview["question"])}) 39 | prompt.append({"role": "user", "content": "{}".format(interview["answer"])}) 40 | 41 | completion = openai.ChatCompletion.create( 42 | model="gpt-3.5-turbo", 43 | messages=prompt 44 | ) 45 | return completion.choices[0].message.content.strip() 46 | 47 | 48 | 49 | # if __name__ == "__main__": 50 | # input_text = "what is breadth first search algorithm" 51 | # output = test_conversation(input_text, {"session_id": "100DFah388kwd"}) 52 | # print(output) 53 | -------------------------------------------------------------------------------- /helpers/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | import openai 5 | import pandas as pd 6 | from dotenv import load_dotenv 7 | 8 | load_dotenv() 9 | 10 | os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_KEY') 11 | 12 | openai.api_key = os.getenv('OPENAI_KEY') 13 | 14 | data = pd.read_csv('data/data.csv') 15 | 16 | new_df = pd.DataFrame({'Interview AI': data['Text'].iloc[::2].values, 'Human': data['Text'].iloc[1::2].values}) 17 | print(new_df.head(5)) 18 | 19 | output = [] 20 | for index, row in new_df.iterrows(): 21 | completion = '' 22 | line = {'prompt': row['Human'], 'completion': row['Interview AI']} 23 | 24 | output.append(line) 25 | 26 | 27 | with open('data/data.jsonl', 'w') as outfile: 28 | for i in output: 29 | json.dump(i, outfile) 30 | outfile.write('\n') 31 | 32 | os.system("openai tools fine_tunes.prepare_data -f 'data/data.jsonl' ") 33 | 34 | os.system("openai api fine_tunes.create -t 'data/data_prepared.jsonl' -m davinci ") 35 | 36 | # In case training is interrupted, to resume training use 37 | # os.system("openai api fine_tunes.follow -i ft-jl6Ofsj1vRHTuJTlxd5gI59v ") 38 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | aiohttp==3.8.3 2 | aiosignal==1.3.1 3 | async-timeout==4.0.2 4 | attrs==22.2.0 5 | certifi==2022.12.7 6 | charset-normalizer==2.1.1 7 | click==8.1.3 8 | Flask==2.2.2 9 | frozenlist==1.3.3 10 | idna==3.4 11 | importlib-metadata==6.0.0 12 | itsdangerous==2.1.2 13 | Jinja2==3.1.2 14 | MarkupSafe==2.1.2 15 | multidict==6.0.4 16 | numpy==1.24.1 17 | openai==0.27 18 | pandas==1.5.3 19 | python-dateutil==2.8.2 20 | python-dotenv==0.21.0 21 | pytz==2022.7.1 22 | requests==2.28.2 23 | six==1.16.0 24 | tqdm==4.64.1 25 | urllib3==1.26.14 26 | Werkzeug==2.2.2 27 | yarl==1.8.2 28 | zipp==3.11.0 29 | --------------------------------------------------------------------------------