├── .gitignore
├── .gitmodules
├── .idea
└── vcs.xml
├── Dockerfile
├── README.md
├── bot_samples.py
├── build.gradle
├── buildall.sh
├── gradle.properties
├── gradle
└── wrapper
│ ├── gradle-wrapper.jar
│ └── gradle-wrapper.properties
├── gradlew
├── gradlew.bat
├── settings.gradle
└── src
└── main
└── kotlin
└── com
└── purelymail
└── gpt2
└── Main.kt
/.gitignore:
--------------------------------------------------------------------------------
1 | # Created by .ignore support plugin (hsz.mobi)
2 | ### Kotlin template
3 | # Compiled class file
4 | *.class
5 |
6 | # Log file
7 | *.log
8 |
9 | # BlueJ files
10 | *.ctxt
11 |
12 | # Mobile Tools for Java (J2ME)
13 | .mtj.tmp/
14 |
15 | # Package Files #
16 | *.jar
17 | *.war
18 | *.nar
19 | *.ear
20 | *.zip
21 | *.tar.gz
22 | *.rar
23 |
24 | # virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml
25 | hs_err_pid*
26 | ### Gradle template
27 | .gradle
28 | /build/
29 |
30 | # Ignore Gradle GUI config
31 | gradle-app.setting
32 |
33 | # Avoid ignoring Gradle wrapper jar file (.jar files are usually ignored)
34 | !gradle-wrapper.jar
35 |
36 | # Cache of project
37 | .gradletasknamecache
38 |
39 | # # Work around https://youtrack.jetbrains.com/issue/IDEA-116898
40 | # gradle/wrapper/gradle-wrapper.properties
41 |
42 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "gpt-2"]
2 | path = gpt-2
3 | url = https://github.com/ScottPeterJohnson/gpt-2
4 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM gpt-2-submodule
2 | RUN apt-get update && apt-get install -y default-jdk
3 | COPY ./build/install /app
4 | COPY ./bot_samples.py /gpt-2/src
5 | WORKDIR /app/gpt2
6 | CMD ["bin/gpt2"]
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # GPT-2 Discord Bot
2 |
3 | ## What is this?
4 | This is just for fun. It's a Discord bot that generates semi-coherent
5 | gibberish based on whatever you prompt it with.
6 |
7 | ## Add it to your server
8 | [Click here](
9 | https://discordapp.com/api/oauth2/authorize?client_id=574391151583559721&permissions=2048&redirect_uri=https%3A%2F%2Fgithub.com%2FScottPeterJohnson%2Fgpt2-discord&scope=bot), but please be gentle.
10 |
11 | ## Ask it things
12 | Type `!gpt (your question here)`
13 |
14 | ## Building
15 | Make sure to run `git submodule update`.
16 | Run `./buildall.sh`, then run the generated docker container with the DISCORD_TOKEN environment
17 | variable set.
--------------------------------------------------------------------------------
/bot_samples.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import fire
4 | import json
5 | import os
6 | import numpy as np
7 | import tensorflow as tf
8 |
9 | import model, sample, encoder
10 | import base64
11 |
12 | def interact_model(
13 | model_name='124M',
14 | seed=None,
15 | nsamples=1,
16 | batch_size=1,
17 | length=None,
18 | temperature=1,
19 | top_k=0,
20 | models_dir='models',
21 | ):
22 | """
23 | Interactively run the model
24 | :model_name=124M : String, which model to use
25 | :seed=None : Integer seed for random number generators, fix seed to reproduce
26 | results
27 | :nsamples=1 : Number of samples to return total
28 | :batch_size=1 : Number of batches (only affects speed/memory). Must divide nsamples.
29 | :length=None : Number of tokens in generated text, if None (default), is
30 | determined by model hyperparameters
31 | :temperature=1 : Float value controlling randomness in boltzmann
32 | distribution. Lower temperature results in less random completions. As the
33 | temperature approaches zero, the model will become deterministic and
34 | repetitive. Higher temperature results in more random completions.
35 | :top_k=0 : Integer value controlling diversity. 1 means only 1 word is
36 | considered for each step (token), resulting in deterministic completions,
37 | while 40 means 40 words are considered at each step. 0 (default) is a
38 | special setting meaning no restrictions. 40 generally is a good value.
39 | :models_dir : path to parent folder containing model subfolders
40 | (i.e. contains the folder)
41 | """
42 | models_dir = os.path.expanduser(os.path.expandvars(models_dir))
43 | if batch_size is None:
44 | batch_size = 1
45 | assert nsamples % batch_size == 0
46 |
47 | enc = encoder.get_encoder(model_name, models_dir)
48 | hparams = model.default_hparams()
49 | with open(os.path.join(models_dir, model_name, 'hparams.json')) as f:
50 | hparams.override_from_dict(json.load(f))
51 |
52 | if length is None:
53 | length = hparams.n_ctx // 2
54 | elif length > hparams.n_ctx:
55 | raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx)
56 |
57 | with tf.Session(graph=tf.Graph()) as sess:
58 | context = tf.placeholder(tf.int32, [batch_size, None])
59 | np.random.seed(seed)
60 | tf.set_random_seed(seed)
61 | output = sample.sample_sequence(
62 | hparams=hparams, length=length,
63 | context=context,
64 | batch_size=batch_size,
65 | temperature=temperature, top_k=top_k
66 | )
67 |
68 | saver = tf.train.Saver()
69 | ckpt = tf.train.latest_checkpoint(os.path.join(models_dir, model_name))
70 | saver.restore(sess, ckpt)
71 |
72 | while True:
73 | raw_text = base64.b64decode(input()).decode("utf-8")
74 | context_tokens = enc.encode(raw_text)
75 | response = ""
76 | for _ in range(nsamples // batch_size):
77 | out = sess.run(output, feed_dict={
78 | context: [context_tokens for _ in range(batch_size)]
79 | })[:, len(context_tokens):]
80 | for i in range(batch_size):
81 | response += enc.decode(out[i])
82 | print(base64.b64encode(response.encode("utf-8")).decode("utf-8"))
83 |
84 | if __name__ == '__main__':
85 | fire.Fire(interact_model)
86 |
87 |
--------------------------------------------------------------------------------
/build.gradle:
--------------------------------------------------------------------------------
1 | plugins {
2 | id 'org.jetbrains.kotlin.jvm' version '1.3.31'
3 | }
4 |
5 | apply plugin: 'application'
6 |
7 | mainClassName = "com.purelymail.gpt2.MainKt"
8 |
9 | group 'net.justmachinery'
10 | version '1.0-SNAPSHOT'
11 |
12 | repositories {
13 | mavenCentral()
14 | }
15 |
16 | repositories {
17 | mavenCentral()
18 | jcenter()
19 | maven { url 'https://dl.bintray.com/scottpjohnson/generic/' }
20 | }
21 |
22 |
23 | dependencies {
24 | implementation "org.jetbrains.kotlin:kotlin-stdlib-jdk8"
25 | compile "net.justmachinery.shellin:shellin:0.1.5"
26 | compile 'net.dv8tion:JDA:3.8.3_462'
27 | }
28 |
29 | compileKotlin {
30 | kotlinOptions.jvmTarget = "1.8"
31 | }
32 | compileTestKotlin {
33 | kotlinOptions.jvmTarget = "1.8"
34 | }
--------------------------------------------------------------------------------
/buildall.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | cd gpt-2
3 | docker build --tag gpt-2-submodule -f Dockerfile.cpu ./
4 | cd ..
5 | ./gradlew installDist
6 | docker build ./
--------------------------------------------------------------------------------
/gradle.properties:
--------------------------------------------------------------------------------
1 | kotlin.code.style=official
--------------------------------------------------------------------------------
/gradle/wrapper/gradle-wrapper.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ScottPeterJohnson/gpt2-discord/d4cbf882d36a9016ab9131881d3ba3fad2d809c7/gradle/wrapper/gradle-wrapper.jar
--------------------------------------------------------------------------------
/gradle/wrapper/gradle-wrapper.properties:
--------------------------------------------------------------------------------
1 | #Sat May 04 17:29:15 PDT 2019
2 | distributionBase=GRADLE_USER_HOME
3 | distributionPath=wrapper/dists
4 | zipStoreBase=GRADLE_USER_HOME
5 | zipStorePath=wrapper/dists
6 | distributionUrl=https\://services.gradle.org/distributions/gradle-4.10.3-all.zip
7 |
--------------------------------------------------------------------------------
/gradlew:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env sh
2 |
3 | ##############################################################################
4 | ##
5 | ## Gradle start up script for UN*X
6 | ##
7 | ##############################################################################
8 |
9 | # Attempt to set APP_HOME
10 | # Resolve links: $0 may be a link
11 | PRG="$0"
12 | # Need this for relative symlinks.
13 | while [ -h "$PRG" ] ; do
14 | ls=`ls -ld "$PRG"`
15 | link=`expr "$ls" : '.*-> \(.*\)$'`
16 | if expr "$link" : '/.*' > /dev/null; then
17 | PRG="$link"
18 | else
19 | PRG=`dirname "$PRG"`"/$link"
20 | fi
21 | done
22 | SAVED="`pwd`"
23 | cd "`dirname \"$PRG\"`/" >/dev/null
24 | APP_HOME="`pwd -P`"
25 | cd "$SAVED" >/dev/null
26 |
27 | APP_NAME="Gradle"
28 | APP_BASE_NAME=`basename "$0"`
29 |
30 | # Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
31 | DEFAULT_JVM_OPTS=""
32 |
33 | # Use the maximum available, or set MAX_FD != -1 to use that value.
34 | MAX_FD="maximum"
35 |
36 | warn () {
37 | echo "$*"
38 | }
39 |
40 | die () {
41 | echo
42 | echo "$*"
43 | echo
44 | exit 1
45 | }
46 |
47 | # OS specific support (must be 'true' or 'false').
48 | cygwin=false
49 | msys=false
50 | darwin=false
51 | nonstop=false
52 | case "`uname`" in
53 | CYGWIN* )
54 | cygwin=true
55 | ;;
56 | Darwin* )
57 | darwin=true
58 | ;;
59 | MINGW* )
60 | msys=true
61 | ;;
62 | NONSTOP* )
63 | nonstop=true
64 | ;;
65 | esac
66 |
67 | CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
68 |
69 | # Determine the Java command to use to start the JVM.
70 | if [ -n "$JAVA_HOME" ] ; then
71 | if [ -x "$JAVA_HOME/jre/sh/java" ] ; then
72 | # IBM's JDK on AIX uses strange locations for the executables
73 | JAVACMD="$JAVA_HOME/jre/sh/java"
74 | else
75 | JAVACMD="$JAVA_HOME/bin/java"
76 | fi
77 | if [ ! -x "$JAVACMD" ] ; then
78 | die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME
79 |
80 | Please set the JAVA_HOME variable in your environment to match the
81 | location of your Java installation."
82 | fi
83 | else
84 | JAVACMD="java"
85 | which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
86 |
87 | Please set the JAVA_HOME variable in your environment to match the
88 | location of your Java installation."
89 | fi
90 |
91 | # Increase the maximum file descriptors if we can.
92 | if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then
93 | MAX_FD_LIMIT=`ulimit -H -n`
94 | if [ $? -eq 0 ] ; then
95 | if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then
96 | MAX_FD="$MAX_FD_LIMIT"
97 | fi
98 | ulimit -n $MAX_FD
99 | if [ $? -ne 0 ] ; then
100 | warn "Could not set maximum file descriptor limit: $MAX_FD"
101 | fi
102 | else
103 | warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT"
104 | fi
105 | fi
106 |
107 | # For Darwin, add options to specify how the application appears in the dock
108 | if $darwin; then
109 | GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\""
110 | fi
111 |
112 | # For Cygwin, switch paths to Windows format before running java
113 | if $cygwin ; then
114 | APP_HOME=`cygpath --path --mixed "$APP_HOME"`
115 | CLASSPATH=`cygpath --path --mixed "$CLASSPATH"`
116 | JAVACMD=`cygpath --unix "$JAVACMD"`
117 |
118 | # We build the pattern for arguments to be converted via cygpath
119 | ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null`
120 | SEP=""
121 | for dir in $ROOTDIRSRAW ; do
122 | ROOTDIRS="$ROOTDIRS$SEP$dir"
123 | SEP="|"
124 | done
125 | OURCYGPATTERN="(^($ROOTDIRS))"
126 | # Add a user-defined pattern to the cygpath arguments
127 | if [ "$GRADLE_CYGPATTERN" != "" ] ; then
128 | OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)"
129 | fi
130 | # Now convert the arguments - kludge to limit ourselves to /bin/sh
131 | i=0
132 | for arg in "$@" ; do
133 | CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -`
134 | CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option
135 |
136 | if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition
137 | eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"`
138 | else
139 | eval `echo args$i`="\"$arg\""
140 | fi
141 | i=$((i+1))
142 | done
143 | case $i in
144 | (0) set -- ;;
145 | (1) set -- "$args0" ;;
146 | (2) set -- "$args0" "$args1" ;;
147 | (3) set -- "$args0" "$args1" "$args2" ;;
148 | (4) set -- "$args0" "$args1" "$args2" "$args3" ;;
149 | (5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;;
150 | (6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;;
151 | (7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;;
152 | (8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;;
153 | (9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;;
154 | esac
155 | fi
156 |
157 | # Escape application args
158 | save () {
159 | for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done
160 | echo " "
161 | }
162 | APP_ARGS=$(save "$@")
163 |
164 | # Collect all arguments for the java command, following the shell quoting and substitution rules
165 | eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS"
166 |
167 | # by default we should be in the correct project dir, but when run from Finder on Mac, the cwd is wrong
168 | if [ "$(uname)" = "Darwin" ] && [ "$HOME" = "$PWD" ]; then
169 | cd "$(dirname "$0")"
170 | fi
171 |
172 | exec "$JAVACMD" "$@"
173 |
--------------------------------------------------------------------------------
/gradlew.bat:
--------------------------------------------------------------------------------
1 | @if "%DEBUG%" == "" @echo off
2 | @rem ##########################################################################
3 | @rem
4 | @rem Gradle startup script for Windows
5 | @rem
6 | @rem ##########################################################################
7 |
8 | @rem Set local scope for the variables with windows NT shell
9 | if "%OS%"=="Windows_NT" setlocal
10 |
11 | set DIRNAME=%~dp0
12 | if "%DIRNAME%" == "" set DIRNAME=.
13 | set APP_BASE_NAME=%~n0
14 | set APP_HOME=%DIRNAME%
15 |
16 | @rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
17 | set DEFAULT_JVM_OPTS=
18 |
19 | @rem Find java.exe
20 | if defined JAVA_HOME goto findJavaFromJavaHome
21 |
22 | set JAVA_EXE=java.exe
23 | %JAVA_EXE% -version >NUL 2>&1
24 | if "%ERRORLEVEL%" == "0" goto init
25 |
26 | echo.
27 | echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
28 | echo.
29 | echo Please set the JAVA_HOME variable in your environment to match the
30 | echo location of your Java installation.
31 |
32 | goto fail
33 |
34 | :findJavaFromJavaHome
35 | set JAVA_HOME=%JAVA_HOME:"=%
36 | set JAVA_EXE=%JAVA_HOME%/bin/java.exe
37 |
38 | if exist "%JAVA_EXE%" goto init
39 |
40 | echo.
41 | echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME%
42 | echo.
43 | echo Please set the JAVA_HOME variable in your environment to match the
44 | echo location of your Java installation.
45 |
46 | goto fail
47 |
48 | :init
49 | @rem Get command-line arguments, handling Windows variants
50 |
51 | if not "%OS%" == "Windows_NT" goto win9xME_args
52 |
53 | :win9xME_args
54 | @rem Slurp the command line arguments.
55 | set CMD_LINE_ARGS=
56 | set _SKIP=2
57 |
58 | :win9xME_args_slurp
59 | if "x%~1" == "x" goto execute
60 |
61 | set CMD_LINE_ARGS=%*
62 |
63 | :execute
64 | @rem Setup the command line
65 |
66 | set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar
67 |
68 | @rem Execute Gradle
69 | "%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS%
70 |
71 | :end
72 | @rem End local scope for the variables with windows NT shell
73 | if "%ERRORLEVEL%"=="0" goto mainEnd
74 |
75 | :fail
76 | rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of
77 | rem the _cmd.exe /c_ return code!
78 | if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1
79 | exit /b 1
80 |
81 | :mainEnd
82 | if "%OS%"=="Windows_NT" endlocal
83 |
84 | :omega
85 |
--------------------------------------------------------------------------------
/settings.gradle:
--------------------------------------------------------------------------------
1 | rootProject.name = 'gpt2'
2 |
3 |
--------------------------------------------------------------------------------
/src/main/kotlin/com/purelymail/gpt2/Main.kt:
--------------------------------------------------------------------------------
1 | package com.purelymail.gpt2
2 |
3 | import net.dv8tion.jda.core.JDA
4 | import net.dv8tion.jda.core.JDABuilder
5 | import net.dv8tion.jda.core.MessageBuilder
6 | import net.dv8tion.jda.core.events.message.MessageReceivedEvent
7 | import net.dv8tion.jda.core.hooks.ListenerAdapter
8 | import java.io.ByteArrayOutputStream
9 | import java.io.File
10 | import java.util.*
11 | import kotlin.concurrent.thread
12 |
13 | fun main(){
14 | MainServer()
15 | }
16 |
17 | class MainServer {
18 | val process = startProcess()
19 | val discord = connectDiscord()
20 |
21 | fun startProcess() : Process {
22 | val process = ProcessBuilder()
23 | .directory(File("/gpt-2"))
24 | .command("python3", "src/bot_samples.py", "--top_k", "40", "--length", "200", "--model_name", "774M")
25 | .redirectError(ProcessBuilder.Redirect.INHERIT)
26 | .start()
27 |
28 | return process
29 | }
30 |
31 | private fun readResponse() : String {
32 | val buffer = ByteArrayOutputStream()
33 | while(true){
34 | val result = process.inputStream.read()
35 | if(result > 0){
36 | val c = result.toChar()
37 | if(c == '\n'){
38 | return buffer.toByteArray().toString(Charsets.UTF_8)
39 | } else {
40 | buffer.write(result)
41 | }
42 | } else {
43 | break
44 | }
45 | }
46 | throw java.lang.IllegalStateException("Process terminated")
47 | }
48 |
49 | fun processRequest(prompt : String) : String {
50 | synchronized(this){
51 | process.outputStream.write(Base64.getEncoder().encode(prompt.toByteArray()))
52 | process.outputStream.write("\n".toByteArray())
53 | process.outputStream.flush()
54 |
55 | val decoded = Base64.getDecoder().decode(readResponse()).toString(Charsets.UTF_8)
56 | val parts = decoded.split("<|endoftext|>")
57 | return parts[0]
58 | }
59 | }
60 |
61 | fun connectDiscord() : JDA {
62 | val token = System.getenv("DISCORD_TOKEN") ?: throw IllegalStateException("Discord token not supplied")
63 | return JDABuilder(token)
64 | .addEventListener(object : ListenerAdapter() {
65 | override fun onMessageReceived(event: MessageReceivedEvent) {
66 | if(event.message.contentStripped.startsWith("!gpt ") && event.author != discord.selfUser){
67 | val prompt = event.message.contentStripped.removePrefix("!gpt ").take(900)
68 | println("Got prompt: $prompt")
69 | val typing = thread {
70 | try {
71 | while (true) {
72 | event.channel.sendTyping().submit().get()
73 | Thread.sleep(5000)
74 | }
75 | } catch(t : InterruptedException){}
76 | }
77 | try {
78 | val response = processRequest(prompt)
79 | println("Response: $response")
80 | event.channel.sendMessage(MessageBuilder()
81 | .append(prompt, MessageBuilder.Formatting.BOLD)
82 | .append(response)
83 | .build()
84 | ).queue()
85 | } finally {
86 | typing.interrupt()
87 | }
88 | }
89 | }
90 | })
91 | .build()
92 | }
93 | }
94 |
95 |
--------------------------------------------------------------------------------