├── LICENSE
├── README.md
├── app.py
├── chat.py
├── intents.json
├── model.py
├── nltk_utils.py
├── standalone-frontend
├── app.js
├── base.html
├── images
│ └── chatbox-icon.svg
└── style.css
├── static
├── app.js
├── images
│ └── chatbox-icon.svg
└── style.css
├── templates
└── base.html
└── train.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Patrick Loeber
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Chatbot Deployment with Flask and JavaScript
2 |
3 | In this tutorial we deploy the chatbot I created in [this](https://github.com/python-engineer/pytorch-chatbot) tutorial with Flask and JavaScript.
4 |
5 | This gives 2 deployment options:
6 | - Deploy within Flask app with jinja2 template
7 | - Serve only the Flask prediction API. The used html and javascript files can be included in any Frontend application (with only a slight modification) and can run completely separate from the Flask App then.
8 |
9 | ## Initial Setup:
10 | This repo currently contains the starter files.
11 |
12 | Clone repo and create a virtual environment
13 | ```
14 | $ git clone https://github.com/python-engineer/chatbot-deployment.git
15 | $ cd chatbot-deployment
16 | $ python3 -m venv venv
17 | $ . venv/bin/activate
18 | ```
19 | Install dependencies
20 | ```
21 | $ (venv) pip install Flask torch torchvision nltk
22 | ```
23 | Install nltk package
24 | ```
25 | $ (venv) python
26 | >>> import nltk
27 | >>> nltk.download('punkt')
28 | ```
29 | Modify `intents.json` with different intents and responses for your Chatbot
30 |
31 | Run
32 | ```
33 | $ (venv) python train.py
34 | ```
35 | This will dump data.pth file. And then run
36 | the following command to test it in the console.
37 | ```
38 | $ (venv) python chat.py
39 | ```
40 |
41 | Now for deployment follow my tutorial to implement `app.py` and `app.js`.
42 |
43 | ## Watch the Tutorial
44 | [](https://youtu.be/a37BL0stIuM)
45 | [https://youtu.be/a37BL0stIuM](https://youtu.be/a37BL0stIuM)
46 |
47 | ## Note
48 | In the video we implement the first approach using jinja2 templates within our Flask app. Only slight modifications are needed to run the frontend separately. I put the final frontend code for a standalone frontend application in the [standalone-frontend](/standalone-frontend) folder.
49 |
50 | ## Credits:
51 | This repo was used for the frontend code:
52 | https://github.com/hitchcliff/front-end-chatjs
53 |
--------------------------------------------------------------------------------
/app.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/patrickloeber/chatbot-deployment/9f347f8505c03202983f362ad4e02907659f5d67/app.py
--------------------------------------------------------------------------------
/chat.py:
--------------------------------------------------------------------------------
1 | import random
2 | import json
3 |
4 | import torch
5 |
6 | from model import NeuralNet
7 | from nltk_utils import bag_of_words, tokenize
8 |
9 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
10 |
11 | with open('intents.json', 'r') as json_data:
12 | intents = json.load(json_data)
13 |
14 | FILE = "data.pth"
15 | data = torch.load(FILE)
16 |
17 | input_size = data["input_size"]
18 | hidden_size = data["hidden_size"]
19 | output_size = data["output_size"]
20 | all_words = data['all_words']
21 | tags = data['tags']
22 | model_state = data["model_state"]
23 |
24 | model = NeuralNet(input_size, hidden_size, output_size).to(device)
25 | model.load_state_dict(model_state)
26 | model.eval()
27 |
28 | bot_name = "Sam"
29 |
30 | def get_response(msg):
31 | sentence = tokenize(msg)
32 | X = bag_of_words(sentence, all_words)
33 | X = X.reshape(1, X.shape[0])
34 | X = torch.from_numpy(X).to(device)
35 |
36 | output = model(X)
37 | _, predicted = torch.max(output, dim=1)
38 |
39 | tag = tags[predicted.item()]
40 |
41 | probs = torch.softmax(output, dim=1)
42 | prob = probs[0][predicted.item()]
43 | if prob.item() > 0.75:
44 | for intent in intents['intents']:
45 | if tag == intent["tag"]:
46 | return random.choice(intent['responses'])
47 |
48 | return "I do not understand..."
49 |
50 |
51 | if __name__ == "__main__":
52 | print("Let's chat! (type 'quit' to exit)")
53 | while True:
54 | # sentence = "do you use credit cards?"
55 | sentence = input("You: ")
56 | if sentence == "quit":
57 | break
58 |
59 | resp = get_response(sentence)
60 | print(resp)
61 |
62 |
--------------------------------------------------------------------------------
/intents.json:
--------------------------------------------------------------------------------
1 | {
2 | "intents": [
3 | {
4 | "tag": "greeting",
5 | "patterns": [
6 | "Hi",
7 | "Hey",
8 | "How are you",
9 | "Is anyone there?",
10 | "Hello",
11 | "Good day"
12 | ],
13 | "responses": [
14 | "Hey :-)",
15 | "Hello, thanks for visiting",
16 | "Hi there, what can I do for you?",
17 | "Hi there, how can I help?"
18 | ]
19 | },
20 | {
21 | "tag": "goodbye",
22 | "patterns": ["Bye", "See you later", "Goodbye"],
23 | "responses": [
24 | "See you later, thanks for visiting",
25 | "Have a nice day",
26 | "Bye! Come back again soon."
27 | ]
28 | },
29 | {
30 | "tag": "thanks",
31 | "patterns": ["Thanks", "Thank you", "That's helpful", "Thank's a lot!"],
32 | "responses": ["Happy to help!", "Any time!", "My pleasure"]
33 | },
34 | {
35 | "tag": "items",
36 | "patterns": [
37 | "Which items do you have?",
38 | "What kinds of items are there?",
39 | "What do you sell?"
40 | ],
41 | "responses": [
42 | "We sell coffee and tea",
43 | "We have coffee and tea"
44 | ]
45 | },
46 | {
47 | "tag": "payments",
48 | "patterns": [
49 | "Do you take credit cards?",
50 | "Do you accept Mastercard?",
51 | "Can I pay with Paypal?",
52 | "Are you cash only?"
53 | ],
54 | "responses": [
55 | "We accept VISA, Mastercard and Paypal",
56 | "We accept most major credit cards, and Paypal"
57 | ]
58 | },
59 | {
60 | "tag": "delivery",
61 | "patterns": [
62 | "How long does delivery take?",
63 | "How long does shipping take?",
64 | "When do I get my delivery?"
65 | ],
66 | "responses": [
67 | "Delivery takes 2-4 days",
68 | "Shipping takes 2-4 days"
69 | ]
70 | },
71 | {
72 | "tag": "funny",
73 | "patterns": [
74 | "Tell me a joke!",
75 | "Tell me something funny!",
76 | "Do you know a joke?"
77 | ],
78 | "responses": [
79 | "Why did the hipster burn his mouth? He drank the coffee before it was cool.",
80 | "What did the buffalo say when his son left for college? Bison."
81 | ]
82 | }
83 | ]
84 | }
85 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class NeuralNet(nn.Module):
6 | def __init__(self, input_size, hidden_size, num_classes):
7 | super(NeuralNet, self).__init__()
8 | self.l1 = nn.Linear(input_size, hidden_size)
9 | self.l2 = nn.Linear(hidden_size, hidden_size)
10 | self.l3 = nn.Linear(hidden_size, num_classes)
11 | self.relu = nn.ReLU()
12 |
13 | def forward(self, x):
14 | out = self.l1(x)
15 | out = self.relu(out)
16 | out = self.l2(out)
17 | out = self.relu(out)
18 | out = self.l3(out)
19 | # no activation and no softmax at the end
20 | return out
21 |
--------------------------------------------------------------------------------
/nltk_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import nltk
3 | # nltk.download('punkt')
4 | from nltk.stem.porter import PorterStemmer
5 | stemmer = PorterStemmer()
6 |
7 |
8 | def tokenize(sentence):
9 | """
10 | split sentence into array of words/tokens
11 | a token can be a word or punctuation character, or number
12 | """
13 | return nltk.word_tokenize(sentence)
14 |
15 |
16 | def stem(word):
17 | """
18 | stemming = find the root form of the word
19 | examples:
20 | words = ["organize", "organizes", "organizing"]
21 | words = [stem(w) for w in words]
22 | -> ["organ", "organ", "organ"]
23 | """
24 | return stemmer.stem(word.lower())
25 |
26 |
27 | def bag_of_words(tokenized_sentence, words):
28 | """
29 | return bag of words array:
30 | 1 for each known word that exists in the sentence, 0 otherwise
31 | example:
32 | sentence = ["hello", "how", "are", "you"]
33 | words = ["hi", "hello", "I", "you", "bye", "thank", "cool"]
34 | bog = [ 0 , 1 , 0 , 1 , 0 , 0 , 0]
35 | """
36 | # stem each word
37 | sentence_words = [stem(word) for word in tokenized_sentence]
38 | # initialize bag with 0 for each word
39 | bag = np.zeros(len(words), dtype=np.float32)
40 | for idx, w in enumerate(words):
41 | if w in sentence_words:
42 | bag[idx] = 1
43 |
44 | return bag
45 |
--------------------------------------------------------------------------------
/standalone-frontend/app.js:
--------------------------------------------------------------------------------
1 | class Chatbox {
2 | constructor() {
3 | this.args = {
4 | openButton: document.querySelector('.chatbox__button'),
5 | chatBox: document.querySelector('.chatbox__support'),
6 | sendButton: document.querySelector('.send__button')
7 | }
8 |
9 | this.state = false;
10 | this.messages = [];
11 | }
12 |
13 | display() {
14 | const {openButton, chatBox, sendButton} = this.args;
15 |
16 | openButton.addEventListener('click', () => this.toggleState(chatBox))
17 |
18 | sendButton.addEventListener('click', () => this.onSendButton(chatBox))
19 |
20 | const node = chatBox.querySelector('input');
21 | node.addEventListener("keyup", ({key}) => {
22 | if (key === "Enter") {
23 | this.onSendButton(chatBox)
24 | }
25 | })
26 | }
27 |
28 | toggleState(chatbox) {
29 | this.state = !this.state;
30 |
31 | // show or hides the box
32 | if(this.state) {
33 | chatbox.classList.add('chatbox--active')
34 | } else {
35 | chatbox.classList.remove('chatbox--active')
36 | }
37 | }
38 |
39 | onSendButton(chatbox) {
40 | var textField = chatbox.querySelector('input');
41 | let text1 = textField.value
42 | if (text1 === "") {
43 | return;
44 | }
45 |
46 | let msg1 = { name: "User", message: text1 }
47 | this.messages.push(msg1);
48 |
49 | fetch('http://127.0.0.1:5000/predict', {
50 | method: 'POST',
51 | body: JSON.stringify({ message: text1 }),
52 | mode: 'cors',
53 | headers: {
54 | 'Content-Type': 'application/json'
55 | },
56 | })
57 | .then(r => r.json())
58 | .then(r => {
59 | let msg2 = { name: "Sam", message: r.answer };
60 | this.messages.push(msg2);
61 | this.updateChatText(chatbox)
62 | textField.value = ''
63 |
64 | }).catch((error) => {
65 | console.error('Error:', error);
66 | this.updateChatText(chatbox)
67 | textField.value = ''
68 | });
69 | }
70 |
71 | updateChatText(chatbox) {
72 | var html = '';
73 | this.messages.slice().reverse().forEach(function(item, index) {
74 | if (item.name === "Sam")
75 | {
76 | html += '
' + item.message + '
'
77 | }
78 | else
79 | {
80 | html += '' + item.message + '
'
81 | }
82 | });
83 |
84 | const chatmessage = chatbox.querySelector('.chatbox__messages');
85 | chatmessage.innerHTML = html;
86 | }
87 | }
88 |
89 |
90 | const chatbox = new Chatbox();
91 | chatbox.display();
--------------------------------------------------------------------------------
/standalone-frontend/base.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 | Chatbot
8 |
9 |
10 |
35 |
36 |
37 |
38 |
39 |
--------------------------------------------------------------------------------
/standalone-frontend/images/chatbox-icon.svg:
--------------------------------------------------------------------------------
1 |
4 |
--------------------------------------------------------------------------------
/standalone-frontend/style.css:
--------------------------------------------------------------------------------
1 | * {
2 | box-sizing: border-box;
3 | margin: 0;
4 | padding: 0;
5 | }
6 |
7 | body {
8 | font-family: 'Nunito', sans-serif;
9 | font-weight: 400;
10 | font-size: 100%;
11 | background: #F1F1F1;
12 | }
13 |
14 | *, html {
15 | --primaryGradient: linear-gradient(93.12deg, #581B98 0.52%, #9C1DE7 100%);
16 | --secondaryGradient: linear-gradient(268.91deg, #581B98 -2.14%, #9C1DE7 99.69%);
17 | --primaryBoxShadow: 0px 10px 15px rgba(0, 0, 0, 0.1);
18 | --secondaryBoxShadow: 0px -10px 15px rgba(0, 0, 0, 0.1);
19 | --primary: #581B98;
20 | }
21 |
22 | /* CHATBOX
23 | =============== */
24 | .chatbox {
25 | position: absolute;
26 | bottom: 30px;
27 | right: 30px;
28 | }
29 |
30 | /* CONTENT IS CLOSE */
31 | .chatbox__support {
32 | display: flex;
33 | flex-direction: column;
34 | background: #eee;
35 | width: 300px;
36 | height: 350px;
37 | z-index: -123456;
38 | opacity: 0;
39 | transition: all .5s ease-in-out;
40 | }
41 |
42 | /* CONTENT ISOPEN */
43 | .chatbox--active {
44 | transform: translateY(-40px);
45 | z-index: 123456;
46 | opacity: 1;
47 |
48 | }
49 |
50 | /* BUTTON */
51 | .chatbox__button {
52 | text-align: right;
53 | }
54 |
55 | .send__button {
56 | padding: 6px;
57 | background: transparent;
58 | border: none;
59 | outline: none;
60 | cursor: pointer;
61 | }
62 |
63 |
64 | /* HEADER */
65 | .chatbox__header {
66 | position: sticky;
67 | top: 0;
68 | background: orange;
69 | }
70 |
71 | /* MESSAGES */
72 | .chatbox__messages {
73 | margin-top: auto;
74 | display: flex;
75 | overflow-y: scroll;
76 | flex-direction: column-reverse;
77 | }
78 |
79 | .messages__item {
80 | background: orange;
81 | max-width: 60.6%;
82 | width: fit-content;
83 | }
84 |
85 | .messages__item--operator {
86 | margin-left: auto;
87 | }
88 |
89 | .messages__item--visitor {
90 | margin-right: auto;
91 | }
92 |
93 | /* FOOTER */
94 | .chatbox__footer {
95 | position: sticky;
96 | bottom: 0;
97 | }
98 |
99 | .chatbox__support {
100 | background: #f9f9f9;
101 | height: 450px;
102 | width: 350px;
103 | box-shadow: 0px 0px 15px rgba(0, 0, 0, 0.1);
104 | border-top-left-radius: 20px;
105 | border-top-right-radius: 20px;
106 | }
107 |
108 | /* HEADER */
109 | .chatbox__header {
110 | background: var(--primaryGradient);
111 | display: flex;
112 | flex-direction: row;
113 | align-items: center;
114 | justify-content: center;
115 | padding: 15px 20px;
116 | border-top-left-radius: 20px;
117 | border-top-right-radius: 20px;
118 | box-shadow: var(--primaryBoxShadow);
119 | }
120 |
121 | .chatbox__image--header {
122 | margin-right: 10px;
123 | }
124 |
125 | .chatbox__heading--header {
126 | font-size: 1.2rem;
127 | color: white;
128 | }
129 |
130 | .chatbox__description--header {
131 | font-size: .9rem;
132 | color: white;
133 | }
134 |
135 | /* Messages */
136 | .chatbox__messages {
137 | padding: 0 20px;
138 | }
139 |
140 | .messages__item {
141 | margin-top: 10px;
142 | background: #E0E0E0;
143 | padding: 8px 12px;
144 | max-width: 70%;
145 | }
146 |
147 | .messages__item--visitor,
148 | .messages__item--typing {
149 | border-top-left-radius: 20px;
150 | border-top-right-radius: 20px;
151 | border-bottom-right-radius: 20px;
152 | }
153 |
154 | .messages__item--operator {
155 | border-top-left-radius: 20px;
156 | border-top-right-radius: 20px;
157 | border-bottom-left-radius: 20px;
158 | background: var(--primary);
159 | color: white;
160 | }
161 |
162 | /* FOOTER */
163 | .chatbox__footer {
164 | display: flex;
165 | flex-direction: row;
166 | align-items: center;
167 | justify-content: space-between;
168 | padding: 20px 20px;
169 | background: var(--secondaryGradient);
170 | box-shadow: var(--secondaryBoxShadow);
171 | border-bottom-right-radius: 10px;
172 | border-bottom-left-radius: 10px;
173 | margin-top: 20px;
174 | }
175 |
176 | .chatbox__footer input {
177 | width: 80%;
178 | border: none;
179 | padding: 10px 10px;
180 | border-radius: 30px;
181 | text-align: left;
182 | }
183 |
184 | .chatbox__send--footer {
185 | color: white;
186 | }
187 |
188 | .chatbox__button button,
189 | .chatbox__button button:focus,
190 | .chatbox__button button:visited {
191 | padding: 10px;
192 | background: white;
193 | border: none;
194 | outline: none;
195 | border-top-left-radius: 50px;
196 | border-top-right-radius: 50px;
197 | border-bottom-left-radius: 50px;
198 | box-shadow: 0px 10px 15px rgba(0, 0, 0, 0.1);
199 | cursor: pointer;
200 | }
201 |
--------------------------------------------------------------------------------
/static/app.js:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/patrickloeber/chatbot-deployment/9f347f8505c03202983f362ad4e02907659f5d67/static/app.js
--------------------------------------------------------------------------------
/static/images/chatbox-icon.svg:
--------------------------------------------------------------------------------
1 |
4 |
--------------------------------------------------------------------------------
/static/style.css:
--------------------------------------------------------------------------------
1 | * {
2 | box-sizing: border-box;
3 | margin: 0;
4 | padding: 0;
5 | }
6 |
7 | body {
8 | font-family: 'Nunito', sans-serif;
9 | font-weight: 400;
10 | font-size: 100%;
11 | background: #F1F1F1;
12 | }
13 |
14 | *, html {
15 | --primaryGradient: linear-gradient(93.12deg, #581B98 0.52%, #9C1DE7 100%);
16 | --secondaryGradient: linear-gradient(268.91deg, #581B98 -2.14%, #9C1DE7 99.69%);
17 | --primaryBoxShadow: 0px 10px 15px rgba(0, 0, 0, 0.1);
18 | --secondaryBoxShadow: 0px -10px 15px rgba(0, 0, 0, 0.1);
19 | --primary: #581B98;
20 | }
21 |
22 | /* CHATBOX
23 | =============== */
24 | .chatbox {
25 | position: absolute;
26 | bottom: 30px;
27 | right: 30px;
28 | }
29 |
30 | /* CONTENT IS CLOSE */
31 | .chatbox__support {
32 | display: flex;
33 | flex-direction: column;
34 | background: #eee;
35 | width: 300px;
36 | height: 350px;
37 | z-index: -123456;
38 | opacity: 0;
39 | transition: all .5s ease-in-out;
40 | }
41 |
42 | /* CONTENT ISOPEN */
43 | .chatbox--active {
44 | transform: translateY(-40px);
45 | z-index: 123456;
46 | opacity: 1;
47 |
48 | }
49 |
50 | /* BUTTON */
51 | .chatbox__button {
52 | text-align: right;
53 | }
54 |
55 | .send__button {
56 | padding: 6px;
57 | background: transparent;
58 | border: none;
59 | outline: none;
60 | cursor: pointer;
61 | }
62 |
63 |
64 | /* HEADER */
65 | .chatbox__header {
66 | position: sticky;
67 | top: 0;
68 | background: orange;
69 | }
70 |
71 | /* MESSAGES */
72 | .chatbox__messages {
73 | margin-top: auto;
74 | display: flex;
75 | overflow-y: scroll;
76 | flex-direction: column-reverse;
77 | }
78 |
79 | .messages__item {
80 | background: orange;
81 | max-width: 60.6%;
82 | width: fit-content;
83 | }
84 |
85 | .messages__item--operator {
86 | margin-left: auto;
87 | }
88 |
89 | .messages__item--visitor {
90 | margin-right: auto;
91 | }
92 |
93 | /* FOOTER */
94 | .chatbox__footer {
95 | position: sticky;
96 | bottom: 0;
97 | }
98 |
99 | .chatbox__support {
100 | background: #f9f9f9;
101 | height: 450px;
102 | width: 350px;
103 | box-shadow: 0px 0px 15px rgba(0, 0, 0, 0.1);
104 | border-top-left-radius: 20px;
105 | border-top-right-radius: 20px;
106 | }
107 |
108 | /* HEADER */
109 | .chatbox__header {
110 | background: var(--primaryGradient);
111 | display: flex;
112 | flex-direction: row;
113 | align-items: center;
114 | justify-content: center;
115 | padding: 15px 20px;
116 | border-top-left-radius: 20px;
117 | border-top-right-radius: 20px;
118 | box-shadow: var(--primaryBoxShadow);
119 | }
120 |
121 | .chatbox__image--header {
122 | margin-right: 10px;
123 | }
124 |
125 | .chatbox__heading--header {
126 | font-size: 1.2rem;
127 | color: white;
128 | }
129 |
130 | .chatbox__description--header {
131 | font-size: .9rem;
132 | color: white;
133 | }
134 |
135 | /* Messages */
136 | .chatbox__messages {
137 | padding: 0 20px;
138 | }
139 |
140 | .messages__item {
141 | margin-top: 10px;
142 | background: #E0E0E0;
143 | padding: 8px 12px;
144 | max-width: 70%;
145 | }
146 |
147 | .messages__item--visitor,
148 | .messages__item--typing {
149 | border-top-left-radius: 20px;
150 | border-top-right-radius: 20px;
151 | border-bottom-right-radius: 20px;
152 | }
153 |
154 | .messages__item--operator {
155 | border-top-left-radius: 20px;
156 | border-top-right-radius: 20px;
157 | border-bottom-left-radius: 20px;
158 | background: var(--primary);
159 | color: white;
160 | }
161 |
162 | /* FOOTER */
163 | .chatbox__footer {
164 | display: flex;
165 | flex-direction: row;
166 | align-items: center;
167 | justify-content: space-between;
168 | padding: 20px 20px;
169 | background: var(--secondaryGradient);
170 | box-shadow: var(--secondaryBoxShadow);
171 | border-bottom-right-radius: 10px;
172 | border-bottom-left-radius: 10px;
173 | margin-top: 20px;
174 | }
175 |
176 | .chatbox__footer input {
177 | width: 80%;
178 | border: none;
179 | padding: 10px 10px;
180 | border-radius: 30px;
181 | text-align: left;
182 | }
183 |
184 | .chatbox__send--footer {
185 | color: white;
186 | }
187 |
188 | .chatbox__button button,
189 | .chatbox__button button:focus,
190 | .chatbox__button button:visited {
191 | padding: 10px;
192 | background: white;
193 | border: none;
194 | outline: none;
195 | border-top-left-radius: 50px;
196 | border-top-right-radius: 50px;
197 | border-bottom-left-radius: 50px;
198 | box-shadow: 0px 10px 15px rgba(0, 0, 0, 0.1);
199 | cursor: pointer;
200 | }
201 |
--------------------------------------------------------------------------------
/templates/base.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 | Chatbot
8 |
9 |
10 |
35 |
36 |
39 |
40 |
41 |
42 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import random
3 | import json
4 |
5 | import torch
6 | import torch.nn as nn
7 | from torch.utils.data import Dataset, DataLoader
8 |
9 | from nltk_utils import bag_of_words, tokenize, stem
10 | from model import NeuralNet
11 |
12 | with open('intents.json', 'r') as f:
13 | intents = json.load(f)
14 |
15 | all_words = []
16 | tags = []
17 | xy = []
18 | # loop through each sentence in our intents patterns
19 | for intent in intents['intents']:
20 | tag = intent['tag']
21 | # add to tag list
22 | tags.append(tag)
23 | for pattern in intent['patterns']:
24 | # tokenize each word in the sentence
25 | w = tokenize(pattern)
26 | # add to our words list
27 | all_words.extend(w)
28 | # add to xy pair
29 | xy.append((w, tag))
30 |
31 | # stem and lower each word
32 | ignore_words = ['?', '.', '!']
33 | all_words = [stem(w) for w in all_words if w not in ignore_words]
34 | # remove duplicates and sort
35 | all_words = sorted(set(all_words))
36 | tags = sorted(set(tags))
37 |
38 | print(len(xy), "patterns")
39 | print(len(tags), "tags:", tags)
40 | print(len(all_words), "unique stemmed words:", all_words)
41 |
42 | # create training data
43 | X_train = []
44 | y_train = []
45 | for (pattern_sentence, tag) in xy:
46 | # X: bag of words for each pattern_sentence
47 | bag = bag_of_words(pattern_sentence, all_words)
48 | X_train.append(bag)
49 | # y: PyTorch CrossEntropyLoss needs only class labels, not one-hot
50 | label = tags.index(tag)
51 | y_train.append(label)
52 |
53 | X_train = np.array(X_train)
54 | y_train = np.array(y_train)
55 |
56 | # Hyper-parameters
57 | num_epochs = 1000
58 | batch_size = 8
59 | learning_rate = 0.001
60 | input_size = len(X_train[0])
61 | hidden_size = 8
62 | output_size = len(tags)
63 | print(input_size, output_size)
64 |
65 | class ChatDataset(Dataset):
66 |
67 | def __init__(self):
68 | self.n_samples = len(X_train)
69 | self.x_data = X_train
70 | self.y_data = y_train
71 |
72 | # support indexing such that dataset[i] can be used to get i-th sample
73 | def __getitem__(self, index):
74 | return self.x_data[index], self.y_data[index]
75 |
76 | # we can call len(dataset) to return the size
77 | def __len__(self):
78 | return self.n_samples
79 |
80 | dataset = ChatDataset()
81 | train_loader = DataLoader(dataset=dataset,
82 | batch_size=batch_size,
83 | shuffle=True,
84 | num_workers=0)
85 |
86 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
87 |
88 | model = NeuralNet(input_size, hidden_size, output_size).to(device)
89 |
90 | # Loss and optimizer
91 | criterion = nn.CrossEntropyLoss()
92 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
93 |
94 | # Train the model
95 | for epoch in range(num_epochs):
96 | for (words, labels) in train_loader:
97 | words = words.to(device)
98 | labels = labels.to(dtype=torch.long).to(device)
99 |
100 | # Forward pass
101 | outputs = model(words)
102 | # if y would be one-hot, we must apply
103 | # labels = torch.max(labels, 1)[1]
104 | loss = criterion(outputs, labels)
105 |
106 | # Backward and optimize
107 | optimizer.zero_grad()
108 | loss.backward()
109 | optimizer.step()
110 |
111 | if (epoch+1) % 100 == 0:
112 | print (f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')
113 |
114 |
115 | print(f'final loss: {loss.item():.4f}')
116 |
117 | data = {
118 | "model_state": model.state_dict(),
119 | "input_size": input_size,
120 | "hidden_size": hidden_size,
121 | "output_size": output_size,
122 | "all_words": all_words,
123 | "tags": tags
124 | }
125 |
126 | FILE = "data.pth"
127 | torch.save(data, FILE)
128 |
129 | print(f'training complete. file saved to {FILE}')
130 |
--------------------------------------------------------------------------------