├── models └── .gitkeep ├── users ├── __init__.py ├── tests.py ├── apps.py ├── utils.py ├── tasks.py ├── admin.py ├── models.py ├── urls.py ├── backends.py ├── serializers.py └── views.py ├── chatbot ├── __init__.py ├── tests.py ├── apps.py ├── admin.py ├── serializers.py ├── utils.py ├── urls.py ├── models.py ├── tasks.py └── views.py ├── site_settings ├── __init__.py ├── tests.py ├── apps.py ├── urls.py ├── serializers.py ├── views.py ├── admin.py └── models.py ├── static └── .gitkeep ├── training_model ├── __init__.py ├── tests.py ├── urls.py ├── apps.py ├── admin.py ├── models.py ├── views.py ├── faiss_helpers.py └── pinecone_helpers.py ├── config ├── __init__.py ├── celery.py ├── settings │ ├── __init__.py │ ├── stage.py │ ├── local.py │ ├── key_values.py │ ├── prod.py │ └── common.py ├── asgi.py ├── wsgi.py └── urls.py ├── templates └── admin │ └── training_model │ └── document │ └── change_form.html ├── requirements.txt ├── manage.py ├── filestructure.txt ├── example.env ├── .gitignore └── README.md /models/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /users/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /chatbot/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /site_settings/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /static/.gitkeep: -------------------------------------------------------------------------------- 1 | # Keep in git -------------------------------------------------------------------------------- /training_model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /chatbot/tests.py: -------------------------------------------------------------------------------- 1 | from django.test import TestCase 2 | 3 | # Create your tests here. 4 | -------------------------------------------------------------------------------- /users/tests.py: -------------------------------------------------------------------------------- 1 | from django.test import TestCase 2 | 3 | # Create your tests here. 4 | -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- 1 | from .celery import app as celery_app 2 | 3 | __all__ = ['celery_app'] 4 | -------------------------------------------------------------------------------- /site_settings/tests.py: -------------------------------------------------------------------------------- 1 | from django.test import TestCase 2 | 3 | # Create your tests here. 4 | -------------------------------------------------------------------------------- /training_model/tests.py: -------------------------------------------------------------------------------- 1 | from django.test import TestCase 2 | 3 | # Create your tests here. 4 | -------------------------------------------------------------------------------- /users/apps.py: -------------------------------------------------------------------------------- 1 | from django.apps import AppConfig 2 | 3 | 4 | class UsersConfig(AppConfig): 5 | default_auto_field = 'django.db.models.BigAutoField' 6 | name = 'users' 7 | -------------------------------------------------------------------------------- /chatbot/apps.py: -------------------------------------------------------------------------------- 1 | from django.apps import AppConfig 2 | 3 | 4 | class ChatbotConfig(AppConfig): 5 | default_auto_field = 'django.db.models.BigAutoField' 6 | name = 'chatbot' 7 | -------------------------------------------------------------------------------- /site_settings/apps.py: -------------------------------------------------------------------------------- 1 | from django.apps import AppConfig 2 | 3 | 4 | class SiteSettingsConfig(AppConfig): 5 | default_auto_field = 'django.db.models.BigAutoField' 6 | name = 'site_settings' 7 | -------------------------------------------------------------------------------- /training_model/urls.py: -------------------------------------------------------------------------------- 1 | from django.urls import path 2 | from .views import TrainView 3 | 4 | urlpatterns = [ 5 | path('train//', TrainView.as_view(), name='train_view'), 6 | ] 7 | -------------------------------------------------------------------------------- /training_model/apps.py: -------------------------------------------------------------------------------- 1 | from django.apps import AppConfig 2 | 3 | 4 | class TrainingModelConfig(AppConfig): 5 | default_auto_field = 'django.db.models.BigAutoField' 6 | name = 'training_model' 7 | -------------------------------------------------------------------------------- /templates/admin/training_model/document/change_form.html: -------------------------------------------------------------------------------- 1 | {% extends "admin/change_form.html" %} 2 | 3 | {% block submit_buttons_bottom %} 4 | {{ block.super }} 5 | {% block custom_buttons %} 6 |
7 | 8 |
9 | {% endblock %} 10 | {% endblock %} 11 | -------------------------------------------------------------------------------- /config/celery.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | from celery import Celery 4 | from django.conf import settings 5 | 6 | os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'config.settings') 7 | app = Celery('customize_chatgpt') 8 | 9 | app.config_from_object('django.conf:settings', namespace='CELERY') 10 | app.autodiscover_tasks(lambda: settings.INSTALLED_APPS) 11 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Django 2 | djangorestframework 3 | django-filter 4 | drf-yasg 5 | celery 6 | django-oauth-toolkit 7 | social-auth-app-django 8 | django-rest-auth 9 | django-cors-headers 10 | python-dotenv 11 | django-celery-results 12 | redis 13 | psycopg2-binary 14 | openai 15 | gunicorn 16 | django-storages 17 | boto3 18 | pycurl 19 | pillow 20 | langchain 21 | pinecone-client 22 | bs4 23 | tiktoken 24 | pypdf -------------------------------------------------------------------------------- /users/utils.py: -------------------------------------------------------------------------------- 1 | def jwt_response_payload_handler(token, user=None, request=None): 2 | """ 3 | Custom JWT response payload handler. 4 | """ 5 | return { 6 | 'token': token, 7 | 'user': { 8 | 'id': user.id, 9 | 'username': user.username, 10 | 'email': user.email, 11 | 'first_name': user.first_name, 12 | 'last_name': user.last_name, 13 | }, 14 | } 15 | -------------------------------------------------------------------------------- /config/settings/__init__.py: -------------------------------------------------------------------------------- 1 | """For production, we'll automatically generate settings from prod.py via ci/cd script""" 2 | import os 3 | from dotenv import load_dotenv 4 | 5 | load_dotenv() 6 | # DEV = False 7 | env_name = os.getenv('ENV_NAME', 'local') 8 | 9 | from .key_values import * 10 | 11 | if env_name == "prod": 12 | from .prod import * 13 | elif env_name == "stage": 14 | from .stage import * 15 | else: 16 | from .local import * 17 | -------------------------------------------------------------------------------- /users/tasks.py: -------------------------------------------------------------------------------- 1 | from celery import shared_task 2 | from django.core.mail import send_mail 3 | from django.conf import settings 4 | 5 | 6 | @shared_task 7 | def send_forgot_password_email(subject, message, recipient): 8 | """ 9 | Send an email for a forgotten password. 10 | """ 11 | send_mail( 12 | subject, 13 | message, 14 | settings.EMAIL_HOST_USER, 15 | [recipient], 16 | fail_silently=False, 17 | ) 18 | -------------------------------------------------------------------------------- /config/asgi.py: -------------------------------------------------------------------------------- 1 | """ 2 | ASGI config for config project. 3 | 4 | It exposes the ASGI callable as a module-level variable named ``application``. 5 | 6 | For more information on this file, see 7 | https://docs.djangoproject.com/en/4.1/howto/deployment/asgi/ 8 | """ 9 | 10 | import os 11 | 12 | from django.core.asgi import get_asgi_application 13 | 14 | os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'config.settings') 15 | 16 | application = get_asgi_application() 17 | -------------------------------------------------------------------------------- /config/wsgi.py: -------------------------------------------------------------------------------- 1 | """ 2 | WSGI config for config project. 3 | 4 | It exposes the WSGI callable as a module-level variable named ``application``. 5 | 6 | For more information on this file, see 7 | https://docs.djangoproject.com/en/4.1/howto/deployment/wsgi/ 8 | """ 9 | 10 | import os 11 | 12 | from django.core.wsgi import get_wsgi_application 13 | 14 | os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'config.settings') 15 | 16 | application = get_wsgi_application() 17 | -------------------------------------------------------------------------------- /site_settings/urls.py: -------------------------------------------------------------------------------- 1 | from django.urls import path 2 | from .views import SiteSettingList, LanguageList, AdList 3 | 4 | app_name = 'site_settings' 5 | 6 | urlpatterns = [ 7 | # SiteSettings list endpoint 8 | path('settings/', SiteSettingList.as_view(), name='settings-list'), 9 | 10 | # Languages list endpoint 11 | path('languages/', LanguageList.as_view(), name='languages-list'), 12 | 13 | # Ads list endpoint 14 | path('ads/', AdList.as_view(), name='ads-list'), 15 | ] 16 | -------------------------------------------------------------------------------- /users/admin.py: -------------------------------------------------------------------------------- 1 | from django.contrib import admin 2 | from django.contrib.auth.admin import UserAdmin 3 | from .models import CustomUser 4 | 5 | 6 | class CustomUserAdmin(UserAdmin): 7 | fieldsets = UserAdmin.fieldsets + ( 8 | (None, {"fields": ("phone_number", "address", "profile_picture")}), 9 | ) 10 | add_fieldsets = UserAdmin.add_fieldsets + ( 11 | (None, { 12 | 'classes': ('wide',), 13 | 'fields': ('phone_number', 'address', 'profile_picture') 14 | }), 15 | ) 16 | 17 | 18 | admin.site.register(CustomUser, CustomUserAdmin) 19 | -------------------------------------------------------------------------------- /users/models.py: -------------------------------------------------------------------------------- 1 | from django.contrib.auth.models import AbstractUser 2 | from django.db import models 3 | from django.utils.translation import gettext_lazy as _ 4 | 5 | 6 | class CustomUser(AbstractUser): 7 | phone_number = models.CharField(_("Phone number"), max_length=15, blank=True, null=True) 8 | address = models.TextField(_("Address"), blank=True, null=True) 9 | profile_picture = models.ImageField(_("Profile picture"), upload_to="profile_pictures/", blank=True, null=True) 10 | 11 | class Meta: 12 | verbose_name = _("User") 13 | verbose_name_plural = _("Users") 14 | -------------------------------------------------------------------------------- /config/settings/stage.py: -------------------------------------------------------------------------------- 1 | from .common import * # noqa 2 | 3 | ALLOWED_HOSTS = ["*"] 4 | INSTALLED_APPS += ["gunicorn", ] 5 | 6 | # Postgres 7 | DATABASES = { 8 | 'default': { 9 | 'ENGINE': 'django.db.backends.postgresql', 10 | 'NAME': os.getenv('DB_NAME'), 11 | 'USER': os.getenv('DB_USER'), 12 | 'PASSWORD': os.getenv('DB_PASSWORD'), 13 | 'HOST': os.getenv('DB_HOST', 'localhost'), 14 | 'PORT': os.getenv('DB_PORT'), 15 | } 16 | } 17 | 18 | # DATABASES = { 19 | # 'default': { 20 | # 'ENGINE': 'django.db.backends.sqlite3', 21 | # 'NAME': BASE_DIR / 'db.sqlite3', 22 | # } 23 | # } 24 | 25 | # Social 26 | SOCIAL_AUTH_REDIRECT_IS_HTTPS = True 27 | -------------------------------------------------------------------------------- /users/urls.py: -------------------------------------------------------------------------------- 1 | from django.urls import path 2 | from .views import UserRegistrationView, UserProfileView, GoogleLoginView, LoginView, LogoutView 3 | 4 | app_name = 'users' 5 | 6 | urlpatterns = [ 7 | # User registration endpoint 8 | path('register/', UserRegistrationView.as_view(), name='register'), 9 | 10 | # User profile retrieval and update endpoint 11 | path('profile/', UserProfileView.as_view(), name='profile'), 12 | 13 | # Login endpoint 14 | path('login/', LoginView.as_view(), name='login'), 15 | 16 | # Google OAuth2 login endpoint 17 | path('login/google/', GoogleLoginView.as_view(), name='google-login'), 18 | 19 | # Logout endpoint 20 | path("logout/", LogoutView.as_view(), name="logout"), 21 | ] 22 | -------------------------------------------------------------------------------- /training_model/admin.py: -------------------------------------------------------------------------------- 1 | from django.contrib import admin 2 | from django.urls import reverse 3 | from django.utils.html import format_html 4 | 5 | from .models import Document 6 | 7 | 8 | class DocumentAdmin(admin.ModelAdmin): 9 | """ 10 | Admin View for Document 11 | """ 12 | list_display = ('file_name', 'index_name', 'storage_type', 'is_trained', 'uploaded_at', 'train_button') 13 | search_fields = ('file', 'index_name', 'storage_type') 14 | list_filter = ('is_trained',) 15 | 16 | def train_button(self, obj): 17 | train_url = reverse('train_view', args=[obj.pk]) 18 | return format_html('{}', train_url, "Train") 19 | 20 | 21 | admin.site.register(Document, DocumentAdmin) 22 | -------------------------------------------------------------------------------- /manage.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Django's command-line utility for administrative tasks.""" 3 | import os 4 | import sys 5 | 6 | 7 | def main(): 8 | """Run administrative tasks.""" 9 | os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'config.settings') 10 | try: 11 | from django.core.management import execute_from_command_line 12 | except ImportError as exc: 13 | raise ImportError( 14 | "Couldn't import Django. Are you sure it's installed and " 15 | "available on your PYTHONPATH environment variable? Did you " 16 | "forget to activate a virtual environment?" 17 | ) from exc 18 | execute_from_command_line(sys.argv) 19 | 20 | 21 | if __name__ == '__main__': 22 | main() 23 | -------------------------------------------------------------------------------- /site_settings/serializers.py: -------------------------------------------------------------------------------- 1 | from rest_framework import serializers 2 | from .models import SiteSetting, Language, Ad 3 | 4 | 5 | class LanguageSerializer(serializers.ModelSerializer): 6 | """ 7 | Serializer for the Language model. 8 | """ 9 | 10 | class Meta: 11 | model = Language 12 | fields = ['name', 'code'] 13 | 14 | 15 | class SiteSettingSerializer(serializers.ModelSerializer): 16 | """ 17 | Serializer for the SiteSetting model. 18 | """ 19 | 20 | class Meta: 21 | model = SiteSetting 22 | fields = ['title', 'logo'] 23 | 24 | 25 | class AdSerializer(serializers.ModelSerializer): 26 | """ 27 | Serializer for the Ad model. 28 | """ 29 | 30 | class Meta: 31 | model = Ad 32 | fields = ['title', 'description', 'image'] 33 | -------------------------------------------------------------------------------- /chatbot/admin.py: -------------------------------------------------------------------------------- 1 | from django.contrib import admin 2 | from .models import Conversation, Message 3 | 4 | 5 | @admin.register(Conversation) 6 | class ConversationAdmin(admin.ModelAdmin): 7 | """ 8 | Admin site configuration for Conversation model. 9 | """ 10 | list_display = ('id', 'title', 'user', 'favourite', 'archive', 'created_at', 'updated_at',) 11 | list_filter = ('created_at', 'updated_at', 'favourite', 'archive',) 12 | search_fields = ('user__username', 'title',) 13 | 14 | 15 | @admin.register(Message) 16 | class MessageAdmin(admin.ModelAdmin): 17 | """ 18 | Admin site configuration for Message model. 19 | """ 20 | list_display = ('id', 'conversation', 'content', 'is_from_user', 'created_at') 21 | list_filter = ('is_from_user', 'conversation__user__username', 'created_at') 22 | search_fields = ('content',) 23 | ordering = ('-created_at',) 24 | -------------------------------------------------------------------------------- /chatbot/serializers.py: -------------------------------------------------------------------------------- 1 | from rest_framework import serializers 2 | 3 | from .models import Conversation, Message 4 | from .utils import time_since 5 | 6 | 7 | class MessageSerializer(serializers.ModelSerializer): 8 | """ 9 | Message serializer. 10 | """ 11 | 12 | class Meta: 13 | model = Message 14 | fields = ['id', 'conversation', 'content', 'is_from_user', 'in_reply_to', 'created_at', ] 15 | 16 | 17 | class ConversationSerializer(serializers.ModelSerializer): 18 | """ 19 | Conversation serializer. 20 | """ 21 | messages = MessageSerializer(many=True, read_only=True) 22 | created_at = serializers.SerializerMethodField() 23 | 24 | class Meta: 25 | model = Conversation 26 | fields = ['id', 'title', 'favourite', 'archive', 'created_at', 'messages'] 27 | 28 | def get_created_at(self, obj): 29 | return time_since(obj.created_at) 30 | -------------------------------------------------------------------------------- /site_settings/views.py: -------------------------------------------------------------------------------- 1 | from rest_framework import generics, permissions 2 | from .models import SiteSetting, Language, Ad 3 | from .serializers import SiteSettingSerializer, LanguageSerializer, AdSerializer 4 | 5 | 6 | class SiteSettingList(generics.ListAPIView): 7 | """ 8 | API view to list SiteSettings. 9 | """ 10 | queryset = SiteSetting.objects.all() 11 | serializer_class = SiteSettingSerializer 12 | permission_classes = [permissions.AllowAny] 13 | 14 | 15 | class LanguageList(generics.ListAPIView): 16 | """ 17 | API view to list available Languages. 18 | """ 19 | queryset = Language.objects.all() 20 | serializer_class = LanguageSerializer 21 | permission_classes = [permissions.AllowAny] 22 | 23 | 24 | class AdList(generics.ListAPIView): 25 | """ 26 | API view to list Ads. 27 | """ 28 | queryset = Ad.objects.all() 29 | serializer_class = AdSerializer 30 | -------------------------------------------------------------------------------- /config/settings/local.py: -------------------------------------------------------------------------------- 1 | from .common import * 2 | 3 | ALLOWED_HOSTS = ["*"] 4 | 5 | # DATABASES = { 6 | # 'default': { 7 | # 'ENGINE': 'django.db.backends.sqlite3', 8 | # 'NAME': BASE_DIR / 'db.sqlite3', 9 | # } 10 | # } 11 | 12 | 13 | # Postgres 14 | DATABASES = { 15 | 'default': { 16 | 'ENGINE': 'django.db.backends.postgresql', 17 | 'NAME': os.getenv('DB_NAME'), 18 | 'USER': os.getenv('DB_USER'), 19 | 'PASSWORD': os.getenv('DB_PASSWORD'), 20 | 'HOST': os.getenv('DB_HOST', 'localhost'), 21 | 'PORT': os.getenv('DB_PORT'), 22 | } 23 | } 24 | 25 | STATIC_ROOT = os.path.join(BASE_DIR, "static_cdn", "static_root") 26 | STATICFILES_DIRS = ( 27 | os.path.join(BASE_DIR, "static"), 28 | ) 29 | 30 | MEDIA_ROOT = os.path.join(BASE_DIR, "static_cdn", "media_root") 31 | 32 | # Static files (CSS, JavaScript, Images) 33 | 34 | STATIC_URL = 'static/' 35 | MEDIA_URL = 'media/' 36 | -------------------------------------------------------------------------------- /config/settings/key_values.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | # Celery 4 | BROKER_URL = os.getenv('CELERY_BROKER_URL', 'redis://redis:6379') 5 | CELERY_RESULT_BACKEND = os.getenv('CELERY_RESULT_BACKEND', 'redis://redis:6379') 6 | 7 | # Google OAuth2 settings 8 | SOCIAL_AUTH_GOOGLE_OAUTH2_KEY = os.getenv('GOOGLE_KEY') 9 | SOCIAL_AUTH_GOOGLE_OAUTH2_SECRET = os.getenv('GOOGLE_SECRET') 10 | 11 | # Open AI key 12 | OPENAI_API_KEY = os.getenv('OPEN_AI_KEY') 13 | 14 | # Pinecone 15 | PINECONE_API_KEY = os.getenv('PINECONE_API_KEY') 16 | PINECONE_ENVIRONMENT = os.getenv('PINECONE_ENVIRONMENT') 17 | PINECONE_INDEX_NAME = os.getenv('PINECONE_INDEX_NAME') 18 | PINECONE_NAMESPACE_NAME = os.getenv('PINECONE_NAMESPACE_NAME') 19 | 20 | # Admin Site Config 21 | ADMIN_SITE_HEADER = os.getenv('ADMIN_SITE_HEADER') 22 | ADMIN_SITE_TITLE = os.getenv('ADMIN_SITE_TITLE') 23 | ADMIN_SITE_INDEX = os.getenv('ADMIN_SITE_INDEX') 24 | 25 | # OAuth2 settings 26 | APPLICATION_NAME = os.getenv('APPLICATION_NAME', 'chatbot') 27 | -------------------------------------------------------------------------------- /chatbot/utils.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, timezone 2 | 3 | 4 | def time_since(dt): 5 | """ 6 | Returns string representing "time since" e.g. 7 | """ 8 | now = datetime.now(timezone.utc) 9 | diff = now - dt 10 | 11 | seconds = diff.total_seconds() 12 | minutes = int(seconds // 60) 13 | hours = int(minutes // 60) 14 | days = int(hours // 24) 15 | months = int(days // 30) 16 | years = int(days // 365) 17 | 18 | if years > 0: 19 | return f"{years} year{'s' if years > 1 else ''} ago" 20 | elif months > 0: 21 | return f"{months} month{'s' if months > 1 else ''} ago" 22 | elif days > 0: 23 | return f"{days} day{'s' if days > 1 else ''} ago" 24 | elif hours > 0: 25 | return f"{hours} hour{'s' if hours > 1 else ''} ago" 26 | elif minutes > 0: 27 | return f"{minutes} minute{'s' if minutes > 1 else ''} ago" 28 | else: 29 | return f"{int(seconds)} second{'s' if seconds > 1 else ''} ago" 30 | -------------------------------------------------------------------------------- /users/backends.py: -------------------------------------------------------------------------------- 1 | from django.contrib.auth.backends import ModelBackend 2 | from django.contrib.auth import get_user_model 3 | 4 | User = get_user_model() 5 | 6 | 7 | class EmailBackend(ModelBackend): 8 | """ 9 | Authenticate using an e-mail address. 10 | """ 11 | 12 | def authenticate(self, request, username=None, password=None, **kwargs): 13 | UserModel = get_user_model() 14 | try: 15 | user = UserModel.objects.get(email=username) 16 | except UserModel.DoesNotExist: 17 | try: 18 | user = UserModel.objects.get(username=username) 19 | except UserModel.DoesNotExist: 20 | return None 21 | 22 | if user.check_password(password): 23 | return user 24 | 25 | def get_user(self, user_id): 26 | UserModel = get_user_model() 27 | try: 28 | return UserModel.objects.get(pk=user_id) 29 | except UserModel.DoesNotExist: 30 | return None 31 | -------------------------------------------------------------------------------- /users/serializers.py: -------------------------------------------------------------------------------- 1 | from rest_framework import serializers 2 | from django.contrib.auth import get_user_model 3 | from oauth2_provider.models import get_application_model 4 | 5 | Application = get_application_model() 6 | User = get_user_model() 7 | 8 | 9 | class UserRegistrationSerializer(serializers.ModelSerializer): 10 | """ 11 | User registration serializer. 12 | """ 13 | password = serializers.CharField(write_only=True) 14 | 15 | class Meta: 16 | model = User 17 | fields = ['id', 'username', 'email', 'password', 'first_name', 'last_name', 'phone_number', 'address', 18 | 'profile_picture'] 19 | 20 | 21 | class UserProfileSerializer(serializers.ModelSerializer): 22 | """ 23 | User profile serializer. 24 | """ 25 | 26 | class Meta: 27 | model = User 28 | fields = ['username', 'email', 'first_name', 'last_name', 'phone_number', 'address', 'profile_picture'] 29 | read_only_fields = ['username', 'email'] 30 | -------------------------------------------------------------------------------- /filestructure.txt: -------------------------------------------------------------------------------- 1 | gpt_chat/ 2 | |-- .env 3 | |-- .gitignore 4 | |-- README.md 5 | |-- manage.py 6 | |-- requirements.txt 7 | |-- config/ 8 | | |-- __init__.py 9 | | |-- asgi.py 10 | | |-- celery.py 11 | | |-- settings.py 12 | | |-- urls.py 13 | | |-- wsgi.py 14 | |-- users/ 15 | | |-- __init__.py 16 | | |-- admin.py 17 | | |-- apps.py 18 | | |-- migrations/ 19 | | |-- models.py 20 | | |-- serializers.py 21 | | |-- urls.py 22 | | |-- views.py 23 | | |-- tasks.py 24 | |-- site_settings/ 25 | | |-- __init__.py 26 | | |-- admin.py 27 | | |-- apps.py 28 | | |-- migrations/ 29 | | |-- models.py 30 | | |-- serializers.py 31 | | |-- urls.py 32 | | |-- views.py 33 | |-- chatbot/ 34 | | |-- __init__.py 35 | | |-- admin.py 36 | | |-- apps.py 37 | | |-- migrations/ 38 | | |-- models.py 39 | | |-- serializers.py 40 | | |-- urls.py 41 | | |-- views.py 42 | | |-- tasks.py 43 | |-- training_model/ 44 | | |-- __init__.py 45 | | |-- admin.py 46 | | |-- apps.py 47 | | |-- migrations/ 48 | | |-- models.py 49 | | |-- tests.py 50 | | |-- views.py 51 | -------------------------------------------------------------------------------- /training_model/models.py: -------------------------------------------------------------------------------- 1 | import os 2 | import uuid 3 | from django.db import models 4 | 5 | 6 | def upload_to_faiss(instance, filename): 7 | basename, ext = os.path.splitext(filename) 8 | new_filename = f"{basename}_{uuid.uuid4().hex}{ext}" 9 | return f'documents/faiss/{new_filename}' 10 | 11 | 12 | def upload_to_pinecone(instance, filename): 13 | basename, ext = os.path.splitext(filename) 14 | new_filename = f"{basename}_{uuid.uuid4().hex}{ext}" 15 | return f'documents/pinecone/{new_filename}' 16 | 17 | 18 | def dynamic_upload_to(instance, filename): 19 | if instance.storage_type == 'FAISS': 20 | return upload_to_faiss(instance, filename) 21 | else: 22 | return upload_to_pinecone(instance, filename) 23 | 24 | 25 | class Document(models.Model): 26 | CHOICES = ( 27 | ('FAISS', 'FAISS'), 28 | ('PINECONE', 'PINECONE') 29 | ) 30 | 31 | file = models.FileField(upload_to=dynamic_upload_to) 32 | index_name = models.CharField(max_length=255) 33 | storage_type = models.CharField(max_length=255, choices=CHOICES) 34 | is_trained = models.BooleanField(default=False) 35 | uploaded_at = models.DateTimeField(auto_now_add=True) 36 | 37 | def __str__(self): 38 | return self.file.name 39 | 40 | def file_name(self): 41 | return os.path.basename(self.file.name) 42 | -------------------------------------------------------------------------------- /site_settings/admin.py: -------------------------------------------------------------------------------- 1 | from django.contrib import admin 2 | from .models import SiteSetting, Language, Ad, PineconeIndex, DefaultSettings 3 | 4 | 5 | @admin.register(Language) 6 | class LanguageAdmin(admin.ModelAdmin): 7 | """ 8 | Admin site configuration for Language model. 9 | """ 10 | list_display = ('name', 'code',) 11 | search_fields = ('name', 'code',) 12 | 13 | 14 | @admin.register(SiteSetting) 15 | class SiteSettingAdmin(admin.ModelAdmin): 16 | """ 17 | Admin site configuration for SiteSetting model. 18 | """ 19 | list_display = ('title', 'logo', 'prompt',) 20 | search_fields = ('title',) 21 | 22 | 23 | @admin.register(Ad) 24 | class AdAdmin(admin.ModelAdmin): 25 | """ 26 | Admin site configuration for Ad model. 27 | """ 28 | list_display = ('title', 'description', 'image',) 29 | search_fields = ('title', 'description',) 30 | 31 | 32 | @admin.register(PineconeIndex) 33 | class PineconeIndexAdmin(admin.ModelAdmin): 34 | """ 35 | Admin site configuration for PineconeIndex model. 36 | """ 37 | list_display = ('name', 'index_id',) 38 | search_fields = ('name', 'index_id',) 39 | 40 | 41 | @admin.register(DefaultSettings) 42 | class DefaultSettingsAdmin(admin.ModelAdmin): 43 | """ 44 | Admin site configuration for DefaultSettings model. 45 | """ 46 | list_display = ('language', 'site_setting', 'ad',) 47 | search_fields = ('language', 'site_setting', 'ad',) 48 | -------------------------------------------------------------------------------- /chatbot/urls.py: -------------------------------------------------------------------------------- 1 | from django.urls import path 2 | from . import views 3 | 4 | app_name = 'chatbot' 5 | 6 | urlpatterns = [ 7 | # List and create conversations 8 | path('conversations/', views.ConversationListCreate.as_view(), name='conversation-list-create'), 9 | # Retrieve, update, and delete a specific conversation 10 | # path('conversations//', views.ConversationDetail.as_view(), name='conversation-detail'), 11 | # Favourite a conversation 12 | path('conversations//favourite/', views.ConversationFavourite.as_view(), name='conversation-favourite'), 13 | # Archive a conversation 14 | path('conversations//archive/', views.ConversationArchive.as_view(), name='conversation-archive'), 15 | 16 | # Delete a conversation 17 | path('conversations//delete/', views.ConversationDelete.as_view(), name='conversation-delete'), 18 | 19 | # update title 20 | path('conversations//title/', views.ConversationRetrieveUpdateView.as_view(), name='conversation-title'), 21 | 22 | # List messages in a conversation 23 | path('conversations//messages/', views.MessageList.as_view(), name='message-list'), 24 | 25 | # Create a message in a conversation 26 | path('conversations//messages/create/', views.MessageCreate.as_view(), name='message-create'), 27 | # async gpt task 28 | # path('conversations/task//', views.GPT3TaskStatus.as_view(), name='gpt_task_status'), 29 | ] 30 | -------------------------------------------------------------------------------- /config/urls.py: -------------------------------------------------------------------------------- 1 | from django.contrib import admin 2 | from django.urls import path, include 3 | from rest_framework import permissions 4 | from drf_yasg.views import get_schema_view 5 | from drf_yasg import openapi 6 | from django.conf import settings 7 | 8 | schema_view = get_schema_view( 9 | openapi.Info( 10 | title="Customized ChatGPT API", 11 | default_version='v1', 12 | description="API documentation for the Customized ChatGPT project", 13 | ), 14 | public=True, 15 | permission_classes=[permissions.AllowAny, ], 16 | ) 17 | 18 | urlpatterns = [ 19 | path('admin/', admin.site.urls), 20 | 21 | # Training model app URLs 22 | path('training-model/', include('training_model.urls')), 23 | 24 | # OAuth2 provider URLs 25 | path('oauth2/', include('oauth2_provider.urls', namespace='oauth2_provider')), 26 | 27 | # Social auth URLs 28 | path('social-auth/', include('social_django.urls', namespace='social')), 29 | 30 | # Rest framework URLs 31 | path('api-auth/', include('rest_framework.urls')), 32 | 33 | # Users app URLs 34 | path('api/v1/users/', include('users.urls')), 35 | 36 | # Site settings app URLs 37 | path('api/v1/site-settings/', include('site_settings.urls')), 38 | 39 | # Chatbot app URLs 40 | path('api/v1/chatbot/', include('chatbot.urls')), 41 | 42 | # Swagger URLs 43 | path('', schema_view.with_ui('swagger', cache_timeout=0), name='schema-swagger-ui'), 44 | path('redoc/', schema_view.with_ui('redoc', cache_timeout=0), name='schema-redoc'), 45 | ] 46 | 47 | admin.site.site_header = settings.ADMIN_SITE_HEADER 48 | admin.site.site_title = settings.ADMIN_SITE_TITLE 49 | admin.site.index_title = settings.ADMIN_SITE_INDEX 50 | -------------------------------------------------------------------------------- /chatbot/models.py: -------------------------------------------------------------------------------- 1 | from django.db import models 2 | from django.conf import settings 3 | import secrets 4 | 5 | 6 | def generate_secure_random_id(): 7 | min_value = 10 ** 10 # Minimum value of the range (inclusive) 8 | max_value = 10 ** 11 - 1 # Maximum value of the range (exclusive) 9 | return secrets.randbelow(max_value - min_value) + min_value 10 | 11 | 12 | class Conversation(models.Model): 13 | """ 14 | Conversation model representing a chat conversation. 15 | """ 16 | STATUS_CHOICES = [ 17 | ('active', 'Active'), 18 | ('archived', 'Archived'), 19 | ('ended', 'Ended'), 20 | ] 21 | 22 | id = models.BigIntegerField(primary_key=True, default=generate_secure_random_id, editable=False) 23 | title = models.CharField(max_length=255, default="Empty") 24 | user = models.ForeignKey(settings.AUTH_USER_MODEL, on_delete=models.CASCADE) 25 | created_at = models.DateTimeField(auto_now_add=True) 26 | updated_at = models.DateTimeField(auto_now=True) 27 | favourite = models.BooleanField(default=False) 28 | archive = models.BooleanField(default=False) 29 | 30 | # status = models.CharField(max_length=10, choices=STATUS_CHOICES, default='active') 31 | 32 | class Meta: 33 | ordering = ['created_at'] 34 | 35 | def __str__(self): 36 | return f"Conversation {self.title} - {self.user.username}" 37 | 38 | 39 | class Message(models.Model): 40 | """ 41 | Message model representing a message within a conversation. 42 | """ 43 | conversation = models.ForeignKey(Conversation, on_delete=models.CASCADE) 44 | content = models.TextField() 45 | created_at = models.DateTimeField(auto_now_add=True) 46 | is_from_user = models.BooleanField(default=True) 47 | in_reply_to = models.ForeignKey('self', null=True, blank=True, on_delete=models.SET_NULL, related_name='replies') 48 | 49 | class Meta: 50 | ordering = ['-created_at'] 51 | 52 | def __str__(self): 53 | return f"Message {self.id} - {self.conversation}" 54 | -------------------------------------------------------------------------------- /training_model/views.py: -------------------------------------------------------------------------------- 1 | from django.http import HttpResponseForbidden, HttpResponseRedirect 2 | from django.views import View 3 | from django.urls import reverse 4 | from django.contrib import messages 5 | import requests 6 | import tempfile 7 | import os 8 | from django.conf import settings 9 | from django.contrib.auth import get_user_model 10 | from .models import Document 11 | from .pinecone_helpers import build_or_update_pinecone_index 12 | 13 | User = get_user_model() 14 | 15 | 16 | class TrainView(View): 17 | """ 18 | View to train a Pinecone index 19 | """ 20 | 21 | def get(self, request, object_id): 22 | # Check if user is staff or superuser 23 | if not request.user.is_staff and not request.user.is_superuser: 24 | return HttpResponseForbidden("You don't have permission to access this page.") 25 | 26 | document = Document.objects.get(pk=object_id) 27 | index_name = settings.PINECONE_INDEX_NAME 28 | # namespace = User.objects.get(pk=request.user.id).username 29 | namespace = settings.PINECONE_NAMESPACE_NAME 30 | 31 | # Download the file and save it to a temporary directory 32 | file_url = document.file.url 33 | response = requests.get(file_url) 34 | temp_dir = tempfile.mkdtemp() 35 | file_name = os.path.join(temp_dir, os.path.basename(file_url)) 36 | 37 | with open(file_name, 'wb') as f: 38 | f.write(response.content) 39 | 40 | file_path = file_name 41 | 42 | # Load and process files 43 | build_or_update_pinecone_index(file_path, index_name, namespace) 44 | 45 | # Update is_trained to True 46 | document.is_trained = True 47 | document.save() 48 | 49 | # Clean up the temporary directory 50 | os.remove(file_path) 51 | os.rmdir(temp_dir) 52 | 53 | # Redirect to Django admin with a success message 54 | messages.success(request, "Training complete.") 55 | admin_url = reverse('admin:training_model_document_change', args=[object_id]) 56 | return HttpResponseRedirect(admin_url) 57 | -------------------------------------------------------------------------------- /site_settings/models.py: -------------------------------------------------------------------------------- 1 | from django.db import models 2 | 3 | 4 | class Language(models.Model): 5 | """ 6 | Language model representing available languages. 7 | """ 8 | name = models.CharField(max_length=100, unique=True) 9 | code = models.CharField(max_length=10, unique=True) 10 | 11 | class Meta: 12 | verbose_name_plural = "Languages" 13 | 14 | def __str__(self): 15 | return self.name 16 | 17 | 18 | class SiteSetting(models.Model): 19 | """ 20 | SiteSetting model representing website settings. 21 | """ 22 | title = models.CharField(max_length=200) 23 | logo = models.ImageField(upload_to='site_logo/') 24 | prompt = models.TextField(null=True, blank=True) 25 | 26 | class Meta: 27 | verbose_name_plural = "Site Settings" 28 | 29 | def __str__(self): 30 | return self.title 31 | 32 | 33 | class Ad(models.Model): 34 | """ 35 | Ad model representing advertisements. 36 | """ 37 | title = models.CharField(max_length=200) 38 | description = models.TextField() 39 | image = models.ImageField(upload_to='ads/') 40 | 41 | class Meta: 42 | verbose_name_plural = "Ads" 43 | 44 | def __str__(self): 45 | return self.title 46 | 47 | 48 | class PineconeIndex(models.Model): 49 | """ 50 | PineconeIndex model representing pinecone indexes. 51 | """ 52 | name = models.CharField(max_length=200) 53 | index_id = models.CharField(max_length=200) 54 | 55 | class Meta: 56 | verbose_name_plural = "Pinecone Indexes" 57 | 58 | def __str__(self): 59 | return self.name 60 | 61 | 62 | class DefaultSettings(models.Model): 63 | """ 64 | DefaultSettings model representing default settings. 65 | """ 66 | language = models.ForeignKey(Language, on_delete=models.CASCADE) 67 | site_setting = models.ForeignKey(SiteSetting, on_delete=models.CASCADE) 68 | ad = models.ForeignKey(Ad, on_delete=models.CASCADE) 69 | 70 | class Meta: 71 | verbose_name_plural = "Default Settings" 72 | 73 | def __str__(self): 74 | return f"Default Settings" 75 | -------------------------------------------------------------------------------- /example.env: -------------------------------------------------------------------------------- 1 | ENV_NAME='local' 2 | 3 | # TEMPORARY SECRET - DO NOT USE THIS 4 | DJANGO_SETTINGS_MODULE=config.settings.local 5 | DJANGO_SECRET_KEY='django-insecure-z#p-sa90q*k(n9xh632k@rgo=sd3@589=(%l8xpn=c-+yr&q=0' 6 | DJANGO_DEBUG=True 7 | 8 | # Database 9 | DB_NAME=database_name 10 | DB_USER=database_username 11 | DB_PASSWORD=database_user_password 12 | DB_HOST=database_host 13 | DB_PORT=database_port 14 | 15 | # Celery 16 | CELERY_BROKER_URL = 'redis://localhost:6379/0' 17 | CELERY_RESULT_BACKEND = 'redis://localhost:6379/0' 18 | 19 | # Email 20 | EMAIL_BACKEND=django.core.mail.backends.smtp.EmailBackend 21 | EMAIL_HOST='' 22 | EMAIL_PORT=587 23 | EMAIL_FROM="" 24 | 25 | SITE_URL=http://localhost:8000 26 | SITE_ID=1 27 | 28 | 29 | # CORS 30 | CSRF_COOKIE_SECURE=True 31 | SESSION_COOKIE_SECURE=True 32 | CSRF_COOKIE_HTTPONLY=False 33 | SESSION_COOKIE_HTTPONLY=True 34 | SESSION_COOKIE_SAMESITE="None" 35 | CSRF_COOKIE_SAMESITE="None" 36 | CORS_ALLOW_CREDENTIALS=True 37 | CORS_ORIGIN_ALLOW_ALL=False 38 | CSRF_COOKIE_NAME="csrftoken" 39 | CORS_ALLOWED_ORIGINS=http://127.0.0.1:3000,http://localhost:3000 40 | CSRF_TRUSTED_ORIGINS=http://127.0.0.1:3000,http://localhost:3000 41 | 42 | # Security 43 | 44 | X_FRAME_OPTIONS='DENY' 45 | SECURE_BROWSER_XSS_FILTER=True 46 | 47 | # GENERALS 48 | AUTH_USER_MODEL=users.CustomUser 49 | LANGUAGE_CODE="en-us" 50 | APPEND_SLASH=True 51 | TIME_ZONE='UTC' 52 | USE_I18N=True 53 | USE_TZ=True 54 | USE_L10N=True 55 | 56 | # Social 57 | FACEBOOK_KEY='' 58 | FACEBOOK_SECRET='' 59 | GOOGLE_KEY='' 60 | GOOGLE_SECRET='' 61 | 62 | # Other API 63 | OPEN_AI_KEY='' 64 | PINECONE_API_KEY='' 65 | PINECONE_ENVIRONMENT='' 66 | PINECONE_INDEX_NAME='' 67 | PINECONE_NAMESPACE_NAME='' 68 | SENTRY_DSN='' 69 | 70 | # AWS 71 | AWS_ACCESS_KEY='' 72 | AWS_SECRET_KEY='' 73 | REGION_NAME='' 74 | QUEUE_NAME='' 75 | 76 | DJANGO_AWS_STORAGE_BUCKET_NAME='' 77 | 78 | 79 | # Admin Site Config 80 | ADMIN_SITE_HEADER="Chatbot" 81 | ADMIN_SITE_TITLE="Chatbot Dashboard" 82 | ADMIN_SITE_INDEX="Chatbot Dashboard" 83 | 84 | # RASA 85 | 86 | RASA_API_URL = "http://localhost:5005/model/parse" 87 | 88 | # OAuth2 settings 89 | APPLICATION_NAME = "Chatbot" 90 | 91 | 92 | LOGIN_REDIRECT_URL = "/" -------------------------------------------------------------------------------- /training_model/faiss_helpers.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import mimetypes 4 | from django.conf import settings 5 | from langchain.document_loaders import ( 6 | CSVLoader, 7 | UnstructuredWordDocumentLoader, 8 | PyPDFLoader, 9 | ) 10 | from langchain.embeddings import OpenAIEmbeddings 11 | from langchain.vectorstores import FAISS as BaseFAISS 12 | 13 | OPENAI_API_KEY = settings.OPENAI_API_KEY 14 | BASE_DIR = settings.BASE_DIR 15 | MODELS_DIR = os.path.join(BASE_DIR, "models") 16 | 17 | 18 | class FAISS(BaseFAISS): 19 | """ 20 | FAISS is a vector store that uses the FAISS library to store and search vectors. 21 | """ 22 | 23 | @classmethod 24 | def load(cls, file_path): 25 | with open(file_path, "rb") as f: 26 | return pickle.load(f) 27 | 28 | def save(self, file_path): 29 | with open(file_path, "wb") as f: 30 | pickle.dump(self, f) 31 | 32 | def add_vectors(self, new_embeddings): 33 | # Assuming self.index is the faiss index instance 34 | self.index.add(new_embeddings) 35 | 36 | 37 | def get_loader(file_path): 38 | mime_type, _ = mimetypes.guess_type(file_path) 39 | 40 | if mime_type == "application/pdf": 41 | return PyPDFLoader(file_path) 42 | elif mime_type == "text/csv": 43 | return CSVLoader(file_path) 44 | elif mime_type in [ 45 | "application/msword", 46 | "application/vnd.openxmlformats-officedocument.wordprocessingml.document", 47 | ]: 48 | return UnstructuredWordDocumentLoader(file_path) 49 | else: 50 | raise ValueError(f"Unsupported file type: {mime_type}") 51 | 52 | 53 | def build_or_update_faiss_index(file_path, index_name): 54 | faiss_obj_path = os.path.join(MODELS_DIR, f"{index_name}.pickle") 55 | embeddings = OpenAIEmbeddings(openai_api_key=OPENAI_API_KEY) 56 | loader = get_loader(file_path) 57 | pages = loader.load_and_split() 58 | 59 | if os.path.exists(faiss_obj_path): 60 | faiss_index = FAISS.load(faiss_obj_path) 61 | new_embeddings = FAISS.from_documents(pages, embeddings, index_name=index_name) 62 | faiss_index.add_vectors(new_embeddings) 63 | else: 64 | faiss_index = FAISS.from_documents(pages, embeddings, index_name=index_name) 65 | 66 | faiss_index.save(faiss_obj_path) 67 | return faiss_index 68 | -------------------------------------------------------------------------------- /config/settings/prod.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .common import * # noqa 3 | 4 | # Site 5 | # https://docs.djangoproject.com/en/2.0/ref/settings/#allowed-hosts 6 | ALLOWED_HOSTS = ["*"] 7 | INSTALLED_APPS += [ 8 | "gunicorn", 9 | "storages", 10 | ] # noqa 11 | 12 | # Postgres 13 | DATABASES = { 14 | 'default': { 15 | 'ENGINE': 'django.db.backends.postgresql', 16 | 'NAME': os.getenv('DB_NAME'), 17 | 'USER': os.getenv('DB_USER'), 18 | 'PASSWORD': os.getenv('DB_PASSWORD'), 19 | 'HOST': os.getenv('DB_HOST', 'localhost'), 20 | 'PORT': os.getenv('DB_PORT'), 21 | } 22 | } 23 | 24 | # Static files (CSS, JavaScript, Images) 25 | # https://docs.djangoproject.com/en/2.0/howto/static-files/ 26 | # http://django-storages.readthedocs.org/en/latest/index.html 27 | AWS_ACCESS_KEY_ID = os.getenv('AWS_ACCESS_KEY') 28 | AWS_SECRET_ACCESS_KEY = os.getenv('AWS_SECRET_KEY') 29 | AWS_S3_REGION_NAME = os.getenv('REGION_NAME') 30 | AWS_STORAGE_BUCKET_NAME = os.getenv('DJANGO_AWS_STORAGE_BUCKET_NAME') 31 | 32 | # By default files with the same name will overwrite each other. 33 | # Set this to False to have extra characters appended. 34 | AWS_S3_FILE_OVERWRITE = False 35 | AWS_DEFAULT_ACL = 'public-read' 36 | AWS_AUTO_CREATE_BUCKET = True 37 | AWS_QUERYSTRING_AUTH = False 38 | 39 | # Must set AWS_S3_CUSTOM_DOMAIN to get static files 40 | AWS_S3_CUSTOM_DOMAIN = "{}.s3.{}.amazonaws.com".format(AWS_STORAGE_BUCKET_NAME, AWS_S3_REGION_NAME) 41 | 42 | # AWS_LOCATION = 'static' 43 | # STATIC_URL = f'https://{AWS_S3_CUSTOM_DOMAIN}/{AWS_LOCATION}/' 44 | 45 | STATIC_URL = "{}/".format(AWS_S3_CUSTOM_DOMAIN, AWS_S3_REGION_NAME) 46 | STATICFILES_STORAGE = 'storages.backends.s3boto3.S3Boto3Storage' 47 | 48 | MEDIA_URL = "{}/".format(AWS_S3_CUSTOM_DOMAIN, AWS_S3_REGION_NAME) 49 | DEFAULT_FILE_STORAGE = 'storages.backends.s3boto3.S3Boto3Storage' 50 | 51 | STATICFILES_DIRS = (os.path.join(BASE_DIR, 'static'),) 52 | 53 | # https://developers.google.com/web/fundamentals/performance/optimizing-content-efficiency/http-caching#cache-control 54 | # Response can be cached by browser and any intermediary caches (i.e. it is "public") for up to 1 day 55 | # 86400 = (60 seconds x 60 minutes x 24 hours) 56 | AWS_S3_OBJECT_PARAMETERS = { 57 | 'CacheControl': 'max-age=86400', 58 | } 59 | # Social 60 | SOCIAL_AUTH_REDIRECT_IS_HTTPS = True 61 | 62 | # easy thumbnails lib & S3 63 | THUMBNAIL_DEFAULT_STORAGE = 'storages.backends.s3boto3.S3Boto3Storage' 64 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | .env 9 | migrations/ 10 | 11 | # Distribution / packaging 12 | .idea/ 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 101 | __pypackages__/ 102 | 103 | # Celery stuff 104 | celerybeat-schedule 105 | celerybeat.pid 106 | 107 | # SageMath parsed files 108 | *.sage.py 109 | 110 | # Environments 111 | env 112 | .venv 113 | env/ 114 | venv/ 115 | env.bak/ 116 | venv.bak/ 117 | 118 | # Spyder project settings 119 | .spyderproject 120 | .spyproject 121 | 122 | # Rope project settings 123 | .ropeproject 124 | 125 | # mkdocs documentation 126 | /site 127 | 128 | # mypy 129 | .mypy_cache/ 130 | .dmypy.json 131 | dmypy.json 132 | 133 | # Pyre type checker 134 | .pyre/ 135 | 136 | # pytype static type analyzer 137 | .pytype/ 138 | 139 | # Cython debug symbols 140 | cython_debug/ 141 | /facerecognition.sqlite 142 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Dynamic AI Chatbot with Custom Training Sources 2 | ## Customizable-gpt-chatbot 3 | This project is a dynamic AI chatbot that can be trained from various sources, such as PDFs, documents, websites, and YouTube videos. It uses a user system with social authentication through Google, and the Django REST framework for its backend. The chatbot leverages OpenAI's GPT-3.5 language model to conduct conversations and is designed for scalability and ease of use. 4 | 5 | ## Features 6 | - Train chatbot from multiple sources (PDFs, documents, websites, YouTube videos) 7 | - User system with social authentication through Google 8 | - Connect with OpenAI GPT-3.5 language model for conversation 9 | - Use Pinecone and FAISS for vector indexing 10 | - Employ OpenAI's text-embedding-ada-002 for text embedding 11 | - Python Langchain library for file processing and text conversion 12 | - Scalable architecture with separate settings for local, staging, and production environments 13 | - Dynamic site settings for title and prompt updates 14 | - Multilingual support 15 | - PostgreSQL database support 16 | - Celery task scheduler with Redis and AWS SQS options 17 | - AWS S3 bucket support for scalable hosting 18 | - Easy deployment on Heroku or AWS 19 | 20 | ## Technologies 21 | - Language: Python 22 | - Framework: Django REST Framework 23 | - Database: PostgreSQL 24 | 25 | ### Major Libraries: 26 | - Celery 27 | - Langchain 28 | - OpenAI 29 | - Pinecone 30 | - FAISS 31 | ## Requirements 32 | - Python 3.8 or above 33 | - Django 4.1 or above 34 | - Pinecone API Key 35 | - API key from OpenAI 36 | - Redis or AWS SQS 37 | - PostgreSQL database 38 | 39 | ## Future Scope 40 | - Integration with more third-party services for authentication 41 | - Support for additional file formats and media types for chatbot training 42 | - Improved context-awareness in conversations 43 | - Enhanced multilingual support with automatic language detection 44 | - Integration with popular messaging platforms and chat applications 45 | 46 | ## How to run 47 | - Clone the repository. `git clone https://github.com/shamspias/customizable-gpt-chatbot` 48 | - Install the required packages by running `pip install -r requirements.txt` 49 | - Run celery `celery -A config worker --loglevel=info` 50 | - Run the command `python manage.py runserver` 51 | - Open `http://127.0.0.1:8000/` in your browser 52 | 53 | In linux and mac need to install 'sudo apt install python3-dev -y` 54 | 1. Make sure that you have the development libraries for libcurl installed on your system. You can install them by running the following command: `sudo apt-get install libcurl4-openssl-dev gcc libssl-dev -y` 55 | 2. Make sure that you have the latest version of pip and setuptools installed by running the following command: `pip install --upgrade pip setuptools` 56 | 3. `pip install pycurl` 57 | 58 | ## Deployment 59 | The chatbot can be deployed on Heroku or AWS by following the standard procedures for Django deployment on these platforms. 60 | 61 | ## Issues 62 | - If you don't use AWS SQS then no need to install `pycurl` and `boto3` packages. 63 | - If you don't use AWS S3 then no need to install `django-storages` package. 64 | 65 | ## Note 66 | Make sure that you have API key from OpenAI before running the project. 67 | 68 | This is just a basic implementation of the project, you can always add more features and customization according to your requirement. 69 | 70 | Enjoy! 71 | -------------------------------------------------------------------------------- /training_model/pinecone_helpers.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pinecone 3 | import requests 4 | from bs4 import BeautifulSoup 5 | from urllib.parse import urljoin, urlsplit 6 | import mimetypes 7 | from django.conf import settings 8 | from langchain.document_loaders import ( 9 | CSVLoader, 10 | UnstructuredWordDocumentLoader, 11 | PyPDFLoader, 12 | WebBaseLoader, 13 | ) 14 | from langchain.embeddings import OpenAIEmbeddings 15 | from langchain.vectorstores import Pinecone 16 | 17 | OPENAI_API_KEY = settings.OPENAI_API_KEY 18 | PINECONE_API_KEY = settings.PINECONE_API_KEY 19 | PINECONE_ENVIRONMENT = settings.PINECONE_ENVIRONMENT 20 | PINECONE_INDEX_NAME = settings.PINECONE_INDEX_NAME 21 | PINECONE_NAMESPACE_NAME = settings.PINECONE_NAMESPACE_NAME 22 | BASE_DIR = settings.BASE_DIR 23 | MODELS_DIR = os.path.join(BASE_DIR, "models") 24 | 25 | embeddings = OpenAIEmbeddings(openai_api_key=OPENAI_API_KEY) 26 | 27 | 28 | class PineconeManager: 29 | """ 30 | This class is used to manage the Pinecone Indexes 31 | """ 32 | 33 | def __init__(self, api_key, environment): 34 | pinecone.init( 35 | api_key=api_key, 36 | environment=environment 37 | ) 38 | 39 | def list_of_indexes(self): 40 | try: 41 | pinecone_index_list = pinecone.list_indexes() 42 | print("List of Pinecone Indexes: ") 43 | print(pinecone_index_list) 44 | print("____________________________________________________") 45 | return pinecone_index_list 46 | except Exception as e: 47 | print("Error in listing the Pinecone Indexes: ", e) 48 | print("____________________________________________________") 49 | raise Exception("Error in listing the Pinecone Indexes: ", e) 50 | 51 | def create_index(self, index_name, dimension, metric): 52 | pinecone.create_index(name=index_name, dimension=dimension, metric=metric) 53 | 54 | def delete_index(self, index_name): 55 | pinecone.delete_index(index_name) 56 | 57 | 58 | class URLHandler: 59 | """ 60 | This class is used to handle the URLs 61 | """ 62 | 63 | @staticmethod 64 | def is_valid_url(url): 65 | parsed_url = urlsplit(url) 66 | return bool(parsed_url.scheme) and bool(parsed_url.netloc) 67 | 68 | @staticmethod 69 | def extract_links(url): 70 | response = requests.get(url) 71 | soup = BeautifulSoup(response.text, 'html.parser') 72 | 73 | links = [] 74 | for link in soup.find_all('a'): 75 | href = link.get('href') 76 | if href: 77 | absolute_url = urljoin(url, href) 78 | if URLHandler.is_valid_url(absolute_url): 79 | links.append(absolute_url) 80 | 81 | return links 82 | 83 | @staticmethod 84 | def extract_links_from_websites(websites): 85 | all_links = [] 86 | 87 | for website in websites: 88 | links = URLHandler.extract_links(website) 89 | all_links.extend(links) 90 | 91 | return all_links 92 | 93 | 94 | class DocumentLoaderFactory: 95 | """ 96 | This class is used to load the documents 97 | """ 98 | 99 | @staticmethod 100 | def get_loader(file_path_or_url): 101 | if file_path_or_url.startswith("http://") or file_path_or_url.startswith("https://"): 102 | handle_website = URLHandler() 103 | return WebBaseLoader(handle_website.extract_links_from_websites([file_path_or_url])) 104 | else: 105 | mime_type, _ = mimetypes.guess_type(file_path_or_url) 106 | 107 | if mime_type == 'application/pdf': 108 | return PyPDFLoader(file_path_or_url) 109 | elif mime_type == 'text/csv': 110 | return CSVLoader(file_path_or_url) 111 | elif mime_type in ['application/msword', 112 | 'application/vnd.openxmlformats-officedocument.wordprocessingml.document']: 113 | return UnstructuredWordDocumentLoader(file_path_or_url) 114 | else: 115 | raise ValueError(f"Unsupported file type: {mime_type}") 116 | 117 | 118 | class PineconeIndexManager: 119 | """ 120 | This class is used to manage the Pinecone Indexes 121 | """ 122 | 123 | def __init__(self, pinecone_manager, index_name): 124 | self.pinecone_manager = pinecone_manager 125 | self.index_name = index_name 126 | 127 | def index_exists(self): 128 | active_indexes = self.pinecone_manager.list_of_indexes() 129 | return self.index_name in active_indexes 130 | 131 | def create_index(self, dimension, metric): 132 | self.pinecone_manager.create_index(self.index_name, dimension, metric) 133 | 134 | def delete_index(self): 135 | self.pinecone_manager.delete_index(self.index_name) 136 | 137 | 138 | def build_or_update_pinecone_index(file_path, index_name, name_space): 139 | """ 140 | This function is used to build or update the Pinecone Index 141 | """ 142 | pinecone_index_manager = PineconeIndexManager(PineconeManager(PINECONE_API_KEY, PINECONE_ENVIRONMENT), index_name) 143 | loader = DocumentLoaderFactory.get_loader(file_path) 144 | pages = loader.load_and_split() 145 | 146 | if pinecone_index_manager.index_exists(): 147 | print("Updating the model") 148 | pinecone_index = Pinecone.from_documents(pages, embeddings, index_name=pinecone_index_manager.index_name, 149 | namespace=PINECONE_NAMESPACE_NAME) 150 | 151 | else: 152 | print("Training the model") 153 | pinecone_index_manager.create_index(dimension=1536, metric="cosine") 154 | pinecone_index = Pinecone.from_documents(documents=pages, embedding=embeddings, 155 | index_name=pinecone_index_manager.index_name, 156 | namespace=PINECONE_NAMESPACE_NAME) 157 | return pinecone_index 158 | -------------------------------------------------------------------------------- /chatbot/tasks.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import os 3 | import openai 4 | # from langchain.vectorstores import FAISS as BaseFAISS 5 | from training_model.pinecone_helpers import ( 6 | PineconeManager, 7 | PineconeIndexManager, 8 | embeddings, 9 | ) 10 | from langchain.vectorstores import Pinecone 11 | from langchain.embeddings import OpenAIEmbeddings 12 | from langchain.chat_models import ChatOpenAI 13 | from langchain.schema import ( 14 | SystemMessage, 15 | HumanMessage, 16 | AIMessage, 17 | ) 18 | 19 | from celery import shared_task 20 | from celery.utils.log import get_task_logger 21 | from django.conf import settings 22 | 23 | logger = get_task_logger(__name__) 24 | 25 | PINECONE_API_KEY = settings.PINECONE_API_KEY 26 | PINECONE_ENVIRONMENT = settings.PINECONE_ENVIRONMENT 27 | PINECONE_INDEX_NAME = settings.PINECONE_INDEX_NAME 28 | OPENAI_API_KEY = settings.OPENAI_API_KEY 29 | 30 | chat = ChatOpenAI(temperature=0, openai_api_key=OPENAI_API_KEY) 31 | embeddings = OpenAIEmbeddings(openai_api_key=OPENAI_API_KEY) 32 | 33 | 34 | # 35 | # class FAISS(BaseFAISS): 36 | # """ 37 | # FAISS is a vector store that uses the FAISS library to store and search vectors. 38 | # """ 39 | # 40 | # def save(self, file_path): 41 | # with open(file_path, "wb") as f: 42 | # pickle.dump(self, f) 43 | # 44 | # @staticmethod 45 | # def load(file_path): 46 | # with open(file_path, "rb") as f: 47 | # return pickle.load(f) 48 | 49 | # 50 | # def get_faiss_index(index_name): 51 | # faiss_obj_path = os.path.join(settings.BASE_DIR, "models", "{}.pickle".format(index_name)) 52 | # 53 | # if os.path.exists(faiss_obj_path): 54 | # # Load the FAISS object from disk 55 | # try: 56 | # faiss_index = FAISS.load(faiss_obj_path) 57 | # return faiss_index 58 | # except Exception as e: 59 | # logger.error(f"Failed to load FAISS index: {e}") 60 | # return None 61 | 62 | def get_pinecone_index(index_name, name_space): 63 | pinecone_manager = PineconeManager(PINECONE_API_KEY, PINECONE_ENVIRONMENT) 64 | pinecone_index_manager = PineconeIndexManager(pinecone_manager, index_name) 65 | 66 | try: 67 | pinecone_index = Pinecone.from_existing_index(index_name=pinecone_index_manager.index_name, 68 | embedding=embeddings, namespace=settings.PINECONE_NAMESPACE_NAME) 69 | # pinecone_index = Pinecone.from_existing_index(index_name=pinecone_index_manager.index_name, 70 | # namespace=name_space, embedding=embeddings) 71 | return pinecone_index 72 | 73 | except Exception as e: 74 | logger.error(f"Failed to load Pinecone index: {e}") 75 | return None 76 | 77 | 78 | @shared_task 79 | def send_gpt_request(message_list, name_space, system_prompt): 80 | try: 81 | 82 | # new_messages_list = [] 83 | # for msg in message_list: 84 | # if msg["role"] == "user": 85 | # new_messages_list.append(HumanMessage(content=msg["content"])) 86 | # else: 87 | # new_messages_list.append(AIMessage(content=msg["content"])) 88 | # Load the FAISS index 89 | # base_index = get_faiss_index("buffer_salaries") 90 | 91 | # Load the Pinecone index 92 | base_index = get_pinecone_index(PINECONE_INDEX_NAME, name_space) 93 | 94 | if base_index: 95 | # Add extra text to the content of the last message 96 | last_message = message_list[-1] 97 | 98 | query_text = last_message["content"] 99 | 100 | # Get the most similar documents to the last message 101 | try: 102 | docs = base_index.similarity_search(query=last_message["content"], k=2) 103 | 104 | updated_content = '"""' 105 | for doc in docs: 106 | updated_content += doc.page_content + "\n\n" 107 | updated_content += '"""\nQuestion:' + query_text 108 | except Exception as e: 109 | logger.error(f"Failed to get similar documents: {e}") 110 | updated_content = query_text 111 | 112 | # Create a new HumanMessage object with the updated content 113 | # updated_message = HumanMessage(content=updated_content) 114 | updated_message = {"role": "user", "content": updated_content} 115 | 116 | # Replace the last message in message_list with the updated message 117 | message_list[-1] = updated_message 118 | 119 | openai.api_key = settings.OPENAI_API_KEY 120 | # Send request to GPT-3 (replace with actual GPT-3 API call) 121 | gpt3_response = openai.ChatCompletion.create( 122 | model="gpt-3.5-turbo-16k", 123 | messages=[ 124 | {"role": "system", 125 | "content": f"{system_prompt}"}, 126 | ] + message_list 127 | ) 128 | 129 | assistant_response = gpt3_response["choices"][0]["message"]["content"].strip() 130 | 131 | except Exception as e: 132 | logger.error(f"Failed to send request to GPT-3.5: {e}") 133 | return "Sorry, I'm having trouble understanding you." 134 | return assistant_response 135 | 136 | 137 | @shared_task 138 | def generate_title_request(message_list): 139 | try: 140 | openai.api_key = settings.OPENAI_API_KEY 141 | # Send request to GPT-3 (replace with actual GPT-3 API call) 142 | gpt3_response = openai.ChatCompletion.create( 143 | model="gpt-3.5-turbo-16k", 144 | messages=[ 145 | {"role": "system", 146 | "content": "Summarize and make a very short meaningful title under 24 characters"}, 147 | ] + message_list 148 | ) 149 | response = gpt3_response["choices"][0]["message"]["content"].strip() 150 | 151 | except Exception as e: 152 | logger.error(f"Failed to send request to GPT-3.5: {e}") 153 | return "Problematic title with error." 154 | return response 155 | -------------------------------------------------------------------------------- /users/views.py: -------------------------------------------------------------------------------- 1 | from rest_framework import generics 2 | from rest_framework.permissions import IsAuthenticated 3 | from .serializers import UserRegistrationSerializer, UserProfileSerializer 4 | from .tasks import send_forgot_password_email 5 | from django.contrib.auth import get_user_model 6 | from django.shortcuts import redirect, get_object_or_404 7 | from social_django.utils import load_strategy 8 | from rest_framework import permissions 9 | from rest_framework import status 10 | from django.utils import timezone 11 | from django.contrib.auth import authenticate 12 | from rest_framework.response import Response 13 | from rest_framework.views import APIView 14 | from oauth2_provider.models import Application 15 | from oauthlib.common import generate_token 16 | from oauth2_provider.settings import oauth2_settings 17 | from oauth2_provider.models import AccessToken, RefreshToken 18 | from datetime import timedelta 19 | from django.contrib.auth.hashers import make_password 20 | from django.conf import settings 21 | 22 | User = get_user_model() 23 | 24 | 25 | class LoginView(APIView): 26 | """ 27 | Login API view. 28 | """ 29 | permission_classes = [permissions.AllowAny] 30 | 31 | def post(self, request, *args, **kwargs): 32 | username = request.data.get("username") 33 | email = request.data.get("email") 34 | password = request.data.get("password") 35 | client_id = request.data.get("client_id") 36 | 37 | if username: 38 | if password is None or client_id is None: 39 | return Response({"error": "username, password and client_id are required"}, 40 | status=status.HTTP_400_BAD_REQUEST) 41 | user = authenticate(request, username=username, password=password) 42 | 43 | else: 44 | if email is None or password is None or client_id is None: 45 | return Response({"error": "email, password and client_id are required"}, 46 | status=status.HTTP_400_BAD_REQUEST) 47 | 48 | user = authenticate(request, username=email, password=password) 49 | 50 | if user is None: 51 | return Response({"error": "Invalid email or password"}, status=status.HTTP_401_UNAUTHORIZED) 52 | 53 | try: 54 | app = Application.objects.get(client_id=client_id) 55 | except Application.DoesNotExist: 56 | return Response({"error": "Invalid client_id"}, status=status.HTTP_401_UNAUTHORIZED) 57 | 58 | # Generate tokens for the user 59 | access_token = generate_token() 60 | refresh_token = generate_token() 61 | expires_in = timedelta(seconds=oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) 62 | expires = timezone.now() + expires_in 63 | 64 | AccessToken.objects.create( 65 | user=user, 66 | token=access_token, 67 | application=app, 68 | scope=oauth2_settings.DEFAULT_SCOPES, 69 | expires=timezone.now() + timedelta(seconds=oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) 70 | ) 71 | 72 | RefreshToken.objects.create( 73 | user=user, 74 | token=refresh_token, 75 | application=app, 76 | access_token=AccessToken.objects.get(token=access_token) 77 | ) 78 | context = { 79 | "access_token": access_token, 80 | "refresh_token": refresh_token, 81 | # "expires_in": expires_in.total_seconds() 82 | } 83 | 84 | return Response(context, status=status.HTTP_200_OK) 85 | 86 | 87 | class GoogleLoginView(APIView): 88 | """ 89 | View for Google login. 90 | """ 91 | permission_classes = [permissions.AllowAny] 92 | 93 | def post(self, request, *args, **kwargs): 94 | user = request.user 95 | if user.is_authenticated: 96 | try: 97 | # Get the OAuth2 Application 98 | app = Application.objects.get(name="google") 99 | except Application.DoesNotExist: 100 | return Response({"error": "OAuth2 Application not found."}, status=status.HTTP_404_NOT_FOUND) 101 | 102 | try: 103 | access_token = app.accesstoken_set.get(user=user) 104 | refresh_token = RefreshToken.objects.get(user=user, access_token=access_token) 105 | 106 | context = { 107 | "access_token": access_token.token, 108 | "refresh_token": refresh_token.token 109 | # "expires_in": access_token.expires 110 | } 111 | 112 | return Response(context, status=status.HTTP_200_OK) 113 | 114 | except AccessToken.DoesNotExist: 115 | return Response({"error": "Access token not found for the user."}, status=status.HTTP_404_NOT_FOUND) 116 | 117 | else: 118 | return redirect(load_strategy().build_absolute_uri('/social-auth/login/google-oauth2/')) 119 | 120 | 121 | class UserRegistrationView(generics.CreateAPIView): 122 | """ 123 | View for user registration. 124 | """ 125 | queryset = User.objects.all() 126 | serializer_class = UserRegistrationSerializer 127 | permission_classes = [permissions.AllowAny] 128 | 129 | def create(self, request, *args, **kwargs): 130 | email = request.data.get('email') 131 | if User.objects.filter(email=email).exists(): 132 | return Response({'error': 'Email address already exists'}, status=status.HTTP_400_BAD_REQUEST) 133 | 134 | password = request.data.get('password') 135 | request.data['password'] = make_password(password) 136 | 137 | serializer = self.get_serializer(data=request.data) 138 | serializer.is_valid(raise_exception=True) 139 | 140 | user = serializer.save() 141 | 142 | # Generate tokens for the user 143 | app = get_object_or_404(Application, name=settings.APPLICATION_NAME) 144 | access_token = generate_token() 145 | refresh_token = generate_token() 146 | 147 | AccessToken.objects.create( 148 | user=user, 149 | token=access_token, 150 | application=app, 151 | scope=oauth2_settings.DEFAULT_SCOPES, 152 | expires=timezone.now() + timedelta(seconds=oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) 153 | ) 154 | 155 | RefreshToken.objects.create( 156 | user=user, 157 | token=refresh_token, 158 | application=app, 159 | access_token=AccessToken.objects.get(token=access_token) 160 | ) 161 | 162 | tokens = { 163 | 'access_token': access_token, 164 | 'refresh_token': refresh_token 165 | } 166 | 167 | context = serializer.data.copy() 168 | context.update(tokens) 169 | 170 | return Response(context, status=status.HTTP_201_CREATED) 171 | 172 | 173 | class UserProfileView(generics.RetrieveUpdateAPIView): 174 | """ 175 | View for user profile retrieval and update. 176 | """ 177 | serializer_class = UserProfileSerializer 178 | permission_classes = [IsAuthenticated] 179 | 180 | def get_object(self): 181 | return self.request.user 182 | 183 | def perform_update(self, serializer): 184 | user = serializer.save() 185 | if 'new_password' in self.request.data: 186 | new_password = self.request.data['new_password'] 187 | user.set_password(new_password) 188 | user.save() 189 | 190 | # Send an email to notify the user about the password change 191 | subject = 'Password Changed' 192 | message = 'Your password has been changed successfully.' 193 | recipient = user.email 194 | send_forgot_password_email.delay(subject, message, recipient) 195 | 196 | 197 | class LogoutView(APIView): 198 | """ 199 | View for user logout. 200 | """ 201 | 202 | def post(self, request, *args, **kwargs): 203 | token = request.auth 204 | if token: 205 | access_token = AccessToken.objects.filter(token=token) 206 | if access_token.exists(): 207 | access_token.delete() 208 | return Response({"detail": "Logout successful"}, status=status.HTTP_200_OK) 209 | 210 | return Response({"detail": "Invalid token"}, status=status.HTTP_400_BAD_REQUEST) 211 | -------------------------------------------------------------------------------- /chatbot/views.py: -------------------------------------------------------------------------------- 1 | from django.shortcuts import get_object_or_404 2 | from rest_framework import generics, status 3 | from rest_framework.views import APIView 4 | from rest_framework.response import Response 5 | from rest_framework.pagination import LimitOffsetPagination 6 | from celery.result import AsyncResult 7 | from django.core.exceptions import ObjectDoesNotExist 8 | from django.contrib.auth import get_user_model 9 | 10 | from .models import Conversation, Message 11 | from .serializers import ConversationSerializer, MessageSerializer 12 | from .tasks import send_gpt_request, generate_title_request 13 | 14 | User = get_user_model() 15 | 16 | 17 | class LastMessagesPagination(LimitOffsetPagination): 18 | """ 19 | Pagination class for last messages. 20 | """ 21 | default_limit = 10 22 | max_limit = 10 23 | 24 | 25 | # List and create conversations 26 | class ConversationListCreate(generics.ListCreateAPIView): 27 | """ 28 | List and create conversations. 29 | """ 30 | serializer_class = ConversationSerializer 31 | 32 | def get_queryset(self): 33 | return Conversation.objects.filter(user=self.request.user).order_by('created_at') 34 | 35 | def perform_create(self, serializer): 36 | serializer.save(user=self.request.user) 37 | 38 | 39 | # Retrieve, update, and delete a specific conversation 40 | class ConversationDetail(generics.RetrieveUpdateDestroyAPIView): 41 | """ 42 | Retrieve, update, and delete a specific conversation. 43 | """ 44 | serializer_class = ConversationSerializer 45 | 46 | def get_queryset(self): 47 | return Conversation.objects.filter(user=self.request.user) 48 | 49 | def delete(self, request, *args, **kwargs): 50 | conversation = self.get_object() 51 | if conversation.user != request.user: 52 | return Response(status=status.HTTP_403_FORBIDDEN) 53 | return super().delete(request, *args, **kwargs) 54 | 55 | 56 | # Archive a conversation 57 | class ConversationArchive(APIView): 58 | """ 59 | Archive a conversation. 60 | """ 61 | 62 | def patch(self, request, pk): 63 | conversation = get_object_or_404(Conversation, id=pk, user=request.user) 64 | if conversation.archive: 65 | conversation.archive = False 66 | conversation.save() 67 | return Response({"message": "remove from archive"}, status=status.HTTP_200_OK) 68 | else: 69 | conversation.archive = True 70 | conversation.save() 71 | return Response({"message": "add to archive"}, status=status.HTTP_200_OK) 72 | 73 | 74 | class ConversationFavourite(APIView): 75 | """ 76 | Favourite a conversation. 77 | """ 78 | 79 | def patch(self, request, pk): 80 | conversation = get_object_or_404(Conversation, id=pk, user=request.user) 81 | if conversation.favourite: 82 | conversation.favourite = False 83 | conversation.save() 84 | return Response({"message": "remove from favourite"}, status=status.HTTP_200_OK) 85 | else: 86 | conversation.favourite = True 87 | conversation.save() 88 | return Response({"message": "add to favourite"}, status=status.HTTP_200_OK) 89 | 90 | 91 | # Delete a conversation 92 | class ConversationDelete(APIView): 93 | """ 94 | Delete a conversation. 95 | """ 96 | 97 | def delete(self, request, pk): 98 | conversation = get_object_or_404(Conversation, id=pk, user=request.user) 99 | conversation.delete() 100 | return Response({"message": "conversation deleted"}, status=status.HTTP_200_OK) 101 | 102 | 103 | # List messages in a conversation 104 | class MessageList(generics.ListAPIView): 105 | """ 106 | List messages in a conversation. 107 | """ 108 | serializer_class = MessageSerializer 109 | pagination_class = LastMessagesPagination 110 | 111 | def get_queryset(self): 112 | conversation = get_object_or_404(Conversation, id=self.kwargs['conversation_id'], user=self.request.user) 113 | return Message.objects.filter(conversation=conversation).select_related('conversation') 114 | 115 | 116 | # Create a message in a conversation 117 | class MessageCreate(generics.CreateAPIView): 118 | """ 119 | Create a message in a conversation. 120 | """ 121 | serializer_class = MessageSerializer 122 | 123 | def perform_create(self, serializer): 124 | conversation = get_object_or_404(Conversation, id=self.kwargs['conversation_id'], user=self.request.user) 125 | serializer.save(conversation=conversation, is_from_user=True) 126 | 127 | # Retrieve the last 10 messages from the conversation 128 | messages = Message.objects.filter(conversation=conversation).order_by('-created_at')[:10][::-1] 129 | 130 | # Build the list of dictionaries containing the message data 131 | message_list = [] 132 | for msg in messages: 133 | if msg.is_from_user: 134 | message_list.append({"role": "user", "content": msg.content}) 135 | else: 136 | message_list.append({"role": "assistant", "content": msg.content}) 137 | 138 | name_space = User.objects.get(id=self.request.user.id).username 139 | 140 | from site_settings.models import SiteSetting 141 | # Get system prompt from site settings 142 | try: 143 | system_prompt_obj = SiteSetting.objects.first() 144 | system_prompt = system_prompt_obj.prompt 145 | except Exception as e: 146 | print(str(e)) 147 | system_prompt = "You are sonic you can do anything you want." 148 | 149 | # Call the Celery task to get a response from GPT-3 150 | task = send_gpt_request.apply_async(args=(message_list, name_space, system_prompt)) 151 | print(message_list) 152 | response = task.get() 153 | return [response, conversation.id, messages[0].id] 154 | 155 | def create(self, request, *args, **kwargs): 156 | serializer = self.get_serializer(data=request.data) 157 | serializer.is_valid(raise_exception=True) 158 | response_list = self.perform_create(serializer) 159 | assistant_response = response_list[0] 160 | conversation_id = response_list[1] 161 | last_user_message_id = response_list[2] 162 | 163 | try: 164 | # Store GPT response as a message 165 | message = Message( 166 | conversation_id=conversation_id, 167 | content=assistant_response, 168 | is_from_user=False, 169 | in_reply_to_id=last_user_message_id 170 | ) 171 | message.save() 172 | 173 | except ObjectDoesNotExist: 174 | error = f"Conversation with id {conversation_id} does not exist" 175 | Response({"error": error}, status=status.HTTP_400_BAD_REQUEST) 176 | except Exception as e: 177 | error_mgs = str(e) 178 | error = f"Failed to save GPT-3 response as a message: {error_mgs}" 179 | Response({"error": error}, status=status.HTTP_400_BAD_REQUEST) 180 | 181 | headers = self.get_success_headers(serializer.data) 182 | return Response({"response": assistant_response}, status=status.HTTP_200_OK, headers=headers) 183 | 184 | 185 | class ConversationRetrieveUpdateView(generics.RetrieveUpdateAPIView): 186 | """ 187 | Retrieve View to update or get the title 188 | """ 189 | queryset = Conversation.objects.all() 190 | serializer_class = ConversationSerializer 191 | lookup_url_kwarg = 'conversation_id' 192 | 193 | def retrieve(self, request, *args, **kwargs): 194 | conversation = self.get_object() 195 | 196 | if conversation.title == "Empty": 197 | messages = Message.objects.filter(conversation=conversation) 198 | 199 | if messages.exists(): 200 | message_list = [] 201 | for msg in messages: 202 | if msg.is_from_user: 203 | message_list.append({"role": "user", "content": msg.content}) 204 | else: 205 | message_list.append({"role": "assistant", "content": msg.content}) 206 | 207 | task = generate_title_request.apply_async(args=(message_list,)) 208 | my_title = task.get() 209 | # if length of title is greater than 55, truncate it 210 | my_title = my_title[:30] 211 | conversation.title = my_title 212 | conversation.save() 213 | serializer = self.get_serializer(conversation) 214 | return Response(serializer.data) 215 | else: 216 | return Response({"message": "No messages in conversation."}, status=status.HTTP_204_NO_CONTENT) 217 | else: 218 | serializer = self.get_serializer(conversation) 219 | return Response(serializer.data) 220 | 221 | 222 | class GPT3TaskStatus(APIView): 223 | """ 224 | Check the status of a GPT task and return the result if it's ready. 225 | """ 226 | 227 | def get(self, request, task_id, *args, **kwargs): 228 | task = AsyncResult(task_id) 229 | 230 | if task.ready(): 231 | response = task.result 232 | return Response({"status": "READY", "response": response}) 233 | else: 234 | return Response({"status": "PENDING"}) 235 | -------------------------------------------------------------------------------- /config/settings/common.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from corsheaders.defaults import default_headers 4 | from urllib.parse import quote 5 | 6 | BASE_DIR = Path(__file__).resolve().parent.parent.parent 7 | 8 | DJANGO_APPS = [ 9 | 'django.contrib.admin', 10 | 'django.contrib.auth', 11 | 'django.contrib.contenttypes', 12 | 'django.contrib.sessions', 13 | 'django.contrib.messages', 14 | 'django.contrib.staticfiles', 15 | 'django.contrib.sites', 16 | ] 17 | 18 | THIRD_PARTY_APPS = [ 19 | 'rest_framework', 20 | 'django_filters', 21 | 'drf_yasg', # another way to swagger 22 | 'corsheaders', # Cross Origin 23 | 'django_celery_results', # Store Celery Result and cache 24 | 25 | # Social Authentication 26 | 'oauth2_provider', 27 | 'social_django', 28 | ] 29 | 30 | LOCAL_APPS = [ 31 | 'chatbot.apps.ChatbotConfig', 32 | 'users.apps.UsersConfig', 33 | 'site_settings.apps.SiteSettingsConfig', 34 | 'training_model.apps.TrainingModelConfig', 35 | 36 | ] 37 | 38 | INSTALLED_APPS = DJANGO_APPS + THIRD_PARTY_APPS + LOCAL_APPS 39 | 40 | MIDDLEWARE = [ 41 | 'django.middleware.security.SecurityMiddleware', 42 | 'django.contrib.sessions.middleware.SessionMiddleware', 43 | 44 | 'django.middleware.csrf.CsrfViewMiddleware', 45 | 'django.contrib.auth.middleware.AuthenticationMiddleware', 46 | 'django.contrib.messages.middleware.MessageMiddleware', 47 | 'django.middleware.clickjacking.XFrameOptionsMiddleware', 48 | 49 | # CROSS Origin 50 | 'corsheaders.middleware.CorsMiddleware', 51 | 52 | 'django.middleware.common.CommonMiddleware', 53 | ] 54 | 55 | # SECURITY WARNING: don't run with debug turned on in production! 56 | SITE_ID = int(os.getenv('SITE_ID', 1)) 57 | 58 | DEBUG = os.getenv('DJANGO_DEBUG', True) 59 | SITE_URL = os.getenv('SITE_URL', 'http://localhost:8000') 60 | 61 | SECRET_KEY = os.getenv('DJANGO_SECRET_KEY', 'django-insecure-p!1w7j+^j5v8y-@$_9j*8mr-)l#$u=08=c)!=(b1dleci18$7+') 62 | ROOT_URLCONF = 'config.urls' 63 | WSGI_APPLICATION = 'config.wsgi.application' 64 | ASGI_APPLICATION = 'config.asgi.application' 65 | 66 | TEMPLATES = [ 67 | { 68 | 'BACKEND': 'django.template.backends.django.DjangoTemplates', 69 | 'DIRS': [BASE_DIR / 'templates'] 70 | , 71 | 'APP_DIRS': True, 72 | 'OPTIONS': { 73 | 'context_processors': [ 74 | 'django.template.context_processors.debug', 75 | 'django.template.context_processors.request', 76 | 'django.contrib.auth.context_processors.auth', 77 | 'django.contrib.messages.context_processors.messages', 78 | ], 79 | }, 80 | }, 81 | ] 82 | 83 | # Password validation 84 | # https://docs.djangoproject.com/en/4.1/ref/settings/#auth-password-validators 85 | 86 | AUTH_PASSWORD_VALIDATORS = [ 87 | { 88 | 'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator', 89 | }, 90 | { 91 | 'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator', 92 | }, 93 | { 94 | 'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator', 95 | }, 96 | { 97 | 'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator', 98 | }, 99 | ] 100 | 101 | # Email 102 | EMAIL_BACKEND = os.getenv('EMAIL_BACKEND', 'django.core.mail.backends.smtp.EmailBackend') 103 | EMAIL_HOST = os.getenv('EMAIL_HOST', 'localhost') 104 | EMAIL_PORT = os.getenv('EMAIL_PORT', 1025) 105 | EMAIL_FROM = os.getenv('EMAIL_FROM', 'noreply@somehost.local') 106 | 107 | EMAIL_USE_TLS = os.getenv('EMAIL_USE_TLS', True) 108 | EMAIL_HOST_USER = os.getenv('EMAIL_HOST_USER', 'shamspias0@gmail.com') 109 | EMAIL_HOST_PASSWORD = os.getenv('EMAIL_HOST_PASSWORD', 'password') 110 | 111 | ADMINS = () 112 | 113 | LOGIN_REDIRECT_URL = (os.getenv('LOGIN_REDIRECT_URL', '/')) 114 | # CORS 115 | 116 | CSRF_COOKIE_SECURE = bool(os.getenv('CSRF_COOKIE_SECURE', True)) 117 | SESSION_COOKIE_SECURE = bool(os.getenv('SESSION_COOKIE_SECURE', True)) 118 | 119 | # False since we will grab it via universal-cookies 120 | CSRF_COOKIE_HTTPONLY = bool(os.getenv('CSRF_COOKIE_HTTPONLY', False)) 121 | 122 | SESSION_COOKIE_HTTPONLY = bool(os.getenv('SESSION_COOKIE_HTTPONLY', True)) 123 | SESSION_COOKIE_SAMESITE = os.getenv('SESSION_COOKIE_SAMESITE', "None") 124 | CSRF_COOKIE_SAMESITE = os.getenv('CSRF_COOKIE_SAMESITE', "None") 125 | CORS_ALLOW_CREDENTIALS = bool(os.getenv('CORS_ALLOW_CREDENTIALS', True)) 126 | CORS_ORIGIN_ALLOW_ALL = bool(os.getenv('CORS_ORIGIN_ALLOW_ALL', True)) 127 | CSRF_COOKIE_NAME = os.getenv('CSRF_COOKIE_NAME', "csrftoken") 128 | CSRF_TRUSTED_ORIGIN = os.getenv('CSRF_TRUSTED_ORIGINS', "localhost:3000") 129 | CSRF_TRUSTED_ORIGIN = CSRF_TRUSTED_ORIGIN.split(',') 130 | CSRF_TRUSTED_ORIGINS = CSRF_TRUSTED_ORIGIN 131 | 132 | # 133 | # CORS_ALLOW_ORIGINS = os.getenv('CORS_ALLOWED_ORIGINS') 134 | # CORS_ALLOW_ORIGINS = CORS_ALLOW_ORIGINS.split(',') 135 | # CORS_ALLOWED_ORIGINS = CORS_ALLOW_ORIGINS 136 | # # 137 | # CORS_ALLOW_METHODS = ( 138 | # 'GET', 139 | # 'POST', 140 | # 'PUT', 141 | # 'PATCH', 142 | # 'DELETE', 143 | # 'OPTIONS' 144 | # ) 145 | # 146 | CORS_ALLOW_HEADERS = list(default_headers) + [ 147 | 'X-CSRFToken', 148 | ] 149 | CORS_EXPOSE_HEADERS = ['Content-Type', 'X-CSRFToken'] 150 | 151 | X_FRAME_OPTIONS = os.getenv('X_FRAME_OPTIONS', 'DENY') 152 | SECURE_BROWSER_XSS_FILTER = bool(os.getenv('SECURE_BROWSER_XSS_FILTER', True)) 153 | # GENERALS 154 | APPEND_SLASH = bool(os.getenv('APPEND_SLASH', True)) 155 | 156 | LANGUAGE_CODE = os.getenv('LANGUAGE_CODE', 'en-us') 157 | 158 | TIME_ZONE = os.getenv('TIME_ZONE', 'UTC') 159 | USE_I18N = bool(os.getenv('USE_I18N', True)) 160 | USE_TZ = bool(os.getenv('USE_TZ', True)) 161 | USE_L10N = bool(os.getenv('USE_L10N', True)) 162 | 163 | # Headers 164 | USE_X_FORWARDED_HOST = True 165 | SECURE_PROXY_SSL_HEADER = ('HTTP_X_FORWARDED_PROTO', 'https') 166 | 167 | # Logging 168 | LOGGING = { 169 | 'version': 1, 170 | 'disable_existing_loggers': False, 171 | 'formatters': { 172 | 'django.server': { 173 | '()': 'django.utils.log.ServerFormatter', 174 | 'format': '[%(server_time)s] %(message)s', 175 | }, 176 | 'verbose': {'format': '%(levelname)s %(asctime)s %(module)s %(process)d %(thread)d %(message)s'}, 177 | 'simple': {'format': '%(levelname)s %(message)s'}, 178 | }, 179 | 'filters': { 180 | 'require_debug_true': { 181 | '()': 'django.utils.log.RequireDebugTrue', 182 | }, 183 | }, 184 | 'handlers': { 185 | 'django.server': { 186 | 'level': 'INFO', 187 | 'class': 'logging.StreamHandler', 188 | 'formatter': 'django.server', 189 | }, 190 | 'console': {'level': 'DEBUG', 'class': 'logging.StreamHandler', 'formatter': 'simple'}, 191 | 'mail_admins': {'level': 'ERROR', 'class': 'django.utils.log.AdminEmailHandler'}, 192 | }, 193 | 'loggers': { 194 | 'django': { 195 | 'handlers': ['console'], 196 | 'propagate': True, 197 | }, 198 | 'django.server': { 199 | 'handlers': ['django.server'], 200 | 'level': 'INFO', 201 | 'propagate': False, 202 | }, 203 | 'django.request': { 204 | 'handlers': ['mail_admins', 'console'], 205 | 'level': 'ERROR', 206 | 'propagate': False, 207 | }, 208 | 'django.db.backends': {'handlers': ['console'], 'level': 'INFO'}, 209 | }, 210 | } 211 | 212 | # Custom user app 213 | AUTH_USER_MODEL = os.getenv('AUTH_USER_MODEL', 'users.User') 214 | 215 | AUTHENTICATION_BACKENDS = ( 216 | 'users.backends.EmailBackend', 217 | 'social_core.backends.google.GoogleOAuth2', 218 | 'django.contrib.auth.backends.ModelBackend', 219 | ) 220 | 221 | # Django Rest Framework 222 | REST_FRAMEWORK = { 223 | 'DEFAULT_PAGINATION_CLASS': 'rest_framework.pagination.LimitOffsetPagination', 224 | 'DEFAULT_FILTER_BACKENDS': ['django_filters.rest_framework.DjangoFilterBackend', 225 | 'rest_framework.filters.OrderingFilter'], 226 | 'PAGE_SIZE': int(os.getenv('DJANGO_PAGINATION_LIMIT', 18)), 227 | 'DATETIME_FORMAT': '%Y-%m-%dT%H:%M:%S.%fZ', 228 | 'DEFAULT_RENDERER_CLASSES': ( 229 | 'rest_framework.renderers.JSONRenderer', 230 | 'rest_framework.renderers.BrowsableAPIRenderer', 231 | ), 232 | 'DEFAULT_PERMISSION_CLASSES': [ 233 | 'rest_framework.permissions.IsAuthenticated', 234 | ], 235 | 'DEFAULT_AUTHENTICATION_CLASSES': ( 236 | 'oauth2_provider.contrib.rest_framework.OAuth2Authentication', 237 | 'rest_framework.authentication.SessionAuthentication', 238 | 'rest_framework.authentication.BasicAuthentication' 239 | ), 240 | 'DEFAULT_THROTTLE_CLASSES': [ 241 | 'rest_framework.throttling.AnonRateThrottle', 242 | 'rest_framework.throttling.UserRateThrottle', 243 | 'rest_framework.throttling.ScopedRateThrottle', 244 | ], 245 | 'DEFAULT_THROTTLE_RATES': {'anon': '100/second', 'user': '1000/second', 'subscribe': '60/minute'}, 246 | 'TEST_REQUEST_DEFAULT_FORMAT': 'json', 247 | } 248 | 249 | DEFAULT_AUTO_FIELD = 'django.db.models.BigAutoField' 250 | 251 | # # AWS 252 | # AWS_ACCESS_KEY = quote(os.getenv('AWS_ACCESS_KEY'), safe='') 253 | # AWS_SECRET_KEY = quote(os.getenv('AWS_SECRET_KEY'), safe='') 254 | # REGION_NAME = quote(os.getenv('REGION_NAME'), safe='') 255 | # QUEUE_NAME = quote(os.getenv('QUEUE_NAME'), safe='') 256 | # 257 | # """ 258 | # AWS celery configuration 259 | # """ 260 | # 261 | # BROKER_URL = 'sqs://{access_key}:{secret_key}@'.format( 262 | # access_key=AWS_ACCESS_KEY, 263 | # secret_key=AWS_SECRET_KEY, 264 | # ) 265 | # # RESULT_BACKEND = '{}{}/{}celery'.format(BROKER_URL, REGION_NAME, QUEUE_NAME) 266 | # 267 | # BROKER_TRANSPORT_OPTIONS = { 268 | # 'region': REGION_NAME, 269 | # 'visibility_timeout': 60, # 1 minutes 270 | # # 'polling_interval': 5, # 5 seconds 271 | # # 'queue_name_prefix': QUEUE_NAME 272 | # } 273 | # 274 | # # CELERY namespaced 275 | # CELERY_BROKER_URL = BROKER_URL 276 | # CELERY_BROKER_TRANSPORT_OPTIONS = BROKER_TRANSPORT_OPTIONS 277 | # # CELERY_TASK_DEFAULT_QUEUE = QUEUE_NAME 278 | # 279 | # CELERY_ACCEPT_CONTENT = ['application/json'] 280 | # CELERY_TASK_SERIALIZER = 'json' 281 | # CELERY_RESULT_SERIALIZER = 'json' 282 | # CELERY_TIMEZONE = 'UTC' 283 | # CELERY_RESULT_BACKEND = 'django-db' # using django-celery-results 284 | # CELERY_CACHE_BACKEND = 'django-cache' 285 | # 286 | --------------------------------------------------------------------------------