├── Model Creation.ipynb ├── README.md ├── data.zip └── main.py /Model Creation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 1. Import Libraries" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 2, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import random\n", 17 | "import copy\n", 18 | "import time\n", 19 | "import pandas as pd\n", 20 | "import numpy as np\n", 21 | "import gc\n", 22 | "import re\n", 23 | "import pickle\n", 24 | "from sklearn.model_selection import train_test_split\n", 25 | "from collections import Counter\n", 26 | "from sklearn.feature_extraction.text import CountVectorizer,TfidfVectorizer, HashingVectorizer\n", 27 | "from sklearn.linear_model import LogisticRegression\n", 28 | "from sklearn.multiclass import OneVsRestClassifier\n", 29 | "from sklearn.linear_model import LogisticRegression\n", 30 | "from sklearn.metrics import accuracy_score\n", 31 | "import gensim\n", 32 | "from sklearn.metrics.pairwise import pairwise_distances_argmin\n", 33 | "import nltk\n", 34 | "#nltk.download('stopwords')\n", 35 | "from nltk.corpus import stopwords" 36 | ] 37 | }, 38 | { 39 | "cell_type": "markdown", 40 | "metadata": {}, 41 | "source": [ 42 | "# 2. Read the Data" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 3, 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "dialogues = pd.read_csv(\"data/dialogues.tsv\",sep=\"\\t\")" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 4, 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "posts = pd.read_csv(\"data/tagged_posts.tsv\",sep=\"\\t\")" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": 5, 66 | "metadata": {}, 67 | "outputs": [ 68 | { 69 | "data": { 70 | "text/html": [ 71 | "
\n", 72 | "\n", 85 | "\n", 86 | " \n", 87 | " \n", 88 | " \n", 89 | " \n", 90 | " \n", 91 | " \n", 92 | " \n", 93 | " \n", 94 | " \n", 95 | " \n", 96 | " \n", 97 | " \n", 98 | " \n", 99 | " \n", 100 | " \n", 101 | " \n", 102 | " \n", 103 | " \n", 104 | " \n", 105 | " \n", 106 | " \n", 107 | " \n", 108 | " \n", 109 | " \n", 110 | " \n", 111 | " \n", 112 | " \n", 113 | " \n", 114 | " \n", 115 | " \n", 116 | " \n", 117 | " \n", 118 | " \n", 119 | " \n", 120 | "
texttag
0Okay -- you're gonna need to learn how to lie.dialogue
1I'm kidding. You know how sometimes you just ...dialogue
2Like my fear of wearing pastels?dialogue
3I figured you'd get to the good stuff eventually.dialogue
4Thank God! If I had to hear one more story ab...dialogue
\n", 121 | "
" 122 | ], 123 | "text/plain": [ 124 | " text tag\n", 125 | "0 Okay -- you're gonna need to learn how to lie. dialogue\n", 126 | "1 I'm kidding. You know how sometimes you just ... dialogue\n", 127 | "2 Like my fear of wearing pastels? dialogue\n", 128 | "3 I figured you'd get to the good stuff eventually. dialogue\n", 129 | "4 Thank God! If I had to hear one more story ab... dialogue" 130 | ] 131 | }, 132 | "execution_count": 5, 133 | "metadata": {}, 134 | "output_type": "execute_result" 135 | } 136 | ], 137 | "source": [ 138 | "dialogues.head()" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": 6, 144 | "metadata": {}, 145 | "outputs": [ 146 | { 147 | "data": { 148 | "text/html": [ 149 | "
\n", 150 | "\n", 163 | "\n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | "
post_idtitletag
09Calculate age in C#c#
116Filling a DataSet or DataTable from a LINQ que...c#
239Reliable timer in a console applicationc#
342Best way to allow plugins for a PHP applicationphp
459How do I get a distinct, ordered list of names...c#
\n", 205 | "
" 206 | ], 207 | "text/plain": [ 208 | " post_id title tag\n", 209 | "0 9 Calculate age in C# c#\n", 210 | "1 16 Filling a DataSet or DataTable from a LINQ que... c#\n", 211 | "2 39 Reliable timer in a console application c#\n", 212 | "3 42 Best way to allow plugins for a PHP application php\n", 213 | "4 59 How do I get a distinct, ordered list of names... c#" 214 | ] 215 | }, 216 | "execution_count": 6, 217 | "metadata": {}, 218 | "output_type": "execute_result" 219 | } 220 | ], 221 | "source": [ 222 | "posts.head()" 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": 7, 228 | "metadata": {}, 229 | "outputs": [ 230 | { 231 | "name": "stdout", 232 | "output_type": "stream", 233 | "text": [ 234 | "Num Posts: 2171575\n", 235 | "Num Dialogues: 218609\n" 236 | ] 237 | } 238 | ], 239 | "source": [ 240 | "print(\"Num Posts:\",len(posts))\n", 241 | "print(\"Num Dialogues:\",len(dialogues))" 242 | ] 243 | }, 244 | { 245 | "cell_type": "markdown", 246 | "metadata": {}, 247 | "source": [ 248 | "# 3. Create training data for intent classifier - Chitchat/SO Question" 249 | ] 250 | }, 251 | { 252 | "cell_type": "code", 253 | "execution_count": 8, 254 | "metadata": {}, 255 | "outputs": [], 256 | "source": [ 257 | "texts = list(dialogues[:200000].text.values) + list(posts[:200000].title.values)\n", 258 | "labels = ['dialogue']*200000 + ['stackoverflow']*200000" 259 | ] 260 | }, 261 | { 262 | "cell_type": "code", 263 | "execution_count": 9, 264 | "metadata": {}, 265 | "outputs": [], 266 | "source": [ 267 | "data = pd.DataFrame({'text':texts,'target':labels})" 268 | ] 269 | }, 270 | { 271 | "cell_type": "code", 272 | "execution_count": 10, 273 | "metadata": {}, 274 | "outputs": [], 275 | "source": [ 276 | "def text_prepare(text):\n", 277 | " \"\"\"Performs tokenization and simple preprocessing.\"\"\"\n", 278 | " \n", 279 | " replace_by_space_re = re.compile('[/(){}\\[\\]\\|@,;]')\n", 280 | " bad_symbols_re = re.compile('[^0-9a-z #+_]')\n", 281 | " stopwords_set = set(stopwords.words('english'))\n", 282 | "\n", 283 | " text = text.lower()\n", 284 | " text = replace_by_space_re.sub(' ', text)\n", 285 | " text = bad_symbols_re.sub('', text)\n", 286 | " text = ' '.join([x for x in text.split() if x and x not in stopwords_set])\n", 287 | "\n", 288 | " return text.strip()" 289 | ] 290 | }, 291 | { 292 | "cell_type": "code", 293 | "execution_count": 11, 294 | "metadata": {}, 295 | "outputs": [], 296 | "source": [ 297 | "# Doing some data cleaning\n", 298 | "data['text'] = data['text'].apply(lambda x : text_prepare(x))" 299 | ] 300 | }, 301 | { 302 | "cell_type": "code", 303 | "execution_count": 12, 304 | "metadata": {}, 305 | "outputs": [ 306 | { 307 | "name": "stdout", 308 | "output_type": "stream", 309 | "text": [ 310 | "Train size = 360000, test size = 40000\n" 311 | ] 312 | } 313 | ], 314 | "source": [ 315 | "X_train, X_test, y_train, y_test = train_test_split(data['text'],data['target'],test_size = .1 , random_state=0)\n", 316 | "\n", 317 | "print('Train size = {}, test size = {}'.format(len(X_train), len(X_test)))\n" 318 | ] 319 | }, 320 | { 321 | "cell_type": "markdown", 322 | "metadata": {}, 323 | "source": [ 324 | "# 4. Create Intent classifier" 325 | ] 326 | }, 327 | { 328 | "cell_type": "code", 329 | "execution_count": 13, 330 | "metadata": {}, 331 | "outputs": [], 332 | "source": [ 333 | "# We will keep our models and vectorizers in this folder" 334 | ] 335 | }, 336 | { 337 | "cell_type": "code", 338 | "execution_count": 14, 339 | "metadata": {}, 340 | "outputs": [ 341 | { 342 | "name": "stdout", 343 | "output_type": "stream", 344 | "text": [ 345 | "mkdir: resources: File exists\r\n" 346 | ] 347 | } 348 | ], 349 | "source": [ 350 | "!mkdir resources" 351 | ] 352 | }, 353 | { 354 | "cell_type": "code", 355 | "execution_count": 15, 356 | "metadata": {}, 357 | "outputs": [], 358 | "source": [ 359 | "def tfidf_features(X_train, X_test, vectorizer_path):\n", 360 | " \"\"\"Performs TF-IDF transformation and dumps the model.\"\"\"\n", 361 | " tfv = TfidfVectorizer(dtype=np.float32, min_df=3, max_features=None, \n", 362 | " strip_accents='unicode', analyzer='word',token_pattern=r'\\w{1,}',\n", 363 | " ngram_range=(1, 3), use_idf=1,smooth_idf=1,sublinear_tf=1,\n", 364 | " stop_words = 'english')\n", 365 | " \n", 366 | " X_train = tfv.fit_transform(X_train)\n", 367 | " X_test = tfv.transform(X_test)\n", 368 | " \n", 369 | " pickle.dump(tfv,vectorizer_path)\n", 370 | " return X_train, X_test" 371 | ] 372 | }, 373 | { 374 | "cell_type": "code", 375 | "execution_count": 16, 376 | "metadata": {}, 377 | "outputs": [], 378 | "source": [ 379 | "X_train_tfidf, X_test_tfidf = tfidf_features(X_train, X_test, open(\"resources/tfidf.pkl\",'wb'))" 380 | ] 381 | }, 382 | { 383 | "cell_type": "code", 384 | "execution_count": 17, 385 | "metadata": {}, 386 | "outputs": [ 387 | { 388 | "name": "stderr", 389 | "output_type": "stream", 390 | "text": [ 391 | "/miniconda3/envs/py36/lib/python3.6/site-packages/sklearn/linear_model/logistic.py:433: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n", 392 | " FutureWarning)\n" 393 | ] 394 | }, 395 | { 396 | "data": { 397 | "text/plain": [ 398 | "LogisticRegression(C=10, class_weight=None, dual=False, fit_intercept=True,\n", 399 | " intercept_scaling=1, max_iter=100, multi_class='warn',\n", 400 | " n_jobs=None, penalty='l2', random_state=0, solver='warn',\n", 401 | " tol=0.0001, verbose=0, warm_start=False)" 402 | ] 403 | }, 404 | "execution_count": 17, 405 | "metadata": {}, 406 | "output_type": "execute_result" 407 | } 408 | ], 409 | "source": [ 410 | "intent_recognizer = LogisticRegression(C=10,random_state=0)\n", 411 | "intent_recognizer.fit(X_train_tfidf,y_train)\n" 412 | ] 413 | }, 414 | { 415 | "cell_type": "code", 416 | "execution_count": 18, 417 | "metadata": {}, 418 | "outputs": [ 419 | { 420 | "name": "stdout", 421 | "output_type": "stream", 422 | "text": [ 423 | "Test accuracy = 0.989825\n" 424 | ] 425 | } 426 | ], 427 | "source": [ 428 | "# Check test accuracy.\n", 429 | "y_test_pred = intent_recognizer.predict(X_test_tfidf)\n", 430 | "test_accuracy = accuracy_score(y_test, y_test_pred)\n", 431 | "print('Test accuracy = {}'.format(test_accuracy))" 432 | ] 433 | }, 434 | { 435 | "cell_type": "code", 436 | "execution_count": 19, 437 | "metadata": {}, 438 | "outputs": [], 439 | "source": [ 440 | "pickle.dump(intent_recognizer, open(\"resources/intent_clf.pkl\" , 'wb'))" 441 | ] 442 | }, 443 | { 444 | "cell_type": "markdown", 445 | "metadata": {}, 446 | "source": [ 447 | "# 5 Create Programming Language classifier" 448 | ] 449 | }, 450 | { 451 | "cell_type": "code", 452 | "execution_count": 20, 453 | "metadata": {}, 454 | "outputs": [], 455 | "source": [ 456 | "X = posts['title'].values\n", 457 | "y = posts['tag'].values" 458 | ] 459 | }, 460 | { 461 | "cell_type": "code", 462 | "execution_count": 21, 463 | "metadata": {}, 464 | "outputs": [ 465 | { 466 | "name": "stdout", 467 | "output_type": "stream", 468 | "text": [ 469 | "Train size = 1737260, test size = 434315\n" 470 | ] 471 | } 472 | ], 473 | "source": [ 474 | "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)\n", 475 | "print('Train size = {}, test size = {}'.format(len(X_train), len(X_test)))" 476 | ] 477 | }, 478 | { 479 | "cell_type": "code", 480 | "execution_count": 22, 481 | "metadata": {}, 482 | "outputs": [], 483 | "source": [ 484 | "vectorizer = pickle.load(open(\"resources/tfidf.pkl\", 'rb'))\n", 485 | "X_train_tfidf, X_test_tfidf = vectorizer.transform(X_train), vectorizer.transform(X_test)" 486 | ] 487 | }, 488 | { 489 | "cell_type": "code", 490 | "execution_count": 24, 491 | "metadata": {}, 492 | "outputs": [ 493 | { 494 | "name": "stderr", 495 | "output_type": "stream", 496 | "text": [ 497 | "/miniconda3/envs/py36/lib/python3.6/site-packages/sklearn/linear_model/logistic.py:433: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n", 498 | " FutureWarning)\n" 499 | ] 500 | }, 501 | { 502 | "data": { 503 | "text/plain": [ 504 | "OneVsRestClassifier(estimator=LogisticRegression(C=5, class_weight=None, dual=False, fit_intercept=True,\n", 505 | " intercept_scaling=1, max_iter=100, multi_class='warn',\n", 506 | " n_jobs=None, penalty='l2', random_state=0, solver='warn',\n", 507 | " tol=0.0001, verbose=0, warm_start=False),\n", 508 | " n_jobs=None)" 509 | ] 510 | }, 511 | "execution_count": 24, 512 | "metadata": {}, 513 | "output_type": "execute_result" 514 | } 515 | ], 516 | "source": [ 517 | "tag_classifier = OneVsRestClassifier(LogisticRegression(C=5,random_state=0))\n", 518 | "tag_classifier.fit(X_train_tfidf,y_train)" 519 | ] 520 | }, 521 | { 522 | "cell_type": "code", 523 | "execution_count": 25, 524 | "metadata": {}, 525 | "outputs": [ 526 | { 527 | "name": "stdout", 528 | "output_type": "stream", 529 | "text": [ 530 | "Test accuracy = 0.8043816124241622\n" 531 | ] 532 | } 533 | ], 534 | "source": [ 535 | "# Check test accuracy.\n", 536 | "y_test_pred = tag_classifier.predict(X_test_tfidf)\n", 537 | "test_accuracy = accuracy_score(y_test, y_test_pred)\n", 538 | "print('Test accuracy = {}'.format(test_accuracy))" 539 | ] 540 | }, 541 | { 542 | "cell_type": "code", 543 | "execution_count": 26, 544 | "metadata": {}, 545 | "outputs": [], 546 | "source": [ 547 | "pickle.dump(tag_classifier, open(\"resources/tag_clf.pkl\", 'wb'))" 548 | ] 549 | }, 550 | { 551 | "cell_type": "markdown", 552 | "metadata": {}, 553 | "source": [ 554 | "# 6. Store Question database Embeddings" 555 | ] 556 | }, 557 | { 558 | "cell_type": "markdown", 559 | "metadata": {}, 560 | "source": [ 561 | "You can use [pre-trained word vectors](https://code.google.com/archive/p/word2vec/) from Google." 562 | ] 563 | }, 564 | { 565 | "cell_type": "code", 566 | "execution_count": 27, 567 | "metadata": {}, 568 | "outputs": [], 569 | "source": [ 570 | "\n", 571 | "\n", 572 | "# Load Google's pre-trained Word2Vec model.\n", 573 | "model = gensim.models.KeyedVectors.load_word2vec_format('GoogleNews-vectors-negative300.bin', binary=True) " 574 | ] 575 | }, 576 | { 577 | "cell_type": "markdown", 578 | "metadata": {}, 579 | "source": [ 580 | "We want to convert every question to an embedding and store them. Whenever user asks a stack overflow question we want to use cosine similarity to get the most similar question" 581 | ] 582 | }, 583 | { 584 | "cell_type": "code", 585 | "execution_count": 28, 586 | "metadata": {}, 587 | "outputs": [], 588 | "source": [ 589 | "def question_to_vec(question, embeddings, dim=300):\n", 590 | " \"\"\"\n", 591 | " question: a string\n", 592 | " embeddings: dict where the key is a word and a value is its' embedding\n", 593 | " dim: size of the representation\n", 594 | "\n", 595 | " result: vector representation for the question\n", 596 | " \"\"\"\n", 597 | " word_tokens = question.split(\" \")\n", 598 | " question_len = len(word_tokens)\n", 599 | " question_mat = np.zeros((question_len,dim), dtype = np.float32)\n", 600 | " \n", 601 | " for idx, word in enumerate(word_tokens):\n", 602 | " if word in embeddings:\n", 603 | " question_mat[idx,:] = embeddings[word]\n", 604 | " \n", 605 | " # remove zero-rows which stand for OOV words \n", 606 | " question_mat = question_mat[~np.all(question_mat == 0, axis = 1)]\n", 607 | " \n", 608 | " # Compute the mean of each word along the sentence\n", 609 | " if question_mat.shape[0] > 0:\n", 610 | " vec = np.array(np.mean(question_mat, axis = 0), dtype = np.float32).reshape((1,dim))\n", 611 | " else:\n", 612 | " vec = np.zeros((1,dim), dtype = np.float32)\n", 613 | " \n", 614 | " return vec" 615 | ] 616 | }, 617 | { 618 | "cell_type": "code", 619 | "execution_count": 29, 620 | "metadata": {}, 621 | "outputs": [], 622 | "source": [ 623 | "counts_by_tag = posts.groupby(by=['tag'])[\"tag\"].count().reset_index(name = 'count').sort_values(['count'], ascending = False)" 624 | ] 625 | }, 626 | { 627 | "cell_type": "code", 628 | "execution_count": 30, 629 | "metadata": {}, 630 | "outputs": [ 631 | { 632 | "name": "stdout", 633 | "output_type": "stream", 634 | "text": [ 635 | "[('c#', 394451), ('java', 383456), ('javascript', 375867), ('php', 321752), ('c_cpp', 281300), ('python', 208607), ('ruby', 99930), ('r', 36359), ('vb', 35044), ('swift', 34809)]\n" 636 | ] 637 | } 638 | ], 639 | "source": [ 640 | "counts_by_tag = list(zip(counts_by_tag['tag'],counts_by_tag['count']))\n", 641 | "print(counts_by_tag)" 642 | ] 643 | }, 644 | { 645 | "cell_type": "code", 646 | "execution_count": 31, 647 | "metadata": {}, 648 | "outputs": [ 649 | { 650 | "name": "stdout", 651 | "output_type": "stream", 652 | "text": [ 653 | "mkdir: resources/embeddings_folder: File exists\r\n" 654 | ] 655 | } 656 | ], 657 | "source": [ 658 | "! mkdir resources/embeddings_folder" 659 | ] 660 | }, 661 | { 662 | "cell_type": "code", 663 | "execution_count": 32, 664 | "metadata": {}, 665 | "outputs": [], 666 | "source": [ 667 | "for tag, count in counts_by_tag:\n", 668 | " tag_posts = posts[posts['tag'] == tag]\n", 669 | " tag_post_ids = tag_posts['post_id'].values\n", 670 | " tag_vectors = np.zeros((count, 300), dtype=np.float32)\n", 671 | " for i, title in enumerate(tag_posts['title']):\n", 672 | " tag_vectors[i, :] = question_to_vec(title, model, 300)\n", 673 | " # Dump post ids and vectors to a file.\n", 674 | " filename = 'resources/embeddings_folder/'+ tag + '.pkl'\n", 675 | " pickle.dump((tag_post_ids, tag_vectors), open(filename, 'wb'))" 676 | ] 677 | }, 678 | { 679 | "cell_type": "markdown", 680 | "metadata": {}, 681 | "source": [ 682 | "# Given a question and tag can I retrieve the most similar question post_id\n" 683 | ] 684 | }, 685 | { 686 | "cell_type": "code", 687 | "execution_count": 33, 688 | "metadata": {}, 689 | "outputs": [], 690 | "source": [ 691 | "\n", 692 | "def get_similar_question(question,tag):\n", 693 | " # get the path where all question embeddings are kept and load the post_ids and post_embeddings\n", 694 | " embeddings_path = 'resources/embeddings_folder/' + tag + \".pkl\"\n", 695 | " post_ids, post_embeddings = pickle.load(open(embeddings_path, 'rb'))\n", 696 | " # Get the embeddings for the question\n", 697 | " question_vec = question_to_vec(question, model, 300)\n", 698 | " # find index of most similar post\n", 699 | " best_post_index = pairwise_distances_argmin(question_vec,\n", 700 | " post_embeddings)\n", 701 | " # return best post id\n", 702 | " return post_ids[best_post_index]" 703 | ] 704 | }, 705 | { 706 | "cell_type": "code", 707 | "execution_count": 37, 708 | "metadata": {}, 709 | "outputs": [ 710 | { 711 | "data": { 712 | "text/plain": [ 713 | "array([5947137])" 714 | ] 715 | }, 716 | "execution_count": 37, 717 | "metadata": {}, 718 | "output_type": "execute_result" 719 | } 720 | ], 721 | "source": [ 722 | "get_similar_question(\"how to use list comprehension in python?\",'python')" 723 | ] 724 | }, 725 | { 726 | "cell_type": "markdown", 727 | "metadata": {}, 728 | "source": [ 729 | "You can find this question at:\n", 730 | " \n", 731 | "https://stackoverflow.com/questions/8278287" 732 | ] 733 | } 734 | ], 735 | "metadata": { 736 | "kernelspec": { 737 | "display_name": "Python 3", 738 | "language": "python", 739 | "name": "python3" 740 | } 741 | }, 742 | "nbformat": 4, 743 | "nbformat_minor": 2 744 | } 745 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This is a repository for the post : https://mlwhiz.com/blog/2019/04/15/chatbot/ 2 | 3 | -------------------------------------------------------------------------------- /data.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLWhiz/chatbot/52145c66af6c6f0226dadcc3790c27e3fb5184f2/data.zip -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import requests 4 | import time 5 | import argparse 6 | import os 7 | import json 8 | from requests.compat import urljoin 9 | import gensim 10 | import pickle 11 | import re 12 | import nltk 13 | from nltk.corpus import stopwords 14 | import numpy as np 15 | from sklearn.metrics.pairwise import pairwise_distances_argmin 16 | 17 | class BotHandler(object): 18 | """ 19 | BotHandler is a class which implements all back-end of the bot. 20 | It has three main functions: 21 | 'get_updates' — checks for new messages 22 | 'send_message' – posts new message to user 23 | 'get_answer' — computes the most relevant on a user's question 24 | """ 25 | 26 | def __init__(self, token, dialogue_manager): 27 | # Put the Telegram Access token here 28 | self.token = token 29 | self.api_url = "https://api.telegram.org/bot{}/".format(token) 30 | self.dialogue_manager = dialogue_manager 31 | 32 | def get_updates(self, offset=None, timeout=30): 33 | params = {"timeout": timeout, "offset": offset} 34 | raw_resp = requests.get(urljoin(self.api_url, "getUpdates"), params) 35 | try: 36 | resp = raw_resp.json() 37 | except json.decoder.JSONDecodeError as e: 38 | print("Failed to parse response {}: {}.".format(raw_resp.content, e)) 39 | return [] 40 | 41 | if "result" not in resp: 42 | return [] 43 | return resp["result"] 44 | 45 | def send_message(self, chat_id, text): 46 | params = {"chat_id": chat_id, "text": text} 47 | return requests.post(urljoin(self.api_url, "sendMessage"), params) 48 | 49 | def get_answer(self, question): 50 | if question == '/start': 51 | return "Hi, I am your project bot. How can I help you today?" 52 | return self.dialogue_manager.generate_answer(question) 53 | 54 | 55 | def is_unicode(text): 56 | return len(text) == len(text.encode()) 57 | 58 | 59 | 60 | # We will need this function to prepare text at prediction time 61 | def text_prepare(text): 62 | """Performs tokenization and simple preprocessing.""" 63 | 64 | replace_by_space_re = re.compile('[/(){}\[\]\|@,;]') 65 | bad_symbols_re = re.compile('[^0-9a-z #+_]') 66 | stopwords_set = set(stopwords.words('english')) 67 | 68 | text = text.lower() 69 | text = replace_by_space_re.sub(' ', text) 70 | text = bad_symbols_re.sub('', text) 71 | text = ' '.join([x for x in text.split() if x and x not in stopwords_set]) 72 | 73 | return text.strip() 74 | 75 | # need this to convert questions asked by user to vectors 76 | def question_to_vec(question, embeddings, dim=300): 77 | """ 78 | question: a string 79 | embeddings: dict where the key is a word and a value is its' embedding 80 | dim: size of the representation 81 | 82 | result: vector representation for the question 83 | """ 84 | word_tokens = question.split(" ") 85 | question_len = len(word_tokens) 86 | question_mat = np.zeros((question_len,dim), dtype = np.float32) 87 | 88 | for idx, word in enumerate(word_tokens): 89 | if word in embeddings: 90 | question_mat[idx,:] = embeddings[word] 91 | 92 | # remove zero-rows which stand for OOV words 93 | question_mat = question_mat[~np.all(question_mat == 0, axis = 1)] 94 | 95 | # Compute the mean of each word along the sentence 96 | if question_mat.shape[0] > 0: 97 | vec = np.array(np.mean(question_mat, axis = 0), dtype = np.float32).reshape((1,dim)) 98 | else: 99 | vec = np.zeros((1,dim), dtype = np.float32) 100 | 101 | return vec 102 | 103 | class SimpleDialogueManager(object): 104 | """ 105 | This is a simple dialogue manager to test the telegram bot. 106 | The main part of our bot will be written here. 107 | """ 108 | 109 | def __init__(self): 110 | 111 | # Instantiate all the models and TFIDF Objects. 112 | print("Loading resources...") 113 | # Instantiate a Chatterbot for Chitchat type questions 114 | from chatterbot import ChatBot 115 | from chatterbot.trainers import ChatterBotCorpusTrainer 116 | chatbot = ChatBot('MLWhizChatterbot') 117 | trainer = ChatterBotCorpusTrainer(chatbot) 118 | trainer.train('chatterbot.corpus.english') 119 | self.chitchat_bot = chatbot 120 | print("Loading Word2vec model...") 121 | # Instantiate the Google's pre-trained Word2Vec model. 122 | self.model = gensim.models.KeyedVectors.load_word2vec_format('GoogleNews-vectors-negative300.bin', binary=True) 123 | print("Loading Classifier objects...") 124 | # Load the intent classifier and tag classifier 125 | self.intent_recognizer = pickle.load(open('resources/intent_clf.pkl', 'rb')) 126 | self.tag_classifier = pickle.load(open('resources/tag_clf.pkl', 'rb')) 127 | # Load the TFIDF vectorizer object 128 | self.tfidf_vectorizer = pickle.load(open('resources/tfidf.pkl', 'rb')) 129 | print("Finished Loading Resources") 130 | 131 | # We created this function just above. We just need to have a function to get most similar question's *post id* in the dataset given we know the programming Language of the question. Here it is: 132 | def get_similar_question(self,question,tag): 133 | # get the path where all question embeddings are kept and load the post_ids and post_embeddings 134 | embeddings_path = 'resources/embeddings_folder/' + tag + ".pkl" 135 | post_ids, post_embeddings = pickle.load(open(embeddings_path, 'rb')) 136 | # Get the embeddings for the question 137 | question_vec = question_to_vec(question, self.model, 300) 138 | # find index of most similar post 139 | best_post_index = pairwise_distances_argmin(question_vec, 140 | post_embeddings) 141 | # return best post id 142 | return post_ids[best_post_index] 143 | 144 | def generate_answer(self, question): 145 | prepared_question = text_prepare(question) 146 | features = self.tfidf_vectorizer.transform([prepared_question]) 147 | # find intent 148 | intent = self.intent_recognizer.predict(features)[0] 149 | # Chit-chat part: 150 | if intent == 'dialogue': 151 | response = self.chitchat_bot.get_response(question) 152 | # Stack Overflow Question 153 | else: 154 | # find programming language 155 | tag = self.tag_classifier.predict(features)[0] 156 | # find most similar question post id 157 | post_id = self.get_similar_question(question,tag)[0] 158 | # respond with 159 | response = 'I think its about %s\nThis thread might help you: https://stackoverflow.com/questions/%s' % (tag, post_id) 160 | return response 161 | 162 | def main(): 163 | token = '839585958:AAEfTDo2X6PgHb9IEdb62ueS4SmdpCkhtmc' 164 | simple_manager = SimpleDialogueManager() 165 | bot = BotHandler(token, simple_manager) 166 | ############################################################### 167 | 168 | print("Ready to talk!") 169 | offset = 0 170 | while True: 171 | updates = bot.get_updates(offset=offset) 172 | for update in updates: 173 | print("An update received.") 174 | if "message" in update: 175 | chat_id = update["message"]["chat"]["id"] 176 | if "text" in update["message"]: 177 | text = update["message"]["text"] 178 | if is_unicode(text): 179 | print("Update content: {}".format(update)) 180 | bot.send_message(chat_id, bot.get_answer(update["message"]["text"])) 181 | else: 182 | bot.send_message(chat_id, "Hmm, you are sending some weird characters to me...") 183 | offset = max(offset, update['update_id'] + 1) 184 | time.sleep(1) 185 | 186 | if __name__ == "__main__": 187 | main() 188 | --------------------------------------------------------------------------------