├── GPT-2-BOT.ipynb ├── README.md ├── SCORES.md ├── download_model.py ├── requirements.txt ├── src ├── GPT2-Learning.py ├── encoder.py ├── model.py ├── olddemo.py └── sample.py └── start.sh /GPT-2-BOT.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "Untitled0.ipynb", 7 | "provenance": [], 8 | "authorship_tag": "ABX9TyMrkPXfE3zUuuEDDYxAmfX6", 9 | "include_colab_link": true 10 | }, 11 | "kernelspec": { 12 | "name": "python3", 13 | "display_name": "Python 3" 14 | }, 15 | "language_info": { 16 | "name": "python" 17 | }, 18 | "accelerator": "TPU" 19 | }, 20 | "cells": [ 21 | { 22 | "cell_type": "markdown", 23 | "metadata": { 24 | "id": "view-in-github", 25 | "colab_type": "text" 26 | }, 27 | "source": [ 28 | "\"Open" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "metadata": { 34 | "colab": { 35 | "base_uri": "https://localhost:8080/" 36 | }, 37 | "id": "4MItH_gOezsh", 38 | "outputId": "228486d1-3fa0-47f0-c877-ca30ad840a81" 39 | }, 40 | "source": [ 41 | "from google.colab import drive\n", 42 | "drive.mount('/content/gdrive')" 43 | ], 44 | "execution_count": 1, 45 | "outputs": [ 46 | { 47 | "output_type": "stream", 48 | "text": [ 49 | "Mounted at /content/gdrive\n" 50 | ], 51 | "name": "stdout" 52 | } 53 | ] 54 | }, 55 | { 56 | "cell_type": "markdown", 57 | "metadata": { 58 | "id": "NowXHz16haZ0" 59 | }, 60 | "source": [ 61 | "Connect to google drive for model and project storage." 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "metadata": { 67 | "colab": { 68 | "base_uri": "https://localhost:8080/" 69 | }, 70 | "id": "RbcOtiuzfCQB", 71 | "outputId": "2013f1d6-4d1f-4297-9702-b16bb2a4e1f0" 72 | }, 73 | "source": [ 74 | "%cd gdrive/My Drive/Colab Notebooks" 75 | ], 76 | "execution_count": 2, 77 | "outputs": [ 78 | { 79 | "output_type": "stream", 80 | "text": [ 81 | "/content/gdrive/My Drive/Colab Notebooks\n" 82 | ], 83 | "name": "stdout" 84 | } 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "metadata": { 90 | "colab": { 91 | "base_uri": "https://localhost:8080/" 92 | }, 93 | "id": "6jmk6czDflu4", 94 | "outputId": "dfa5e5d2-c648-4e52-ccbd-2df1ee22cee5" 95 | }, 96 | "source": [ 97 | "! git clone https://github.com/Existencce/GPT2-Telegram-Chatbot" 98 | ], 99 | "execution_count": 3, 100 | "outputs": [ 101 | { 102 | "output_type": "stream", 103 | "text": [ 104 | "fatal: destination path 'GPT2-Telegram-Chatbot' already exists and is not an empty directory.\n" 105 | ], 106 | "name": "stdout" 107 | } 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "metadata": { 113 | "colab": { 114 | "base_uri": "https://localhost:8080/" 115 | }, 116 | "id": "iN58X9hKgTvN", 117 | "outputId": "f4000c18-6727-41a0-a3e5-32d7cc7b16ec" 118 | }, 119 | "source": [ 120 | "%cd GPT2-Telegram-Chatbot" 121 | ], 122 | "execution_count": 4, 123 | "outputs": [ 124 | { 125 | "output_type": "stream", 126 | "text": [ 127 | "/content/gdrive/My Drive/Colab Notebooks/GPT2-Telegram-Chatbot\n" 128 | ], 129 | "name": "stdout" 130 | } 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "metadata": { 136 | "id": "IH4Yccl0gaNw" 137 | }, 138 | "source": [ 139 | "! git pull" 140 | ], 141 | "execution_count": null, 142 | "outputs": [] 143 | }, 144 | { 145 | "cell_type": "markdown", 146 | "metadata": { 147 | "id": "pBQj9tJJgpnN" 148 | }, 149 | "source": [ 150 | "Change to 774M model, Set your bot token below.\n", 151 | "Make sure to change runtime to GPU/TPU in google collab." 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "metadata": { 157 | "id": "BB0SqY7xgxDl" 158 | }, 159 | "source": [ 160 | "!sed -i -e 's/BOTKEY/1887036376:AAF_gJdkt_2z44ZdoLLeiumQvN-9ihYRUBQ/' src/GPT2-Learning.py\n", 161 | "!sed -i -e 's/774M/1558M/' src/GPT2-Learning.py" 162 | ], 163 | "execution_count": 7, 164 | "outputs": [] 165 | }, 166 | { 167 | "cell_type": "markdown", 168 | "metadata": { 169 | "id": "o6bsyPRkg_EZ" 170 | }, 171 | "source": [ 172 | "Install Requirements.. You might need to do this a few times." 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "metadata": { 178 | "colab": { 179 | "base_uri": "https://localhost:8080/", 180 | "height": 1000 181 | }, 182 | "id": "KibtYZZ5hBAk", 183 | "outputId": "90b2073d-32bd-49e5-f3d7-bd9a74a20704" 184 | }, 185 | "source": [ 186 | "!pip3 install tqdm\n", 187 | "!pip3 install regex\n", 188 | "!pip3 install fire\n", 189 | "!pip3 install python-telegram-bot==12.0.0\n", 190 | "!pip3 install requests\n", 191 | "!pip3 install tensorflow-gpu==1.15.5" 192 | ], 193 | "execution_count": 8, 194 | "outputs": [ 195 | { 196 | "output_type": "stream", 197 | "text": [ 198 | "Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (4.41.1)\n", 199 | "Requirement already satisfied: regex in /usr/local/lib/python3.7/dist-packages (2019.12.20)\n", 200 | "Collecting fire\n", 201 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/11/07/a119a1aa04d37bc819940d95ed7e135a7dcca1c098123a3764a6dcace9e7/fire-0.4.0.tar.gz (87kB)\n", 202 | "\u001b[K |████████████████████████████████| 92kB 4.2MB/s \n", 203 | "\u001b[?25hRequirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from fire) (1.15.0)\n", 204 | "Requirement already satisfied: termcolor in /usr/local/lib/python3.7/dist-packages (from fire) (1.1.0)\n", 205 | "Building wheels for collected packages: fire\n", 206 | " Building wheel for fire (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 207 | " Created wheel for fire: filename=fire-0.4.0-py2.py3-none-any.whl size=115928 sha256=bbcd80493466e69a169b7e9329d46ade1766ba745010bccf7c3c80e6d42e0de9\n", 208 | " Stored in directory: /root/.cache/pip/wheels/af/19/30/1ea0cad502dcb4e66ed5a690279628c827aea38bbbab75d5ed\n", 209 | "Successfully built fire\n", 210 | "Installing collected packages: fire\n", 211 | "Successfully installed fire-0.4.0\n", 212 | "Collecting python-telegram-bot==12.0.0\n", 213 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/25/06/5047b87e9ec3ffd9af6f83069803921f2f02c9411753d610cc569c8e4638/python_telegram_bot-12.0.0-py2.py3-none-any.whl (346kB)\n", 214 | "\u001b[K |████████████████████████████████| 348kB 6.9MB/s \n", 215 | "\u001b[?25hRequirement already satisfied: certifi in /usr/local/lib/python3.7/dist-packages (from python-telegram-bot==12.0.0) (2020.12.5)\n", 216 | "Requirement already satisfied: tornado>=5.1 in /usr/local/lib/python3.7/dist-packages (from python-telegram-bot==12.0.0) (5.1.1)\n", 217 | "Collecting cryptography\n", 218 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/b2/26/7af637e6a7e87258b963f1731c5982fb31cd507f0d90d91836e446955d02/cryptography-3.4.7-cp36-abi3-manylinux2014_x86_64.whl (3.2MB)\n", 219 | "\u001b[K |████████████████████████████████| 3.2MB 9.3MB/s \n", 220 | "\u001b[?25hRequirement already satisfied: future>=0.16.0 in /usr/local/lib/python3.7/dist-packages (from python-telegram-bot==12.0.0) (0.16.0)\n", 221 | "Requirement already satisfied: cffi>=1.12 in /usr/local/lib/python3.7/dist-packages (from cryptography->python-telegram-bot==12.0.0) (1.14.5)\n", 222 | "Requirement already satisfied: pycparser in /usr/local/lib/python3.7/dist-packages (from cffi>=1.12->cryptography->python-telegram-bot==12.0.0) (2.20)\n", 223 | "Installing collected packages: cryptography, python-telegram-bot\n", 224 | "Successfully installed cryptography-3.4.7 python-telegram-bot-12.0.0\n", 225 | "Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (2.23.0)\n", 226 | "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests) (3.0.4)\n", 227 | "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests) (1.24.3)\n", 228 | "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests) (2.10)\n", 229 | "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests) (2020.12.5)\n", 230 | "Collecting tensorflow-gpu==1.15.5\n", 231 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/73/b5/adc281ce4e631251c749d342793795832026edf9035df81c3813ef33fad2/tensorflow_gpu-1.15.5-cp37-cp37m-manylinux2010_x86_64.whl (411.0MB)\n", 232 | "\u001b[K |████████████████████████████████| 411.0MB 40kB/s \n", 233 | "\u001b[?25hRequirement already satisfied: google-pasta>=0.1.6 in /usr/local/lib/python3.7/dist-packages (from tensorflow-gpu==1.15.5) (0.2.0)\n", 234 | "Collecting gast==0.2.2\n", 235 | " Downloading https://files.pythonhosted.org/packages/4e/35/11749bf99b2d4e3cceb4d55ca22590b0d7c2c62b9de38ac4a4a7f4687421/gast-0.2.2.tar.gz\n", 236 | "Requirement already satisfied: six>=1.10.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow-gpu==1.15.5) (1.15.0)\n", 237 | "Collecting keras-applications>=1.0.8\n", 238 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/71/e3/19762fdfc62877ae9102edf6342d71b28fbfd9dea3d2f96a882ce099b03f/Keras_Applications-1.0.8-py3-none-any.whl (50kB)\n", 239 | "\u001b[K |████████████████████████████████| 51kB 6.1MB/s \n", 240 | "\u001b[?25hRequirement already satisfied: wrapt>=1.11.1 in /usr/local/lib/python3.7/dist-packages (from tensorflow-gpu==1.15.5) (1.12.1)\n", 241 | "Requirement already satisfied: protobuf>=3.6.1 in /usr/local/lib/python3.7/dist-packages (from tensorflow-gpu==1.15.5) (3.12.4)\n", 242 | "Requirement already satisfied: keras-preprocessing>=1.0.5 in /usr/local/lib/python3.7/dist-packages (from tensorflow-gpu==1.15.5) (1.1.2)\n", 243 | "Requirement already satisfied: opt-einsum>=2.3.2 in /usr/local/lib/python3.7/dist-packages (from tensorflow-gpu==1.15.5) (3.3.0)\n", 244 | "Collecting tensorboard<1.16.0,>=1.15.0\n", 245 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/1e/e9/d3d747a97f7188f48aa5eda486907f3b345cd409f0a0850468ba867db246/tensorboard-1.15.0-py3-none-any.whl (3.8MB)\n", 246 | "\u001b[K |████████████████████████████████| 3.8MB 32.2MB/s \n", 247 | "\u001b[?25hRequirement already satisfied: h5py<=2.10.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow-gpu==1.15.5) (2.10.0)\n", 248 | "Requirement already satisfied: wheel>=0.26; python_version >= \"3\" in /usr/local/lib/python3.7/dist-packages (from tensorflow-gpu==1.15.5) (0.36.2)\n", 249 | "Requirement already satisfied: grpcio>=1.8.6 in /usr/local/lib/python3.7/dist-packages (from tensorflow-gpu==1.15.5) (1.32.0)\n", 250 | "Collecting tensorflow-estimator==1.15.1\n", 251 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/de/62/2ee9cd74c9fa2fa450877847ba560b260f5d0fb70ee0595203082dafcc9d/tensorflow_estimator-1.15.1-py2.py3-none-any.whl (503kB)\n", 252 | "\u001b[K |████████████████████████████████| 512kB 53.4MB/s \n", 253 | "\u001b[?25hRequirement already satisfied: astor>=0.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow-gpu==1.15.5) (0.8.1)\n", 254 | "Requirement already satisfied: absl-py>=0.7.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow-gpu==1.15.5) (0.12.0)\n", 255 | "Requirement already satisfied: termcolor>=1.1.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow-gpu==1.15.5) (1.1.0)\n", 256 | "Collecting numpy<1.19.0,>=1.16.0\n", 257 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/d6/c6/58e517e8b1fb192725cfa23c01c2e60e4e6699314ee9684a1c5f5c9b27e1/numpy-1.18.5-cp37-cp37m-manylinux1_x86_64.whl (20.1MB)\n", 258 | "\u001b[K |████████████████████████████████| 20.1MB 1.4MB/s \n", 259 | "\u001b[?25hRequirement already satisfied: setuptools in /usr/local/lib/python3.7/dist-packages (from protobuf>=3.6.1->tensorflow-gpu==1.15.5) (56.1.0)\n", 260 | "Requirement already satisfied: werkzeug>=0.11.15 in /usr/local/lib/python3.7/dist-packages (from tensorboard<1.16.0,>=1.15.0->tensorflow-gpu==1.15.5) (2.0.0)\n", 261 | "Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.7/dist-packages (from tensorboard<1.16.0,>=1.15.0->tensorflow-gpu==1.15.5) (3.3.4)\n", 262 | "Requirement already satisfied: importlib-metadata; python_version < \"3.8\" in /usr/local/lib/python3.7/dist-packages (from markdown>=2.6.8->tensorboard<1.16.0,>=1.15.0->tensorflow-gpu==1.15.5) (4.0.1)\n", 263 | "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata; python_version < \"3.8\"->markdown>=2.6.8->tensorboard<1.16.0,>=1.15.0->tensorflow-gpu==1.15.5) (3.4.1)\n", 264 | "Requirement already satisfied: typing-extensions>=3.6.4; python_version < \"3.8\" in /usr/local/lib/python3.7/dist-packages (from importlib-metadata; python_version < \"3.8\"->markdown>=2.6.8->tensorboard<1.16.0,>=1.15.0->tensorflow-gpu==1.15.5) (3.7.4.3)\n", 265 | "Building wheels for collected packages: gast\n", 266 | " Building wheel for gast (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 267 | " Created wheel for gast: filename=gast-0.2.2-cp37-none-any.whl size=7540 sha256=3d3a4bf0e39a6e2a93a59617b12913ca208bd56cfd5a81827bc5c088c22ce407\n", 268 | " Stored in directory: /root/.cache/pip/wheels/5c/2e/7e/a1d4d4fcebe6c381f378ce7743a3ced3699feb89bcfbdadadd\n", 269 | "Successfully built gast\n", 270 | "\u001b[31mERROR: tensorflow 2.4.1 has requirement gast==0.3.3, but you'll have gast 0.2.2 which is incompatible.\u001b[0m\n", 271 | "\u001b[31mERROR: tensorflow 2.4.1 has requirement numpy~=1.19.2, but you'll have numpy 1.18.5 which is incompatible.\u001b[0m\n", 272 | "\u001b[31mERROR: tensorflow 2.4.1 has requirement tensorboard~=2.4, but you'll have tensorboard 1.15.0 which is incompatible.\u001b[0m\n", 273 | "\u001b[31mERROR: tensorflow 2.4.1 has requirement tensorflow-estimator<2.5.0,>=2.4.0, but you'll have tensorflow-estimator 1.15.1 which is incompatible.\u001b[0m\n", 274 | "\u001b[31mERROR: tensorflow-probability 0.12.1 has requirement gast>=0.3.2, but you'll have gast 0.2.2 which is incompatible.\u001b[0m\n", 275 | "\u001b[31mERROR: datascience 0.10.6 has requirement folium==0.2.1, but you'll have folium 0.8.3 which is incompatible.\u001b[0m\n", 276 | "\u001b[31mERROR: albumentations 0.1.12 has requirement imgaug<0.2.7,>=0.2.5, but you'll have imgaug 0.2.9 which is incompatible.\u001b[0m\n", 277 | "Installing collected packages: gast, numpy, keras-applications, tensorboard, tensorflow-estimator, tensorflow-gpu\n", 278 | " Found existing installation: gast 0.3.3\n", 279 | " Uninstalling gast-0.3.3:\n", 280 | " Successfully uninstalled gast-0.3.3\n", 281 | " Found existing installation: numpy 1.19.5\n", 282 | " Uninstalling numpy-1.19.5:\n", 283 | " Successfully uninstalled numpy-1.19.5\n", 284 | " Found existing installation: tensorboard 2.4.1\n", 285 | " Uninstalling tensorboard-2.4.1:\n", 286 | " Successfully uninstalled tensorboard-2.4.1\n", 287 | " Found existing installation: tensorflow-estimator 2.4.0\n", 288 | " Uninstalling tensorflow-estimator-2.4.0:\n", 289 | " Successfully uninstalled tensorflow-estimator-2.4.0\n", 290 | "Successfully installed gast-0.2.2 keras-applications-1.0.8 numpy-1.18.5 tensorboard-1.15.0 tensorflow-estimator-1.15.1 tensorflow-gpu-1.15.5\n" 291 | ], 292 | "name": "stdout" 293 | }, 294 | { 295 | "output_type": "display_data", 296 | "data": { 297 | "application/vnd.colab-display-data+json": { 298 | "pip_warning": { 299 | "packages": [ 300 | "numpy" 301 | ] 302 | } 303 | } 304 | }, 305 | "metadata": { 306 | "tags": [] 307 | } 308 | } 309 | ] 310 | }, 311 | { 312 | "cell_type": "markdown", 313 | "metadata": { 314 | "id": "Lit4usgih93R" 315 | }, 316 | "source": [ 317 | "After requirements installed, reconnect to google drive after restarting runtime and setting runtime to TPU under \"Runtime -> Change Runtime Type\" tab" 318 | ] 319 | }, 320 | { 321 | "cell_type": "code", 322 | "metadata": { 323 | "colab": { 324 | "base_uri": "https://localhost:8080/" 325 | }, 326 | "id": "7QanwV8MiDsh", 327 | "outputId": "f2d66701-b99c-4b6c-e8b5-b7586169d80b" 328 | }, 329 | "source": [ 330 | "from google.colab import drive\n", 331 | "drive.mount('/content/gdrive')\n", 332 | "%cd /content/gdrive/MyDrive/Colab Notebooks/GPT2-Telegram-Chatbot" 333 | ], 334 | "execution_count": 1, 335 | "outputs": [ 336 | { 337 | "output_type": "stream", 338 | "text": [ 339 | "Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount(\"/content/gdrive\", force_remount=True).\n", 340 | "/content/gdrive/MyDrive/Colab Notebooks/GPT2-Telegram-Chatbot\n" 341 | ], 342 | "name": "stdout" 343 | } 344 | ] 345 | }, 346 | { 347 | "cell_type": "markdown", 348 | "metadata": { 349 | "id": "xDZj822bieW7" 350 | }, 351 | "source": [ 352 | "Download the model" 353 | ] 354 | }, 355 | { 356 | "cell_type": "code", 357 | "metadata": { 358 | "colab": { 359 | "base_uri": "https://localhost:8080/" 360 | }, 361 | "id": "EY6miAy9igEs", 362 | "outputId": "b2bd1c49-d4e6-4974-ce82-ee2663da0cdc" 363 | }, 364 | "source": [ 365 | "!python3 download_model.py 1558M" 366 | ], 367 | "execution_count": null, 368 | "outputs": [ 369 | { 370 | "output_type": "stream", 371 | "text": [ 372 | "\rFetching checkpoint: 0%| | 0.00/77.0 [00:00 1: 343 | running = True 344 | time.sleep(1) 345 | temp = temp - 1 346 | if running == True: 347 | mode = False 348 | learn = False 349 | learning = "" 350 | cache = "" 351 | user = "" 352 | update.message.reply_text('Timer has run down, bot has been reset into the default mode.') 353 | running = False 354 | else: 355 | left = str(temp) 356 | update.message.reply_text('Bot is in use, current cooldown is: ' + left + ' seconds.') 357 | 358 | def interact_model(bot, update, new): 359 | model_name = '1558M' 360 | seed = random.randint(1431655765, 2863311530) 361 | nsamples = 1 362 | batch_size = 1 363 | top_k = tok 364 | topp = top 365 | models_dir = 'models' 366 | tex = str(update.message.text) 367 | global learning 368 | global learn 369 | global mode 370 | global cache 371 | ############################################# 372 | # This does some basic length processing. 373 | if mode == True: 374 | tlen = len(tex.split()) 375 | if tlen > 300: 376 | update.message.reply_text('Input text is too long.') 377 | return 378 | if new == True and cache: 379 | m = re.search('.* You: ', cache) 380 | raw_text = m.group(0) 381 | tlensp = len(raw_text.split()) 382 | tlen = tlensp - 2 383 | length = tlen 384 | if tlen < 20: 385 | length = 20 386 | if tlen > 20: 387 | length = 20 388 | if tlen > 30: 389 | length = 40 390 | if tlen > 50: 391 | length = 60 392 | if debug == True: 393 | print("Cache is...") 394 | print(raw_text) 395 | if new != True: 396 | texm = 'Me: ' + tex 397 | initial = texm + ' You: ' 398 | raw_text = learning + initial 399 | length = tlen 400 | if tlen < 20: 401 | length = 20 402 | if tlen > 20: 403 | length = 20 404 | if tlen > 30: 405 | length = 40 406 | if tlen > 50: 407 | length = 60 408 | cache = raw_text 409 | maxls = len(raw_text.split()) 410 | if maxls > 300: 411 | while maxls > 300: 412 | if debug == True: 413 | print("Reducing memory of chat.") 414 | raw_text = raw_text.split(' Me:', 1)[-1] 415 | raw_text = "Me:" + raw_text 416 | maxls = len(raw_text.split()) 417 | if maxls > 300: 418 | if debug == True: 419 | print("Reducing memory of chat.") 420 | raw_text = raw_text.split('You:', 1)[-1] 421 | raw_text = "You:" + raw_text 422 | maxls = len(raw_text.split()) 423 | if debug == True: 424 | print("FINAL MEMORY REDUCTION:") 425 | print(raw_text) 426 | if mode == False: 427 | tlen = len(penguin.split()) 428 | length = tlen 429 | if length > 300: 430 | update.message.reply_text('Input text is too long.') 431 | return 432 | if new != True: 433 | cache = tex 434 | if new == True and cache: 435 | tex = cache 436 | length = len(tex.split()) 437 | tlen = length 438 | if debug == True: 439 | print("Cache is...") 440 | print(penguin) 441 | raw_text = tex 442 | toppf = float(topp) 443 | lengthm = float(tlen) 444 | multf = float(mx) 445 | lxm = float(lengthm * multf) 446 | top_p = lxm + toppf 447 | # The max here is 0.84 and minimum 0.005 448 | if top_p > 0.84: 449 | top_p = 0.84 450 | if top_p < 0.005: 451 | top_p = 0.005 452 | ############################################# 453 | update.message.reply_text('Computing...') 454 | models_dir = os.path.expanduser(os.path.expandvars(models_dir)) 455 | if batch_size is None: 456 | batch_size = 1 457 | assert nsamples % batch_size == 0 458 | enc = encoder.get_encoder(model_name, models_dir) 459 | hparams = model.default_hparams() 460 | with open(os.path.join(models_dir, model_name, 'hparams.json')) as f: 461 | hparams.override_from_dict(json.load(f)) 462 | if length is None: 463 | length = hparams.n_ctx // 2 464 | elif length > hparams.n_ctx: 465 | raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx) 466 | with tf.Session(graph=tf.Graph()) as sess: 467 | context = tf.placeholder(tf.int32, [batch_size, None]) 468 | np.random.seed(seed) 469 | tf.set_random_seed(seed) 470 | output = sample.sample_sequence( 471 | hparams=hparams, length=length, 472 | context=context, 473 | batch_size=batch_size, 474 | temperature=degree, top_k=top_k, top_p=top_p 475 | ) 476 | saver = tf.train.Saver() 477 | ckpt = tf.train.latest_checkpoint(os.path.join(models_dir, model_name)) 478 | saver.restore(sess, ckpt) 479 | context_tokens = enc.encode(raw_text) 480 | generated = 0 481 | for _ in range(nsamples // batch_size): 482 | out = sess.run(output, feed_dict={ 483 | context: [context_tokens for _ in range(batch_size)] 484 | })[:, len(context_tokens):] 485 | for i in range(batch_size): 486 | generated += 1 487 | text = enc.decode(out[i]) 488 | if debug == True: 489 | print("==========") 490 | print("Raw output: " + text) 491 | print("==========") 492 | if mode == True: 493 | splitted = text.splitlines()[0] 494 | else: 495 | splitted = text 496 | encodedstr = splitted.encode(encoding=sys.stdout.encoding,errors='ignore') 497 | decodedstr = encodedstr.decode("utf-8") 498 | final = str(decodedstr) 499 | # disable any regex on finishsentence mode. 500 | if mode == True: 501 | # Final regex 502 | sanitized = regex(final) 503 | finalsan = " ".join(re.split("[^a-zA-Z.,?!'*]+", sanitized)) 504 | 505 | else: 506 | finalsan = final 507 | if learn == True: 508 | learning = raw_text + finalsan + " " 509 | update.message.reply_text(finalsan) 510 | if debug == True: 511 | modes = str(mode) 512 | print("Chatbot mode: " + modes) 513 | learns = str(learn) 514 | print("Learning mode: " + learns) 515 | lengths = str(length) 516 | print("Length: " + lengths) 517 | print("==========") 518 | splits = str(splitted) 519 | print("Before regex: " + splits) 520 | print("==========") 521 | print("Output: " + finalsan) 522 | print("==========") 523 | print("Raw_text or Original: " + raw_text) 524 | print("==========") 525 | print("Learning text or Next: " + learning) 526 | print("==========") 527 | tps = str(top_p) 528 | print("Final top_p: " + tps) 529 | print("==========") 530 | print("top_p in: " + tpstring) 531 | print("==========") 532 | sess.close() 533 | 534 | def error(bot, update): 535 | """Log Errors caused by Updates.""" 536 | logger.warning('Update "%s" caused error "%s"', update) 537 | 538 | def main(): 539 | """Start the bot.""" 540 | # Create the Updater and pass it your bot's token. 541 | # Make sure to set use_context=True to use the new context based callbacks 542 | # Post version 12 this will no longer be necessary 543 | updater = Updater("BOTKEY", use_context=False) 544 | # Get the dispatcher to register handlers 545 | dp = updater.dispatcher 546 | # on different commands - answer in Telegram 547 | dp.add_handler(CommandHandler("start", start)) 548 | dp.add_handler(CommandHandler("help", help)) 549 | dp.add_handler(CommandHandler("chatbot", chatbot)) 550 | dp.add_handler(CommandHandler("finish", finish)) 551 | dp.add_handler(CommandHandler("learnon", learnon)) 552 | dp.add_handler(CommandHandler("learnoff", learnoff)) 553 | dp.add_handler(CommandHandler("learnreset", learnreset)) 554 | dp.add_handler(CommandHandler("retry", retry)) 555 | # on noncommand i.e message - echo the message on Telegram 556 | dp.add_handler(MessageHandler(Filters.text, runn)) 557 | # log all errors 558 | dp.add_error_handler(error) 559 | # Start the Bot 560 | updater.start_polling() 561 | # Run the bot until you press Ctrl-C or the process receives SIGINT, 562 | # SIGTERM or SIGABRT. This should be used most of the time, since 563 | # start_polling() is non-blocking and will stop the bot gracefully. 564 | updater.idle() 565 | 566 | if __name__ == '__main__': 567 | main() 568 | -------------------------------------------------------------------------------- /src/encoder.py: -------------------------------------------------------------------------------- 1 | """Byte pair encoding utilities""" 2 | 3 | import os 4 | import json 5 | import regex as re 6 | from functools import lru_cache 7 | 8 | @lru_cache() 9 | def bytes_to_unicode(): 10 | """ 11 | Returns list of utf-8 byte and a corresponding list of unicode strings. 12 | The reversible bpe codes work on unicode strings. 13 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 14 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 15 | This is a signficant percentage of your normal, say, 32K bpe vocab. 16 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 17 | And avoids mapping to whitespace/control characters the bpe code barfs on. 18 | """ 19 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 20 | cs = bs[:] 21 | n = 0 22 | for b in range(2**8): 23 | if b not in bs: 24 | bs.append(b) 25 | cs.append(2**8+n) 26 | n += 1 27 | cs = [chr(n) for n in cs] 28 | return dict(zip(bs, cs)) 29 | 30 | def get_pairs(word): 31 | """Return set of symbol pairs in a word. 32 | 33 | Word is represented as tuple of symbols (symbols being variable-length strings). 34 | """ 35 | pairs = set() 36 | prev_char = word[0] 37 | for char in word[1:]: 38 | pairs.add((prev_char, char)) 39 | prev_char = char 40 | return pairs 41 | 42 | class Encoder: 43 | def __init__(self, encoder, bpe_merges, errors='replace'): 44 | self.encoder = encoder 45 | self.decoder = {v:k for k,v in self.encoder.items()} 46 | self.errors = errors # how to handle errors in decoding 47 | self.byte_encoder = bytes_to_unicode() 48 | self.byte_decoder = {v:k for k, v in self.byte_encoder.items()} 49 | self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) 50 | self.cache = {} 51 | 52 | # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions 53 | self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") 54 | 55 | def bpe(self, token): 56 | if token in self.cache: 57 | return self.cache[token] 58 | word = tuple(token) 59 | pairs = get_pairs(word) 60 | 61 | if not pairs: 62 | return token 63 | 64 | while True: 65 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 66 | if bigram not in self.bpe_ranks: 67 | break 68 | first, second = bigram 69 | new_word = [] 70 | i = 0 71 | while i < len(word): 72 | try: 73 | j = word.index(first, i) 74 | new_word.extend(word[i:j]) 75 | i = j 76 | except: 77 | new_word.extend(word[i:]) 78 | break 79 | 80 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 81 | new_word.append(first+second) 82 | i += 2 83 | else: 84 | new_word.append(word[i]) 85 | i += 1 86 | new_word = tuple(new_word) 87 | word = new_word 88 | if len(word) == 1: 89 | break 90 | else: 91 | pairs = get_pairs(word) 92 | word = ' '.join(word) 93 | self.cache[token] = word 94 | return word 95 | 96 | def encode(self, text): 97 | bpe_tokens = [] 98 | for token in re.findall(self.pat, text): 99 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 100 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 101 | return bpe_tokens 102 | 103 | def decode(self, tokens): 104 | text = ''.join([self.decoder[token] for token in tokens]) 105 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors) 106 | return text 107 | 108 | def get_encoder(model_name, models_dir): 109 | with open(os.path.join(models_dir, model_name, 'encoder.json'), 'r') as f: 110 | encoder = json.load(f) 111 | with open(os.path.join(models_dir, model_name, 'vocab.bpe'), 'r', encoding="utf-8") as f: 112 | bpe_data = f.read() 113 | bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]] 114 | return Encoder( 115 | encoder=encoder, 116 | bpe_merges=bpe_merges, 117 | ) 118 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from tensorflow.contrib.training import HParams 4 | 5 | def default_hparams(): 6 | return HParams( 7 | n_vocab=0, 8 | n_ctx=1024, 9 | n_embd=768, 10 | n_head=12, 11 | n_layer=12, 12 | ) 13 | 14 | def shape_list(x): 15 | """Deal with dynamic shape in tensorflow cleanly.""" 16 | static = x.shape.as_list() 17 | dynamic = tf.shape(x) 18 | return [dynamic[i] if s is None else s for i, s in enumerate(static)] 19 | 20 | def softmax(x, axis=-1): 21 | x = x - tf.reduce_max(x, axis=axis, keepdims=True) 22 | ex = tf.exp(x) 23 | return ex / tf.reduce_sum(ex, axis=axis, keepdims=True) 24 | 25 | def gelu(x): 26 | return 0.5*x*(1+tf.tanh(np.sqrt(2/np.pi)*(x+0.044715*tf.pow(x, 3)))) 27 | 28 | def norm(x, scope, *, axis=-1, epsilon=1e-5): 29 | """Normalize to mean = 0, std = 1, then do a diagonal affine transform.""" 30 | with tf.variable_scope(scope): 31 | n_state = x.shape[-1].value 32 | g = tf.get_variable('g', [n_state], initializer=tf.constant_initializer(1)) 33 | b = tf.get_variable('b', [n_state], initializer=tf.constant_initializer(0)) 34 | u = tf.reduce_mean(x, axis=axis, keepdims=True) 35 | s = tf.reduce_mean(tf.square(x-u), axis=axis, keepdims=True) 36 | x = (x - u) * tf.rsqrt(s + epsilon) 37 | x = x*g + b 38 | return x 39 | 40 | def split_states(x, n): 41 | """Reshape the last dimension of x into [n, x.shape[-1]/n].""" 42 | *start, m = shape_list(x) 43 | return tf.reshape(x, start + [n, m//n]) 44 | 45 | def merge_states(x): 46 | """Smash the last two dimensions of x into a single dimension.""" 47 | *start, a, b = shape_list(x) 48 | return tf.reshape(x, start + [a*b]) 49 | 50 | def conv1d(x, scope, nf, *, w_init_stdev=0.02): 51 | with tf.variable_scope(scope): 52 | *start, nx = shape_list(x) 53 | w = tf.get_variable('w', [1, nx, nf], initializer=tf.random_normal_initializer(stddev=w_init_stdev)) 54 | b = tf.get_variable('b', [nf], initializer=tf.constant_initializer(0)) 55 | c = tf.reshape(tf.matmul(tf.reshape(x, [-1, nx]), tf.reshape(w, [-1, nf]))+b, start+[nf]) 56 | return c 57 | 58 | def attention_mask(nd, ns, *, dtype): 59 | """1's in the lower triangle, counting from the lower right corner. 60 | 61 | Same as tf.matrix_band_part(tf.ones([nd, ns]), -1, ns-nd), but doesn't produce garbage on TPUs. 62 | """ 63 | i = tf.range(nd)[:,None] 64 | j = tf.range(ns) 65 | m = i >= j - ns + nd 66 | return tf.cast(m, dtype) 67 | 68 | 69 | def attn(x, scope, n_state, *, past, hparams): 70 | assert x.shape.ndims == 3 # Should be [batch, sequence, features] 71 | assert n_state % hparams.n_head == 0 72 | if past is not None: 73 | assert past.shape.ndims == 5 # Should be [batch, 2, heads, sequence, features], where 2 is [k, v] 74 | 75 | def split_heads(x): 76 | # From [batch, sequence, features] to [batch, heads, sequence, features] 77 | return tf.transpose(split_states(x, hparams.n_head), [0, 2, 1, 3]) 78 | 79 | def merge_heads(x): 80 | # Reverse of split_heads 81 | return merge_states(tf.transpose(x, [0, 2, 1, 3])) 82 | 83 | def mask_attn_weights(w): 84 | # w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst. 85 | _, _, nd, ns = shape_list(w) 86 | b = attention_mask(nd, ns, dtype=w.dtype) 87 | b = tf.reshape(b, [1, 1, nd, ns]) 88 | w = w*b - tf.cast(1e10, w.dtype)*(1-b) 89 | return w 90 | 91 | def multihead_attn(q, k, v): 92 | # q, k, v have shape [batch, heads, sequence, features] 93 | w = tf.matmul(q, k, transpose_b=True) 94 | w = w * tf.rsqrt(tf.cast(v.shape[-1].value, w.dtype)) 95 | 96 | w = mask_attn_weights(w) 97 | w = softmax(w) 98 | a = tf.matmul(w, v) 99 | return a 100 | 101 | with tf.variable_scope(scope): 102 | c = conv1d(x, 'c_attn', n_state*3) 103 | q, k, v = map(split_heads, tf.split(c, 3, axis=2)) 104 | present = tf.stack([k, v], axis=1) 105 | if past is not None: 106 | pk, pv = tf.unstack(past, axis=1) 107 | k = tf.concat([pk, k], axis=-2) 108 | v = tf.concat([pv, v], axis=-2) 109 | a = multihead_attn(q, k, v) 110 | a = merge_heads(a) 111 | a = conv1d(a, 'c_proj', n_state) 112 | return a, present 113 | 114 | 115 | def mlp(x, scope, n_state, *, hparams): 116 | with tf.variable_scope(scope): 117 | nx = x.shape[-1].value 118 | h = gelu(conv1d(x, 'c_fc', n_state)) 119 | h2 = conv1d(h, 'c_proj', nx) 120 | return h2 121 | 122 | 123 | def block(x, scope, *, past, hparams): 124 | with tf.variable_scope(scope): 125 | nx = x.shape[-1].value 126 | a, present = attn(norm(x, 'ln_1'), 'attn', nx, past=past, hparams=hparams) 127 | x = x + a 128 | m = mlp(norm(x, 'ln_2'), 'mlp', nx*4, hparams=hparams) 129 | x = x + m 130 | return x, present 131 | 132 | def past_shape(*, hparams, batch_size=None, sequence=None): 133 | return [batch_size, hparams.n_layer, 2, hparams.n_head, sequence, hparams.n_embd // hparams.n_head] 134 | 135 | def expand_tile(value, size): 136 | """Add a new axis of given size.""" 137 | value = tf.convert_to_tensor(value, name='value') 138 | ndims = value.shape.ndims 139 | return tf.tile(tf.expand_dims(value, axis=0), [size] + [1]*ndims) 140 | 141 | def positions_for(tokens, past_length): 142 | batch_size = tf.shape(tokens)[0] 143 | nsteps = tf.shape(tokens)[1] 144 | return expand_tile(past_length + tf.range(nsteps), batch_size) 145 | 146 | 147 | def model(hparams, X, past=None, scope='model', reuse=False): 148 | with tf.variable_scope(scope, reuse=reuse): 149 | results = {} 150 | batch, sequence = shape_list(X) 151 | 152 | wpe = tf.get_variable('wpe', [hparams.n_ctx, hparams.n_embd], 153 | initializer=tf.random_normal_initializer(stddev=0.01)) 154 | wte = tf.get_variable('wte', [hparams.n_vocab, hparams.n_embd], 155 | initializer=tf.random_normal_initializer(stddev=0.02)) 156 | past_length = 0 if past is None else tf.shape(past)[-2] 157 | h = tf.gather(wte, X) + tf.gather(wpe, positions_for(X, past_length)) 158 | 159 | # Transformer 160 | presents = [] 161 | pasts = tf.unstack(past, axis=1) if past is not None else [None] * hparams.n_layer 162 | assert len(pasts) == hparams.n_layer 163 | for layer, past in enumerate(pasts): 164 | h, present = block(h, 'h%d' % layer, past=past, hparams=hparams) 165 | presents.append(present) 166 | results['present'] = tf.stack(presents, axis=1) 167 | h = norm(h, 'ln_f') 168 | 169 | # Language model loss. Do tokens hparams.n_ctx: 60 | raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx) 61 | 62 | self.sess = tf.Session(graph=tf.Graph()) 63 | self.sess.__enter__() 64 | 65 | self.context = tf.placeholder(tf.int32, [batch_size, None]) 66 | np.random.seed(seed) 67 | tf.set_random_seed(seed) 68 | self.output = sample.sample_sequence( 69 | hparams=hparams, length=length, 70 | context=self.context, 71 | batch_size=batch_size, 72 | temperature=temperature, top_k=top_k 73 | ) 74 | 75 | saver = tf.train.Saver() 76 | self.ckpt = tf.train.latest_checkpoint(os.path.join('models', model_name)) 77 | saver.restore(self.sess, self.ckpt) 78 | 79 | def close(self): 80 | self.sess.close() 81 | 82 | def generate_conditional(self,raw_text): 83 | context_tokens = self.enc.encode(raw_text) 84 | generated = 0 85 | for _ in range(self.nsamples // self.batch_size): 86 | out = self.sess.run(self.output, feed_dict={ 87 | self.context: [context_tokens for _ in range(self.batch_size)] 88 | })[:, len(context_tokens):] 89 | for i in range(self.batch_size): 90 | generated += 1 91 | text = self.enc.decode(out[i]) 92 | return text 93 | #print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40) 94 | #print(text) 95 | #print("=" * 80) 96 | ### 97 | 98 | gpt2 = GPT2(model_name="1558M") 99 | ### 100 | class Who: 101 | """A class defining the conversation parties: me, he""" 102 | def __init__(self): 103 | self.prefixes = [] 104 | 105 | def matches(self,phrase): 106 | for prefix in self.prefixes: 107 | if phrase.startswith(prefix): 108 | #print(f"{phrase} starts with {prefix}") 109 | return True 110 | 111 | #print(f"{phrase} does not start with {self.prefixes}") 112 | return False 113 | 114 | def get_random_prefix(self): 115 | return self.prefixes[0] 116 | 117 | class Me(Who): 118 | def __init__(self): 119 | super().__init__() 120 | self.prefixes = ["I said: \""] 121 | 122 | 123 | class You(Who): 124 | def __init__(self): 125 | super().__init__() 126 | self.prefixes = ["You said: \""] 127 | 128 | class Conversation: 129 | 130 | def __init__(self, prior = None): 131 | if prior is None: 132 | prior=""" 133 | You said: "Nice to meet you. What's your name?" 134 | I said: "My name is Pete." 135 | You said: "That's an interesting name. How old are you?" 136 | I said: "I'm 40 years old." 137 | You said: "Can you tell me something about yourself?" 138 | I said: "Ofcourse! I like playing video games and eating cake. " 139 | You said: "I like sweet stuff too. What are your plans for tomorrow?" 140 | """ 141 | self.suggestion = None 142 | 143 | self.me = Me() 144 | self.you = You() 145 | self.parties = [ self.me, self.you ] 146 | 147 | self.conversation = [] 148 | 149 | lines = prior.split("\n") 150 | for line in lines: 151 | line = line.strip() 152 | if len(line)!=0: 153 | party = None 154 | for party in self.parties: 155 | if party.matches(line): 156 | break 157 | if party is None: 158 | raise Exception(f"Unknown party: {line}") 159 | 160 | self.conversation.append((party,line)) 161 | self.get_suggestion() 162 | 163 | 164 | def get_prior(self): 165 | conv = "" 166 | for (party, line) in self.conversation: 167 | conv+=line+"\n" 168 | return conv 169 | 170 | def get_suggestion(self): 171 | who, last_line = self.conversation[-1] 172 | 173 | party_index = self.parties.index(who) 174 | next_party = self.parties[(party_index+1) % len(self.parties)] 175 | 176 | conv = self.get_prior() 177 | conv += next_party.get_random_prefix() 178 | answer = self.get_answer(next_party, conv) 179 | 180 | if not next_party.matches(answer): 181 | prefix = next_party.get_random_prefix() 182 | answer = prefix + answer 183 | 184 | self.suggestion = (next_party, answer) 185 | 186 | def next(self, party = None, answer = ""): 187 | """Continue the conversation 188 | :param party: None -> use the current party which is currently in turn 189 | :param answer: None -> use the suggestion, specify a text to override the 190 | suggestion 191 | 192 | """ 193 | suggested_party, suggested_answer = self.suggestion 194 | if party is None: 195 | party = suggested_party 196 | 197 | if answer == "": 198 | answer = suggested_answer 199 | 200 | if not party.matches(answer): 201 | prefix = party.get_random_prefix() 202 | answer = prefix + answer 203 | 204 | answer = answer.strip() 205 | if answer[-1] != "\"": 206 | # add the closing " 207 | answer += "\"" 208 | 209 | self.conversation.append((party, answer)) 210 | self.get_suggestion() 211 | 212 | def retry(self): 213 | self.get_suggestion() 214 | 215 | def get_answer(self, party, conv): 216 | answer = gpt2.generate_conditional(raw_text=conv) 217 | lines = answer.split("\n") 218 | line = "" 219 | for line in lines: 220 | if line !="": 221 | break 222 | 223 | if line!="": 224 | return line 225 | 226 | return "" 227 | 228 | def show(self): 229 | conv = "" 230 | for (party, line) in self.conversation: 231 | conv+=line+"\n" 232 | print(conv) 233 | if self.suggestion is not None: 234 | party, answer = self.suggestion 235 | print("--> "+answer) 236 | 237 | 238 | c = Conversation() 239 | c.show() 240 | c.retry() 241 | c.show() 242 | c.next() 243 | c.show() 244 | c.retry() 245 | c.next(c.you, "Pizza is not to good for your health though.") 246 | c.show() 247 | gpt2.close() 248 | 249 | # This is for possible future development but way slow out of date etc. 250 | -------------------------------------------------------------------------------- /src/sample.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | import model 4 | 5 | def top_k_logits(logits, k): 6 | if k == 0: 7 | # no truncation 8 | return logits 9 | 10 | def _top_k(): 11 | values, _ = tf.nn.top_k(logits, k=k) 12 | min_values = values[:, -1, tf.newaxis] 13 | return tf.where( 14 | logits < min_values, 15 | tf.ones_like(logits, dtype=logits.dtype) * -1e10, 16 | logits, 17 | ) 18 | return tf.cond( 19 | tf.equal(k, 0), 20 | lambda: logits, 21 | lambda: _top_k(), 22 | ) 23 | 24 | 25 | def top_p_logits(logits, p): 26 | """Nucleus sampling""" 27 | batch, _ = logits.shape.as_list() 28 | sorted_logits = tf.sort(logits, direction='DESCENDING', axis=-1) 29 | cumulative_probs = tf.cumsum(tf.nn.softmax(sorted_logits, axis=-1), axis=-1) 30 | indices = tf.stack([ 31 | tf.range(0, batch), 32 | # number of indices to include 33 | tf.maximum(tf.reduce_sum(tf.cast(cumulative_probs <= p, tf.int32), axis=-1) - 1, 0), 34 | ], axis=-1) 35 | min_values = tf.gather_nd(sorted_logits, indices) 36 | return tf.where( 37 | logits < min_values, 38 | tf.ones_like(logits) * -1e10, 39 | logits, 40 | ) 41 | 42 | 43 | def sample_sequence(*, hparams, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, top_p=1): 44 | if start_token is None: 45 | assert context is not None, 'Specify exactly one of start_token and context!' 46 | else: 47 | assert context is None, 'Specify exactly one of start_token and context!' 48 | context = tf.fill([batch_size, 1], start_token) 49 | 50 | def step(hparams, tokens, past=None): 51 | lm_output = model.model(hparams=hparams, X=tokens, past=past, reuse=tf.AUTO_REUSE) 52 | 53 | logits = lm_output['logits'][:, :, :hparams.n_vocab] 54 | presents = lm_output['present'] 55 | presents.set_shape(model.past_shape(hparams=hparams, batch_size=batch_size)) 56 | return { 57 | 'logits': logits, 58 | 'presents': presents, 59 | } 60 | 61 | with tf.name_scope('sample_sequence'): 62 | def body(past, prev, output): 63 | next_outputs = step(hparams, prev, past=past) 64 | logits = next_outputs['logits'][:, -1, :] / tf.to_float(temperature) 65 | logits = top_k_logits(logits, k=top_k) 66 | logits = top_p_logits(logits, p=top_p) 67 | samples = tf.multinomial(logits, num_samples=1, output_dtype=tf.int32) 68 | return [ 69 | next_outputs['presents'] if past is None else tf.concat([past, next_outputs['presents']], axis=-2), 70 | samples, 71 | tf.concat([output, samples], axis=1) 72 | ] 73 | 74 | past, prev, output = body(None, context, context) 75 | 76 | def cond(*args): 77 | return True 78 | 79 | _, _, tokens = tf.while_loop( 80 | cond=cond, body=body, 81 | maximum_iterations=length - 1, 82 | loop_vars=[ 83 | past, 84 | prev, 85 | output 86 | ], 87 | shape_invariants=[ 88 | tf.TensorShape(model.past_shape(hparams=hparams, batch_size=batch_size)), 89 | tf.TensorShape([batch_size, None]), 90 | tf.TensorShape([batch_size, None]), 91 | ], 92 | back_prop=False, 93 | ) 94 | 95 | return tokens 96 | -------------------------------------------------------------------------------- /start.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | while true 3 | do 4 | python3 src/GPT2-Learning.py 5 | done 6 | --------------------------------------------------------------------------------