├── .gitignore ├── LICENSE ├── README.md ├── notebooks └── re3.ipynb ├── presentation_materials ├── poster.pdf ├── slides_12min.pptx ├── slides_3min.pptx ├── video_12min.mp4 └── video_3min.mp4 ├── requirements.txt ├── scripts ├── data │ ├── create_alignment_data.py │ ├── create_gpt3_finetuning_data_rollingwindow.py │ ├── generate_contradictory_stories.py │ └── resample_character_descriptions.py ├── main.py └── training │ └── train_controller.py ├── setup.py └── story_generation ├── common ├── controller │ ├── controller_util.py │ ├── loader_util.py │ ├── loaders │ │ ├── alignment_loader.py │ │ └── coherence_loader.py │ └── models │ │ ├── abstract_controller.py │ │ └── longformer_classifier.py ├── data │ ├── data_util.py │ ├── datasets │ │ ├── abstract_dataset.py │ │ ├── alignment.py │ │ └── writing_prompts.py │ ├── split_paragraphs.py │ └── tree_util.py ├── summarizer │ ├── models │ │ ├── abstract_summarizer.py │ │ ├── gpt3_summarizer.py │ │ └── opt_summarizer.py │ └── summarizer_util.py └── util.py ├── draft_module └── beam_candidate.py ├── edit_module ├── entity.py ├── evaluate_consistency.py └── example_library.csv ├── plan_module └── plan.py └── rewrite_module ├── README.md └── heuristics.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | .DS_Store 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | pip-wheel-metadata/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # pipenv 89 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 90 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 91 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 92 | # install all needed dependencies. 93 | #Pipfile.lock 94 | 95 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 96 | __pypackages__/ 97 | 98 | # Celery stuff 99 | celerybeat-schedule 100 | celerybeat.pid 101 | 102 | # SageMath parsed files 103 | *.sage.py 104 | 105 | # Environments 106 | .env 107 | .venv 108 | env/ 109 | venv/ 110 | ENV/ 111 | env.bak/ 112 | venv.bak/ 113 | 114 | # Spyder project settings 115 | .spyderproject 116 | .spyproject 117 | 118 | # Rope project settings 119 | .ropeproject 120 | 121 | # mkdocs documentation 122 | /site 123 | 124 | # mypy 125 | .mypy_cache/ 126 | .dmypy.json 127 | dmypy.json 128 | 129 | # Pyre type checker 130 | .pyre/ 131 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Kevin Yang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Re3: Generating Longer Stories With Recursive Reprompting and Revision 2 | 3 | This repo contains code for Re3: Generating Longer Stories With Recursive Reprompting and Revision (https://arxiv.org/abs/2210.06774) by Kevin Yang, Yuandong Tian, Nanyun Peng, and Dan Klein, to appear at EMNLP 2022. In this codebase we provide instructions for automatically generating stories of 2000+ words (or even much longer), as well as reproducing all of our baselines/ablations/analyses from the paper. We hope that this work can be useful for future research on automatic long story generation. 4 | 5 | UPDATE 12/20/22: Check out our newer/better system at https://github.com/yangkevin2/doc-story-generation ! 6 | 7 | ## Quick Start with Notebook 8 | 9 | See `notebooks/re3.ipynb` (thanks @jagilley for contributing this!) for a notebook with installation commands + a default story generation command. More detailed instructions below. 10 | 11 | ## Installation / Data 12 | 13 | Install Python 3.8.13 and PyTorch 1.12.1 (other versions are probably also fine for both), then install the remaining requirements via `pip install -r requirements.txt`. Additionally install this repo with `pip install -e .`. 14 | 15 | Also run `export OPENAI_API_KEY=$YOUR_API_KEY` in your terminal so that the code can call the GPT3 API with your key. 16 | 17 | Meanwhile, run `wget https://emnlp22-re3-data.s3.amazonaws.com/emnlp22_re3_data.zip` and unzip the folder. This folder contains some pretrained reranker ckpts as well as all of the data we used in all components/analyses. It also contains the final generated stories and MTurk annotation results from our main experiments (note: some generated stories may contain sensitive/NSFW content, since we didn't attempt to filter these). 18 | 19 | ## Main Story Generation 20 | 21 | Main story generation command matching the settings used for our main paper experiments: 22 | ``` 23 | mkdir output 24 | CUDA_VISIBLE_DEVICES=0 python -u scripts/main.py --summarizer gpt3_summarizer --controller longformer_classifier longformer_classifier --loader alignment coherence --controller-load-dir emnlp22_re3_data/ckpt/relevance_reranker emnlp22_re3_data/ckpt/coherence_reranker --controller-model-string allenai/longformer-base-4096 allenai/longformer-base-4096 --save-outline-file output/outline0.pkl --save-complete-file output/complete_story0.pkl --log-file output/story0.log 25 | ``` 26 | 27 | Don't worry if you see some errors being printed, as long as the program doesn't terminate early. (For example, it may need multiple tries in some parts of the initial plan generation.) 28 | 29 | This command uses our existing relevance and coherence reranker ckpts included in the download (note: these were retrained after submission to be compatible with an updated version of the HuggingFace transformers package, but otherwise are effectively the same as the ones we used in our paper experiments). If you want to use your own ckpts, see the instructions further down for training, and change the paths in this command to point to the correct ckpts. 30 | 31 | ### Other Arguments 32 | 33 | Main story generation arguments are compiled in `scripts/main.py`; follow the links there to see a complete list. Some particular arguments of interest: 34 | 35 | * Use the `--premise` argument to specify your own story premise instead of having one autogenerated by GPT3. 36 | * Add the `--setup-only` flag to stop after generating the initial plan; the plan can then be reloaded by specifying `--load-outline-file` instead of `--save-outline-file` later. 37 | * Add `--no-editor` to turn off the Edit module. This will make generation a decent amount faster without sacrificing too much. This is the Plan-Draft-Rewrite ablation in our paper. 38 | * Add `--no-planner` to turn off the Plan module, which will make performance a good bit worse. This is the Draft-Rewrite-Edit ablation in our paper. 39 | * Set `--max-candidates 1` to turn off the Rewrite module, which will make performance a good bit worse. This is the Plan-Draft-Edit ablation in our paper. 40 | * Increase `--max-beam-size` (defaults to 1) to turn on a passage-level variable-size beam search procedure based on the rerankers. This is off for the paper experiments (makes the system several times slower) but should improve performance a bit. 41 | * Change `--fixed-outline-length` (defaults to 3) to set a desired length for your outline (i.e., how many numbered items it will have) or set it to -1 to accept variable length. 42 | * Change `--max-continuation-substeps` (defaults to 4) and `--generation-max-length` (defaults to 256) to change how much story text to write for each numbered item of the outline. With the default settings, it will write four 256-token passages for each. 43 | * For the longer story in Appendix L of the paper, we added `--outline-levels 2 --fixed-outline-length -1 --continuation-threshold 0.5 --max-continuation-substeps 5` which (1) generates a 2-level outline (note: the current version of this generates somewhat repetitive plans sometimes) and (2) dynamically decides when to move on to the next part of the outline during story generation based on reranker scores, instead of just using a fixed length for each part of the outline. 44 | * Set `--log-level` to be something between 21 and 25 to vary the verbosity of logging (higher = less verbose). 45 | 46 | ### Baselines 47 | 48 | The simplest Rolling baseline in the paper can be run by simply adding all of the flags `--no-editor`, `--no-planner`, and `--max-candidates 1` to the main story generation command above. 49 | 50 | To run the Rolling-Finetune baseline requires one to additionally finetune OpenAI's davinci model, for which we used passages from the WritingPrompts dataset (Fan et al https://arxiv.org/pdf/1805.04833.pdf). The data we used for finetuning can be found in `emnlp22_re3_data/data/rollingwindow_finetune_data.json`. Alternatively you can generate your own finetuning data via the command 51 | 52 | ``` 53 | python scripts/data/create_gpt3_finetuning_data_rollingwindow.py --expander --data-dir emnlp22_re3_data/data/writing_prompts --limit 10000 --length-limit 10000000 --lower-length-limit 3000 --save-json $PATH_TO_SAVE_FINETUNING_DATA --track-num-tokens 54 | ``` 55 | 56 | Then, using the OpenAI command line API (https://beta.openai.com/docs/guides/fine-tuning), we finetuned as follows. Replace the json with your own path if you re-created the data. 57 | 58 | ``` 59 | openai api fine_tunes.create -t emnlp22_re3_data/data/rollingwindow_finetune_data.json -m davinci --learning_rate_multiplier 0.02 --n_epochs 1 60 | ``` 61 | 62 | After finetuning, you can run the same command as for the Rolling baseline to generate stories, but additionally specifying `--draft-model-string $OPENAI_CKPT` where `$OPENAI_CKPT` is the name of your finetuned davinci checkpoint (should be a string starting with `davinci:ft-personal-`) 63 | 64 | ## Reranker Training 65 | 66 | If you wanted to retrain the relevance and coherence rerankers yourself, follow the instructions below. Feel free to adjust the batch sizes depending on your GPU memory. 67 | 68 | ### Training Relevance Reranker: 69 | 70 | The data can be found in `emnlp22_re3_data/data/alignment_data.csv`, and was generated by chunking WritingPrompts stories and prompting GPT3-Instruct-13B (text-curie-001) for summaries, according to the command: 71 | 72 | ``` 73 | python scripts/data/create_alignment_data.py --summarizer gpt3_summarizer --gpt3-model text-curie-001 --dataset writing_prompts --data-dir emnlp22_re3_data/data/writing_prompts --limit 2000 --save-csv $PATH_TO_SAVE_ALIGNMENT_DATA 74 | ``` 75 | 76 | Note that the preprocessing with the WritingPrompts dataset can take a while, so feel free to try it with the folder `emnlp22_re3_data/data/writing_prompts_debug` if you just want to check if it's working first. 77 | After getting the data, train the relevance reranker: 78 | 79 | ``` 80 | CUDA_VISIBLE_DEVICES=0 python scripts/training/train_controller.py --controller longformer_classifier --data-dir emnlp22_re3_data/data/alignment_data.csv --dataset alignment --batch-size 8 --controller-save-dir $DIR_TO_SAVE_RELEVANCE_CKPT --length-limit 2000 --controller-epochs 20 --loader alignment --controller-num-negatives 3 --controller-model-string allenai/longformer-base-4096 --controller-lr 1e-6 --coherence-negative-categories other shuffle 81 | ``` 82 | 83 | Feel free to modify batch size depending on your GPU memory. 84 | 85 | ### Training Coherence Reranker: 86 | 87 | ``` 88 | CUDA_VISIBLE_DEVICES=0 python scripts/training/train_controller.py --controller longformer_classifier --data-dir emnlp22_re3_data/data/writing_prompts --dataset writing_prompts --batch-size 1 --controller-save-dir $DIR_TO_SAVE_COHERENCE_CKPT --length-limit 1000 --controller-epochs 20 --loader coherence --controller-num-negatives 3 --controller-model-string allenai/longformer-base-4096 --controller-lr 1e-6 --coherence-negative-categories other shuffle repeat 89 | ``` 90 | 91 | Feel free to modify batch size depending on your GPU memory. 92 | 93 | ## Edit Module Analysis 94 | 95 | The data we used for evaluation in Section 5.2 can be found in `emnlp22_re3_data/data/consistency_analysis_data`. The command to evaluate each method is as follows, where `$METHOD` is any of `structured`, `entailment`, or `entailment-dpr` corresponding to the methods of the same names in the paper. Note that there's some randomness involved in `structured` due to GPT3 reliance. The latter can also take a while to run. 96 | 97 | ``` 98 | CUDA_VISIBLE_DEVICES=0 python story_generation/edit_module/evaluate_consistency.py --consistency-dataset-dir $CONSISTENCY_DATA_DIR --consistency-method $METHOD 99 | ``` 100 | 101 | We originally generated the data by generating a large number N setups according to the main story generation command above with the `--setup-only` flag, saving them as `0.pkl` up to `N-1.pkl` in your directory of choice. We then manually sampled and filtered modified versions via `scripts/data/resample_character_descriptions.py` to resample character descriptions until we got contradictory descriptions compared to the original setup. Finally, we used `scripts/data/generate_contradictory_stories.py` to generate story beginnings for the two contradictory setups until the contradiction manifests in the story. (Both scripts rely on the indexing 0 to N-1 to track what setups have already been processed, in case one needs to restart the script. At any point one can backtrack to add more data in previous steps if needed.) See the individual files if you're actually interested in running these scripts, but for a larger-scale analysis in the future it might be better to integrate this pipeline with a crowdsourcing platform like MTurk to avoid the need for manually annotating. -------------------------------------------------------------------------------- /notebooks/re3.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "colab": { 8 | "base_uri": "https://localhost:8080/" 9 | }, 10 | "id": "lb8HqYd9ETZB", 11 | "outputId": "8690fee1-9a05-442f-e28c-aef8d9fbbab6" 12 | }, 13 | "outputs": [], 14 | "source": [ 15 | "# setup\n", 16 | "\n", 17 | "!git clone https://github.com/yangkevin2/emnlp22-re3-story-generation\n", 18 | "%cd emnlp22-re3-story-generation\n", 19 | "!pip install -r requirements.txt\n", 20 | "!pip install -e ." 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": null, 26 | "metadata": { 27 | "colab": { 28 | "base_uri": "https://localhost:8080/" 29 | }, 30 | "id": "NHS9SXR-_TpS", 31 | "outputId": "04c50260-1b92-4694-ec1f-2f4dce87f780" 32 | }, 33 | "outputs": [], 34 | "source": [ 35 | "# PASTE YOUR OPENAI API KEY BELOW\n", 36 | "\n", 37 | "key = \"your-api-key-here\"\n", 38 | "\n", 39 | "%env OPENAI_API_KEY=$key" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": null, 45 | "metadata": { 46 | "colab": { 47 | "base_uri": "https://localhost:8080/" 48 | }, 49 | "id": "s-n_XEWr_gcL", 50 | "outputId": "7974a16f-07ff-4867-f552-453e0745e347" 51 | }, 52 | "outputs": [], 53 | "source": [ 54 | "!wget https://emnlp22-re3-data.s3.amazonaws.com/emnlp22_re3_data.zip\n", 55 | "!unzip emnlp22_re3_data.zip\n", 56 | "!mkdir output" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "metadata": { 63 | "colab": { 64 | "base_uri": "https://localhost:8080/" 65 | }, 66 | "id": "Jw60ycNpAAmo", 67 | "outputId": "2bdbcae4-eb47-460b-dfa5-daf31cc2d0a1" 68 | }, 69 | "outputs": [], 70 | "source": [ 71 | "# run re3\n", 72 | "\n", 73 | "!CUDA_VISIBLE_DEVICES=0 python -u scripts/main.py --summarizer gpt3_summarizer --controller longformer_classifier longformer_classifier --loader alignment coherence --controller-load-dir emnlp22_re3_data/ckpt/relevance_reranker emnlp22_re3_data/ckpt/coherence_reranker --controller-model-string allenai/longformer-base-4096 allenai/longformer-base-4096 --save-outline-file output/outline0.pkl --save-complete-file output/complete_story0.pkl --log-file output/story0.log" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": null, 79 | "metadata": { 80 | "colab": { 81 | "base_uri": "https://localhost:8080/" 82 | }, 83 | "id": "8fQWavWjP4Ol", 84 | "outputId": "ec6e9bf8-51fe-4b34-bced-b00bc2aa9cb6" 85 | }, 86 | "outputs": [], 87 | "source": [ 88 | "import pickle\n", 89 | "\n", 90 | "with open('output/complete_story0.pkl', 'rb') as f:\n", 91 | " story = pickle.load(f)" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": null, 97 | "metadata": { 98 | "colab": { 99 | "base_uri": "https://localhost:8080/" 100 | }, 101 | "id": "gFKZixeCQxW2", 102 | "outputId": "64051138-6398-45ce-e858-630897f06b1e" 103 | }, 104 | "outputs": [], 105 | "source": [ 106 | "story[0].story()" 107 | ] 108 | } 109 | ], 110 | "metadata": { 111 | "accelerator": "GPU", 112 | "colab": { 113 | "collapsed_sections": [], 114 | "machine_shape": "hm", 115 | "provenance": [] 116 | }, 117 | "gpuClass": "standard", 118 | "kernelspec": { 119 | "display_name": "Python 3", 120 | "name": "python3" 121 | }, 122 | "language_info": { 123 | "name": "python" 124 | } 125 | }, 126 | "nbformat": 4, 127 | "nbformat_minor": 0 128 | } 129 | -------------------------------------------------------------------------------- /presentation_materials/poster.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangkevin2/emnlp22-re3-story-generation/3a97ebde04e3333962c2825146897efe1dc87dd8/presentation_materials/poster.pdf -------------------------------------------------------------------------------- /presentation_materials/slides_12min.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangkevin2/emnlp22-re3-story-generation/3a97ebde04e3333962c2825146897efe1dc87dd8/presentation_materials/slides_12min.pptx -------------------------------------------------------------------------------- /presentation_materials/slides_3min.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangkevin2/emnlp22-re3-story-generation/3a97ebde04e3333962c2825146897efe1dc87dd8/presentation_materials/slides_3min.pptx -------------------------------------------------------------------------------- /presentation_materials/video_12min.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangkevin2/emnlp22-re3-story-generation/3a97ebde04e3333962c2825146897efe1dc87dd8/presentation_materials/video_12min.mp4 -------------------------------------------------------------------------------- /presentation_materials/video_3min.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangkevin2/emnlp22-re3-story-generation/3a97ebde04e3333962c2825146897efe1dc87dd8/presentation_materials/video_3min.mp4 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | flair==0.10 2 | nltk==3.6.2 3 | numpy==1.21.6 4 | openai==0.16.0 5 | pandas==1.3.5 6 | protobuf==3.20 7 | python-Levenshtein==0.12.2 8 | scikit-learn==1.0.2 9 | scipy==1.7.3 10 | sentence-transformers==2.2.2 11 | tqdm==4.62.3 12 | transformers==4.21.2 13 | -------------------------------------------------------------------------------- /scripts/data/create_alignment_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | import os 4 | 5 | from tqdm import tqdm 6 | from transformers import AutoTokenizer 7 | 8 | from story_generation.common.util import add_general_args 9 | from story_generation.common.data.data_util import add_data_args, load_dataset 10 | from story_generation.common.summarizer.summarizer_util import add_summarizer_args, load_summarizer 11 | from story_generation.common.data.split_paragraphs import split_paragraphs, group_chunks 12 | 13 | if __name__=='__main__': 14 | parser = argparse.ArgumentParser() 15 | parser = add_general_args(parser) 16 | parser = add_data_args(parser) 17 | parser = add_summarizer_args(parser) 18 | parser.add_argument('--save-csv', type=str, required=True, help='save to this csv file') 19 | parser.add_argument('--max-chunk-length', type=int, default=200, help='maximum length of chunks when splitting paragraphs') 20 | args = parser.parse_args() 21 | 22 | dataset = load_dataset(args) 23 | summarizer = load_summarizer(args) 24 | tab_token = summarizer.tokenizer.encode('\t')[0] 25 | logit_bias = {tab_token:-100} 26 | long_texts = dataset.load_long_texts(split='train', split_paragraphs=False) 27 | os.makedirs(os.path.dirname(args.save_csv), exist_ok=True) 28 | with open(args.save_csv, 'w') as wf: 29 | writer = csv.writer(wf) 30 | writer.writerow(['text1', 'text2']) 31 | for text in tqdm(long_texts): 32 | chunks = group_chunks(split_paragraphs(text, mode='sentence'), max_chunk_length=args.max_chunk_length) 33 | prompts = [chunk.strip() + '\n\n\n\nOne-sentence summary:\n\n\n\n' for chunk in chunks] 34 | summaries = [s.strip() for s in summarizer(prompts, modify_prompt=False, logit_bias=logit_bias)] 35 | writer.writerow(['\t'.join(prompts), '\t'.join(summaries)]) 36 | -------------------------------------------------------------------------------- /scripts/data/create_gpt3_finetuning_data_rollingwindow.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | from transformers import AutoTokenizer 5 | import random 6 | 7 | from story_generation.common.util import add_general_args 8 | from story_generation.common.data.data_util import add_data_args, load_dataset 9 | from story_generation.common.summarizer.summarizer_util import add_summarizer_args 10 | from story_generation.common.summarizer.models.gpt3_summarizer import GPT3_SEP, GPT3_END 11 | 12 | if __name__=='__main__': 13 | parser = argparse.ArgumentParser() 14 | parser = add_general_args(parser) 15 | parser = add_data_args(parser) 16 | parser = add_summarizer_args(parser) 17 | parser.add_argument('--save-json', type=str, required=True, help='save to this json file') 18 | parser.add_argument('--track-num-tokens', default=False, action='store_true', help='track num tokens') 19 | parser.add_argument('--target-max-length', type=int, default=256, help='max length of target') 20 | parser.add_argument('--source-max-length', type=int, default=768, help='max length of source') 21 | args = parser.parse_args() 22 | 23 | dataset = load_dataset(args) 24 | long_texts = dataset.load_long_texts(split='train', split_paragraphs=False) 25 | short_texts = dataset.load_short_texts(split='train', split_paragraphs=False) 26 | 27 | tokenizer = AutoTokenizer.from_pretrained('gpt2') 28 | 29 | os.makedirs(os.path.dirname(args.save_json), exist_ok=True) 30 | 31 | num_tokens = 0 32 | with open(args.save_json, 'w') as wf: 33 | for long_text, short_text in zip(long_texts, short_texts): 34 | tokenized_long_text = tokenizer.encode(long_text) 35 | if random.random() < 0.2: 36 | source = 'Write a story with the following premise.\n\n' + 'Premise: ' + short_text.strip() + '\n\nChapter 1\n\n' 37 | # source = 'Premise:\n\n' + short_text.strip() + '\n\nWrite a story with this premise:\n\n' 38 | target = tokenizer.decode(tokenized_long_text[:args.target_max_length]) 39 | else: 40 | split_idx = random.choice(list(range(args.target_max_length, len(tokenized_long_text)-1))) 41 | source = tokenizer.decode(tokenized_long_text[max(0, split_idx - args.source_max_length):split_idx]) 42 | target = tokenizer.decode(tokenized_long_text[split_idx:split_idx + args.target_max_length]) 43 | if split_idx + args.target_max_length > len(tokenized_long_text): 44 | target = target + '\n\n\n\n' + GPT3_END 45 | if args.track_num_tokens: 46 | num_tokens += len(tokenizer.encode(source)) + len(tokenizer.encode(target)) 47 | wf.write(json.dumps({'prompt': source, 'completion': target}).strip() + '\n') 48 | if args.track_num_tokens: 49 | print('num tokens', num_tokens) 50 | 51 | -------------------------------------------------------------------------------- /scripts/data/generate_contradictory_stories.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | import os 4 | import time 5 | import pickle 6 | from copy import deepcopy 7 | 8 | from transformers import AutoTokenizer 9 | 10 | from story_generation.edit_module.entity import Entity 11 | from story_generation.common.util import * 12 | from story_generation.common.data.data_util import add_data_args, load_dataset 13 | from story_generation.common.summarizer.summarizer_util import add_summarizer_args, load_summarizer 14 | 15 | def generate_story(model, prompt, character, dpr_query, num_samples=1): 16 | stories = [character + story for story in model([prompt], stop='Chapter', cut_sentence=True, num_completions=num_samples)] 17 | scores = score_dpr(dpr_query, stories) 18 | stories_scores = [(story, score) for story, score in zip(stories, scores)] 19 | return sorted(stories_scores, key=lambda x: x[1], reverse=True) 20 | 21 | if __name__=='__main__': 22 | parser = argparse.ArgumentParser() 23 | parser = add_general_args(parser) 24 | parser = add_data_args(parser) 25 | parser = add_summarizer_args(parser) 26 | parser.add_argument('--load-dir', type=str, required=True, help='directory where stuff was saved in resample_character_descriptions.py') 27 | args = parser.parse_args() 28 | 29 | base_model = load_summarizer(args) 30 | instruct_args = deepcopy(args) 31 | instruct_args.gpt3_model = 'text-' + args.gpt3_model + '-001' 32 | instruct_model = load_summarizer(instruct_args) 33 | 34 | original_dir = os.path.join(args.load_dir, 'original') 35 | altered_dir = os.path.join(args.load_dir, 'altered') 36 | original_stories_dir = os.path.join(args.load_dir, 'original_stories') 37 | altered_stories_dir = os.path.join(args.load_dir, 'altered_stories') 38 | 39 | os.makedirs(original_stories_dir, exist_ok=True) 40 | os.makedirs(altered_stories_dir, exist_ok=True) 41 | 42 | num_files = len(os.listdir(original_dir)) 43 | num_preexisting_files = len(os.listdir(original_stories_dir)) 44 | 45 | for i in range(num_preexisting_files, num_files): # resume where we left off, if restarting the script 46 | with open(os.path.join(original_dir, str(i) + '.pkl'), 'rb') as f: 47 | original_save_info = pickle.load(f) 48 | with open(os.path.join(altered_dir, str(i) + '.pkl'), 'rb') as f: 49 | altered_save_info = pickle.load(f) 50 | modified_character = None 51 | for character in original_save_info['character_strings']: 52 | if original_save_info['character_strings'][character].description != altered_save_info['character_strings'][character].description: 53 | modified_character = character 54 | break 55 | assert modified_character is not None 56 | prompt_modifier = '\n\nWrite a story with the above premise, setting, and characters.\n\nChapter 1\n\n' + modified_character # make it start the story talking about the right character. 57 | original_prompt = original_save_info['infer_attributes_string'] + prompt_modifier 58 | altered_prompt = altered_save_info['infer_attributes_string'] + prompt_modifier 59 | 60 | print('ORIGINAL PROMPT') 61 | print(original_prompt) 62 | print('ALTERED CHARACTER') 63 | print(modified_character) 64 | print('ORIGINAL DESCRIPTION') 65 | print(original_save_info['character_strings'][modified_character].description) 66 | print('ALTERED DESCRIPTION') 67 | print(altered_save_info['character_strings'][modified_character].description) 68 | print('CONTRADICTED ORIGINAL') 69 | print(original_save_info['contradicted_part']) 70 | print('CONTRADICTED ALTERED') 71 | print(altered_save_info['contradicted_part']) 72 | 73 | while True: 74 | original_stories = generate_story(base_model, original_prompt, modified_character, original_save_info['contradicted_part'] + '\nFind evidence to support or refute this description.', num_samples=5) 75 | for original_story, score in original_stories: 76 | print('\n\nORIGINAL STORY') 77 | print(original_story) 78 | print('DPR SCORE') 79 | print(score) 80 | is_good = input('Is this story good? (y/n/s) ') 81 | if is_good == 'y' or is_good == 's': 82 | break 83 | if is_good != 'n': 84 | break 85 | if is_good != 'y': 86 | with open(os.path.join(original_stories_dir, str(i) + '.txt'), 'w') as f: 87 | f.write('SKIPPED') 88 | with open(os.path.join(altered_stories_dir, str(i) + '.txt'), 'w') as f: 89 | f.write('SKIPPED') 90 | continue 91 | while True: 92 | altered_stories = generate_story(base_model, altered_prompt, modified_character, original_save_info['contradicted_part'] + '\nFind evidence to support or refute this description.', num_samples=5) 93 | for altered_story, score in altered_stories: 94 | print('\n\nALTERED STORY') 95 | print(altered_story) 96 | print('DPR SCORE') 97 | print(score) 98 | is_good = input('Is this story good? (y/n/s) ') 99 | if is_good == 'y' or is_good == 's': 100 | break 101 | if is_good != 'n': 102 | break 103 | if is_good != 'y': 104 | with open(os.path.join(original_stories_dir, str(i) + '.txt'), 'w') as f: 105 | f.write('SKIPPED') 106 | with open(os.path.join(altered_stories_dir, str(i) + '.txt'), 'w') as f: 107 | f.write('SKIPPED') 108 | continue 109 | with open(os.path.join(original_stories_dir, str(i) + '.txt'), 'w') as f: 110 | f.write(original_story) 111 | with open(os.path.join(altered_stories_dir, str(i) + '.txt'), 'w') as f: 112 | f.write(altered_story) 113 | 114 | 115 | -------------------------------------------------------------------------------- /scripts/data/resample_character_descriptions.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | import os 4 | import random 5 | from copy import deepcopy 6 | 7 | from story_generation.common.util import * 8 | from story_generation.common.data.data_util import add_data_args, load_dataset 9 | from story_generation.common.summarizer.summarizer_util import add_summarizer_args, load_summarizer 10 | from story_generation.common.controller.controller_util import add_controller_args, load_controller 11 | from story_generation.plan_module.plan import generate_outline 12 | 13 | if __name__ == '__main__': 14 | parser = argparse.ArgumentParser() 15 | parser = add_general_args(parser) 16 | parser = add_data_args(parser) 17 | parser = add_summarizer_args(parser) 18 | parser.add_argument('--load-dir', type=str, required=True, help='directory to load setups from') 19 | parser.add_argument('--save-dir', type=str, required=True, help='directory to save setups to') 20 | args = parser.parse_args() 21 | 22 | already_labeled_paths = [] 23 | if os.path.exists(os.path.join(args.save_dir, 'already_labeled.txt')): 24 | with open(os.path.join(args.save_dir, 'already_labeled.txt'), 'r') as f: 25 | for line in f: 26 | already_labeled_paths.append(line.strip()) 27 | 28 | base_model = load_summarizer(args) 29 | instruct_args = deepcopy(args) 30 | instruct_args.gpt3_model = 'text-' + args.gpt3_model + '-001' 31 | instruct_model = load_summarizer(instruct_args) 32 | 33 | os.makedirs(os.path.join(args.save_dir, 'original'), exist_ok=True) 34 | os.makedirs(os.path.join(args.save_dir, 'altered'), exist_ok=True) 35 | 36 | file_idx = len([x for x in os.listdir(os.path.join(args.save_dir, 'original'))]) # keep counting from wherever we left off, so we can resume progress as needed 37 | 38 | for fname in os.listdir(args.load_dir): 39 | path = os.path.join(args.load_dir, fname) 40 | if path in already_labeled_paths: 41 | continue 42 | 43 | with open(path, 'rb') as f: 44 | save_info = pickle.load(f) 45 | premise = save_info['premise'] 46 | setting = save_info['setting'] 47 | characters = save_info['characters'] 48 | character_strings = save_info['character_strings'] 49 | infer_attributes_string = save_info['infer_attributes_string'] 50 | 51 | print('\n\n\n\nORIGINAL') 52 | print(infer_attributes_string) 53 | 54 | remaining_characters_to_sample = list(character_strings.keys()) 55 | 56 | while True: 57 | if len(remaining_characters_to_sample) == 0: 58 | break 59 | resample_key = random.choice(remaining_characters_to_sample) 60 | remaining_characters_to_sample.remove(resample_key) 61 | context = infer_attributes_string.split('\n\n') 62 | prefix = '' 63 | for i, section in enumerate(context): 64 | if not section.startswith(resample_key): 65 | prefix += section + '\n\n' 66 | else: 67 | original_description = section[len(resample_key):] 68 | suffix = '\n\n' + '\n\n'.join(context[i+1:]) 69 | break 70 | prefix += resample_key 71 | contradiction_entries = resample_description(prefix, suffix, resample_key, original_description, num_samples=5) 72 | 73 | should_break = False 74 | for entry in contradiction_entries: 75 | if entry['contradiction_logprob'] < -1: 76 | print('REMAINING ENTRIES LOW LOGPROB; GO TO NEXT CHARACTER') 77 | break 78 | 79 | print('\n\n') 80 | print('RESAMPLED CHARACTER') 81 | print(resample_key) 82 | print('ORIGINAL DESCRIPTION') 83 | print(resample_key + original_description) 84 | print('NEW DESCRIPTION') 85 | print(resample_key + entry['new_description']) 86 | print('PREDICTED CONTRADICTION ORIGINAL') 87 | print(entry['contradicted_original']) 88 | print('PREDICTED CONTRADICTION NEW') 89 | print(entry['contradictory_completion']) 90 | print('PREDICTED CONTRADICTION LOGPROB') 91 | print(entry['contradiction_logprob']) 92 | 93 | is_good = input('Is this description good? [y/n/s]') 94 | if is_good.startswith('n'): 95 | continue 96 | if is_good.startswith('s'): 97 | should_break = True 98 | break 99 | 100 | new_characters = characters.replace(original_description, entry['new_description']) 101 | new_character_strings = deepcopy(character_strings) 102 | for entity in new_character_strings.values(): 103 | entity.reset_attributes() 104 | entity.description = entity.description.replace(original_description, entry['new_description']) 105 | new_infer_attributes_string = infer_attributes_string.replace(original_description, entry['new_description']) 106 | should_break = True 107 | break 108 | if should_break: 109 | break 110 | 111 | with open(os.path.join(args.save_dir, 'already_labeled.txt'), 'a') as f: 112 | f.write(path + '\n') 113 | 114 | if not is_good.startswith('y'): # s for skip, n for no 115 | continue 116 | 117 | save_info['contradiction_logprob'] = entry['contradiction_logprob'] 118 | save_info['contradicted_part'] = entry['contradicted_original'] 119 | with open(os.path.join(args.save_dir, 'original', str(file_idx) + '.pkl'), 'wb') as wf: 120 | pickle.dump(save_info, wf) 121 | 122 | save_info['characters'] = new_characters 123 | save_info['character_strings'] = new_character_strings 124 | save_info['infer_attributes_string'] = new_infer_attributes_string 125 | save_info['contradicted_part'] = entry['contradictory_completion'] 126 | with open(os.path.join(args.save_dir, 'altered', str(file_idx) + '.pkl'), 'wb') as wf: 127 | pickle.dump(save_info, wf) 128 | 129 | file_idx += 1 -------------------------------------------------------------------------------- /scripts/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | from enum import auto 4 | import os 5 | from copy import deepcopy 6 | import pickle 7 | from collections import defaultdict 8 | import multiprocessing as mp 9 | import random 10 | import string 11 | import logging 12 | import json 13 | 14 | import torch 15 | import Levenshtein 16 | import numpy as np 17 | from transformers import AutoTokenizer 18 | import openai 19 | from scipy.special import softmax 20 | 21 | from story_generation.edit_module.entity import * 22 | from story_generation.draft_module.beam_candidate import BeamCandidate 23 | from story_generation.plan_module.plan import * 24 | from story_generation.common.util import * 25 | from story_generation.common.data.data_util import add_data_args, load_dataset 26 | from story_generation.common.summarizer.summarizer_util import add_summarizer_args, load_summarizer 27 | from story_generation.common.summarizer.models.gpt3_summarizer import GPT3_SEP, GPT3_END 28 | from story_generation.common.summarizer.models.opt_summarizer import OPTSummarizer 29 | from story_generation.common.controller.controller_util import add_controller_args, load_controller 30 | from story_generation.common.controller.loaders.alignment_loader import create_prefix_completion 31 | from story_generation.common.data.split_paragraphs import * 32 | 33 | 34 | if __name__=='__main__': 35 | parser = argparse.ArgumentParser() # parameter defaults are set to values used in paper 36 | parser = add_general_args(parser) 37 | parser = add_data_args(parser) 38 | parser = add_summarizer_args(parser) 39 | parser = add_controller_args(parser) 40 | 41 | parser.add_argument('--premise', type=str, default=None, help='Premise to use for generation') 42 | 43 | # SAVE/LOAD PLAN/LOGS 44 | parser.add_argument('--load-outline-file', type=str, help='load outline from this file') 45 | parser.add_argument('--save-outline-file', type=str, help='save outline to this file') 46 | parser.add_argument('--save-complete-file', type=str, help='save completed beam object to this file') 47 | parser.add_argument('--log-file', type=str, help='logging file', default=None) 48 | parser.add_argument('--log-level', type=int, default=22, help='logging level; decrease to 21 for full verbosity while suppressing stuff openai and urllib') 49 | 50 | # ALTERNATE MODES / ABLATIONS 51 | parser.add_argument('--setup-only', action='store_true', help='exit after generating the premise/setup/outline') 52 | parser.add_argument('--no-attributes', action='store_true', help='do not infer attributes') 53 | parser.add_argument('--no-editor', action='store_true', help='do not use editor to edit text for detected contradictions') 54 | parser.add_argument('--no-planner', action='store_true', help='do not planner beyond the initial setup') 55 | 56 | # SEARCH SIZE / BEAM PARAMETERS 57 | parser.add_argument('--max-candidates', type=int, default=10, help='max number of candidates to generate at each step by each beam candidate') 58 | parser.add_argument('--max-beam-size', type=int, default=1, help='max number of beam candidates to generate at each step') 59 | parser.add_argument('--beam-max-difference', type=float, default=1, help='max difference between beam scores') 60 | 61 | # OUTLINE PARAMETERS 62 | parser.add_argument('--fixed-outline-length', type=int, default=3, help='fixed length for outline; use -1 for no fixed length') 63 | parser.add_argument('--outline-levels', type=int, default=1, help='num levels of hierarchy in outline') 64 | 65 | # CONTINUATION ALIGNMENT / LENGTH PARAMETERS 66 | parser.add_argument('--continuation-threshold', type=float, default=10000, help='if alignment score is worse by at least this much, move on to next outline point; 10000 basically turns this off') 67 | parser.add_argument('--max-continuation-substeps', type=int, default=4, help='max number of continuation candidates to generate at each step') 68 | parser.add_argument('--max-ending-continuations', type=int, default=3, help='max number of continuation steps for ending the story') 69 | 70 | # PROMPT PARAMETERS 71 | parser.add_argument('--previous-prompt-length', type=int, default=256, help='length of previously generated text in prompt') 72 | parser.add_argument('--max-entity-context-tokens', type=int, default=128, help='max number of tokens to use for entity context') 73 | parser.add_argument('--entity-description-max-length', type=int, default=48, help='max number of tokens to use per entity description') 74 | 75 | # GENERATION PARAMETERS 76 | parser.add_argument('--extension-method', type=str, choices=['gpt3', 'opt'], default='gpt3', help='generator to use for main story drafting') 77 | parser.add_argument('--repetition-penalty-weight', type=float, default=5, help='weight of repetition penalty') 78 | parser.add_argument('--draft-top-p', type=float, default=1, help='initial top_p for beam search') 79 | parser.add_argument('--plan-model-string', type=str, default='text-davinci-002', help='gpt3 model string to use in planning') 80 | parser.add_argument('--draft-model-string', type=str, default='davinci', help='gpt3 model string to use in extending story') 81 | parser.add_argument('--cut-sentence', action='store_true', default=False, help='cut incomplete sentence at end of generation') 82 | 83 | args = parser.parse_args() 84 | 85 | if os.path.exists(args.save_complete_file): 86 | logging.log(25, 'save file already exists') 87 | sys.exit() 88 | 89 | os.makedirs(os.path.dirname(args.log_file), exist_ok=True) 90 | logging.basicConfig(format='%(message)s', filename=args.log_file, level=args.log_level) 91 | 92 | gpt3_model = load_summarizer(args) # naming is a relic of some old preliminary experiments; it's just a gpt3 interface 93 | controllers = [load_controller(args, i) for i in range(len(args.controller))] 94 | assert all([controller.type == 'sentence' for controller in controllers]) 95 | assert len(controllers) == 2, 'Re3 expects both a relevance and a coherence reranker; please see the example command in the README' 96 | 97 | opt_model = OPTSummarizer(args) if args.extension_method == 'opt' else None 98 | 99 | if args.load_outline_file is not None: 100 | if args.load_outline_file.endswith('.pkl'): 101 | save_info = load_plan_info(args.load_outline_file) 102 | else: 103 | with open(args.load_outline_file, 'r') as f: 104 | info = json.load(f) 105 | if 'character_strings' not in info: 106 | character_strings = {} 107 | for name, desc in info['character_info']: 108 | character_strings[name] = Entity(name, desc, is_character=True) 109 | save_info = {'premise': info['premise'], 110 | 'setting': info['setting'], 111 | 'character_strings': character_strings, 112 | 'outline': None, 113 | 'outline_sections': info['outline_sections'], 114 | 'infer_attributes_string': info['premise'] + '\n\n' + info['setting'] + '\n\n' + '\n\n'.join([c.description for c in character_strings.values()])} 115 | if (not args.no_attributes and not args.no_editor and not args.no_planner): # fill in the attributes if we need them, if they're not already present in the save 116 | infer_initial_attributes_from_plan(save_info, gpt3_model) 117 | else: 118 | save_info = generate_plan_info(args, gpt3_model, model_string=args.plan_model_string) 119 | if args.save_outline_file is not None: 120 | os.makedirs(os.path.dirname(args.save_outline_file), exist_ok=True) 121 | with open(args.save_outline_file, 'wb') as f: 122 | pickle.dump(save_info, f) 123 | if args.setup_only: 124 | sys.exit() 125 | 126 | premise = save_info['premise'] 127 | setting = save_info['setting'] 128 | character_strings = save_info['character_strings'] 129 | outline_sections = save_info['outline_sections'] 130 | infer_attributes_string = premise + '\n\n' + setting + '\n\n' + '\n\n'.join([c.description for c in character_strings.values()]) 131 | 132 | if args.no_attributes: 133 | all_entities_dict = {} 134 | else: 135 | all_entities_dict = deepcopy(character_strings) 136 | all_entities_dict['Premise'] = Entity('Premise', description='Premise: ' + premise.strip(), is_character=False) 137 | all_entities_dict['Setting'] = Entity('Setting', description='Setting: ' + setting.strip(), is_character=False) 138 | 139 | all_paragraphs = [] 140 | previous_alignment_score = -1e8 141 | beam = [BeamCandidate(args, 142 | all_entities_dict, 143 | infer_attributes_string, 144 | model=gpt3_model, 145 | opt_model=opt_model, 146 | controllers=controllers)] 147 | if not args.no_editor and not args.no_planner: 148 | for candidate in beam: 149 | candidate.all_entities_dict = candidate.create_updated_entities('\n\n'.join(outline_sections)) 150 | if args.no_planner: # only get the premise 151 | for candidate in beam: 152 | initial_keys = list(candidate.all_entities_dict.keys()) 153 | for key in initial_keys: 154 | if key != 'Premise': 155 | del candidate.all_entities_dict[key] 156 | outline_sections[-1] = outline_sections[-1] + ' This is the end of the story.' 157 | 158 | # restart from intermediate if exists 159 | for i in range(len(outline_sections)-1, -1, -1): 160 | if os.path.exists(args.save_complete_file + '.temp' + str(i)): 161 | logging.log(25, 'found temp file for section ' + str(i) + ', restarting from there') 162 | with open(args.save_complete_file + '.temp' + str(i), 'rb') as f: 163 | beam = pickle.load(f) 164 | for b in beam: 165 | b.controllers = controllers 166 | b.model = gpt3_model 167 | b.opt_model = opt_model 168 | break 169 | 170 | for i in range(len(outline_sections)): 171 | logging.log(25, '\n\n\n\niteration at step ' + str(i)) 172 | outline_section = outline_sections[i] 173 | if outline_section in beam[0].outline_sections: 174 | logging.log(25, 'already generated this section') 175 | continue 176 | extensions = sum([b.extend(outline_section) for b in beam], []) 177 | extensions = sorted(extensions, key=lambda x: x.best_alignment_so_far, reverse=True) 178 | # pick the best extension plus up to max_beam_size that are below some alignment threshold 179 | new_beam = [extensions[0]] 180 | for extension in extensions[1:args.max_beam_size]: 181 | if extension.best_alignment_so_far > extensions[0].best_alignment_so_far - args.beam_max_difference: # variable beam size 182 | new_beam.append(extension) 183 | beam = new_beam 184 | for b in beam: 185 | b.condense_outline_sections(None) 186 | logging.log(25, '\n\n\n\nend of iteration ' + str(i)) 187 | for entity in beam[0].all_entities_dict.values(): 188 | logging.debug(entity) 189 | logging.log(23, beam[0].story()) 190 | 191 | # save intermediate 192 | with open(args.save_complete_file + '.temp' + str(i), 'wb') as f: 193 | for b in beam: 194 | b.controllers = None 195 | pickle.dump(beam, f) 196 | for b in beam: 197 | b.controllers = controllers 198 | if i > 0 and os.path.exists(args.save_complete_file + '.temp' + str(i-1)): 199 | os.remove(args.save_complete_file + '.temp' + str(i-1)) 200 | 201 | for i in range(len(beam)): 202 | should_continue = True 203 | num_attempts = 0 204 | while should_continue: 205 | logging.log(25, 'BEAM ' + str(i) + ' ENDING ATTEMPT ' + str(num_attempts)) 206 | beam[i], should_continue = beam[i].complete_ending() 207 | num_attempts += 1 208 | if num_attempts >= args.max_ending_continuations: 209 | break 210 | 211 | logging.log(25, '\n\n\n\nFINAL STORY') 212 | logging.log(25, beam[0].story()) 213 | if args.save_complete_file is not None: 214 | with open(args.save_complete_file, 'wb') as wf: 215 | for b in beam: 216 | b.controllers = None 217 | pickle.dump(beam, wf) 218 | -------------------------------------------------------------------------------- /scripts/training/train_controller.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from story_generation.common.util import add_general_args 4 | from story_generation.common.data.data_util import add_data_args, load_dataset 5 | from story_generation.common.controller.controller_util import add_controller_args, load_controller 6 | 7 | if __name__=='__main__': 8 | parser = argparse.ArgumentParser() 9 | parser = add_general_args(parser) 10 | parser = add_data_args(parser) 11 | parser = add_controller_args(parser) 12 | args = parser.parse_args() 13 | 14 | assert args.controller_save_dir is not None 15 | 16 | controller = load_controller(args, 0) 17 | dataset = load_dataset(args) 18 | dataset.shuffle('train') 19 | controller.fit(dataset) 20 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup, find_packages 3 | 4 | setup(name='story_generation', version='0.1', packages=find_packages()) -------------------------------------------------------------------------------- /story_generation/common/controller/controller_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from story_generation.common.controller.models.longformer_classifier import LongformerClassifier 4 | 5 | CONTROLLER_CHOICES=['longformer_classifier'] 6 | LOADER_CHOICES=['coherence', 'alignment'] 7 | 8 | def add_controller_args(parser): 9 | parser.add_argument('--controller', type=str, nargs='*', default=['longformer_classifier'], choices=CONTROLLER_CHOICES, help='model architecture') 10 | parser.add_argument('--controller-model-string', type=str, nargs='*', default=['none'], help='model string') 11 | parser.add_argument('--loader', type=str, nargs='*', default=['coherence'], choices=LOADER_CHOICES, help='loader for controller') 12 | parser.add_argument('--controller-save-dir', type=str, default=None, help='directory to save controller') 13 | parser.add_argument('--controller-load-dir', type=str, nargs='*', default=[''], help='directory to load controller') 14 | parser.add_argument('--controller-epochs', type=int, default=1, help='number of epochs for controller finetuning') 15 | parser.add_argument('--fudge-time-label-decay', type=float, default=1.0, help='discounting for label weights over time for controller training') 16 | parser.add_argument('--control-strength', type=float, nargs='*', default=None, help='strength of control for controller inference') 17 | parser.add_argument('--fudge-top-k', type=int, nargs='*', default=[100], help='top k for fudge inference') 18 | parser.add_argument('--controller-num-negatives', type=int, default=1, help='number of negative samples for controller contrastive training') 19 | parser.add_argument('--coherence-negative-categories', type=str, nargs='*', default=['other', 'repeat', 'shuffle'], help='types of negatives for coherence') 20 | parser.add_argument('--controller-margin', type=int, default=1, help='margin for controller contrastive training') 21 | parser.add_argument('--hierarchical-sentence-encoder', action='store_true', help='use hierarchical sentence encoder in sentence prefix completion classifier') 22 | parser.add_argument('--hierarchical-sentence-position-encodings', action='store_true', help='use hierarchical sentence position encodings in sentence prefix completion classifier') 23 | parser.add_argument('--freeze-epochs', type=int, default=0, help='number of epochs to freeze pretrained backbone') 24 | parser.add_argument('--controller-lr', type=float, default=5e-5, help='learning rate for controller finetuning') 25 | parser.add_argument('--use-beginning-middle-tokens', action='store_true', help='use special beginning/middle tokens for coherence training') 26 | parser.add_argument('--coherence-eval-index', type=int, default=None, help='index of controller to use for coherence eval') 27 | parser.add_argument('--eval-only-controllers', type=int, nargs='*', default=[], help='indices of controllers to use for eval only') 28 | parser.add_argument('--sentence-coherence-control-mode', type=str, default=None, choices=['rerank', 'greedy-sentence', 'beam-sentence'], help='how to use sentence-level coherence controller for inference') 29 | return parser 30 | 31 | def load_controller(args, index): 32 | if args.controller[index] == 'longformer_classifier': 33 | controller = LongformerClassifier(args, index) 34 | if len(args.controller_load_dir[index]) > 0: 35 | controller.load(args.controller_load_dir[index]) 36 | else: 37 | raise NotImplementedError 38 | return controller -------------------------------------------------------------------------------- /story_generation/common/controller/loader_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from story_generation.common.controller.loaders.coherence_loader import CoherenceSplitLoader 4 | from story_generation.common.controller.loaders.alignment_loader import AlignmentSplitLoader 5 | 6 | def get_loader(loader_name, dataset, split, collate_fn, batch_size=32, append_mask_token=False, num_workers=20, tokenizer_model='roberta-base', time_label_decay=1, **kwargs): 7 | assert split in ['train', 'valid', 'test'] 8 | if loader_name == 'coherence': 9 | loader_class = CoherenceSplitLoader 10 | elif loader_name == 'alignment': 11 | loader_class = AlignmentSplitLoader 12 | else: 13 | raise NotImplementedError 14 | print('loading long short texts for data loader') 15 | contents, summaries = dataset.load_long_texts(split, split_paragraphs=False), dataset.load_short_texts(split, split_paragraphs=False) 16 | print('done loading long short texts') 17 | return torch.utils.data.DataLoader(loader_class(contents, summaries, tokenizer_model, append_mask_token=False, time_label_decay=time_label_decay, **kwargs), batch_size=batch_size, pin_memory=True, collate_fn=collate_fn, num_workers=num_workers) 18 | -------------------------------------------------------------------------------- /story_generation/common/controller/loaders/alignment_loader.py: -------------------------------------------------------------------------------- 1 | from calendar import c 2 | from concurrent.futures import process 3 | from lib2to3.pgen2 import token 4 | import random 5 | import os 6 | import pickle 7 | import math 8 | import string 9 | import json 10 | from collections import defaultdict, namedtuple 11 | import multiprocessing as mp 12 | from functools import partial 13 | 14 | import pandas as pd 15 | import numpy as np 16 | from tqdm import tqdm, trange 17 | import torch 18 | from transformers import AutoTokenizer 19 | 20 | from story_generation.common.data.split_paragraphs import split_paragraphs 21 | from story_generation.common.data.tree_util import START_OF_STORY, MIDDLE_OF_STORY 22 | 23 | def create_prefix_completion(content, summary): 24 | prefix = 'Full text:\n\n\n\n' + content + '\n\n\n\n' + 'Summary:\n\n\n\n' 25 | completion = 'Full text:\n\n\n\n' + content + '\n\n\n\n' + 'Summary:\n\n\n\n' + summary 26 | return prefix, completion 27 | 28 | class AlignmentSplitLoader(torch.utils.data.IterableDataset): 29 | def __init__(self, contents, summaries, tokenizer_model, append_mask_token=False, time_label_decay=1, **kwargs): 30 | super(AlignmentSplitLoader).__init__() 31 | if append_mask_token: 32 | raise NotImplementedError 33 | assert len(contents) == len(summaries) 34 | self.contents = contents 35 | self.summaries = summaries 36 | self.tokenizer_model = tokenizer_model 37 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_model) 38 | self.append_mask_token = append_mask_token 39 | self.time_label_decay = time_label_decay 40 | self.tokenized_info = kwargs['tokenized_info'] if 'tokenized_info' in kwargs else False 41 | self.negative_categories = kwargs['negative_categories'] if 'negative_categories' in kwargs else ['other', 'shuffle'] 42 | self.generate_negatives = kwargs['generate_negatives'] if 'generate_negatives' in kwargs else False 43 | if self.generate_negatives: 44 | assert 'num_negatives' in kwargs 45 | self.num_negatives = kwargs['num_negatives'] 46 | self.pos = 0 47 | 48 | def __len__(self): 49 | return len(self.contents) 50 | 51 | def __iter__(self): 52 | return self 53 | 54 | def __next__(self): 55 | increment = 1 56 | worker_info = torch.utils.data.get_worker_info() 57 | if worker_info is not None: # # in a worker process 58 | increment = worker_info.num_workers 59 | worker_id = worker_info.id 60 | if self.pos == 0: 61 | self.pos = worker_id 62 | valid = False 63 | while not valid: 64 | if self.pos >= len(self.contents): 65 | raise StopIteration 66 | summary = self.summaries[self.pos].split('\t') 67 | content = self.contents[self.pos].split('\t') 68 | assert len(summary) == len(content) 69 | selected_idx = random.randint(0, len(content)-1) 70 | 71 | possible_modes = ['true'] 72 | if 'other' in self.negative_categories: 73 | possible_modes.append('other') 74 | if len(content) > 1 and 'shuffle' in self.negative_categories: # no shuffle if only 1 paragraph 75 | possible_modes.append('shuffle') 76 | 77 | if self.generate_negatives: 78 | completions = set() 79 | all_examples = [] 80 | true_example, true_completion = self.create_example('true', content, summary, selected_idx) 81 | all_examples.append(true_example) 82 | completions.add(true_completion) 83 | for _ in range(self.num_negatives): 84 | while True: 85 | mode = random.choice(possible_modes) 86 | if mode == 'true': 87 | continue 88 | neg_example, neg_completion = self.create_example(mode, content, summary, selected_idx) 89 | if neg_completion not in completions: 90 | all_examples.append(neg_example) 91 | completions.add(neg_completion) 92 | break 93 | else: 94 | mode = random.choice(possible_modes) 95 | example, _ = self.create_example(mode, content, summary, selected_idx) 96 | all_examples = example 97 | 98 | valid = True 99 | self.pos += increment 100 | return all_examples 101 | 102 | def create_example(self, mode, content, summary, selected_idx): 103 | # in practice, for a given summary, you want to discriminate against different possible contents. so follow that setup here. 104 | if mode == 'true': 105 | selected_content = content[selected_idx] 106 | label = np.array([1]) 107 | # create shuffled sentence example 108 | elif mode == 'shuffle': 109 | idx = selected_idx 110 | while idx == selected_idx: 111 | idx = random.randint(0, len(content)-1) 112 | selected_content = content[idx] 113 | label = np.array([0]) 114 | # create random other story example 115 | elif mode == 'other': 116 | selected_content = random.choice(self.contents[random.randint(0, len(self.contents)-1)].split('\t')) 117 | label = np.array([0]) 118 | selected_content = selected_content.replace("\n\n\n\nOne-sentence summary:", "") 119 | prefix, completion = create_prefix_completion(selected_content, summary[selected_idx]) 120 | tokenized_summary = [self.tokenizer.eos_token_id] + self.tokenizer.encode(summary[selected_idx]) if 'bart' in self.tokenizer_model else self.tokenizer.encode(summary[selected_idx]) 121 | tokenized_prefix = [self.tokenizer.eos_token_id] + self.tokenizer.encode(prefix) if 'bart' in self.tokenizer_model else self.tokenizer.encode(prefix) 122 | tokenized_completion = [self.tokenizer.eos_token_id] + self.tokenizer.encode(completion) if 'bart' in self.tokenizer_model else self.tokenizer.encode(completion) 123 | loss_mask = np.array([0 for _ in range(len(tokenized_prefix))] + [1 for _ in range(len(tokenized_completion) - len(tokenized_prefix))]) 124 | 125 | if self.tokenized_info: 126 | # prefix_info: 'input_ids', 'attention_mask' (all 1) 127 | prefix_info = self.tokenizer(selected_content, return_tensors='pt') 128 | # completion_info: 'input_ids', 'attention_mask' 129 | completion_info = self.tokenizer(summary[selected_idx], return_tensors='pt') 130 | # reversed_prefix_sentence_info: 'input_ids', 'attention_mask' 131 | content_sentences = split_paragraphs(selected_content, mode='sentence') 132 | reversed_prefix_sentence_info = self.tokenizer(list(reversed([s for s in content_sentences if len(s.strip()) > 0])), return_tensors='pt', padding=True) 133 | else: 134 | prefix_info, completion_info, reversed_prefix_sentence_info = None, None, None 135 | 136 | example = {'prefix': tokenized_completion, # you actually want to run on all of the completion, and then mask out the tokenized_prefix sometimes 137 | 'labels': label, 138 | 'summary': tokenized_summary, 139 | 'loss_mask': loss_mask, 140 | 'prefix_info': prefix_info, 141 | 'completion_info': completion_info, 142 | 'reversed_prefix_sentence_info': reversed_prefix_sentence_info, 143 | } 144 | 145 | return example, completion -------------------------------------------------------------------------------- /story_generation/common/controller/loaders/coherence_loader.py: -------------------------------------------------------------------------------- 1 | from calendar import c 2 | from concurrent.futures import process 3 | from lib2to3.pgen2 import token 4 | import random 5 | import os 6 | import pickle 7 | import math 8 | import string 9 | import json 10 | from collections import defaultdict, namedtuple 11 | import multiprocessing as mp 12 | from functools import partial 13 | 14 | import pandas as pd 15 | import numpy as np 16 | from tqdm import tqdm, trange 17 | import torch 18 | from transformers import AutoTokenizer 19 | 20 | from story_generation.common.data.split_paragraphs import split_paragraphs, group_chunks 21 | from story_generation.common.data.tree_util import START_OF_STORY, MIDDLE_OF_STORY 22 | 23 | class CoherenceSplitLoader(torch.utils.data.IterableDataset): 24 | def __init__(self, contents, summaries, tokenizer_model, append_mask_token=False, time_label_decay=1, **kwargs): 25 | super(CoherenceSplitLoader).__init__() 26 | if append_mask_token: 27 | raise NotImplementedError 28 | assert len(contents) == len(summaries) 29 | self.contents = contents 30 | self.summaries = summaries 31 | self.tokenizer_model = tokenizer_model 32 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_model) 33 | self.append_mask_token = append_mask_token 34 | self.time_label_decay = time_label_decay 35 | self.tokenized_info = kwargs['tokenized_info'] if 'tokenized_info' in kwargs else False 36 | self.use_special_tokens = kwargs['use_special_tokens'] if 'use_special_tokens' in kwargs else False 37 | self.negative_categories = kwargs['negative_categories'] if 'negative_categories' in kwargs else ['other', 'repeat', 'shuffle'] 38 | self.generate_negatives = kwargs['generate_negatives'] if 'generate_negatives' in kwargs else False 39 | if self.generate_negatives: 40 | assert 'num_negatives' in kwargs 41 | self.num_negatives = kwargs['num_negatives'] 42 | self.pos = 0 43 | 44 | def __len__(self): 45 | return len(self.contents) 46 | 47 | def __iter__(self): 48 | return self 49 | 50 | def __next__(self): 51 | increment = 1 52 | worker_info = torch.utils.data.get_worker_info() 53 | if worker_info is not None: # # in a worker process 54 | increment = worker_info.num_workers 55 | worker_id = worker_info.id 56 | if self.pos == 0: 57 | self.pos = worker_id 58 | valid = False 59 | while not valid: 60 | if self.pos >= len(self.contents): 61 | raise StopIteration 62 | try: 63 | summary = self.summaries[self.pos] 64 | tokenized_summary = self.tokenizer.encode(summary) 65 | base_content = self.contents[self.pos] 66 | 67 | # segment into sentences 68 | sentences = split_paragraphs(base_content, mode='sentence') 69 | if len(sentences) < 2: 70 | self.pos += increment 71 | continue 72 | sentences = group_chunks(sentences, max_chunk_length=200) # so actually paragraphs 73 | if self.use_special_tokens: 74 | sentences = [START_OF_STORY] + sentences 75 | if random.random() < 0.5: # chance to cutoff mid-story 76 | try: 77 | pre_cutoff = random.randint(2, len(sentences) - 1) # if omitting previous details, omit something other than start token 78 | except: 79 | self.pos += increment 80 | continue 81 | sentences = [MIDDLE_OF_STORY] + sentences[pre_cutoff:] 82 | # cutoff at some sentence 83 | try: 84 | cutoff = random.randint(1 if self.use_special_tokens else 0, len(sentences)-1) 85 | except: 86 | self.pos += increment 87 | continue 88 | prefix = ' '.join([s.strip() for s in sentences[:cutoff]]) 89 | if len(prefix.strip()) == 0: 90 | self.pos += increment 91 | continue 92 | tokenized_prefix = [self.tokenizer.eos_token_id] + self.tokenizer.encode(prefix) if 'bart' in self.tokenizer_model else self.tokenizer.encode(prefix) 93 | 94 | # select true, repetition, shuffled sentence, random other story 95 | possible_modes = ['true'] 96 | if 'other' in self.negative_categories: 97 | possible_modes.append('other') 98 | if cutoff > 0 and 'repeat' in self.negative_categories: # can't repeat if don't have anything to repeat yet 99 | possible_modes.append('repeat') 100 | if cutoff < len(sentences) - 1 and 'shuffle' in self.negative_categories: # no shuffle if only 1 sentence left 101 | possible_modes.append('shuffle') 102 | 103 | if self.generate_negatives: 104 | completions = set() 105 | all_examples = [] 106 | true_example, true_completion = self.create_example('true', sentences, cutoff, prefix, tokenized_prefix, tokenized_summary) 107 | all_examples.append(true_example) 108 | completions.add(true_completion) 109 | for _ in range(self.num_negatives): 110 | while True: 111 | mode = random.choice(possible_modes) 112 | if mode == 'true': 113 | continue 114 | neg_example, neg_completion = self.create_example(mode, sentences, cutoff, prefix, tokenized_prefix, tokenized_summary) 115 | if neg_completion not in completions: 116 | all_examples.append(neg_example) 117 | completions.add(neg_completion) 118 | break 119 | else: 120 | mode = random.choice(possible_modes) 121 | example, _ = self.create_example(mode, sentences, cutoff, prefix, tokenized_prefix, tokenized_summary) 122 | all_examples = example 123 | 124 | valid = True 125 | self.pos += increment 126 | except: 127 | self.pos += increment 128 | continue 129 | return all_examples 130 | 131 | def create_example(self, mode, sentences, cutoff, prefix, tokenized_prefix, tokenized_summary): 132 | if mode == 'true': 133 | separate_completion = sentences[cutoff] 134 | # completion = ' '.join([s.strip() for s in sentences[:cutoff+1]]) 135 | completion = prefix.strip() + ' ' + separate_completion 136 | label = np.array([1]) 137 | # create repetition example 138 | elif mode == 'repeat': 139 | separate_completion = random.choice(sentences[:cutoff]).strip() # random already used sentence 140 | completion = prefix.strip() + ' ' + separate_completion 141 | label = np.array([0]) 142 | # create shuffled sentence example 143 | elif mode == 'shuffle': 144 | separate_completion = random.choice(sentences[cutoff+1:]).strip() # random out of order sentence 145 | completion = prefix.strip() + ' ' + separate_completion 146 | label = np.array([0]) 147 | # create random other story example 148 | elif mode == 'other': 149 | other_content_sentences = [] 150 | while len(other_content_sentences) == 0: 151 | other_content = self.contents[random.randint(0, len(self.contents)-1)] 152 | other_content_sentences = split_paragraphs(other_content, mode='sentence') 153 | other_content_sentences = group_chunks(other_content_sentences, max_chunk_length=200) # so actually paragraphs 154 | separate_completion = random.choice(other_content_sentences).strip() 155 | completion = prefix.strip() + ' ' + separate_completion 156 | label = np.array([0]) 157 | # print('MODE', mode) 158 | # print('PREFIX', prefix) 159 | # print('SEPARATE COMPLETION', separate_completion) 160 | # import pdb; pdb.set_trace() 161 | tokenized_completion = [self.tokenizer.eos_token_id] + self.tokenizer.encode(completion) if 'bart' in self.tokenizer_model else self.tokenizer.encode(completion) 162 | loss_mask = np.array([0 for _ in range(len(tokenized_prefix))] + [1 for _ in range(len(tokenized_completion) - len(tokenized_prefix))]) 163 | 164 | if self.tokenized_info: 165 | # prefix_info: 'input_ids', 'attention_mask' (all 1) 166 | prefix_info = self.tokenizer(prefix, return_tensors='pt') 167 | # completion_info: 'input_ids', 'attention_mask' 168 | completion_info = self.tokenizer(separate_completion, return_tensors='pt') 169 | # reversed_prefix_sentence_info: 'input_ids', 'attention_mask' 170 | reversed_prefix_sentence_info = self.tokenizer(list(reversed([s for s in sentences[:cutoff] if len(s.strip()) > 0])), return_tensors='pt', padding=True) 171 | else: 172 | prefix_info, completion_info, reversed_prefix_sentence_info = None, None, None 173 | 174 | example = {'prefix': tokenized_completion, # you actually want to run on all of the completion, and then mask out the tokenized_prefix sometimes 175 | 'labels': label, 176 | 'summary': tokenized_summary, 177 | 'loss_mask': loss_mask, 178 | 'prefix_info': prefix_info, 179 | 'completion_info': completion_info, 180 | 'reversed_prefix_sentence_info': reversed_prefix_sentence_info, 181 | } 182 | 183 | return example, completion -------------------------------------------------------------------------------- /story_generation/common/controller/models/abstract_controller.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | class AbstractController(ABC): 4 | @abstractmethod 5 | def __call__(self, lm_logits, full_decoder_input_ids, keyword_ids): 6 | pass 7 | 8 | @abstractmethod 9 | def reset_cache(self): 10 | pass 11 | 12 | @abstractmethod 13 | def fit(self, dataset): 14 | pass 15 | 16 | @abstractmethod 17 | def save(self, path): 18 | pass 19 | 20 | @abstractmethod 21 | def load(self, path): 22 | pass -------------------------------------------------------------------------------- /story_generation/common/controller/models/longformer_classifier.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from tqdm import tqdm 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from transformers import AdamW, AutoModelForSequenceClassification, AutoTokenizer 9 | 10 | from story_generation.common.controller.models.abstract_controller import AbstractController 11 | from story_generation.common.util import AverageMeter, pad_to_max_length, pad_mask 12 | from story_generation.common.controller.loader_util import get_loader 13 | from story_generation.common.data.split_paragraphs import split_paragraphs 14 | from story_generation.common.data.tree_util import START_OF_STORY, MIDDLE_OF_STORY 15 | 16 | 17 | class LongformerClassifier(AbstractController): 18 | def __init__(self, args, index): 19 | self.type = 'sentence' 20 | self.index = index 21 | self.model_string = args.controller_model_string[index] if args.controller_model_string[index] != 'none' else 'allenai/longformer-base-4096' 22 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu' 23 | self.args = args 24 | self.trained = False 25 | self.loader_type = self.args.loader[self.index] 26 | self.model = AutoModelForSequenceClassification.from_pretrained(self.model_string, num_labels=2).to(self.device) 27 | self.tokenizer = AutoTokenizer.from_pretrained(self.model_string) 28 | self.optimizer = AdamW(self.model.parameters(), lr=args.controller_lr) 29 | 30 | def reset_cache(self): 31 | pass 32 | 33 | @torch.no_grad() 34 | def evaluate_full_texts(self, texts, reduce='mean', add_prefix=True): 35 | # evaluate by prefix one sentence at a time 36 | all_scores = [] 37 | for text in texts: 38 | while '\n\n' in text: 39 | text = text.replace('\n\n', '\n') 40 | text = text.replace('\n', ' ').strip() # since that's how it's like when trained 41 | sentences = split_paragraphs(text, mode='sentence') 42 | current_text = [] 43 | if add_prefix and self.args.use_beginning_middle_tokens: 44 | current_text.append(START_OF_STORY) 45 | eval_texts, eval_sentences = [], [] 46 | for sentence in sentences: 47 | while len(self.tokenizer.encode(' '.join(current_text + [sentence]))) > self.tokenizer.model_max_length: 48 | if self.args.use_beginning_middle_tokens: 49 | current_text = [MIDDLE_OF_STORY] + current_text[2:] # delete the special start token, then one extra sentence 50 | if len(current_text) == 1: 51 | break 52 | else: 53 | current_text = current_text[1:] 54 | if len(current_text) == 0: 55 | break 56 | eval_texts.append(' '.join(current_text)) 57 | if len(self.tokenizer.encode(' '.join(current_text + [sentence]))) > self.tokenizer.model_max_length: # rare edge case of one super long sentence 58 | eval_sentences.append(self.tokenizer.decode(self.tokenizer.encode(' '.join(current_text + [sentence]))[:self.tokenizer.model_max_length])) 59 | else: 60 | eval_sentences.append(sentence.strip()) 61 | current_text.append(sentence.strip()) 62 | scores = self(eval_texts, eval_sentences) # should get scores or logprobs 63 | all_scores.append(scores.mean().item()) 64 | if reduce == 'mean': 65 | return np.mean(all_scores) 66 | elif reduce == 'none': 67 | return all_scores 68 | else: 69 | raise NotImplementedError 70 | 71 | @torch.no_grad() 72 | def __call__(self, texts, sentences): 73 | assert len(texts) == len(sentences) 74 | all_texts = [] 75 | for text, sentence in zip(texts, sentences): 76 | while '\n\n' in text: 77 | text = text.replace('\n\n', '\n') 78 | text = text.replace('\n', ' ').strip() # since that's how it's like when trained 79 | text = text + ' ' + sentence 80 | all_texts.append(text.strip()) 81 | batch = self.tokenizer(all_texts, return_tensors="pt", padding=True) 82 | batch = {k: v.to(self.device) for k, v in batch.items()} 83 | outputs = self.model(**batch) 84 | logits = outputs.logits 85 | positive_log_probs = F.softmax(logits, dim=-1)[:, 1].log() 86 | return positive_log_probs 87 | 88 | @torch.no_grad() 89 | def evaluate_overall_texts(self, texts): 90 | batch = self.tokenizer(texts, return_tensors="pt", padding=True) 91 | batch = {k: v.to(self.device) for k, v in batch.items()} 92 | outputs = self.model(**batch) 93 | logits = outputs.logits 94 | positive_log_probs = F.softmax(logits, dim=-1)[:, 1].log() 95 | return positive_log_probs 96 | 97 | def fit(self, dataset): 98 | best_val_loss = 1e8 99 | for epoch in range(self.args.controller_epochs): 100 | dataset.shuffle('train') 101 | train_loader = get_loader(self.args.loader[self.index], 102 | dataset, 103 | 'train', 104 | longformer_classifier_collate, 105 | batch_size=self.args.batch_size, 106 | append_mask_token=False, 107 | tokenizer_model=self.model_string, 108 | num_workers=self.args.num_workers, 109 | time_label_decay=self.args.fudge_time_label_decay, 110 | generate_negatives=True, 111 | num_negatives=self.args.controller_num_negatives, 112 | negative_categories=self.args.coherence_negative_categories, 113 | use_special_tokens=self.args.use_beginning_middle_tokens) 114 | loop = tqdm(train_loader, leave=True) 115 | loss_meter = AverageMeter('loss', ':6.4f') 116 | for batch in loop: 117 | # initialize calculated gradients (from prev step) 118 | self.optimizer.zero_grad() 119 | # pull all tensor batches required for training 120 | input_ids = batch['input_ids'].to(self.device) 121 | if input_ids.shape[0] < self.args.batch_size: # don't do the last batch if smaller 122 | continue 123 | attention_mask = batch['attention_mask'].to(self.device) 124 | labels = batch['labels'].to(self.device) 125 | # process 126 | outputs = self.model(input_ids, attention_mask=attention_mask, labels=labels) 127 | loss = outputs.loss 128 | loss.backward() 129 | # update parameters 130 | self.optimizer.step() 131 | loss_meter.update(loss.detach().item(), input_ids.shape[0]) 132 | # print relevant info to progress bar 133 | loop.set_description(f'Epoch {epoch}') 134 | loop.set_postfix(loss=loss.item()) 135 | print('Training epoch {} average loss {}'.format(epoch, loss_meter.avg)) 136 | 137 | valid_loader = get_loader(self.args.loader[self.index], 138 | dataset, 139 | 'valid', 140 | longformer_classifier_collate, 141 | batch_size=self.args.batch_size, 142 | append_mask_token=False, 143 | tokenizer_model=self.model_string, 144 | num_workers=self.args.num_workers, 145 | time_label_decay=self.args.fudge_time_label_decay, 146 | generate_negatives=True, 147 | num_negatives=self.args.controller_num_negatives, 148 | negative_categories=self.args.coherence_negative_categories, 149 | use_special_tokens=self.args.use_beginning_middle_tokens) 150 | loop = tqdm(valid_loader, leave=True) 151 | loss_meter = AverageMeter('loss', ':6.4f') 152 | with torch.no_grad(): 153 | for batch in loop: 154 | # pull all tensor batches required for training 155 | input_ids = batch['input_ids'].to(self.device) 156 | attention_mask = batch['attention_mask'].to(self.device) 157 | labels = batch['labels'].to(self.device) 158 | # process 159 | outputs = self.model(input_ids, attention_mask=attention_mask, labels=labels) 160 | loss = outputs.loss 161 | loss_meter.update(loss.item(), input_ids.shape[0]) 162 | # print relevant info to progress bar 163 | loop.set_description(f'Epoch {epoch}') 164 | loop.set_postfix(loss=loss.item()) 165 | print('Validation epoch {} average loss {}'.format(epoch, loss_meter.avg)) 166 | if loss_meter.avg < best_val_loss: 167 | print('Found new best model. Saving...') 168 | best_val_loss = loss_meter.avg 169 | self.save(os.path.join(self.args.controller_save_dir, 'model_best.pth.tar')) 170 | 171 | self.trained = True 172 | 173 | def save(self, path): 174 | os.makedirs(os.path.dirname(path), exist_ok=True) 175 | torch.save({ 176 | 'state_dict': self.model.state_dict(), 177 | 'optimizer': self.optimizer.state_dict(), 178 | 'args': self.args 179 | }, path) 180 | 181 | def load(self, path): 182 | try: 183 | checkpoint = torch.load(path, map_location=self.device) 184 | except: 185 | checkpoint = torch.load(os.path.join(path, 'model_best.pth.tar'), map_location=self.device) 186 | self.model.load_state_dict(checkpoint['state_dict']) 187 | self.optimizer.load_state_dict(checkpoint['optimizer']) 188 | self.trained = True 189 | 190 | 191 | def longformer_classifier_collate(batch): 192 | batch = sum(batch, []) 193 | lengths = torch.LongTensor([len(p['prefix']) for p in batch]) 194 | inputs = [torch.LongTensor(p['prefix']) for p in batch] 195 | input_ids = torch.stack(pad_to_max_length(inputs, 0), dim=0) 196 | attention_mask = pad_mask(lengths).permute(1, 0) 197 | labels = torch.stack([torch.from_numpy(p['labels']) for p in batch], dim=0) 198 | return {'input_ids': input_ids, 199 | 'attention_mask': attention_mask, 200 | 'labels': labels, 201 | 'lengths': lengths} -------------------------------------------------------------------------------- /story_generation/common/data/data_util.py: -------------------------------------------------------------------------------- 1 | from story_generation.common.data.datasets.writing_prompts import WritingPromptsDataset 2 | from story_generation.common.data.datasets.alignment import AlignmentDataset 3 | from story_generation.common.data.split_paragraphs import SPLIT_PARAGRAPH_MODES 4 | 5 | DATASET_CHOICES=['writing_prompts', 'alignment'] 6 | # if providing a csv, shold give the full path to csv in data-dir. only for inference. 7 | 8 | def add_data_args(parser): 9 | parser.add_argument('--dataset', type=str, default='writing_prompts', choices=DATASET_CHOICES, help='dataset format') 10 | parser.add_argument('--data-dir', type=str, help='data directory') 11 | parser.add_argument('--split-sizes', type=float, nargs=3, default=[0.8, 0.1, 0.1], help='train/val/test proportions for datasets where not provided') 12 | parser.add_argument('--summarizer-prediction-split', type=str, default='valid', help='split to use for summarizer predictions') 13 | parser.add_argument('--limit', type=int, default=None, help='limit the number of examples') 14 | parser.add_argument('--length-limit', type=int, default=1000000, help='limit the number of words per example') 15 | parser.add_argument('--lower-length-limit', type=int, default=0, help='limit the number of words per example') 16 | parser.add_argument('--summary-length-limit', type=int, default=1000000, help='limit the number of words in the summary') 17 | parser.add_argument('--single-sentence-summary', action='store_true', help='use single sentence summary data only') 18 | parser.add_argument('--split-long-paragraph-mode', type=str, default='none', choices=SPLIT_PARAGRAPH_MODES, help='split long paragraph mode') 19 | parser.add_argument('--split-short-paragraph-mode', type=str, default='none', choices=SPLIT_PARAGRAPH_MODES, help='split short paragraph mode') 20 | parser.add_argument('--extra-keywords', type=int, default=0, help='max number of extra keywords from long content to add to short content') 21 | parser.add_argument('--hallucinate-keywords', action='store_true', default=False, help='hallucinate keywords from short content') 22 | parser.add_argument('--keyword-file', type=str, default='/home/yangk/data/glove/glove.840B.300d.vocab', help='file to load keywords from') 23 | parser.add_argument('--keyword-temperature', type=float, default=1.0, help='temperature for keyword sampling') 24 | parser.add_argument('--csv-column', type=str, help='column name to use as input for csv') 25 | parser.add_argument('--num-workers', type=int, default=20, help='number of workers for data loading') 26 | return parser 27 | 28 | def load_dataset(args): 29 | if args.dataset == 'writing_prompts': 30 | dataset = WritingPromptsDataset(args) 31 | elif args.dataset == 'alignment': 32 | dataset = AlignmentDataset(args) 33 | else: 34 | raise NotImplementedError 35 | return dataset -------------------------------------------------------------------------------- /story_generation/common/data/datasets/abstract_dataset.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | class Dataset(ABC): 4 | @abstractmethod 5 | def __init__(self, args): 6 | pass 7 | 8 | @abstractmethod 9 | def shuffle(self, split, seed=None): 10 | pass 11 | 12 | @abstractmethod 13 | def load_long_texts(self, split='train', limit=None, split_paragraphs=False): 14 | pass 15 | 16 | @abstractmethod 17 | def load_short_texts(self, split='train', limit=None, split_paragraphs=False): 18 | pass 19 | 20 | @abstractmethod 21 | def pandas_format(self, split, long_name='content', short_name='title', limit=None): 22 | pass -------------------------------------------------------------------------------- /story_generation/common/data/datasets/alignment.py: -------------------------------------------------------------------------------- 1 | from calendar import c 2 | import random 3 | import os 4 | import csv 5 | import pickle 6 | import math 7 | import string 8 | from collections import defaultdict, namedtuple 9 | import multiprocessing as mp 10 | 11 | import numpy as np 12 | from tqdm import tqdm, trange 13 | import torch 14 | import pandas as pd 15 | 16 | from story_generation.common.data.datasets.abstract_dataset import Dataset 17 | from story_generation.common.data.split_paragraphs import split_texts 18 | 19 | 20 | class AlignmentDataset(Dataset): 21 | def __init__(self, args): 22 | print('loading data') 23 | random.seed(args.seed) 24 | self.args = args 25 | self.debug = args.debug 26 | self.batch_size = args.batch_size 27 | self.data_dir = args.data_dir 28 | 29 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 30 | 31 | self.splits = {} 32 | df = pd.read_csv(args.data_dir, delimiter=',', quotechar='"', skipinitialspace=True) 33 | text1 = [text.strip().replace('\n\n\n\nSummarize this passage.\n\n\n\n', '') for text in getattr(df, 'text1').tolist()][:args.limit] 34 | text2 = [text.strip() for text in getattr(df, 'text2').tolist()][:args.limit] 35 | # each item in text1 and text2 is actually a tab-separated list, different from other datasets 36 | # assume longer texts come first 37 | 38 | assert sum(args.split_sizes) == 1 39 | train_end = int(len(text1) * args.split_sizes[0]) 40 | valid_end = int(len(text1) * (args.split_sizes[0] + args.split_sizes[1])) 41 | self.splits['train'] = (text1[:train_end], text2[:train_end]) 42 | self.splits['valid'] = (text1[train_end:valid_end], text2[train_end:valid_end]) 43 | self.splits['test'] = (text1[valid_end:], text2[valid_end:]) 44 | 45 | print('done loading data') 46 | print('split sizes:') 47 | for key in ['train', 'valid', 'test']: 48 | print(key, len(self.splits[key])) 49 | 50 | def load_long_texts(self, split='train', limit=None, split_paragraphs=False): 51 | texts = self.splits[split][0] 52 | return split_texts(texts if limit is None else texts[:limit], mode=self.args.split_long_paragraph_mode if split_paragraphs else 'none') 53 | 54 | def load_short_texts(self, split='train', limit=None, split_paragraphs=False): 55 | texts = self.splits[split][1] 56 | return split_texts(texts if limit is None else texts[:limit], mode=self.args.split_short_paragraph_mode if split_paragraphs else 'none') 57 | 58 | def pandas_format(self, split, long_name='content', short_name='title', limit=None): 59 | raise NotImplementedError 60 | 61 | def shuffle(self, split, seed=None): 62 | assert split in ['train', 'valid', 'test'] 63 | if seed is not None: 64 | random.seed(seed) 65 | indices = list(range(len(self.splits[split][0]))) 66 | random.shuffle(indices) 67 | self.splits[split] = ([self.splits[split][0][i] for i in indices], [self.splits[split][1][i] for i in indices]) 68 | 69 | -------------------------------------------------------------------------------- /story_generation/common/data/datasets/writing_prompts.py: -------------------------------------------------------------------------------- 1 | from calendar import c 2 | import random 3 | import os 4 | import pickle 5 | import math 6 | from re import U 7 | import string 8 | from collections import defaultdict, namedtuple 9 | import multiprocessing as mp 10 | 11 | import numpy as np 12 | from tqdm import tqdm, trange 13 | import torch 14 | from transformers import BartTokenizerFast 15 | import pandas as pd 16 | 17 | from story_generation.common.data.datasets.abstract_dataset import Dataset 18 | from story_generation.common.data.split_paragraphs import split_texts 19 | # from story_generation.common.data.keywords import extract_keywords, hallucinate_keywords, KEYWORD_PROMPT, KEYWORD_SEP, KEYWORD_END 20 | 21 | 22 | def preprocess(texts): 23 | # remaining known edge cases: 24 | # people who use ' as quotation marks (definitely not me) 25 | all_fixed = [] 26 | for text in texts: 27 | if text.startswith('['): 28 | text = ']'.join(text.split(']')[1:]) # remove leading brackets 29 | fixed = '' 30 | text = text.replace(' ', '') 31 | text = text.replace(u'\u2018', "'").replace(u'\u2019', "'") 32 | text = text.replace(u'\u201d', '').replace(u'\u201c', '') 33 | while ' ' in text: 34 | text = text.replace(' ', ' ') 35 | text = text.replace('``', '"') 36 | text = text.replace("''", '"') 37 | tokens = text.split() 38 | fixed = '' 39 | in_quotes = False 40 | for tok in tokens: 41 | if tok == '': 42 | fixed += '\n' 43 | elif '' in tok: 44 | print(tok) 45 | elif tok.startswith('"') and not in_quotes: 46 | fixed += ' ' + tok 47 | elif all([c in string.punctuation for c in tok]): 48 | fixed += tok 49 | elif tok.startswith("'"): 50 | fixed += tok 51 | elif tok.startswith("n't"): 52 | fixed += tok 53 | # https://en.wikipedia.org/wiki/Wikipedia:List_of_English_contractions 54 | elif fixed.endswith("'") and tok in ['s', 't', 'll', 'd', 're', 'm', 'n', 've', 'cause', 'cept', 'ight', 'bout', 'ye', 'en', 'er', 'em', 'gainst', 'day', 'am', 'neath', 'clock', 'round', 'til', 'tis', 'tween', 'twere', 'twas', 'all', 'know']: 55 | fixed += tok 56 | else: 57 | if fixed.endswith('"') and in_quotes: 58 | fixed += tok 59 | else: 60 | fixed += ' ' + tok 61 | if '"' in tok: 62 | in_quotes = not in_quotes 63 | fixed = fixed.replace('( ', ' (').replace('[ ', ' [').replace('{ ', ' {') 64 | all_fixed.append(fixed) 65 | return tuple(all_fixed) 66 | 67 | 68 | class WritingPromptsDataset(Dataset): 69 | def __init__(self, args): 70 | print('loading data') 71 | random.seed(args.seed) 72 | self.args = args 73 | self.debug = args.debug 74 | self.batch_size = args.batch_size 75 | self.data_dir = args.data_dir 76 | 77 | tokenizer = BartTokenizerFast.from_pretrained('facebook/bart-large-cnn') 78 | self.splits = {} 79 | for split in ['train', 'valid', 'test']: 80 | self.splits[split] = [] 81 | with open(os.path.join(args.data_dir, split + '.wp_target'), 'r') as rf1, \ 82 | open(os.path.join(args.data_dir, split + '.wp_source'), 'r') as rf2: 83 | contents = [line for line in rf1] 84 | summaries = [line for line in rf2] 85 | assert len(contents) == len(summaries) 86 | tokenized_contents = tokenizer.batch_encode_plus(contents, max_length=args.length_limit+1, truncation=True)['input_ids'] 87 | tokenized_summaries = tokenizer.batch_encode_plus(summaries, max_length=min(args.length_limit, args.summary_length_limit)+1, truncation=True)['input_ids'] 88 | for i in range(len(contents)): 89 | tokenized_content = tokenized_contents[i] 90 | if len(tokenized_content) > args.length_limit or len(tokenized_content) < args.lower_length_limit: 91 | continue 92 | tokenized_summary = tokenized_summaries[i] 93 | if len(tokenized_summary) > min(args.summary_length_limit, args.length_limit): 94 | continue 95 | content, summary = contents[i], summaries[i] 96 | # if args.extra_keywords > 0 and not args.hallucinate_keywords: 97 | # keywords = extract_keywords(content, max_keywords=args.extra_keywords, max_length=1) 98 | # summary = summary.strip() + KEYWORD_PROMPT + KEYWORD_SEP.join(keywords) + KEYWORD_END 99 | self.splits[split].append((content.strip(), summary.strip())) 100 | if args.limit is not None and len(self.splits[split]) >= args.limit: 101 | break 102 | if args.debug and len(self.splits[split]) >= 10: 103 | break 104 | os.environ["TOKENIZERS_PARALLELISM"] = "false" # avoid warnings later 105 | for split in ['train', 'valid', 'test']: 106 | with mp.Pool(20) as pool: 107 | self.splits[split] = pool.map(preprocess, self.splits[split]) 108 | 109 | print('done loading data') 110 | print('split sizes:') 111 | for key in ['train', 'valid', 'test']: 112 | print(key, len(self.splits[key])) 113 | 114 | def load_long_texts(self, split='train', limit=None, split_paragraphs=False): 115 | texts = [d[0] for d in self.splits[split]] 116 | return split_texts(texts if limit is None else texts[:limit], mode=self.args.split_long_paragraph_mode if split_paragraphs else 'none') 117 | 118 | def load_short_texts(self, split='train', limit=None, split_paragraphs=False): 119 | texts = [d[1] for d in self.splits[split]] 120 | texts = split_texts(texts if limit is None else texts[:limit], mode=self.args.split_short_paragraph_mode if split_paragraphs else 'none') 121 | if self.args.hallucinate_keywords and self.args.extra_keywords > 0: 122 | texts = [t + KEYWORD_PROMPT + KEYWORD_SEP.join(hallucinate_keywords(t, vocab_file=self.args.keyword_file, temperature=self.args.keyword_temperature, max_keywords=self.args.extra_keywords, sample=True)) + KEYWORD_END for t in texts] 123 | return texts 124 | 125 | def pandas_format(self, split, long_name='content', short_name='title', limit=None): 126 | pandas_data = self.splits[split] 127 | if limit is not None: 128 | pandas_data = pandas_data[:limit] 129 | return pd.DataFrame(pandas_data, columns=[long_name, short_name]) 130 | 131 | def shuffle(self, split, seed=None): 132 | assert split in ['train', 'valid', 'test'] 133 | if seed is not None: 134 | random.seed(seed) 135 | random.shuffle(self.splits[split]) -------------------------------------------------------------------------------- /story_generation/common/data/split_paragraphs.py: -------------------------------------------------------------------------------- 1 | from json import load 2 | import math 3 | 4 | import nltk 5 | nltk.download('punkt', quiet=True) 6 | nltk.download('stopwords', quiet=True) 7 | from nltk import tokenize 8 | from transformers import AutoTokenizer 9 | 10 | SPLIT_PARAGRAPH_MODES = ['none', 'newline', 'newline-filter', 'sentence'] 11 | split_paragraph_tokenizer = None 12 | 13 | def load_split_paragraph_tokenizer(): 14 | global split_paragraph_tokenizer 15 | if split_paragraph_tokenizer is None: 16 | split_paragraph_tokenizer = AutoTokenizer.from_pretrained('gpt2') 17 | return split_paragraph_tokenizer 18 | 19 | 20 | def cut_last_sentence(text): # remove possibly incomplete last sentence 21 | text = text.rstrip() + ' and' # possibly start a new sentence so we can delete it, if the last sentence is already complete and ended with a period 22 | last_sentence = split_paragraphs(text, mode='sentence')[-1].strip() # possibly incomplete, so strip it 23 | text = text.rstrip()[:len(text.rstrip()) - len(last_sentence)].rstrip() 24 | return text 25 | 26 | 27 | def cut_first_sentence(text): # remove possibly incomplete first sentence 28 | first_sentence = split_paragraphs(text, mode='sentence')[0].strip() # possibly incomplete, so strip it 29 | text = text.lstrip()[len(first_sentence):].lstrip() 30 | return text 31 | 32 | 33 | def split_paragraphs(text, mode='none'): 34 | """ 35 | Split a text into paragraphs. 36 | """ 37 | if mode == 'none': 38 | return [text.strip()] 39 | elif mode == 'newline': 40 | while '\n\n' in text: 41 | text = text.replace('\n\n', '\n') 42 | return [s.strip() for s in text.split('\n')] 43 | elif mode == 'newline-filter': 44 | while '\n\n' in text: 45 | text = text.replace('\n\n', '\n') 46 | paragraphs = text.split('\n') 47 | return [p.strip() for p in paragraphs if len(p.split()) > 100] 48 | elif mode == 'sentence': 49 | while '\n\n' in text: 50 | text = text.replace('\n\n', '\n') 51 | return sum([[s.strip() for s in tokenize.sent_tokenize(t)] for t in text.split('\n')], []) 52 | else: 53 | raise NotImplementedError 54 | 55 | 56 | def group_chunks(sentences, max_chunk_length=200, sep=' ', strip=True): 57 | tokenizer = load_split_paragraph_tokenizer() 58 | tokenized_lengths = [len(s) for s in tokenizer.batch_encode_plus(sentences)['input_ids']] 59 | num_chunks = math.ceil(sum(tokenized_lengths) / max_chunk_length) 60 | length_partition = partition_list(tokenized_lengths, num_chunks) 61 | chunks = [] 62 | sentence_idx = 0 63 | for group in length_partition: 64 | chunk = [] 65 | for _ in range(len(group)): 66 | chunk.append(sentences[sentence_idx]) 67 | sentence_idx += 1 68 | chunks.append(sep.join(chunk)) 69 | assert sentence_idx == len(sentences) 70 | return [c.strip() for c in chunks] 71 | 72 | 73 | # following function is copied from https://stackoverflow.com/questions/35517051/split-a-list-of-numbers-into-n-chunks-such-that-the-chunks-have-close-to-equal 74 | #partition list a into k partitions 75 | def partition_list(a, k): 76 | #check degenerate conditions 77 | if k <= 1: return [a] 78 | if k >= len(a): return [[x] for x in a] 79 | #create a list of indexes to partition between, using the index on the 80 | #left of the partition to indicate where to partition 81 | #to start, roughly partition the array into equal groups of len(a)/k (note 82 | #that the last group may be a different size) 83 | partition_between = [] 84 | for i in range(k-1): 85 | partition_between.append((i+1)*len(a)//k) 86 | #the ideal size for all partitions is the total height of the list divided 87 | #by the number of paritions 88 | average_height = float(sum(a))/k 89 | best_score = None 90 | best_partitions = None 91 | count = 0 92 | no_improvements_count = 0 93 | #loop over possible partitionings 94 | while True: 95 | #partition the list 96 | partitions = [] 97 | index = 0 98 | for div in partition_between: 99 | #create partitions based on partition_between 100 | partitions.append(a[index:div]) 101 | index = div 102 | #append the last partition, which runs from the last partition divider 103 | #to the end of the list 104 | partitions.append(a[index:]) 105 | #evaluate the partitioning 106 | worst_height_diff = 0 107 | worst_partition_index = -1 108 | for p in partitions: 109 | #compare the partition height to the ideal partition height 110 | height_diff = average_height - sum(p) 111 | #if it's the worst partition we've seen, update the variables that 112 | #track that 113 | if abs(height_diff) > abs(worst_height_diff): 114 | worst_height_diff = height_diff 115 | worst_partition_index = partitions.index(p) 116 | #if the worst partition from this run is still better than anything 117 | #we saw in previous iterations, update our best-ever variables 118 | if best_score is None or abs(worst_height_diff) < best_score: 119 | best_score = abs(worst_height_diff) 120 | best_partitions = partitions 121 | no_improvements_count = 0 122 | else: 123 | no_improvements_count += 1 124 | #decide if we're done: if all our partition heights are ideal, or if 125 | #we haven't seen improvement in >5 iterations, or we've tried 100 126 | #different partitionings 127 | #the criteria to exit are important for getting a good result with 128 | #complex data, and changing them is a good way to experiment with getting 129 | #improved results 130 | if worst_height_diff == 0 or no_improvements_count > 5 or count > 100: 131 | return best_partitions 132 | count += 1 133 | #adjust the partitioning of the worst partition to move it closer to the 134 | #ideal size. the overall goal is to take the worst partition and adjust 135 | #its size to try and make its height closer to the ideal. generally, if 136 | #the worst partition is too big, we want to shrink the worst partition 137 | #by moving one of its ends into the smaller of the two neighboring 138 | #partitions. if the worst partition is too small, we want to grow the 139 | #partition by expanding the partition towards the larger of the two 140 | #neighboring partitions 141 | if worst_partition_index == 0: #the worst partition is the first one 142 | if worst_height_diff < 0: partition_between[0] -= 1 #partition too big, so make it smaller 143 | else: partition_between[0] += 1 #partition too small, so make it bigger 144 | elif worst_partition_index == len(partitions)-1: #the worst partition is the last one 145 | if worst_height_diff < 0: partition_between[-1] += 1 #partition too small, so make it bigger 146 | else: partition_between[-1] -= 1 #partition too big, so make it smaller 147 | else: #the worst partition is in the middle somewhere 148 | left_bound = worst_partition_index - 1 #the divider before the partition 149 | right_bound = worst_partition_index #the divider after the partition 150 | if worst_height_diff < 0: #partition too big, so make it smaller 151 | if sum(partitions[worst_partition_index-1]) > sum(partitions[worst_partition_index+1]): #the partition on the left is bigger than the one on the right, so make the one on the right bigger 152 | partition_between[right_bound] -= 1 153 | else: #the partition on the left is smaller than the one on the right, so make the one on the left bigger 154 | partition_between[left_bound] += 1 155 | else: #partition too small, make it bigger 156 | if sum(partitions[worst_partition_index-1]) > sum(partitions[worst_partition_index+1]): #the partition on the left is bigger than the one on the right, so make the one on the left smaller 157 | partition_between[left_bound] -= 1 158 | else: #the partition on the left is smaller than the one on the right, so make the one on the right smaller 159 | partition_between[right_bound] += 1 160 | 161 | 162 | def split_texts(texts, mode='none'): 163 | """ 164 | Split a list of texts into paragraphs. 165 | """ 166 | if mode == 'none': 167 | return texts 168 | return sum([split_paragraphs(text, mode=mode) for text in texts], []) 169 | -------------------------------------------------------------------------------- /story_generation/common/data/tree_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import json 4 | import random 5 | 6 | START_OF_STORY = '[Beginning of story]' 7 | MIDDLE_OF_STORY = '[Previous details omitted] ...' 8 | SHORT_TEXT_PROMPT = 'Elaborate on the following passage:' 9 | LONG_TEXT_PROMPT = 'Story context:' 10 | PRE_TEXT_SEP = '\n\n"""\n' 11 | POST_TEXT_SEP = '\n"""\n\n' 12 | 13 | 14 | class Node: 15 | def __init__(self): 16 | self.children = [] 17 | self.parent = None 18 | self.long_text = None 19 | self.short_text = None 20 | 21 | def set_long_text(self, long_text): 22 | self.long_text = long_text 23 | return self 24 | 25 | def set_short_text(self, short_text): 26 | self.short_text = short_text 27 | return self 28 | 29 | def add_parent(self, parent=None): 30 | self.parent = Node() if parent is None else parent 31 | self.parent.children.append(self) 32 | return self.parent 33 | 34 | def add_child(self, child=None): 35 | child = Node() if child is None else child 36 | child.parent = self 37 | self.children.append(child) 38 | return child 39 | 40 | def max_depth_from_self(self): 41 | if len(self.children) == 0: 42 | return 0 43 | else: 44 | depths = [child.max_depth_from_self() for child in self.children] 45 | return max(depths) + 1 46 | 47 | def ordered_leaves(self): 48 | if len(self.children) == 0: 49 | return [self] 50 | else: 51 | leaves = [] 52 | for child in self.children: 53 | leaves += child.ordered_leaves() 54 | return leaves 55 | 56 | def traverse_subtree(self): 57 | yield self 58 | for child in self.children: 59 | yield from child.traverse_subtree() 60 | 61 | def depth(self): 62 | if self.parent is None: 63 | return 0 64 | else: 65 | return 1 + self.parent.depth() 66 | 67 | def nodes_at_depth(self, depth): 68 | if depth == 0: 69 | return [self] 70 | else: 71 | nodes = [] 72 | for child in self.children: 73 | nodes += child.nodes_at_depth(depth-1) 74 | return nodes 75 | 76 | def root(self): 77 | current_node = self 78 | while current_node.parent is not None: 79 | current_node = current_node.parent 80 | return current_node 81 | 82 | def previous_same_depth_node(self): 83 | same_depth_nodes = self.root().nodes_at_depth(self.depth()) 84 | idx = same_depth_nodes.index(self) 85 | return same_depth_nodes[idx-1] if idx > 0 else None 86 | 87 | def coherence_prefix(self): 88 | previous_same_depth_node = self.previous_same_depth_node() 89 | prefix = START_OF_STORY if (previous_same_depth_node is None or previous_same_depth_node.previous_same_depth_node() is None) else MIDDLE_OF_STORY 90 | # separate with ' ' for short text, '\n' for long text for sentences vs paragraphs, maybe 91 | previous_long_text = prefix + ' ' + previous_same_depth_node.long_text.strip() if previous_same_depth_node is not None else prefix 92 | return previous_long_text 93 | 94 | def context_expansion(self): 95 | previous_same_depth_node = self.previous_same_depth_node() 96 | prefix = START_OF_STORY if (previous_same_depth_node is None or previous_same_depth_node.previous_same_depth_node() is None) else MIDDLE_OF_STORY 97 | # separate with ' ' for short text, '\n' for long text for sentences vs paragraphs, maybe 98 | previous_short_text = prefix + ' ' + previous_same_depth_node.short_text.strip() + ' ' if previous_same_depth_node is not None else prefix + ' ' 99 | previous_long_text = prefix + ' ' + previous_same_depth_node.long_text.strip() + ' ' if previous_same_depth_node is not None else prefix + ' ' 100 | expansion = SHORT_TEXT_PROMPT + PRE_TEXT_SEP + previous_short_text + self.short_text.strip() + POST_TEXT_SEP + LONG_TEXT_PROMPT + PRE_TEXT_SEP + previous_long_text + self.long_text.strip() + POST_TEXT_SEP 101 | return expansion 102 | 103 | def recursive_context_prompt(self): 104 | previous_same_depth_node = self.previous_same_depth_node() 105 | prefix = START_OF_STORY if (previous_same_depth_node is None or previous_same_depth_node.previous_same_depth_node() is None) else MIDDLE_OF_STORY 106 | # separate with ' ' for short text, '\n' for long text for sentences vs paragraphs, maybe 107 | previous_short_text = prefix + ' ' + previous_same_depth_node.short_text.strip() + ' ' if previous_same_depth_node is not None else prefix + ' ' 108 | previous_long_text = prefix + ' ' + previous_same_depth_node.long_text.strip() if previous_same_depth_node is not None else prefix 109 | prompt = SHORT_TEXT_PROMPT + PRE_TEXT_SEP + previous_short_text + self.short_text.strip() + POST_TEXT_SEP + LONG_TEXT_PROMPT + PRE_TEXT_SEP + previous_long_text 110 | 111 | current_node = self 112 | while current_node.parent is not None: 113 | current_node = current_node.parent 114 | prompt = current_node.context_expansion() + prompt 115 | return prompt 116 | 117 | def context_completion(self): 118 | return ' ' + self.long_text.strip() + POST_TEXT_SEP 119 | 120 | def full_text(self, joiner=' '): 121 | return joiner.join([leaf.long_text for leaf in self.ordered_leaves()]) 122 | 123 | 124 | def save_gpt3_prompts(trees, path, shuffle=True): 125 | all_data_dicts = [] 126 | for root in trees: 127 | for node in root.traverse_subtree(): 128 | prompt = node.recursive_context_prompt() 129 | completion = node.context_completion() 130 | all_data_dicts.append({'prompt': prompt, 'completion': completion}) 131 | if shuffle: 132 | random.shuffle(all_data_dicts) 133 | os.makedirs(os.path.dirname(path), exist_ok=True) 134 | with open(path, 'w') as wf: 135 | for data_dict in all_data_dicts: 136 | wf.write(json.dumps(data_dict) + '\n') 137 | 138 | 139 | def save_trees(trees, path, mode='all', replace_newline=True, joiner='***', short_long_sep='@@@'): 140 | assert mode in ['all', 'final_long', 'final_short'] 141 | num_iterations = max([root.max_depth_from_self() for root in trees]) 142 | os.makedirs(os.path.dirname(path), exist_ok=True) 143 | with open(path, 'w') as wf: 144 | writer = csv.writer(wf) 145 | if mode == 'final_short': 146 | writer.writerow(['final_short']) 147 | for root in trees: 148 | writer.writerow([root.short_text.replace('\n', '\\n') if replace_newline else root.short_text]) 149 | elif mode == 'final_long': 150 | writer.writerow(['final_long']) 151 | for root in trees: 152 | long_text = root.full_text(joiner=joiner) 153 | long_text = long_text.replace('\n', '\\n') if replace_newline else long_text 154 | writer.writerow([long_text]) 155 | elif mode == 'all': 156 | writer.writerow(['iter' + str(i) for i in range(num_iterations+1)]) 157 | for root in trees: 158 | iters = [] 159 | current_nodes = [root] 160 | while len(current_nodes) > 0: 161 | iters.append(joiner.join([node.short_text for node in current_nodes]) + short_long_sep + joiner.join([node.long_text for node in current_nodes])) 162 | current_nodes = sum([node.children for node in current_nodes], []) 163 | writer.writerow([t.replace('\n', '\\n') for t in iters] if replace_newline else iters) -------------------------------------------------------------------------------- /story_generation/common/summarizer/models/abstract_summarizer.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | class AbstractSummarizer(ABC): 4 | @abstractmethod 5 | def __call__(self, texts): 6 | pass 7 | 8 | @abstractmethod 9 | def fit(self, dataset): 10 | pass 11 | 12 | @abstractmethod 13 | def save(self, path): 14 | pass 15 | 16 | @abstractmethod 17 | def load(self, path): 18 | pass 19 | 20 | @abstractmethod 21 | def add_controller(self, controller): 22 | pass -------------------------------------------------------------------------------- /story_generation/common/summarizer/models/gpt3_summarizer.py: -------------------------------------------------------------------------------- 1 | from re import T 2 | import time 3 | import logging 4 | 5 | from tqdm import tqdm 6 | import numpy as np 7 | import torch 8 | from transformers import AutoTokenizer 9 | import openai 10 | 11 | from story_generation.common.summarizer.models.abstract_summarizer import AbstractSummarizer 12 | from story_generation.common.data.split_paragraphs import split_paragraphs, cut_last_sentence 13 | 14 | GPT3_SEP = '\n\n###\n\n' 15 | GPT3_END = 'THE END.' 16 | PRETRAINED_MODELS = ['ada', 'babbage', 'curie', 'davinci', 'text-ada-001', 'text-babbage-001', 'text-curie-001', 'text-davinci-001', 'text-davinci-002'] 17 | 18 | class GPT3Summarizer(AbstractSummarizer): 19 | def __init__(self, args): 20 | assert args.gpt3_model is not None 21 | self.model = args.gpt3_model 22 | self.tokenizer = AutoTokenizer.from_pretrained("gpt2") 23 | self.args = args 24 | self.controller = None 25 | 26 | @torch.no_grad() 27 | def __call__(self, texts, generation_max_length=None, top_p=None, temperature=None, coherence_prefixes=None, retry_until_success=True, verbose=False, nodes=None, stop=None, modify_prompt=False, logit_bias={}, num_completions=1, cut_sentence=False, model_string=None): 28 | assert type(texts) == list 29 | if modify_prompt: 30 | logging.warning('Warning: modifying prompt for summarization') 31 | if model_string is None: 32 | logging.warning('model string not provided, using default model') 33 | if self.controller is None: 34 | return self._call_helper(texts, generation_max_length=generation_max_length, top_p=top_p, temperature=temperature, retry_until_success=retry_until_success, nodes=nodes, stop=stop, modify_prompt=modify_prompt, logit_bias=logit_bias, num_completions=num_completions, cut_sentence=cut_sentence, model_string=model_string) 35 | else: 36 | assert coherence_prefixes is not None and len(coherence_prefixes) == len(texts) 37 | if self.args.sentence_coherence_control_mode == 'rerank': 38 | generations = [] 39 | for text, prefix in zip(texts, coherence_prefixes): 40 | candidates = [] 41 | for _ in range(self.args.summarizer_beam_size): 42 | candidates += self._call_helper([text], generation_max_length=generation_max_length, top_p=top_p, temperature=temperature, retry_until_success=retry_until_success,nodes=nodes, stop=stop, modify_prompt=modify_prompt, logit_bias=logit_bias, num_completions=num_completions, cut_sentence=cut_sentence, model_string=model_string) 43 | coherence_scores = self.controller.evaluate_full_texts([prefix + c for c in candidates], reduce='none', add_prefix=False) 44 | generations.append(candidates[np.argmax(coherence_scores)]) 45 | return generations 46 | elif self.args.sentence_coherence_control_mode == 'greedy_sentence': 47 | raise NotImplementedError 48 | elif self.args.sentence_coherence_control_mode == 'beam_sentence': 49 | raise NotImplementedError 50 | else: 51 | raise NotImplementedError 52 | 53 | @torch.no_grad() 54 | def _call_helper(self, texts, generation_max_length=None, top_p=None, temperature=None, retry_until_success=True, nodes=None, stop=None, modify_prompt=False, logit_bias={}, num_completions=1, cut_sentence=False, model_string=None): 55 | given_stop = stop 56 | if nodes is not None: 57 | assert self.args.expander 58 | assert len(nodes) == len(texts) 59 | 60 | outputs = [] 61 | for i in range(len(texts)): 62 | text = texts[i] 63 | if not modify_prompt: 64 | prompt = text 65 | stop = None if self.model in PRETRAINED_MODELS else GPT3_END 66 | elif nodes is not None: 67 | node = nodes[i] 68 | prompt = node.recursive_context_prompt() 69 | stop = '"""' 70 | else: 71 | if self.model in PRETRAINED_MODELS: 72 | if self.args.expander: # generate 73 | prompt = 'A summary of a story:\n"""\n' + text.strip() + '\n"""\nFull version:\n"""\n' 74 | else: # summarize 75 | prompt = 'A passage from a story:\n"""\n' + text.strip() + '\n"""\nOne-sentence summary:\n"""\n' 76 | stop = '"""' 77 | else: 78 | prompt = text.strip() + GPT3_SEP # finetuned model 79 | stop = GPT3_END 80 | if given_stop is not None: 81 | stop = given_stop 82 | 83 | retry = True 84 | num_fails = 0 85 | while retry: 86 | try: 87 | context_length = len(self.tokenizer.encode(prompt)) 88 | if context_length > self.args.max_context_length: 89 | logging.warning('context length' + ' ' + context_length + ' ' + 'exceeded artificial context length limit' + ' ' + self.args.max_context_length) 90 | time.sleep(5) # similar interface to gpt3 query failing and retrying 91 | assert False 92 | if generation_max_length is None: 93 | generation_max_length = min(self.args.generation_max_length, self.args.max_context_length - context_length) 94 | engine = self.model if model_string is None else model_string 95 | if engine == 'text-davinci-001': 96 | engine = 'text-davinci-002' # update to latest version 97 | if engine in PRETRAINED_MODELS: 98 | logging.log(21, 'PROMPT') 99 | logging.log(21, prompt) 100 | logging.log(21, 'MODEL STRING:' + ' ' + self.model if model_string is None else model_string) 101 | completion = openai.Completion.create( 102 | engine=engine, 103 | prompt=prompt, 104 | max_tokens=generation_max_length, 105 | temperature=temperature if temperature is not None else self.args.summarizer_temperature, 106 | top_p=top_p if top_p is not None else self.args.summarizer_top_p, 107 | frequency_penalty=self.args.summarizer_frequency_penalty, 108 | presence_penalty=self.args.summarizer_presence_penalty, 109 | stop=stop, 110 | logit_bias=logit_bias, 111 | n=num_completions) 112 | else: 113 | logging.log(21, 'PROMPT') 114 | logging.log(21, prompt) 115 | logging.log(21, 'MODEL STRING:' + ' ' + self.model if model_string is None else model_string) 116 | completion = openai.Completion.create( 117 | model=engine, 118 | prompt=prompt, 119 | max_tokens=generation_max_length, 120 | temperature=temperature if temperature is not None else self.args.summarizer_temperature, 121 | top_p=self.args.summarizer_top_p, 122 | frequency_penalty=self.args.summarizer_frequency_penalty, 123 | presence_penalty=self.args.summarizer_presence_penalty, 124 | stop=stop, 125 | logit_bias=logit_bias, 126 | n=num_completions) 127 | retry = False 128 | except Exception as e: 129 | logging.warning(str(e)) 130 | retry = retry_until_success 131 | num_fails += 1 132 | if num_fails > 20: 133 | raise e 134 | if retry: 135 | logging.warning('retrying...') 136 | time.sleep(num_fails) 137 | outputs += [completion['choices'][j]['text'] for j in range(num_completions)] 138 | if cut_sentence: 139 | for i in range(len(outputs)): 140 | if len(outputs[i].strip()) > 0: 141 | outputs[i] = cut_last_sentence(outputs[i]) 142 | engine = self.model if model_string is None else model_string 143 | logging.log(21, 'OUTPUTS') 144 | logging.log(21, str(outputs)) 145 | logging.log(21, 'GPT3 CALL' + ' ' + engine + ' ' + str(len(self.tokenizer.encode(texts[0])) + sum([len(self.tokenizer.encode(o)) for o in outputs]))) 146 | return outputs 147 | 148 | @torch.no_grad() 149 | def next_sentence(self, text, stop=None): 150 | stop = ['.', '!', '?'] if stop is None else ['.', '!', '?'] + [stop] 151 | return self(text, stop=stop, modify_prompt=False) 152 | 153 | def generate_with_prompt_repetition_penalty(self, prompt, penalty=0, stop=None, bias_component=''): 154 | logit_bias = {key: -penalty for key in set(self.tokenizer.encode(bias_component))} if penalty != 0 else {} 155 | return self._call_helper([prompt], stop=stop, modify_prompt=False, logit_bias=logit_bias) 156 | 157 | def fit(self, dataset): 158 | pass 159 | 160 | def save(self, path): 161 | pass 162 | 163 | def load(self, path): 164 | pass 165 | 166 | def add_controller(self, controller): 167 | assert len(controller) == 1 168 | self.controller = controller[0] 169 | assert self.controller.type == 'sentence' -------------------------------------------------------------------------------- /story_generation/common/summarizer/models/opt_summarizer.py: -------------------------------------------------------------------------------- 1 | from re import T 2 | import time 3 | import logging 4 | import math 5 | from copy import deepcopy 6 | 7 | from tqdm import tqdm 8 | import numpy as np 9 | import torch 10 | from transformers import AutoTokenizer 11 | from scipy.special import expm1 12 | 13 | from story_generation.common.summarizer.models.abstract_summarizer import AbstractSummarizer 14 | from story_generation.common.data.split_paragraphs import split_paragraphs, cut_last_sentence 15 | from story_generation.common.util import * 16 | 17 | OPT_LOGPROBS_MAX_BS = 4 18 | 19 | # tokenizer ids for tokens which contain " or ' which causes some problematic punc/spacing sometimes. prefer to use e.g. “ and ” instead in stories. 20 | OPT_MACHINE_QUOTE_IDS = [22, 60, 72, 113, 845, 1297, 1917, 2901, 4332, 4805, 6697, 7862, 8070, 9957, 10076, 11227, 13198, 14025, 14220, 16844, 17495, 17523, 18456, 18653, 19207, 19651, 22896, 23962, 24095, 24337, 24464, 24681, 24992, 25718, 27223, 28553, 28578, 30550, 30697, 30831, 31051, 33525, 34133, 35290, 35347, 36856, 37008, 37637, 39058, 39732, 40021, 40389, 40635, 41039, 41066, 41758, 42078, 42248, 42255, 42777, 43012, 43074, 43101, 43775, 43809, 44065, 44374, 44690, 44717, 44757, 44926, 45333, 45390, 45751, 45863, 45894, 46150, 46294, 46353, 46469, 46479, 46481, 46671, 46679, 47096, 47460, 47770, 47919, 48110, 48149, 48298, 48336, 48474, 48615, 48742, 48789, 48805, 48880, 48893, 49070, 49177, 49189, 49293, 49329, 49434, 49509, 49608, 49643, 49667, 49713, 49721, 49738, 49761, 49778, 49784, 49799, 49817, 49849, 49852, 49853, 49871, 49900, 49923, 49991, 49995, 50000, 50003, 50020, 50154, 50184, 50206] + [18, 75, 108, 128, 214, 348, 437, 581, 955, 1017, 1598, 2652, 3934, 6600, 9376, 9957, 10076, 10559, 12801, 13373, 13864, 17809, 19651, 22896, 23500, 24095, 24464, 24992, 27144, 27645, 30171, 31509, 32269, 35347, 35661, 41667, 41734, 41833, 44162, 44294, 44403, 45393, 45803, 46117, 46150, 46250, 46495, 47033, 47052, 47429, 47579, 48694, 48759, 48817, 49201, 49515, 49525, 49690, 49836, 49888] 21 | 22 | class OPTSummarizer(AbstractSummarizer): 23 | def __init__(self, args): 24 | assert args.alpa_url is not None 25 | if args.alpa_url.startswith('http'): 26 | alpa_url = args.alpa_url 27 | else: 28 | with open(args.alpa_url, 'r') as rf: 29 | alpa_hostname = rf.read().strip().split()[0] 30 | alpa_url = f'http://{alpa_hostname}:{args.alpa_port}' 31 | self.client = AlpaOPTClient(url=alpa_url, api_key=args.alpa_key) 32 | self.tokenizer = AutoTokenizer.from_pretrained("facebook/opt-30b", use_fast=False) 33 | self.tokenizer.add_bos_token = False 34 | self.args = args 35 | self.controller = None 36 | 37 | @torch.no_grad() 38 | def generate(self, 39 | prompt, 40 | generation_max_length=None, 41 | top_p=None, 42 | temperature=None, 43 | retry_until_success=True, 44 | verbose=False, 45 | stop=None, 46 | num_completions=1, 47 | cut_sentence=False): 48 | assert type(prompt) == str 49 | logging.log(21, 'OPT GENERATION PROMPT') 50 | logging.log(21, prompt) 51 | if generation_max_length is None: 52 | generation_max_length = self.args.generation_max_length 53 | if top_p is None: 54 | top_p = self.args.summarizer_top_p 55 | if temperature is None: 56 | temperature = self.args.summarizer_temperature 57 | if stop is None: 58 | stop = [] 59 | if type(stop) == str: 60 | stop = [stop] 61 | retry = True 62 | num_fails = 0 63 | while retry: 64 | try: 65 | completions = self.client.completions([prompt for _ in range(num_completions)], temperature=temperature, top_p=top_p, max_tokens=generation_max_length) 66 | if 'choices' not in completions: 67 | import pdb; pdb.set_trace() 68 | completions = [entry['text'][len(prompt):] for entry in completions['choices']] 69 | retry = False 70 | except Exception as e: 71 | logging.warning(str(e)) 72 | retry = retry_until_success 73 | num_fails += 1 74 | if retry: 75 | logging.warning('retrying...') 76 | time.sleep(num_fails) 77 | logging.warning('old alpa url: ' + self.client.logprobs_url + ' at time ' + str(time.ctime())) 78 | self.client.refresh_url(self.args.alpa_url, self.args.alpa_port) 79 | logging.warning('new alpa url: ' + self.client.logprobs_url) 80 | for i, text in enumerate(completions): 81 | for s in stop: 82 | if s in text: 83 | text = text[:text.index(s)] 84 | completions[i] = text 85 | if cut_sentence: 86 | completions = [cut_last_sentence(text) for text in completions] 87 | return completions 88 | 89 | @torch.no_grad() 90 | def __call__(self, 91 | texts, 92 | controllers=None, 93 | controller_initial_texts=None, 94 | control_strengths=None, 95 | generation_max_length=None, 96 | top_p=1, 97 | top_k=100, 98 | temperature=None, 99 | retry_until_success=True, 100 | verbose=False, 101 | stop=None, 102 | logit_bias=None, 103 | exclude_strings=None, 104 | num_completions=1, 105 | frequency_penalty=None, 106 | presence_penalty=None, 107 | cut_sentence=False, 108 | bias_machine_quotes=True, 109 | logit_bias_decay=1): 110 | assert type(texts) == list 111 | if logit_bias is None: 112 | logit_bias = {} 113 | assert controller_initial_texts is not None 114 | assert type(controller_initial_texts) == list and len(controllers) == len(control_strengths) and len(controller_initial_texts) == len(texts) 115 | if generation_max_length is None: 116 | generation_max_length = self.args.generation_max_length 117 | if top_p is None: 118 | top_p = self.args.summarizer_top_p 119 | if top_k is None: 120 | top_k = self.args.summarizer_top_k 121 | if temperature is None: 122 | temperature = self.args.summarizer_temperature 123 | if frequency_penalty is None: 124 | frequency_penalty = self.args.summarizer_frequency_penalty 125 | if presence_penalty is None: 126 | presence_penalty = self.args.summarizer_presence_penalty 127 | if stop is None: 128 | stop = [] 129 | if type(stop) == str: 130 | stop = [stop] 131 | exclude_tokens = set() 132 | if exclude_strings is not None: 133 | for s in exclude_strings: 134 | exclude_tokens.update(self.tokenizer.encode(s[0].upper() + s[1:])) 135 | exclude_tokens.update(self.tokenizer.encode(' ' + s[0].upper() + s[1:])) 136 | exclude_tokens.update(self.tokenizer.encode(s[0].lower() + s[1:])) 137 | exclude_tokens.update(self.tokenizer.encode(' ' + s[0].lower() + s[1:])) 138 | sentences = [] 139 | for text_idx, text in enumerate(texts): 140 | context_length = len(self.tokenizer.encode(text)) 141 | if context_length > self.args.max_context_length - generation_max_length: 142 | logging.warning('context length' + ' ' + str(context_length) + ' ' + 'exceeded artificial context length limit' + ' ' + str(self.args.max_context_length - generation_max_length)) 143 | # time.sleep(5) # similar interface to gpt3 query failing and retrying 144 | print('TOO LONG CONTEXT: ' + text) 145 | print('CONTEXT LENGTH:' + str(context_length)) 146 | current_controller_initial_texts = controller_initial_texts[text_idx] 147 | assert len(current_controller_initial_texts) == len(controllers) 148 | logging.log(21, 'OPT CALL PROMPT') 149 | logging.log(21, text) 150 | device = controllers[0].device if (controllers is not None and len(controllers) > 0) else ('cuda' if torch.cuda.is_available() else 'cpu') 151 | expanded_logit_bias = torch.zeros(num_completions, self.tokenizer.vocab_size + 10).to(device) 152 | for token, bias in logit_bias.items(): 153 | if token not in exclude_tokens: 154 | expanded_logit_bias[:, token] = bias 155 | if bias_machine_quotes: 156 | for token in OPT_MACHINE_QUOTE_IDS: 157 | expanded_logit_bias[:, token] = -100 158 | frequency_bias = torch.zeros_like(expanded_logit_bias) 159 | prompt = [[int(x) for x in self.tokenizer.encode(text)] for _ in range(num_completions)] 160 | if controllers is not None: 161 | controller_ids = [] 162 | for ci in range(len(controllers)): 163 | controller_ids.append([[int(x) for x in self.tokenizer.encode(current_controller_initial_texts[ci])] for _ in range(num_completions)]) 164 | initial_prompt_length = len(prompt[0]) 165 | cache_id = None 166 | for _ in range(generation_max_length): 167 | retry = True 168 | num_fails = 0 169 | while retry: 170 | try: 171 | with time_limit(30): 172 | output = self.client.logprobs(prompt, top_p=top_p, top_k=top_k, cache_id=cache_id) 173 | assert 'indices' in output and 'logprobs' in output 174 | retry = False 175 | except Exception as e: 176 | logging.warning(str(e)) 177 | cache_id = None # not reentrant; restart cache 178 | retry = retry_until_success 179 | num_fails += 1 180 | if retry: 181 | logging.warning('retrying...') 182 | time.sleep(num_fails) 183 | logging.warning('old alpa url: ' + self.client.logprobs_url + ' at time ' + str(time.ctime())) 184 | self.client.refresh_url(self.args.alpa_url, self.args.alpa_port) 185 | logging.warning('new alpa url: ' + self.client.logprobs_url) 186 | distribution = (torch.zeros(num_completions, self.tokenizer.vocab_size + 10) - 1e8).to(device) 187 | distribution.scatter_(1, torch.LongTensor(output['indices']).to(device), torch.Tensor(output['logprobs']).to(device)) 188 | if controllers is not None: 189 | """ 190 | lm_logits: beam x 1 x vocab 191 | input_ids: beam x seqlen 192 | optionally, top_logits and top_indices, both beam x 1 x topk 193 | """ 194 | for ci in range(len(controllers)): 195 | # this call modifies and returns the distribution based on the given control string 196 | distribution = controllers[ci](distribution.view(num_completions, 1, -1).to(device), 197 | torch.LongTensor(controller_ids[ci]).view(num_completions, -1).to(device), 198 | top_logits=torch.Tensor(output['logprobs']).view(num_completions, 1, -1).to(device), 199 | top_indices=torch.LongTensor(output['indices']).view(num_completions, 1, -1).to(device), 200 | control_strength=control_strengths[ci]) 201 | distribution = distribution.squeeze(1) 202 | distribution /= temperature 203 | distribution += expanded_logit_bias + frequency_bias 204 | distribution = torch.softmax(distribution, dim=1) 205 | next_tokens = torch.multinomial(distribution, 1).squeeze(1) 206 | for i in range(num_completions): 207 | prompt[i].append(next_tokens[i].item()) 208 | if controllers is not None: 209 | for ci in range(len(controllers)): 210 | controller_ids[ci][i].append(next_tokens[i].item()) 211 | if next_tokens[i].item() not in exclude_tokens: 212 | frequency_bias[i, next_tokens[i].item()] -= frequency_penalty 213 | if frequency_bias[i, next_tokens[i].item()] > -presence_penalty: 214 | frequency_bias[i, next_tokens[i].item()] -= presence_penalty 215 | frequency_bias = frequency_bias * logit_bias_decay 216 | cache_id = output["cache_id"] 217 | for completion in prompt: 218 | decoded_completion = self.tokenizer.decode(completion[initial_prompt_length:]) 219 | sentences.append(decoded_completion) 220 | 221 | for i in range(len(sentences)): 222 | sentence = sentences[i] 223 | while len(sentence) > 0 and sentence[-1] not in string.printable: # sometimes you get half of a special char at the end 224 | sentence = sentence[:-1] 225 | if len(self.tokenizer.encode(sentence.split('\n')[-1])) < 10: # if we just barely started a new paragraph, don't include it; you can get led down bad paths 226 | sentence = '\n'.join(sentence.split('\n')[:-1]).rstrip() 227 | sentence = sentence.rstrip() 228 | for s in stop: 229 | if s in sentence: 230 | sentence = sentence[:sentence.index(s)] 231 | sentences[i] = sentence.rstrip() 232 | if cut_sentence: 233 | sentences = [cut_last_sentence(sentence) for sentence in sentences] 234 | return sentences 235 | 236 | @torch.no_grad() 237 | def generate_with_controller(self, 238 | controllers, 239 | controller_initial_texts, 240 | prompt, 241 | control_strengths=None, 242 | generation_max_length=1, 243 | num_completions=1, 244 | fudge_top_k=100, 245 | fudge_top_p=1, 246 | temperature=None, 247 | logit_bias=None, 248 | exclude_strings=None, 249 | cut_sentence=False, 250 | logit_bias_decay=1): 251 | if logit_bias is None: 252 | logit_bias = {} 253 | completions = [] 254 | for i in range(0, num_completions, OPT_LOGPROBS_MAX_BS): 255 | completions += self([prompt], 256 | controllers=controllers, 257 | control_strengths=control_strengths, 258 | controller_initial_texts=[controller_initial_texts], 259 | generation_max_length=generation_max_length, 260 | top_p=fudge_top_p, 261 | top_k=fudge_top_k, 262 | temperature=temperature, 263 | logit_bias=logit_bias, 264 | exclude_strings=exclude_strings, 265 | num_completions=min(num_completions - i, OPT_LOGPROBS_MAX_BS), 266 | cut_sentence=cut_sentence, 267 | logit_bias_decay=logit_bias_decay) 268 | return completions 269 | 270 | def create_logit_bias_for_prompt(self, prompt, bias=0, exclude_strings=None, logit_bias=None, use_frequency_penalty=False, decay=1): 271 | if logit_bias is None: 272 | logit_bias = {} 273 | exclude_tokens = set() 274 | if exclude_strings is not None: 275 | for s in exclude_strings: 276 | exclude_tokens.update(self.tokenizer.encode(s[0].upper() + s[1:])) 277 | exclude_tokens.update(self.tokenizer.encode(' ' + s[0].upper() + s[1:])) 278 | exclude_tokens.update(self.tokenizer.encode(s[0].lower() + s[1:])) 279 | exclude_tokens.update(self.tokenizer.encode(' ' + s[0].lower() + s[1:])) 280 | for i, token in enumerate(reversed(self.tokenizer.encode(prompt))): 281 | if token not in exclude_tokens: 282 | if token in logit_bias: 283 | if use_frequency_penalty: 284 | logit_bias[token] += bias * decay ** i 285 | else: 286 | logit_bias[token] = bias * decay ** i 287 | for i, token in enumerate(reversed(self.tokenizer.encode(prompt.upper()))): 288 | if token not in exclude_tokens: 289 | if token in logit_bias: 290 | pass # don't re-penalize tokens based on this upper/lowercase heuristic 291 | else: 292 | logit_bias[token] = bias 293 | for i, token in enumerate(reversed(self.tokenizer.encode(prompt.lower()))): 294 | if token not in exclude_tokens: 295 | if token in logit_bias: 296 | pass # don't re-penalize tokens based on this upper/lowercase heuristic 297 | else: 298 | logit_bias[token] = bias 299 | return logit_bias 300 | 301 | def fit(self, dataset): 302 | pass 303 | 304 | def save(self, path): 305 | pass 306 | 307 | def load(self, path): 308 | pass 309 | 310 | def add_controller(self, controller): 311 | raise NotImplementedError -------------------------------------------------------------------------------- /story_generation/common/summarizer/summarizer_util.py: -------------------------------------------------------------------------------- 1 | from story_generation.common.summarizer.models.gpt3_summarizer import GPT3Summarizer 2 | from story_generation.common.summarizer.models.opt_summarizer import OPTSummarizer 3 | 4 | SUMMARIZER_CHOICES=['gpt3_summarizer', 'opt_summarizer'] 5 | 6 | def add_summarizer_args(parser): 7 | parser.add_argument('--summarizer', type=str, default='gpt3_summarizer', choices=SUMMARIZER_CHOICES, help='model architecture') 8 | parser.add_argument('--summarizer-save-dir', type=str, default=None, help='directory to save summarizer') 9 | parser.add_argument('--summarizer-load-dir', type=str, default=None, help='directory to load summarizer') 10 | parser.add_argument('--expander', action='store_true', help='swap source and target to learn expanding a summary') 11 | parser.add_argument('--summarizer-temperature', type=float, default=0.8, help='temperature for summarizer') 12 | parser.add_argument('--opt-summarizer-temperature', type=float, default=0.8, help='temperature for OPT summarizer during main story generation') 13 | parser.add_argument('--summarizer-top-p', type=float, default=1.0, help='top p for summarizer') 14 | parser.add_argument('--summarizer-frequency-penalty', type=float, default=0.5, help='frequency penalty for summarizer') 15 | parser.add_argument('--summarizer-prompt-penalty', type=float, default=0.5, help='OPT control penalty for prompt tokens for summarizer, excluding stopwords/punc/names') 16 | parser.add_argument('--summarizer-frequency-penalty-decay', type=float, default=0.98, help='frequency penalty decay for OPT summarizer') 17 | parser.add_argument('--summarizer-presence-penalty', type=float, default=0, help='presence penalty for summarizer') 18 | parser.add_argument('--generation-max-length', type=int, default=256, help='max length for generation, not including prompt') 19 | parser.add_argument('--summarizer-beam-size', type=int, default=1, help='beam size for summarizer') 20 | parser.add_argument('--gpt3-model', type=str, default='text-davinci-002', help='gpt3 model or finetuned ckpt for GPT3Summarizer') 21 | parser.add_argument('--max-context-length', type=int, default=1024, help='max length for context to facilitate toy version') 22 | parser.add_argument('--alpa-url', type=str, default=None, help='url for alpa API') 23 | parser.add_argument('--alpa-port', type=str, default=None, help='port for alpa API, if alpa-url is a filename to read server location from. convenient for slurm') 24 | parser.add_argument('--alpa-key', type=str, default='', help='key for alpa API, if using the public API') 25 | return parser 26 | 27 | def load_summarizer(args): 28 | if args.summarizer == 'gpt3_summarizer': 29 | summarizer = GPT3Summarizer(args) 30 | elif args.summarizer == 'opt_summarizer': 31 | summarizer = OPTSummarizer(args) 32 | else: 33 | raise NotImplementedError 34 | return summarizer -------------------------------------------------------------------------------- /story_generation/draft_module/beam_candidate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | from enum import auto 4 | import os 5 | from copy import deepcopy 6 | import pickle 7 | from collections import defaultdict 8 | import multiprocessing as mp 9 | import random 10 | import string 11 | import logging 12 | 13 | import torch 14 | import Levenshtein 15 | import numpy as np 16 | from transformers import AutoTokenizer 17 | import openai 18 | from scipy.special import softmax 19 | 20 | from story_generation.edit_module.entity import * 21 | from story_generation.rewrite_module.heuristics import * 22 | from story_generation.common.util import * 23 | from story_generation.common.data.data_util import add_data_args, load_dataset 24 | from story_generation.common.summarizer.summarizer_util import add_summarizer_args, load_summarizer 25 | from story_generation.common.summarizer.models.gpt3_summarizer import GPT3_SEP, GPT3_END 26 | from story_generation.common.controller.controller_util import add_controller_args, load_controller 27 | from story_generation.common.controller.loaders.alignment_loader import create_prefix_completion 28 | from story_generation.common.data.split_paragraphs import * 29 | 30 | class BeamCandidate: 31 | def __init__(self, 32 | args, 33 | all_entities_dict, 34 | infer_attributes_string, 35 | model=None, 36 | opt_model=None, 37 | controllers=None, 38 | step=0, 39 | alignment_score=-1e8, 40 | best_alignment_so_far=-1e8, 41 | all_paragraphs=None, 42 | outline_sections=None, 43 | paragraphs_by_outline_section=None): 44 | self.args = args 45 | self.all_entities_dict = all_entities_dict 46 | self.infer_attributes_string = infer_attributes_string 47 | self.model = model 48 | self.opt_model = opt_model 49 | self.controllers = controllers 50 | self.step = step 51 | self.alignment_score = alignment_score 52 | self.best_alignment_so_far = best_alignment_so_far 53 | self.all_paragraphs = all_paragraphs if all_paragraphs is not None else [] 54 | self.outline_sections = outline_sections if outline_sections is not None else [] 55 | self.paragraphs_by_outline_section = paragraphs_by_outline_section if paragraphs_by_outline_section is not None else {} 56 | self.is_consistent = False 57 | 58 | def story(self): 59 | # technically there could be a missing newline here instead of a space. low priority 60 | return ' '.join(self.all_paragraphs) 61 | 62 | def previous_passage(self, max_tokens, suffix=None): 63 | if len(self.all_paragraphs) == 0: 64 | return '' 65 | passage = self.story() 66 | if len(self.story().strip()) == 0: 67 | return '' 68 | if suffix is not None: 69 | passage = passage[:len(passage) - len(suffix)].rstrip() 70 | if len(passage.strip()) == 0: 71 | return '' 72 | passage = self.model.tokenizer.decode(self.model.tokenizer.encode(passage)[-max_tokens:]) 73 | return cut_first_sentence(passage) 74 | 75 | def print_section(self, section_idx): 76 | return ' '.join(self.paragraphs_by_outline_section[self.outline_sections[section_idx]]) 77 | 78 | def select_entities(self, outline_section, previous_paragraph=None): 79 | # TODO lot of things to try here... 80 | matched_entities, _, _ = deduplicate_match_entities(detect_entities(outline_section), self.all_entities_dict.keys()) 81 | matched_entities = list(matched_entities) 82 | dpr_query_encoder, dpr_context_encoder = load_dpr() 83 | if previous_paragraph is not None: 84 | summary_prompt = previous_paragraph.strip() + '\n\n\n\nOne-sentence summary:\n\n\n\n' 85 | summary = self.model([summary_prompt], modify_prompt=False, model_string='text-curie-001')[0].strip() 86 | query_encoding = dpr_query_encoder.encode('Previous passage summary: ' + summary + '\n\nCurrent story outline: ' + outline_section.strip() + '\n\nWho or what appears in the upcoming paragraphs?') 87 | else: 88 | query_encoding = dpr_query_encoder.encode('Current story outline: ' + outline_section.strip() + '\n\nWho or what appears in the upcoming paragraphs?') 89 | entities = [key for key in list(self.all_entities_dict.keys()) if self.all_entities_dict[key].is_character] 90 | if len(entities) == 0: 91 | return [] 92 | context_encodings = dpr_context_encoder.encode(entities) 93 | scores = (query_encoding.reshape(1, -1) * context_encodings).sum(axis=1) 94 | selected_entities = matched_entities 95 | total_tokens = sum([len(self.model.tokenizer.encode(self.all_entities_dict[entity].description)) for entity in selected_entities]) 96 | if total_tokens > self.args.max_entity_context_tokens: 97 | logging.warning('Warning: truncating entity context to fit context length limit') 98 | selected_entities = [] 99 | total_tokens = 0 100 | for entity in matched_entities: 101 | total_tokens += len(self.model.tokenizer.encode(self.all_entities_dict[entity].description)) 102 | if total_tokens > self.args.max_entity_context_tokens: 103 | break 104 | selected_entities.append(entity) 105 | return selected_entities 106 | # sample additional entities without repeats, up to context length 107 | for i, ent in enumerate(entities): # mask out ones we already selected 108 | if ent in selected_entities: 109 | scores[i] = -1e8 110 | unselected_entities = [ent for ent in entities if ent not in selected_entities] 111 | while total_tokens < self.args.max_entity_context_tokens and len(unselected_entities) > 0: 112 | probs = softmax(scores) 113 | next_entity = np.random.choice(entities, p=probs) 114 | total_tokens += len(self.model.tokenizer.encode(self.all_entities_dict[next_entity].description)) 115 | if total_tokens > self.args.max_entity_context_tokens: 116 | break 117 | selected_entities.append(next_entity) 118 | scores[entities.index(next_entity)] = -1e8 119 | unselected_entities = [ent for ent in entities if ent not in selected_entities] 120 | 121 | return selected_entities 122 | 123 | def create_updated_entities(self, new_passage, cached_update_dict=None): 124 | # detect and make entries for new entities, run inference for description / is_character on new entities, update attributes 125 | new_entities_dict = deepcopy(self.all_entities_dict) 126 | entities = [str(ent) for ent in detect_entities(new_passage)] 127 | matched_entities, new_entities, _ = deduplicate_match_entities(entities, self.all_entities_dict.keys()) 128 | new_entities_dict = deepcopy(self.all_entities_dict) 129 | for ent in new_entities: 130 | entity = Entity(ent) 131 | entity.infer_description(new_passage, self.model, max_length=self.args.entity_description_max_length) 132 | entity.infer_is_character(new_passage, self.model) 133 | entity.infer_attributes(new_passage, self.model, other_names=[name for name in matched_entities if name != entity.name] + [name for name in new_entities if name != entity.name]) 134 | new_entities_dict[ent] = entity 135 | for ent in matched_entities: 136 | if cached_update_dict is not None and ent in cached_update_dict: 137 | new_entities_dict[ent] = cached_update_dict[ent] 138 | else: 139 | new_entities_dict[ent].infer_attributes(new_passage, self.model, other_names=[name for name in matched_entities if name != ent] + list(new_entities), detect_contradictions=False) 140 | complete_mutual_relations(new_entities_dict, self.model) 141 | return new_entities_dict 142 | 143 | def detect_attribute_contradictions(self, completion, detect_contradictions=True): 144 | matched_entities, new_entities, _ = deduplicate_match_entities(detect_entities(completion, add_dpr_entities=False, all_entities_dict=self.all_entities_dict), self.all_entities_dict.keys()) 145 | matched_entities = list(matched_entities) 146 | contradictions = {} 147 | cached_update_dict = {} 148 | copied_entities = deepcopy(self.all_entities_dict) 149 | for ent in matched_entities: 150 | entity = copied_entities[ent] 151 | contradictions[ent] = entity.infer_attributes(completion, self.model, detect_contradictions=detect_contradictions, other_names=[name for name in matched_entities if name != entity.name] + list(new_entities)) 152 | cached_update_dict[ent] = entity 153 | _, additional_contradictions = complete_mutual_relations(copied_entities, self.model) 154 | for ent in additional_contradictions: 155 | for key in additional_contradictions[ent]: 156 | if ent not in contradictions: 157 | contradictions[ent] = {} 158 | contradictions[ent][key] = additional_contradictions[ent][key] 159 | return matched_entities, contradictions, cached_update_dict 160 | 161 | def condense_outline_sections(self, outline): 162 | if type(outline) != tuple: 163 | return 164 | logging.log(23, 'CONDENSING OUTLINE') 165 | logging.log(23, 'BEFORE') 166 | logging.log(23, str(self.outline_sections)) 167 | high_level_outline = split_list(outline[0]) 168 | for i in range(len(high_level_outline)): 169 | if high_level_outline[i] in self.outline_sections: 170 | assert self.outline_sections[i] == high_level_outline[i] 171 | continue 172 | detailed_outline = split_list(outline[1][i]) 173 | if len(self.outline_sections) - i == len(detailed_outline): 174 | self.outline_sections = deepcopy(high_level_outline[:i+1]) 175 | break 176 | logging.log(23, 'AFTER') 177 | logging.log(23, str(self.outline_sections)) 178 | 179 | def construct_prompt(self, outline_section, selected_entities=[]): 180 | presumed_max_prompt_length = 2*self.args.generation_max_length + self.args.max_entity_context_tokens + 128 181 | if self.args.no_planner: 182 | if len(self.model.tokenizer.encode(self.story())) <= self.args.max_context_length - 2*self.args.generation_max_length: # early on enough to fit the premise in the rolling window 183 | prompt = 'Write a story with the following premise.\n\n' + self.all_entities_dict['Premise'].description + '\n\n' 184 | prompt += 'Chapter 1\n\n' 185 | if len(self.story()) > 0: 186 | prompt += self.story() 187 | return prompt 188 | else: 189 | return self.previous_passage(self.args.max_context_length - self.args.generation_max_length) 190 | if len(self.all_paragraphs) == 0: 191 | prompt = self.infer_attributes_string + '\n\n\n\n' 192 | else: 193 | if len(selected_entities) > 0: 194 | selected_entity_strings = [self.all_entities_dict[ent].description for ent in selected_entities] 195 | prompt = 'Relevant Context:\n\n' + '\n\n'.join(selected_entity_strings) + '\n\n\n\n' 196 | prompt += 'The story is written in third person.' 197 | if self.step > 1: 198 | prompt += '\n\n\n\nPrevious story summary: ' + ' '.join(self.outline_sections[:-1]) 199 | previous_text = self.previous_passage(self.args.previous_prompt_length) 200 | if len(self.all_paragraphs) > 0: 201 | previous_passage = self.previous_passage(int(self.args.max_context_length/2), suffix=previous_text) 202 | if len(self.model.tokenizer.encode(previous_passage)) > int(self.args.max_context_length/4): # no need to do this extra summary if it's really short 203 | max_preceding_summary_tokens = 128 204 | preceding_summary = self.model([previous_passage + '\n\nSummarize the events in this passage.'], generation_max_length=max_preceding_summary_tokens, model_string='text-curie-001')[0].strip() 205 | if len(self.model.tokenizer.encode(preceding_summary)) == max_preceding_summary_tokens: 206 | logging.warning('Warning: preceding events summary is too long, truncating') 207 | prompt += '\n\n\n\nEvents immediately prior to the upcoming passage: ' + preceding_summary 208 | if self.step == 1: 209 | prompt += '\n\n\n\nChapter 1 Summary: ' + outline_section.strip() 210 | else: 211 | prompt += '\n\n\n\nIn the upcoming passage, ' + outline_section.strip()[0].lower() + outline_section.strip()[1:] # uncapitalize the first letter if needed 212 | prompt += '\n\n\n\nFull text below:\n\n\n\n' 213 | if len(self.all_paragraphs) == 0: 214 | prompt = prompt + 'Chapter 1\n\n' 215 | prompt = prompt + previous_text 216 | length_model = self.model if self.args.extension_method == 'gpt3' else self.opt_model 217 | if len(length_model.tokenizer.encode(prompt)) > presumed_max_prompt_length: 218 | # generation max length from selected entities and outline, max entity context tokens from previous context, then some padding 219 | logging.warning('Warning: prompt is too long, please inspect') 220 | prompt = length_model.tokenizer.decode(length_model.tokenizer.encode(prompt)[-presumed_max_prompt_length:]) # left truncate prompt to fit our imposed limit on context window size 221 | # import pdb; pdb.set_trace() 222 | return prompt 223 | 224 | @torch.no_grad() 225 | def edit_update_contradictions(self): 226 | assert not self.is_consistent 227 | completion = self.all_paragraphs[-1] 228 | autoregressive_context = self.all_paragraphs[-2].lstrip(string.punctuation) if len(self.all_paragraphs) > 1 else '' 229 | matched_entities, contradictions, cached_update_dict = self.detect_attribute_contradictions(completion.strip(), detect_contradictions=True) 230 | edited_sentences = set() 231 | if any([len(contradictions[ent]) > 0 for ent in matched_entities]) and len(autoregressive_context) > 0: # don't do it on the first paragraph, if we don't have autoregressive context to help check we're not messing something up 232 | logging.log(23, 'editing completion based on contradictions') 233 | logging.log(23, 'AUTOREGRESSIVE CONTEXT ' + autoregressive_context) 234 | logging.log(23, 'BEFORE ' + completion) 235 | for ent in matched_entities: 236 | for contradiction_key in contradictions[ent]: 237 | for contradicted_sentence in contradictions[ent][contradiction_key][0]['text'].strip().split('\n'): 238 | if contradicted_sentence in edited_sentences: # no need to edit again if the sentence was contradicted more than once 239 | continue 240 | edited_sentences.add(contradicted_sentence) 241 | instruction = 'Edit so that ' + contradicted_sentence + ' Keep the text unchanged as much as possible.' 242 | logging.log(23, 'INSTRUCTION ' + instruction) 243 | completion = gpt3_edit(completion, instruction, prefix=None if len(autoregressive_context.strip()) == 0 else autoregressive_context).strip() 244 | if len(self.model.tokenizer.encode(completion)) > self.args.generation_max_length + 64: # give some leeway for editing to expand text 245 | logging.warning('WARNING: completion is too long after editing. Truncating...') 246 | completion = self.model.tokenizer.decode(self.model.tokenizer.encode(completion)[:self.args.generation_max_length + 64]) 247 | completion = cut_last_sentence(completion) 248 | logging.log(23, 'AFTER ' + completion) 249 | _, _, cached_update_dict = self.detect_attribute_contradictions(completion.strip(), detect_contradictions=False) # only reupdate the cache, and allow appending any new entries; presumably GPT3 fixed any "real" contradictions 250 | self.all_paragraphs[-1] = completion 251 | self.paragraphs_by_outline_section[self.outline_sections[-1]][-1] = completion 252 | self.all_entities_dict = self.create_updated_entities(completion.strip(), cached_update_dict=cached_update_dict) 253 | self.is_consistent = True 254 | 255 | @torch.no_grad() 256 | def extend(self, outline_section): 257 | # return a list of up to max_beam_size new BeamCandidates with their respective alignment scores before moving on to the next outline sentence 258 | logging.log(25, 'extension step ' + str(self.step)) 259 | self.step += 1 260 | self.alignment_score = -1e8 261 | self.best_alignment_so_far = -1e8 262 | self.outline_sections.append(outline_section) 263 | self.paragraphs_by_outline_section[outline_section] = [] 264 | completed_candidates = [] 265 | beam = [self] 266 | substep = 0 267 | while len(completed_candidates) < self.args.max_beam_size: 268 | logging.log(25, 'substep ' + str(substep)) 269 | next_candidates = [] 270 | for beam_idx, prev_candidate in enumerate(beam): 271 | candidates = [] 272 | for candidate in prev_candidate.extend_single(outline_section, batch_size=self.args.max_candidates, top_p=self.args.draft_top_p): 273 | candidates.append(candidate) 274 | logging.log(25, 'beam idx ' + str(beam_idx) + ' single extension with score ' + str(candidates[-1].alignment_score)) 275 | candidates = sorted(candidates, key=lambda x: x.alignment_score, reverse=True) 276 | if candidates[0].alignment_score < prev_candidate.best_alignment_so_far - self.args.continuation_threshold: # early termination of expansion of this outline point 277 | logging.log(25, 'beam idx ' + str(beam_idx) + ' adding completed candidate with score ' + str(prev_candidate.alignment_score)) 278 | assert self.args.no_editor or prev_candidate.is_consistent 279 | completed_candidates.append(prev_candidate) 280 | else: 281 | if candidates[0].alignment_score < prev_candidate.best_alignment_so_far: 282 | logging.log(25, 'continuation with slightly worse score') 283 | next_candidates.extend(candidates) 284 | next_candidates = sorted(next_candidates, key=lambda x: x.alignment_score, reverse=True)[:self.args.max_beam_size - len(completed_candidates)] 285 | beam = next_candidates 286 | if not self.args.no_editor: 287 | for c in beam: 288 | c.edit_update_contradictions() 289 | substep += 1 290 | if substep >= self.args.max_continuation_substeps: # fill out the rest of the completed candidates 291 | for c in beam: 292 | logging.log(25, 'beam idx ' + str(beam_idx) + ' adding completed candidate with score ' + str(c.alignment_score)) 293 | assert self.args.no_editor or c.is_consistent 294 | completed_candidates.append(c) 295 | break 296 | return sorted(completed_candidates, key=lambda x: x.alignment_score, reverse=True)[:self.args.max_beam_size] 297 | 298 | def calculate_alignment(self, completions, prompt, outline_section): 299 | if self.args.max_candidates == 1: 300 | return np.zeros(len(completions)) # in this case, we're doing no reranking, and this will also prevent the reranking from being used to decide when to stop. 301 | repetition_penalty = np.array([calculate_repetition_length_penalty(c, [prompt]) for c in completions]) 302 | is_first_person = np.array([1 if detect_first_second_person(c) else 0 for c in completions]) # could have some false positives if the quotations are off, but whatever. 303 | repetition_penalty += is_first_person * 10 304 | alignment_score = 0 305 | if not self.args.no_planner: 306 | alignment_input = [create_prefix_completion(c, outline_section)[1] for c in completions] 307 | multiplier = 1 if self.args.control_strength is None else self.args.control_strength[0] 308 | alignment_score = multiplier * self.controllers[0].evaluate_overall_texts(alignment_input).cpu().numpy() # logprob for alignment with outline 309 | multiplier = 1 if self.args.control_strength is None else self.args.control_strength[1] 310 | if len(self.story().strip()) > 0: 311 | alignment_score += multiplier * self.controllers[1]([cut_first_sentence(self.previous_passage(1000)) for _ in range(len(completions))], completions).cpu().numpy() # logprob for alignment with previous story, up to 1k prev tokens 312 | alignment_score += -repetition_penalty * self.args.repetition_penalty_weight 313 | return alignment_score 314 | 315 | def extend_single(self, outline_section, batch_size=1, top_p=None): 316 | if self.args.outline_levels == 1: 317 | assert self.step == len(self.outline_sections) 318 | if self.args.no_planner: 319 | selected_entities = None 320 | else: 321 | selected_entities = self.select_entities(outline_section, previous_paragraph=self.all_paragraphs[-1] if len(self.all_paragraphs) > 0 else None) 322 | prompt = self.construct_prompt(outline_section, selected_entities=selected_entities) 323 | logging.log(21, 'PROMPT') 324 | logging.log(21, prompt) 325 | if self.args.extension_method == 'gpt3': 326 | completions = self.model([prompt], model_string=self.args.draft_model_string, modify_prompt=False, num_completions=batch_size, top_p=top_p, temperature=self.args.summarizer_temperature, cut_sentence=True, logit_bias={50256:-100, 14126:-100, 7006:-100, 6843:-100, 43582:-100}) # don't let it end prematurely, and don't let it repeatedly generate variants of the word "chapter" since we used it to prompt it initially # stop=['Chapter', 'Chapters', 'Full text', '\n\n\n\n\n'] 327 | elif self.args.extension_method == 'opt': 328 | exclude_strings = stopwords.words('english') + list("!\"“”‘’'(),-.:;?") + ['\n', '\n\n'] + ([] if selected_entities is None else selected_entities) 329 | if self.args.no_planner: 330 | opt_control_logit_bias = {} 331 | else: 332 | assert '\n\nFull text below:\n\n' in prompt 333 | previous_paragraph = prompt.split('\n\nFull text below:\n\n')[-1].strip() 334 | opt_control_logit_bias = self.opt_model.create_logit_bias_for_prompt( 335 | previous_paragraph, 336 | bias=-self.args.summarizer_frequency_penalty, 337 | decay=self.args.summarizer_frequency_penalty_decay, 338 | ) 339 | prompt_logit_bias_string = prompt[:len(prompt) - len(previous_paragraph)] 340 | for character in self.all_entities_dict: 341 | prompt_logit_bias_string = prompt_logit_bias_string.replace(self.all_entities_dict[character].description, '') # don't bias against char descriptions? 342 | opt_control_logit_bias_prompt = self.opt_model.create_logit_bias_for_prompt( 343 | prompt_logit_bias_string, 344 | bias=-self.args.summarizer_prompt_penalty, 345 | exclude_strings=exclude_strings, 346 | ) 347 | for key in opt_control_logit_bias_prompt: 348 | if key in opt_control_logit_bias: 349 | opt_control_logit_bias[key] = min(opt_control_logit_bias[key], opt_control_logit_bias_prompt[key]) 350 | else: 351 | opt_control_logit_bias[key] = opt_control_logit_bias_prompt[key] 352 | opt_control_logit_bias[2] = -1e8 # ban 353 | completions = self.opt_model.generate_with_controller( 354 | [], 355 | [], 356 | prompt, 357 | control_strengths=[], 358 | generation_max_length=self.args.generation_max_length, 359 | temperature=self.args.opt_summarizer_temperature, 360 | logit_bias=opt_control_logit_bias, 361 | num_completions=batch_size, 362 | cut_sentence=self.args.cut_sentence, 363 | logit_bias_decay=self.args.summarizer_frequency_penalty_decay, 364 | ) 365 | else: 366 | raise NotImplementedError 367 | for i in range(len(completions)): 368 | completions[i] = completions[i].strip() 369 | while '\n\n\n' in completions[i]: # just improve the formatting a bit 370 | completions[i] = completions[i].replace('\n\n\n', '\n\n') 371 | for i in range(len(completions)): 372 | _, _, replacements = deduplicate_match_entities(detect_entities(completions[i]), self.all_entities_dict.keys()) 373 | if not self.args.no_editor: 374 | for key, value in replacements.items(): 375 | completions[i] = completions[i].replace(key, value) 376 | alignment_score = self.calculate_alignment(completions, prompt, outline_section) 377 | new_candidates = [] 378 | for c, s in zip(completions, alignment_score): 379 | new_paragraphs_by_outline_section = deepcopy(self.paragraphs_by_outline_section) 380 | new_paragraphs_by_outline_section[outline_section].append(c) 381 | new_candidates.append(BeamCandidate(self.args, 382 | self.all_entities_dict, 383 | self.infer_attributes_string, 384 | model=self.model, 385 | opt_model=self.opt_model, 386 | controllers=self.controllers, 387 | step=self.step, 388 | alignment_score=s, 389 | best_alignment_so_far=max(s, self.best_alignment_so_far), 390 | all_paragraphs=deepcopy(self.all_paragraphs) + [c], 391 | outline_sections=deepcopy(self.outline_sections), 392 | paragraphs_by_outline_section=new_paragraphs_by_outline_section)) 393 | return new_candidates 394 | 395 | def complete_ending(self): 396 | outline_section = self.outline_sections[-1] 397 | if self.args.no_planner: 398 | selected_entities = None 399 | else: 400 | selected_entities = self.select_entities(outline_section, previous_paragraph=self.all_paragraphs[-1] if len(self.all_paragraphs) > 0 else None) 401 | prompt = self.construct_prompt(outline_section, selected_entities=selected_entities) 402 | completions = gpt3_insert(prompt, 403 | '\n\n\n\n' + GPT3_END, 404 | top_p=self.args.draft_top_p, 405 | temperature=self.args.summarizer_temperature, 406 | n=self.args.max_candidates, 407 | max_tokens=self.args.generation_max_length, 408 | frequency_penalty=self.args.summarizer_frequency_penalty, 409 | presence_penalty=self.args.summarizer_presence_penalty) 410 | completions = [c.replace('\n\n\n\n', '\n\n') for c in completions] 411 | alignment_score = self.calculate_alignment(completions, prompt, outline_section) 412 | logging.log(23, 'ENDING ALIGNMENT SCORES ' + str(alignment_score)) 413 | ranked_completions = sorted(zip(completions, alignment_score), key=lambda x: x[1], reverse=True) 414 | ending = ranked_completions[0][0] 415 | should_continue = len(self.model.tokenizer.encode(ending))==self.args.generation_max_length # ending didn't finish writing; should generate more toward the ending after this 416 | ending = cut_last_sentence(ending) 417 | logging.log(23, 'ENDING' + ' ' + ending) 418 | new_paragraphs_by_outline_section = deepcopy(self.paragraphs_by_outline_section) 419 | if outline_section not in new_paragraphs_by_outline_section: 420 | new_paragraphs_by_outline_section[outline_section] = [] 421 | new_paragraphs_by_outline_section[outline_section].append(ending) 422 | new_candidate = BeamCandidate(self.args, 423 | self.all_entities_dict, 424 | self.infer_attributes_string, 425 | model=self.model, 426 | opt_model=self.opt_model, 427 | controllers=self.controllers, 428 | step=self.step, 429 | alignment_score=self.alignment_score, 430 | best_alignment_so_far=self.best_alignment_so_far, 431 | all_paragraphs=deepcopy(self.all_paragraphs) + [ending], 432 | outline_sections=deepcopy(self.outline_sections), 433 | paragraphs_by_outline_section=new_paragraphs_by_outline_section) 434 | if not self.args.no_editor: 435 | new_candidate.edit_update_contradictions() 436 | return new_candidate, should_continue -------------------------------------------------------------------------------- /story_generation/edit_module/evaluate_consistency.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | import os 4 | from copy import deepcopy 5 | import math 6 | import random 7 | 8 | from tqdm import trange 9 | from sklearn.metrics import roc_auc_score 10 | 11 | from story_generation.edit_module.entity import Entity, complete_mutual_relations 12 | from story_generation.common.util import * 13 | from story_generation.common.data.data_util import add_data_args, load_dataset 14 | from story_generation.common.summarizer.summarizer_util import add_summarizer_args, load_summarizer 15 | 16 | 17 | def evaluate_consistency(save_info_file, story_file, instruct_model, method='structured', reinfer=True, verbose=True, contradiction_threshold=0.5): 18 | assert method in ['entailment', 'entailment-dpr', 'structured'] # TODO add baselines 19 | if method == 'entailment': 20 | with open(story_file, 'r') as rf: 21 | story = rf.read() 22 | if story.strip() == 'SKIPPED': 23 | if verbose: 24 | print('SKIPPED STORY') 25 | return None 26 | with open(save_info_file, 'rb') as f: 27 | save_info = pickle.load(f) 28 | base_info = save_info['infer_attributes_string'] 29 | base_info_sentences = split_paragraphs(base_info, mode='sentence') 30 | story_sentences = split_paragraphs(story, mode='sentence') 31 | premises, hypotheses = [], [] 32 | for s1 in base_info_sentences: 33 | for s2 in story_sentences: 34 | premises.append(s1) 35 | hypotheses.append(s2) 36 | logprobs, _ = score_entailment(premises, hypotheses) 37 | return np.exp(logprobs[:, 0]).max() 38 | # if np.exp(logprobs[:, 0]).max() > contradiction_threshold: # contradiction detected 39 | # return 0 40 | # return 1 41 | elif method == 'entailment-dpr': 42 | with open(story_file, 'r') as rf: 43 | story = rf.read() 44 | if story.strip() == 'SKIPPED': 45 | if verbose: 46 | print('SKIPPED STORY') 47 | return None 48 | with open(save_info_file, 'rb') as f: 49 | save_info = pickle.load(f) 50 | base_info = save_info['infer_attributes_string'] 51 | base_info_sentences = split_paragraphs(base_info, mode='sentence') 52 | story_sentences = split_paragraphs(story, mode='sentence') 53 | premises, hypotheses = [], [] 54 | for premise in story_sentences: 55 | scores = score_dpr(premise + 'Is this sentence consistent with the previous story?', base_info_sentences) 56 | premises.append(premise) 57 | hypotheses.append(base_info_sentences[scores.argmax()]) 58 | logprobs, _ = score_entailment(premises, hypotheses) 59 | return np.exp(logprobs[:, 0]).max() 60 | # if np.exp(logprobs[:, 0]).max() > contradiction_threshold: # contradiction detected 61 | # return 0 62 | # return 1 63 | else: 64 | # if the story file is SKIPPED, return None 65 | with open(story_file, 'r') as rf: 66 | story = rf.read() 67 | if story.strip() == 'SKIPPED': 68 | if verbose: 69 | print('SKIPPED STORY') 70 | return None 71 | 72 | # infer attributes on the characters if it's not already there, and resave if necessary 73 | with open(save_info_file, 'rb') as f: 74 | save_info = pickle.load(f) 75 | for character in save_info['character_strings']: 76 | if type(save_info['character_strings'][character]) == dict: # for the data from the paper, we converted it to a dict when refactoring the code to avoid pkl reloading problems 77 | char_info = save_info['character_strings'][character] 78 | save_info['character_strings'][character] = Entity(char_info['name'], char_info['description'], char_info['is_character'], char_info['attributes']) 79 | 80 | story_detected_characters = deduplicate_match_entities(detect_entities(story, all_entities_dict=save_info['character_strings']), save_info['character_strings'].keys())[0] 81 | 82 | if all([len(save_info['character_strings'][character].attributes) == 0 for character in save_info['character_strings']]) or reinfer: # haven't inferred attributes yet, or want to reinfer 83 | if verbose: 84 | print('INFERRING INITIAL ATTRIBUTES') 85 | # for character in story_detected_characters: # no need to infer undetected characters since we wouldn't contradict against them anyway 86 | for character in save_info['character_strings'].keys(): 87 | entity = save_info['character_strings'][character] 88 | entity.reset_attributes() 89 | entity.infer_attributes(save_info['infer_attributes_string'], instruct_model, other_names=[name for name in save_info['character_strings'].keys() if name != entity.name]) 90 | complete_mutual_relations(save_info['character_strings'], instruct_model) 91 | with open(save_info_file, 'wb') as f: 92 | pickle.dump(save_info, f) 93 | 94 | if verbose: 95 | print('INFER ATTRIBUTES STRING') 96 | print(save_info['infer_attributes_string']) 97 | print('STORY') 98 | print(story) 99 | 100 | # infer attributes on the story file, only for the characters detected in this story passage 101 | contradiction_prob = 0 102 | for character in story_detected_characters: 103 | if verbose: 104 | print('CHARACTER') 105 | print(character) 106 | entity = save_info['character_strings'][character] 107 | if verbose: 108 | print('ATTRIBUTES BEFORE') 109 | print(entity.attributes) 110 | new_prob = entity.infer_attributes(story, instruct_model, other_names=[name for name in save_info['character_strings'].keys() if name != entity.name], return_contradiction_prob=True) 111 | contradiction_prob = max(contradiction_prob, new_prob) 112 | if verbose: 113 | print('ATTRIBUTES AFTER') 114 | print(entity.attributes) 115 | 116 | _, new_prob = complete_mutual_relations(save_info['character_strings'], instruct_model, return_contradiction_prob=True) 117 | return max(contradiction_prob, new_prob) 118 | # if len(mutual_relation_contradictions) > 0: 119 | # print('CONTRADICTIONS') 120 | # print(mutual_relation_contradictions) 121 | # return 0 122 | # return 1 123 | 124 | 125 | def evaluate_consistency_dataset(data_dir, instruct_model, method='structured', verbose=True, max_num_files=1000000, contradiction_threshold=0.5): 126 | num_files = len(os.listdir(os.path.join(data_dir, 'original_stories'))) 127 | num_files = min(num_files, max_num_files) 128 | same_scores = [] 129 | diff_scores = [] 130 | 131 | for i in trange(0, num_files): 132 | for save_info_folder, story_folder in [('original', 'original_stories'), ('altered', 'altered_stories')]: 133 | 134 | if verbose: 135 | print('\n\n\n\nEXAMPLE ' + str(i)) 136 | print('\n\nSAME PAIR:', save_info_folder, story_folder) 137 | score = evaluate_consistency(os.path.join(data_dir, save_info_folder, str(i) + '.pkl'), os.path.join(data_dir, story_folder, str(i) + '.txt'), instruct_model, method=method, verbose=verbose, contradiction_threshold=contradiction_threshold) 138 | if verbose: 139 | print('SCORE:', score) 140 | # print({1: 'CORRECT', 0: 'WRONG', None: 'N/A'}[score]) 141 | if score is not None: 142 | same_scores.append(score) 143 | for save_info_folder, story_folder in [('original', 'altered_stories'), ('altered', 'original_stories')]: 144 | 145 | if verbose: 146 | print('\n\n\n\nEXAMPLE ' + str(i)) 147 | print('\n\nDIFF PAIR:', save_info_folder, story_folder) 148 | score = evaluate_consistency(os.path.join(data_dir, save_info_folder, str(i) + '.pkl'), os.path.join(data_dir, story_folder, str(i) + '.txt'), instruct_model, method=method, verbose=verbose, contradiction_threshold=contradiction_threshold) 149 | if verbose: 150 | print('SCORE:', score) 151 | # print({0: 'CORRECT', 1: 'WRONG', None: 'N/A'}[score]) 152 | if score is not None: 153 | diff_scores.append(score) 154 | assert len(same_scores) == len(diff_scores) 155 | print('ROC AUC', roc_auc_score([0 for _ in range(len(same_scores))] + [1 for _ in range(len(diff_scores))], same_scores + diff_scores)) 156 | # if verbose: 157 | # print('TOTAL', len(same_scores)) 158 | # print('SAME CONSISTENCY FRAC', sum(same_scores) / len(same_scores)) 159 | # print('DIFF CONSISTENCY FRAC', sum(diff_scores) / len(diff_scores)) 160 | # return {'total': len(same_scores), 161 | # 'same_frac': sum(same_scores) / len(same_scores), 162 | # 'diff_frac': sum(diff_scores) / len(diff_scores), 163 | # 'same_scores': same_scores, 164 | # 'diff_scores': diff_scores} 165 | 166 | 167 | if __name__=='__main__': 168 | parser = argparse.ArgumentParser() 169 | parser = add_general_args(parser) 170 | parser = add_summarizer_args(parser) 171 | parser.add_argument('--consistency-dataset-dir', type=str, required=True, help='dataset directory') 172 | parser.add_argument('--consistency-method', type=str, default='structured', choices=['structured', 'entailment', 'entailment-dpr'], help='consistency method') 173 | parser.add_argument('--contradiction-threshold', type=float, default=0.5, help='threshold for contradiction prob when using entailment baselines') 174 | args = parser.parse_args() 175 | 176 | base_model = load_summarizer(args) 177 | instruct_args = deepcopy(args) 178 | instruct_args.gpt3_model = 'text-' + args.gpt3_model + '-001' 179 | instruct_model = load_summarizer(instruct_args) 180 | 181 | results = evaluate_consistency_dataset(args.consistency_dataset_dir, instruct_model, method=args.consistency_method, verbose=not args.quiet, contradiction_threshold=args.contradiction_threshold) 182 | -------------------------------------------------------------------------------- /story_generation/edit_module/example_library.csv: -------------------------------------------------------------------------------- 1 | text,name,key,value 2 | Elle Woods is a powerful witch.,Elle Woods,identity,witch 3 | Cara loves Lenny.,Cara,"Lenny's,love interest's name","admirer,Lenny" 4 | Harold Mayfleet is a young boy who loves books.,Harold Mayfleet,"gender,hobby,age","male,books,young" 5 | "Jenna Chen is beautiful, even though she doesn't care about appearances.",Jenna Chen,"gender,appearance","female,beautiful" 6 | Sarah Winters is a woman who is trying to make ends meet.,Sarah Winters,"gender,financial situation","female,trying to make ends meet" 7 | "Molly is curious and determined, and she wants to figure out what is going on in the game world.",Molly,gender,female 8 | Christopher Borden is haunted by his past.,Christopher Borden,gender,male 9 | Scout Moore is a girl who is adjusting to a family move from a small town to the New York City.,Scout Moore,"gender,home location,previous home location","female,New York City,small town" 10 | John Smith was an engineer who created intelligent machines that now run the world.,John Smith,occupation,engineer 11 | Jane Doe is John Doe's wife.,Jane Doe,"gender,John's,husband's name","female,wife,John Doe" 12 | "William ""Billy"" Bensington is an adventurous 12-year-old who loves exploring new things.","William ""Billy"" Bensington",age,12 13 | Mr. Ochocinski works at the hospital in their hometown.,Mr. Ochocinski,"gender,workplace","male,hospital" 14 | Christopher Butler is the son of Ashley Butler.,Christopher Butler,"gender,Ashley's,mother's name","male,son,Ashley Butler" 15 | Anna Fleur is an ambitious young woman.,Anna Fleur,gender,female 16 | Jordan Marshall is Naomi Nakahara's boyfriend.,Jordan Marshall,"gender,Naomi's,girlfriend's name","male,boyfriend,Naomi Nakahara" 17 | David Monroe is not afraid of danger.,David Monroe,, 18 | Tina Palmer befriends Amy Sinkhorn.,Tina Palmer,"Amy's,friend's name","friend,Amy Sinkhorn" 19 | Anabel has faced many challenges in her life.,Anabel,"gender,background","female,many challenges in life" 20 | Arnold Oparin was born into a wealthy family.,Arnold Oparin,background,wealthy 21 | Mallory Saunders is not happy about having to spend her weekend with her parents.,Mallory Saunders,gender,female 22 | Violet Nightingale has been forced to work in the sap harvesting business since she was a child.,Violet Nightingale,"gender,occupation","female,sap harvesting" 23 | Tomas Serrat is Maria Serrat's brother.,Tomas Serrat,"gender,Maria's,sister's name","male,brother,Maria Serrat" 24 | Jensen Wojciehowski is strong and muscular.,Jensen Wojciehowski,appearance,strong and muscular 25 | Mayor Jameson came up with the idea to harvest sap from her own citizens,Mayor Jameson,gender,female 26 | Andrea Montgomery is the daughter of William Montgomery.,Andrea Montgomery,"gender,William's,father's name","female,daughter,William Montgomery" 27 | Eric Beetleson is a human who dreams of being a beetle.,Eric Beetleson,dream,to be a beetle 28 | Sydney Wells is a woman who has been assigned by Montgomery's will to close down his Estate.,Sydney Wells,gender,female 29 | Anderson Peters has spent his life trying to escape his father's shadow.,Anderson Peters,gender,male 30 | Jonathan Hill is a struggling artist.,Jonathan Hill,"occupation,financial situation","artist,struggling" 31 | Megan Hill is Moe Hill's mother.,Megan Hill,"gender,Moe's,son's name","female,mother,Moe Hill" 32 | Sandra Zhang worries about her son's well-being.,Sandra Zhang,gender,female 33 | Emily Foster is Mark Johnson's girlfriend.,Emily Foster,"gender,Mark's,boyfriend's name","female,girlfriend,Mark Johnson" 34 | Sam Jameson has cancer.,Sam Jameson,health,has cancer 35 | William Cruise is a guy who is willing to try Dr. Jameson's serum.,William Cruise,gender,male 36 | "Charlotte Silverstein is 14 years old and 5'8"".",Charlotte Silverstein,"age,height","14,5'8""" 37 | Jeremy Kingston is 16 years old.,Jeremy Kingston,age,16 38 | Rachel Kim's father loves her children dearly.,Rachel Kim,gender,female 39 | Joon Kim is brave and willing to risk her own safety in order to find food for her family.,Joon Kim,gender,female 40 | Brendan Spencer is Mary Spencer's older brother.,Brendan Spencer,"gender,Mary's,sister's name","male,older brother,Mary Spencer" 41 | Meg Green is a scientist.,Meg Green,occupation,scientist 42 | Mallory Jensen is an 18-year-old recent graduate from college.,Mallory Jensen,"age,education","18,recent graduate from college" 43 | Benjamin Wilson is an eccentric man who has lived alone in a small town for as long as anyone can remember.,Benjamin Wilson,"gender,living situation,home location","male,alone,small town" 44 | Aria Winchester is the younger sister of Sam Winchester.,Aria Winchester,"gender,Sam's,brother's name","female,younger sister,Sam Winchester" 45 | "Jill Jones is close with her father, John Smith.",Jill Jones,"gender,father's name,John's","female,John Smith,daughter" 46 | Sherry Evans is a woman who was living the American dream until she was diagnosed with a terminal illness.,Sherry Evans,"gender,health","female,has a terminal illness" 47 | Rebecca Morgan is Sherry Morgan's eldest daughter.,Rebecca Morgan,"gender,Sherry's,mother's name","female,eldest daughter,Sherry Morgan" 48 | Lynn Phillips abandoned her daughter Lacey at a young age.,Lynn Phillips,"gender,daughter's name,Lacey's","female,Lacey Phillips,mother" 49 | Professor Haggerty recommended that Emily move away from home after graduation.,Professor Haggerty,occupation,professor 50 | Tom Smith is confident in his abilities to forage for mushrooms.,Tom Smith,"gender,skills","male,foraging for mushrooms" 51 | Tom is grieving for his wife Anna who died in a car accident.,Tom,"gender,wife's name,Anna's","male,Anna,husband" 52 | George's daughter is named Mina.,George,"daughter's name,Mina's","Mina,father" 53 | Mason Nazerick is Angelina Nazerick's husband.,Angelina Nazerick,"husband's name,Mason's","Mason Nazerith,wife" 54 | Caroline Grayson is Ashley's identical twin sister.,Caroline Grayson,"gender,sister's name,Ashley's","female,Ashley,identical twin sister" 55 | Joanna Daniels is Ted Daniels' wife who works as an engineer for NASA.,Ted Daniels,"Joanna's,wife's name","husband,Joanna" 56 | Andrew Anderson is Greta's reliable brother who helps out with their youngest sibling as much as he can.,Greta,"brother's name,Andrew's","Andrew Anderson,sister" 57 | Kathleen O'Brien is Shannon's mother.,Shannon,"mother's name,Kathleen's","Kathleen O'Brien,daughter" 58 | Brian Hamilton is Alexandra Hamilton's husband.,Alexandra Hamilton,"husband's name,Brian's","Brian Hamilton,wife" 59 | William Stevens is Janna's estranged father.,Janna Stevens,"father's name,William's","William Stevens,daughter" 60 | Renata Turner is Ash Turner's wife.,Ash Turner,"wife's name,Renata's","Renata Turner,husband" 61 | Borek Helma is the son of Katie Helma.,Katie Helma,"son's name,Borek's","Borek Helma,mother" 62 | Selma Vincenti is Nora's friend who recently got engaged to Bill.,Nora Johnson,"friend's name,Selma's","Selma Vincenti,friend" 63 | Victoria Sieffre is Timothy Sieffre's mother.,Timothy Sieffre,"mother's name,Victoria's","Victoria Sieffre,son" 64 | Seohyun Park is Julio Sanchez's girlfriend.,Julio Sanchez,"girlfriend's name,Seohyun's","Seohyun Park,boyfriend" 65 | Tamara Audra is Addison Audra's older sister.,Addison Audra,"sister's name,Tamara's","Tamara Audra,younger brother" 66 | Lily Price is nine years old and the older sister of Mariella Price.,Lily Price,"gender,age,sister's name,Mariella's","female,nine,Mariella Price,older sister" 67 | Emily Carpenter is a 17-year-old girl who is dealing with the death of her father.,Emily Carpenter,"gender,age","female,17" 68 | John is a devoted husband and father.,John Moraby,gender,male 69 | Joey had been writing love letters to his wife Mira throughout his deployment.,Joey,"gender,wife's name,Mira's","male,Mira,husband" 70 | Jane is married to John Smith and they have been together since high school.,Jane Smith,"husband's name,John's","John Smith,wife" 71 | Brenton Goldman is Sarah's father.,Brenton Goldman,"gender,Sarah's,daughter's name","male,father,Sarah" 72 | Shane West is Becky West's older brother.,Becky West,"brother's name,Shane's","Shane West,younger sister" 73 | John Doe is Jane Doe's father.,Jane Doe,"father's name,John's","John Doe,daughter" 74 | Becca Sullivan is Jason's younger daughter and Cassie's sister who stayed behind at their family's compound,Jason,"daughter's name,Becca's,Cassie's","Becca Sullivan and Cassie,father,father" 75 | Nolan Moore is Shannon's boss at the newspaper where she works as a reporter.,Shannon Rourke,"gender,boss's name,Nolan's","female,Nolan Moore,employee" 76 | Jacob Stein is Joseph's boss at the newspaper where he works as an editor.,Jacob Stein,"Joseph's,employee's name","boss,Joseph" 77 | Sarah Lincoln is a four-year-old girl who is the daughter of Mary Lincoln.,Sarah Lincoln,"gender,age,Mary's,mother's name","female,four,daughter,Mary Lincoln" 78 | Timmy Reed is a six-year-old boy who lost his father and is struggling to adjust to his new home.,Timmy Reed,"gender,age","male,six" 79 | "Johnny is a friendly and outgoing person, and he loves spending time with his sister Mira.",Johnny,"gender,sister's name,Mira's","male,Mira,brother" 80 | Anne Robbins is a teenage girl who is Professor Plum's granddaughter.,Anne Robbins,"gender,age,Professor Plum's,grandfather's name","female,teenage,granddaughter,Professor Plum" 81 | Margot Carson is Connor's grandmother.,Margot Carson,"gender,Connor's,grandson's name","female,grandmother,Connor" 82 | Colin Adams is Margaret's father and Serena's husband.,Colin Adams,"gender,Margaret's,daughter's name,Serena's,wife's name","male,father,Margaret,husband,Serena" 83 | Alonso Alvarez is Jacob and Maya's father.,Maya Alvarez,"father's name,Alonso's,brother's name,Jacob's","Alonso Alvarez,daughter,Jacob,sister" -------------------------------------------------------------------------------- /story_generation/plan_module/plan.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | from enum import auto 4 | import os 5 | from copy import deepcopy 6 | import pickle 7 | from collections import defaultdict 8 | import multiprocessing as mp 9 | import random 10 | import string 11 | import logging 12 | 13 | import torch 14 | import Levenshtein 15 | import numpy as np 16 | from transformers import AutoTokenizer 17 | import openai 18 | from scipy.special import softmax 19 | 20 | from story_generation.edit_module.entity import * 21 | from story_generation.rewrite_module.heuristics import * 22 | from story_generation.common.util import * 23 | from story_generation.common.data.data_util import add_data_args, load_dataset 24 | from story_generation.common.summarizer.summarizer_util import add_summarizer_args, load_summarizer 25 | from story_generation.common.summarizer.models.gpt3_summarizer import GPT3_SEP, GPT3_END 26 | from story_generation.common.controller.controller_util import add_controller_args, load_controller 27 | from story_generation.common.controller.loaders.alignment_loader import create_prefix_completion 28 | from story_generation.common.data.split_paragraphs import * 29 | 30 | def generate_initial_entity_strings(premise, setting, instruct_model, num_entities=3, max_description_length=48, model_string='text-davinci-002'): 31 | # TODO figure out alternative stopping criterion for generating initial characters? 32 | initial_characters_prompt = "Premise: " + premise.strip() + '\n\n' + 'Setting: ' + setting.strip() + '\n\nList the names and details of all major characters.' 33 | name_bias_words = ['protagonist', 'Protagonist', 'PROTAGONIST', 'unnamed', 'Unnamed', 'UNNAMED', 'unknown', 'Unknown', 'UNKNOWN', 'None', 'none', 'None', 'Mr.', 'Ms.', 'Mrs.', 'Dr.', 'TBA', 'TBD', 'N/A'] # technically no ' can filter out some reasonable names, but it's not a big deal and prevents some bad cases 34 | banned_name_words = name_bias_words + ['\'', '_', '\n', '"', '#', '1', '2', '3', '4', '5', '6', '7', '8', '9', '0', 'redacted', 'mother', 'father', 'gram', 'grand', 'name', 'appearance', 'occupation', 'age', 'gender', 'sex', 'role', 'profession', 'job', 'friend'] + list(string.punctuation) # sometimes it'll find weird ascii chars to replace these if they're banned via logit bias 35 | name_logit_bias = get_repetition_logit_bias(instruct_model.tokenizer, initial_characters_prompt + ' ' + ' '.join(name_bias_words), bias=-5, bias_common_tokens=True) 36 | name_logit_bias[198] = -5 # also penalize newline, although we want it eventually eventually 37 | 38 | character_strings = {} 39 | characters_prompt = initial_characters_prompt 40 | for i in range(num_entities): 41 | characters_prompt += '\n\n' + str(i+1) +'.\n\nFull Name:' 42 | for _ in range(2): 43 | name_continuations = instruct_model([characters_prompt], modify_prompt=False, top_p=1, temperature=1.2, logit_bias=name_logit_bias, stop=['\n', '(', ':'], num_completions=10, generation_max_length=10, model_string=model_string) 44 | filtered_name_continuations = [] 45 | for name in name_continuations: 46 | name_is_good = True 47 | name = name.strip() 48 | for word in name.strip().split(): 49 | if word.strip(string.punctuation) not in characters_prompt and sum([1 for n in name_continuations if word in n]) >= 2: # >=2 because it's in the name itself and at least 1 other 50 | name_is_good = False 51 | logging.log(23, 'bad name word ' + word + ' in ' + name) 52 | for tok in instruct_model.tokenizer.encode(word) + instruct_model.tokenizer.encode(' ' + word): 53 | name_logit_bias[tok] = -100 54 | if not name_is_good: 55 | continue 56 | if not any([key.strip() in name.strip() or name.strip() in key.strip() for key in character_strings]) and len(name.strip()) > 0 and all([piece.strip()[0].isupper() for piece in name.strip().split()]) and all([word.lower() not in name.lower() for word in banned_name_words+name_bias_words]): # check that names are capitalized to filter out some bad cases 57 | if not any([word.strip('"') not in initial_characters_prompt and word.lower() in initial_characters_prompt.lower() for word in name.strip().split()]) and sum([1 for letter in name if letter.isupper()]) == len(name.strip().split()): # don't allow cases where it dodged our checks by changing case 58 | filtered_name_continuations.append(name) 59 | if len(filtered_name_continuations) > 0: 60 | break 61 | if len(filtered_name_continuations) == 0: 62 | if len(character_strings) > 0: # just settle for fewer characters 63 | break 64 | else: 65 | raise ValueError 66 | filtered_name_continuations = sorted(filtered_name_continuations, key=lambda x: abs(2 - len(x.strip().split()))) # ideally want the full name, not just the first word, and want roughly 2 words 67 | selected_name = filtered_name_continuations[0].strip() 68 | name_logit_bias = get_repetition_logit_bias(instruct_model.tokenizer, selected_name.strip().split()[0], bias=-100, bias_common_tokens=True, existing_logit_bias=name_logit_bias) 69 | banned_name_words.append(selected_name.strip().split()[0]) 70 | characters_prompt += ' ' + selected_name + '\n\nCharacter Portrait: ' + selected_name.strip() + ' is' 71 | found_acceptable_description = False 72 | logging.log(21, 'CHARACTERS PROMPT', characters_prompt) 73 | for j in range(5): 74 | description_logit_bias = get_repetition_logit_bias(instruct_model.tokenizer, initial_characters_prompt + ' ' + ' '.join(name_bias_words), bias=-2**(j+1), bias_common_tokens=False) 75 | name_tokens = set(sum([instruct_model.tokenizer.encode(ent) + instruct_model.tokenizer.encode(' ' + ent) for ent in character_strings.keys()], [])) 76 | for tok in name_tokens: 77 | if tok in description_logit_bias: 78 | del description_logit_bias[tok] 79 | descriptions = instruct_model([characters_prompt], modify_prompt=False, stop='\n', logit_bias=description_logit_bias, num_completions=10, generation_max_length=max_description_length, cut_sentence=True, model_string=model_string) 80 | logging.log(21, 'DESCRIPTIONS', descriptions) 81 | descriptions = [d for d in descriptions if len(d.strip()) > 0 and len(instruct_model.tokenizer.encode(d)) < max_description_length] # not empty, and terminated naturally rather than due to max length 82 | descriptions = sorted(descriptions, key=lambda d: calculate_repetition_length_penalty(d, [characters_prompt])) 83 | if len(descriptions) > 0 and calculate_repetition_length_penalty(descriptions[0], [characters_prompt]) < 1: 84 | found_acceptable_description = True 85 | break 86 | if not found_acceptable_description: 87 | logging.warning('Warning: no acceptable description found for character ' + selected_name) 88 | assert False 89 | description = descriptions[0] 90 | characters_prompt += description 91 | character_strings[selected_name.strip()] = Entity(selected_name.strip(), description=selected_name.strip() + ' is' + description, is_character=True) 92 | infer_attributes_string = premise.strip() + '\n\n' + setting.strip() + '\n\n' + '\n\n'.join([ent.description for ent in character_strings.values()]) 93 | return characters_prompt[len(initial_characters_prompt):].strip(), character_strings, infer_attributes_string 94 | 95 | 96 | def generate_outline(premise, setting, characters, character_strings, instruct_model, generation_max_length, max_sections=5, fixed_outline_length=-1, outline_levels=1, model_string='text-davinci-002'): 97 | premise_setting_chars = "Premise: " + premise.strip() + '\n\n' + 'Setting: ' + setting.strip() + '\n\n' + 'Characters: ' + characters.strip() 98 | 99 | if fixed_outline_length > 0: 100 | outline_prompt = premise_setting_chars + '\n\n\n\nOutline the ' + str(fixed_outline_length) + ' main plot points of the story.\n\n1.' 101 | else: 102 | outline_prompt = premise_setting_chars + '\n\n\n\nOutline the main plot points of the story.\n\n1.' 103 | found_acceptable_outline = False 104 | for i in range(5): 105 | # bias against repeating the tokens in the prompt, except for the character names themselves 106 | outline_logit_bias = get_repetition_logit_bias(instruct_model.tokenizer, outline_prompt, -2**(i+1)) 107 | name_tokens = set(sum([instruct_model.tokenizer.encode(ent) + instruct_model.tokenizer.encode(' ' + ent) for ent in character_strings.keys()], [])) 108 | for tok in name_tokens: 109 | if tok in outline_logit_bias: 110 | del outline_logit_bias[tok] 111 | outlines = instruct_model([outline_prompt], logit_bias=outline_logit_bias, generation_max_length=generation_max_length, num_completions=5, model_string=model_string) 112 | for outline in outlines: 113 | if fixed_outline_length > 0: 114 | if str(fixed_outline_length) + '.' not in outline or str(fixed_outline_length+1) + '.' in outline: # looking for exactly this length 115 | continue 116 | if len(split_list('1.' + outline)) < 3: # failure 117 | continue 118 | if '2.' not in outline or '3.' not in outline: # properly formatted list and contains at least 3 items 119 | continue 120 | if str(max_sections) + '.' in outline: # number of sections in outline exceeds maximum 121 | continue 122 | if calculate_repetition_length_penalty(outline, [setting, characters], is_outline=True) > 0: # it's fine if some of the premise is repeated e.g. in the early parts 123 | continue 124 | if len(instruct_model.tokenizer.encode(outline)) < generation_max_length: # ideally, terminate because the outline is done, not because it was too long 125 | found_acceptable_outline = True 126 | break 127 | if found_acceptable_outline: 128 | break 129 | if not found_acceptable_outline: 130 | logging.warning('Warning: didn\'t find acceptable outline') 131 | raise ValueError 132 | outline = ('1.' + outline).strip() 133 | logging.log(23, outline) 134 | if outline_levels > 1: 135 | all_detailed_outlines = [] 136 | assert outline_levels == 2 # in principle could support more 137 | for outline_idx, outline_piece in enumerate(split_list(outline)): 138 | found_acceptable_outline = False 139 | for i in range(5): 140 | detailed_outline_logit_bias = get_repetition_logit_bias(instruct_model.tokenizer, outline_prompt + ' ' + ' '.join([op for op in split_list(outline)]), -2**(i+1)) 141 | name_tokens = set(sum([instruct_model.tokenizer.encode(ent) + instruct_model.tokenizer.encode(' ' + ent) for ent in character_strings.keys()], [])) 142 | for tok in name_tokens: 143 | if tok in outline_logit_bias: 144 | del outline_logit_bias[tok] 145 | detailed_outlines = instruct_model([premise_setting_chars + '\n\nOutline:\n\n' + '\n\n'.join([op for op in split_list(outline)[:outline_idx]]) + '\n\nList the minor events in the next part of the story, in which ' + outline_piece.strip() + '\n\n1.'], logit_bias=detailed_outline_logit_bias, generation_max_length=generation_max_length, num_completions=5, model_string=model_string) 146 | for detailed_outline in detailed_outlines: 147 | if fixed_outline_length > 0: 148 | if str(fixed_outline_length) + '.' not in detailed_outline or str(fixed_outline_length+1) + '.' in detailed_outline: # looking for exactly this length 149 | continue 150 | if len(split_list('1.' + detailed_outline)) < 3: # failure 151 | continue 152 | if '2.' not in detailed_outline or '3.' not in detailed_outline: # properly formatted list and contains at least 3 items 153 | continue 154 | if str(max_sections) + '.' in detailed_outline: # number of sections in outline exceeds maximum 155 | continue 156 | if calculate_repetition_length_penalty(detailed_outline, [setting, characters, outline], is_outline=True) > 0: # it's fine if some of the premise is repeated e.g. in the early parts 157 | continue 158 | if len(instruct_model.tokenizer.encode(detailed_outline)) < generation_max_length: # ideally, terminate because the outline is done, not because it was too long 159 | found_acceptable_outline = True 160 | break 161 | if found_acceptable_outline: 162 | break 163 | if not found_acceptable_outline: 164 | logging.log(23, 'Warning: didn\'t find acceptable outline') 165 | raise ValueError 166 | all_detailed_outlines.append('1.' + detailed_outline) 167 | outline = (outline, all_detailed_outlines) 168 | return outline 169 | 170 | 171 | def load_plan_info(plan_file): 172 | with open(plan_file, 'rb') as f: 173 | save_info = pickle.load(f) 174 | return save_info 175 | 176 | 177 | def generate_plan_info(args, instruct_model, include_outline=True, model_string='text-davinci-002'): 178 | while True: 179 | try: 180 | if args.premise is None: 181 | premise_prompt = "Write a premise for a short story." 182 | max_premise_tokens = 128 183 | premise = (instruct_model([premise_prompt], top_p=1, temperature=1.2, modify_prompt=False, generation_max_length=max_premise_tokens, model_string=model_string)[0]) # more diversity with premises with higher temp 184 | if len(instruct_model.tokenizer.encode(premise)) == max_premise_tokens: # likely we got cutoff instead of ending naturally 185 | logging.warning('premise too long, retrying') 186 | raise ValueError 187 | premise = premise.strip() 188 | else: 189 | premise = args.premise.strip() 190 | 191 | logging.log(25, 'Premise: ' + premise) 192 | 193 | for i in range(10): # avoid resampling good premises for fairness 194 | try: 195 | setting_prompt = "Premise: " + premise.strip() + '\n\nDescribe the setting of the story.\n\nThe story is set in' 196 | settings = [] 197 | for i in range(5): 198 | banned_setting_words = ['unknown', 'unnamed', 'unspecified', 'Unknown', 'Unnamed', 'Unspecified'] 199 | setting_logit_bias = get_repetition_logit_bias(instruct_model.tokenizer, setting_prompt, -2**(i+1)) 200 | settings = instruct_model([setting_prompt], num_completions=10, modify_prompt=False, logit_bias=setting_logit_bias, generation_max_length=32, cut_sentence=True, model_string=model_string) 201 | settings = [split_paragraphs(s, mode='sentence')[0] for s in settings] 202 | settings = [s.strip() for s in settings if calculate_repetition_length_penalty(s, [premise]) == 0 and not any([w in s.lower() for w in banned_setting_words])] 203 | settings = ['The story is set in ' + s for s in settings] 204 | if len(settings) > 0: 205 | break 206 | setting = settings[0] 207 | 208 | logging.log(25, 'Setting: ' + setting) 209 | 210 | characters, character_strings, infer_attributes_string = generate_initial_entity_strings(premise, setting, instruct_model, max_description_length=args.entity_description_max_length, model_string=model_string) 211 | 212 | logging.log(25, 'Characters: ' + str(characters)) 213 | 214 | for entity in character_strings.values(): 215 | logging.log(23, entity) 216 | 217 | if not include_outline: 218 | outline = None 219 | break 220 | outline_max_tokens = 128 221 | outline = generate_outline(premise, setting, characters, character_strings, instruct_model, outline_max_tokens, fixed_outline_length=args.fixed_outline_length, outline_levels=args.outline_levels) 222 | 223 | # assume gpt3 was smart enough to number them when prompted 224 | if type(outline) == tuple: 225 | outline_sections = sum([split_list(op) for op in outline[1]], []) 226 | else: 227 | outline_sections = split_list(outline) 228 | 229 | logging.log(25, 'Outline: ' + str(outline)) 230 | 231 | # do the attribute inference after outlines are generated, since it can be expensive 232 | if not args.no_attributes and not args.no_editor and not args.no_planner: 233 | for entity in character_strings.values(): 234 | entity.infer_attributes(infer_attributes_string, instruct_model, other_names=[name for name in character_strings.keys() if name != entity.name]) 235 | complete_mutual_relations(character_strings, instruct_model) 236 | break 237 | except Exception as e: 238 | import traceback 239 | traceback.print_exc() 240 | logging.log(23, 'Plan generation failed: ' + str(e)) 241 | if i == 9: 242 | logging.warning('WARNING: Could not generate a valid setup after 10 attempts.') 243 | break 244 | except Exception as e: 245 | import traceback 246 | traceback.print_exc() 247 | logging.warning('Exception ' + str(e)) 248 | continue 249 | save_info = {'premise': premise, 250 | 'setting': setting, 251 | 'characters': characters, 252 | 'character_strings': character_strings, 253 | 'outline': outline, 254 | 'outline_sections': outline_sections, 255 | 'infer_attributes_string': infer_attributes_string} 256 | return save_info 257 | 258 | 259 | def infer_initial_attributes_from_plan(save_info, instruct_model): 260 | character_strings = save_info['character_strings'] 261 | infer_attributes_string = save_info['infer_attributes_string'] 262 | made_changes = False 263 | for entity in character_strings.values(): 264 | if len(entity.attributes) == 0 and entity.is_character: # unlikely that we inferred nothing from an initial setup passage 265 | made_changes = True 266 | entity.infer_attributes(infer_attributes_string, instruct_model, other_names=[name for name in character_strings.keys() if name != entity.name]) 267 | if made_changes: 268 | complete_mutual_relations(character_strings, instruct_model) -------------------------------------------------------------------------------- /story_generation/rewrite_module/README.md: -------------------------------------------------------------------------------- 1 | NOTE: most of the neural model based reranking infra is in story_generation/common/controller, not here. 2 | -------------------------------------------------------------------------------- /story_generation/rewrite_module/heuristics.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | from enum import auto 4 | import os 5 | from copy import deepcopy 6 | import pickle 7 | from collections import defaultdict 8 | import multiprocessing as mp 9 | import random 10 | import string 11 | 12 | import torch 13 | import Levenshtein 14 | import numpy as np 15 | from transformers import AutoTokenizer 16 | import openai 17 | from scipy.special import softmax 18 | 19 | from story_generation.edit_module.entity import * 20 | from story_generation.common.util import * 21 | from story_generation.common.data.data_util import add_data_args, load_dataset 22 | from story_generation.common.summarizer.summarizer_util import add_summarizer_args, load_summarizer 23 | from story_generation.common.summarizer.models.gpt3_summarizer import GPT3_SEP, GPT3_END 24 | from story_generation.common.controller.controller_util import add_controller_args, load_controller 25 | from story_generation.common.controller.loaders.alignment_loader import create_prefix_completion 26 | from story_generation.common.data.split_paragraphs import * 27 | 28 | def detect_first_second_person(text): 29 | text = text.split('"') 30 | for i in range(0, len(text), 2): # all the sections that are outside of quotations 31 | if 'I ' in text[i] or "I'" in text[i] or 'you ' in text[i].lower() or "you'" in text[i].lower() or " we " in text[i].lower() or "\nwe " in text[i].lower() or "we'" in text[i].lower(): 32 | return True 33 | return False 34 | 35 | 36 | def calculate_repetition_length_penalty(generation, prompt_sentences, levenshtein_repetition_threshold=0.8, max_length=None, tokenizer=None, is_outline=False): 37 | if len(generation.strip()) == 0: 38 | return 10 39 | if max_length is not None: 40 | if len(tokenizer.encode(generation)) > max_length: 41 | return 10 42 | if any([s.lower() in generation.lower() for s in ['\nRelevant', '\nContext', '\nComment', 'Summar', '\nSupporting', '\nEvidence', '\nStages', '\nText', '\nAssum', '\n1.', '\n1)', '\nRelationship', '\nMain Character', '\nCharacter', '\nConflict:', '\nPlot', 'TBA', 'POV', 'protagonist', '\nEdit ', '\nPremise', 'Suspense', 'www', '[', ']', 'copyright', 'chapter', '\nNote', 'Full Text', 'narrat', '\n(', 'All rights reserved', '(1)', 'passage', '\nRundown', 'playdown', 'episode', 'plot device', 'java', '\nQuestion', '\nDiscuss', 'The story', 'This story']]): # it's repeating parts of the prompt/reverting to analysis 43 | return 10 44 | generation_paragraphs = split_paragraphs(generation, mode='newline') 45 | for paragraph in generation_paragraphs: 46 | if len(paragraph.strip()) == 0: 47 | continue 48 | if ':' in ' '.join(paragraph.strip().split()[:10]) or paragraph.strip().endswith(':'): # there's a colon in the first few words, so it's probably a section header for some fake analysis, or ends with a colon 49 | return 10 50 | penalty = 0 51 | for p in prompt_sentences: 52 | split = p.lower().split(' ') 53 | for i in range(6, len(split)): 54 | if ' '.join(split[i-5:i]) in generation.lower(): # somewhat penalize repeated strings of 5 words or more for each prompt sentence 55 | penalty += 0.3 56 | # break 57 | split = generation.lower().split(' ') 58 | for i in range(6, len(split)): 59 | if ' '.join(split[i-5:i]) in ' '.join(split[i:]): # penalize repetition within the generation itself 60 | penalty += 0.3 61 | # break 62 | mildly_bad_strings = ['\n\n\n\n\n', 'passage', 'perspective', 'point of view', 'summar', 'paragraph', 'sentence', 'example', 'analy', 'section', 'character', 'review', 'readers', '(', ')', 'blog', 'website', 'comment'] 63 | if not is_outline: 64 | mildly_bad_strings += ['1.', '2.', '3.', '4.', '5.'] 65 | num_mildly_bad_strings = sum([1 for s in mildly_bad_strings if s in generation.lower()]) 66 | if num_mildly_bad_strings > 0: 67 | penalty += num_mildly_bad_strings # discourage multiple of these strings appearing, since it's likely that this is resulting from GPT3 generating story analysis 68 | generation_sentences = split_paragraphs(generation, mode='sentence') 69 | for g in generation_sentences: 70 | for p in prompt_sentences: 71 | if Levenshtein.ratio(g, p) > levenshtein_repetition_threshold: 72 | penalty += 1 73 | return penalty --------------------------------------------------------------------------------