├── .gitignore ├── .gitmodules ├── .idea └── vcs.xml ├── Dockerfile ├── README.md ├── bot_samples.py ├── build.gradle ├── buildall.sh ├── gradle.properties ├── gradle └── wrapper │ ├── gradle-wrapper.jar │ └── gradle-wrapper.properties ├── gradlew ├── gradlew.bat ├── settings.gradle └── src └── main └── kotlin └── com └── purelymail └── gpt2 └── Main.kt /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | ### Kotlin template 3 | # Compiled class file 4 | *.class 5 | 6 | # Log file 7 | *.log 8 | 9 | # BlueJ files 10 | *.ctxt 11 | 12 | # Mobile Tools for Java (J2ME) 13 | .mtj.tmp/ 14 | 15 | # Package Files # 16 | *.jar 17 | *.war 18 | *.nar 19 | *.ear 20 | *.zip 21 | *.tar.gz 22 | *.rar 23 | 24 | # virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml 25 | hs_err_pid* 26 | ### Gradle template 27 | .gradle 28 | /build/ 29 | 30 | # Ignore Gradle GUI config 31 | gradle-app.setting 32 | 33 | # Avoid ignoring Gradle wrapper jar file (.jar files are usually ignored) 34 | !gradle-wrapper.jar 35 | 36 | # Cache of project 37 | .gradletasknamecache 38 | 39 | # # Work around https://youtrack.jetbrains.com/issue/IDEA-116898 40 | # gradle/wrapper/gradle-wrapper.properties 41 | 42 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "gpt-2"] 2 | path = gpt-2 3 | url = https://github.com/ScottPeterJohnson/gpt-2 4 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM gpt-2-submodule 2 | RUN apt-get update && apt-get install -y default-jdk 3 | COPY ./build/install /app 4 | COPY ./bot_samples.py /gpt-2/src 5 | WORKDIR /app/gpt2 6 | CMD ["bin/gpt2"] -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GPT-2 Discord Bot 2 | 3 | ## What is this? 4 | This is just for fun. It's a Discord bot that generates semi-coherent 5 | gibberish based on whatever you prompt it with. 6 | 7 | ## Add it to your server 8 | [Click here]( 9 | https://discordapp.com/api/oauth2/authorize?client_id=574391151583559721&permissions=2048&redirect_uri=https%3A%2F%2Fgithub.com%2FScottPeterJohnson%2Fgpt2-discord&scope=bot), but please be gentle. 10 | 11 | ## Ask it things 12 | Type `!gpt (your question here)` 13 | 14 | ## Building 15 | Make sure to run `git submodule update`. 16 | Run `./buildall.sh`, then run the generated docker container with the DISCORD_TOKEN environment 17 | variable set. -------------------------------------------------------------------------------- /bot_samples.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import fire 4 | import json 5 | import os 6 | import numpy as np 7 | import tensorflow as tf 8 | 9 | import model, sample, encoder 10 | import base64 11 | 12 | def interact_model( 13 | model_name='124M', 14 | seed=None, 15 | nsamples=1, 16 | batch_size=1, 17 | length=None, 18 | temperature=1, 19 | top_k=0, 20 | models_dir='models', 21 | ): 22 | """ 23 | Interactively run the model 24 | :model_name=124M : String, which model to use 25 | :seed=None : Integer seed for random number generators, fix seed to reproduce 26 | results 27 | :nsamples=1 : Number of samples to return total 28 | :batch_size=1 : Number of batches (only affects speed/memory). Must divide nsamples. 29 | :length=None : Number of tokens in generated text, if None (default), is 30 | determined by model hyperparameters 31 | :temperature=1 : Float value controlling randomness in boltzmann 32 | distribution. Lower temperature results in less random completions. As the 33 | temperature approaches zero, the model will become deterministic and 34 | repetitive. Higher temperature results in more random completions. 35 | :top_k=0 : Integer value controlling diversity. 1 means only 1 word is 36 | considered for each step (token), resulting in deterministic completions, 37 | while 40 means 40 words are considered at each step. 0 (default) is a 38 | special setting meaning no restrictions. 40 generally is a good value. 39 | :models_dir : path to parent folder containing model subfolders 40 | (i.e. contains the folder) 41 | """ 42 | models_dir = os.path.expanduser(os.path.expandvars(models_dir)) 43 | if batch_size is None: 44 | batch_size = 1 45 | assert nsamples % batch_size == 0 46 | 47 | enc = encoder.get_encoder(model_name, models_dir) 48 | hparams = model.default_hparams() 49 | with open(os.path.join(models_dir, model_name, 'hparams.json')) as f: 50 | hparams.override_from_dict(json.load(f)) 51 | 52 | if length is None: 53 | length = hparams.n_ctx // 2 54 | elif length > hparams.n_ctx: 55 | raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx) 56 | 57 | with tf.Session(graph=tf.Graph()) as sess: 58 | context = tf.placeholder(tf.int32, [batch_size, None]) 59 | np.random.seed(seed) 60 | tf.set_random_seed(seed) 61 | output = sample.sample_sequence( 62 | hparams=hparams, length=length, 63 | context=context, 64 | batch_size=batch_size, 65 | temperature=temperature, top_k=top_k 66 | ) 67 | 68 | saver = tf.train.Saver() 69 | ckpt = tf.train.latest_checkpoint(os.path.join(models_dir, model_name)) 70 | saver.restore(sess, ckpt) 71 | 72 | while True: 73 | raw_text = base64.b64decode(input()).decode("utf-8") 74 | context_tokens = enc.encode(raw_text) 75 | response = "" 76 | for _ in range(nsamples // batch_size): 77 | out = sess.run(output, feed_dict={ 78 | context: [context_tokens for _ in range(batch_size)] 79 | })[:, len(context_tokens):] 80 | for i in range(batch_size): 81 | response += enc.decode(out[i]) 82 | print(base64.b64encode(response.encode("utf-8")).decode("utf-8")) 83 | 84 | if __name__ == '__main__': 85 | fire.Fire(interact_model) 86 | 87 | -------------------------------------------------------------------------------- /build.gradle: -------------------------------------------------------------------------------- 1 | plugins { 2 | id 'org.jetbrains.kotlin.jvm' version '1.3.31' 3 | } 4 | 5 | apply plugin: 'application' 6 | 7 | mainClassName = "com.purelymail.gpt2.MainKt" 8 | 9 | group 'net.justmachinery' 10 | version '1.0-SNAPSHOT' 11 | 12 | repositories { 13 | mavenCentral() 14 | } 15 | 16 | repositories { 17 | mavenCentral() 18 | jcenter() 19 | maven { url 'https://dl.bintray.com/scottpjohnson/generic/' } 20 | } 21 | 22 | 23 | dependencies { 24 | implementation "org.jetbrains.kotlin:kotlin-stdlib-jdk8" 25 | compile "net.justmachinery.shellin:shellin:0.1.5" 26 | compile 'net.dv8tion:JDA:3.8.3_462' 27 | } 28 | 29 | compileKotlin { 30 | kotlinOptions.jvmTarget = "1.8" 31 | } 32 | compileTestKotlin { 33 | kotlinOptions.jvmTarget = "1.8" 34 | } -------------------------------------------------------------------------------- /buildall.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cd gpt-2 3 | docker build --tag gpt-2-submodule -f Dockerfile.cpu ./ 4 | cd .. 5 | ./gradlew installDist 6 | docker build ./ -------------------------------------------------------------------------------- /gradle.properties: -------------------------------------------------------------------------------- 1 | kotlin.code.style=official -------------------------------------------------------------------------------- /gradle/wrapper/gradle-wrapper.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ScottPeterJohnson/gpt2-discord/d4cbf882d36a9016ab9131881d3ba3fad2d809c7/gradle/wrapper/gradle-wrapper.jar -------------------------------------------------------------------------------- /gradle/wrapper/gradle-wrapper.properties: -------------------------------------------------------------------------------- 1 | #Sat May 04 17:29:15 PDT 2019 2 | distributionBase=GRADLE_USER_HOME 3 | distributionPath=wrapper/dists 4 | zipStoreBase=GRADLE_USER_HOME 5 | zipStorePath=wrapper/dists 6 | distributionUrl=https\://services.gradle.org/distributions/gradle-4.10.3-all.zip 7 | -------------------------------------------------------------------------------- /gradlew: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | ############################################################################## 4 | ## 5 | ## Gradle start up script for UN*X 6 | ## 7 | ############################################################################## 8 | 9 | # Attempt to set APP_HOME 10 | # Resolve links: $0 may be a link 11 | PRG="$0" 12 | # Need this for relative symlinks. 13 | while [ -h "$PRG" ] ; do 14 | ls=`ls -ld "$PRG"` 15 | link=`expr "$ls" : '.*-> \(.*\)$'` 16 | if expr "$link" : '/.*' > /dev/null; then 17 | PRG="$link" 18 | else 19 | PRG=`dirname "$PRG"`"/$link" 20 | fi 21 | done 22 | SAVED="`pwd`" 23 | cd "`dirname \"$PRG\"`/" >/dev/null 24 | APP_HOME="`pwd -P`" 25 | cd "$SAVED" >/dev/null 26 | 27 | APP_NAME="Gradle" 28 | APP_BASE_NAME=`basename "$0"` 29 | 30 | # Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. 31 | DEFAULT_JVM_OPTS="" 32 | 33 | # Use the maximum available, or set MAX_FD != -1 to use that value. 34 | MAX_FD="maximum" 35 | 36 | warn () { 37 | echo "$*" 38 | } 39 | 40 | die () { 41 | echo 42 | echo "$*" 43 | echo 44 | exit 1 45 | } 46 | 47 | # OS specific support (must be 'true' or 'false'). 48 | cygwin=false 49 | msys=false 50 | darwin=false 51 | nonstop=false 52 | case "`uname`" in 53 | CYGWIN* ) 54 | cygwin=true 55 | ;; 56 | Darwin* ) 57 | darwin=true 58 | ;; 59 | MINGW* ) 60 | msys=true 61 | ;; 62 | NONSTOP* ) 63 | nonstop=true 64 | ;; 65 | esac 66 | 67 | CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar 68 | 69 | # Determine the Java command to use to start the JVM. 70 | if [ -n "$JAVA_HOME" ] ; then 71 | if [ -x "$JAVA_HOME/jre/sh/java" ] ; then 72 | # IBM's JDK on AIX uses strange locations for the executables 73 | JAVACMD="$JAVA_HOME/jre/sh/java" 74 | else 75 | JAVACMD="$JAVA_HOME/bin/java" 76 | fi 77 | if [ ! -x "$JAVACMD" ] ; then 78 | die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME 79 | 80 | Please set the JAVA_HOME variable in your environment to match the 81 | location of your Java installation." 82 | fi 83 | else 84 | JAVACMD="java" 85 | which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. 86 | 87 | Please set the JAVA_HOME variable in your environment to match the 88 | location of your Java installation." 89 | fi 90 | 91 | # Increase the maximum file descriptors if we can. 92 | if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then 93 | MAX_FD_LIMIT=`ulimit -H -n` 94 | if [ $? -eq 0 ] ; then 95 | if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then 96 | MAX_FD="$MAX_FD_LIMIT" 97 | fi 98 | ulimit -n $MAX_FD 99 | if [ $? -ne 0 ] ; then 100 | warn "Could not set maximum file descriptor limit: $MAX_FD" 101 | fi 102 | else 103 | warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT" 104 | fi 105 | fi 106 | 107 | # For Darwin, add options to specify how the application appears in the dock 108 | if $darwin; then 109 | GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\"" 110 | fi 111 | 112 | # For Cygwin, switch paths to Windows format before running java 113 | if $cygwin ; then 114 | APP_HOME=`cygpath --path --mixed "$APP_HOME"` 115 | CLASSPATH=`cygpath --path --mixed "$CLASSPATH"` 116 | JAVACMD=`cygpath --unix "$JAVACMD"` 117 | 118 | # We build the pattern for arguments to be converted via cygpath 119 | ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null` 120 | SEP="" 121 | for dir in $ROOTDIRSRAW ; do 122 | ROOTDIRS="$ROOTDIRS$SEP$dir" 123 | SEP="|" 124 | done 125 | OURCYGPATTERN="(^($ROOTDIRS))" 126 | # Add a user-defined pattern to the cygpath arguments 127 | if [ "$GRADLE_CYGPATTERN" != "" ] ; then 128 | OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)" 129 | fi 130 | # Now convert the arguments - kludge to limit ourselves to /bin/sh 131 | i=0 132 | for arg in "$@" ; do 133 | CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -` 134 | CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option 135 | 136 | if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition 137 | eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"` 138 | else 139 | eval `echo args$i`="\"$arg\"" 140 | fi 141 | i=$((i+1)) 142 | done 143 | case $i in 144 | (0) set -- ;; 145 | (1) set -- "$args0" ;; 146 | (2) set -- "$args0" "$args1" ;; 147 | (3) set -- "$args0" "$args1" "$args2" ;; 148 | (4) set -- "$args0" "$args1" "$args2" "$args3" ;; 149 | (5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;; 150 | (6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;; 151 | (7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;; 152 | (8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;; 153 | (9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;; 154 | esac 155 | fi 156 | 157 | # Escape application args 158 | save () { 159 | for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done 160 | echo " " 161 | } 162 | APP_ARGS=$(save "$@") 163 | 164 | # Collect all arguments for the java command, following the shell quoting and substitution rules 165 | eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS" 166 | 167 | # by default we should be in the correct project dir, but when run from Finder on Mac, the cwd is wrong 168 | if [ "$(uname)" = "Darwin" ] && [ "$HOME" = "$PWD" ]; then 169 | cd "$(dirname "$0")" 170 | fi 171 | 172 | exec "$JAVACMD" "$@" 173 | -------------------------------------------------------------------------------- /gradlew.bat: -------------------------------------------------------------------------------- 1 | @if "%DEBUG%" == "" @echo off 2 | @rem ########################################################################## 3 | @rem 4 | @rem Gradle startup script for Windows 5 | @rem 6 | @rem ########################################################################## 7 | 8 | @rem Set local scope for the variables with windows NT shell 9 | if "%OS%"=="Windows_NT" setlocal 10 | 11 | set DIRNAME=%~dp0 12 | if "%DIRNAME%" == "" set DIRNAME=. 13 | set APP_BASE_NAME=%~n0 14 | set APP_HOME=%DIRNAME% 15 | 16 | @rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. 17 | set DEFAULT_JVM_OPTS= 18 | 19 | @rem Find java.exe 20 | if defined JAVA_HOME goto findJavaFromJavaHome 21 | 22 | set JAVA_EXE=java.exe 23 | %JAVA_EXE% -version >NUL 2>&1 24 | if "%ERRORLEVEL%" == "0" goto init 25 | 26 | echo. 27 | echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. 28 | echo. 29 | echo Please set the JAVA_HOME variable in your environment to match the 30 | echo location of your Java installation. 31 | 32 | goto fail 33 | 34 | :findJavaFromJavaHome 35 | set JAVA_HOME=%JAVA_HOME:"=% 36 | set JAVA_EXE=%JAVA_HOME%/bin/java.exe 37 | 38 | if exist "%JAVA_EXE%" goto init 39 | 40 | echo. 41 | echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% 42 | echo. 43 | echo Please set the JAVA_HOME variable in your environment to match the 44 | echo location of your Java installation. 45 | 46 | goto fail 47 | 48 | :init 49 | @rem Get command-line arguments, handling Windows variants 50 | 51 | if not "%OS%" == "Windows_NT" goto win9xME_args 52 | 53 | :win9xME_args 54 | @rem Slurp the command line arguments. 55 | set CMD_LINE_ARGS= 56 | set _SKIP=2 57 | 58 | :win9xME_args_slurp 59 | if "x%~1" == "x" goto execute 60 | 61 | set CMD_LINE_ARGS=%* 62 | 63 | :execute 64 | @rem Setup the command line 65 | 66 | set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar 67 | 68 | @rem Execute Gradle 69 | "%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS% 70 | 71 | :end 72 | @rem End local scope for the variables with windows NT shell 73 | if "%ERRORLEVEL%"=="0" goto mainEnd 74 | 75 | :fail 76 | rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of 77 | rem the _cmd.exe /c_ return code! 78 | if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 79 | exit /b 1 80 | 81 | :mainEnd 82 | if "%OS%"=="Windows_NT" endlocal 83 | 84 | :omega 85 | -------------------------------------------------------------------------------- /settings.gradle: -------------------------------------------------------------------------------- 1 | rootProject.name = 'gpt2' 2 | 3 | -------------------------------------------------------------------------------- /src/main/kotlin/com/purelymail/gpt2/Main.kt: -------------------------------------------------------------------------------- 1 | package com.purelymail.gpt2 2 | 3 | import net.dv8tion.jda.core.JDA 4 | import net.dv8tion.jda.core.JDABuilder 5 | import net.dv8tion.jda.core.MessageBuilder 6 | import net.dv8tion.jda.core.events.message.MessageReceivedEvent 7 | import net.dv8tion.jda.core.hooks.ListenerAdapter 8 | import java.io.ByteArrayOutputStream 9 | import java.io.File 10 | import java.util.* 11 | import kotlin.concurrent.thread 12 | 13 | fun main(){ 14 | MainServer() 15 | } 16 | 17 | class MainServer { 18 | val process = startProcess() 19 | val discord = connectDiscord() 20 | 21 | fun startProcess() : Process { 22 | val process = ProcessBuilder() 23 | .directory(File("/gpt-2")) 24 | .command("python3", "src/bot_samples.py", "--top_k", "40", "--length", "200", "--model_name", "774M") 25 | .redirectError(ProcessBuilder.Redirect.INHERIT) 26 | .start() 27 | 28 | return process 29 | } 30 | 31 | private fun readResponse() : String { 32 | val buffer = ByteArrayOutputStream() 33 | while(true){ 34 | val result = process.inputStream.read() 35 | if(result > 0){ 36 | val c = result.toChar() 37 | if(c == '\n'){ 38 | return buffer.toByteArray().toString(Charsets.UTF_8) 39 | } else { 40 | buffer.write(result) 41 | } 42 | } else { 43 | break 44 | } 45 | } 46 | throw java.lang.IllegalStateException("Process terminated") 47 | } 48 | 49 | fun processRequest(prompt : String) : String { 50 | synchronized(this){ 51 | process.outputStream.write(Base64.getEncoder().encode(prompt.toByteArray())) 52 | process.outputStream.write("\n".toByteArray()) 53 | process.outputStream.flush() 54 | 55 | val decoded = Base64.getDecoder().decode(readResponse()).toString(Charsets.UTF_8) 56 | val parts = decoded.split("<|endoftext|>") 57 | return parts[0] 58 | } 59 | } 60 | 61 | fun connectDiscord() : JDA { 62 | val token = System.getenv("DISCORD_TOKEN") ?: throw IllegalStateException("Discord token not supplied") 63 | return JDABuilder(token) 64 | .addEventListener(object : ListenerAdapter() { 65 | override fun onMessageReceived(event: MessageReceivedEvent) { 66 | if(event.message.contentStripped.startsWith("!gpt ") && event.author != discord.selfUser){ 67 | val prompt = event.message.contentStripped.removePrefix("!gpt ").take(900) 68 | println("Got prompt: $prompt") 69 | val typing = thread { 70 | try { 71 | while (true) { 72 | event.channel.sendTyping().submit().get() 73 | Thread.sleep(5000) 74 | } 75 | } catch(t : InterruptedException){} 76 | } 77 | try { 78 | val response = processRequest(prompt) 79 | println("Response: $response") 80 | event.channel.sendMessage(MessageBuilder() 81 | .append(prompt, MessageBuilder.Formatting.BOLD) 82 | .append(response) 83 | .build() 84 | ).queue() 85 | } finally { 86 | typing.interrupt() 87 | } 88 | } 89 | } 90 | }) 91 | .build() 92 | } 93 | } 94 | 95 | --------------------------------------------------------------------------------