├── GPT2_with_Javascript_interface_POC.ipynb
├── LICENSE
├── README.md
├── gpt2js.png
└── how_it_works.md
/GPT2_with_Javascript_interface_POC.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "GPT2_with_Javascript_interface_POC.ipynb",
7 | "provenance": [],
8 | "collapsed_sections": [],
9 | "include_colab_link": true
10 | },
11 | "kernelspec": {
12 | "name": "python3",
13 | "display_name": "Python 3"
14 | },
15 | "accelerator": "GPU"
16 | },
17 | "cells": [
18 | {
19 | "cell_type": "markdown",
20 | "metadata": {
21 | "id": "view-in-github",
22 | "colab_type": "text"
23 | },
24 | "source": [
25 | ""
26 | ]
27 | },
28 | {
29 | "cell_type": "markdown",
30 | "metadata": {
31 | "id": "ZMMtepKIwxR8"
32 | },
33 | "source": [
34 | "#This is proof of concept that GPT-2 can be run from colab with Javascript interface\n",
35 | "**Q: How to do?**\n",
36 | "\n",
37 | "A: \n",
38 | "1. Runtime -> Change runtime type -> Hardware accelerator -> GPU\n",
39 | "2. Runtime -> Reset all runtimes\n",
40 | "3. Runtime -> Run all\n",
41 | "4. Scroll down and wait until you see the little window\n",
42 | "5. Type text\n",
43 | "6. The button \"Continue with GPT-2\" will invoke GPT-2 and it will continue your text.\n",
44 | "\n"
45 | ]
46 | },
47 | {
48 | "cell_type": "code",
49 | "metadata": {
50 | "id": "rGUX9yKxnaRe"
51 | },
52 | "source": [
53 | "%tensorflow_version 1.x\n",
54 | "!git clone https://github.com/gpt2ent/gpt-2-simple.git\n",
55 | "%cd gpt-2-simple\n",
56 | "!git checkout context-trim\n",
57 | "!pip install .\n",
58 | "%cd ..\n",
59 | "!git clone https://github.com/gpt2ent/gpt2colab-js.git\n",
60 | "%cd gpt2colab-js\n",
61 | "\n",
62 | "import gpt_2_simple as gpt2\n",
63 | "\n",
64 | "import os\n",
65 | "import requests\n",
66 | "import tensorflow as tf\n",
67 | "\n",
68 | "import re\n",
69 | "\n",
70 | "#Determining the graphics card used by colab: full model can run only on P100\n",
71 | "\n",
72 | "try:\n",
73 | " !cat /proc/driver/nvidia/gpus/0000:00:04.0/information >> /content/card_info.txt\n",
74 | " with open('/content/card_info.txt','r') as f:\n",
75 | " graphics_card = re.split('\\n|\\t\\t ',f.read())[1]\n",
76 | "\n",
77 | " if not graphics_card.startswith(\"Tesla P100\") and not graphics_card.startswith(\"Tesla T4\"):\n",
78 | " print(\"=\"*90+'\\n'+\"=\"*90+'\\n\\n')\n",
79 | " print('\\n\\tYour current GPU - %s - cannot fit the full GPT-2 model!' % graphics_card)\n",
80 | " print('\\n\\tFalling back on 774M model.')\n",
81 | " print('\\n\\tNothing I can do. just pray to Google to give you a P100')\n",
82 | " print('\\t\\tnext time. ¯\\_(ツ)_/¯')\n",
83 | " print('\\n\\tAlso you might try TPU runtime.')\n",
84 | " print('\\n\\n'+\"=\"*90+'\\n'+\"=\"*90+'\\n\\n')\n",
85 | " model_name = \"774M\"\n",
86 | " spinner_speed = \"300ms\"\n",
87 | " else:\n",
88 | " print('GPU: %s' % graphics_card)\n",
89 | " model_name = \"1558M\"\n",
90 | " spinner_speed = '400ms'\n",
91 | "except IndexError:\n",
92 | " print(\"=\"*90+'\\n'+\"=\"*90+'\\n\\n')\n",
93 | " print('\\n\\tYou\\'re not in a GPU runtime.\\n')\n",
94 | " print('\\n\\tTrying 1558M model anyways - assuming you\\'re on a good TPU.')\n",
95 | " print('\\n\\tIf it fails, you have to go to Runtime -> Change runtime type')\n",
96 | " print('\\n\\tand choose GPU.')\n",
97 | " print('\\n\\n'+\"=\"*90+'\\n'+\"=\"*90+'\\n\\n')\n",
98 | " model_name = \"1558M\"\n",
99 | " spinner_speed = \"1200ms\"\n",
100 | "\n",
101 | "\n",
102 | "#Overwrite default model choice\n",
103 | "#model_name = \"1558M\"\n",
104 | "#model_name = \"774M\"\n",
105 | "#model_name = \"124M\"\n",
106 | "#model_name = \"355M\"\n",
107 | "\n",
108 | "\n",
109 | "if not os.path.isdir(os.path.join(\"models\", model_name)):\n",
110 | " print(f\"Downloading {model_name} model...\")\n",
111 | " gpt2.download_gpt2(model_name=model_name)\n",
112 | " \n",
113 | "sess = gpt2.start_tf_sess()\n",
114 | "gpt2.load_gpt2(sess, model_name=model_name)\n",
115 | "\n",
116 | "generate_count = 0\n",
117 | "\n",
118 | "import google.colab.output\n",
119 | "\n",
120 | "import json\n",
121 | "\n",
122 | "class JsonRepr:\n",
123 | " \"\"\"\n",
124 | " For some reasons I can only use the result of __repr__\n",
125 | " from inside Javascript. So this wrapper uses json.dumps() as __repr__\n",
126 | " for python function output.\n",
127 | " \"\"\"\n",
128 | " def __init__(self, obj):\n",
129 | " self.obj = obj\n",
130 | "\n",
131 | " def __repr__(self):\n",
132 | " return json.dumps(self.obj)\n",
133 | "\n",
134 | "def overlap(a, b):\n",
135 | " return max(i for i in range(len(b)+1) if a.endswith(b[:i]))\n",
136 | "\n",
137 | "\n",
138 | "def ai_generate(prefix, temp, top_k, length):\n",
139 | " global sess\n",
140 | " global generate_count\n",
141 | "\n",
142 | " temp = float(temp)\n",
143 | " top_k = int(top_k)\n",
144 | " length = int(length)\n",
145 | " result = gpt2.generate(sess, model_name=model_name, prefix=prefix, temperature=temp,\n",
146 | " top_k=top_k, length=length, include_prefix=False, return_as_list=True)[0]\n",
147 | " \n",
148 | " j = overlap(prefix, result)\n",
149 | " result = result[j:]\n",
150 | " \n",
151 | " generate_count += 1\n",
152 | " if generate_count == 6:\n",
153 | " #prevent memory leak as in https://github.com/minimaxir/gpt-2-simple/issues/71\n",
154 | " tf.reset_default_graph()\n",
155 | " sess.close()\n",
156 | " sess = gpt2.start_tf_sess()\n",
157 | " gpt2.load_gpt2(sess, model_name=model_name)\n",
158 | " generate_count = 0\n",
159 | " return JsonRepr(result)\n",
160 | "\n",
161 | "#register callback for Javascript\n",
162 | "google.colab.output.register_callback('ai_generate', ai_generate)\n",
163 | "\n",
164 | "print('Done')"
165 | ],
166 | "execution_count": null,
167 | "outputs": []
168 | },
169 | {
170 | "cell_type": "code",
171 | "metadata": {
172 | "id": "CdyRipC0o8vR"
173 | },
174 | "source": [
175 | "from IPython.display import HTML\n",
176 | "\n",
177 | "#spinner from https://codepen.io/vovchisko/pen/vROoYQ\n",
178 | "spinner_css = \"\"\"\n",
179 | "\n",
216 | "\"\"\"\n",
217 | "\n",
218 | "input_form = \"\"\"\n",
219 | "\n",
220 | "\n",
221 | "
You have currently loaded %s model
\n", 223 | "\n", 239 | "
\n", 240 | "