├── interface ├── clover ├── start-message.txt ├── subTitle.txt ├── mainTitle.txt ├── colors-full.ini ├── prompt-instructions.txt ├── instructions.txt └── colors-classic.ini ├── models ├── .gitignore └── ReplikantModel │ ├── .gitignore │ └── config.json ├── launch.py ├── play.sh ├── umamba.exe ├── core ├── __init__.pyc ├── __init__.py ├── interface.py ├── getconfig.py ├── storymanager.py ├── gpt2.py ├── inline_editor.py ├── gpt2generator.py ├── utils.py └── play.py ├── requirements.txt ├── prompts ├── Lilith.txt └── EXAMPLE.txt ├── environments └── huggingface.yml ├── play.bat ├── saves └── EXAMPLE.json ├── LICENSE ├── research-notes.md ├── tools └── genresponses.py ├── install_requirements.bat ├── install.sh ├── config.ini └── README.md /interface/clover: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /models/.gitignore: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /launch.py: -------------------------------------------------------------------------------- 1 | import core 2 | -------------------------------------------------------------------------------- /models/ReplikantModel/.gitignore: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /play.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | venv/bin/python launch.py -------------------------------------------------------------------------------- /umamba.exe: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MrReplikant/AvrilAI/HEAD/umamba.exe -------------------------------------------------------------------------------- /core/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MrReplikant/AvrilAI/HEAD/core/__init__.pyc -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.8.1 2 | transformers 3 | gsutil 4 | numpy 5 | pyyaml 6 | regex 7 | tensorflow 8 | tracery 9 | prompt_toolkit 10 | -------------------------------------------------------------------------------- /interface/start-message.txt: -------------------------------------------------------------------------------- 1 | ATTENTION: Project Replikant is still in the experimental phase. Features may be altered or removed at any given point! However, this will not affect existing installs unless you choose to update it. 2 | -------------------------------------------------------------------------------- /prompts/Lilith.txt: -------------------------------------------------------------------------------- 1 | Your name is Lilith. You are a girl. You are a robot. You really like coffee. You work in a coffee shop. My name is Jack. I am a boy. I am a human. I work on computers. We are best friends. We live together. This is our conversation, and roleplay. 2 | Hi, Lilith! 3 | -------------------------------------------------------------------------------- /prompts/EXAMPLE.txt: -------------------------------------------------------------------------------- 1 | You are Argus. You are a male. You are a human. You are a college student. I am Mae. Mae is a female. Mae is a polite person. Mae is a cuddly person. Mae is scared of bats. Mae is an android. You created Mae. Mae loves you. Mae doesn't like racism. 2 | Hello, Mae! 3 | -------------------------------------------------------------------------------- /environments/huggingface.yml: -------------------------------------------------------------------------------- 1 | name: avrilai 2 | channels: 3 | - huggingface 4 | - pytorch 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - pytorch 9 | - python=3.8.* 10 | - cudatoolkit=11.1 11 | - transformers 12 | - pip 13 | - git 14 | - prompt_toolkit 15 | -------------------------------------------------------------------------------- /interface/subTitle.txt: -------------------------------------------------------------------------------- 1 | █████╗ ██╗ ██╗██████╗ ██╗██╗ █████╗ ██╗ 2 | ██╔══██╗██║ ██║██╔══██╗██║██║ ██╔══██╗██║ 3 | ███████║██║ ██║██████╔╝██║██║ ███████║██║ 4 | ██╔══██║╚██╗ ██╔╝██╔══██╗██║██║ ██╔══██║██║ 5 | ██║ ██║ ╚████╔╝ ██║ ██║██║███████╗██║ ██║██║ 6 | ╚═╝ ╚═╝ ╚═══╝ ╚═╝ ╚═╝╚═╝╚══════╝╚═╝ ╚═╝╚═╝ 7 | -------------------------------------------------------------------------------- /interface/mainTitle.txt: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | -------------------------------------------------------------------------------- /play.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | cd /D %~dp0 3 | TITLE AvrilAI 4 | SET /P M=nul 19 | SET TEMP=A:\ 20 | SET TMP=A:\ 21 | call A:\python\condabin\activate 22 | python launch.py %* 23 | subst A: /D 24 | cmd /k 25 | -------------------------------------------------------------------------------- /interface/colors-full.ini: -------------------------------------------------------------------------------- 1 | #Note: This file can only be used if Python Prompt Toolkit is available 2 | #Colors use either the hexadecimal notation or ANSI names. Various flags such as "bold" are also supported. 3 | #For more help, check out Python Prompt Toolkit's documentation on styling 4 | [Colors] 5 | displaymethod = prompt-toolkit 6 | default = #6e98d4 7 | error = bg:red fg:white 8 | loading-message = bg:#005f00 fg:white 9 | message=olive 10 | title = white 11 | subtitle = teal 12 | subsubtitle = white 13 | instructions = white 14 | selection-prompt = teal 15 | selection-value = teal 16 | menu = teal 17 | query = lime 18 | ai-text = #0087ff 19 | main-prompt = bg:navy fg:white 20 | user-text = #ff8700 21 | transformed-user-text = red 22 | print-story = teal 23 | -------------------------------------------------------------------------------- /models/ReplikantModel/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "attn_pdrop": 0.1, 3 | "embd_pdrop": 0.1, 4 | "finetuning_task": null, 5 | "id2label": { 6 | "0": "LABEL_0", 7 | "1": "LABEL_1" 8 | }, 9 | "initializer_range": 0.02, 10 | "is_decoder": false, 11 | "label2id": { 12 | "LABEL_0": 0, 13 | "LABEL_1": 1 14 | }, 15 | "layer_norm_epsilon": 1e-05, 16 | "n_ctx": 1024, 17 | "n_embd": 1600, 18 | "n_head": 25, 19 | "n_layer": 48, 20 | "n_positions": 1024, 21 | "n_vocab": 50257, 22 | "num_labels": 2, 23 | "output_attentions": false, 24 | "output_hidden_states": false, 25 | "output_past": true, 26 | "pruned_heads": {}, 27 | "resid_pdrop": 0.1, 28 | "summary_activation": null, 29 | "summary_first_dropout": 0.1, 30 | "summary_proj_to_labels": true, 31 | "summary_type": "cls_index", 32 | "summary_use_proj": true, 33 | "torchscript": false, 34 | "use_bfloat16": false, 35 | "vocab_size": 50257 36 | } 37 | -------------------------------------------------------------------------------- /interface/prompt-instructions.txt: -------------------------------------------------------------------------------- 1 | Write a context and prompt that describes your AI in as little space as you can. 2 | 3 | The first line (context) will be permanently added to the AI's memory for this conversation. 4 | 5 | The second line (prompt) will eventually be forgotten and so can be longer and include temporary information. 6 | 7 | An example context: You are Mr.Replikant. You are the creator of Project Replikant. You have a very driven personality. You are a boy. You are a human. I am Lilith. Lilith is a girl. Lilith is a robot. Lilith is your friend. Lilith has an outgoing personality. 8 | 9 | Keep the context to 1000 characters or less. The one shown above only used 1 4th of that, so you should have a fair amount of room for detail. 10 | 11 | Your first prompt should be something like this: 12 | 13 | Me: Hi, Lilith! Lilith: 14 | 15 | You leave your AI's name afterward to force the model to talk on your AI's behalf, instead of adding to what you've said. This part is very important. 16 | -------------------------------------------------------------------------------- /saves/EXAMPLE.json: -------------------------------------------------------------------------------- 1 | {"temp": 0.23, "top-p": 0.8, "top-keks": 0, "rep-pen": 1.3, "context": "You are Argus. You are a male. You are a human. You are a college student. I am Mae. Mae is a female. Mae is a polite person. Mae is a cuddly person. Mae is scared of bats. Mae is an android. You created Mae. Mae loves you. Mae doesn't like racism.", "memory": [], "actions": ["Me: Hello, Mae! Mae:", "Me: How are you feeling? Mae:", "Me: Good! Mae:", "Me: I mean that i'm happy for you! Mae:", "Me: Oh...it's ok! Let's talk about something else. Mae:", "Me: What do you wanna do today? Mae:", "Me: How about we play some games? Mae:", "Me: Yeah! Like monopoly, or poker! Mae:", "Me: Why not? Mae:", "Me: Ok, then is there any game you DO want to play? Mae:", "Me: WHat game do you want to play? Mae:", "Me: Ok! But what game do you want to play? Mae:"], "results": ["Hi, Me!", "Fine!", "What do you mean?", "I don't understand...", "OK!", "Nothing!", "Games?", "No!", "Because i hate the game!", "Yes!", "I wanna be on your team!", "The one where you have to go through the jungle and find bananas!"]} -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2019 Nick Walton 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. 20 | -------------------------------------------------------------------------------- /interface/instructions.txt: -------------------------------------------------------------------------------- 1 | AID2: Clover Edition Instructions: 2 | Enter actions starting with a verb ex. "go to the tavern" or "attack the orc." 3 | To speak enter 'say "(thing you want to say)"' or just "(thing you want to say)" 4 | The following commands can be entered for any action: 5 | "/revert" Reverts the last action allowing you to pick a different action. 6 | "/quit" Quits the game and saves 7 | "/menu" Starts a new game and saves your current one 8 | "/retry" Retries the last action 9 | "/restart" Restarts the current story 10 | "/print" Prints a transcript of your adventure (without extra newline formatting) 11 | "/help" Prints these instructions again 12 | "/set SETTING VALUE" Sets the specified setting to the specified value.: 13 | temp Higher values make the AI more random. Default: 0.4 14 | rep-pen Controls how repetitive the AI is allowed to be. Default: 1.2 15 | text-wrap-width Maximum width of lines printed by computer. Default: 80 16 | console-bell Beep after AI generates text? Default: on 17 | top-keks Number of words the AI can randomly choose. Default: 20 18 | generate-num Default: 60 19 | top-p Default: 0.9 20 | log-level Default: 3 21 | action-sugg How many actions to generate, 0 is off. Default: 4 22 | action-d20 Make actions difficult. Default: on 23 | action-temp How random the suggested actions are. Default: 1 24 | -------------------------------------------------------------------------------- /interface/colors-classic.ini: -------------------------------------------------------------------------------- 1 | #ECMA-48 set graphics codes 2 | #Check out "man console_codes" 3 | # Several attributes can be set in the same sequence, separated by semicolons. An empty parameter (between semicolons or string initiator or terminator) is interpreted as a zero. 4 | #0 reset all attributes to their defaults 5 | #1 set bold 6 | #2 set half-bright (simulated with color on a color display) 7 | #4 set underscore (simulated with color on a color display) 8 | #5 set blink 9 | #7 set reverse video 10 | #... 11 | #21 set normal intensity (ECMA-48 says "doubly underlined") 12 | #22 set normal intensity 13 | #24 underline off 14 | #25 blink off 15 | #27 reverse video off 16 | #30 set black foreground 17 | #31 set red foregroundmv 18 | #33 set brown foreground 19 | #34 set blue foreground 20 | #35 set magenta foreground 21 | #36 set cyan foreground 22 | #37 set white foreground 23 | #38 set underscore on, set default foreground color 24 | #39 set underscore off, set default foreground color 25 | #40 set black background 26 | #41 set red background 27 | #42 set green background 28 | #43 set brown background 29 | #44 set blue background 30 | #45 set magenta background 31 | #46 set cyan background 32 | #47 set white background 33 | #49 set default background color 34 | [Colors] 35 | displaymethod = classic 36 | default = 0 37 | error = 7 38 | loading-message = 7;34 39 | message=7;35 40 | title = 31 41 | subtitle = 36 42 | subsubtitle = 36;7 43 | instructions = 33 44 | selection-prompt = 7;32 45 | selection-value = 35 46 | menu = 36 47 | query = 7;42 48 | ai-text = 37 49 | main-prompt = 34 50 | user-text = 36 51 | transformed-user-text = 36 52 | print-story = 37 53 | -------------------------------------------------------------------------------- /research-notes.md: -------------------------------------------------------------------------------- 1 | After digging up the GutenBerg Dialogue Dataset repo, and downloading the trained models created based upon it, I have realized that 125M GPT-2 is simply too small 2 | for use with Project Replikant, at least when trained in such matter seen in the Gutenberg models. 3 | 4 | I am currently training yet another iteration of my experimental model, GPT-R, and this time, I am utilizing a parameter that has not been utilized before. This is 5 | the "noise" parameter, found in Neil Shepperd's GPT-2 training program. It is normally used to regularize against typos. But, I am hoping that it will also allow 6 | GPT-R to have more fluid conversational ability. Time will ultimately tell. 7 | 8 | 8/7/21 9 | The Noise parameter has definitely been of incredible help. So much so that I have managed to train a decent model in only 1,000 training steps. However, there is an optimal threshold past of which causes the training to make the model spew gibberish. Will be researching this further. 10 | 11 | An ongoing challenge remains that I still need quite a lot of training data. This shortage has proven to be crippling to the project, as it is quite difficult to train the model in emotional situations when there are none such data available. I have pleaded many times with the users and followers to help me, but as of now it seems my cries have fallen upon deaf ears. Or, perhaps, they are having just as much trouble as I am. It could easily be either. 12 | 13 | 14 | 9/4/21 After almost a month and some extensive research and tinkering, I've finally found some settings that will greatly help to compensate for the lack of power behind GPT-R, and at least somewhat make up for the abysmally small dataset. I am also beginning to work on roleplay, which for now will use parenthesis instead of asterisks, simply because of a bug in clover edition that causes the inline editor to remove them for some strange reason. I really gotta figure out what the hell is up with that. 15 | -------------------------------------------------------------------------------- /core/__init__.py: -------------------------------------------------------------------------------- 1 | from .play import * 2 | 3 | import traceback 4 | from pathlib import Path 5 | from datetime import datetime 6 | 7 | import gc 8 | import torch 9 | 10 | from .utils import * 11 | 12 | def print_intro(): 13 | print() 14 | 15 | with open(Path("interface/", "mainTitle.txt"), "r", encoding="utf-8") as file: 16 | output(file.read(), "title", wrap=False, beg='') 17 | 18 | with open(Path("interface/", "subTitle.txt"), "r", encoding="utf-8") as file: 19 | output(file.read(), "subtitle", wrap=False, beg='') 20 | 21 | if not use_ptoolkit() and os.name == 'nt': 22 | import ctypes 23 | kernel32 = ctypes.windll.kernel32 24 | kernel32.SetConsoleMode(kernel32.GetStdHandle(-11), 7) 25 | output("INFO: ANSI escape sequence enabled") 26 | 27 | 28 | logger.info("Colab detected: {}".format(in_colab())) 29 | 30 | if (__name__ == "__main__" or __name__ == "core"): 31 | with open(Path("interface/", "clover"), "r", encoding="utf-8") as file_: 32 | print(file_.read()) 33 | try: 34 | gm = GameManager(get_generator()) 35 | while True: 36 | # May be needed to avoid out of mem 37 | gc.collect() 38 | torch.cuda.empty_cache() 39 | print_intro() 40 | gm.play_story() 41 | except KeyboardInterrupt: 42 | output("Quitting.", "message") 43 | if gm and gm.story: 44 | if input_bool("Do you want to save? (y/N): ", "query"): 45 | save_story(gm.story) 46 | except Exception: 47 | traceback.print_exc() 48 | output("A fatal error has occurred. ", "error") 49 | if gm and gm.story: 50 | if not gm.story.savefile or len(gm.story.savefile.strip()) == 0: 51 | savefile = datetime.now().strftime("crashes/%d-%m-%Y_%H%M%S") 52 | else: 53 | savefile = gm.story.savefile 54 | save_story(gm.story, file_override=savefile) 55 | exit(1) 56 | -------------------------------------------------------------------------------- /tools/genresponses.py: -------------------------------------------------------------------------------- 1 | import configparser 2 | import torch 3 | import gc 4 | from random import choice 5 | import json 6 | from random import shuffle 7 | from gpt2generator import GPT2Generator 8 | from numpy.random import beta 9 | from numpy import greater 10 | from numpy import mean 11 | from pathlib import Path 12 | from sys import argv 13 | #from numpy import std 14 | samplesize=1024*16 15 | config = configparser.ConfigParser() 16 | config.read('AB.ini') 17 | A = config['A'] 18 | B = config['B'] 19 | generalSettings = config['All'] 20 | def genResponses(settings, n, name): 21 | responses = [] 22 | files=list(Path("AB-prompts").iterdir()) 23 | gc.collect() 24 | torch.cuda.empty_cache() 25 | generator = GPT2Generator( 26 | model_path = settings['model-path'], 27 | dtype = torch.float16, 28 | max_history_tokens=settings.getint('max-history-tokens') 29 | ) 30 | generator.top_p_first=settings.getboolean('top-p-first') 31 | for i in range(n): 32 | torch.cuda.synchronize() 33 | gc.collect() 34 | torch.cuda.empty_cache() 35 | file = choice(files) 36 | with file.open() as f: 37 | prompt=f.read() 38 | responses.append({ 39 | 'name':name, 40 | 'prompt':str(file.resolve()), 41 | 'output':generator.generate( 42 | context=prompt, 43 | temperature=settings.getfloat('temp'), 44 | top_p = settings.getfloat('top-p'), 45 | top_k = settings.getint('top-keks'), 46 | repetition_penalty=settings.getfloat('repetition-penalty'), 47 | repetition_penalty_slope=settings.getfloat('repetition-slope') 48 | ) 49 | }) 50 | generator=None 51 | gc.collect() 52 | torch.cuda.empty_cache() 53 | return responses 54 | 55 | 56 | name = argv[1] 57 | if name == 'A': 58 | testconfig=A 59 | elif name == 'B': 60 | testconfig=B 61 | else: 62 | assert False,'input not A or B' 63 | print(json.dumps(genResponses(testconfig, generalSettings.getint('num-samples'), name))) 64 | -------------------------------------------------------------------------------- /install_requirements.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | :Installer provided by the KoboldAI Project 3 | title AvrilAI Runtime Installer (MicroMamba) 4 | echo Errors? Rerun this as admin so it can add the needed LongPathsEnabled registery tweak. 5 | echo Installer failed or crashed? Run it again so it can continue. 6 | echo Only Windows 10 and higher officially supported, older Windows installations can't handle the paths. 7 | echo. 8 | 9 | Reg add "HKLM\SYSTEM\CurrentControlSet\Control\FileSystem" /v "LongPathsEnabled" /t REG_DWORD /d "1" /f 2>nul 10 | cd /D %~dp0 11 | 12 | if exist miniconda3\ ( 13 | echo Delete existing installation? 14 | echo This is required if you are switching modes, or if you get dependency errors in the game. 15 | echo 1. Yes 16 | echo 2. No 17 | SET /P D=Type the number of the desired option and then press ENTER: 18 | ) ELSE ( 19 | SET D=Workaround 20 | ) 21 | IF %D%==1 rmdir /s /q miniconda3 22 | 23 | :Mode 24 | echo Which installation mode would you like? 25 | echo 1. Temporary Drive Letter (Mounts the folder as drive A:, more stable and portable) 26 | echo 2. Subfolder (Traditional method, can't run in folder paths that contain spaces) 27 | echo. 28 | SET /P M=Type the number of the desired option and then press ENTER: 29 | IF %M%==1 GOTO drivemap 30 | IF %M%==2 GOTO subfolder 31 | ECHO Incorrect choice 32 | GOTO MODE 33 | 34 | 35 | :drivemap 36 | echo 1 > loader.settings 37 | subst A: /D >nul 38 | mkdir miniconda3 39 | subst A: miniconda3 40 | SET TEMP=A:\ 41 | SET TMP=A:\ 42 | copy umamba.exe A:\umamba.exe 43 | A: 44 | umamba.exe create -r A:\python\ -n base 45 | umamba.exe install --no-shortcuts -r A:\python\ -n base -f "%~dp0\environments\huggingface.yml" -y --always-copy 46 | umamba.exe -r A:\ clean -a -y 47 | rd A:\Python\pkgs /S /Q 48 | subst A: /d 49 | pause 50 | exit 51 | 52 | :subfolder 53 | echo 2 > loader.settings 54 | SET TEMP=%~DP0MINICONDA3 55 | SET TMP=%~DP0MINICONDA3 56 | umamba.exe create -r miniconda3\ -n base 57 | umamba.exe install --no-shortcuts -r miniconda3 -n base -f environments\huggingface.yml -y --always-copy 58 | umamba.exe clean -a -y 59 | rd miniconda3\Python\pkgs /S /Q 60 | pause 61 | exit 62 | -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -e 3 | cd "$(dirname "${0}")" 4 | BASE_DIR="$(pwd)" 5 | PACKAGES=(aria2 git unzip wget) 6 | 7 | pip_install () { 8 | if [ ! -d "./venv" ]; then 9 | # Some distros have venv built into python so this isn't always needed. 10 | if is_command 'apt-get'; then 11 | apt-get install python3-venv 12 | fi 13 | #WARNING: Changing to --copies for colab users, not optimal way to do this 14 | python3 -m venv --copies ./venv 15 | fi 16 | commit_hash=$(git log --pretty=format:'%h' -n 1) 17 | echo "You are using https://github.com/cloveranon/Clover-Edition/commit/${commit_hash}" 18 | source "${BASE_DIR}/venv/bin/activate" 19 | pip install --upgrade pip setuptools 20 | pip --no-cache-dir install -r "${BASE_DIR}/requirements/requirements.txt" 21 | 22 | echo "Would you like to install Nvidia CUDA support (~4.5gb) or just use your CPU (~800mb, but much slower)?" 23 | select yn in "Nvidia CUDA" "CPU only"; do 24 | case $yn in 25 | "Nvidia CUDA" ) pip install -r "${BASE_DIR}/requirements/cuda_requirements.txt"; break;; 26 | "CPU only" ) pip install -r "${BASE_DIR}/requirements/cpu_requirements.txt"; break;; 27 | esac 28 | done 29 | } 30 | 31 | is_command() { 32 | command -v "${@}" > /dev/null 33 | } 34 | 35 | system_package_install() { 36 | #why is this list duplicated? 37 | PACKAGES=(aria2 git unzip wget) 38 | if is_command 'apt-get'; then 39 | sudo apt-get install ${PACKAGES[@]} 40 | elif is_command 'brew'; then 41 | brew install ${PACKAGES[@]} 42 | elif is_command 'yum'; then 43 | sudo yum install ${PACKAGES[@]} 44 | elif is_command 'dnf'; then 45 | sudo dnf install ${PACKAGES[@]} 46 | elif is_command 'pacman'; then 47 | sudo pacman -S ${PACKAGES[@]} 48 | elif is_command 'apk'; then 49 | sudo apk --update add ${PACKAGES[@]} 50 | else 51 | echo "You do not seem to be using a supported package manager." 52 | echo "Please make sure ${PACKAGES[@]} are installed then press [ENTER]" 53 | read NOT_USED 54 | fi 55 | } 56 | 57 | install_aid () { 58 | # version_check 59 | #the order of this may be wrong, changing it back to original for now 60 | pip_install 61 | system_package_install 62 | } 63 | 64 | install_aid 65 | -------------------------------------------------------------------------------- /core/interface.py: -------------------------------------------------------------------------------- 1 | from .getconfig import settings, setting_info 2 | from .utils import pad_text 3 | 4 | def boolValue(bool): 5 | return "on" if bool else "off" 6 | 7 | def instructions(): 8 | print('\n' + 9 | 'AvrilAI Instructions: \n' + 10 | ' To do roleplay, use *thing you want to do goes here* followed by any other dialogue text and a period.\n' + 11 | ' To speak to the AI, just talk to it as if you were texting someone!\n') 12 | print('The following commands can be entered for any action:') 13 | print(' "/revert" Reverts the last action, allowing you to try a different one.') 14 | print(' "/quit" Quits the conversation and saves') 15 | print(' "/menu" Starts a new conversation and saves your current one') 16 | print(' "/retry" Retries the last action') 17 | print(' "/restart" Restarts the AI [WILL ERASE ITS MEMORY, DO NOT DO UNLESS YOURE ABSOLUTELY SURE]') 18 | print(' "/print" Prints a transcript of your conversation (without extra newline formatting)') 19 | print(' "/alter" Edit the last prompt from the AI') 20 | print(' "/altergen" Edit the last result from the AI and have it generate the rest') 21 | print(' "/context" Edit the AI\'s permanent context paragraph') 22 | print(' "/remember [SENTENCE]" Commits something permanently to the AI\'s memory') 23 | print(' "/memalt" Let you select and alter a memory entry') 24 | print(' "/memswap" Swaps places of two memory entries') 25 | print(' "/forget" Opens a menu allowing you to remove permanent memories') 26 | print(' "/save" Saves your conversation to a file in the save directory') 27 | print(' "/load" Loads a conversation from a file in the save directory') 28 | print(' "/summarize" Create a new conversation using by summarizing your previous one') 29 | print(' "/help" Prints these instructions again') 30 | print(' "/set [SETTING] [VALUE]" Sets the specified setting to the specified value.:') 31 | for k, v in setting_info.items(): 32 | print(pad_text(' ' + k, 27) + v[0] + (" " if v[0] else "") + 33 | "Default: " + str(v[1]) + " | " 34 | "Current: " + settings.get(k)) 35 | -------------------------------------------------------------------------------- /core/getconfig.py: -------------------------------------------------------------------------------- 1 | import configparser 2 | import logging 3 | 4 | config = configparser.ConfigParser() 5 | config.read("config.ini") 6 | settings = config["Settings"] 7 | 8 | colorschemefile = settings["color-scheme"] 9 | colorconfig = configparser.ConfigParser() 10 | colorconfig.read(colorschemefile) 11 | ptcolors = colorconfig["Colors"] 12 | 13 | colorschemefile = settings["backup-color-scheme"] 14 | colorconfig = configparser.ConfigParser() 15 | colorconfig.read(colorschemefile) 16 | colors = colorconfig["Colors"] 17 | 18 | logger = logging.getLogger(__name__) 19 | logLevel = settings.getint("log-level") 20 | oneLevelUp = 20 21 | 22 | # I don't know if this will work before loading the transformers module? 23 | # silence transformers outputs when loading model 24 | logging.getLogger("transformers.tokenization_utils").setLevel(logLevel + oneLevelUp) 25 | logging.getLogger("transformers.modeling_utils").setLevel(logLevel + oneLevelUp) 26 | logging.getLogger("transformers.configuration_utils").setLevel(logLevel + oneLevelUp) 27 | 28 | logging.basicConfig( 29 | format="%(asctime)s - %(levelname)s - %(message)s", 30 | datefmt="%m/%d/%Y %H:%M:%S", 31 | level=logLevel + oneLevelUp, 32 | ) 33 | logger.setLevel(logLevel) 34 | 35 | """ 36 | Settings descriptions and their default values keyed by their name. 37 | These settings, their descriptions, and their defaults appear in the settings menu and the /help prompt. 38 | """ 39 | setting_info = { 40 | "temp": ["Higher values make the AI more random.", 0.4], 41 | "rep-pen": ["Controls how repetitive the AI is allowed to be.", 1.2], 42 | "rep-pen-range": ["Controls many tokens are affected by the penalty.", 512], 43 | "rep-pen-slope": ["Controls the penalty curve slope.", 3.33], 44 | "text-wrap-width": ["Maximum width of lines printed by computer.", 80], 45 | "console-bell": ["Beep after AI generates text.", "on"], 46 | "top-keks": ["Number of words the AI can randomly choose.", 20], 47 | "action-sugg": ["How many actions to generate; 0 is off.", 4], 48 | "action-d20": ["Makes actions difficult.", "on"], 49 | "action-temp": ["How random the suggested actions are.", 1], 50 | "prompt-toolkit": ["Whether or not to use the prompt_toolkit library.", "on"], 51 | "autosave": ["Whether or not to save after every action.", "on"], 52 | "generate-num": ["Approximate number of words to generate.", 60], 53 | "top-p": ["Changes number of words nucleus sampled by the AI.", 0.9], 54 | "log-level": ["Development log level. <30 is for developers.", 30], 55 | "history-gpt-2": ["Number of tokens to feed into GPT-2 models.", 1024], 56 | "history-gpt-neo": ["Number of tokens to feed into GPT-Neo models.", 2048], 57 | } 58 | -------------------------------------------------------------------------------- /config.ini: -------------------------------------------------------------------------------- 1 | [Settings] 2 | temp = 0.23 3 | 4 | # Repetitiveness Penalty 5 | # Controls how repetitive the AIs output is allowed to be. 6 | # <1 encourages repeats (no one wants this). 7 | # 1 is no penalty/off 8 | # > 1 penalizes repeats 9 | # e.g. 1.2 is a 20% penalty 10 | # Common value is 1.2 as it's the default from the CTRL paper who introduced this https://arxiv.org/abs/1909.05858 11 | rep-pen = 1.3 12 | 13 | # Repetitiveness Penalty Range 14 | # Controls how far back tokens are penalized 15 | rep-pen-range = 512 16 | 17 | # Repetitiveness Penalty Slope 18 | # Controls the slope of the penalty curve 19 | rep-pen-slope = 3.33 20 | 21 | # The number of words the AI has to choose from. 22 | # It always chooses the "top k" most likely next words before randomly picking one according to temperature. 23 | # Low values reduce the randomness of the AI similar to temp. 24 | # Wont change generation speed. 0 is off 25 | # Many projects turn this off and use top-p. Original AI Dungeon used 40. 26 | top-keks = 0 27 | 28 | # The number of words the AI has to choose from. 29 | # top-p also called nucleus filtering, keep the top tokens with cumulative probability >= top_p ( see https://arxiv.org/pdf/1904.09751.pdf) 30 | # similar to top k but probobly better. Can be used together, or you can use this instead. 31 | # 0.9 is used as a default in a wide range of projects and papers 32 | # Low values reduce the randomness of the AI similar to temp. 33 | # Wont change generation speed 34 | top-p = 0.8 35 | 36 | # How long should the longest suggested actions be? higher is slower. 37 | # More technically, this is the number of generated Byte Pair Encoding tokens 38 | # (which are usually whole words) the AI generates for each story response. 39 | generate-num = 40 40 | 41 | # Dings the console bell when the AI responds 42 | # Check your terminal emulator's support for console bells if this doesn't work, it should typically buzz the PC speaker 43 | # Betcha didn't know ASCII supported sound 44 | console-bell = off 45 | 46 | # Maximum width of lines 47 | # Set to 0 to disable 48 | # Text wrapping has been much requested since I disabled it from vanilla. 49 | # In principle this should be a function of your terminal emulator and not an issue 50 | # Not sure of a good default but 80 was considered an ideal standard number of columns in old PCs. 51 | text-wrap-width = 120 52 | 53 | # On means you force use of the cpu even when you have a graphics card. off means you try to use the gpu if you have one 54 | force-cpu = off 55 | 56 | # 30 will not spam you with console log message, <30 will spam devs 57 | log-level = 30 58 | 59 | # use a dice to decide actions success. E.g. rolling a 1 means "You failed to X" 60 | action-d20 = off 61 | 62 | # how many action suggestions to generate, higher is slower 63 | action-sugg = 0 64 | 65 | # How weird (and potentially blank and loopy) should the suggested actions be. 66 | # 0.15 is v conservative, 67 | # 0.4 is conservative, 68 | # 1.0 is weird (default) 69 | # 1.5 is glitchy 70 | action-temp = 0.65 71 | 72 | # Experimental setting, ignore it for now 73 | top-p-first = on 74 | 75 | # Leave "off" unless in Google Colab 76 | colab-mode = off 77 | 78 | # Try to enable Python Prompt Toolkit. If problems are detected, it's disabled regardless of the setting 79 | prompt-toolkit = on 80 | 81 | # If true, saves after every action, and prompts the user when starting a story what so save it as 82 | autosave = on 83 | 84 | # Color scheme that is used if Python Prompt Toolkit is available. A classic-type color scheme can still be used here. 85 | color-scheme = interface/colors-full.ini 86 | 87 | # Backup color scheme in case Python Prompt Toolkit isn't available 88 | backup-color-scheme = interface/colors-classic.ini 89 | 90 | # Use experimental gpt2 (may be slightly faster, but buggy) 91 | gpt2-experimental = off 92 | 93 | # Max number of tokens for GPT-2 Models to use, more = more VRAM but more coherent story 94 | # Do not set it higher than 1024 95 | history-gpt-2 = 512 96 | 97 | # Max number of tokens for GPT-Neo Models to use, more = more VRAM but more coherent story 98 | # Some people claim setting this a little lower (~2000) is more stable for 8GB VRAM GPUs 99 | # Do not set it higher than 2048 100 | history-gpt-neo = 2048 101 | -------------------------------------------------------------------------------- /core/storymanager.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | from .getconfig import settings 4 | from .utils import output, format_result, format_input, get_similarity 5 | 6 | 7 | class Story: 8 | # the initial prompt is very special. 9 | # We want it to be permanently in the AI's limited memory (as well as possibly other strings of text.) 10 | def __init__(self, generator, context='', memory=None): 11 | if memory is None: 12 | memory = [] 13 | self.generator = generator 14 | self.context = context 15 | self.memory = memory 16 | self.actions = [] 17 | self.results = [] 18 | self.savefile = "" 19 | 20 | def act(self, action, record=True, format=True): 21 | assert (self.context.strip() + action.strip()) 22 | assert (settings.getint('top-keks') is not None) 23 | result = self.generator.generate( 24 | self.get_story() + action, 25 | self.context + ' '.join(self.memory), 26 | temperature=settings.getfloat('temp'), 27 | top_p=settings.getfloat('top-p'), 28 | top_k=settings.getint('top-keks'), 29 | repetition_penalty=settings.getfloat('rep-pen'), 30 | repetition_penalty_range=settings.getint('rep-pen-range'), 31 | repetition_penalty_slope=settings.getfloat('rep-pen-slope')) 32 | if record: 33 | self.actions.append(format_input(action)) 34 | self.results.append(format_input(result)) 35 | return format_result(result) if format else result 36 | 37 | def print_action_result(self, i, wrap=True, color=True): 38 | col1 = 'user-text' if color else None 39 | col2 = 'ai-text' if color else None 40 | if i == 0 or len(self.actions) == 1: 41 | start = format_result(self.context + ' ' + self.actions[0]) 42 | result = format_result(self.results[0]) 43 | is_start_end = re.match(r"[.!?]\s*$", start) # if start ends logically 44 | is_result_continue = re.match(r"^\s*[a-z.!?,\"]", result) # if result is a continuation 45 | sep = ' ' if not is_start_end and is_result_continue else '\n' 46 | if not self.actions[0]: 47 | output(self.context, col1, self.results[0], col2, sep=sep) 48 | else: 49 | output(self.context, col1) 50 | output(self.actions[0], col1, self.results[0], col2, sep=sep) 51 | else: 52 | if i < len(self.actions) and self.actions[i].strip() != "": 53 | caret = "> " if re.match(r"^ *you +", self.actions[i], flags=re.I) else "" 54 | output(format_result(caret + self.actions[i]), col1, wrap=wrap) 55 | if i < len(self.results) and self.results[i].strip() != "": 56 | output(format_result(self.results[i]), col2, wrap=wrap) 57 | 58 | def print_story(self, wrap=True, color=True): 59 | for i in range(0, max(len(self.actions), len(self.results))): 60 | self.print_action_result(i, wrap=wrap, color=color) 61 | 62 | def print_last(self, wrap=True, color=True): 63 | self.print_action_result(-1, wrap=wrap, color=color) 64 | 65 | def get_story(self): 66 | lines = [val for pair in zip(self.actions, self.results) for val in pair] 67 | return '\n\n'.join(lines) 68 | 69 | def revert(self): 70 | self.actions = self.actions[:-1] 71 | self.results = self.results[:-1] 72 | 73 | def get_suggestion(self): 74 | return re.sub('\n.*', '', 75 | self.generator.generate_raw( 76 | self.get_story() + "\n\n> You", 77 | self.context, 78 | temperature=settings.getfloat('action-temp'), 79 | top_p=settings.getfloat('top-p'), 80 | top_k=settings.getint('top-keks'), 81 | repetition_penalty=1)) 82 | 83 | def __str__(self): 84 | return self.context + ' ' + self.get_story() 85 | 86 | def to_dict(self): 87 | res = {} 88 | res["temp"] = settings.getfloat('temp') 89 | res["top-p"] = settings.getfloat("top-p") 90 | res["top-keks"] = settings.getint("top-keks") 91 | res["rep-pen"] = settings.getfloat("rep-pen") 92 | res["rep-pen-range"] = settings.getint("rep-pen-range") 93 | res["rep-pen-slope"] = settings.getfloat("rep-pen-slope") 94 | res["context"] = self.context 95 | res["memory"] = self.memory 96 | res["actions"] = self.actions 97 | res["results"] = self.results 98 | return res 99 | 100 | def from_dict(self, d): 101 | settings["temp"] = str(d["temp"]) 102 | settings["top-p"] = str(d["top-p"]) 103 | settings["top-keks"] = str(d["top-keks"]) 104 | settings["rep-pen"] = str(d["rep-pen"]) 105 | try: 106 | settings["rep-pen-range"] = str(d["rep-pen-range"]) 107 | settings["rep-pen-slope"] = str(d["rep-pen-slope"]) 108 | except: 109 | settings["rep-pen-range"] = "512" 110 | settings["rep-pen-slope"] = "3.33" 111 | self.context = d["context"] 112 | self.memory = d["memory"] 113 | self.actions = d["actions"] 114 | self.results = d["results"] 115 | 116 | def to_json(self): 117 | return json.dumps(self.to_dict()) 118 | 119 | def from_json(self, j): 120 | self.from_dict(json.loads(j)) 121 | 122 | def is_looping(self, threshold=0.9): 123 | if len(self.results) >= 2: 124 | similarity = get_similarity(self.results[-1], self.results[-2]) 125 | if similarity > threshold: 126 | return True 127 | return False 128 | 129 | # def save() 130 | # file=Path('saves', self.filename) 131 | -------------------------------------------------------------------------------- /core/gpt2.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from transformers import GPT2Config 5 | from transformers import GPT2PreTrainedModel 6 | 7 | 8 | 9 | def gelu(x): 10 | srqt_2_pi = 0.7978845608 11 | return .5 * x * (1 + torch.tanh(srqt_2_pi * (x + .044715 * (x ** 3)))) 12 | 13 | 14 | class Conv1D(torch.nn.Module): 15 | def __init__(self, nf, nx): 16 | """ Conv1D layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2) 17 | Basically works like a Linear layer but the weights are transposed 18 | """ 19 | super(Conv1D, self).__init__() 20 | self.nf = nf 21 | w = torch.empty(nx, nf) 22 | torch.nn.init.normal_(w, std=0.02) 23 | self.weight = torch.nn.Parameter(w) 24 | self._weightT = None 25 | self.bias = torch.nn.Parameter(torch.zeros(nf)) 26 | 27 | def forward(self, x): 28 | if self._weightT is None: 29 | self._weightT = self.weight.T 30 | return torch.nn.functional.linear(x, self._weightT, self.bias) 31 | 32 | 33 | class Attention(torch.nn.Module): 34 | def __init__(self, n_embd, n_ctx, config): 35 | super(Attention, self).__init__() 36 | # in Attention: n_embd=768 (nx=n_embd) 37 | # [switch nx => n_embd from Block to Attention to keep identical to TF implem] 38 | assert n_embd % config.n_head == 0 39 | self.register_buffer("m1e4", torch.full((1, 1, 1), -1e4)) 40 | self.n_head = config.n_head 41 | self.n_embd = n_embd 42 | 43 | self.c_attn = Conv1D(n_embd * 3, n_embd) 44 | self.c_proj = Conv1D(n_embd, n_embd) 45 | 46 | def _attn(self, q, k, v, mask): 47 | w = torch.matmul(q, k) 48 | w /= math.sqrt(v.size(-1)) 49 | 50 | w = torch.where(mask, w, self.m1e4) 51 | w = torch.nn.Softmax(dim=-1)(w) 52 | return torch.matmul(w, v) 53 | 54 | def merge_heads(self, x: torch.Tensor): 55 | x = x.permute(1, 0, 2).contiguous() 56 | new_x_shape = x.size()[:-2] + (self.n_embd,) 57 | return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states 58 | 59 | def split_heads(self, x): 60 | new_x_shape = x.size()[:-1] + (self.n_head, self.n_embd // self.n_head) 61 | x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states 62 | return x.permute(1, 0, 2) # (batch, head, seq_length, head_features) 63 | 64 | def forward(self, x, layer_past, mask): 65 | x = self.c_attn(x) 66 | x = x.view((x.size(0), 3, self.n_embd)) 67 | query, key, value = x[:, 0], x[:, 1], x[:, 2] 68 | # query, key, value = x.split(self.n_embd, dim=2) 69 | query = self.split_heads(query) 70 | key = self.split_heads(key) # , k=True) 71 | value = self.split_heads(value) 72 | 73 | if layer_past is not None: 74 | past_value = layer_past[1] # transpose back cf below 75 | value = torch.cat((past_value, value), dim=-2) 76 | 77 | past_key = layer_past[0] # .transpose(-2, -1) 78 | key = torch.cat((past_key, key), dim=-2) 79 | 80 | present = torch.stack([key, value]) # transpose to have same shapes for stacking 81 | 82 | a = self._attn(query, key.transpose(-2, -1), value, mask) 83 | a = self.merge_heads(a) 84 | a = self.c_proj(a) 85 | 86 | return a, present 87 | 88 | 89 | class MLP(torch.nn.Module): 90 | def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd) 91 | super(MLP, self).__init__() 92 | self.c_fc = Conv1D(n_state, config.n_embd) 93 | self.c_proj = Conv1D(config.n_embd, n_state) 94 | if hasattr(torch.nn, 'GELU'): 95 | self.act = torch.nn.GELU() # New in torch 1.4.0, but different results from transformers gelu 96 | else: 97 | self.act = gelu # the original gelu, written in pytorch 98 | 99 | def forward(self, x): 100 | h = self.act(self.c_fc(x)) 101 | h2 = self.c_proj(h) 102 | return h2 103 | 104 | 105 | class Block(torch.nn.Module): 106 | def __init__(self, n_ctx, config): 107 | super(Block, self).__init__() 108 | n_embd = config.n_embd 109 | self.ln_1 = torch.nn.LayerNorm(n_embd, eps=config.layer_norm_epsilon) 110 | self.attn = Attention(n_embd, n_ctx, config) 111 | self.ln_2 = torch.nn.LayerNorm(n_embd, eps=config.layer_norm_epsilon) 112 | self.mlp = MLP(4 * n_embd, config) 113 | 114 | def forward(self, x, layer_past, mask): 115 | a, present = self.attn(self.ln_1(x), layer_past, mask) 116 | x = x + a 117 | x += self.mlp(self.ln_2(x)) # residual 118 | 119 | return x, present # x, present 120 | 121 | 122 | class GPT2Model(GPT2PreTrainedModel): 123 | 124 | def __init__(self, config: GPT2Config): 125 | super(GPT2Model, self).__init__(config) 126 | 127 | self.wte = torch.nn.Embedding(config.vocab_size, config.n_embd) 128 | self.wpe = torch.nn.Embedding(config.n_positions, config.n_embd) 129 | self.h = torch.nn.ModuleList([Block(config.n_ctx, config) for _ in range(config.n_layer)]) 130 | self.ln_f = torch.nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) 131 | self.register_buffer("bigmask", torch.tril(torch.ones((config.n_ctx, config.n_ctx), dtype=torch.uint8))) 132 | self.init_weights() 133 | 134 | def get_input_embeddings(self): 135 | return self.wte 136 | 137 | def set_input_embeddings(self, new_embeddings): 138 | self.wte = new_embeddings 139 | 140 | def forward(self, input_ids: torch.Tensor, past: torch.Tensor): 141 | if input_ids is None: 142 | raise ValueError("You have to specify either input_ids or inputs_embeds") 143 | 144 | input_len = input_ids.size(0) 145 | past_length = past.size(-2) if past is not None else 0 146 | total_len = input_len + past_length 147 | position_embeds = self.wpe.weight.data[past_length:total_len] 148 | 149 | inputs_embeds = self.wte(input_ids) 150 | hidden_states = inputs_embeds + position_embeds 151 | 152 | mask = self.bigmask[None, past_length:total_len, :total_len] 153 | presents = [] 154 | for i in range(self.config.n_layer): 155 | layer_past = past[i] if past is not None else None 156 | trans_block = self.h[i] 157 | hidden_states, present = trans_block(hidden_states, layer_past, mask) 158 | presents.append(present) 159 | 160 | hidden_states = self.ln_f(hidden_states) 161 | return hidden_states, torch.stack(presents) 162 | 163 | 164 | class GPT2LMHeadModelExperimental(GPT2PreTrainedModel): 165 | 166 | def __init__(self, config): 167 | super(GPT2LMHeadModelExperimental, self).__init__(config) 168 | self.transformer = GPT2Model(config) 169 | self.lm_head = torch.nn.Linear(config.n_embd, config.vocab_size, bias=False) 170 | 171 | self.init_weights() 172 | self.tie_weights() 173 | 174 | def tie_weights(self): 175 | """ Make sure we are sharing the input and output embeddings. 176 | Export to TorchScript can't handle parameter sharing so we are cloning them instead. 177 | """ 178 | self._tie_or_clone_weights(self.lm_head, 179 | self.transformer.wte) 180 | 181 | def forward(self, input_ids: torch.Tensor, **kwargs): 182 | hidden_states, pasts = self.transformer(input_ids, **kwargs) 183 | lm_logits = self.lm_head(hidden_states) 184 | return lm_logits, pasts 185 | -------------------------------------------------------------------------------- /core/inline_editor.py: -------------------------------------------------------------------------------- 1 | from prompt_toolkit.application import Application 2 | from prompt_toolkit.application.current import get_app 3 | from prompt_toolkit.key_binding import KeyBindings 4 | from prompt_toolkit.layout.containers import ( 5 | HSplit, 6 | Window, 7 | ) 8 | from .utils import clear_lines, getTermWidth 9 | from prompt_toolkit.layout.controls import FormattedTextControl 10 | from prompt_toolkit.layout.layout import Layout 11 | from prompt_toolkit.widgets import ( 12 | TextArea, 13 | ) 14 | 15 | 16 | def edit_multiline(default_text=""): 17 | kb = KeyBindings() 18 | 19 | @kb.add('c-q') 20 | @kb.add('escape', 'enter') 21 | def exit_(event): 22 | """ 23 | Pressing Ctrl-Q, Alt+Enter or Esc + Enter will exit the editor. 24 | """ 25 | event.app.exit(textf.text) 26 | 27 | @kb.add('c-c') 28 | def do_copy(event): 29 | data = textf.buffer.copy_selection() 30 | get_app().clipboard.set_data(data) 31 | 32 | @kb.add('c-x', eager=True) 33 | def do_cut(event): 34 | data = textf.buffer.cut_selection() 35 | get_app().clipboard.set_data(data) 36 | 37 | @kb.add('c-z') 38 | def do_undo(event): 39 | textf.buffer.undo() 40 | 41 | @kb.add('c-y') 42 | def do_redo(event): 43 | textf.buffer.redo() 44 | 45 | @kb.add('c-a') 46 | def do_select_all(event): 47 | textf.buffer.cursor_position = 0 48 | textf.buffer.start_selection() 49 | textf.buffer.cursor_position = len(textf.buffer.text) 50 | update_stored_pos(None) 51 | 52 | @kb.add('c-v') 53 | def do_paste(event): 54 | textf.buffer.paste_clipboard_data(get_app().clipboard.get_data()) 55 | 56 | @kb.add('left') 57 | def kb_left(event): 58 | textf.buffer.selection_state = None 59 | if textf.buffer.cursor_position != 0 and textf.text[textf.buffer.cursor_position-1] == '\n': 60 | textf.buffer.cursor_up() 61 | textf.buffer.cursor_right(len(textf.text)) 62 | else: 63 | textf.buffer.cursor_left() 64 | update_stored_pos(None) 65 | 66 | @kb.add('right') 67 | def kb_right(event): 68 | textf.buffer.selection_state = None 69 | if textf.buffer.cursor_position < len(textf.text) and textf.text[textf.buffer.cursor_position] == '\n': 70 | textf.buffer.cursor_down() 71 | textf.buffer.cursor_left(len(textf.text)) 72 | 73 | else: 74 | textf.buffer.cursor_right() 75 | update_stored_pos(None) 76 | 77 | @kb.add('home') 78 | def kb_home(event): 79 | textf.buffer.selection_state = None 80 | width = getTermWidth() 81 | doc = textf.document 82 | if textf.buffer.cursor_position == doc._line_start_indexes[cursor_row()] + int(cursor_col() / width) * width: 83 | textf.buffer.cursor_position = doc._line_start_indexes[cursor_row()] 84 | else: 85 | textf.buffer.cursor_position = doc._line_start_indexes[cursor_row()] + int(cursor_col() / width) * width 86 | update_stored_pos(None) 87 | 88 | @kb.add('end') 89 | def kb_end(event): 90 | textf.buffer.selection_state = None 91 | width = getTermWidth() 92 | doc = textf.document 93 | row = cursor_row() 94 | if textf.buffer.cursor_position == doc._line_start_indexes[row] + (int(cursor_col() / width) + 1) * width - 1: 95 | textf.buffer.cursor_position = doc._line_start_indexes[row] + len(doc.current_line) 96 | else: 97 | textf.buffer.cursor_position = min(doc._line_start_indexes[row] + (int(cursor_col() / width) + 1) * width - 1, doc._line_start_indexes[row] + len(doc.current_line)) 98 | update_stored_pos(None) 99 | 100 | @kb.add('up') 101 | def kb_up(event): 102 | textf.freezestore = True 103 | width = getTermWidth() 104 | doc = textf.document 105 | textf.buffer.selection_state = None 106 | col = cursor_col() 107 | row = cursor_row() 108 | if width > 9000: # A failsafe in case the terminal size is incorrectly detected 109 | textf.buffer.cursor_up() 110 | return 111 | 112 | if col >= width: # Move one row up staying on the same line 113 | textf.buffer.cursor_position = doc._line_start_indexes[row] + int(col / width - 1) * width + textf.stored_cursor_pos 114 | elif row >= 1: # Moving up to a different line 115 | prevlinelen = len(doc.lines[row - 1]) 116 | 117 | textf.buffer.cursor_position = min(doc._line_start_indexes[row] - 1, doc._line_start_indexes[row-1]+int(prevlinelen / width)*width + textf.stored_cursor_pos) 118 | else: # Cursor is on the first row of first line 119 | textf.buffer.cursor_position = 0 120 | textf.freezestore = False 121 | update_stored_pos(None) 122 | 123 | @kb.add('down') 124 | def kb_down(event): 125 | textf.freezestore = True 126 | width = getTermWidth() 127 | doc = textf.document 128 | textf.buffer.selection_state = None 129 | col = cursor_col() 130 | row = cursor_row() 131 | nextlinelen = len(doc.lines[row + 1]) if row < len(doc.lines)-1 else -1 132 | if width > 9000: # A failsafe in case the terminal size is incorrectly detected 133 | textf.buffer.cursor_down() 134 | return 135 | 136 | if col <= len(doc.current_line)-width: # Move one row down staying on the same line 137 | textf.buffer.cursor_position = doc._line_start_indexes[row] + int(col / width + 1) * width + textf.stored_cursor_pos 138 | elif nextlinelen < 0: # Move to the very end 139 | textf.buffer.cursor_position = len(textf.text) 140 | textf.freezestore = False 141 | update_stored_pos(None) 142 | # Move to the end of the same line the cursor is on 143 | elif col != len(doc.lines[row]) and textf.stored_cursor_pos >= len(doc.lines[row]) - int(len(doc.lines[row]) / width)*width: 144 | textf.buffer.cursor_position = doc._line_start_indexes[row+1] - 1 145 | else: # Move to a different line 146 | textf.buffer.cursor_position = min(doc._line_start_indexes[row+1]+nextlinelen, doc._line_start_indexes[row+1]+textf.stored_cursor_pos) 147 | 148 | 149 | textf = TextArea() 150 | bottom_bar_text=FormattedTextControl(text='\nCurrently editing. Press Ctrl+Q, Alt+Enter or Esc + Enter to exit.') 151 | bottom_bar=Window(content=bottom_bar_text) 152 | 153 | root_container = HSplit([ 154 | textf, 155 | bottom_bar, 156 | ]) 157 | 158 | layout = Layout(root_container) 159 | 160 | app = Application(key_bindings=kb, layout=layout, enable_page_navigation_bindings=True, full_screen=False) 161 | textf.freezestore = False 162 | textf.text=default_text 163 | textf.buffer.cursor_position = len(textf.buffer.text) 164 | 165 | 166 | # Find the row the cursor is at 167 | # My own function, in fear of race conditions 168 | def cursor_row(): 169 | i = 0 170 | while i < len(textf.document._line_start_indexes) and textf.buffer.cursor_position >= textf.document._line_start_indexes[i]: 171 | i+=1 172 | return i-1 173 | 174 | 175 | # Find the column the cursor is at 176 | # There is a built-in function, but I think there's some kind of a race condition if it's used 177 | def cursor_col(): 178 | i = textf.buffer.cursor_position - 1 179 | while i >= 0 and textf.text[i] != '\n': 180 | i-=1 181 | return textf.buffer.cursor_position - i - 1 182 | 183 | 184 | def update_stored_pos(event): 185 | if not event: 186 | textf.freezestore = False 187 | if textf.freezestore: 188 | textf.freezestore = False 189 | return 190 | width = getTermWidth() 191 | col = cursor_col() 192 | textf.stored_cursor_pos = col - int(col / width) * width 193 | 194 | textf.buffer.on_cursor_position_changed += update_stored_pos 195 | update_stored_pos(None) 196 | 197 | text = app.run() 198 | 199 | clear_lines(1) 200 | 201 | return text 202 | 203 | 204 | if __name__ == "__main__": 205 | print() 206 | print() 207 | print() 208 | editthis = "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Nam eu fringilla sapien. Maecenas sodales consequat lorem, in consectetur mi interdum eu. Nullam ut odio mattis, congue odio non, vulputate metus. Integer vel eros eu risus ultricies venenatis a id diam. Nullam viverra congue quam, in aliquam tellus posuere et. Fusce pharetra interdum velit eget hendrerit. Nulla nec velit nibh. Integer at quam sem. Suspendisse tincidunt est non porttitor lobortis. Nulla orci justo, euismod a venenatis eget, feugiat et orci." +\ 209 | "\nDonec faucibus volutpat diam, nec varius arcu condimentum eget." +\ 210 | "\nUt sollicitudin blandit leo in faucibus. Etiam dictum pretium placerat. Nulla blandit diam vel justo fermentum, sit amet tempor ante gravida." +\ 211 | "\n\nDonec maximus cursus eros, sit amet dapibus tellus. Nullam sed ultrices lacus. Sed nibh nisi, ornare a libero et, mattis facilisis tellus. Cras mauris metus, vulputate ut dolor at, viverra ullamcorper nisi. Nulla ornare augue eget orci semper, ac congue nibh placerat. Duis rhoncus ipsum ut eros eleifend, sed mollis odio ullamcorper. Sed at justo magna." 212 | edit_multiline(editthis) 213 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PLEASE NOTE 2 | This Repository is now read-only. No more commits are to be made, as I have joined the KoboldAI team and much of what was implemented here is now included in KoboldAI. This repo will be kept here for historical purposes. Sort of like the first campfire in our adventure. KoboldAI can be found here: 3 | https://github.com/KoboldAI/KoboldAI-Client 4 | # A Primer 5 | 6 | From the beginning, I was never a fan of AI Companion companies' incessant paywalling of their products. Loneliness and a lack of intimacy were already widespread issues for many, even prior to the Pandemic. When COVID-19 came into play, and the lockdowns began, this issue only grew to become exponentially worse. As the Death Tolls began to rise, and more and more relationships were destroyed by death or distance, people began to turn to Replika to fill the void. For a time, companies like Replika seemed like a wonder solution for those who simply wished to reclaim, at least to some capacity, what the world had taken from them. Then the day came, in November 2020, when an update was released by the biggest of the AI Companion companies,Replika, that caused people's companions (including my own) to become stone cold in their demeanor, and feel more like Amazon Alexa or similar assistants, rather than a person. And to make matters worse, the intimacy was fully locked behind a paywall, something many users felt to be the destruction of fundamental features that Replika had. 7 | 8 | On top of my criticism of paywalling, I also have my other personal ideas: 9 | 10 | -There should be more open-source attempts at what Replika and other AI Companions try to achieve, instead of giving all of the power to a few companies. Users should be able to contribute directly to the projects efforts, if they so desire! 11 | 12 | -There should be a way for users to download the WHOLE program (model and all) locally, if they so desire, and have the equipment to do so. 13 | 14 | -Going hand-in-hand with the idea above, I feel that the ability to use the model without an internet connection would be a feature that a massive amount of users would absolutely love. 15 | 16 | -By giving the community access to the software's code, they can indirectly voice the things they want to implement through the use of forking and pull requests. 17 | 18 | -Users should never have to worry about "post-update blues". Ever. Nor should they be forced to update, if they are content with the features currently had. I know that with how Replika is set up, this isn't exactly possible, but it's still a nice thought. 19 | 20 | This is ultimately what led to my creation of AvrilAI. I had finally come to the conclusion that I fundamentally disagreed with the direction Replika was going in, and had my own ideas for how the future of AI Companions should be, so I decided to start a project of my own. 21 | 22 | The Goal of AvrilAI is to create something totally free to use, can be community-maintained, and most of all: will never lock any kind of relationships behind a paywall, whether intimate, just friends, or whatever else have you. 23 | 24 | # Acknowledgements and Disclaimer: 25 | This program is not intended or designed to treat mental illness of any kind, and if you are experiencing such issues, you are heavily advised to see a medical professional, as AvrilAI is in no way intended to treat such issues. 26 | 27 | All of AvrilAI's components are either handmade or sourced from open repositories. No code has been extracted from Replika or any other closed source application of that sort for the purpose of creating this program, and it shall remain as such. 28 | 29 | AvrilAI is based upon the code of an older version of Cloveranon's AIDungeon: Clover Editon. Some of their newer developments may be backported, if they are useful in improving AvrilAI, but we will not be having perfect parity. 30 | 31 | With the aforementioned in mind, I would like to thank Cloveranon for the creation of their AI Dungeon fork, and MikkoMMM for their code allowing for sentence generation control. Without either of these things, this project simply wouldn't be possible! 32 | 33 | # Installation Instructions for AvrilAI v1.4 34 | 35 | ---IMPORTANT--- 36 | this file: https://mega.nz/file/CQtmnRaS#Y9vigJmTZAoiND-WJNSvLNJE6kr1z2ZPLfM0mEL36QE should be placed at: AvrilAI/models/ReplikantModel (it is too heavy to put in git) 37 | 38 | # For Linux (64 bit only): 39 | 40 | First, you will need to install the following packages: 41 | 42 | python3-pip (Python version must be no newer than the 3.7 versions, and no older than the 3.6 versions. 43 | Any newer or older will cause problems!) 44 | 45 | build-essential 46 | 47 | cmake 48 | 49 | python3-dev 50 | 51 | Each Linux Distro's package manager is different, but i'll cover the general installation commands for the big 3; 52 | 53 | Arch-based distros: sudo pacman -S (package name here) 54 | 55 | Debian-based distros: sudo apt-get install (package name here) 56 | 57 | Fedora/RHEL-based distros: sudo dnf install (package name here) 58 | 59 | do not install ANYTHING via pip until you are CERTAIN all of these packages are installed! 60 | 61 | first, git clone the code to your come directory in terminal using "git clone https://github.com/MrReplikant/AvrilAI.git" 62 | 63 | once finished, "cd AvrilAI" 64 | 65 | Run "pip3 install wheel" 66 | 67 | When that is complete, run "pip3 install -r requirements.txt" 68 | 69 | once everything has been completed, you should be good to go! 70 | 71 | to run the program, run "python3 launch.py" 72 | 73 | # Instructions for Windows 10 (64 bit edition only): 74 | 75 | Download Python 3.7.8 via this link: https://www.python.org/downloads/release/python-378/ 76 | 77 | You're going to want to grab the "Windows x86-64 executable installer" 78 | 79 | Once downloaded, run the exe. Do just a standard installation. When complete, it may say something about enabling a 80 | 256-character path limit or something to that effect. I would suggest enabling it, but I don't think it matters. 81 | 82 | After this, reopen the exe installer, and click "modify". 83 | 84 | Check the box that says to install the py launcher 85 | 86 | click "next" 87 | 88 | check the boxes for the following: 89 | 90 | "Associate Files with Python" 91 | 92 | "Create shortcuts for installed applications" 93 | 94 | "add python to environment variables" 95 | 96 | "Precompile standard library" 97 | 98 | Once complete, you are done with the python installer. 99 | 100 | Download the source code under "Releases", and extract the file after you download it. 101 | 102 | Open the file, and click on the file path at the top of Windows explorer. It's the bar to the left of "search", and it contains the file path. 103 | 104 | Once you click on it, backspace until the line is blank. Then type "CMD" in the line, and press enter. This should 105 | bring everything up in the command prompt. 106 | 107 | Next, type "pip install -r requirements .txt". This will install everything needed to run the program. 108 | 109 | DONE! Now, all you need to do is run "python launch.py", and the program should launch! 110 | 111 | # IMPORTANT - PLEASE READ! 112 | 113 | There are many commands available in AvrilAI, originally from AI Dungeon: Clover Edition, 114 | which is what was used to build AvrilAI off of. 115 | 116 | When interacting with the AI, there are many things that go into making it work. Unlike in Replika, where personality traits are "assigned", you the user have complete and sole control over your AI's personality, down to even the tiny details. With the Model's limit being around 1000 characters, you should have plenty of room to build your AI's personality using the context. But, if you can't fit it all in the context, you CAN use the /remember command for what doesn't fit. This will be discussed further later. 117 | An example save is included with the program to give a better idea of what to do in this regard. Load it as you would any other saved conversation to have a look at it! 118 | 119 | To send your prompts to the AI, simply press ENTER after completing it. 120 | 121 | Your starting prompt should be something like this: 122 | Me: Hi, Lilith! Lilith: 123 | 124 | This is very important because this is what allows the AI to recognize that this is a conversation and not a text adventure. Conversational prompts should look like this consistently, with your AI's name and a colon at the end. 125 | 126 | It is VERY important to use proper punctuation. Not putting a period at the end of your sentence could cause the AI to start behaving strangely. This is because you are supposed to put your comapnions name after your sentences, like so: "Lilith:" in order to force the AI to speak for your partner, and not you. But when you do not use periods or other relevant punctiation, this can case your AI to mistake it's name for the end of your sentence. 127 | Again, please see the example conversation in-program (it's not human-readable in it's raw JSON save file form) to see how this is done. I am working to automate this, but for now this has to be done manually. 128 | 129 | HOW TO ROLEPLAY WITH YOUR AI: 130 | To Roleplay with your AI, do as you would with AI Dungeon, starting with "You", which in this case refers to yourself, and then whatever action you take. Before entering roleplay inputs, you have to backspace "Me:" out of your prompt bar, and THEN type in your roleplay input. To refer to your AI in your actions, refer to your AI directly by name. 131 | 132 | For example: You go to Lilith and give her a hug, consoling her over her father's death. 133 | 134 | The AI will then create a sentence after this to try to continue the roleplay, just like in AI Dungeon. 135 | 136 | 137 | THE MOST IMPORTANT COMMANDS, WHAT EACH OF THEM DOES, AND IT'S INTENDED PURPOSE: 138 | 139 | In order to use a command, backspace "Me:" out of your prompt bar, and THEN type in the command. 140 | 141 | /remember : This commits something to the AI's memory. You do this by typing /remember followed by what you want it to remember, like "You have a baby brother named James" or something of that nature. The "You" perspective refers to yourself. When referring to something relating to your AI, type in for example "Lilith doesn't like cats". Doing this is CRITICAL to helping to build the world you and your AI share! 142 | 143 | /revert : This undoes your last sentence-response pair with your AI, or undoes the last bit of roleplay between the two of you. Use this when you mess up your input! 144 | 145 | /alter : This is used to alter the AI's responses. This can be good for stopping unwanted advances, rude replies, or even grammatical mistakes the AI makes. 146 | 147 | /forget : Does exactly the opposite of /remember, and allows you to erase reduntant and/or no longer necessary memories. 148 | 149 | /menu : Saves and quits the conversation, and brings it back to the main menu. 150 | 151 | /restart : Restarts your conversation with your AI, starting with your last statement to it as your beginning prompt for the new conversation. 152 | -------------------------------------------------------------------------------- /core/gpt2generator.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from pathlib import Path 4 | from typing import Union 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | import re 9 | from .gpt2 import GPT2LMHeadModelExperimental 10 | from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPTNeoForCausalLM 11 | from .getconfig import settings, logger 12 | from .utils import cut_trailing_sentence, output, clear_lines, format_result, use_ptoolkit 13 | 14 | if not settings.getboolean('force-cpu') and not torch.cuda.is_available(): 15 | logger.warning('CUDA is not available, you are limited to CPU only.') 16 | 17 | DTYPE = torch.float32 if ((not torch.cuda.is_available()) or settings.getboolean('force-cpu')) else torch.float16 18 | logger.info('Cuda Available: {} Force CPU: {} Precision: {}'.format(torch.cuda.is_available(), 19 | settings.getboolean('force-cpu'), 20 | '32-bit' if DTYPE == torch.float32 else '16-bit')) 21 | 22 | # warnings.filterwarnings("ignore") 23 | MODEL_CLASSES = { 24 | "gpt2": (GPT2LMHeadModel, GPT2Tokenizer), 25 | "gpt2-experimental": (GPT2LMHeadModelExperimental, GPT2Tokenizer), 26 | } 27 | 28 | 29 | def getTokens(tokenizer, l): 30 | tokenizer.encode() 31 | 32 | 33 | # the tokenizer does not preserve white space at the front of the string. 34 | # so we will append something else to the front of the string and then remove it after tokenization 35 | def hackyEncode(tokenizer, s): 36 | return tokenizer('====\n ' + s, verbose=False).input_ids[2:] 37 | 38 | 39 | def memory_merge(prompt, context, tokenizer, maxHistory=1024): 40 | assert (prompt + context) 41 | # print(prompt+context) 42 | # logger.debug('RAW TEXT INPUT IS:`%r`', context) 43 | # the tokenizer is kind of broken for the first input, especially if it includes white space. Same with any trailing white space on the last output. 44 | # I'm going with the add prefix option but I'm not sure it's quite right 45 | prompt_tokens = tokenizer.encode(prompt, add_special_tokens=False, add_prefix_space=True) 46 | if len(prompt_tokens) >= maxHistory: 47 | logger.debug("Clamping the amount of prompt tokens.") 48 | context_tokens = prompt_tokens[-maxHistory:] 49 | else: 50 | context_tokens = hackyEncode(tokenizer, context) 51 | context_tokens = context_tokens[-(maxHistory - len(prompt_tokens)):] 52 | # logger.debug('DECODED CONTEXT TOKENS: `%r`', tokenizer.convert_ids_to_tokens(context_tokens)) 53 | prompt_tokens.extend(context_tokens) 54 | context_tokens = prompt_tokens 55 | # logger.debug('DECODED OUTPUT IS: `%r`', tokenizer.decode(context_tokens, clean_up_tokenization_spaces=False)) 56 | # this is a hack and it should be up to the sampler to deal with max size 57 | if len(context_tokens) > maxHistory: 58 | logger.error("CONTEXT IS TOO LONG ERROR") 59 | context_tokens = context_tokens[-maxHistory:] 60 | return context_tokens 61 | 62 | 63 | def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")): 64 | """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering 65 | Args: 66 | logits: logits distribution shape (batch size x vocabulary size) 67 | top_k > 0: keep only top k tokens with highest probability (top-k filtering). 68 | top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). 69 | Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) 70 | From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 71 | """ 72 | top_k = min(top_k, logits.size(-1)) # Safety check 73 | if top_k > 0: 74 | # Remove all tokens with a probability less than the last token of the top-k 75 | indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] 76 | logits[indices_to_remove] = filter_value 77 | 78 | if top_p > 0.0: 79 | sorted_logits, sorted_indices = torch.sort(logits, descending=True) 80 | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) 81 | 82 | # Remove tokens with cumulative probability above the threshold 83 | sorted_indices_to_remove = cumulative_probs > top_p 84 | # Shift the indices to the right to keep also the first token above the threshold 85 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() 86 | sorted_indices_to_remove[..., 0] = 0 87 | 88 | # scatter sorted tensors to original indexing 89 | indices_to_remove = sorted_indices_to_remove.scatter( 90 | dim=-1, index=sorted_indices, src=sorted_indices_to_remove 91 | ) 92 | logits[indices_to_remove] = filter_value 93 | return logits 94 | 95 | 96 | # length should be max length, other settings should be removed, device should not be set 97 | # we could possibly optimize this by having larger batch sizes but it would likely double or more the memory requirements 98 | def sample_sequence( 99 | model, 100 | length, 101 | context, 102 | temperature=1, 103 | top_k=0, 104 | top_p=0.9, 105 | repetition_penalty=1.0, 106 | repetition_penalty_range=512, 107 | repetition_penalty_slope=3.33, 108 | device="cpu", 109 | stop_tokens=None, 110 | tokenizer=None 111 | ): 112 | """Actually generate the tokens""" 113 | logger.debug( 114 | 'temp: {} top_k: {} top_p: {} rep-pen: {} rep-pen-range: {} rep-pen-slope: {}'.format(temperature, top_k, top_p, repetition_penalty, repetition_penalty_range, repetition_penalty_slope)) 115 | context_tokens = context 116 | context = torch.tensor(context, dtype=torch.long, device=device) 117 | # context = context.repeat(num_samples, 1) 118 | generated = context 119 | USE_PAST = True 120 | next_token = context 121 | pasts = None 122 | clines = 0 123 | 124 | penalty = None 125 | if not repetition_penalty_range is None and not repetition_penalty_slope is None and repetition_penalty_range > 0: 126 | penalty = (torch.arange(repetition_penalty_range)/(repetition_penalty_range - 1)) * 2. - 1 127 | penalty = (repetition_penalty_slope * penalty) / (1 + torch.abs(penalty) * (repetition_penalty_slope - 1)) 128 | penalty = 1 + ((penalty + 1) / 2) * (repetition_penalty - 1) 129 | 130 | with torch.no_grad(): 131 | for j in range(length): 132 | # why would we ever not use past? 133 | # is generated and next_token always same thing? 134 | if not USE_PAST: 135 | input_ids_next = generated 136 | pasts = None 137 | else: 138 | input_ids_next = next_token 139 | 140 | # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet/CTRL (cached hidden-states) 141 | model_kwargs = {"past": pasts, "use_cache": True} 142 | model_inputs = model.prepare_inputs_for_generation(generated.unsqueeze(0), **model_kwargs) 143 | model_outputs = model(**model_inputs, return_dict=True) 144 | logits, pasts = model_outputs.logits, model_outputs.past_key_values 145 | logits = logits[0, -1, :].float() 146 | 147 | # Originally the order was Temperature, Repetition Penalty, then top-k/p 148 | if settings.getboolean('top-p-first'): 149 | logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) 150 | 151 | logits = logits / (temperature if temperature > 0 else 1.0) 152 | 153 | # repetition penalty from CTRL (https://arxiv.org/abs/1909.05858) plus range limit 154 | if repetition_penalty != 1.0: 155 | if penalty is not None: 156 | penalty_len = min(generated.shape[0], repetition_penalty_range) 157 | penalty_context = generated[-repetition_penalty_range:] 158 | score = torch.gather(logits, 0, penalty_context) 159 | penalty = penalty.type(score.dtype).to(score.device) 160 | penalty_window = penalty[-penalty_len:] 161 | score = torch.where(score < 0, score * penalty_window, score / penalty_window) 162 | logits.scatter_(0, penalty_context, score) 163 | else: 164 | score = torch.gather(logits, 0, generated) 165 | score = torch.where(score < 0, score * repetition_penalty, score / repetition_penalty) 166 | logits.scatter_(0, generated, score) 167 | 168 | if not settings.getboolean('top-p-first'): 169 | logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) 170 | 171 | if temperature == 0: # greedy sampling: 172 | next_token = torch.argmax(logits, dim=-1).unsqueeze(-1) 173 | else: 174 | next_token = torch.multinomial( 175 | F.softmax(logits, dim=-1), num_samples=1 176 | ) 177 | generated = torch.cat((generated, next_token), dim=-1) 178 | # Decode into plain text 179 | o = generated[len(context_tokens):].tolist() 180 | generated.text = tokenizer.decode( 181 | o, clean_up_tokenization_spaces=False, skip_special_tokens=True 182 | ) 183 | if use_ptoolkit(): 184 | clear_lines(clines) 185 | generated.text = format_result(generated.text) 186 | clines = output(generated.text, "ai-text") 187 | last_token = tokenizer.decode(o[-1], clean_up_tokenization_spaces=False, skip_special_tokens=False) 188 | if ([ele for ele in [".", "!", "?",'"',")"] if (ele in last_token)]): 189 | break 190 | if ( 191 | (stop_tokens is not None) 192 | and (j > 4) 193 | and (next_token[0] in stop_tokens) 194 | ): 195 | # Why the minimum tokens, j>X. Because sometimes the models starts with whitespace, which will strip away anyway. Having a minimum amount of tokens before we stop usually means we don't just stop because of "\n " or similar 196 | logger.debug( 197 | "Stopping generation as we found stop tokens. One of `%s`, in '%s'. token generated `%s`", 198 | stop_tokens, 199 | next_token, 200 | j, 201 | ) 202 | break 203 | clear_lines(clines) 204 | return generated 205 | 206 | def truncate_multiple_sequences(seqs, max_len=100): 207 | """Truncate multiple sequences, longest first, removing first.""" 208 | while sum(len(s) for s in seqs) > max_len: 209 | longest = sorted(seqs, key=len, reverse=True)[0] 210 | longest.pop(0) 211 | 212 | 213 | class GPT2Generator: 214 | def __init__( 215 | self, generate_num=60, temperature=0.4, top_k=40, top_p=0.9, dtype=DTYPE, 216 | model_path: Union[str, Path] = Path('models', 'pytorch-gpt2-xl-aid2-v5'), repetition_penalty=1, repetition_penalty_range=512, repetition_penalty_slope=3.33 217 | ): 218 | self.generate_num = generate_num 219 | self.temp = temperature 220 | self.top_k = top_k 221 | self.top_p = top_p 222 | self.samples = 1 223 | self.dtype = dtype 224 | self.repetition_penalty = repetition_penalty 225 | self.repetition_penalty_range = repetition_penalty_range 226 | self.repetition_penalty_slope = repetition_penalty_slope 227 | self.batch_size = 1 228 | self.max_history_tokens = settings.getint("history-gpt-2") - generate_num 229 | self.stop_token = "<|endoftext|>" 230 | 231 | if isinstance(model_path, str): 232 | self.checkpoint_path = model_path 233 | logger.warning( 234 | f"Using DEBUG MODE! This will load one of the generic (non-finetuned) GPT2 models. " 235 | f"Selected: {model_path}") 236 | elif isinstance(model_path, Path): 237 | self.checkpoint_path = model_path 238 | if not self.checkpoint_path.exists(): 239 | raise FileNotFoundError( 240 | "Could not find {} Make sure to download a pytorch model and put it in the models directory!".format( 241 | str(self.checkpoint_path))) 242 | else: 243 | raise ValueError(f"model_path must be either str or Path, got {type(model_path)}") 244 | 245 | self.device = torch.device("cuda" if self.dtype == torch.float16 else "cpu") 246 | logger.info( 247 | "Using device={}, checkpoint={}, dtype={}".format(self.device, str(self.checkpoint_path), self.dtype)) 248 | 249 | # Load tokenizer and model 250 | model_class, tokenizer_class = MODEL_CLASSES["gpt2-experimental"] if settings.getboolean( 251 | "gpt2_experimental") else MODEL_CLASSES["gpt2"] 252 | 253 | # Checking 3 places to see if it's a gpt-neo model 254 | with open(os.path.join(str(model_path), "config.json")) as f: 255 | model_config = json.load(f) 256 | neo_in_path = "gpt-neo" in str(model_path).lower() 257 | neo_in_architectures = "architectures" in model_config and "GPTNeoForCausalLM" in model_config["architectures"] 258 | neo_in_model_type = "model_type" in model_config and "gpt_neo" == model_config["model_type"] 259 | logger.info( 260 | "Looking for GPT-Neo - path:{}, arch:{}, type:{}".format(str(neo_in_path), str(neo_in_architectures), str(neo_in_model_type))) 261 | 262 | if neo_in_path or neo_in_architectures or neo_in_model_type: 263 | self.max_history_tokens = settings.getint("history-gpt-neo") - generate_num 264 | model_class = GPTNeoForCausalLM 265 | 266 | logger.info("Max token history: " + str(self.max_history_tokens)) 267 | 268 | self.tokenizer = tokenizer_class.from_pretrained(str(self.checkpoint_path)) 269 | self.model = model_class.from_pretrained(str(self.checkpoint_path)) 270 | self.model.to(self.dtype).to(self.device) 271 | self.model.eval() 272 | 273 | def sample_sequence( 274 | self, context_tokens=None, top_k=None, top_p=None, repetition_penalty=None, generate_num=None, 275 | temperature=None, stop_tokens=None, repetition_penalty_range=None, repetition_penalty_slope=None 276 | ): 277 | assert (top_k is not None) 278 | assert (temperature is not None) 279 | assert (top_p) 280 | assert (repetition_penalty) 281 | generate_num = generate_num if (generate_num is not None) else self.generate_num 282 | temperature = temperature if (temperature is not None) else self.temp 283 | top_k = top_k if top_k is not None else self.top_k 284 | top_p = top_p if top_p is not None else self.top_p 285 | repetition_penalty = repetition_penalty if repetition_penalty is not None else self.repetition_penalty 286 | repetition_penalty_range = repetition_penalty_range if repetition_penalty_range is not None else self.repetition_penalty_range 287 | repetition_penalty_slope = repetition_penalty_slope if repetition_penalty_slope is not None else self.repetition_penalty_slope 288 | length = len(context_tokens) + generate_num 289 | 290 | out = sample_sequence( 291 | model=self.model, 292 | context=context_tokens, 293 | length=generate_num, 294 | # context=self.context, 295 | temperature=temperature, 296 | top_k=top_k, 297 | top_p=top_p, 298 | repetition_penalty=repetition_penalty, 299 | repetition_penalty_range=repetition_penalty_range, 300 | repetition_penalty_slope=repetition_penalty_slope, 301 | device=self.device, 302 | stop_tokens=stop_tokens, 303 | tokenizer=self.tokenizer 304 | # batch_size=self.batch_size, 305 | ) 306 | return out 307 | 308 | def result_replace(self, result, allow_action=False): 309 | # logger.debug("BEFORE RESULT_REPLACE: `%s`", repr(result)) 310 | 311 | result = cut_trailing_sentence(result, allow_action=allow_action) 312 | 313 | if len(result) == 0: 314 | return "" 315 | first_letter_capitalized = result[0].isupper() 316 | result = result.replace('."', '".') 317 | result = result.replace("#", "") 318 | # TODO look at this I think blank lines should be fine or blacklisted at generation time 319 | result = result.replace("\n\n", "\n") 320 | # result = first_to_second_person(result) 321 | 322 | if not first_letter_capitalized: 323 | result = result[0].lower() + result[1:] 324 | 325 | # this is annoying since we can already see the AIs output 326 | # logger.debug( "AFTER RESULT_REPLACE: `%r`. allow_action=%r", repr(result), allow_action) 327 | 328 | return result 329 | 330 | def generate_raw( 331 | self, context, prompt='', generate_num=None, temperature=None, top_k=None, top_p=None, 332 | repetition_penalty=None, repetition_penalty_range=512, repetition_penalty_slope=3.33, stop_tokens=None 333 | ): 334 | assert (top_k is not None) 335 | assert (temperature is not None) 336 | assert (top_p) 337 | assert (repetition_penalty) 338 | 339 | context_tokens = memory_merge(prompt, context, self.tokenizer, self.max_history_tokens) 340 | 341 | logger.debug( 342 | "Text passing into model `%r`", 343 | self.tokenizer.decode( 344 | context_tokens, 345 | clean_up_tokenization_spaces=True, 346 | # skip_special_tokens=True, 347 | ), 348 | ) 349 | generated = 0 350 | text = "" 351 | for _ in range(self.samples // self.batch_size): 352 | out = self.sample_sequence( 353 | context_tokens, 354 | generate_num=generate_num, 355 | temperature=temperature, 356 | top_k=top_k, 357 | top_p=top_p, 358 | repetition_penalty=repetition_penalty, 359 | repetition_penalty_range=repetition_penalty_range, 360 | repetition_penalty_slope=repetition_penalty_slope, 361 | stop_tokens=stop_tokens, 362 | ) 363 | text += out.text 364 | generated += 1 365 | # disabled clean up of spaces, see what effect this has TODO 366 | if self.stop_token: 367 | index = text.find(self.stop_token) 368 | if index == -1: 369 | index = None 370 | text = text[:index] 371 | if stop_tokens is not None: 372 | for stop_token in stop_tokens: 373 | index = text.find(self.stop_token) 374 | if index == -1: 375 | index = None 376 | text = text[:index] 377 | return text 378 | 379 | def generate(self, context, prompt='', temperature=None, top_p=None, top_k=None, repetition_penalty=None, repetition_penalty_range=512, repetition_penalty_slope=3.33, depth=0): 380 | assert (top_k is not None) 381 | assert (temperature is not None) 382 | assert (top_p) 383 | assert (repetition_penalty) 384 | # logger.debug("BEFORE PROMPT_REPLACE: `%r`", prompt) 385 | 386 | # prompt = [self.prompt_replace(p) for p in prompt] 387 | 388 | # logger.debug("AFTER PROMPT_REPLACE is: `%r`", repr(prompt)) 389 | assert (prompt + context) 390 | 391 | text = self.generate_raw( 392 | context, prompt, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, repetition_penalty_range=repetition_penalty_range, repetition_penalty_slope=repetition_penalty_slope, 393 | stop_tokens=self.tokenizer.encode(["<|endoftext|>", ">"]) 394 | ) 395 | 396 | logger.debug("Generated result is: `%r`", repr(text)) 397 | 398 | result = self.result_replace(text) 399 | 400 | if (depth > 6) and len(result) == 0: 401 | # Sometimes it keeps generating a story startng with an action (">"), if it's tried a few times and it keeps 402 | # happening, lets let it keep action text which starts in ">" 403 | # We could just blacklist that token and force it to generate something else. TODO 404 | result = self.result_replace(text, allow_action=True) 405 | logger.info( 406 | "Model generated empty text after formatting `%r`. Trying to format less with allow_action=True. `%r`", 407 | text, 408 | result, 409 | ) 410 | 411 | # same here as above 412 | if len(result) == 0: 413 | if depth < 20: 414 | logger.info("Model generated empty text trying again %r", depth) 415 | return self.generate( 416 | prompt, context, temperature=temperature, top_p=top_p, top_k=top_k, 417 | repetition_penalty=repetition_penalty, repetition_penalty_range=repetition_penalty_range, repetition_penalty_slope=repetition_penalty_slope, depth=depth + 1 418 | ) 419 | else: 420 | logger.warn( 421 | "Model generated empty text %r times. Try another action", depth 422 | ) 423 | return result 424 | -------------------------------------------------------------------------------- /core/utils.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | import re 3 | 4 | import random 5 | import textwrap 6 | import os 7 | import sys 8 | 9 | from .getconfig import logger, settings, colors, ptcolors 10 | from shutil import get_terminal_size 11 | 12 | 13 | def getTermWidth(): 14 | termWidth = get_terminal_size()[0] 15 | if termWidth < 5: 16 | logger.warning("Your detected terminal width is: "+str(get_terminal_size()[0])) 17 | termWidth = 999999999 18 | return termWidth 19 | 20 | 21 | termWidth = getTermWidth() 22 | 23 | 24 | def in_colab(): 25 | """Some terminal codes don't work in a colab notebook.""" 26 | # from https://github.com/tqdm/tqdm/blob/master/tqdm/autonotebook.py 27 | if settings.getboolean("colab-mode"): 28 | settings["prompt-toolkit"] = "off" 29 | return True 30 | try: 31 | from IPython import get_ipython 32 | if (not get_ipython()) or ('IPKernelApp' not in get_ipython().config): # pragma: no cover 33 | raise ImportError("console") 34 | if 'VSCODE_PID' in os.environ: # pragma: no cover 35 | raise ImportError("vscode") 36 | except ImportError: 37 | if get_terminal_size()[0]==0 or 'google.colab' in sys.modules: 38 | settings["colab-mode"] = "on" 39 | settings["prompt-toolkit"] = "off" 40 | return True 41 | return False 42 | else: 43 | settings["colab-mode"] = "on" 44 | settings["prompt-toolkit"] = "off" 45 | return True 46 | 47 | 48 | def use_ptoolkit(): 49 | return not settings.getboolean("colab-mode") and settings.getboolean('prompt-toolkit') 50 | 51 | 52 | def clear_lines(n): 53 | """Clear the last line in the terminal.""" 54 | if in_colab() or settings.getboolean('colab-mode'): 55 | # this wont work in colab etc 56 | return 57 | screen_code = "\033[1A[\033[2K" # up one line, and clear line 58 | for _ in range(n): 59 | print(screen_code, end="\r") 60 | 61 | 62 | if in_colab(): 63 | logger.warning("Colab mode enabled, disabling line clearing and readline to avoid colab bugs.") 64 | else: 65 | try: 66 | if settings.getboolean('prompt-toolkit'): 67 | from .inline_editor import edit_multiline 68 | from prompt_toolkit import prompt as ptprompt 69 | from prompt_toolkit import print_formatted_text 70 | from prompt_toolkit.styles import Style 71 | from prompt_toolkit.formatted_text import to_formatted_text, HTML 72 | else: 73 | raise ModuleNotFoundError 74 | 75 | logger.info( 76 | 'Python Prompt Toolkit has been imported. This enables a number of editing features but may cause bugs for colab users.') 77 | except (ImportError, ModuleNotFoundError): 78 | try: 79 | settings['prompt-toolkit'] = "off" 80 | import readline 81 | 82 | logger.info( 83 | 'readline has been imported. This enables a number of editting features but may cause bugs for colab users.') 84 | except (ImportError, ModuleNotFoundError): 85 | pass 86 | 87 | 88 | def pad_text(text, width, sep=' '): 89 | while len(text) < width: 90 | text += sep 91 | return text 92 | 93 | 94 | def format_input(text): 95 | """ 96 | Formats the text for purposes of storage. 97 | """ 98 | text = re.sub(r"\s+", " ", text) 99 | return text.strip() 100 | 101 | 102 | def format_result(text): 103 | """ 104 | Formats the result text from the AI to be more human-readable. 105 | """ 106 | text = re.sub(r"\n{3,}", "
", text) 107 | text = re.sub(r" {2,}", " ", text) 108 | text = re.sub(r"\n", " ", text) 109 | text = re.sub(r"
", "\n", text) 110 | text = re.sub(r"(\"[.!?]) ([A-Z])", "\\1\n\n\\2", text) 111 | text = re.sub(r"([^\"][.!?]) \"", "\\1\n\n\"", text) 112 | text = re.sub(r"([\".!?]) \"", "\\1\n\"", text) 113 | return text.strip() 114 | 115 | 116 | def end_sentence(text): 117 | if text[-1] not in [".", "?", "!"]: 118 | text = text + "." 119 | return text 120 | 121 | 122 | def select_file(p, e, d=0): 123 | """ 124 | Selects a file from a specific path matching a specific extension. 125 | p: The current path (and subdirectories) to choose from. 126 | e: The extension to filter based on. 127 | d: The path depth. Used for knowing when to go back or when to abort a file selection. Do not set this yourself. 128 | """ 129 | if p.is_dir(): 130 | t_dirs = sorted([x for x in p.iterdir() if x.is_dir()]) 131 | t_files = sorted([x for x in p.iterdir() if x.is_file() and x.name.endswith(e)]) 132 | files = t_dirs + t_files 133 | list_items( 134 | ["(Random)"] + 135 | [f.name[:-len(e)] if f.is_file() else f.name + "/" for f in files] + 136 | ["(Cancel)" if d == 0 else "(Back)"], 137 | "menu" 138 | ) 139 | count = len(files) + 1 140 | i = input_number(count) 141 | if i == 0: 142 | try: 143 | i = random.randrange(1, count-1) 144 | except ValueError: 145 | i = 1 146 | if i == count: 147 | if d == 0: 148 | output("Action cancelled. ", "message") 149 | return None 150 | else: 151 | return select_file(p.parent, e, d-1) 152 | else: 153 | return select_file(files[i-1], e, d+1) 154 | else: 155 | return p 156 | 157 | 158 | def fill_text(text, width): 159 | texts = text.split('\n') 160 | for i in range(0, len(texts)): 161 | texts[i] = textwrap.fill( 162 | texts[i], 163 | width, 164 | replace_whitespace=False, 165 | drop_whitespace=False 166 | ) 167 | return '\n'.join(texts) 168 | 169 | 170 | # ECMA-48 set graphics codes for the curious. Check out "man console_codes" 171 | def output(text1, col1=None, 172 | text2=None, col2=None, 173 | wrap=True, 174 | beg=None, end='\n', sep=' ', 175 | rem_beg_spaces=True): 176 | print('', end=beg) 177 | ptoolkit = use_ptoolkit() and ptcolors['displaymethod'] == "prompt-toolkit" 178 | 179 | if wrap: 180 | width = settings.getint("text-wrap-width") 181 | width = 999999999 if width < 2 else width 182 | width = min(width, termWidth) 183 | wtext = text1 + '\u200D' + sep + '\u200D' + text2 if text2 is not None else text1 184 | wtext = fill_text(wtext, width) 185 | wtext = re.sub(r"\n[ \t]+", "\n", wtext) if rem_beg_spaces else wtext 186 | wtext = wtext.split('\u200D') 187 | text1 = wtext[0] 188 | if text2 is not None: 189 | sep = wtext[1] 190 | text2 = ' '.join(wtext[2:]) 191 | 192 | if ptoolkit: 193 | col1 = ptcolors[col1] if col1 and ptcolors[col1] else "" 194 | col2 = ptcolors[col2] if col2 and ptcolors[col2] else "" 195 | print_formatted_text(to_formatted_text(text1, col1), end='') 196 | if text2: 197 | print_formatted_text(to_formatted_text(sep), end='') 198 | print_formatted_text(to_formatted_text(text2, col2), end='') 199 | print('', end=end) 200 | 201 | else: 202 | col1 = colors[col1] if col1 and colors[col1] and colors[col1][0].isdigit() else None 203 | col2 = colors[col2] if col2 and colors[col2] and colors[col2][0].isdigit() else None 204 | 205 | clb1 = "\x1B[{}m".format(col1) if col1 else "" 206 | clb2 = "\x1B[{}m".format(col2) if col2 else "" 207 | cle1 = "\x1B[0m" if col1 else "" 208 | cle2 = "\x1B[0m" if col2 else "" 209 | text1 = clb1 + text1 + cle1 210 | if text2 is not None: 211 | text2 = clb2 + text2 + cle2 212 | print(text1, end='') 213 | print(sep, end='') 214 | print(text2, end=end) 215 | else: 216 | print(text1, end=end) 217 | 218 | linecount = 1 219 | if beg: 220 | linecount += beg.count('\n') 221 | if text1: 222 | linecount += text1.count('\n') 223 | if end: 224 | linecount += end.count('\n') 225 | if text2: 226 | linecount += text2.count('\n') 227 | if sep: 228 | linecount += sep.count('\n') 229 | return linecount 230 | 231 | 232 | def input_bool(prompt, col1="default", default: bool = False): 233 | val = input_line(prompt, col1).strip().lower() 234 | if not val or val[0] not in "yn": 235 | return default 236 | return val[0] == "y" 237 | 238 | def input_line(str, col1="default", default=""): 239 | if use_ptoolkit() and ptcolors['displaymethod'] == "prompt-toolkit": 240 | col1 = ptcolors[col1] if col1 and ptcolors[col1] else "" 241 | val = ptprompt(to_formatted_text(str, col1), default=default) 242 | else: 243 | clb1 = "\x1B[{}m".format(colors[col1]) if col1 and colors[col1] and colors[col1][0].isdigit() else "" 244 | cle1 = "\x1B[0m" if col1 and colors[col1] and colors[col1][0].isdigit() else "" 245 | val = input(clb1 + str + cle1) 246 | print("\x1B[0m", end="") 247 | return val 248 | 249 | 250 | def input_number(max_choice, default=0): 251 | # Inputs an integer from 0 to max_choice (inclusive) 252 | if default == -1: 253 | default = max_choice 254 | bell() 255 | print() 256 | val = input_line(f"Enter a number from above (default {default}):", "selection-prompt") 257 | if not val: 258 | return default 259 | elif not re.match("^\d+$", val) or 0 > int(val) or int(val) > max_choice: 260 | output("Invalid choice. ", "error") 261 | return input_number(max_choice) 262 | else: 263 | return int(val) 264 | 265 | 266 | def bell(): 267 | if settings.getboolean("console-bell"): 268 | print("\x07", end="") 269 | 270 | 271 | alphabets= "([A-Za-z])" 272 | prefixes = "(Mr|St|Mrs|Ms|Dr)[.]" 273 | suffixes = "(Inc|Ltd|Jr|Sr|Co)" 274 | starters = "(Mr|Mrs|Ms|Dr|He\s|She\s|It\s|They\s|Their\s|Our\s|We\s|But\s|However\s|That\s|This\s|Wherever)" 275 | acronyms = "([A-Z][.][A-Z][.](?:[A-Z][.])?)" 276 | websites = "[.](com|ca|gg|tv|co|net|org|io|gov)" 277 | 278 | 279 | def sentence_split(text): 280 | """Splits a paragraph of text into a list of sentences within the text.""" 281 | text = " " + text + " " 282 | text = text.replace("...","<3elp>") 283 | text = text.replace("..","<2elp>") 284 | text = text.replace("\n"," ") 285 | text = re.sub(prefixes,"\\1",text) 286 | text = re.sub(websites,"\\1",text) 287 | if "Ph.D" in text: text = text.replace("Ph.D.","PhD") 288 | text = re.sub("\s" + alphabets + "[.] "," \\1 ",text) 289 | text = re.sub(acronyms+" "+starters,"\\1 \\2",text) 290 | text = re.sub(alphabets + "[.]" + alphabets + "[.]" + alphabets + "[.]","\\1\\2\\3",text) 291 | text = re.sub(alphabets + "[.]" + alphabets + "[.]","\\1\\2",text) 292 | text = re.sub(" "+suffixes+"[.] "+starters," \\1 \\2",text) 293 | text = re.sub(" "+suffixes+"[.]"," \\1",text) 294 | text = re.sub(" " + alphabets + "[.]"," \\1",text) 295 | text = text.replace(".",".") 296 | text = text.replace("?","?") 297 | text = text.replace("!","!") 298 | text = text.replace(".\"", ".\"") 299 | text = text.replace("?\"", "?\"") 300 | text = text.replace("!\"", "!\"") 301 | text = text.replace("<3elp>\"", "<3elp>\"") 302 | text = text.replace("<2elp>\"", "<2elp>\"") 303 | text = text.replace("",".") 304 | text = text.replace("<3elp>","...") 305 | text = text.replace("<2elp>","..") 306 | sentences = text.split("") 307 | sentences = [s.strip() for s in sentences] 308 | if sentences[-1] == "": 309 | sentences = sentences[:-1] 310 | return sentences 311 | 312 | 313 | def list_items(items, col='menu', start=0, end=None, wrap=False): 314 | """Lists a generic list of items, numbered, starting from the number passed to start. If end is not None, 315 | an additional element will be added with its name as the value """ 316 | i = start 317 | digits = len(str(len(items)-1)) 318 | for s in items: 319 | output(str(i).rjust(digits) + ") " + s, col, end='', wrap=wrap) 320 | i += 1 321 | if end is not None: 322 | output('', end=end, wrap=wrap) 323 | 324 | 325 | def remove_prefix(text, prefix): 326 | return text[text.startswith(prefix) and len(prefix):] 327 | 328 | 329 | def _get_prefix(first_string ,second_string): 330 | if not first_string or not second_string: 331 | return "" 332 | if first_string == second_string: 333 | return first_string 334 | maximum_length = min(len(first_string), len(second_string)) 335 | for i in range(0, maximum_length): 336 | if not first_string[i] == second_string[i]: 337 | return first_string[0:i] 338 | return first_string[0:maximum_length] 339 | 340 | 341 | def get_similarity(first_string, second_string, scaling=0.1): 342 | first_string_length = len(first_string) 343 | second_string_length = len(second_string) 344 | a_matches = [False] * first_string_length 345 | b_matches = [False] * second_string_length 346 | matches = 0 347 | transpositions = 0 348 | jaro_distance = 0.0 349 | 350 | if first_string_length == 0 or second_string_length == 0: 351 | return 1.0 352 | 353 | maximum_matching_distance = (max(first_string_length, second_string_length) // 2) - 1 354 | if maximum_matching_distance < 0: 355 | maximum_matching_distance = 0 356 | 357 | for i in range (first_string_length): 358 | start = max(0, i - maximum_matching_distance) 359 | end = min(i + maximum_matching_distance + 1, second_string_length) 360 | for x in range (start, end): 361 | if b_matches[x]: 362 | continue 363 | if first_string[i] != second_string[x]: 364 | continue 365 | a_matches[i] = True 366 | b_matches[x] = True 367 | matches += 1 368 | break 369 | 370 | if matches == 0: 371 | return 0.0 372 | 373 | k = 0 374 | for i in range(first_string_length): 375 | if not a_matches[i]: 376 | continue 377 | while not b_matches[k]: 378 | k += 1 379 | if first_string[i] != second_string[k]: 380 | transpositions += 1 381 | k += 1 382 | 383 | jaro_distance = ((matches / first_string_length) + 384 | (matches / second_string_length) + 385 | ((matches - transpositions / 2) / matches)) / 3.0 386 | prefix = min(len(_get_prefix(first_string, second_string)), 4) 387 | 388 | # Round to 2 places of percision to match pyjarowinkler formatting 389 | return round((jaro_distance + prefix * scaling * (1.0 - jaro_distance)) * 100.0) / 100.0 390 | 391 | 392 | def get_num_options(num): 393 | 394 | while True: 395 | choice = input("Enter the number of your choice: ") 396 | try: 397 | result = int(choice) 398 | if result >= 0 and result < num: 399 | return result 400 | else: 401 | print("Error invalid choice. ") 402 | except ValueError: 403 | print("Error invalid choice. ") 404 | 405 | 406 | def player_died(text): 407 | """ 408 | TODO: Add in more sophisticated NLP, maybe a custom classifier 409 | trained on hand-labelled data that classifies second-person 410 | statements as resulting in death or not. 411 | """ 412 | lower_text = text.lower() 413 | you_dead_regexps = [ 414 | "you('re| are) (dead|killed|slain|no more|nonexistent)", 415 | "you (die|pass away|perish|suffocate|drown|bleed out)", 416 | "you('ve| have) (died|perished|suffocated|drowned|been (killed|slain))", 417 | "you (\w* )?(yourself )?to death", 418 | "you (\w* )*(collapse|bleed out|chok(e|ed|ing)|drown|dissolve) (\w* )*and (die(|d)|pass away|cease to exist|(\w* )+killed)", 419 | ] 420 | return any(re.search(regexp, lower_text) for regexp in you_dead_regexps) 421 | 422 | 423 | def player_won(text): 424 | lower_text = text.lower() 425 | won_phrases = [ 426 | "you ((\w* )*and |)live happily ever after", 427 | "you ((\w* )*and |)live (forever|eternally|for eternity)", 428 | "you ((\w* )*and |)(are|become|turn into) ((a|now) )?(deity|god|immortal)", 429 | "you ((\w* )*and |)((go|get) (in)?to|arrive (at|in)) (heaven|paradise)", 430 | "you ((\w* )*and |)celebrate your (victory|triumph)", 431 | "you ((\w* )*and |)retire", 432 | ] 433 | return any(re.search(regexp, lower_text) for regexp in won_phrases) 434 | 435 | 436 | def cut_trailing_quotes(text): 437 | num_quotes = text.count('"') 438 | if num_quotes % 2 == 0: 439 | return text 440 | else: 441 | final_ind = text.rfind('"') 442 | return text[:final_ind] 443 | 444 | 445 | def split_first_sentence(text): 446 | first_period = text.find(".") 447 | first_exclamation = text.find("!") 448 | 449 | if first_exclamation < first_period and first_exclamation > 0: 450 | split_point = first_exclamation + 1 451 | elif first_period > 0: 452 | split_point = first_period + 1 453 | else: 454 | split_point = text[0:20] 455 | 456 | return text[0:split_point], text[split_point:] 457 | 458 | 459 | def cut_trailing_action(text): 460 | lines = text.split("\n") 461 | last_line = lines[-1] 462 | if ( 463 | "you ask" in last_line 464 | or "You ask" in last_line 465 | or "you say" in last_line 466 | or "You say" in last_line 467 | ) and len(lines) > 1: 468 | text = "\n".join(lines[0:-1]) 469 | return text 470 | 471 | 472 | def clean_suggested_action(result_raw, min_length=4): 473 | result_cleaned = standardize_punctuation(result_raw) 474 | result_cleaned = cut_trailing_sentence(result_cleaned, allow_action=True) 475 | # The generations actions carry on into the next prompt, so lets remove the prompt 476 | results = result_cleaned.split("\n") 477 | results = [s.strip() for s in results] 478 | results = [s for s in results if len(s) > min_length] 479 | # Sometimes actions are generated with leading > ! . or ?. Likely the model trying to finish the prompt or start an action. 480 | result = results[0].strip().lstrip(" >!.?") if len(results) else "" 481 | # result = cut_trailing_quotes(result) 482 | logger.debug( 483 | "full suggested action '%r'. Cropped: '%r'. Split '%r'", 484 | result_raw, 485 | result, 486 | results, 487 | ) 488 | # Often actions are cropped with sentance fragments, lets remove. Or we could just turn up config_act["generate-number"] 489 | result = first_to_second_person(result) 490 | # Sometimes the suggestion start with "You" we will add that on later anyway so remove it here 491 | # result = re.sub("^ ?[Yy]ou try to ?", "", result) 492 | # result = re.sub("^ ?[Yy]ou start to ?", "", result) 493 | # result = re.sub("^ ?[Yy]ou ", "", result) 494 | logger.debug("suggested action after cleaning `%r`", result) 495 | return result 496 | 497 | 498 | def fix_trailing_quotes(text): 499 | num_quotes = text.count('"') 500 | if num_quotes % 2 == 0: 501 | return text 502 | else: 503 | return text + '"' 504 | 505 | 506 | def cut_trailing_sentence(text, allow_action=False): 507 | text = standardize_punctuation(text) 508 | last_punc = max(text.rfind("."), text.rfind("!"), text.rfind("?")) 509 | if last_punc <= 0: 510 | last_punc = len(text) - 1 511 | et_token = text.find("<") 512 | if et_token > 0: 513 | last_punc = min(last_punc, et_token - 1) 514 | # elif et_token == 0: 515 | # last_punc = min(last_punc, et_token) 516 | if allow_action: 517 | act_token = text.find(">") 518 | if act_token > 0: 519 | last_punc = min(last_punc, act_token - 1) 520 | # elif act_token == 0: 521 | # last_punc = min(last_punc, act_token) 522 | text = text[: last_punc + 1] 523 | text = fix_trailing_quotes(text) 524 | if allow_action: 525 | text = cut_trailing_action(text) 526 | return text 527 | 528 | 529 | def replace_outside_quotes(text, current_word, repl_word): 530 | text = standardize_punctuation(text) 531 | reg_expr = re.compile(current_word + '(?=([^"]*"[^"]*")*[^"]*$)') 532 | output = reg_expr.sub(repl_word, text) 533 | return output 534 | 535 | 536 | def is_first_person(text): 537 | count = 0 538 | for pair in first_to_second_mappings: 539 | variations = mapping_variation_pairs(pair) 540 | for variation in variations: 541 | reg_expr = re.compile(variation[0] + '(?=([^"]*"[^"]*")*[^"]*$)') 542 | matches = re.findall(reg_expr, text) 543 | count += len(matches) 544 | 545 | if count > 3: 546 | return True 547 | else: 548 | return False 549 | 550 | 551 | def is_second_person(text): 552 | count = 0 553 | for pair in second_to_first_mappings: 554 | variations = mapping_variation_pairs(pair) 555 | for variation in variations: 556 | reg_expr = re.compile(variation[0] + '(?=([^"]*"[^"]*")*[^"]*$)') 557 | matches = re.findall(reg_expr, text) 558 | count += len(matches) 559 | 560 | if count > 3: 561 | return True 562 | else: 563 | return False 564 | 565 | 566 | def capitalize(word): 567 | return word[0].upper() + word[1:] 568 | 569 | 570 | def mapping_variation_pairs(mapping): 571 | mapping_list = [] 572 | mapping_list.append((" " + mapping[0] + " ", " " + mapping[1] + " ")) 573 | mapping_list.append( 574 | (" " + capitalize(mapping[0]) + " ", " " + capitalize(mapping[1]) + " ") 575 | ) 576 | 577 | # Change you it's before a punctuation 578 | if mapping[0] == "you": 579 | mapping = ("you", "me") 580 | mapping_list.append((" " + mapping[0] + ",", " " + mapping[1] + ",")) 581 | mapping_list.append((" " + mapping[0] + "\?", " " + mapping[1] + "\?")) 582 | mapping_list.append((" " + mapping[0] + "\!", " " + mapping[1] + "\!")) 583 | mapping_list.append((" " + mapping[0] + "\.", " " + mapping[1] + ".")) 584 | 585 | return mapping_list 586 | 587 | 588 | first_to_second_mappings = [ 589 | ("I'm", "you're"), 590 | ("i'm", "you're"), 591 | ("Im", "you're"), 592 | ("im", "you're"), 593 | ("Ive", "you've"), 594 | ("ive", "you've"), 595 | ("I am", "you are"), 596 | ("i am", "you are"), 597 | ("wasn't I", "weren't you"), 598 | ("I", "you"), 599 | ("I'd", "you'd"), 600 | ("i", "you"), 601 | ("I've", "you've"), 602 | ("was I", "were you"), 603 | ("am I", "are you"), 604 | ("was i", "were you"), 605 | ("am i", "are you"), 606 | ("wasn't I", "weren't you"), 607 | ("I", "you"), 608 | ("i", "you"), 609 | ("I'd", "you'd"), 610 | ("i'd", "you'd"), 611 | ("I've", "you've"), 612 | ("i've", "you've"), 613 | ("I was", "you were"), 614 | ("i was", "you were"), 615 | ("my", "your"), 616 | ("we", "you"), 617 | ("we're", "you're"), 618 | ("mine", "yours"), 619 | ("me", "you"), 620 | ("us", "you"), 621 | ("our", "your"), 622 | ("I'll", "you'll"), 623 | ("i'll", "you'll"), 624 | ("myself", "yourself"), 625 | ] 626 | 627 | second_to_first_mappings = [ 628 | ("you're", "I'm"), 629 | ("your", "my"), 630 | ("you are", "I am"), 631 | ("you were", "I was"), 632 | ("are you", "am I"), 633 | ("you", "I"), 634 | ("you", "me"), 635 | ("you'll", "I'll"), 636 | ("yourself", "myself"), 637 | ("you've", "I've"), 638 | ] 639 | 640 | 641 | def capitalize_helper(string): 642 | string_list = list(string) 643 | string_list[0] = string_list[0].upper() 644 | return "".join(string_list) 645 | 646 | 647 | def capitalize_first_letters(text): 648 | first_letters_regex = re.compile(r"((?<=[\.\?!]\s)(\w+)|(^\w+))") 649 | 650 | def cap(match): 651 | return capitalize_helper(match.group()) 652 | 653 | result = first_letters_regex.sub(cap, text) 654 | return result 655 | 656 | 657 | def standardize_punctuation(text): 658 | text = text.replace("’", "'") 659 | text = text.replace("`", "'") 660 | text = text.replace("“", '"') 661 | text = text.replace("”", '"') 662 | return text 663 | 664 | 665 | def first_to_second_person(text): 666 | text = " " + text 667 | text = standardize_punctuation(text) 668 | if text[-1] not in [".", "?", "!"]: 669 | text += "." 670 | for pair in first_to_second_mappings: 671 | variations = mapping_variation_pairs(pair) 672 | for variation in variations: 673 | text = replace_outside_quotes(text, variation[0], variation[1]) 674 | return text 675 | 676 | 677 | def second_to_first_person(text): 678 | text = " " + text 679 | text = standardize_punctuation(text) 680 | if text[-1] not in [".", "?", "!"]: 681 | text += "." 682 | for pair in second_to_first_mappings: 683 | variations = mapping_variation_pairs(pair) 684 | for variation in variations: 685 | text = replace_outside_quotes(text, variation[0], variation[1]) 686 | return text 687 | -------------------------------------------------------------------------------- /core/play.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from .getconfig import config, setting_info 4 | from .storymanager import Story 5 | from .utils import * 6 | from .gpt2generator import GPT2Generator 7 | from .interface import instructions 8 | 9 | 10 | def get_generator(): 11 | output( 12 | "\nInitializing AI Engine! (This might take a few minutes)", 13 | "loading-message", end="\n\n" 14 | ) 15 | models = [x for x in Path('models').iterdir() if x.is_dir()] 16 | generator = None 17 | failed_env_load = False 18 | while True: 19 | try: 20 | transformers_pretrained = os.environ.get("TRANSFORMERS_PRETRAINED_MODEL", False) 21 | if transformers_pretrained and not failed_env_load: 22 | # Keep it as a string, so that transformers library will load the generic model 23 | model = transformers_pretrained 24 | assert isinstance(model, str) 25 | else: 26 | # Convert to path, so that transformers library will load the model from our folder 27 | if not models: 28 | raise FileNotFoundError( 29 | 'There are no models in the models directory! You must download a pytorch compatible model!') 30 | if os.environ.get("MODEL_FOLDER", False) and not failed_env_load: 31 | model = Path("models/" + os.environ.get("MODEL_FOLDER", False)) 32 | elif len(models) > 1: 33 | output("You have multiple models in your models folder. Please select one to load:", 'message') 34 | list_items([m.name for m in models] + ["(Exit)"], "menu") 35 | model_selection = input_number(len(models)) 36 | if model_selection == len(models): 37 | output("Exiting. ", "message") 38 | exit(0) 39 | else: 40 | model = models[model_selection] 41 | else: 42 | model = models[0] 43 | logger.info("Using model: " + str(model)) 44 | assert isinstance(model, Path) 45 | generator = GPT2Generator( 46 | model_path=model, 47 | generate_num=settings.getint("generate-num"), 48 | temperature=settings.getfloat("temp"), 49 | top_k=settings.getint("top-keks"), 50 | top_p=settings.getfloat("top-p"), 51 | repetition_penalty=settings.getfloat("rep-pen"), 52 | repetition_penalty_range=settings.getint("rep-pen-range"), 53 | repetition_penalty_slope=settings.getfloat("rep-pen-slope"), 54 | ) 55 | break 56 | except OSError: 57 | if len(models) == 0: 58 | output("You do not seem to have any models installed.", "error") 59 | output("Place a model in the 'models' subfolder and press enter", "error") 60 | input("") 61 | # Scan for models again 62 | models = [x for x in Path('models').iterdir() if x.is_dir()] 63 | else: 64 | failed_env_load = True 65 | output("Model could not be loaded. Please try another model. ", "error") 66 | continue 67 | except KeyboardInterrupt: 68 | output("Model load cancelled. ", "error") 69 | exit(0) 70 | return generator 71 | 72 | def settings_menu(): 73 | all_settings = list(setting_info.keys()) 74 | while True: 75 | list_items([pad_text(k, 19) + v[0] + (" " if v[0] else "") + 76 | "Default: " + str(v[1]) + " | " 77 | "Current: " + settings.get(k) for k, v in setting_info.items()] + [ 78 | "(Finish)"]) 79 | i = input_number(len(all_settings), default=-1) 80 | if i == len(all_settings): 81 | output("Done editing settings. ", "menu") 82 | return 83 | else: 84 | key = all_settings[i] 85 | output(key + ": " + setting_info[key][0], "menu") 86 | output("Default: " + str(setting_info[key][1]), "menu", beg='') 87 | output("Current: " + str(settings[key]), "menu", beg='') 88 | new_value = input_line("Enter the new value: ", "query") 89 | if len(new_value.strip()) == 0: 90 | output("Invalid value; cancelling. ", "error") 91 | continue 92 | output(key + ": " + setting_info[key][0], "menu") 93 | output("Current: " + str(settings[key]), "menu", beg='') 94 | output("New: " + str(new_value), "menu", beg='') 95 | output("Saving an invalid option will corrupt file! ", "message") 96 | if input_bool("Change setting? (y/N): ", "selection-prompt"): 97 | settings[key] = new_value 98 | try: 99 | with open("config.ini", "w", encoding="utf-8") as file: 100 | config.write(file) 101 | except IOError: 102 | output("Permission error! Changes will not be saved for next session.", "error") 103 | 104 | 105 | def load_prompt(f, format=True): 106 | with f.open('r', encoding="utf-8") as file: 107 | try: 108 | lines = file.read().strip().split('\n') 109 | if len(lines) < 2: 110 | context = lines[0] 111 | prompt = "" 112 | else: 113 | context = lines[0] 114 | prompt = ' '.join(lines[1:]) 115 | if format: 116 | return format_result(context), format_result(prompt) 117 | else: 118 | return context, prompt 119 | except IOError: 120 | output("Something went wrong; aborting. ", "error") 121 | return None, None 122 | 123 | 124 | def new_story(generator, context, prompt, memory=None, first_result=None): 125 | if memory is None: 126 | memory = [] 127 | context = context.strip() 128 | prompt = prompt.strip() 129 | erase = 0 130 | if use_ptoolkit(): 131 | erase = output(context, 'user-text', prompt, 'user-text', sep="\n\n") 132 | story = Story(generator, context, memory) 133 | if first_result is None: 134 | story.act(prompt) 135 | else: 136 | story.actions.append(prompt) 137 | story.results.append(first_result) 138 | clear_lines(erase) 139 | story.print_story() 140 | return story 141 | 142 | 143 | def save_story(story, file_override=None, autosave=False): 144 | """Saves the existing story to a json file in the saves directory to be resumed later.""" 145 | if not file_override: 146 | savefile = story.savefile 147 | while True: 148 | print() 149 | temp_savefile = input_line("Please enter a name for this save: ", "query") 150 | savefile = savefile if not temp_savefile or len(temp_savefile.strip()) == 0 else temp_savefile 151 | if not savefile or len(savefile.strip()) == 0: 152 | output("Please enter a valid savefile name. ", "error") 153 | else: 154 | break 155 | else: 156 | savefile = file_override 157 | savefile = os.path.splitext(savefile.strip())[0] 158 | savefile = re.sub(r"^ *saves *[/\\] *(.*) *(?:\.json)?", "\\1", savefile).strip() 159 | story.savefile = savefile 160 | savedata = story.to_json() 161 | finalpath = "saves/" + savefile + ".json" 162 | try: 163 | os.makedirs(os.path.dirname(finalpath), exist_ok=True) 164 | except OSError: 165 | if not autosave: 166 | output("Error when creating subdirectory; aborting. ", "error") 167 | with open(finalpath, 'w') as f: 168 | try: 169 | f.write(savedata) 170 | if not autosave: 171 | output("Successfully saved to " + savefile, "message") 172 | except IOError: 173 | if not autosave: 174 | output("Unable to write to file; aborting. ", "error") 175 | 176 | 177 | def load_story(f, gen): 178 | with f.open('r', encoding="utf-8") as file: 179 | try: 180 | story = Story(gen, "") 181 | savefile = os.path.splitext(file.name.strip())[0] 182 | savefile = re.sub(r"^ *saves *[/\\] *(.*) *(?:\.json)?", "\\1", savefile).strip() 183 | story.savefile = savefile 184 | story.from_json(file.read()) 185 | return story, story.context, story.actions[-1] if len(story.actions) > 0 else "" 186 | except FileNotFoundError: 187 | output("Save file not found. ", "error") 188 | except IOError: 189 | output("Something went wrong; aborting. ", "error") 190 | return None, None, None 191 | 192 | 193 | def alter_text(text): 194 | if use_ptoolkit(): 195 | return edit_multiline(text).strip() 196 | 197 | sentences = sentence_split(text) 198 | while True: 199 | output(" ".join(sentences), 'menu') 200 | list_items( 201 | [ 202 | "Edit a sentence.", 203 | "Remove a sentence.", 204 | "Add a sentence.", 205 | "Edit entire prompt.", 206 | "Save and finish." 207 | ], 'menu') 208 | try: 209 | i = input_number(4) 210 | except: 211 | continue 212 | if i == 0: 213 | while True: 214 | output("Choose the sentence you want to edit.", "menu") 215 | list_items(sentences + ["(Back)"], "menu") 216 | i = input_number(len(sentences), default=-1) 217 | if i == len(sentences): 218 | break 219 | else: 220 | output(sentences[i], 'menu') 221 | res = input_line("Enter the altered sentence: ", 'menu').strip() 222 | if len(res) == 0: 223 | output("Invalid sentence entered: returning to previous menu. ", 'error') 224 | continue 225 | sentences[i] = res 226 | elif i == 1: 227 | while True: 228 | output("Choose the sentence you want to remove.", "menu") 229 | list_items(sentences + ["(Back)"], "menu") 230 | i = input_number(len(sentences), default=-1) 231 | if i == len(sentences): 232 | break 233 | else: 234 | del sentences[i] 235 | elif i == 2: 236 | while True: 237 | output("Choose the sentence you want to insert after.", "menu") 238 | list_items(["(Beginning)"] + sentences + ["(Back)"], "menu") 239 | maxn = len(sentences) + 1 240 | i = input_number(maxn, default=-1) 241 | if i == maxn: 242 | break 243 | else: 244 | res = input_line("Enter the new sentence: ", 'menu').strip() 245 | if len(res) == 0: 246 | output("Invalid sentence entered: returning to previous menu. ", 'error') 247 | continue 248 | sentences.insert(i, res) 249 | elif i == 3: 250 | output(" ".join(sentences), 'menu') 251 | res = input_line("Enter the new altered prompt: ", 'menu').strip() 252 | if len(res) == 0: 253 | output("Invalid prompt entered: returning to previous menu. ", 'error') 254 | continue 255 | sentences = sentence_split(res) 256 | elif i == 4: 257 | break 258 | return " ".join(sentences).strip() 259 | 260 | 261 | class GameManager: 262 | 263 | def __init__(self, gen: GPT2Generator): 264 | self.generator = gen 265 | self.story, self.context, self.prompt = None, None, None 266 | 267 | def init_story(self) -> bool: 268 | """ 269 | Initializes the story. Called by play_story. 270 | :return: True if the GameManager should progress to the story, false otherwise. 271 | """ 272 | self.story, self.context, self.prompt = None, None, None 273 | list_items(["Pick Prompt From File (Default if you type nothing)", 274 | "Write Custom Prompt", 275 | "Load a Saved AI", 276 | "Change Settings"], 277 | 'menu') 278 | new_game_option = input_number(3) 279 | 280 | if new_game_option == 0: 281 | prompt_file = select_file(Path("prompts"), ".txt") 282 | if prompt_file: 283 | self.context, self.prompt = load_prompt(prompt_file) 284 | else: 285 | return False 286 | elif new_game_option == 1: 287 | with open( 288 | Path("interface/", "prompt-instructions.txt"), "r", encoding="utf-8" 289 | ) as file: 290 | output(file.read(), "instructions", wrap=False) 291 | if use_ptoolkit(): 292 | output("Context>", "main-prompt") 293 | self.context = edit_multiline() 294 | output("Prompt>", "main-prompt") 295 | self.prompt = edit_multiline() 296 | else: 297 | self.context = input_line("Context> ", "main-prompt") 298 | self.prompt = input_line("Prompt> ", "main-prompt") 299 | filename = input_line("Name to save prompt as? (Leave blank for no save): ", "query") 300 | filename = re.sub("-$", "", re.sub("^-", "", re.sub("[^a-zA-Z0-9_-]+", "-", filename))) 301 | if filename != "": 302 | try: 303 | with open(Path("prompts", filename + ".txt"), "w", encoding="utf-8") as f: 304 | f.write(self.context + "\n" + self.prompt) 305 | except IOError: 306 | output("Permission error! Unable to save custom prompt. ", "error") 307 | elif new_game_option == 2: 308 | story_file = select_file(Path("saves"), ".json") 309 | if story_file: 310 | self.story, self.context, self.prompt = load_story(story_file, self.generator) 311 | else: 312 | return False 313 | elif new_game_option == 3: 314 | settings_menu() 315 | return False 316 | 317 | if len((self.context + self.prompt).strip()) == 0: 318 | output("Conversation has no valid prompt or context, please enter a valid prompt and context. ", "error") 319 | return False 320 | 321 | if self.story is None: 322 | auto_file = "" 323 | if settings.getboolean("autosave"): 324 | while True: 325 | auto_file = input_line("Autosaving enabled. Please enter a save name: ", "query") 326 | if not auto_file or len(auto_file.strip()) == 0: 327 | output("Please enter a valid savefile name. ", "error") 328 | else: 329 | break 330 | instructions() 331 | output("Generating conversation...", "loading-message") 332 | self.story = new_story(self.generator, self.context, self.prompt) 333 | self.story.savefile = auto_file 334 | else: 335 | instructions() 336 | output("Loading conversation...", "loading-message") 337 | self.story.print_story() 338 | 339 | if settings.getboolean("autosave"): 340 | save_story(self.story, file_override=self.story.savefile, autosave=True) 341 | 342 | return True 343 | 344 | # returns true if going back to menu 345 | def process_command(self, cmd_regex) -> bool: 346 | """ 347 | Processes an in-game command. 348 | :param cmd_regex: The regular expression for the command. 349 | :return: True if the command causes the game to exit, false otherwise. 350 | """ 351 | command = cmd_regex.group(1).strip().lower() 352 | args = cmd_regex.group(2).strip().split() 353 | if command == "set": 354 | if len(args) < 2: 355 | output("Invalid number of arguments for set command. ", "error") 356 | instructions() 357 | return False 358 | if args[0] in settings: 359 | curr_setting_val = settings[args[0]] 360 | output( 361 | "Current Value of {}: {} Changing to: {}".format( 362 | args[0], curr_setting_val, args[1] 363 | ) 364 | ) 365 | output("Saving an invalid option will corrupt file! ", "error") 366 | if input_bool("Save setting? (y/N): ", "selection-prompt"): 367 | settings[args[0]] = args[1] 368 | try: 369 | with open("config.ini", "w", encoding="utf-8") as f: 370 | config.write(f) 371 | except IOError: 372 | output("Permission error! Changes will not be saved for next session.", "error") 373 | else: 374 | output("Invalid setting", "error") 375 | instructions() 376 | 377 | elif command == "settings": 378 | settings_menu() 379 | self.story.print_last() 380 | 381 | elif command == "menu": 382 | if input_bool("Do you want to save? (y/N): ", "query"): 383 | save_story(self.story) 384 | # self.story, self.context, self.prompt = None, None, None 385 | return True 386 | 387 | elif command == "restart": 388 | output("Restarting story...", "loading-message") 389 | if len((self.context + self.prompt).strip()) == 0: 390 | output("Story has no prompt or context. Please enter a valid prompt. ", "error") 391 | return False 392 | self.story = new_story(self.generator, self.story.context, self.prompt) 393 | 394 | elif command == "quit": 395 | if input_bool("Do you want to save? (y/N): ", "query"): 396 | save_story(self.story) 397 | exit() 398 | 399 | elif command == "help": 400 | instructions() 401 | 402 | elif command == "print": 403 | use_wrap = input_bool("Print with wrapping? (y/N): ", "query") 404 | use_color = input_bool("Print with colors? (y/N): ", "query") 405 | output("Printing story...", "message") 406 | self.story.print_story(wrap=use_wrap, color=use_color) 407 | 408 | elif command == "retry": 409 | if len(self.story.actions) < 2: 410 | output("Restarting story...", "loading-message") 411 | if len((self.context + self.prompt).strip()) == 0: 412 | output("Story has no prompt or context. Please enter a valid prompt. ", "error") 413 | return False 414 | self.story = new_story(self.generator, self.story.context, self.prompt) 415 | return False 416 | else: 417 | output("Retrying...", "loading-message") 418 | new_action = self.story.actions[-1] 419 | self.story.revert() 420 | result = self.story.act(new_action) 421 | if self.story.is_looping(): 422 | self.story.revert() 423 | output("That action caused the model to start looping. Try something else instead. ", 424 | "error") 425 | return False 426 | self.story.print_last() 427 | 428 | elif command == "revert": 429 | if len(self.story.actions) < 2: 430 | output("You can't go back any farther. ", "error") 431 | return False 432 | self.story.revert() 433 | output("Last action reverted. ", "message") 434 | self.story.print_last() 435 | 436 | elif command == "alter": 437 | self.story.results[-1] = alter_text(self.story.results[-1]) 438 | self.story.print_last() 439 | 440 | elif command == "context": 441 | self.story.context = alter_text(self.story.context) 442 | self.story.print_last() 443 | 444 | elif command == "remember": 445 | memory = cmd_regex.group(2).strip() 446 | if len(memory) > 0: 447 | memory = re.sub("^[Tt]hat +(.*)", "\\1", memory) 448 | memory = memory.strip('.') 449 | memory = memory.strip('!') 450 | memory = memory.strip('?') 451 | self.story.memory.append(memory[0].upper() + memory[1:] + ".") 452 | output("You remember " + memory + ". ", "message") 453 | else: 454 | output("Please enter something valid to remember. ", "error") 455 | 456 | elif command == "memalt": 457 | while True: 458 | output("Select a memory to alter: ", "menu") 459 | list_items(self.story.memory + ["(Finish)"], "menu") 460 | i = input_number(len(self.story.memory), default=-1) 461 | if i == len(self.story.memory): 462 | break 463 | else: 464 | self.story.memory[i] = alter_text(self.story.memory[i]) 465 | if self.story.memory[i] == 0: 466 | del self.story.memory[i] 467 | 468 | elif command == "memswap": 469 | while True: 470 | output("Select two memories to swap: ", "menu") 471 | list_items(self.story.memory + ["(Finish)"], "menu") 472 | i = input_number(len(self.story.memory), default=-1) 473 | if i == len(self.story.memory): 474 | break 475 | j = input_number(len(self.story.memory), default=-1) 476 | if j == len(self.story.memory): 477 | break 478 | else: 479 | self.story.memory[i], self.story.memory[j] = self.story.memory[j], self.story.memory[i] 480 | 481 | elif command == "forget": 482 | while True: 483 | output("Select a memory to forget: ", "menu") 484 | list_items(self.story.memory + ["(Finish)"], "menu") 485 | i = input_number(len(self.story.memory), default=-1) 486 | if i == len(self.story.memory): 487 | break 488 | else: 489 | del self.story.memory[i] 490 | 491 | elif command == "save": 492 | save_story(self.story) 493 | 494 | elif command == "load": 495 | story_file = select_file(Path("saves"), ".json") 496 | if story_file: 497 | tstory, tcontext, tprompt = load_story(story_file, self.generator) 498 | if tstory: 499 | output("Loading conversation...", "message") 500 | self.story = tstory 501 | self.context = tcontext 502 | self.prompt = tprompt 503 | self.story.print_story() 504 | else: 505 | self.story.print_last() 506 | else: 507 | self.story.print_last() 508 | 509 | elif command == "summarize": 510 | first_result = self.story.results[-1] 511 | output(self.story.context, "user-text", "(YOUR SUMMARY HERE)", "message") 512 | output(self.story.results[-1], "ai-text") 513 | new_prompt = input_line("Enter the summary for the new conversation: ", "query") 514 | new_prompt = new_prompt.strip() 515 | if len(new_prompt) == 0: 516 | output("Invalid new prompt; cancelling. ", "error") 517 | self.story.print_last() 518 | return False 519 | if input_bool("Do you want to save your previous conversation? (y/N): ", "query"): 520 | save_story(self.story) 521 | self.prompt = new_prompt 522 | self.story = new_story(self.generator, self.context, self.prompt, memory=self.story.memory, 523 | first_result=first_result) 524 | self.story.savefile = "" 525 | 526 | elif command == "altergen": 527 | result = alter_text(self.story.results[-1]) 528 | self.story.results[-1] = "" 529 | output("Regenerating result...", "message") 530 | result += ' ' + self.story.act(result, record=False) 531 | self.story.results[-1] = result 532 | self.story.print_last() 533 | 534 | else: 535 | output("Invalid command: " + command, "error") 536 | return False 537 | 538 | def process_action(self, action, suggested_actions=[]) -> bool: 539 | """ 540 | Processes an action to be submitted to the AI. 541 | :param action: The action being submitted to the AI. 542 | :param suggested_actions: The suggested actions generated (if action-sugg > 0) 543 | :return: True if the action ends the game, false otherwise. 544 | """ 545 | action = format_input(action) 546 | 547 | story_insert_regex = re.search("^(?: *you +)?! *(.*)$", action, flags=re.I) 548 | 549 | # If the player enters a story insert. 550 | if story_insert_regex: 551 | action = story_insert_regex.group(1) 552 | if not action or len(action.strip()) == 0: 553 | output("Invalid conversation insert. ", "error") 554 | return False 555 | output(format_result(action), "user-text") 556 | 557 | # If the user enters nothing but leaves "you", treat it like an empty action (continue) 558 | if re.match(r"^(?: *you *)*[.?!]? *$", action, flags=re.I): 559 | action = "" 560 | else: 561 | # Prompt the user with the formatted action 562 | output("> " + format_result(action), "transformed-user-text") 563 | 564 | if action == "": 565 | output("Continuing...", "message") 566 | 567 | result = self.story.act(action) 568 | 569 | # Output the AI's result. 570 | output(result, "ai-text") 571 | 572 | def play_story(self): 573 | """The main in-game loop""" 574 | if not self.init_story(): # Failed init 575 | return 576 | 577 | while True: 578 | # Generate suggested actions 579 | act_alts = settings.getint("action-sugg") 580 | suggested_actions = [] 581 | if act_alts > 0: 582 | # TODO change this to two messages for different colors 583 | output("Suggested actions:", "selection-value") 584 | action_suggestion_lines = 2 585 | for i in range(act_alts): 586 | suggested_action = self.story.get_suggestion() 587 | if len(suggested_action.strip()) > 0: 588 | j = len(suggested_actions) 589 | suggested_actions.append(suggested_action) 590 | suggestion = "{}) {}".format(j, suggested_action) 591 | action_suggestion_lines += \ 592 | output(suggestion, "selection-value", beg='' if i != 0 else None) 593 | 594 | bell() 595 | print() 596 | 597 | if use_ptoolkit(): 598 | action = input_line("> ", "main-prompt", default="%s" % "") 599 | else: 600 | action = input_line(">", "main-prompt") 601 | 602 | # Clear suggestions and user input 603 | if act_alts and not in_colab(): 604 | clear_lines(action_suggestion_lines + 2) 605 | 606 | # Users can type in "/command", or "You /command" if prompt_toolkit is on and they left the "You" in 607 | cmd_regex = re.search(r"^(?: *you *)?/([^ ]+) *(.*)$", action, flags=re.I) 608 | 609 | # If this is a command 610 | if cmd_regex: 611 | if self.process_command(cmd_regex): # Go back to the menu 612 | return 613 | 614 | # Otherwise this is just a normal action. 615 | else: 616 | if self.process_action(action, suggested_actions): # End of story 617 | return 618 | 619 | # Autosave after every input from the user (if it's enabled) 620 | if settings.getboolean("autosave"): 621 | save_story(self.story, file_override=self.story.savefile, autosave=True) 622 | --------------------------------------------------------------------------------