├── GPT-2-BOT.ipynb
├── README.md
├── SCORES.md
├── download_model.py
├── requirements.txt
├── src
├── GPT2-Learning.py
├── encoder.py
├── model.py
├── olddemo.py
└── sample.py
└── start.sh
/GPT-2-BOT.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "Untitled0.ipynb",
7 | "provenance": [],
8 | "authorship_tag": "ABX9TyMrkPXfE3zUuuEDDYxAmfX6",
9 | "include_colab_link": true
10 | },
11 | "kernelspec": {
12 | "name": "python3",
13 | "display_name": "Python 3"
14 | },
15 | "language_info": {
16 | "name": "python"
17 | },
18 | "accelerator": "TPU"
19 | },
20 | "cells": [
21 | {
22 | "cell_type": "markdown",
23 | "metadata": {
24 | "id": "view-in-github",
25 | "colab_type": "text"
26 | },
27 | "source": [
28 | "
"
29 | ]
30 | },
31 | {
32 | "cell_type": "code",
33 | "metadata": {
34 | "colab": {
35 | "base_uri": "https://localhost:8080/"
36 | },
37 | "id": "4MItH_gOezsh",
38 | "outputId": "228486d1-3fa0-47f0-c877-ca30ad840a81"
39 | },
40 | "source": [
41 | "from google.colab import drive\n",
42 | "drive.mount('/content/gdrive')"
43 | ],
44 | "execution_count": 1,
45 | "outputs": [
46 | {
47 | "output_type": "stream",
48 | "text": [
49 | "Mounted at /content/gdrive\n"
50 | ],
51 | "name": "stdout"
52 | }
53 | ]
54 | },
55 | {
56 | "cell_type": "markdown",
57 | "metadata": {
58 | "id": "NowXHz16haZ0"
59 | },
60 | "source": [
61 | "Connect to google drive for model and project storage."
62 | ]
63 | },
64 | {
65 | "cell_type": "code",
66 | "metadata": {
67 | "colab": {
68 | "base_uri": "https://localhost:8080/"
69 | },
70 | "id": "RbcOtiuzfCQB",
71 | "outputId": "2013f1d6-4d1f-4297-9702-b16bb2a4e1f0"
72 | },
73 | "source": [
74 | "%cd gdrive/My Drive/Colab Notebooks"
75 | ],
76 | "execution_count": 2,
77 | "outputs": [
78 | {
79 | "output_type": "stream",
80 | "text": [
81 | "/content/gdrive/My Drive/Colab Notebooks\n"
82 | ],
83 | "name": "stdout"
84 | }
85 | ]
86 | },
87 | {
88 | "cell_type": "code",
89 | "metadata": {
90 | "colab": {
91 | "base_uri": "https://localhost:8080/"
92 | },
93 | "id": "6jmk6czDflu4",
94 | "outputId": "dfa5e5d2-c648-4e52-ccbd-2df1ee22cee5"
95 | },
96 | "source": [
97 | "! git clone https://github.com/Existencce/GPT2-Telegram-Chatbot"
98 | ],
99 | "execution_count": 3,
100 | "outputs": [
101 | {
102 | "output_type": "stream",
103 | "text": [
104 | "fatal: destination path 'GPT2-Telegram-Chatbot' already exists and is not an empty directory.\n"
105 | ],
106 | "name": "stdout"
107 | }
108 | ]
109 | },
110 | {
111 | "cell_type": "code",
112 | "metadata": {
113 | "colab": {
114 | "base_uri": "https://localhost:8080/"
115 | },
116 | "id": "iN58X9hKgTvN",
117 | "outputId": "f4000c18-6727-41a0-a3e5-32d7cc7b16ec"
118 | },
119 | "source": [
120 | "%cd GPT2-Telegram-Chatbot"
121 | ],
122 | "execution_count": 4,
123 | "outputs": [
124 | {
125 | "output_type": "stream",
126 | "text": [
127 | "/content/gdrive/My Drive/Colab Notebooks/GPT2-Telegram-Chatbot\n"
128 | ],
129 | "name": "stdout"
130 | }
131 | ]
132 | },
133 | {
134 | "cell_type": "code",
135 | "metadata": {
136 | "id": "IH4Yccl0gaNw"
137 | },
138 | "source": [
139 | "! git pull"
140 | ],
141 | "execution_count": null,
142 | "outputs": []
143 | },
144 | {
145 | "cell_type": "markdown",
146 | "metadata": {
147 | "id": "pBQj9tJJgpnN"
148 | },
149 | "source": [
150 | "Change to 774M model, Set your bot token below.\n",
151 | "Make sure to change runtime to GPU/TPU in google collab."
152 | ]
153 | },
154 | {
155 | "cell_type": "code",
156 | "metadata": {
157 | "id": "BB0SqY7xgxDl"
158 | },
159 | "source": [
160 | "!sed -i -e 's/BOTKEY/1887036376:AAF_gJdkt_2z44ZdoLLeiumQvN-9ihYRUBQ/' src/GPT2-Learning.py\n",
161 | "!sed -i -e 's/774M/1558M/' src/GPT2-Learning.py"
162 | ],
163 | "execution_count": 7,
164 | "outputs": []
165 | },
166 | {
167 | "cell_type": "markdown",
168 | "metadata": {
169 | "id": "o6bsyPRkg_EZ"
170 | },
171 | "source": [
172 | "Install Requirements.. You might need to do this a few times."
173 | ]
174 | },
175 | {
176 | "cell_type": "code",
177 | "metadata": {
178 | "colab": {
179 | "base_uri": "https://localhost:8080/",
180 | "height": 1000
181 | },
182 | "id": "KibtYZZ5hBAk",
183 | "outputId": "90b2073d-32bd-49e5-f3d7-bd9a74a20704"
184 | },
185 | "source": [
186 | "!pip3 install tqdm\n",
187 | "!pip3 install regex\n",
188 | "!pip3 install fire\n",
189 | "!pip3 install python-telegram-bot==12.0.0\n",
190 | "!pip3 install requests\n",
191 | "!pip3 install tensorflow-gpu==1.15.5"
192 | ],
193 | "execution_count": 8,
194 | "outputs": [
195 | {
196 | "output_type": "stream",
197 | "text": [
198 | "Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (4.41.1)\n",
199 | "Requirement already satisfied: regex in /usr/local/lib/python3.7/dist-packages (2019.12.20)\n",
200 | "Collecting fire\n",
201 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/11/07/a119a1aa04d37bc819940d95ed7e135a7dcca1c098123a3764a6dcace9e7/fire-0.4.0.tar.gz (87kB)\n",
202 | "\u001b[K |████████████████████████████████| 92kB 4.2MB/s \n",
203 | "\u001b[?25hRequirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from fire) (1.15.0)\n",
204 | "Requirement already satisfied: termcolor in /usr/local/lib/python3.7/dist-packages (from fire) (1.1.0)\n",
205 | "Building wheels for collected packages: fire\n",
206 | " Building wheel for fire (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
207 | " Created wheel for fire: filename=fire-0.4.0-py2.py3-none-any.whl size=115928 sha256=bbcd80493466e69a169b7e9329d46ade1766ba745010bccf7c3c80e6d42e0de9\n",
208 | " Stored in directory: /root/.cache/pip/wheels/af/19/30/1ea0cad502dcb4e66ed5a690279628c827aea38bbbab75d5ed\n",
209 | "Successfully built fire\n",
210 | "Installing collected packages: fire\n",
211 | "Successfully installed fire-0.4.0\n",
212 | "Collecting python-telegram-bot==12.0.0\n",
213 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/25/06/5047b87e9ec3ffd9af6f83069803921f2f02c9411753d610cc569c8e4638/python_telegram_bot-12.0.0-py2.py3-none-any.whl (346kB)\n",
214 | "\u001b[K |████████████████████████████████| 348kB 6.9MB/s \n",
215 | "\u001b[?25hRequirement already satisfied: certifi in /usr/local/lib/python3.7/dist-packages (from python-telegram-bot==12.0.0) (2020.12.5)\n",
216 | "Requirement already satisfied: tornado>=5.1 in /usr/local/lib/python3.7/dist-packages (from python-telegram-bot==12.0.0) (5.1.1)\n",
217 | "Collecting cryptography\n",
218 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/b2/26/7af637e6a7e87258b963f1731c5982fb31cd507f0d90d91836e446955d02/cryptography-3.4.7-cp36-abi3-manylinux2014_x86_64.whl (3.2MB)\n",
219 | "\u001b[K |████████████████████████████████| 3.2MB 9.3MB/s \n",
220 | "\u001b[?25hRequirement already satisfied: future>=0.16.0 in /usr/local/lib/python3.7/dist-packages (from python-telegram-bot==12.0.0) (0.16.0)\n",
221 | "Requirement already satisfied: cffi>=1.12 in /usr/local/lib/python3.7/dist-packages (from cryptography->python-telegram-bot==12.0.0) (1.14.5)\n",
222 | "Requirement already satisfied: pycparser in /usr/local/lib/python3.7/dist-packages (from cffi>=1.12->cryptography->python-telegram-bot==12.0.0) (2.20)\n",
223 | "Installing collected packages: cryptography, python-telegram-bot\n",
224 | "Successfully installed cryptography-3.4.7 python-telegram-bot-12.0.0\n",
225 | "Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (2.23.0)\n",
226 | "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests) (3.0.4)\n",
227 | "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests) (1.24.3)\n",
228 | "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests) (2.10)\n",
229 | "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests) (2020.12.5)\n",
230 | "Collecting tensorflow-gpu==1.15.5\n",
231 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/73/b5/adc281ce4e631251c749d342793795832026edf9035df81c3813ef33fad2/tensorflow_gpu-1.15.5-cp37-cp37m-manylinux2010_x86_64.whl (411.0MB)\n",
232 | "\u001b[K |████████████████████████████████| 411.0MB 40kB/s \n",
233 | "\u001b[?25hRequirement already satisfied: google-pasta>=0.1.6 in /usr/local/lib/python3.7/dist-packages (from tensorflow-gpu==1.15.5) (0.2.0)\n",
234 | "Collecting gast==0.2.2\n",
235 | " Downloading https://files.pythonhosted.org/packages/4e/35/11749bf99b2d4e3cceb4d55ca22590b0d7c2c62b9de38ac4a4a7f4687421/gast-0.2.2.tar.gz\n",
236 | "Requirement already satisfied: six>=1.10.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow-gpu==1.15.5) (1.15.0)\n",
237 | "Collecting keras-applications>=1.0.8\n",
238 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/71/e3/19762fdfc62877ae9102edf6342d71b28fbfd9dea3d2f96a882ce099b03f/Keras_Applications-1.0.8-py3-none-any.whl (50kB)\n",
239 | "\u001b[K |████████████████████████████████| 51kB 6.1MB/s \n",
240 | "\u001b[?25hRequirement already satisfied: wrapt>=1.11.1 in /usr/local/lib/python3.7/dist-packages (from tensorflow-gpu==1.15.5) (1.12.1)\n",
241 | "Requirement already satisfied: protobuf>=3.6.1 in /usr/local/lib/python3.7/dist-packages (from tensorflow-gpu==1.15.5) (3.12.4)\n",
242 | "Requirement already satisfied: keras-preprocessing>=1.0.5 in /usr/local/lib/python3.7/dist-packages (from tensorflow-gpu==1.15.5) (1.1.2)\n",
243 | "Requirement already satisfied: opt-einsum>=2.3.2 in /usr/local/lib/python3.7/dist-packages (from tensorflow-gpu==1.15.5) (3.3.0)\n",
244 | "Collecting tensorboard<1.16.0,>=1.15.0\n",
245 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/1e/e9/d3d747a97f7188f48aa5eda486907f3b345cd409f0a0850468ba867db246/tensorboard-1.15.0-py3-none-any.whl (3.8MB)\n",
246 | "\u001b[K |████████████████████████████████| 3.8MB 32.2MB/s \n",
247 | "\u001b[?25hRequirement already satisfied: h5py<=2.10.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow-gpu==1.15.5) (2.10.0)\n",
248 | "Requirement already satisfied: wheel>=0.26; python_version >= \"3\" in /usr/local/lib/python3.7/dist-packages (from tensorflow-gpu==1.15.5) (0.36.2)\n",
249 | "Requirement already satisfied: grpcio>=1.8.6 in /usr/local/lib/python3.7/dist-packages (from tensorflow-gpu==1.15.5) (1.32.0)\n",
250 | "Collecting tensorflow-estimator==1.15.1\n",
251 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/de/62/2ee9cd74c9fa2fa450877847ba560b260f5d0fb70ee0595203082dafcc9d/tensorflow_estimator-1.15.1-py2.py3-none-any.whl (503kB)\n",
252 | "\u001b[K |████████████████████████████████| 512kB 53.4MB/s \n",
253 | "\u001b[?25hRequirement already satisfied: astor>=0.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow-gpu==1.15.5) (0.8.1)\n",
254 | "Requirement already satisfied: absl-py>=0.7.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow-gpu==1.15.5) (0.12.0)\n",
255 | "Requirement already satisfied: termcolor>=1.1.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow-gpu==1.15.5) (1.1.0)\n",
256 | "Collecting numpy<1.19.0,>=1.16.0\n",
257 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/d6/c6/58e517e8b1fb192725cfa23c01c2e60e4e6699314ee9684a1c5f5c9b27e1/numpy-1.18.5-cp37-cp37m-manylinux1_x86_64.whl (20.1MB)\n",
258 | "\u001b[K |████████████████████████████████| 20.1MB 1.4MB/s \n",
259 | "\u001b[?25hRequirement already satisfied: setuptools in /usr/local/lib/python3.7/dist-packages (from protobuf>=3.6.1->tensorflow-gpu==1.15.5) (56.1.0)\n",
260 | "Requirement already satisfied: werkzeug>=0.11.15 in /usr/local/lib/python3.7/dist-packages (from tensorboard<1.16.0,>=1.15.0->tensorflow-gpu==1.15.5) (2.0.0)\n",
261 | "Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.7/dist-packages (from tensorboard<1.16.0,>=1.15.0->tensorflow-gpu==1.15.5) (3.3.4)\n",
262 | "Requirement already satisfied: importlib-metadata; python_version < \"3.8\" in /usr/local/lib/python3.7/dist-packages (from markdown>=2.6.8->tensorboard<1.16.0,>=1.15.0->tensorflow-gpu==1.15.5) (4.0.1)\n",
263 | "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata; python_version < \"3.8\"->markdown>=2.6.8->tensorboard<1.16.0,>=1.15.0->tensorflow-gpu==1.15.5) (3.4.1)\n",
264 | "Requirement already satisfied: typing-extensions>=3.6.4; python_version < \"3.8\" in /usr/local/lib/python3.7/dist-packages (from importlib-metadata; python_version < \"3.8\"->markdown>=2.6.8->tensorboard<1.16.0,>=1.15.0->tensorflow-gpu==1.15.5) (3.7.4.3)\n",
265 | "Building wheels for collected packages: gast\n",
266 | " Building wheel for gast (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
267 | " Created wheel for gast: filename=gast-0.2.2-cp37-none-any.whl size=7540 sha256=3d3a4bf0e39a6e2a93a59617b12913ca208bd56cfd5a81827bc5c088c22ce407\n",
268 | " Stored in directory: /root/.cache/pip/wheels/5c/2e/7e/a1d4d4fcebe6c381f378ce7743a3ced3699feb89bcfbdadadd\n",
269 | "Successfully built gast\n",
270 | "\u001b[31mERROR: tensorflow 2.4.1 has requirement gast==0.3.3, but you'll have gast 0.2.2 which is incompatible.\u001b[0m\n",
271 | "\u001b[31mERROR: tensorflow 2.4.1 has requirement numpy~=1.19.2, but you'll have numpy 1.18.5 which is incompatible.\u001b[0m\n",
272 | "\u001b[31mERROR: tensorflow 2.4.1 has requirement tensorboard~=2.4, but you'll have tensorboard 1.15.0 which is incompatible.\u001b[0m\n",
273 | "\u001b[31mERROR: tensorflow 2.4.1 has requirement tensorflow-estimator<2.5.0,>=2.4.0, but you'll have tensorflow-estimator 1.15.1 which is incompatible.\u001b[0m\n",
274 | "\u001b[31mERROR: tensorflow-probability 0.12.1 has requirement gast>=0.3.2, but you'll have gast 0.2.2 which is incompatible.\u001b[0m\n",
275 | "\u001b[31mERROR: datascience 0.10.6 has requirement folium==0.2.1, but you'll have folium 0.8.3 which is incompatible.\u001b[0m\n",
276 | "\u001b[31mERROR: albumentations 0.1.12 has requirement imgaug<0.2.7,>=0.2.5, but you'll have imgaug 0.2.9 which is incompatible.\u001b[0m\n",
277 | "Installing collected packages: gast, numpy, keras-applications, tensorboard, tensorflow-estimator, tensorflow-gpu\n",
278 | " Found existing installation: gast 0.3.3\n",
279 | " Uninstalling gast-0.3.3:\n",
280 | " Successfully uninstalled gast-0.3.3\n",
281 | " Found existing installation: numpy 1.19.5\n",
282 | " Uninstalling numpy-1.19.5:\n",
283 | " Successfully uninstalled numpy-1.19.5\n",
284 | " Found existing installation: tensorboard 2.4.1\n",
285 | " Uninstalling tensorboard-2.4.1:\n",
286 | " Successfully uninstalled tensorboard-2.4.1\n",
287 | " Found existing installation: tensorflow-estimator 2.4.0\n",
288 | " Uninstalling tensorflow-estimator-2.4.0:\n",
289 | " Successfully uninstalled tensorflow-estimator-2.4.0\n",
290 | "Successfully installed gast-0.2.2 keras-applications-1.0.8 numpy-1.18.5 tensorboard-1.15.0 tensorflow-estimator-1.15.1 tensorflow-gpu-1.15.5\n"
291 | ],
292 | "name": "stdout"
293 | },
294 | {
295 | "output_type": "display_data",
296 | "data": {
297 | "application/vnd.colab-display-data+json": {
298 | "pip_warning": {
299 | "packages": [
300 | "numpy"
301 | ]
302 | }
303 | }
304 | },
305 | "metadata": {
306 | "tags": []
307 | }
308 | }
309 | ]
310 | },
311 | {
312 | "cell_type": "markdown",
313 | "metadata": {
314 | "id": "Lit4usgih93R"
315 | },
316 | "source": [
317 | "After requirements installed, reconnect to google drive after restarting runtime and setting runtime to TPU under \"Runtime -> Change Runtime Type\" tab"
318 | ]
319 | },
320 | {
321 | "cell_type": "code",
322 | "metadata": {
323 | "colab": {
324 | "base_uri": "https://localhost:8080/"
325 | },
326 | "id": "7QanwV8MiDsh",
327 | "outputId": "f2d66701-b99c-4b6c-e8b5-b7586169d80b"
328 | },
329 | "source": [
330 | "from google.colab import drive\n",
331 | "drive.mount('/content/gdrive')\n",
332 | "%cd /content/gdrive/MyDrive/Colab Notebooks/GPT2-Telegram-Chatbot"
333 | ],
334 | "execution_count": 1,
335 | "outputs": [
336 | {
337 | "output_type": "stream",
338 | "text": [
339 | "Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount(\"/content/gdrive\", force_remount=True).\n",
340 | "/content/gdrive/MyDrive/Colab Notebooks/GPT2-Telegram-Chatbot\n"
341 | ],
342 | "name": "stdout"
343 | }
344 | ]
345 | },
346 | {
347 | "cell_type": "markdown",
348 | "metadata": {
349 | "id": "xDZj822bieW7"
350 | },
351 | "source": [
352 | "Download the model"
353 | ]
354 | },
355 | {
356 | "cell_type": "code",
357 | "metadata": {
358 | "colab": {
359 | "base_uri": "https://localhost:8080/"
360 | },
361 | "id": "EY6miAy9igEs",
362 | "outputId": "b2bd1c49-d4e6-4974-ce82-ee2663da0cdc"
363 | },
364 | "source": [
365 | "!python3 download_model.py 1558M"
366 | ],
367 | "execution_count": null,
368 | "outputs": [
369 | {
370 | "output_type": "stream",
371 | "text": [
372 | "\rFetching checkpoint: 0%| | 0.00/77.0 [00:00, ?it/s]\rFetching checkpoint: 1.00kit [00:00, 1.05Mit/s] \n",
373 | "Fetching encoder.json: 1.04Mit [00:00, 8.09Mit/s] \n",
374 | "Fetching hparams.json: 1.00kit [00:00, 1.07Mit/s] \n",
375 | "Fetching model.ckpt.data-00000-of-00001: 6.23Git [02:42, 38.3Mit/s] \n",
376 | "Fetching model.ckpt.index: 21.0kit [00:00, 4.20Mit/s] \n",
377 | "Fetching model.ckpt.meta: 1.84Mit [00:00, 12.2Mit/s] \n",
378 | "Fetching vocab.bpe: 457kit [00:00, 4.45Mit/s] \n"
379 | ],
380 | "name": "stdout"
381 | }
382 | ]
383 | },
384 | {
385 | "cell_type": "markdown",
386 | "metadata": {
387 | "id": "mXT-Vh_jkDE5"
388 | },
389 | "source": [
390 | "Run the bot, if any errors appear try re-installing requirements."
391 | ]
392 | },
393 | {
394 | "cell_type": "code",
395 | "metadata": {
396 | "id": "jcizkIQAkEEQ"
397 | },
398 | "source": [
399 | "!chmod +x ./start.sh"
400 | ],
401 | "execution_count": 2,
402 | "outputs": []
403 | },
404 | {
405 | "cell_type": "code",
406 | "metadata": {
407 | "colab": {
408 | "base_uri": "https://localhost:8080/"
409 | },
410 | "id": "MZLiNrOPlQcx",
411 | "outputId": "5154346c-b604-48bb-e627-4efb4946bcc7"
412 | },
413 | "source": [
414 | "!./start.sh"
415 | ],
416 | "execution_count": null,
417 | "outputs": [
418 | {
419 | "output_type": "stream",
420 | "text": [
421 | "src/GPT2-Learning.py:545: TelegramDeprecationWarning: Old Handler API is deprecated - see https://git.io/fxJuV for details\n",
422 | " updater = Updater(\"1887036376:AAF_gJdkt_2z44ZdoLLeiumQvN-9ihYRUBQ\", use_context=False)\n"
423 | ],
424 | "name": "stdout"
425 | }
426 | ]
427 | }
428 | ]
429 | }
430 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## GPT2-Telegram-Chatbot
2 |
3 | A GPT-2 Telegram chatbot that's been relatively tuned for chatting. Feel free to make me PRs and I'll check out your code! The bot isn't 100% accurate all the time (why I coded in a /retry function.)
4 |
5 | Since the bot consumes so much memory, I have it programmed in a round-robin sort of mode. Each input will reset a timer on your account ID, once the timer runs down the bot is free for other users to use. You will be notified when the timer runs down, and other users can see how much time is left and if the bot is in use.
6 |
7 | ### Installation/How to use:
8 |
9 | Brief install instructions on Ubuntu 20/WSL.
10 |
11 | I highly reccomend looking at the jupyter notebook/ipynb on google collab instead.
12 |
13 | Install python3.7 (I think 3.6 might work as well, but not 3.8):
14 |
15 | ```
16 | sudo add-apt-repository ppa:deadsnakes/ppa
17 | sudo apt-get update
18 | sudo apt-get install python3.7
19 | ```
20 |
21 | Install pip on python 3.7:
22 |
23 | ```
24 | wget https://bootstrap.pypa.io/get-pip.py
25 | python3.7 get-pip.py
26 | ```
27 |
28 | Install requirements inside of bot folder after cloning repository:
29 | ```
30 | python3.7 -m pip install -r requirements.txt
31 | ```
32 |
33 | Note: You realistically need 16GB of ram or a 8GB video card. Otherwise you will wait forever.
34 | You can use GPU functions with atleast a 8GB video card that supports cuda tooklit 10.0 and cudnn for cuda toolkit 10. This install also works on windows with python 3.7 and nvidia, you must run command prompt as admin running python 3.7 on windows.
35 |
36 | Download the model:
37 | ```
38 | python3.7 download_model.py 1558M
39 | ```
40 |
41 | Set your telegram bot API key in src/GPT2-Learning.py
42 | ```
43 | Replace "BOTKEY" with telegram bot token i.e. "1827396499:AAHifc06oS31oQ9L3TuCiZxD9EIfKPi0oWQ"
44 | ```
45 |
46 | Run the bot:
47 |
48 | If using python3 command:
49 | ```
50 | ./start.sh
51 | ```
52 |
53 | If using python3.7 command:
54 | ```
55 | python3.7 src/GPT2-Learning.py
56 | ```
57 |
58 | You can edit start.sh to match your python command as you wish.
59 |
60 | ### For quick command reference:
61 |
62 | ```
63 | start - Basic start command.
64 | help - Show quick help command.
65 | chatbot - Chatbot mode, me+you string edit and open-ended questions.
66 | finish - Default finish sentence command counting words as output length.
67 | learnon - Turn on sentence compilation with Me+you string for conversation.
68 | learnoff - Turn off the sentence compilation.
69 | learnreset - Reset the learned conversation.
70 | retry - Retry the current computation.
71 | ```
72 | Here's a sample in learning chatbot mode with the latest gpt-2 database and top_p set to 0.83:
73 |
74 | ```
75 | Me: Hello! How are you today? Are you feeling alright? I hope you are doing well and not getting into any trouble.
76 | Bot: Yeah, I'm fine. Do you want to make something to eat?
77 | Me: I already made some food, but thanks! I had some jalapeno poppers. I made a ton of them last night after I sold a lawnmower to some folks I got for free earlier in the day. So what are your plans for the day?
78 | Bot: Nothing special. Nothing exciting in particular. I just wanna chill out for a bit.
79 | Me: Just a typical day of kicking back and relaxing huh? Are you gonna watch cartoons or any movies in particular?
80 | Bot: If it's not sports I'll probably just watch something on the TV.
81 | Me: Sounds alright. I coded on my chatbot project most of the morning. Was waiting on someone to come buy a projector but they haven't shown up yet unfortunately. What do you do for work if you don't mind me asking?
82 | Bot: I'm a web developer and engineer. It's fun.
83 | Me: Cool! I'm also a coder and an ethical hacker. I mostly know how to code in Python but I've been unemployed for some time now after my spine injury. I wish I could have joined the military when I was younger. Were you ever in the military?
84 | Bot: No. I'm not really too proud of my military career. It ended up taking a huge toll on my back.
85 | Me: What happened to your back if you don't mind me asking?
86 | Bot: First of all, it's better than before.
87 | Me: Well that's good to hear. Did you injure your back in the military or doing something else?
88 | Bot: Nope. It's an old injury.
89 | Me: Ah I see- Well I'll talk to you later. Sorry about your back! I hope we can catch up tomorrow!
90 | Bot: Yes, of course. Thanks again for chatting!
91 | ```
92 |
93 | ### Example of learning mode backend:
94 | ```
95 | INFO:tensorflow:Restoring parameters from models/1558M/model.ckpt
96 | 2020-05-10 14:19:43,794 - tensorflow - INFO - Restoring parameters from models/1558M/model.ckpt
97 | ==========
98 | Before splitlines: Nothing special. Nothing exciting in particular. I just wanna chill out for a bit. You: I'll try to keep my spirits up. Just keep in mind that today's the 10th anniversary of the citywide strike
99 | ==========
100 | ==========
101 | Mode: True
102 | Learn: True
103 | Length: 44
104 | ==========
105 | Before regex: Nothing special. Nothing exciting in particular. I just wanna chill out for a bit. You: I'll try to keep my spirits up. Just keep in mind that today's the 10th anniversary of the citywide strike
106 | ==========
107 | Output: Nothing special. Nothing exciting in particular. I just wanna chill out for a bit.
108 | ==========
109 | Raw_text or Original: You: Hello! How are you today? Are you feeling alright? I hope you are doing well and not getting into any trouble. Me: Yeah, I'm fine. Do you want to make something to eat? You: I already made some food, but thanks! I had some jalapeno poppers. I made a ton of them last night after I sold a lawnmower to some folks I got for free earlier in the day. So what are your plans for the day? Me:
110 | ==========
111 | Learning text or Next: You: Hello! How are you today? Are you feeling alright? I hope you are doing well and not getting into any trouble. Me: Yeah, I'm fine. Do you want to make something to eat? You: I already made some food, but thanks! I had some jalapeno poppers. I made a ton of them last night after I sold a lawnmower to some folks I got for free earlier in the day. So what are your plans for the day? Me: Nothing special. Nothing exciting in particular. I just wanna chill out for a bit.
112 | ==========
113 | top_p out: 0.8338636363636364
114 | ==========
115 | top_p in: 0.83
116 | ==========
117 | ```
118 |
119 | For a list of grammarly scores please see [/SCORES.MD](/SCORES.md).
120 |
121 | Tip:
122 |
123 | 0.77 top_p can sound emotional, confused and copycat-ish.
124 |
125 | 0.66 top_p can sound thought-out and literal but can have ascii and cut-off errors.
126 |
--------------------------------------------------------------------------------
/SCORES.md:
--------------------------------------------------------------------------------
1 | Here's a list of grammarly scores:
2 |
3 | Top-P | Score
4 | ------------- | -------------
5 | 0.59 | 68
6 | 0.60 | 82
7 | 0.61 | 97
8 | 0.62 | 99
9 | 0.63 | 97
10 | 0.64 | 81
11 | 0.65 | 97
12 | 0.66 | 99
13 | 0.67 | 96
14 | 0.68 | 99
15 | 0.69 | 99
16 | 0.70 | 99
17 | 0.71 | 96
18 | 0.72 | 99
19 | 0.73 | 99
20 | 0.74 | 99
21 | 0.75 | 86
22 | 0.76 | 85
23 | 0.77 | 93
24 | 0.78 | 99
25 | 0.79 | 97
26 | 0.80 | 96
27 | 0.81 | 93
28 | 0.82 | 92
29 | 0.83 | 91
30 | 0.84 | 96
31 | 0.85 | 76
32 | 0.86 | 87
33 | 0.87 | 70
34 | 0.88 | 86
35 | 0.89 | 87
36 | 0.90 | 79
37 |
38 |
39 | Here is a list of length scores run through grammarly at 0.73 top_p:
40 | Somewhere along here I started playing with the context words. i.e. in the code: It adds "Response:" and "Reply:" before anything computed and before it was "Me:" + "You:".
41 |
42 | Sentence Length | Score
43 | ------------- | -------------
44 | Length 25 | 77
45 | Length 24 | 80
46 | Length 23 | 72
47 | Length 22 | 71
48 | Length 21 | 86
49 | Length 20 | 88
50 | Length 19 | 76
51 | Length 18 | 68
52 | Length 17 | 61
53 | Length 16 | 53
54 | Length 15 | 87
55 | Length 14 | 83
56 | Length 13 | 94
57 | Length 12 | 55
58 | Length 11 | 79
59 | Length 10 | 99
60 | Length 9 | 92
61 | Length 8 | 85
62 | Length 7 | 83
63 |
64 | I'm unsure of how useful this would be. I think the input text was 5 words.
65 |
--------------------------------------------------------------------------------
/download_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import requests
4 | from tqdm import tqdm
5 |
6 | if len(sys.argv) != 2:
7 | print('You must enter the model name as a parameter, e.g.: download_model.py 124M')
8 | sys.exit(1)
9 |
10 | model = sys.argv[1]
11 |
12 | subdir = os.path.join('models', model)
13 | if not os.path.exists(subdir):
14 | os.makedirs(subdir)
15 | subdir = subdir.replace('\\','/') # needed for Windows
16 |
17 | for filename in ['checkpoint','encoder.json','hparams.json','model.ckpt.data-00000-of-00001', 'model.ckpt.index', 'model.ckpt.meta', 'vocab.bpe']:
18 |
19 | r = requests.get("https://openaipublic.blob.core.windows.net/gpt-2/" + subdir + "/" + filename, stream=True)
20 |
21 | with open(os.path.join(subdir, filename), 'wb') as f:
22 | file_size = int(r.headers["content-length"])
23 | chunk_size = 1000
24 | with tqdm(ncols=100, desc="Fetching " + filename, total=file_size, unit_scale=True) as pbar:
25 | # 1k for chunk_size, since Ethernet packet size is around 1500 bytes
26 | for chunk in r.iter_content(chunk_size=chunk_size):
27 | f.write(chunk)
28 | pbar.update(chunk_size)
29 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | tqdm
2 | regex
3 | python-telegram-bot==12.0.0
4 | requests
5 | tensorflow-gpu==1.15.5
6 |
--------------------------------------------------------------------------------
/src/GPT2-Learning.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | from telegram.ext import Updater, CommandHandler, MessageHandler, Filters
4 | import json, os, string, sys, threading, random, model, sample, encoder, logging, time
5 | import numpy as np
6 | import tensorflow as tf
7 | import re
8 | import os
9 |
10 | # Enable console logging
11 | logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
12 | level=logging.INFO)
13 | logger = logging.getLogger(__name__)
14 |
15 | # Console output debug prints
16 | debug = True
17 |
18 | # Session timeout
19 | timeout = 1500
20 |
21 | # top_p (refer to gpt-2 documentation)
22 | top = 0.66
23 |
24 | # Temperature (refer to gpt-2 documentation)
25 | degree = 1
26 |
27 | # Top_p multiplier - add to top_p per word
28 | # 0.00375 - may be shorter
29 | # 0.00400
30 | # 0.00425
31 | # 0.00450
32 | # 0.00475
33 | # 0.00500 - may be longer
34 | mx = 0.00375
35 |
36 | # Top_K unused here, might be useful eventually.
37 | tok = 0
38 |
39 | # This is the start of the learning cache, could be useful eventually.
40 | learning = ""
41 |
42 | # End settings
43 | mode = True
44 | learn = True
45 | user = ""
46 | cache = ""
47 | running = False
48 | temps = str(degree)
49 | tpstring = str(top)
50 |
51 | # Define a few command handlers. These usually take the two arguments bot and
52 | # update. Error handlers also receive the raised TelegramError object in error.
53 |
54 | def start(bot, update):
55 | """Send a message when the command /start is issued."""
56 | global running
57 | global mode
58 | global learn
59 | global user
60 | global tim
61 | global learning
62 | global cache
63 | if user == "":
64 | user = update.message.from_user.id
65 | mode = True
66 | learn = True
67 | learning = ""
68 | cache = ""
69 | if mode == True and learn == True:
70 | update.message.reply_text('Send a message! Get it computed! 1558M Settings: Logic: ' + tpstring + ' Rate:' + temps + ' GPT-2 1558M. I am in the learning chatbot mode.')
71 | if mode == True and learn == False:
72 | update.message.reply_text('Send a message! Get it computed! 1558M Settings: Logic: ' + tpstring + ' Rate:' + temps + ' GPT-2 1558M I am in the chatbot mode.')
73 | if mode == False:
74 | update.message.reply_text('Send a message! Get it computed! 1558M Settings: Logic: ' + tpstring + ' Rate:' + temps + ' GPT-2 1558M I am in the finishsentence mode.')
75 | return
76 | if user == update.message.from_user.id:
77 | mode = True
78 | learn = True
79 | learning = ""
80 | cache = ""
81 | if mode == True and learn == True:
82 | update.message.reply_text('Send a message! Get it computed! 1558M Settings: Logic: ' + tpstring + ' Rate:' + temps + ' GPT-2 1558M. I am in the learning chatbot mode.')
83 | if mode == True and learn == False:
84 | update.message.reply_text('Send a message! Get it computed! 1558M Settings: Logic: ' + tpstring + ' Rate:' + temps + ' GPT-2 1558M I am in the chatbot mode.')
85 | if mode == False:
86 | update.message.reply_text('Send a message! Get it computed! 1558M Settings: Logic: ' + tpstring + ' Rate:' + temps + ' GPT-2 1558M I am in the finishsentence mode.')
87 | return
88 | else:
89 | left = str(tim)
90 | update.message.reply_text('Bot is currently in use, make sure to set your settings when their timer runs down. ' + left + ' seconds.')
91 |
92 | def help(bot, update):
93 | """Send a message when the command /help is issued."""
94 | update.message.reply_text('Just type a message... It could be lagged out. /chatbot goes into Me: You: mode. /finish just finishes the text /learnon for conversation learning mode.')
95 |
96 | def chatbot(bot, update):
97 | """Send a message when the command /chatbot is issued."""
98 | global running
99 | global mode
100 | global learn
101 | global user
102 | global tim
103 | global learning
104 | global cache
105 | if user == "":
106 | user = update.message.from_user.id
107 | mode = True
108 | learn = False
109 | learning = ""
110 | cache = ""
111 | if mode == True and learn == True:
112 | update.message.reply_text('Send a message! Get it computed! 1558M Settings: Logic: ' + tpstring + ' Rate:' + temps + ' GPT-2 1558M. I am in the learning chatbot mode.')
113 | if mode == True and learn == False:
114 | update.message.reply_text('Send a message! Get it computed! 1558M Settings: Logic: ' + tpstring + ' Rate:' + temps + ' GPT-2 1558M I am in the chatbot mode.')
115 | if mode == False:
116 | update.message.reply_text('Send a message! Get it computed! 1558M Settings: Logic: ' + tpstring + ' Rate:' + temps + ' GPT-2 1558M I am in the finishsentence mode.')
117 | return
118 | if user == update.message.from_user.id:
119 | mode = True
120 | learn = False
121 | learning = ""
122 | cache = ""
123 | if mode == True and learn == True:
124 | update.message.reply_text('Send a message! Get it computed! 1558M Settings: Logic: ' + tpstring + ' Rate:' + temps + ' GPT-2 1558M. I am in the learning chatbot mode.')
125 | if mode == True and learn == False:
126 | update.message.reply_text('Send a message! Get it computed! 1558M Settings: Logic: ' + tpstring + ' Rate:' + temps + ' GPT-2 1558M I am in the chatbot mode.')
127 | if mode == False:
128 | update.message.reply_text('Send a message! Get it computed! 1558M Settings: Logic: ' + tpstring + ' Rate:' + temps + ' GPT-2 1558M I am in the finishsentence mode.')
129 | return
130 | else:
131 | left = str(tim)
132 | update.message.reply_text('Bot is currently in use, make sure to set your settings when their timer runs down. ' + left + ' seconds.')
133 |
134 | def finish(bot, update):
135 | """Send a message when the command /finish is issued."""
136 | global running
137 | global mode
138 | global learn
139 | global user
140 | global tim
141 | global learning
142 | global cache
143 | if user == "":
144 | user = update.message.from_user.id
145 | mode = False
146 | learn = False
147 | learning = ""
148 | cache = ""
149 | if mode == True and learn == True:
150 | update.message.reply_text('Send a message! Get it computed! 1558M Settings: Logic: ' + tpstring + ' Rate:' + temps + ' GPT-2 1558M. I am in the learning chatbot mode.')
151 | if mode == True and learn == False:
152 | update.message.reply_text('Send a message! Get it computed! 1558M Settings: Logic: ' + tpstring + ' Rate:' + temps + ' GPT-2 1558M I am in the chatbot mode.')
153 | if mode == False:
154 | update.message.reply_text('Send a message! Get it computed! 1558M Settings: Logic: ' + tpstring + ' Rate:' + temps + ' GPT-2 1558M I am in the finishsentence mode.')
155 | return
156 | if user == update.message.from_user.id:
157 | mode = False
158 | learn = False
159 | learning = ""
160 | cache = ""
161 | if mode == True and learn == True:
162 | update.message.reply_text('Send a message! Get it computed! 1558M Settings: Logic: ' + tpstring + ' Rate:' + temps + ' GPT-2 1558M. I am in the learning chatbot mode.')
163 | if mode == True and learn == False:
164 | update.message.reply_text('Send a message! Get it computed! 1558M Settings: Logic: ' + tpstring + ' Rate:' + temps + ' GPT-2 1558M I am in the chatbot mode.')
165 | if mode == False:
166 | update.message.reply_text('Send a message! Get it computed! 1558M Settings: Logic: ' + tpstring + ' Rate:' + temps + ' GPT-2 1558M I am in the finishsentence mode.')
167 | return
168 | else:
169 | left = str(tim)
170 | update.message.reply_text('Bot is currently in use, make sure to set your settings when their timer runs down. ' + left + ' seconds.')
171 |
172 | def learnon(bot, update):
173 | """Send a message when the command /learnon is issued."""
174 | global running
175 | global mode
176 | global learn
177 | global user
178 | global tim
179 | global learning
180 | global cache
181 | if user == "":
182 | user = update.message.from_user.id
183 | mode = True
184 | learn = True
185 | learning = ""
186 | cache = ""
187 | if mode == True and learn == True:
188 | update.message.reply_text('Send a message! Get it computed! 1558M Settings: Logic: ' + tpstring + ' Rate:' + temps + ' GPT-2 1558M. I am in the learning chatbot mode.')
189 | if mode == True and learn == False:
190 | update.message.reply_text('Send a message! Get it computed! 1558M Settings: Logic: ' + tpstring + ' Rate:' + temps + ' GPT-2 1558M I am in the chatbot mode.')
191 | if mode == False:
192 | update.message.reply_text('Send a message! Get it computed! 1558M Settings: Logic: ' + tpstring + ' Rate:' + temps + ' GPT-2 1558M I am in the finishsentence mode.')
193 | return
194 | if user == update.message.from_user.id:
195 | mode = True
196 | learn = True
197 | learning = ""
198 | cache = ""
199 | if mode == True and learn == True:
200 | update.message.reply_text('Send a message! Get it computed! 1558M Settings: Logic: ' + tpstring + ' Rate:' + temps + ' GPT-2 1558M. I am in the learning chatbot mode.')
201 | if mode == True and learn == False:
202 | update.message.reply_text('Send a message! Get it computed! 1558M Settings: Logic: ' + tpstring + ' Rate:' + temps + ' GPT-2 1558M I am in the chatbot mode.')
203 | if mode == False:
204 | update.message.reply_text('Send a message! Get it computed! 1558M Settings: Logic: ' + tpstring + ' Rate:' + temps + ' GPT-2 1558M I am in the finishsentence mode.')
205 | return
206 | else:
207 | left = str(tim)
208 | update.message.reply_text('Bot is currently in use, make sure to set your settings when their timer runs down. ' + left + ' seconds.')
209 |
210 | def learnoff(bot, update):
211 | """Send a message when the command /learnoff is issued."""
212 | global running
213 | global mode
214 | global learn
215 | global user
216 | global tim
217 | global learning
218 | global cache
219 | if user == "":
220 | user = update.message.from_user.id
221 | mode = True
222 | learn = False
223 | learning = ""
224 | cache = ""
225 | if mode == True and learn == True:
226 | update.message.reply_text('Send a message! Get it computed! 1558M Settings: Logic: ' + tpstring + ' Rate:' + temps + ' GPT-2 1558M. I am in the learning chatbot mode.')
227 | if mode == True and learn == False:
228 | update.message.reply_text('Send a message! Get it computed! 1558M Settings: Logic: ' + tpstring + ' Rate:' + temps + ' GPT-2 1558M I am in the chatbot mode.')
229 | if mode == False:
230 | update.message.reply_text('Send a message! Get it computed! 1558M Settings: Logic: ' + tpstring + ' Rate:' + temps + ' GPT-2 1558M I am in the finishsentence mode.')
231 | return
232 | if user == update.message.from_user.id:
233 | mode = True
234 | learn = False
235 | learning = ""
236 | cache = ""
237 | if mode == True and learn == True:
238 | update.message.reply_text('Send a message! Get it computed! 1558M Settings: Logic: ' + tpstring + ' Rate:' + temps + ' GPT-2 1558M. I am in the learning chatbot mode.')
239 | if mode == True and learn == False:
240 | update.message.reply_text('Send a message! Get it computed! 1558M Settings: Logic: ' + tpstring + ' Rate:' + temps + ' GPT-2 1558M I am in the chatbot mode.')
241 | if mode == False:
242 | update.message.reply_text('Send a message! Get it computed! 1558M Settings: Logic: ' + tpstring + ' Rate:' + temps + ' GPT-2 1558M I am in the finishsentence mode.')
243 | return
244 | else:
245 | left = str(tim)
246 | update.message.reply_text('Bot is currently in use, make sure to set your settings when their timer runs down. ' + left + ' seconds.')
247 |
248 | def learnreset(bot, update):
249 | """Send a message when the command /learnreset is issued."""
250 | global running
251 | global mode
252 | global learn
253 | global user
254 | global tim
255 | global learning
256 | global cache
257 | if user == "":
258 | user = update.message.from_user.id
259 | mode = True
260 | learn = True
261 | learning = ""
262 | cache = ""
263 | if mode == True and learn == True:
264 | update.message.reply_text('Send a message! Get it computed! 1558M Settings: Logic: ' + tpstring + ' Rate:' + temps + ' GPT-2 1558M. I am in the learning chatbot mode.')
265 | if mode == True and learn == False:
266 | update.message.reply_text('Send a message! Get it computed! 1558M Settings: Logic: ' + tpstring + ' Rate:' + temps + ' GPT-2 1558M I am in the chatbot mode.')
267 | if mode == False:
268 | update.message.reply_text('Send a message! Get it computed! 1558M Settings: Logic: ' + tpstring + ' Rate:' + temps + ' GPT-2 1558M I am in the finishsentence mode.')
269 | return
270 | if user == update.message.from_user.id:
271 | mode = True
272 | learn = True
273 | learning = ""
274 | cache = ""
275 | if mode == True and learn == True:
276 | update.message.reply_text('Send a message! Get it computed! 1558M Settings: Logic: ' + tpstring + ' Rate:' + temps + ' GPT-2 1558M. I am in the learning chatbot mode.')
277 | if mode == True and learn == False:
278 | update.message.reply_text('Send a message! Get it computed! 1558M Settings: Logic: ' + tpstring + ' Rate:' + temps + ' GPT-2 1558M I am in the chatbot mode.')
279 | if mode == False:
280 | update.message.reply_text('Send a message! Get it computed! 1558M Settings: Logic: ' + tpstring + ' Rate:' + temps + ' GPT-2 1558M I am in the finishsentence mode.')
281 | return
282 | else:
283 | left = str(tim)
284 | update.message.reply_text('Bot is currently in use, make sure to set your settings when their timer runs down. ' + left + ' seconds.')
285 |
286 | def regex(mew):
287 | meow = mew
288 | if "You:" in meow:
289 | meow = meow[0:meow.find('You:')]
290 | if "Me:" in meow:
291 | meow = meow[0:meow.find('Me:')]
292 | return meow
293 | if "Me:" in meow:
294 | meow = meow[0:meow.find('Me:')]
295 | if "You:" in meow:
296 | meow = meow[0:meow.find('You:')]
297 | return meow
298 | if "?" in meow:
299 | meow = meow[0:meow.find('?')]
300 | meow = meow + "?"
301 | return meow
302 | if "!" in meow:
303 | meow = meow.rsplit('!', 1)[0]
304 | meow = meow + "!"
305 | return meow
306 | else:
307 | meow = meow.rsplit('.', 1)[0]
308 | meow = meow + "."
309 | return meow
310 | meow = "Error."
311 | return meow
312 |
313 |
314 | def retry(bot, update):
315 | retr = True
316 | new = retr
317 | comput = threading.Thread(target=wait, args=(bot, update, new,))
318 | comput.start()
319 |
320 | def runn(bot, update):
321 | retr = False
322 | new = retr
323 | comput = threading.Thread(target=wait, args=(bot, update, new,))
324 | comput.start()
325 |
326 | def wait(bot, update, new):
327 | global tim
328 | global user
329 | global running
330 | global mode
331 | global learn
332 | global learning
333 | global cache
334 | if user == "":
335 | user = update.message.from_user.id
336 | if user == update.message.from_user.id:
337 | user = update.message.from_user.id
338 | temp = timeout
339 | compute = threading.Thread(target=interact_model, args=(bot, update, new,))
340 | compute.start()
341 | if running == False:
342 | while temp > 1:
343 | running = True
344 | time.sleep(1)
345 | temp = temp - 1
346 | if running == True:
347 | mode = False
348 | learn = False
349 | learning = ""
350 | cache = ""
351 | user = ""
352 | update.message.reply_text('Timer has run down, bot has been reset into the default mode.')
353 | running = False
354 | else:
355 | left = str(temp)
356 | update.message.reply_text('Bot is in use, current cooldown is: ' + left + ' seconds.')
357 |
358 | def interact_model(bot, update, new):
359 | model_name = '1558M'
360 | seed = random.randint(1431655765, 2863311530)
361 | nsamples = 1
362 | batch_size = 1
363 | top_k = tok
364 | topp = top
365 | models_dir = 'models'
366 | tex = str(update.message.text)
367 | global learning
368 | global learn
369 | global mode
370 | global cache
371 | #############################################
372 | # This does some basic length processing.
373 | if mode == True:
374 | tlen = len(tex.split())
375 | if tlen > 300:
376 | update.message.reply_text('Input text is too long.')
377 | return
378 | if new == True and cache:
379 | m = re.search('.* You: ', cache)
380 | raw_text = m.group(0)
381 | tlensp = len(raw_text.split())
382 | tlen = tlensp - 2
383 | length = tlen
384 | if tlen < 20:
385 | length = 20
386 | if tlen > 20:
387 | length = 20
388 | if tlen > 30:
389 | length = 40
390 | if tlen > 50:
391 | length = 60
392 | if debug == True:
393 | print("Cache is...")
394 | print(raw_text)
395 | if new != True:
396 | texm = 'Me: ' + tex
397 | initial = texm + ' You: '
398 | raw_text = learning + initial
399 | length = tlen
400 | if tlen < 20:
401 | length = 20
402 | if tlen > 20:
403 | length = 20
404 | if tlen > 30:
405 | length = 40
406 | if tlen > 50:
407 | length = 60
408 | cache = raw_text
409 | maxls = len(raw_text.split())
410 | if maxls > 300:
411 | while maxls > 300:
412 | if debug == True:
413 | print("Reducing memory of chat.")
414 | raw_text = raw_text.split(' Me:', 1)[-1]
415 | raw_text = "Me:" + raw_text
416 | maxls = len(raw_text.split())
417 | if maxls > 300:
418 | if debug == True:
419 | print("Reducing memory of chat.")
420 | raw_text = raw_text.split('You:', 1)[-1]
421 | raw_text = "You:" + raw_text
422 | maxls = len(raw_text.split())
423 | if debug == True:
424 | print("FINAL MEMORY REDUCTION:")
425 | print(raw_text)
426 | if mode == False:
427 | tlen = len(penguin.split())
428 | length = tlen
429 | if length > 300:
430 | update.message.reply_text('Input text is too long.')
431 | return
432 | if new != True:
433 | cache = tex
434 | if new == True and cache:
435 | tex = cache
436 | length = len(tex.split())
437 | tlen = length
438 | if debug == True:
439 | print("Cache is...")
440 | print(penguin)
441 | raw_text = tex
442 | toppf = float(topp)
443 | lengthm = float(tlen)
444 | multf = float(mx)
445 | lxm = float(lengthm * multf)
446 | top_p = lxm + toppf
447 | # The max here is 0.84 and minimum 0.005
448 | if top_p > 0.84:
449 | top_p = 0.84
450 | if top_p < 0.005:
451 | top_p = 0.005
452 | #############################################
453 | update.message.reply_text('Computing...')
454 | models_dir = os.path.expanduser(os.path.expandvars(models_dir))
455 | if batch_size is None:
456 | batch_size = 1
457 | assert nsamples % batch_size == 0
458 | enc = encoder.get_encoder(model_name, models_dir)
459 | hparams = model.default_hparams()
460 | with open(os.path.join(models_dir, model_name, 'hparams.json')) as f:
461 | hparams.override_from_dict(json.load(f))
462 | if length is None:
463 | length = hparams.n_ctx // 2
464 | elif length > hparams.n_ctx:
465 | raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx)
466 | with tf.Session(graph=tf.Graph()) as sess:
467 | context = tf.placeholder(tf.int32, [batch_size, None])
468 | np.random.seed(seed)
469 | tf.set_random_seed(seed)
470 | output = sample.sample_sequence(
471 | hparams=hparams, length=length,
472 | context=context,
473 | batch_size=batch_size,
474 | temperature=degree, top_k=top_k, top_p=top_p
475 | )
476 | saver = tf.train.Saver()
477 | ckpt = tf.train.latest_checkpoint(os.path.join(models_dir, model_name))
478 | saver.restore(sess, ckpt)
479 | context_tokens = enc.encode(raw_text)
480 | generated = 0
481 | for _ in range(nsamples // batch_size):
482 | out = sess.run(output, feed_dict={
483 | context: [context_tokens for _ in range(batch_size)]
484 | })[:, len(context_tokens):]
485 | for i in range(batch_size):
486 | generated += 1
487 | text = enc.decode(out[i])
488 | if debug == True:
489 | print("==========")
490 | print("Raw output: " + text)
491 | print("==========")
492 | if mode == True:
493 | splitted = text.splitlines()[0]
494 | else:
495 | splitted = text
496 | encodedstr = splitted.encode(encoding=sys.stdout.encoding,errors='ignore')
497 | decodedstr = encodedstr.decode("utf-8")
498 | final = str(decodedstr)
499 | # disable any regex on finishsentence mode.
500 | if mode == True:
501 | # Final regex
502 | sanitized = regex(final)
503 | finalsan = " ".join(re.split("[^a-zA-Z.,?!'*]+", sanitized))
504 |
505 | else:
506 | finalsan = final
507 | if learn == True:
508 | learning = raw_text + finalsan + " "
509 | update.message.reply_text(finalsan)
510 | if debug == True:
511 | modes = str(mode)
512 | print("Chatbot mode: " + modes)
513 | learns = str(learn)
514 | print("Learning mode: " + learns)
515 | lengths = str(length)
516 | print("Length: " + lengths)
517 | print("==========")
518 | splits = str(splitted)
519 | print("Before regex: " + splits)
520 | print("==========")
521 | print("Output: " + finalsan)
522 | print("==========")
523 | print("Raw_text or Original: " + raw_text)
524 | print("==========")
525 | print("Learning text or Next: " + learning)
526 | print("==========")
527 | tps = str(top_p)
528 | print("Final top_p: " + tps)
529 | print("==========")
530 | print("top_p in: " + tpstring)
531 | print("==========")
532 | sess.close()
533 |
534 | def error(bot, update):
535 | """Log Errors caused by Updates."""
536 | logger.warning('Update "%s" caused error "%s"', update)
537 |
538 | def main():
539 | """Start the bot."""
540 | # Create the Updater and pass it your bot's token.
541 | # Make sure to set use_context=True to use the new context based callbacks
542 | # Post version 12 this will no longer be necessary
543 | updater = Updater("BOTKEY", use_context=False)
544 | # Get the dispatcher to register handlers
545 | dp = updater.dispatcher
546 | # on different commands - answer in Telegram
547 | dp.add_handler(CommandHandler("start", start))
548 | dp.add_handler(CommandHandler("help", help))
549 | dp.add_handler(CommandHandler("chatbot", chatbot))
550 | dp.add_handler(CommandHandler("finish", finish))
551 | dp.add_handler(CommandHandler("learnon", learnon))
552 | dp.add_handler(CommandHandler("learnoff", learnoff))
553 | dp.add_handler(CommandHandler("learnreset", learnreset))
554 | dp.add_handler(CommandHandler("retry", retry))
555 | # on noncommand i.e message - echo the message on Telegram
556 | dp.add_handler(MessageHandler(Filters.text, runn))
557 | # log all errors
558 | dp.add_error_handler(error)
559 | # Start the Bot
560 | updater.start_polling()
561 | # Run the bot until you press Ctrl-C or the process receives SIGINT,
562 | # SIGTERM or SIGABRT. This should be used most of the time, since
563 | # start_polling() is non-blocking and will stop the bot gracefully.
564 | updater.idle()
565 |
566 | if __name__ == '__main__':
567 | main()
568 |
--------------------------------------------------------------------------------
/src/encoder.py:
--------------------------------------------------------------------------------
1 | """Byte pair encoding utilities"""
2 |
3 | import os
4 | import json
5 | import regex as re
6 | from functools import lru_cache
7 |
8 | @lru_cache()
9 | def bytes_to_unicode():
10 | """
11 | Returns list of utf-8 byte and a corresponding list of unicode strings.
12 | The reversible bpe codes work on unicode strings.
13 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
14 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
15 | This is a signficant percentage of your normal, say, 32K bpe vocab.
16 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
17 | And avoids mapping to whitespace/control characters the bpe code barfs on.
18 | """
19 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
20 | cs = bs[:]
21 | n = 0
22 | for b in range(2**8):
23 | if b not in bs:
24 | bs.append(b)
25 | cs.append(2**8+n)
26 | n += 1
27 | cs = [chr(n) for n in cs]
28 | return dict(zip(bs, cs))
29 |
30 | def get_pairs(word):
31 | """Return set of symbol pairs in a word.
32 |
33 | Word is represented as tuple of symbols (symbols being variable-length strings).
34 | """
35 | pairs = set()
36 | prev_char = word[0]
37 | for char in word[1:]:
38 | pairs.add((prev_char, char))
39 | prev_char = char
40 | return pairs
41 |
42 | class Encoder:
43 | def __init__(self, encoder, bpe_merges, errors='replace'):
44 | self.encoder = encoder
45 | self.decoder = {v:k for k,v in self.encoder.items()}
46 | self.errors = errors # how to handle errors in decoding
47 | self.byte_encoder = bytes_to_unicode()
48 | self.byte_decoder = {v:k for k, v in self.byte_encoder.items()}
49 | self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
50 | self.cache = {}
51 |
52 | # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
53 | self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
54 |
55 | def bpe(self, token):
56 | if token in self.cache:
57 | return self.cache[token]
58 | word = tuple(token)
59 | pairs = get_pairs(word)
60 |
61 | if not pairs:
62 | return token
63 |
64 | while True:
65 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
66 | if bigram not in self.bpe_ranks:
67 | break
68 | first, second = bigram
69 | new_word = []
70 | i = 0
71 | while i < len(word):
72 | try:
73 | j = word.index(first, i)
74 | new_word.extend(word[i:j])
75 | i = j
76 | except:
77 | new_word.extend(word[i:])
78 | break
79 |
80 | if word[i] == first and i < len(word)-1 and word[i+1] == second:
81 | new_word.append(first+second)
82 | i += 2
83 | else:
84 | new_word.append(word[i])
85 | i += 1
86 | new_word = tuple(new_word)
87 | word = new_word
88 | if len(word) == 1:
89 | break
90 | else:
91 | pairs = get_pairs(word)
92 | word = ' '.join(word)
93 | self.cache[token] = word
94 | return word
95 |
96 | def encode(self, text):
97 | bpe_tokens = []
98 | for token in re.findall(self.pat, text):
99 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
100 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
101 | return bpe_tokens
102 |
103 | def decode(self, tokens):
104 | text = ''.join([self.decoder[token] for token in tokens])
105 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
106 | return text
107 |
108 | def get_encoder(model_name, models_dir):
109 | with open(os.path.join(models_dir, model_name, 'encoder.json'), 'r') as f:
110 | encoder = json.load(f)
111 | with open(os.path.join(models_dir, model_name, 'vocab.bpe'), 'r', encoding="utf-8") as f:
112 | bpe_data = f.read()
113 | bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]
114 | return Encoder(
115 | encoder=encoder,
116 | bpe_merges=bpe_merges,
117 | )
118 |
--------------------------------------------------------------------------------
/src/model.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 | from tensorflow.contrib.training import HParams
4 |
5 | def default_hparams():
6 | return HParams(
7 | n_vocab=0,
8 | n_ctx=1024,
9 | n_embd=768,
10 | n_head=12,
11 | n_layer=12,
12 | )
13 |
14 | def shape_list(x):
15 | """Deal with dynamic shape in tensorflow cleanly."""
16 | static = x.shape.as_list()
17 | dynamic = tf.shape(x)
18 | return [dynamic[i] if s is None else s for i, s in enumerate(static)]
19 |
20 | def softmax(x, axis=-1):
21 | x = x - tf.reduce_max(x, axis=axis, keepdims=True)
22 | ex = tf.exp(x)
23 | return ex / tf.reduce_sum(ex, axis=axis, keepdims=True)
24 |
25 | def gelu(x):
26 | return 0.5*x*(1+tf.tanh(np.sqrt(2/np.pi)*(x+0.044715*tf.pow(x, 3))))
27 |
28 | def norm(x, scope, *, axis=-1, epsilon=1e-5):
29 | """Normalize to mean = 0, std = 1, then do a diagonal affine transform."""
30 | with tf.variable_scope(scope):
31 | n_state = x.shape[-1].value
32 | g = tf.get_variable('g', [n_state], initializer=tf.constant_initializer(1))
33 | b = tf.get_variable('b', [n_state], initializer=tf.constant_initializer(0))
34 | u = tf.reduce_mean(x, axis=axis, keepdims=True)
35 | s = tf.reduce_mean(tf.square(x-u), axis=axis, keepdims=True)
36 | x = (x - u) * tf.rsqrt(s + epsilon)
37 | x = x*g + b
38 | return x
39 |
40 | def split_states(x, n):
41 | """Reshape the last dimension of x into [n, x.shape[-1]/n]."""
42 | *start, m = shape_list(x)
43 | return tf.reshape(x, start + [n, m//n])
44 |
45 | def merge_states(x):
46 | """Smash the last two dimensions of x into a single dimension."""
47 | *start, a, b = shape_list(x)
48 | return tf.reshape(x, start + [a*b])
49 |
50 | def conv1d(x, scope, nf, *, w_init_stdev=0.02):
51 | with tf.variable_scope(scope):
52 | *start, nx = shape_list(x)
53 | w = tf.get_variable('w', [1, nx, nf], initializer=tf.random_normal_initializer(stddev=w_init_stdev))
54 | b = tf.get_variable('b', [nf], initializer=tf.constant_initializer(0))
55 | c = tf.reshape(tf.matmul(tf.reshape(x, [-1, nx]), tf.reshape(w, [-1, nf]))+b, start+[nf])
56 | return c
57 |
58 | def attention_mask(nd, ns, *, dtype):
59 | """1's in the lower triangle, counting from the lower right corner.
60 |
61 | Same as tf.matrix_band_part(tf.ones([nd, ns]), -1, ns-nd), but doesn't produce garbage on TPUs.
62 | """
63 | i = tf.range(nd)[:,None]
64 | j = tf.range(ns)
65 | m = i >= j - ns + nd
66 | return tf.cast(m, dtype)
67 |
68 |
69 | def attn(x, scope, n_state, *, past, hparams):
70 | assert x.shape.ndims == 3 # Should be [batch, sequence, features]
71 | assert n_state % hparams.n_head == 0
72 | if past is not None:
73 | assert past.shape.ndims == 5 # Should be [batch, 2, heads, sequence, features], where 2 is [k, v]
74 |
75 | def split_heads(x):
76 | # From [batch, sequence, features] to [batch, heads, sequence, features]
77 | return tf.transpose(split_states(x, hparams.n_head), [0, 2, 1, 3])
78 |
79 | def merge_heads(x):
80 | # Reverse of split_heads
81 | return merge_states(tf.transpose(x, [0, 2, 1, 3]))
82 |
83 | def mask_attn_weights(w):
84 | # w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst.
85 | _, _, nd, ns = shape_list(w)
86 | b = attention_mask(nd, ns, dtype=w.dtype)
87 | b = tf.reshape(b, [1, 1, nd, ns])
88 | w = w*b - tf.cast(1e10, w.dtype)*(1-b)
89 | return w
90 |
91 | def multihead_attn(q, k, v):
92 | # q, k, v have shape [batch, heads, sequence, features]
93 | w = tf.matmul(q, k, transpose_b=True)
94 | w = w * tf.rsqrt(tf.cast(v.shape[-1].value, w.dtype))
95 |
96 | w = mask_attn_weights(w)
97 | w = softmax(w)
98 | a = tf.matmul(w, v)
99 | return a
100 |
101 | with tf.variable_scope(scope):
102 | c = conv1d(x, 'c_attn', n_state*3)
103 | q, k, v = map(split_heads, tf.split(c, 3, axis=2))
104 | present = tf.stack([k, v], axis=1)
105 | if past is not None:
106 | pk, pv = tf.unstack(past, axis=1)
107 | k = tf.concat([pk, k], axis=-2)
108 | v = tf.concat([pv, v], axis=-2)
109 | a = multihead_attn(q, k, v)
110 | a = merge_heads(a)
111 | a = conv1d(a, 'c_proj', n_state)
112 | return a, present
113 |
114 |
115 | def mlp(x, scope, n_state, *, hparams):
116 | with tf.variable_scope(scope):
117 | nx = x.shape[-1].value
118 | h = gelu(conv1d(x, 'c_fc', n_state))
119 | h2 = conv1d(h, 'c_proj', nx)
120 | return h2
121 |
122 |
123 | def block(x, scope, *, past, hparams):
124 | with tf.variable_scope(scope):
125 | nx = x.shape[-1].value
126 | a, present = attn(norm(x, 'ln_1'), 'attn', nx, past=past, hparams=hparams)
127 | x = x + a
128 | m = mlp(norm(x, 'ln_2'), 'mlp', nx*4, hparams=hparams)
129 | x = x + m
130 | return x, present
131 |
132 | def past_shape(*, hparams, batch_size=None, sequence=None):
133 | return [batch_size, hparams.n_layer, 2, hparams.n_head, sequence, hparams.n_embd // hparams.n_head]
134 |
135 | def expand_tile(value, size):
136 | """Add a new axis of given size."""
137 | value = tf.convert_to_tensor(value, name='value')
138 | ndims = value.shape.ndims
139 | return tf.tile(tf.expand_dims(value, axis=0), [size] + [1]*ndims)
140 |
141 | def positions_for(tokens, past_length):
142 | batch_size = tf.shape(tokens)[0]
143 | nsteps = tf.shape(tokens)[1]
144 | return expand_tile(past_length + tf.range(nsteps), batch_size)
145 |
146 |
147 | def model(hparams, X, past=None, scope='model', reuse=False):
148 | with tf.variable_scope(scope, reuse=reuse):
149 | results = {}
150 | batch, sequence = shape_list(X)
151 |
152 | wpe = tf.get_variable('wpe', [hparams.n_ctx, hparams.n_embd],
153 | initializer=tf.random_normal_initializer(stddev=0.01))
154 | wte = tf.get_variable('wte', [hparams.n_vocab, hparams.n_embd],
155 | initializer=tf.random_normal_initializer(stddev=0.02))
156 | past_length = 0 if past is None else tf.shape(past)[-2]
157 | h = tf.gather(wte, X) + tf.gather(wpe, positions_for(X, past_length))
158 |
159 | # Transformer
160 | presents = []
161 | pasts = tf.unstack(past, axis=1) if past is not None else [None] * hparams.n_layer
162 | assert len(pasts) == hparams.n_layer
163 | for layer, past in enumerate(pasts):
164 | h, present = block(h, 'h%d' % layer, past=past, hparams=hparams)
165 | presents.append(present)
166 | results['present'] = tf.stack(presents, axis=1)
167 | h = norm(h, 'ln_f')
168 |
169 | # Language model loss. Do tokens hparams.n_ctx:
60 | raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx)
61 |
62 | self.sess = tf.Session(graph=tf.Graph())
63 | self.sess.__enter__()
64 |
65 | self.context = tf.placeholder(tf.int32, [batch_size, None])
66 | np.random.seed(seed)
67 | tf.set_random_seed(seed)
68 | self.output = sample.sample_sequence(
69 | hparams=hparams, length=length,
70 | context=self.context,
71 | batch_size=batch_size,
72 | temperature=temperature, top_k=top_k
73 | )
74 |
75 | saver = tf.train.Saver()
76 | self.ckpt = tf.train.latest_checkpoint(os.path.join('models', model_name))
77 | saver.restore(self.sess, self.ckpt)
78 |
79 | def close(self):
80 | self.sess.close()
81 |
82 | def generate_conditional(self,raw_text):
83 | context_tokens = self.enc.encode(raw_text)
84 | generated = 0
85 | for _ in range(self.nsamples // self.batch_size):
86 | out = self.sess.run(self.output, feed_dict={
87 | self.context: [context_tokens for _ in range(self.batch_size)]
88 | })[:, len(context_tokens):]
89 | for i in range(self.batch_size):
90 | generated += 1
91 | text = self.enc.decode(out[i])
92 | return text
93 | #print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
94 | #print(text)
95 | #print("=" * 80)
96 | ###
97 |
98 | gpt2 = GPT2(model_name="1558M")
99 | ###
100 | class Who:
101 | """A class defining the conversation parties: me, he"""
102 | def __init__(self):
103 | self.prefixes = []
104 |
105 | def matches(self,phrase):
106 | for prefix in self.prefixes:
107 | if phrase.startswith(prefix):
108 | #print(f"{phrase} starts with {prefix}")
109 | return True
110 |
111 | #print(f"{phrase} does not start with {self.prefixes}")
112 | return False
113 |
114 | def get_random_prefix(self):
115 | return self.prefixes[0]
116 |
117 | class Me(Who):
118 | def __init__(self):
119 | super().__init__()
120 | self.prefixes = ["I said: \""]
121 |
122 |
123 | class You(Who):
124 | def __init__(self):
125 | super().__init__()
126 | self.prefixes = ["You said: \""]
127 |
128 | class Conversation:
129 |
130 | def __init__(self, prior = None):
131 | if prior is None:
132 | prior="""
133 | You said: "Nice to meet you. What's your name?"
134 | I said: "My name is Pete."
135 | You said: "That's an interesting name. How old are you?"
136 | I said: "I'm 40 years old."
137 | You said: "Can you tell me something about yourself?"
138 | I said: "Ofcourse! I like playing video games and eating cake. "
139 | You said: "I like sweet stuff too. What are your plans for tomorrow?"
140 | """
141 | self.suggestion = None
142 |
143 | self.me = Me()
144 | self.you = You()
145 | self.parties = [ self.me, self.you ]
146 |
147 | self.conversation = []
148 |
149 | lines = prior.split("\n")
150 | for line in lines:
151 | line = line.strip()
152 | if len(line)!=0:
153 | party = None
154 | for party in self.parties:
155 | if party.matches(line):
156 | break
157 | if party is None:
158 | raise Exception(f"Unknown party: {line}")
159 |
160 | self.conversation.append((party,line))
161 | self.get_suggestion()
162 |
163 |
164 | def get_prior(self):
165 | conv = ""
166 | for (party, line) in self.conversation:
167 | conv+=line+"\n"
168 | return conv
169 |
170 | def get_suggestion(self):
171 | who, last_line = self.conversation[-1]
172 |
173 | party_index = self.parties.index(who)
174 | next_party = self.parties[(party_index+1) % len(self.parties)]
175 |
176 | conv = self.get_prior()
177 | conv += next_party.get_random_prefix()
178 | answer = self.get_answer(next_party, conv)
179 |
180 | if not next_party.matches(answer):
181 | prefix = next_party.get_random_prefix()
182 | answer = prefix + answer
183 |
184 | self.suggestion = (next_party, answer)
185 |
186 | def next(self, party = None, answer = ""):
187 | """Continue the conversation
188 | :param party: None -> use the current party which is currently in turn
189 | :param answer: None -> use the suggestion, specify a text to override the
190 | suggestion
191 |
192 | """
193 | suggested_party, suggested_answer = self.suggestion
194 | if party is None:
195 | party = suggested_party
196 |
197 | if answer == "":
198 | answer = suggested_answer
199 |
200 | if not party.matches(answer):
201 | prefix = party.get_random_prefix()
202 | answer = prefix + answer
203 |
204 | answer = answer.strip()
205 | if answer[-1] != "\"":
206 | # add the closing "
207 | answer += "\""
208 |
209 | self.conversation.append((party, answer))
210 | self.get_suggestion()
211 |
212 | def retry(self):
213 | self.get_suggestion()
214 |
215 | def get_answer(self, party, conv):
216 | answer = gpt2.generate_conditional(raw_text=conv)
217 | lines = answer.split("\n")
218 | line = ""
219 | for line in lines:
220 | if line !="":
221 | break
222 |
223 | if line!="":
224 | return line
225 |
226 | return ""
227 |
228 | def show(self):
229 | conv = ""
230 | for (party, line) in self.conversation:
231 | conv+=line+"\n"
232 | print(conv)
233 | if self.suggestion is not None:
234 | party, answer = self.suggestion
235 | print("--> "+answer)
236 |
237 |
238 | c = Conversation()
239 | c.show()
240 | c.retry()
241 | c.show()
242 | c.next()
243 | c.show()
244 | c.retry()
245 | c.next(c.you, "Pizza is not to good for your health though.")
246 | c.show()
247 | gpt2.close()
248 |
249 | # This is for possible future development but way slow out of date etc.
250 |
--------------------------------------------------------------------------------
/src/sample.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | import model
4 |
5 | def top_k_logits(logits, k):
6 | if k == 0:
7 | # no truncation
8 | return logits
9 |
10 | def _top_k():
11 | values, _ = tf.nn.top_k(logits, k=k)
12 | min_values = values[:, -1, tf.newaxis]
13 | return tf.where(
14 | logits < min_values,
15 | tf.ones_like(logits, dtype=logits.dtype) * -1e10,
16 | logits,
17 | )
18 | return tf.cond(
19 | tf.equal(k, 0),
20 | lambda: logits,
21 | lambda: _top_k(),
22 | )
23 |
24 |
25 | def top_p_logits(logits, p):
26 | """Nucleus sampling"""
27 | batch, _ = logits.shape.as_list()
28 | sorted_logits = tf.sort(logits, direction='DESCENDING', axis=-1)
29 | cumulative_probs = tf.cumsum(tf.nn.softmax(sorted_logits, axis=-1), axis=-1)
30 | indices = tf.stack([
31 | tf.range(0, batch),
32 | # number of indices to include
33 | tf.maximum(tf.reduce_sum(tf.cast(cumulative_probs <= p, tf.int32), axis=-1) - 1, 0),
34 | ], axis=-1)
35 | min_values = tf.gather_nd(sorted_logits, indices)
36 | return tf.where(
37 | logits < min_values,
38 | tf.ones_like(logits) * -1e10,
39 | logits,
40 | )
41 |
42 |
43 | def sample_sequence(*, hparams, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, top_p=1):
44 | if start_token is None:
45 | assert context is not None, 'Specify exactly one of start_token and context!'
46 | else:
47 | assert context is None, 'Specify exactly one of start_token and context!'
48 | context = tf.fill([batch_size, 1], start_token)
49 |
50 | def step(hparams, tokens, past=None):
51 | lm_output = model.model(hparams=hparams, X=tokens, past=past, reuse=tf.AUTO_REUSE)
52 |
53 | logits = lm_output['logits'][:, :, :hparams.n_vocab]
54 | presents = lm_output['present']
55 | presents.set_shape(model.past_shape(hparams=hparams, batch_size=batch_size))
56 | return {
57 | 'logits': logits,
58 | 'presents': presents,
59 | }
60 |
61 | with tf.name_scope('sample_sequence'):
62 | def body(past, prev, output):
63 | next_outputs = step(hparams, prev, past=past)
64 | logits = next_outputs['logits'][:, -1, :] / tf.to_float(temperature)
65 | logits = top_k_logits(logits, k=top_k)
66 | logits = top_p_logits(logits, p=top_p)
67 | samples = tf.multinomial(logits, num_samples=1, output_dtype=tf.int32)
68 | return [
69 | next_outputs['presents'] if past is None else tf.concat([past, next_outputs['presents']], axis=-2),
70 | samples,
71 | tf.concat([output, samples], axis=1)
72 | ]
73 |
74 | past, prev, output = body(None, context, context)
75 |
76 | def cond(*args):
77 | return True
78 |
79 | _, _, tokens = tf.while_loop(
80 | cond=cond, body=body,
81 | maximum_iterations=length - 1,
82 | loop_vars=[
83 | past,
84 | prev,
85 | output
86 | ],
87 | shape_invariants=[
88 | tf.TensorShape(model.past_shape(hparams=hparams, batch_size=batch_size)),
89 | tf.TensorShape([batch_size, None]),
90 | tf.TensorShape([batch_size, None]),
91 | ],
92 | back_prop=False,
93 | )
94 |
95 | return tokens
96 |
--------------------------------------------------------------------------------
/start.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | while true
3 | do
4 | python3 src/GPT2-Learning.py
5 | done
6 |
--------------------------------------------------------------------------------