├── parrot2.png ├── cog.yaml ├── LICENSE ├── predict.py ├── README.md └── .gitignore /parrot2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyrick/cog-prompt-parrot/HEAD/parrot2.png -------------------------------------------------------------------------------- /cog.yaml: -------------------------------------------------------------------------------- 1 | # Configuration for Cog ⚙️ 2 | # Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md 3 | 4 | image: "r8.im/kyrick/prompt-parrot" 5 | build: 6 | gpu: true 7 | cuda: "11.6.2" 8 | python_version: "3.10" 9 | python_packages: 10 | - "transformers==4.21.1" 11 | - "torch==1.12.1 --extra-index-url=https://download.pytorch.org/whl/cu116" 12 | - "ftfy==6.1.1" 13 | - "scipy==1.9.0" 14 | 15 | predict: "predict.py:Predictor" 16 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Stephen Young 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | 3 | from cog import BasePredictor, Input 4 | from transformers import AutoModelForCausalLM, AutoTokenizer 5 | 6 | start_token = "" 7 | pad_token = "" 8 | end_token = "" 9 | 10 | class Predictor(BasePredictor): 11 | def setup(self): 12 | """Load the model into memory to make running multiple predictions efficient""" 13 | self.model = AutoModelForCausalLM.from_pretrained("./model") 14 | self.tokenizer = AutoTokenizer.from_pretrained( 15 | "distilgpt2", cache_dir="./model", bos_token=start_token, eos_token=end_token, pad_token=pad_token 16 | ) 17 | 18 | def predict( 19 | self, prompt: str = Input(description="Input prompt", default="a beautiful painting") 20 | ) -> str: 21 | num_prompts_to_generate = 5 22 | max_prompt_length = 50 23 | min_prompt_length = 30 24 | temperature = 1.0 25 | top_k = 70 26 | top_p = 0.9 27 | 28 | 29 | if not prompt: 30 | raise UserWarning("manual_prompt must be at least 1 letter") 31 | 32 | encoded_prompt = self.tokenizer( 33 | prompt, add_special_tokens=False, return_tensors="pt" 34 | ).input_ids 35 | encoded_prompt = encoded_prompt.to(self.model.device) 36 | 37 | output_sequences = self.model.generate( 38 | input_ids=encoded_prompt, 39 | max_length=max_prompt_length, 40 | min_length=min_prompt_length, 41 | temperature=temperature, 42 | top_k=top_k, 43 | top_p=top_p, 44 | do_sample=True, 45 | num_return_sequences=num_prompts_to_generate, 46 | pad_token_id=self.tokenizer.pad_token_id, # gets rid of warning 47 | ) 48 | 49 | tokenized_start_token = self.tokenizer.encode(start_token) 50 | 51 | generated_prompts = [] 52 | for generated_sequence in output_sequences: 53 | # precision is a virtue 54 | tokens = [] 55 | for i, s in enumerate(generated_sequence): 56 | if s in tokenized_start_token and i != 0: 57 | if len(tokens) >= min_prompt_length: 58 | break 59 | tokens.append(s) 60 | 61 | text = self.tokenizer.decode( 62 | tokens, clean_up_tokenization_spaces=True, skip_special_tokens=True 63 | ) 64 | text = ( 65 | text.strip().replace("\n", " ").replace("/", ",") 66 | ) # / remove slash. It causes problems in namings 67 | # remove repeated adjacent words from `text`. For example: "lamma lamma is cool cool" -> "lamma is cool" 68 | text = " ".join([k for k, g in itertools.groupby(text.split())]) 69 | generated_prompts.append(text) 70 | 71 | return "\n------------------------------------------\n".join(generated_prompts) 72 | 73 | 74 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # COG Prompt Parrot 2 | 3 |

4 | 5 |

