├── Model Creation.ipynb
├── README.md
├── data.zip
└── main.py
/Model Creation.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# 1. Import Libraries"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": 2,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "import random\n",
17 | "import copy\n",
18 | "import time\n",
19 | "import pandas as pd\n",
20 | "import numpy as np\n",
21 | "import gc\n",
22 | "import re\n",
23 | "import pickle\n",
24 | "from sklearn.model_selection import train_test_split\n",
25 | "from collections import Counter\n",
26 | "from sklearn.feature_extraction.text import CountVectorizer,TfidfVectorizer, HashingVectorizer\n",
27 | "from sklearn.linear_model import LogisticRegression\n",
28 | "from sklearn.multiclass import OneVsRestClassifier\n",
29 | "from sklearn.linear_model import LogisticRegression\n",
30 | "from sklearn.metrics import accuracy_score\n",
31 | "import gensim\n",
32 | "from sklearn.metrics.pairwise import pairwise_distances_argmin\n",
33 | "import nltk\n",
34 | "#nltk.download('stopwords')\n",
35 | "from nltk.corpus import stopwords"
36 | ]
37 | },
38 | {
39 | "cell_type": "markdown",
40 | "metadata": {},
41 | "source": [
42 | "# 2. Read the Data"
43 | ]
44 | },
45 | {
46 | "cell_type": "code",
47 | "execution_count": 3,
48 | "metadata": {},
49 | "outputs": [],
50 | "source": [
51 | "dialogues = pd.read_csv(\"data/dialogues.tsv\",sep=\"\\t\")"
52 | ]
53 | },
54 | {
55 | "cell_type": "code",
56 | "execution_count": 4,
57 | "metadata": {},
58 | "outputs": [],
59 | "source": [
60 | "posts = pd.read_csv(\"data/tagged_posts.tsv\",sep=\"\\t\")"
61 | ]
62 | },
63 | {
64 | "cell_type": "code",
65 | "execution_count": 5,
66 | "metadata": {},
67 | "outputs": [
68 | {
69 | "data": {
70 | "text/html": [
71 | "
\n",
72 | "\n",
85 | "
\n",
86 | " \n",
87 | " \n",
88 | " | \n",
89 | " text | \n",
90 | " tag | \n",
91 | "
\n",
92 | " \n",
93 | " \n",
94 | " \n",
95 | " 0 | \n",
96 | " Okay -- you're gonna need to learn how to lie. | \n",
97 | " dialogue | \n",
98 | "
\n",
99 | " \n",
100 | " 1 | \n",
101 | " I'm kidding. You know how sometimes you just ... | \n",
102 | " dialogue | \n",
103 | "
\n",
104 | " \n",
105 | " 2 | \n",
106 | " Like my fear of wearing pastels? | \n",
107 | " dialogue | \n",
108 | "
\n",
109 | " \n",
110 | " 3 | \n",
111 | " I figured you'd get to the good stuff eventually. | \n",
112 | " dialogue | \n",
113 | "
\n",
114 | " \n",
115 | " 4 | \n",
116 | " Thank God! If I had to hear one more story ab... | \n",
117 | " dialogue | \n",
118 | "
\n",
119 | " \n",
120 | "
\n",
121 | "
"
122 | ],
123 | "text/plain": [
124 | " text tag\n",
125 | "0 Okay -- you're gonna need to learn how to lie. dialogue\n",
126 | "1 I'm kidding. You know how sometimes you just ... dialogue\n",
127 | "2 Like my fear of wearing pastels? dialogue\n",
128 | "3 I figured you'd get to the good stuff eventually. dialogue\n",
129 | "4 Thank God! If I had to hear one more story ab... dialogue"
130 | ]
131 | },
132 | "execution_count": 5,
133 | "metadata": {},
134 | "output_type": "execute_result"
135 | }
136 | ],
137 | "source": [
138 | "dialogues.head()"
139 | ]
140 | },
141 | {
142 | "cell_type": "code",
143 | "execution_count": 6,
144 | "metadata": {},
145 | "outputs": [
146 | {
147 | "data": {
148 | "text/html": [
149 | "\n",
150 | "\n",
163 | "
\n",
164 | " \n",
165 | " \n",
166 | " | \n",
167 | " post_id | \n",
168 | " title | \n",
169 | " tag | \n",
170 | "
\n",
171 | " \n",
172 | " \n",
173 | " \n",
174 | " 0 | \n",
175 | " 9 | \n",
176 | " Calculate age in C# | \n",
177 | " c# | \n",
178 | "
\n",
179 | " \n",
180 | " 1 | \n",
181 | " 16 | \n",
182 | " Filling a DataSet or DataTable from a LINQ que... | \n",
183 | " c# | \n",
184 | "
\n",
185 | " \n",
186 | " 2 | \n",
187 | " 39 | \n",
188 | " Reliable timer in a console application | \n",
189 | " c# | \n",
190 | "
\n",
191 | " \n",
192 | " 3 | \n",
193 | " 42 | \n",
194 | " Best way to allow plugins for a PHP application | \n",
195 | " php | \n",
196 | "
\n",
197 | " \n",
198 | " 4 | \n",
199 | " 59 | \n",
200 | " How do I get a distinct, ordered list of names... | \n",
201 | " c# | \n",
202 | "
\n",
203 | " \n",
204 | "
\n",
205 | "
"
206 | ],
207 | "text/plain": [
208 | " post_id title tag\n",
209 | "0 9 Calculate age in C# c#\n",
210 | "1 16 Filling a DataSet or DataTable from a LINQ que... c#\n",
211 | "2 39 Reliable timer in a console application c#\n",
212 | "3 42 Best way to allow plugins for a PHP application php\n",
213 | "4 59 How do I get a distinct, ordered list of names... c#"
214 | ]
215 | },
216 | "execution_count": 6,
217 | "metadata": {},
218 | "output_type": "execute_result"
219 | }
220 | ],
221 | "source": [
222 | "posts.head()"
223 | ]
224 | },
225 | {
226 | "cell_type": "code",
227 | "execution_count": 7,
228 | "metadata": {},
229 | "outputs": [
230 | {
231 | "name": "stdout",
232 | "output_type": "stream",
233 | "text": [
234 | "Num Posts: 2171575\n",
235 | "Num Dialogues: 218609\n"
236 | ]
237 | }
238 | ],
239 | "source": [
240 | "print(\"Num Posts:\",len(posts))\n",
241 | "print(\"Num Dialogues:\",len(dialogues))"
242 | ]
243 | },
244 | {
245 | "cell_type": "markdown",
246 | "metadata": {},
247 | "source": [
248 | "# 3. Create training data for intent classifier - Chitchat/SO Question"
249 | ]
250 | },
251 | {
252 | "cell_type": "code",
253 | "execution_count": 8,
254 | "metadata": {},
255 | "outputs": [],
256 | "source": [
257 | "texts = list(dialogues[:200000].text.values) + list(posts[:200000].title.values)\n",
258 | "labels = ['dialogue']*200000 + ['stackoverflow']*200000"
259 | ]
260 | },
261 | {
262 | "cell_type": "code",
263 | "execution_count": 9,
264 | "metadata": {},
265 | "outputs": [],
266 | "source": [
267 | "data = pd.DataFrame({'text':texts,'target':labels})"
268 | ]
269 | },
270 | {
271 | "cell_type": "code",
272 | "execution_count": 10,
273 | "metadata": {},
274 | "outputs": [],
275 | "source": [
276 | "def text_prepare(text):\n",
277 | " \"\"\"Performs tokenization and simple preprocessing.\"\"\"\n",
278 | " \n",
279 | " replace_by_space_re = re.compile('[/(){}\\[\\]\\|@,;]')\n",
280 | " bad_symbols_re = re.compile('[^0-9a-z #+_]')\n",
281 | " stopwords_set = set(stopwords.words('english'))\n",
282 | "\n",
283 | " text = text.lower()\n",
284 | " text = replace_by_space_re.sub(' ', text)\n",
285 | " text = bad_symbols_re.sub('', text)\n",
286 | " text = ' '.join([x for x in text.split() if x and x not in stopwords_set])\n",
287 | "\n",
288 | " return text.strip()"
289 | ]
290 | },
291 | {
292 | "cell_type": "code",
293 | "execution_count": 11,
294 | "metadata": {},
295 | "outputs": [],
296 | "source": [
297 | "# Doing some data cleaning\n",
298 | "data['text'] = data['text'].apply(lambda x : text_prepare(x))"
299 | ]
300 | },
301 | {
302 | "cell_type": "code",
303 | "execution_count": 12,
304 | "metadata": {},
305 | "outputs": [
306 | {
307 | "name": "stdout",
308 | "output_type": "stream",
309 | "text": [
310 | "Train size = 360000, test size = 40000\n"
311 | ]
312 | }
313 | ],
314 | "source": [
315 | "X_train, X_test, y_train, y_test = train_test_split(data['text'],data['target'],test_size = .1 , random_state=0)\n",
316 | "\n",
317 | "print('Train size = {}, test size = {}'.format(len(X_train), len(X_test)))\n"
318 | ]
319 | },
320 | {
321 | "cell_type": "markdown",
322 | "metadata": {},
323 | "source": [
324 | "# 4. Create Intent classifier"
325 | ]
326 | },
327 | {
328 | "cell_type": "code",
329 | "execution_count": 13,
330 | "metadata": {},
331 | "outputs": [],
332 | "source": [
333 | "# We will keep our models and vectorizers in this folder"
334 | ]
335 | },
336 | {
337 | "cell_type": "code",
338 | "execution_count": 14,
339 | "metadata": {},
340 | "outputs": [
341 | {
342 | "name": "stdout",
343 | "output_type": "stream",
344 | "text": [
345 | "mkdir: resources: File exists\r\n"
346 | ]
347 | }
348 | ],
349 | "source": [
350 | "!mkdir resources"
351 | ]
352 | },
353 | {
354 | "cell_type": "code",
355 | "execution_count": 15,
356 | "metadata": {},
357 | "outputs": [],
358 | "source": [
359 | "def tfidf_features(X_train, X_test, vectorizer_path):\n",
360 | " \"\"\"Performs TF-IDF transformation and dumps the model.\"\"\"\n",
361 | " tfv = TfidfVectorizer(dtype=np.float32, min_df=3, max_features=None, \n",
362 | " strip_accents='unicode', analyzer='word',token_pattern=r'\\w{1,}',\n",
363 | " ngram_range=(1, 3), use_idf=1,smooth_idf=1,sublinear_tf=1,\n",
364 | " stop_words = 'english')\n",
365 | " \n",
366 | " X_train = tfv.fit_transform(X_train)\n",
367 | " X_test = tfv.transform(X_test)\n",
368 | " \n",
369 | " pickle.dump(tfv,vectorizer_path)\n",
370 | " return X_train, X_test"
371 | ]
372 | },
373 | {
374 | "cell_type": "code",
375 | "execution_count": 16,
376 | "metadata": {},
377 | "outputs": [],
378 | "source": [
379 | "X_train_tfidf, X_test_tfidf = tfidf_features(X_train, X_test, open(\"resources/tfidf.pkl\",'wb'))"
380 | ]
381 | },
382 | {
383 | "cell_type": "code",
384 | "execution_count": 17,
385 | "metadata": {},
386 | "outputs": [
387 | {
388 | "name": "stderr",
389 | "output_type": "stream",
390 | "text": [
391 | "/miniconda3/envs/py36/lib/python3.6/site-packages/sklearn/linear_model/logistic.py:433: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
392 | " FutureWarning)\n"
393 | ]
394 | },
395 | {
396 | "data": {
397 | "text/plain": [
398 | "LogisticRegression(C=10, class_weight=None, dual=False, fit_intercept=True,\n",
399 | " intercept_scaling=1, max_iter=100, multi_class='warn',\n",
400 | " n_jobs=None, penalty='l2', random_state=0, solver='warn',\n",
401 | " tol=0.0001, verbose=0, warm_start=False)"
402 | ]
403 | },
404 | "execution_count": 17,
405 | "metadata": {},
406 | "output_type": "execute_result"
407 | }
408 | ],
409 | "source": [
410 | "intent_recognizer = LogisticRegression(C=10,random_state=0)\n",
411 | "intent_recognizer.fit(X_train_tfidf,y_train)\n"
412 | ]
413 | },
414 | {
415 | "cell_type": "code",
416 | "execution_count": 18,
417 | "metadata": {},
418 | "outputs": [
419 | {
420 | "name": "stdout",
421 | "output_type": "stream",
422 | "text": [
423 | "Test accuracy = 0.989825\n"
424 | ]
425 | }
426 | ],
427 | "source": [
428 | "# Check test accuracy.\n",
429 | "y_test_pred = intent_recognizer.predict(X_test_tfidf)\n",
430 | "test_accuracy = accuracy_score(y_test, y_test_pred)\n",
431 | "print('Test accuracy = {}'.format(test_accuracy))"
432 | ]
433 | },
434 | {
435 | "cell_type": "code",
436 | "execution_count": 19,
437 | "metadata": {},
438 | "outputs": [],
439 | "source": [
440 | "pickle.dump(intent_recognizer, open(\"resources/intent_clf.pkl\" , 'wb'))"
441 | ]
442 | },
443 | {
444 | "cell_type": "markdown",
445 | "metadata": {},
446 | "source": [
447 | "# 5 Create Programming Language classifier"
448 | ]
449 | },
450 | {
451 | "cell_type": "code",
452 | "execution_count": 20,
453 | "metadata": {},
454 | "outputs": [],
455 | "source": [
456 | "X = posts['title'].values\n",
457 | "y = posts['tag'].values"
458 | ]
459 | },
460 | {
461 | "cell_type": "code",
462 | "execution_count": 21,
463 | "metadata": {},
464 | "outputs": [
465 | {
466 | "name": "stdout",
467 | "output_type": "stream",
468 | "text": [
469 | "Train size = 1737260, test size = 434315\n"
470 | ]
471 | }
472 | ],
473 | "source": [
474 | "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)\n",
475 | "print('Train size = {}, test size = {}'.format(len(X_train), len(X_test)))"
476 | ]
477 | },
478 | {
479 | "cell_type": "code",
480 | "execution_count": 22,
481 | "metadata": {},
482 | "outputs": [],
483 | "source": [
484 | "vectorizer = pickle.load(open(\"resources/tfidf.pkl\", 'rb'))\n",
485 | "X_train_tfidf, X_test_tfidf = vectorizer.transform(X_train), vectorizer.transform(X_test)"
486 | ]
487 | },
488 | {
489 | "cell_type": "code",
490 | "execution_count": 24,
491 | "metadata": {},
492 | "outputs": [
493 | {
494 | "name": "stderr",
495 | "output_type": "stream",
496 | "text": [
497 | "/miniconda3/envs/py36/lib/python3.6/site-packages/sklearn/linear_model/logistic.py:433: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
498 | " FutureWarning)\n"
499 | ]
500 | },
501 | {
502 | "data": {
503 | "text/plain": [
504 | "OneVsRestClassifier(estimator=LogisticRegression(C=5, class_weight=None, dual=False, fit_intercept=True,\n",
505 | " intercept_scaling=1, max_iter=100, multi_class='warn',\n",
506 | " n_jobs=None, penalty='l2', random_state=0, solver='warn',\n",
507 | " tol=0.0001, verbose=0, warm_start=False),\n",
508 | " n_jobs=None)"
509 | ]
510 | },
511 | "execution_count": 24,
512 | "metadata": {},
513 | "output_type": "execute_result"
514 | }
515 | ],
516 | "source": [
517 | "tag_classifier = OneVsRestClassifier(LogisticRegression(C=5,random_state=0))\n",
518 | "tag_classifier.fit(X_train_tfidf,y_train)"
519 | ]
520 | },
521 | {
522 | "cell_type": "code",
523 | "execution_count": 25,
524 | "metadata": {},
525 | "outputs": [
526 | {
527 | "name": "stdout",
528 | "output_type": "stream",
529 | "text": [
530 | "Test accuracy = 0.8043816124241622\n"
531 | ]
532 | }
533 | ],
534 | "source": [
535 | "# Check test accuracy.\n",
536 | "y_test_pred = tag_classifier.predict(X_test_tfidf)\n",
537 | "test_accuracy = accuracy_score(y_test, y_test_pred)\n",
538 | "print('Test accuracy = {}'.format(test_accuracy))"
539 | ]
540 | },
541 | {
542 | "cell_type": "code",
543 | "execution_count": 26,
544 | "metadata": {},
545 | "outputs": [],
546 | "source": [
547 | "pickle.dump(tag_classifier, open(\"resources/tag_clf.pkl\", 'wb'))"
548 | ]
549 | },
550 | {
551 | "cell_type": "markdown",
552 | "metadata": {},
553 | "source": [
554 | "# 6. Store Question database Embeddings"
555 | ]
556 | },
557 | {
558 | "cell_type": "markdown",
559 | "metadata": {},
560 | "source": [
561 | "You can use [pre-trained word vectors](https://code.google.com/archive/p/word2vec/) from Google."
562 | ]
563 | },
564 | {
565 | "cell_type": "code",
566 | "execution_count": 27,
567 | "metadata": {},
568 | "outputs": [],
569 | "source": [
570 | "\n",
571 | "\n",
572 | "# Load Google's pre-trained Word2Vec model.\n",
573 | "model = gensim.models.KeyedVectors.load_word2vec_format('GoogleNews-vectors-negative300.bin', binary=True) "
574 | ]
575 | },
576 | {
577 | "cell_type": "markdown",
578 | "metadata": {},
579 | "source": [
580 | "We want to convert every question to an embedding and store them. Whenever user asks a stack overflow question we want to use cosine similarity to get the most similar question"
581 | ]
582 | },
583 | {
584 | "cell_type": "code",
585 | "execution_count": 28,
586 | "metadata": {},
587 | "outputs": [],
588 | "source": [
589 | "def question_to_vec(question, embeddings, dim=300):\n",
590 | " \"\"\"\n",
591 | " question: a string\n",
592 | " embeddings: dict where the key is a word and a value is its' embedding\n",
593 | " dim: size of the representation\n",
594 | "\n",
595 | " result: vector representation for the question\n",
596 | " \"\"\"\n",
597 | " word_tokens = question.split(\" \")\n",
598 | " question_len = len(word_tokens)\n",
599 | " question_mat = np.zeros((question_len,dim), dtype = np.float32)\n",
600 | " \n",
601 | " for idx, word in enumerate(word_tokens):\n",
602 | " if word in embeddings:\n",
603 | " question_mat[idx,:] = embeddings[word]\n",
604 | " \n",
605 | " # remove zero-rows which stand for OOV words \n",
606 | " question_mat = question_mat[~np.all(question_mat == 0, axis = 1)]\n",
607 | " \n",
608 | " # Compute the mean of each word along the sentence\n",
609 | " if question_mat.shape[0] > 0:\n",
610 | " vec = np.array(np.mean(question_mat, axis = 0), dtype = np.float32).reshape((1,dim))\n",
611 | " else:\n",
612 | " vec = np.zeros((1,dim), dtype = np.float32)\n",
613 | " \n",
614 | " return vec"
615 | ]
616 | },
617 | {
618 | "cell_type": "code",
619 | "execution_count": 29,
620 | "metadata": {},
621 | "outputs": [],
622 | "source": [
623 | "counts_by_tag = posts.groupby(by=['tag'])[\"tag\"].count().reset_index(name = 'count').sort_values(['count'], ascending = False)"
624 | ]
625 | },
626 | {
627 | "cell_type": "code",
628 | "execution_count": 30,
629 | "metadata": {},
630 | "outputs": [
631 | {
632 | "name": "stdout",
633 | "output_type": "stream",
634 | "text": [
635 | "[('c#', 394451), ('java', 383456), ('javascript', 375867), ('php', 321752), ('c_cpp', 281300), ('python', 208607), ('ruby', 99930), ('r', 36359), ('vb', 35044), ('swift', 34809)]\n"
636 | ]
637 | }
638 | ],
639 | "source": [
640 | "counts_by_tag = list(zip(counts_by_tag['tag'],counts_by_tag['count']))\n",
641 | "print(counts_by_tag)"
642 | ]
643 | },
644 | {
645 | "cell_type": "code",
646 | "execution_count": 31,
647 | "metadata": {},
648 | "outputs": [
649 | {
650 | "name": "stdout",
651 | "output_type": "stream",
652 | "text": [
653 | "mkdir: resources/embeddings_folder: File exists\r\n"
654 | ]
655 | }
656 | ],
657 | "source": [
658 | "! mkdir resources/embeddings_folder"
659 | ]
660 | },
661 | {
662 | "cell_type": "code",
663 | "execution_count": 32,
664 | "metadata": {},
665 | "outputs": [],
666 | "source": [
667 | "for tag, count in counts_by_tag:\n",
668 | " tag_posts = posts[posts['tag'] == tag]\n",
669 | " tag_post_ids = tag_posts['post_id'].values\n",
670 | " tag_vectors = np.zeros((count, 300), dtype=np.float32)\n",
671 | " for i, title in enumerate(tag_posts['title']):\n",
672 | " tag_vectors[i, :] = question_to_vec(title, model, 300)\n",
673 | " # Dump post ids and vectors to a file.\n",
674 | " filename = 'resources/embeddings_folder/'+ tag + '.pkl'\n",
675 | " pickle.dump((tag_post_ids, tag_vectors), open(filename, 'wb'))"
676 | ]
677 | },
678 | {
679 | "cell_type": "markdown",
680 | "metadata": {},
681 | "source": [
682 | "# Given a question and tag can I retrieve the most similar question post_id\n"
683 | ]
684 | },
685 | {
686 | "cell_type": "code",
687 | "execution_count": 33,
688 | "metadata": {},
689 | "outputs": [],
690 | "source": [
691 | "\n",
692 | "def get_similar_question(question,tag):\n",
693 | " # get the path where all question embeddings are kept and load the post_ids and post_embeddings\n",
694 | " embeddings_path = 'resources/embeddings_folder/' + tag + \".pkl\"\n",
695 | " post_ids, post_embeddings = pickle.load(open(embeddings_path, 'rb'))\n",
696 | " # Get the embeddings for the question\n",
697 | " question_vec = question_to_vec(question, model, 300)\n",
698 | " # find index of most similar post\n",
699 | " best_post_index = pairwise_distances_argmin(question_vec,\n",
700 | " post_embeddings)\n",
701 | " # return best post id\n",
702 | " return post_ids[best_post_index]"
703 | ]
704 | },
705 | {
706 | "cell_type": "code",
707 | "execution_count": 37,
708 | "metadata": {},
709 | "outputs": [
710 | {
711 | "data": {
712 | "text/plain": [
713 | "array([5947137])"
714 | ]
715 | },
716 | "execution_count": 37,
717 | "metadata": {},
718 | "output_type": "execute_result"
719 | }
720 | ],
721 | "source": [
722 | "get_similar_question(\"how to use list comprehension in python?\",'python')"
723 | ]
724 | },
725 | {
726 | "cell_type": "markdown",
727 | "metadata": {},
728 | "source": [
729 | "You can find this question at:\n",
730 | " \n",
731 | "https://stackoverflow.com/questions/8278287"
732 | ]
733 | }
734 | ],
735 | "metadata": {
736 | "kernelspec": {
737 | "display_name": "Python 3",
738 | "language": "python",
739 | "name": "python3"
740 | }
741 | },
742 | "nbformat": 4,
743 | "nbformat_minor": 2
744 | }
745 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | This is a repository for the post : https://mlwhiz.com/blog/2019/04/15/chatbot/
2 |
3 |
--------------------------------------------------------------------------------
/data.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MLWhiz/chatbot/52145c66af6c6f0226dadcc3790c27e3fb5184f2/data.zip
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import requests
4 | import time
5 | import argparse
6 | import os
7 | import json
8 | from requests.compat import urljoin
9 | import gensim
10 | import pickle
11 | import re
12 | import nltk
13 | from nltk.corpus import stopwords
14 | import numpy as np
15 | from sklearn.metrics.pairwise import pairwise_distances_argmin
16 |
17 | class BotHandler(object):
18 | """
19 | BotHandler is a class which implements all back-end of the bot.
20 | It has three main functions:
21 | 'get_updates' — checks for new messages
22 | 'send_message' – posts new message to user
23 | 'get_answer' — computes the most relevant on a user's question
24 | """
25 |
26 | def __init__(self, token, dialogue_manager):
27 | # Put the Telegram Access token here
28 | self.token = token
29 | self.api_url = "https://api.telegram.org/bot{}/".format(token)
30 | self.dialogue_manager = dialogue_manager
31 |
32 | def get_updates(self, offset=None, timeout=30):
33 | params = {"timeout": timeout, "offset": offset}
34 | raw_resp = requests.get(urljoin(self.api_url, "getUpdates"), params)
35 | try:
36 | resp = raw_resp.json()
37 | except json.decoder.JSONDecodeError as e:
38 | print("Failed to parse response {}: {}.".format(raw_resp.content, e))
39 | return []
40 |
41 | if "result" not in resp:
42 | return []
43 | return resp["result"]
44 |
45 | def send_message(self, chat_id, text):
46 | params = {"chat_id": chat_id, "text": text}
47 | return requests.post(urljoin(self.api_url, "sendMessage"), params)
48 |
49 | def get_answer(self, question):
50 | if question == '/start':
51 | return "Hi, I am your project bot. How can I help you today?"
52 | return self.dialogue_manager.generate_answer(question)
53 |
54 |
55 | def is_unicode(text):
56 | return len(text) == len(text.encode())
57 |
58 |
59 |
60 | # We will need this function to prepare text at prediction time
61 | def text_prepare(text):
62 | """Performs tokenization and simple preprocessing."""
63 |
64 | replace_by_space_re = re.compile('[/(){}\[\]\|@,;]')
65 | bad_symbols_re = re.compile('[^0-9a-z #+_]')
66 | stopwords_set = set(stopwords.words('english'))
67 |
68 | text = text.lower()
69 | text = replace_by_space_re.sub(' ', text)
70 | text = bad_symbols_re.sub('', text)
71 | text = ' '.join([x for x in text.split() if x and x not in stopwords_set])
72 |
73 | return text.strip()
74 |
75 | # need this to convert questions asked by user to vectors
76 | def question_to_vec(question, embeddings, dim=300):
77 | """
78 | question: a string
79 | embeddings: dict where the key is a word and a value is its' embedding
80 | dim: size of the representation
81 |
82 | result: vector representation for the question
83 | """
84 | word_tokens = question.split(" ")
85 | question_len = len(word_tokens)
86 | question_mat = np.zeros((question_len,dim), dtype = np.float32)
87 |
88 | for idx, word in enumerate(word_tokens):
89 | if word in embeddings:
90 | question_mat[idx,:] = embeddings[word]
91 |
92 | # remove zero-rows which stand for OOV words
93 | question_mat = question_mat[~np.all(question_mat == 0, axis = 1)]
94 |
95 | # Compute the mean of each word along the sentence
96 | if question_mat.shape[0] > 0:
97 | vec = np.array(np.mean(question_mat, axis = 0), dtype = np.float32).reshape((1,dim))
98 | else:
99 | vec = np.zeros((1,dim), dtype = np.float32)
100 |
101 | return vec
102 |
103 | class SimpleDialogueManager(object):
104 | """
105 | This is a simple dialogue manager to test the telegram bot.
106 | The main part of our bot will be written here.
107 | """
108 |
109 | def __init__(self):
110 |
111 | # Instantiate all the models and TFIDF Objects.
112 | print("Loading resources...")
113 | # Instantiate a Chatterbot for Chitchat type questions
114 | from chatterbot import ChatBot
115 | from chatterbot.trainers import ChatterBotCorpusTrainer
116 | chatbot = ChatBot('MLWhizChatterbot')
117 | trainer = ChatterBotCorpusTrainer(chatbot)
118 | trainer.train('chatterbot.corpus.english')
119 | self.chitchat_bot = chatbot
120 | print("Loading Word2vec model...")
121 | # Instantiate the Google's pre-trained Word2Vec model.
122 | self.model = gensim.models.KeyedVectors.load_word2vec_format('GoogleNews-vectors-negative300.bin', binary=True)
123 | print("Loading Classifier objects...")
124 | # Load the intent classifier and tag classifier
125 | self.intent_recognizer = pickle.load(open('resources/intent_clf.pkl', 'rb'))
126 | self.tag_classifier = pickle.load(open('resources/tag_clf.pkl', 'rb'))
127 | # Load the TFIDF vectorizer object
128 | self.tfidf_vectorizer = pickle.load(open('resources/tfidf.pkl', 'rb'))
129 | print("Finished Loading Resources")
130 |
131 | # We created this function just above. We just need to have a function to get most similar question's *post id* in the dataset given we know the programming Language of the question. Here it is:
132 | def get_similar_question(self,question,tag):
133 | # get the path where all question embeddings are kept and load the post_ids and post_embeddings
134 | embeddings_path = 'resources/embeddings_folder/' + tag + ".pkl"
135 | post_ids, post_embeddings = pickle.load(open(embeddings_path, 'rb'))
136 | # Get the embeddings for the question
137 | question_vec = question_to_vec(question, self.model, 300)
138 | # find index of most similar post
139 | best_post_index = pairwise_distances_argmin(question_vec,
140 | post_embeddings)
141 | # return best post id
142 | return post_ids[best_post_index]
143 |
144 | def generate_answer(self, question):
145 | prepared_question = text_prepare(question)
146 | features = self.tfidf_vectorizer.transform([prepared_question])
147 | # find intent
148 | intent = self.intent_recognizer.predict(features)[0]
149 | # Chit-chat part:
150 | if intent == 'dialogue':
151 | response = self.chitchat_bot.get_response(question)
152 | # Stack Overflow Question
153 | else:
154 | # find programming language
155 | tag = self.tag_classifier.predict(features)[0]
156 | # find most similar question post id
157 | post_id = self.get_similar_question(question,tag)[0]
158 | # respond with
159 | response = 'I think its about %s\nThis thread might help you: https://stackoverflow.com/questions/%s' % (tag, post_id)
160 | return response
161 |
162 | def main():
163 | token = '839585958:AAEfTDo2X6PgHb9IEdb62ueS4SmdpCkhtmc'
164 | simple_manager = SimpleDialogueManager()
165 | bot = BotHandler(token, simple_manager)
166 | ###############################################################
167 |
168 | print("Ready to talk!")
169 | offset = 0
170 | while True:
171 | updates = bot.get_updates(offset=offset)
172 | for update in updates:
173 | print("An update received.")
174 | if "message" in update:
175 | chat_id = update["message"]["chat"]["id"]
176 | if "text" in update["message"]:
177 | text = update["message"]["text"]
178 | if is_unicode(text):
179 | print("Update content: {}".format(update))
180 | bot.send_message(chat_id, bot.get_answer(update["message"]["text"]))
181 | else:
182 | bot.send_message(chat_id, "Hmm, you are sending some weird characters to me...")
183 | offset = max(offset, update['update_id'] + 1)
184 | time.sleep(1)
185 |
186 | if __name__ == "__main__":
187 | main()
188 |
--------------------------------------------------------------------------------