6 | 7 | [Run Prompt Parrot on Replicate!](https://replicate.com/kyrick/prompt-parrot) 8 | 9 | This is the official "fork" of the [Prompt Parrot Colab](https://colab.research.google.com/drive/1GtyVgVCwnDfRvfsHbeU0AlG-SgQn1p8e?usp=sharing) for Replicate. But you can also run it locally! 10 | 11 | # What is Prompt Parrot? 12 | 13 | Prompt Parrot takes a text input and generates text2image prompts you can use with Stable Diffusion, Midjourney, etc. To use the parrot, you provide a base prompt and the parrot completes the rest of the prompt. It's a bit like mad libs for text prompts except more orderly (in theroy)! Prompt Parrot is an idea generating tool. It may present you with interesting styles and fantastic combinations! 14 | 15 | Example usage: input a base prompt such as "a majestic horse in a field" and see what the Parrot comes up with! 16 | 17 | ### What is it? 18 | 19 | Prompt Parrot started life as a colab to finetune your own distilgpt2 model on your own prompts and then generate prompts in your style. This version is still online and you can try it out here: [Prompt Parrot Colab](https://colab.research.google.com/drive/1GtyVgVCwnDfRvfsHbeU0AlG-SgQn1p8e?usp=sharing) 20 | 21 | Eventually, it became clear that providing many prompts was a barrier to entry. It's difficult to pull together a large amount of prompts. So the hosted model is distilgpt2 fine-tuned on a corpus of 51,747 community provided text prompts. The model uses this training data to geberate prompts for text2image prompt generators. 22 | 23 | # Online Runs 24 | 25 | [Run Prompt Parrot on Replicate!](https://replicate.com/kyrick/prompt-parrot) 26 | 27 | # Local Runs 28 | 29 | Steps: 30 | 1. create a python virtualenv 31 | 1. `pip install cog` 32 | 1. [Download the model from gdrive](https://drive.google.com/drive/folders/1X7PoqEEN8qxV6wvHnHuOWGm7T3Zr2l2w?usp=sharing). 33 | 1. Download `config.json` and `pytorch_model.bin` into folder `models/`. 34 | 1. in terminal execute: `sudo cog predict -i prompt="a crystal fortress in a field of flowers"` 35 | 36 | # Credits 37 | Lead Developer - [Stephen Young](https://linktr.ee/KyrickYoung) 38 | 39 | Special thanks to members of the AI art community for providing data train prompt parrot! 40 | 41 | # License 42 | The MIT License (MIT) 43 | 44 | Copyright 2022 Stephen Young 45 | 46 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 47 | 48 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 49 | 50 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 51 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.toptal.com/developers/gitignore/api/visualstudiocode,python 2 | # Edit at https://www.toptal.com/developers/gitignore?templates=visualstudiocode,python 3 | 4 | ### Python ### 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | cover/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | .pybuilder/ 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | # For a library or package, you might want to ignore these files since the code is 91 | # intended to run in multiple environments; otherwise, check them in: 92 | # .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # poetry 102 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 103 | # This is especially recommended for binary packages to ensure reproducibility, and is more 104 | # commonly ignored for libraries. 105 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 106 | #poetry.lock 107 | 108 | # pdm 109 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 110 | #pdm.lock 111 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 112 | # in version control. 113 | # https://pdm.fming.dev/#use-with-ide 114 | .pdm.toml 115 | 116 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 117 | __pypackages__/ 118 | 119 | # Celery stuff 120 | celerybeat-schedule 121 | celerybeat.pid 122 | 123 | # SageMath parsed files 124 | *.sage.py 125 | 126 | # Environments 127 | .env 128 | .venv 129 | env/ 130 | venv/ 131 | ENV/ 132 | env.bak/ 133 | venv.bak/ 134 | 135 | # Spyder project settings 136 | .spyderproject 137 | .spyproject 138 | 139 | # Rope project settings 140 | .ropeproject 141 | 142 | # mkdocs documentation 143 | /site 144 | 145 | # mypy 146 | .mypy_cache/ 147 | .dmypy.json 148 | dmypy.json 149 | 150 | # Pyre type checker 151 | .pyre/ 152 | 153 | # pytype static type analyzer 154 | .pytype/ 155 | 156 | # Cython debug symbols 157 | cython_debug/ 158 | 159 | # PyCharm 160 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 161 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 162 | # and can be added to the global gitignore or merged into this file. For a more nuclear 163 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 164 | #.idea/ 165 | 166 | ### Python Patch ### 167 | # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration 168 | poetry.toml 169 | 170 | 171 | ### VisualStudioCode ### 172 | .vscode/* 173 | 174 | # Local History for Visual Studio Code 175 | .history/ 176 | 177 | # Built Visual Studio Code Extensions 178 | *.vsix 179 | 180 | ### VisualStudioCode Patch ### 181 | # Ignore all local history of files 182 | .history 183 | .ionide 184 | 185 | # End of https://www.toptal.com/developers/gitignore/api/visualstudiocode,python 186 | 187 | model/ 188 | .cog/ --------------------------------------------------------------------------------