├── .github └── workflows │ └── python-publish.yml ├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── deep_daze ├── __init__.py ├── cli.py ├── clip.py ├── data │ └── bpe_simple_vocab_16e6.txt ├── deep_daze.py └── version.py ├── instruction_images └── Windows │ ├── Step_1_DD_Win.png │ └── Step_2_DD_Win.png ├── samples ├── A_man_painting_a_completely_red_image.png ├── A_psychedelic_experience_on_LSD.png ├── A_time_traveler_in_the_crowd.jpg ├── Autumn_1875_Frederic_Edwin_Church.jpg ├── Autumn_1875_Frederic_Edwin_Church_original.jpg ├── Cosmic_love_and_attention.jpg ├── Life_during_the_plague.jpg ├── Meditative_peace_in_a_sunlit_forest.jpg ├── Mist_over_green_hills.jpg ├── Shattered_plates_on_the_grass.jpg ├── cosmic-love.png ├── hot-dog.jpg ├── hot-dog_imagined.png ├── life-plague.png ├── mist-over-green-hills.png ├── peace-sunlit-forest.png ├── prime-orig.jpg ├── prime-trained.png ├── psychedelic_hot_dog.png ├── shattered-plates.png └── time-traveler.png └── setup.py /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | name: Upload Python Package 5 | 6 | on: 7 | release: 8 | types: [created] 9 | 10 | jobs: 11 | deploy: 12 | 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v2 17 | - name: Set up Python 18 | uses: actions/setup-python@v2 19 | with: 20 | python-version: '3.x' 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install setuptools wheel twine 25 | - name: Build and publish 26 | env: 27 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 28 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 29 | run: | 30 | python setup.py sdist bdist_wheel 31 | twine upload dist/* 32 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # IDEs 132 | .vscode/ 133 | .idea/ 134 | 135 | output/ 136 | run.py 137 | run.sh 138 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Ryan Murdock, Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-include deep_daze *.txt 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Deep Daze 2 | 3 | 4 | 5 | *mist over green hills* 6 | 7 | 8 | 9 | *shattered plates on the grass* 10 | 11 | 12 | 13 | *cosmic love and attention* 14 | 15 | 16 | 17 | *a time traveler in the crowd* 18 | 19 | 20 | 21 | *life during the plague* 22 | 23 | 24 | 25 | *meditative peace in a sunlit forest* 26 | 27 | 28 | 29 | *a man painting a completely red image* 30 | 31 | 32 | 33 | *a psychedelic experience on LSD* 34 | 35 | ## What is this? 36 | 37 | Simple command line tool for text to image generation using OpenAI's CLIP and Siren. Credit goes to Ryan Murdock for the discovery of this technique (and for coming up with the great name)! 38 | 39 | Original notebook [![Open In Colab][colab-badge]][colab-notebook] 40 | 41 | New simplified notebook [![Open In Colab][colab-badge]][colab-notebook-2] 42 | 43 | [colab-notebook]: 44 | [colab-notebook-2]: 45 | [colab-badge]: 46 | 47 | This will require that you have an Nvidia GPU or AMD GPU 48 | - Recommended: 16GB VRAM 49 | - Minimum Requirements: 4GB VRAM (Using VERY LOW settings, see usage instructions below) 50 | 51 | ## Install 52 | 53 | ```bash 54 | $ pip install deep-daze 55 | ``` 56 | 57 | ### Windows Install 58 | 59 | 60 | 61 | Presuming Python is installed: 62 | - Open command prompt and navigate to the directory of your current version of Python 63 | ```bash 64 | pip install deep-daze 65 | ``` 66 | 67 | ## Examples 68 | 69 | ```bash 70 | $ imagine "a house in the forest" 71 | ``` 72 | For Windows: 73 | 74 | 75 | 76 | - Open command prompt as administrator 77 | ```bash 78 | imagine "a house in the forest" 79 | ``` 80 | 81 | That's it. 82 | 83 | 84 | If you have enough memory, you can get better quality by adding a `--deeper` flag 85 | 86 | ```bash 87 | $ imagine "shattered plates on the ground" --deeper 88 | ``` 89 | 90 | ### Advanced 91 | 92 | In true deep learning fashion, more layers will yield better results. Default is at `16`, but can be increased to `32` depending on your resources. 93 | 94 | ```bash 95 | $ imagine "stranger in strange lands" --num-layers 32 96 | ``` 97 | 98 | ## Usage 99 | 100 | ### CLI 101 | ```bash 102 | NAME 103 | imagine 104 | 105 | SYNOPSIS 106 | imagine TEXT 107 | 108 | POSITIONAL ARGUMENTS 109 | TEXT 110 | (required) A phrase less than 77 tokens which you would like to visualize. 111 | 112 | FLAGS 113 | --img=IMAGE_PATH 114 | Default: None 115 | Path to png/jpg image or PIL image to optimize on 116 | --encoding=ENCODING 117 | Default: None 118 | User-created custom CLIP encoding. If used, replaces any text or image that was used. 119 | --create_story=CREATE_STORY 120 | Default: False 121 | Creates a story by optimizing each epoch on a new sliding-window of the input words. If this is enabled, much longer texts than 77 tokens can be used. Requires save_progress to visualize the transitions of the story. 122 | --story_start_words=STORY_START_WORDS 123 | Default: 5 124 | Only used if create_story is True. How many words to optimize on for the first epoch. 125 | --story_words_per_epoch=STORY_WORDS_PER_EPOCH 126 | Default: 5 127 | Only used if create_story is True. How many words to add to the optimization goal per epoch after the first one. 128 | --story_separator: 129 | Default: None 130 | Only used if create_story is True. Defines a separator like '.' that splits the text into groups for each epoch. Separator needs to be in the text otherwise it will be ignored 131 | --lower_bound_cutout=LOWER_BOUND_CUTOUT 132 | Default: 0.1 133 | Lower bound of the sampling of the size of the random cut-out of the SIREN image per batch. Should be smaller than 0.8. 134 | --upper_bound_cutout=UPPER_BOUND_CUTOUT 135 | Default: 1.0 136 | Upper bound of the sampling of the size of the random cut-out of the SIREN image per batch. Should probably stay at 1.0. 137 | --saturate_bound=SATURATE_BOUND 138 | Default: False 139 | If True, the LOWER_BOUND_CUTOUT is linearly increased to 0.75 during training. 140 | --learning_rate=LEARNING_RATE 141 | Default: 1e-05 142 | The learning rate of the neural net. 143 | --num_layers=NUM_LAYERS 144 | Default: 16 145 | The number of hidden layers to use in the Siren neural net. 146 | --batch_size=BATCH_SIZE 147 | Default: 4 148 | The number of generated images to pass into Siren before calculating loss. Decreasing this can lower memory and accuracy. 149 | --gradient_accumulate_every=GRADIENT_ACCUMULATE_EVERY 150 | Default: 4 151 | Calculate a weighted loss of n samples for each iteration. Increasing this can help increase accuracy with lower batch sizes. 152 | --epochs=EPOCHS 153 | Default: 20 154 | The number of epochs to run. 155 | --iterations=ITERATIONS 156 | Default: 1050 157 | The number of times to calculate and backpropagate loss in a given epoch. 158 | --save_every=SAVE_EVERY 159 | Default: 100 160 | Generate an image every time iterations is a multiple of this number. 161 | --image_width=IMAGE_WIDTH 162 | Default: 512 163 | The desired resolution of the image. 164 | --deeper=DEEPER 165 | Default: False 166 | Uses a Siren neural net with 32 hidden layers. 167 | --overwrite=OVERWRITE 168 | Default: False 169 | Whether or not to overwrite existing generated images of the same name. 170 | --save_progress=SAVE_PROGRESS 171 | Default: False 172 | Whether or not to save images generated before training Siren is complete. 173 | --seed=SEED 174 | Type: Optional[] 175 | Default: None 176 | A seed to be used for deterministic runs. 177 | --open_folder=OPEN_FOLDER 178 | Default: True 179 | Whether or not to open a folder showing your generated images. 180 | --save_date_time=SAVE_DATE_TIME 181 | Default: False 182 | Save files with a timestamp prepended e.g. `%y%m%d-%H%M%S-my_phrase_here` 183 | --start_image_path=START_IMAGE_PATH 184 | Default: None 185 | The generator is trained first on a starting image before steered towards the textual input 186 | --start_image_train_iters=START_IMAGE_TRAIN_ITERS 187 | Default: 50 188 | The number of steps for the initial training on the starting image 189 | --theta_initial=THETA_INITIAL 190 | Default: 30.0 191 | Hyperparameter describing the frequency of the color space. Only applies to the first layer of the network. 192 | --theta_hidden=THETA_INITIAL 193 | Default: 30.0 194 | Hyperparameter describing the frequency of the color space. Only applies to the hidden layers of the network. 195 | --save_gif=SAVE_GIF 196 | Default: False 197 | Whether or not to save a GIF animation of the generation procedure. Only works if save_progress is set to True. 198 | ``` 199 | 200 | ### Priming 201 | 202 | Technique first devised and shared by Mario Klingemann, it allows you to prime the generator network with a starting image, before being steered towards the text. 203 | 204 | Simply specify the path to the image you wish to use, and optionally the number of initial training steps. 205 | 206 | ```bash 207 | $ imagine 'a clear night sky filled with stars' --start_image_path ./cloudy-night-sky.jpg 208 | ``` 209 | 210 | Primed starting image 211 | 212 | 213 | 214 | Then trained with the prompt `A pizza with green pepper.` 215 | 216 | 217 | 218 | 219 | ### Optimize for the interpretation of an image 220 | 221 | We can also feed in an image as an optimization goal, instead of only priming the generator network. Deepdaze will then render its own interpretation of that image: 222 | ```bash 223 | $ imagine --img samples/Autumn_1875_Frederic_Edwin_Church.jpg 224 | ``` 225 | Original image: 226 | 227 | 228 | 229 | The network's interpretation: 230 | 231 | 232 | 233 | Original image: 234 | 235 | 236 | 237 | The network's interpretation: 238 | 239 | 240 | 241 | #### Optimize for text and image combined 242 | 243 | ```bash 244 | $ imagine "A psychedelic experience." --img samples/hot-dog.jpg 245 | ``` 246 | The network's interpretation: 247 | 248 | 249 | 250 | ### New: Create a story 251 | The regular mode for texts only allows 77 tokens. If you want to visualize a full story/paragraph/song/poem, set `create_story` to `True`. 252 | 253 | Given the poem “Stopping by Woods On a Snowy Evening” by Robert Frost - 254 | "Whose woods these are I think I know. His house is in the village though; He will not see me stopping here To watch his woods fill up with snow. My little horse must think it queer To stop without a farmhouse near Between the woods and frozen lake The darkest evening of the year. He gives his harness bells a shake To ask if there is some mistake. The only other sound’s the sweep Of easy wind and downy flake. The woods are lovely, dark and deep, But I have promises to keep, And miles to go before I sleep, And miles to go before I sleep.". 255 | 256 | We get: 257 | 258 | https://user-images.githubusercontent.com/19983153/109539633-d671ef80-7ac1-11eb-8d8c-380332d7c868.mp4 259 | 260 | 261 | 262 | ### Python 263 | #### Invoke `deep_daze.Imagine` in Python 264 | ```python 265 | from deep_daze import Imagine 266 | 267 | imagine = Imagine( 268 | text = 'cosmic love and attention', 269 | num_layers = 24, 270 | ) 271 | imagine() 272 | ``` 273 | 274 | #### Save progress every fourth iteration 275 | Save images in the format insert_text_here.00001.png, insert_text_here.00002.png, ...up to `(total_iterations % save_every)` 276 | ```python 277 | imagine = Imagine( 278 | text=text, 279 | save_every=4, 280 | save_progress=True 281 | ) 282 | ``` 283 | 284 | #### Prepend current timestamp on each image. 285 | Creates files with both the timestamp and the sequence number. 286 | 287 | e.g. 210129-043928_328751_insert_text_here.00001.png, 210129-043928_512351_insert_text_here.00002.png, ... 288 | ```python 289 | imagine = Imagine( 290 | text=text, 291 | save_every=4, 292 | save_progress=True, 293 | save_date_time=True, 294 | ) 295 | ``` 296 | 297 | #### High GPU memory usage 298 | If you have at least 16 GiB of vram available, you should be able to run these settings with some wiggle room. 299 | ```python 300 | imagine = Imagine( 301 | text=text, 302 | num_layers=42, 303 | batch_size=64, 304 | gradient_accumulate_every=1, 305 | ) 306 | ``` 307 | 308 | #### Average GPU memory usage 309 | ```python 310 | imagine = Imagine( 311 | text=text, 312 | num_layers=24, 313 | batch_size=16, 314 | gradient_accumulate_every=2 315 | ) 316 | ``` 317 | 318 | #### Very low GPU memory usage (less than 4 GiB) 319 | If you are desperate to run this on a card with less than 8 GiB vram, you can lower the image_width. 320 | ```python 321 | imagine = Imagine( 322 | text=text, 323 | image_width=256, 324 | num_layers=16, 325 | batch_size=1, 326 | gradient_accumulate_every=16 # Increase gradient_accumulate_every to correct for loss in low batch sizes 327 | ) 328 | ``` 329 | 330 | ### VRAM and speed benchmarks: 331 | These experiments were conducted with a 2060 Super RTX and a 3700X Ryzen 5. We first mention the parameters (bs = batch size), then the memory usage and in some cases the training iterations per second: 332 | 333 | For an image resolution of 512: 334 | * bs 1, num_layers 22: 7.96 GB 335 | * bs 2, num_layers 20: 7.5 GB 336 | * bs 16, num_layers 16: 6.5 GB 337 | 338 | For an image resolution of 256: 339 | * bs 8, num_layers 48: 5.3 GB 340 | * bs 16, num_layers 48: 5.46 GB - 2.0 it/s 341 | * bs 32, num_layers 48: 5.92 GB - 1.67 it/s 342 | * bs 8, num_layers 44: 5 GB - 2.39 it/s 343 | * bs 32, num_layers 44, grad_acc 1: 5.62 GB - 4.83 it/s 344 | * bs 96, num_layers 44, grad_acc 1: 7.51 GB - 2.77 it/s 345 | * bs 32, num_layers 66, grad_acc 1: 7.09 GB - 3.7 it/s 346 | 347 | @NotNANtoN recommends a batch size of 32 with 44 layers and training 1-8 epochs. 348 | 349 | 350 | ## Where is this going? 351 | 352 | This is just a teaser. We will be able to generate images, sound, anything at will, with natural language. The holodeck is about to become real in our lifetimes. 353 | 354 | Please join replication efforts for DALL-E for Pytorch or Mesh Tensorflow if you are interested in furthering this technology. 355 | 356 | ## Alternatives 357 | 358 | Big Sleep - CLIP and the generator from Big GAN 359 | 360 | ## Citations 361 | 362 | ```bibtex 363 | @misc{unpublished2021clip, 364 | title = {CLIP: Connecting Text and Images}, 365 | author = {Alec Radford, Ilya Sutskever, Jong Wook Kim, Gretchen Krueger, Sandhini Agarwal}, 366 | year = {2021} 367 | } 368 | ``` 369 | 370 | ```bibtex 371 | @misc{sitzmann2020implicit, 372 | title = {Implicit Neural Representations with Periodic Activation Functions}, 373 | author = {Vincent Sitzmann and Julien N. P. Martel and Alexander W. Bergman and David B. Lindell and Gordon Wetzstein}, 374 | year = {2020}, 375 | eprint = {2006.09661}, 376 | archivePrefix = {arXiv}, 377 | primaryClass = {cs.CV} 378 | } 379 | ``` 380 | 381 | [colab-notebook]: 382 | -------------------------------------------------------------------------------- /deep_daze/__init__.py: -------------------------------------------------------------------------------- 1 | from deep_daze.deep_daze import DeepDaze, Imagine 2 | -------------------------------------------------------------------------------- /deep_daze/cli.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import fire 4 | 5 | from deep_daze import Imagine 6 | 7 | 8 | def train( 9 | text=None, 10 | img=None, 11 | learning_rate=1e-5, 12 | num_layers=16, 13 | hidden_size=256, 14 | batch_size=4, 15 | gradient_accumulate_every=4, 16 | epochs=20, 17 | iterations=1050, 18 | save_every=100, 19 | image_width=512, 20 | deeper=False, 21 | overwrite=False, 22 | save_progress=True, 23 | seed=None, 24 | open_folder=True, 25 | save_date_time=False, 26 | start_image_path=None, 27 | start_image_train_iters=50, 28 | theta_initial=None, 29 | theta_hidden=None, 30 | start_image_lr=3e-4, 31 | lower_bound_cutout=0.1, 32 | upper_bound_cutout=1.0, 33 | saturate_bound=False, 34 | create_story=False, 35 | story_start_words=5, 36 | story_words_per_epoch=5, 37 | story_separator=None, 38 | averaging_weight=0.3, 39 | gauss_sampling=False, 40 | gauss_mean=0.6, 41 | gauss_std=0.2, 42 | do_cutout=True, 43 | center_bias=False, 44 | center_focus=2, 45 | jit=True, 46 | save_gif=False, 47 | save_video=False, 48 | model_name="ViT-B/32", 49 | optimizer="AdamP" 50 | ): 51 | """ 52 | :param text: (required) A phrase less than 77 tokens which you would like to visualize. 53 | :param img: The path to a jpg or png image which you would like to imagine. Can be combined with text. 54 | :param learning_rate: The learning rate of the neural net. 55 | :param hidden_size: The hidden layer size of the Siren net. 56 | :param num_layers: The number of hidden layers to use in the Siren neural net. 57 | :param batch_size: The number of generated images to pass into Siren before calculating loss. Decreasing this can lower memory and accuracy. 58 | :param gradient_accumulate_every: Calculate a weighted loss of n samples for each iteration. Increasing this can help increase accuracy with lower batch sizes. 59 | :param epochs: The number of epochs to run. 60 | :param iterations: The number of times to calculate and backpropagate loss in a given epoch. 61 | :param save_progress: Whether or not to save images generated before training Siren is complete. 62 | :param save_every: Generate an image every time iterations is a multiple of this number. 63 | :param open_folder: Whether or not to open a folder showing your generated images. 64 | :param overwrite: Whether or not to overwrite existing generated images of the same name. 65 | :param deeper: Uses a Siren neural net with 32 hidden layers. 66 | :param image_width: The desired resolution of the image. 67 | :param seed: A seed to be used for deterministic runs. 68 | :param save_date_time: Save files with a timestamp prepended e.g. `%y%m%d-%H%M%S-my_phrase_here.png` 69 | :param start_image_path: Path to the image you would like to prime the generator with initially 70 | :param start_image_train_iters: Number of iterations for priming, defaults to 50 71 | :param theta_initial: Hyperparameter describing the frequency of the color space. Only applies to the first layer of the network. 72 | :param theta_hidden: Hyperparameter describing the frequency of the color space. Only applies to the hidden layers of the network. 73 | :param start_image_lr: Learning rate for the start image training. 74 | :param upper_bound_cutout: The upper bound for the cutouts used in generation. 75 | :param lower_bound_cutout: The lower bound for the cutouts used in generation. 76 | :param saturate_bound: If True, the LOWER_BOUND_CUTOUT is linearly increased to 0.75 during training. 77 | :param create_story: Creates a story by optimizing each epoch on a new sliding-window of the input words. If this is enabled, much longer texts than 77 tokens can be used. Requires save_progress to visualize the transitions of the story. 78 | :param story_start_words: Only used if create_story is True. How many words to optimize on for the first epoch. 79 | :param story_words_per_epoch: Only used if create_story is True. How many words to add to the optimization goal per epoch after the first one. 80 | :param story_separator: Only used if create_story is True. Defines a separator like '.' that splits the text into groups for each epoch. Separator needs to be in the text otherwise it will be ignored! 81 | :param averaging_weight: How much to weigh the averaged features of the random cutouts over the individual random cutouts. Increasing this value leads to more details being represented at the cost of some global coherence and a parcellation into smaller scenes. 82 | :param gauss_sampling: Whether to use sampling from a Gaussian distribution instead of a uniform distribution. 83 | :param gauss_mean: The mean of the Gaussian sampling distribution. 84 | :param gauss_std: The standard deviation of the Gaussian sampling distribution. 85 | :param do_cutouts: Whether to use random cutouts as an augmentation. This basically needs to be turned on unless some new augmentations are added in code eventually. 86 | :param center_bias: Whether to use a Gaussian distribution centered around the center of the image to sample the locations of random cutouts instead of a uniform distribution. Leads to the main generated objects to be more focused in the center. 87 | :param center_focus: How much to focus on the center if using center_bias. std = sampling_range / center_focus. High values lead to a very correct representation in the center but washed out colors and details towards the edges, 88 | :param jit: Whether to use the jit-compiled CLIP model. The jit model is faster, but only compatible with torch version 1.7.1. 89 | :param save_gif: Only used if save_progress is True. Saves a GIF animation of the generation procedure using the saved frames. 90 | :param save_video: Only used if save_progress is True. Saves a MP4 animation of the generation procedure using the saved frames. 91 | """ 92 | # Don't instantiate imagine if the user just wants help. 93 | if any("--help" in arg for arg in sys.argv): 94 | print("Type `imagine --help` for usage info.") 95 | sys.exit() 96 | 97 | num_layers = 32 if deeper else num_layers 98 | 99 | imagine = Imagine( 100 | text=text, 101 | img=img, 102 | lr=learning_rate, 103 | num_layers=num_layers, 104 | batch_size=batch_size, 105 | gradient_accumulate_every=gradient_accumulate_every, 106 | epochs=epochs, 107 | iterations=iterations, 108 | image_width=image_width, 109 | save_every=save_every, 110 | save_progress=save_progress, 111 | seed=seed, 112 | open_folder=open_folder, 113 | save_date_time=save_date_time, 114 | start_image_path=start_image_path, 115 | start_image_train_iters=start_image_train_iters, 116 | theta_initial=theta_initial, 117 | theta_hidden=theta_hidden, 118 | start_image_lr=start_image_lr, 119 | lower_bound_cutout=lower_bound_cutout, 120 | upper_bound_cutout=upper_bound_cutout, 121 | saturate_bound=saturate_bound, 122 | create_story=create_story, 123 | story_start_words=story_start_words, 124 | story_words_per_epoch=story_words_per_epoch, 125 | story_separator=story_separator, 126 | averaging_weight=averaging_weight, 127 | gauss_sampling=gauss_sampling, 128 | gauss_mean=gauss_mean, 129 | gauss_std=gauss_std, 130 | do_cutout=do_cutout, 131 | center_bias=center_bias, 132 | center_focus=center_focus, 133 | jit=jit, 134 | hidden_size=hidden_size, 135 | model_name=model_name, 136 | optimizer=optimizer, 137 | save_gif=save_gif, 138 | save_video=save_video, 139 | ) 140 | 141 | print('Starting up...') 142 | if not overwrite and imagine.filename.exists(): 143 | answer = input('Imagined image already exists, do you want to overwrite? (y/n) ').lower() 144 | if answer not in ('yes', 'y'): 145 | sys.exit() 146 | 147 | imagine() 148 | 149 | 150 | def main(): 151 | fire.Fire(train) 152 | -------------------------------------------------------------------------------- /deep_daze/clip.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Tuple, Union 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import nn 7 | from pathlib import Path 8 | 9 | import hashlib 10 | import os 11 | import urllib 12 | import warnings 13 | from typing import Union, List 14 | from torchvision.transforms import Compose, Normalize 15 | from tqdm import tqdm 16 | 17 | _MODELS = { 18 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 19 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 20 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 21 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 22 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt" 23 | } 24 | 25 | def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): 26 | os.makedirs(root, exist_ok=True) 27 | filename = os.path.basename(url) 28 | 29 | expected_sha256 = url.split("/")[-2] 30 | download_target = os.path.join(root, filename) 31 | 32 | if os.path.exists(download_target) and not os.path.isfile(download_target): 33 | raise RuntimeError(f"{download_target} exists and is not a regular file") 34 | 35 | if os.path.isfile(download_target): 36 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 37 | return download_target 38 | else: 39 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 40 | 41 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 42 | with tqdm( 43 | total=int(source.info().get("Content-Length")), 44 | unit='iB', 45 | unit_scale=True, 46 | desc=f"Downloading {filename}", 47 | ) as loop: 48 | while True: 49 | buffer = source.read(524288) 50 | if not buffer: 51 | break 52 | 53 | output.write(buffer) 54 | loop.update(len(buffer)) 55 | 56 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 57 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 58 | 59 | return download_target 60 | 61 | 62 | def _transform(): 63 | return Compose([ 64 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 65 | ]) 66 | 67 | 68 | def available_models() -> List[str]: 69 | """Returns the names of available CLIP models""" 70 | return list(_MODELS.keys()) 71 | 72 | 73 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True): 74 | """Load a CLIP model 75 | 76 | Parameters 77 | ---------- 78 | name : str 79 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 80 | 81 | device : Union[str, torch.device] 82 | The device to put the loaded model 83 | 84 | jit : bool 85 | Whether to load the optimized JIT model (default) or more hackable non-JIT model. 86 | 87 | Returns 88 | ------- 89 | model : torch.nn.Module 90 | The CLIP model 91 | 92 | preprocess : Callable[[PIL.Image], torch.Tensor] 93 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 94 | """ 95 | if name in _MODELS: 96 | model_path = _download(_MODELS[name]) 97 | elif os.path.isfile(name): 98 | model_path = name 99 | else: 100 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 101 | 102 | try: 103 | # loading JIT archive 104 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 105 | state_dict = None 106 | except RuntimeError: 107 | # loading saved state dict 108 | if jit: 109 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 110 | jit = False 111 | state_dict = torch.load(model_path, map_location="cpu") 112 | 113 | if not jit: 114 | model = build_model(state_dict or model.state_dict()).to(device) 115 | if str(device) == "cpu": 116 | model.float() 117 | return model, _transform() 118 | 119 | # patch the device names 120 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 121 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 122 | 123 | def patch_device(module): 124 | graphs = [module.graph] if hasattr(module, "graph") else [] 125 | if hasattr(module, "forward1"): 126 | graphs.append(module.forward1.graph) 127 | 128 | for graph in graphs: 129 | for node in graph.findAllNodes("prim::Constant"): 130 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 131 | node.copyAttributes(device_node) 132 | 133 | model.apply(patch_device) 134 | patch_device(model.encode_image) 135 | patch_device(model.encode_text) 136 | 137 | # patch dtype to float32 on CPU 138 | if str(device) == "cpu": 139 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 140 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 141 | float_node = float_input.node() 142 | 143 | def patch_float(module): 144 | graphs = [module.graph] if hasattr(module, "graph") else [] 145 | if hasattr(module, "forward1"): 146 | graphs.append(module.forward1.graph) 147 | 148 | for graph in graphs: 149 | for node in graph.findAllNodes("aten::to"): 150 | inputs = list(node.inputs()) 151 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 152 | if inputs[i].node()["value"] == 5: 153 | inputs[i].node().copyAttributes(float_node) 154 | 155 | model.apply(patch_float) 156 | patch_float(model.encode_image) 157 | patch_float(model.encode_text) 158 | 159 | model.float() 160 | 161 | return model, _transform() 162 | 163 | 164 | def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: 165 | """ 166 | Returns the tokenized representation of given input string(s) 167 | 168 | Parameters 169 | ---------- 170 | texts : Union[str, List[str]] 171 | An input string or a list of input strings to tokenize 172 | 173 | context_length : int 174 | The context length to use; all CLIP models use 77 as the context length 175 | 176 | Returns 177 | ------- 178 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 179 | """ 180 | if isinstance(texts, str): 181 | texts = [texts] 182 | 183 | sot_token = _tokenizer.encoder["<|startoftext|>"] 184 | eot_token = _tokenizer.encoder["<|endoftext|>"] 185 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 186 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 187 | 188 | for i, tokens in enumerate(all_tokens): 189 | if len(tokens) > context_length: 190 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 191 | result[i, :len(tokens)] = torch.tensor(tokens) 192 | 193 | return result 194 | 195 | class Bottleneck(nn.Module): 196 | expansion = 4 197 | 198 | def __init__(self, inplanes, planes, stride=1): 199 | super().__init__() 200 | 201 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 202 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 203 | self.bn1 = nn.BatchNorm2d(planes) 204 | 205 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 206 | self.bn2 = nn.BatchNorm2d(planes) 207 | 208 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 209 | 210 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 211 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 212 | 213 | self.relu = nn.ReLU(inplace=True) 214 | self.downsample = None 215 | self.stride = stride 216 | 217 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 218 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 219 | self.downsample = nn.Sequential(OrderedDict([ 220 | ("-1", nn.AvgPool2d(stride)), 221 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 222 | ("1", nn.BatchNorm2d(planes * self.expansion)) 223 | ])) 224 | 225 | def forward(self, x: torch.Tensor): 226 | identity = x 227 | 228 | out = self.relu(self.bn1(self.conv1(x))) 229 | out = self.relu(self.bn2(self.conv2(out))) 230 | out = self.avgpool(out) 231 | out = self.bn3(self.conv3(out)) 232 | 233 | if self.downsample is not None: 234 | identity = self.downsample(x) 235 | 236 | out += identity 237 | out = self.relu(out) 238 | return out 239 | 240 | 241 | class AttentionPool2d(nn.Module): 242 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 243 | super().__init__() 244 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 245 | self.k_proj = nn.Linear(embed_dim, embed_dim) 246 | self.q_proj = nn.Linear(embed_dim, embed_dim) 247 | self.v_proj = nn.Linear(embed_dim, embed_dim) 248 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 249 | self.num_heads = num_heads 250 | 251 | def forward(self, x): 252 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 253 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 254 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 255 | x, _ = F.multi_head_attention_forward( 256 | query=x, key=x, value=x, 257 | embed_dim_to_check=x.shape[-1], 258 | num_heads=self.num_heads, 259 | q_proj_weight=self.q_proj.weight, 260 | k_proj_weight=self.k_proj.weight, 261 | v_proj_weight=self.v_proj.weight, 262 | in_proj_weight=None, 263 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 264 | bias_k=None, 265 | bias_v=None, 266 | add_zero_attn=False, 267 | dropout_p=0, 268 | out_proj_weight=self.c_proj.weight, 269 | out_proj_bias=self.c_proj.bias, 270 | use_separate_proj_weight=True, 271 | training=self.training, 272 | need_weights=False 273 | ) 274 | 275 | return x[0] 276 | 277 | 278 | class ModifiedResNet(nn.Module): 279 | """ 280 | A ResNet class that is similar to torchvision's but contains the following changes: 281 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 282 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 283 | - The final pooling layer is a QKV attention instead of an average pool 284 | """ 285 | 286 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): 287 | super().__init__() 288 | self.output_dim = output_dim 289 | self.input_resolution = input_resolution 290 | 291 | # the 3-layer stem 292 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 293 | self.bn1 = nn.BatchNorm2d(width // 2) 294 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 295 | self.bn2 = nn.BatchNorm2d(width // 2) 296 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 297 | self.bn3 = nn.BatchNorm2d(width) 298 | self.avgpool = nn.AvgPool2d(2) 299 | self.relu = nn.ReLU(inplace=True) 300 | 301 | # residual layers 302 | self._inplanes = width # this is a *mutable* variable used during construction 303 | self.layer1 = self._make_layer(width, layers[0]) 304 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 305 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 306 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 307 | 308 | embed_dim = width * 32 # the ResNet feature dimension 309 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) 310 | 311 | def _make_layer(self, planes, blocks, stride=1): 312 | layers = [Bottleneck(self._inplanes, planes, stride)] 313 | 314 | self._inplanes = planes * Bottleneck.expansion 315 | for _ in range(1, blocks): 316 | layers.append(Bottleneck(self._inplanes, planes)) 317 | 318 | return nn.Sequential(*layers) 319 | 320 | def forward(self, x): 321 | def stem(x): 322 | for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]: 323 | x = self.relu(bn(conv(x))) 324 | x = self.avgpool(x) 325 | return x 326 | 327 | x = x.type(self.conv1.weight.dtype) 328 | x = stem(x) 329 | x = self.layer1(x) 330 | x = self.layer2(x) 331 | x = self.layer3(x) 332 | x = self.layer4(x) 333 | x = self.attnpool(x) 334 | 335 | return x 336 | 337 | 338 | class LayerNorm(nn.LayerNorm): 339 | """Subclass torch's LayerNorm to handle fp16.""" 340 | 341 | def forward(self, x: torch.Tensor): 342 | orig_type = x.dtype 343 | ret = super().forward(x.type(torch.float32)) 344 | return ret.type(orig_type) 345 | 346 | 347 | class QuickGELU(nn.Module): 348 | def forward(self, x: torch.Tensor): 349 | return x * torch.sigmoid(1.702 * x) 350 | 351 | 352 | class ResidualAttentionBlock(nn.Module): 353 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 354 | super().__init__() 355 | 356 | self.attn = nn.MultiheadAttention(d_model, n_head) 357 | self.ln_1 = LayerNorm(d_model) 358 | self.mlp = nn.Sequential(OrderedDict([ 359 | ("c_fc", nn.Linear(d_model, d_model * 4)), 360 | ("gelu", QuickGELU()), 361 | ("c_proj", nn.Linear(d_model * 4, d_model)) 362 | ])) 363 | self.ln_2 = LayerNorm(d_model) 364 | self.attn_mask = attn_mask 365 | 366 | def attention(self, x: torch.Tensor): 367 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 368 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 369 | 370 | def forward(self, x: torch.Tensor): 371 | x = x + self.attention(self.ln_1(x)) 372 | x = x + self.mlp(self.ln_2(x)) 373 | return x 374 | 375 | 376 | class Transformer(nn.Module): 377 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): 378 | super().__init__() 379 | self.width = width 380 | self.layers = layers 381 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) 382 | 383 | def forward(self, x: torch.Tensor): 384 | return self.resblocks(x) 385 | 386 | 387 | class VisualTransformer(nn.Module): 388 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): 389 | super().__init__() 390 | self.input_resolution = input_resolution 391 | self.output_dim = output_dim 392 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) 393 | 394 | scale = width ** -0.5 395 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 396 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) 397 | self.ln_pre = LayerNorm(width) 398 | 399 | self.transformer = Transformer(width, layers, heads) 400 | 401 | self.ln_post = LayerNorm(width) 402 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) 403 | 404 | def forward(self, x: torch.Tensor): 405 | x = self.conv1(x) # shape = [*, width, grid, grid] 406 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 407 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 408 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 409 | x = x + self.positional_embedding.to(x.dtype) 410 | x = self.ln_pre(x) 411 | 412 | x = x.permute(1, 0, 2) # NLD -> LND 413 | x = self.transformer(x) 414 | x = x.permute(1, 0, 2) # LND -> NLD 415 | 416 | x = self.ln_post(x[:, 0, :]) 417 | 418 | if self.proj is not None: 419 | x = x @ self.proj 420 | 421 | return x 422 | 423 | 424 | class CLIP(nn.Module): 425 | def __init__(self, 426 | embed_dim: int, 427 | # vision 428 | image_resolution: int, 429 | vision_layers: Union[Tuple[int, int, int, int], int], 430 | vision_width: int, 431 | vision_patch_size: int, 432 | # text 433 | context_length: int, 434 | vocab_size: int, 435 | transformer_width: int, 436 | transformer_heads: int, 437 | transformer_layers: int 438 | ): 439 | super().__init__() 440 | 441 | self.context_length = context_length 442 | 443 | if isinstance(vision_layers, (tuple, list)): 444 | vision_heads = vision_width * 32 // 64 445 | self.visual = ModifiedResNet( 446 | layers=vision_layers, 447 | output_dim=embed_dim, 448 | heads=vision_heads, 449 | input_resolution=image_resolution, 450 | width=vision_width 451 | ) 452 | else: 453 | vision_heads = vision_width // 64 454 | self.visual = VisualTransformer( 455 | input_resolution=image_resolution, 456 | patch_size=vision_patch_size, 457 | width=vision_width, 458 | layers=vision_layers, 459 | heads=vision_heads, 460 | output_dim=embed_dim 461 | ) 462 | 463 | self.transformer = Transformer( 464 | width=transformer_width, 465 | layers=transformer_layers, 466 | heads=transformer_heads, 467 | attn_mask=self.build_attention_mask() 468 | ) 469 | 470 | self.vocab_size = vocab_size 471 | self.token_embedding = nn.Embedding(vocab_size, transformer_width) 472 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) 473 | self.ln_final = LayerNorm(transformer_width) 474 | 475 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 476 | self.logit_scale = nn.Parameter(torch.ones([])) 477 | 478 | self.initialize_parameters() 479 | 480 | def initialize_parameters(self): 481 | nn.init.normal_(self.token_embedding.weight, std=0.02) 482 | nn.init.normal_(self.positional_embedding, std=0.01) 483 | 484 | if isinstance(self.visual, ModifiedResNet): 485 | if self.visual.attnpool is not None: 486 | std = self.visual.attnpool.c_proj.in_features ** -0.5 487 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) 488 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) 489 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) 490 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) 491 | 492 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: 493 | for name, param in resnet_block.named_parameters(): 494 | if name.endswith("bn3.weight"): 495 | nn.init.zeros_(param) 496 | 497 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 498 | attn_std = self.transformer.width ** -0.5 499 | fc_std = (2 * self.transformer.width) ** -0.5 500 | for block in self.transformer.resblocks: 501 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 502 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 503 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 504 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 505 | 506 | if self.text_projection is not None: 507 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 508 | 509 | def build_attention_mask(self): 510 | # lazily create causal attention mask, with full attention between the vision tokens 511 | # pytorch uses additive attention mask; fill with -inf 512 | mask = torch.empty(self.context_length, self.context_length) 513 | mask.fill_(float("-inf")) 514 | mask.triu_(1) # zero out the lower diagonal 515 | return mask 516 | 517 | @property 518 | def dtype(self): 519 | return self.visual.conv1.weight.dtype 520 | 521 | def encode_image(self, image): 522 | return self.visual(image.type(self.dtype)) 523 | 524 | def encode_text(self, text): 525 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 526 | 527 | x = x + self.positional_embedding.type(self.dtype) 528 | x = x.permute(1, 0, 2) # NLD -> LND 529 | x = self.transformer(x) 530 | x = x.permute(1, 0, 2) # LND -> NLD 531 | x = self.ln_final(x).type(self.dtype) 532 | 533 | # x.shape = [batch_size, n_ctx, transformer.width] 534 | # take features from the eot embedding (eot_token is the highest number in each sequence) 535 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 536 | 537 | return x 538 | 539 | def forward(self, image, text): 540 | image_features = self.encode_image(image) 541 | text_features = self.encode_text(text) 542 | 543 | # normalized features 544 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 545 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 546 | 547 | # cosine similarity as logits 548 | logit_scale = self.logit_scale.exp() 549 | logits_per_image = logit_scale * image_features @ text_features.t() 550 | logits_per_text = logit_scale * text_features @ image_features.t() 551 | 552 | # shape = [global_batch_size, global_batch_size] 553 | return logits_per_image, logits_per_text 554 | 555 | 556 | def convert_weights(model: nn.Module): 557 | """Convert applicable model parameters to fp16""" 558 | 559 | def _convert_weights_to_fp16(l): 560 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): 561 | l.weight.data = l.weight.data.half() 562 | if l.bias is not None: 563 | l.bias.data = l.bias.data.half() 564 | 565 | if isinstance(l, nn.MultiheadAttention): 566 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: 567 | tensor = getattr(l, attr) 568 | if tensor is not None: 569 | tensor.data = tensor.data.half() 570 | 571 | for name in ["text_projection", "proj"]: 572 | if hasattr(l, name): 573 | attr = getattr(l, name) 574 | if attr is not None: 575 | attr.data = attr.data.half() 576 | 577 | model.apply(_convert_weights_to_fp16) 578 | 579 | 580 | def build_model(state_dict: dict): 581 | vit = "visual.proj" in state_dict 582 | 583 | if vit: 584 | vision_width = state_dict["visual.conv1.weight"].shape[0] 585 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 586 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 587 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 588 | image_resolution = vision_patch_size * grid_size 589 | else: 590 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] 591 | vision_layers = tuple(counts) 592 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 593 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) 594 | vision_patch_size = None 595 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] 596 | image_resolution = output_width * 32 597 | 598 | embed_dim = state_dict["text_projection"].shape[1] 599 | context_length = state_dict["positional_embedding"].shape[0] 600 | vocab_size = state_dict["token_embedding.weight"].shape[0] 601 | transformer_width = state_dict["ln_final.weight"].shape[0] 602 | transformer_heads = transformer_width // 64 603 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) 604 | 605 | model = CLIP( 606 | embed_dim, 607 | image_resolution, vision_layers, vision_width, vision_patch_size, 608 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers 609 | ) 610 | 611 | for key in ["input_resolution", "context_length", "vocab_size"]: 612 | if key in state_dict: 613 | del state_dict[key] 614 | 615 | convert_weights(model) 616 | model.load_state_dict(state_dict) 617 | return model.eval() 618 | 619 | import html 620 | from functools import lru_cache 621 | 622 | import ftfy 623 | import regex as re 624 | 625 | 626 | @lru_cache() 627 | def default_bpe(): 628 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "data/bpe_simple_vocab_16e6.txt") 629 | 630 | 631 | @lru_cache() 632 | def bytes_to_unicode(): 633 | """ 634 | Returns list of utf-8 byte and a corresponding list of unicode strings. 635 | The reversible bpe codes work on unicode strings. 636 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 637 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 638 | This is a signficant percentage of your normal, say, 32K bpe vocab. 639 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 640 | And avoids mapping to whitespace/control characters the bpe code barfs on. 641 | """ 642 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 643 | cs = bs[:] 644 | n = 0 645 | for b in range(2**8): 646 | if b not in bs: 647 | bs.append(b) 648 | cs.append(2**8+n) 649 | n += 1 650 | cs = [chr(n) for n in cs] 651 | return dict(zip(bs, cs)) 652 | 653 | 654 | def get_pairs(word): 655 | """Return set of symbol pairs in a word. 656 | Word is represented as tuple of symbols (symbols being variable-length strings). 657 | """ 658 | pairs = set() 659 | prev_char = word[0] 660 | for char in word[1:]: 661 | pairs.add((prev_char, char)) 662 | prev_char = char 663 | return pairs 664 | 665 | 666 | def basic_clean(text): 667 | text = ftfy.fix_text(text) 668 | text = html.unescape(html.unescape(text)) 669 | return text.strip() 670 | 671 | 672 | def whitespace_clean(text): 673 | text = re.sub(r'\s+', ' ', text) 674 | text = text.strip() 675 | return text 676 | 677 | 678 | class SimpleTokenizer(object): 679 | def __init__(self, bpe_path: str = default_bpe()): 680 | self.byte_encoder = bytes_to_unicode() 681 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 682 | merges = Path(bpe_path).read_text(encoding='utf8').split('\n') 683 | merges = merges[1:49152-256-2+1] 684 | merges = [tuple(merge.split()) for merge in merges] 685 | vocab = list(bytes_to_unicode().values()) 686 | vocab = vocab + [v+'' for v in vocab] 687 | for merge in merges: 688 | vocab.append(''.join(merge)) 689 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 690 | self.encoder = dict(zip(vocab, range(len(vocab)))) 691 | self.decoder = {v: k for k, v in self.encoder.items()} 692 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 693 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 694 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 695 | 696 | def bpe(self, token): 697 | if token in self.cache: 698 | return self.cache[token] 699 | word = tuple(token[:-1]) + ( token[-1] + '',) 700 | pairs = get_pairs(word) 701 | 702 | if not pairs: 703 | return token+'' 704 | 705 | while True: 706 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 707 | if bigram not in self.bpe_ranks: 708 | break 709 | first, second = bigram 710 | new_word = [] 711 | i = 0 712 | while i < len(word): 713 | try: 714 | j = word.index(first, i) 715 | new_word.extend(word[i:j]) 716 | i = j 717 | except: 718 | new_word.extend(word[i:]) 719 | break 720 | 721 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 722 | new_word.append(first+second) 723 | i += 2 724 | else: 725 | new_word.append(word[i]) 726 | i += 1 727 | new_word = tuple(new_word) 728 | word = new_word 729 | if len(word) == 1: 730 | break 731 | else: 732 | pairs = get_pairs(word) 733 | word = ' '.join(word) 734 | self.cache[token] = word 735 | return word 736 | 737 | def encode(self, text): 738 | bpe_tokens = [] 739 | text = whitespace_clean(basic_clean(text)).lower() 740 | for token in re.findall(self.pat, text): 741 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 742 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 743 | return bpe_tokens 744 | 745 | def decode(self, tokens): 746 | text = ''.join([self.decoder[token] for token in tokens]) 747 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 748 | return text 749 | _tokenizer = SimpleTokenizer() 750 | -------------------------------------------------------------------------------- /deep_daze/deep_daze.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import sys 4 | import random 5 | from datetime import datetime 6 | from pathlib import Path 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | from siren_pytorch import SirenNet, SirenWrapper 11 | from torch import nn 12 | from torch.cuda.amp import GradScaler, autocast 13 | from torch_optimizer import DiffGrad, AdamP 14 | import numpy as np 15 | 16 | from PIL import Image 17 | from imageio import imread, mimsave 18 | import torchvision.transforms as T 19 | 20 | 21 | from tqdm import trange, tqdm 22 | 23 | from .clip import load, tokenize 24 | 25 | 26 | # Helpers 27 | 28 | def exists(val): 29 | return val is not None 30 | 31 | 32 | def default(val, d): 33 | return val if exists(val) else d 34 | 35 | 36 | def interpolate(image, size): 37 | return F.interpolate(image, (size, size), mode='bilinear', align_corners=False) 38 | 39 | 40 | def rand_cutout(image, size, center_bias=False, center_focus=2): 41 | width = image.shape[-1] 42 | min_offset = 0 43 | max_offset = width - size 44 | if center_bias: 45 | # sample around image center 46 | center = max_offset / 2 47 | std = center / center_focus 48 | offset_x = int(random.gauss(mu=center, sigma=std)) 49 | offset_y = int(random.gauss(mu=center, sigma=std)) 50 | # resample uniformly if over boundaries 51 | offset_x = random.randint(min_offset, max_offset) if (offset_x > max_offset or offset_x < min_offset) else offset_x 52 | offset_y = random.randint(min_offset, max_offset) if (offset_y > max_offset or offset_y < min_offset) else offset_y 53 | else: 54 | offset_x = random.randint(min_offset, max_offset) 55 | offset_y = random.randint(min_offset, max_offset) 56 | cutout = image[:, :, offset_x:offset_x + size, offset_y:offset_y + size] 57 | return cutout 58 | 59 | 60 | def create_clip_img_transform(image_width): 61 | clip_mean = [0.48145466, 0.4578275, 0.40821073] 62 | clip_std = [0.26862954, 0.26130258, 0.27577711] 63 | transform = T.Compose([ 64 | #T.ToPILImage(), 65 | T.Resize(image_width), 66 | T.CenterCrop((image_width, image_width)), 67 | T.ToTensor(), 68 | T.Normalize(mean=clip_mean, std=clip_std) 69 | ]) 70 | return transform 71 | 72 | 73 | def open_folder(path): 74 | if os.path.isfile(path): 75 | path = os.path.dirname(path) 76 | 77 | if not os.path.isdir(path): 78 | return 79 | 80 | cmd_list = None 81 | if sys.platform == 'darwin': 82 | cmd_list = ['open', '--', path] 83 | elif sys.platform == 'linux2' or sys.platform == 'linux': 84 | cmd_list = ['xdg-open', path] 85 | elif sys.platform in ['win32', 'win64']: 86 | cmd_list = ['explorer', path.replace('/', '\\')] 87 | if cmd_list is None: 88 | return 89 | 90 | try: 91 | subprocess.check_call(cmd_list) 92 | except subprocess.CalledProcessError: 93 | pass 94 | except OSError: 95 | pass 96 | 97 | 98 | def norm_siren_output(img): 99 | return ((img + 1) * 0.5).clamp(0.0, 1.0) 100 | 101 | 102 | def create_text_path(context_length, text=None, img=None, encoding=None, separator=None): 103 | if text is not None: 104 | if separator is not None and separator in text: 105 | #Reduces filename to first epoch text 106 | text = text[:text.index(separator, )] 107 | input_name = text.replace(" ", "_")[:context_length] 108 | elif img is not None: 109 | if isinstance(img, str): 110 | input_name = "".join(img.replace(" ", "_").split(".")[:-1]) 111 | else: 112 | input_name = "PIL_img" 113 | else: 114 | input_name = "your_encoding" 115 | return input_name 116 | 117 | 118 | class DeepDaze(nn.Module): 119 | def __init__( 120 | self, 121 | clip_perceptor, 122 | clip_norm, 123 | input_res, 124 | total_batches, 125 | batch_size, 126 | num_layers=8, 127 | image_width=512, 128 | loss_coef=100, 129 | theta_initial=None, 130 | theta_hidden=None, 131 | lower_bound_cutout=0.1, # should be smaller than 0.8 132 | upper_bound_cutout=1.0, 133 | saturate_bound=False, 134 | gauss_sampling=False, 135 | gauss_mean=0.6, 136 | gauss_std=0.2, 137 | do_cutout=True, 138 | center_bias=False, 139 | center_focus=2, 140 | hidden_size=256, 141 | averaging_weight=0.3, 142 | ): 143 | super().__init__() 144 | # load clip 145 | self.perceptor = clip_perceptor 146 | self.input_resolution = input_res 147 | self.normalize_image = clip_norm 148 | 149 | self.loss_coef = loss_coef 150 | self.image_width = image_width 151 | 152 | self.batch_size = batch_size 153 | self.total_batches = total_batches 154 | self.num_batches_processed = 0 155 | 156 | w0 = default(theta_hidden, 30.) 157 | w0_initial = default(theta_initial, 30.) 158 | 159 | siren = SirenNet( 160 | dim_in=2, 161 | dim_hidden=hidden_size, 162 | num_layers=num_layers, 163 | dim_out=3, 164 | use_bias=True, 165 | w0=w0, 166 | w0_initial=w0_initial 167 | ) 168 | 169 | self.model = SirenWrapper( 170 | siren, 171 | image_width=image_width, 172 | image_height=image_width 173 | ) 174 | 175 | self.saturate_bound = saturate_bound 176 | self.saturate_limit = 0.75 # cutouts above this value lead to destabilization 177 | self.lower_bound_cutout = lower_bound_cutout 178 | self.upper_bound_cutout = upper_bound_cutout 179 | self.gauss_sampling = gauss_sampling 180 | self.gauss_mean = gauss_mean 181 | self.gauss_std = gauss_std 182 | self.do_cutout = do_cutout 183 | self.center_bias = center_bias 184 | self.center_focus = center_focus 185 | self.averaging_weight = averaging_weight 186 | 187 | def sample_sizes(self, lower, upper, width, gauss_mean): 188 | if self.gauss_sampling: 189 | gauss_samples = torch.zeros(self.batch_size).normal_(mean=gauss_mean, std=self.gauss_std) 190 | outside_bounds_mask = (gauss_samples > upper) | (gauss_samples < upper) 191 | gauss_samples[outside_bounds_mask] = torch.zeros((len(gauss_samples[outside_bounds_mask]),)).uniform_(lower, upper) 192 | sizes = (gauss_samples * width).int() 193 | else: 194 | lower *= width 195 | upper *= width 196 | sizes = torch.randint(int(lower), int(upper), (self.batch_size,)) 197 | return sizes 198 | 199 | def forward(self, text_embed, return_loss=True, dry_run=False): 200 | out = self.model() 201 | out = norm_siren_output(out) 202 | 203 | if not return_loss: 204 | return out 205 | 206 | # determine upper and lower sampling bound 207 | width = out.shape[-1] 208 | lower_bound = self.lower_bound_cutout 209 | if self.saturate_bound: 210 | progress_fraction = self.num_batches_processed / self.total_batches 211 | lower_bound += (self.saturate_limit - self.lower_bound_cutout) * progress_fraction 212 | 213 | # sample cutout sizes between lower and upper bound 214 | sizes = self.sample_sizes(lower_bound, self.upper_bound_cutout, width, self.gauss_mean) 215 | 216 | # create normalized random cutouts 217 | if self.do_cutout: 218 | image_pieces = [rand_cutout(out, size, center_bias=self.center_bias, center_focus=self.center_focus) for size in sizes] 219 | image_pieces = [interpolate(piece, self.input_resolution) for piece in image_pieces] 220 | else: 221 | image_pieces = [interpolate(out.clone(), self.input_resolution) for _ in sizes] 222 | 223 | # normalize 224 | image_pieces = torch.cat([self.normalize_image(piece) for piece in image_pieces]) 225 | 226 | # calc image embedding 227 | with autocast(enabled=False): 228 | image_embed = self.perceptor.encode_image(image_pieces) 229 | 230 | # calc loss 231 | # loss over averaged features of cutouts 232 | avg_image_embed = image_embed.mean(dim=0).unsqueeze(0) 233 | averaged_loss = -self.loss_coef * torch.cosine_similarity(text_embed, avg_image_embed, dim=-1).mean() 234 | # loss over all cutouts 235 | general_loss = -self.loss_coef * torch.cosine_similarity(text_embed, image_embed, dim=-1).mean() 236 | # merge losses 237 | loss = averaged_loss * (self.averaging_weight) + general_loss * (1 - self.averaging_weight) 238 | 239 | # count batches 240 | if not dry_run: 241 | self.num_batches_processed += self.batch_size 242 | 243 | return out, loss 244 | 245 | 246 | class Imagine(nn.Module): 247 | def __init__( 248 | self, 249 | *, 250 | text=None, 251 | img=None, 252 | clip_encoding=None, 253 | lr=1e-5, 254 | batch_size=4, 255 | gradient_accumulate_every=4, 256 | save_every=100, 257 | image_width=512, 258 | num_layers=16, 259 | epochs=20, 260 | iterations=1050, 261 | save_progress=True, 262 | seed=None, 263 | open_folder=True, 264 | save_date_time=False, 265 | start_image_path=None, 266 | start_image_train_iters=10, 267 | start_image_lr=3e-4, 268 | theta_initial=None, 269 | theta_hidden=None, 270 | model_name="ViT-B/32", 271 | lower_bound_cutout=0.1, # should be smaller than 0.8 272 | upper_bound_cutout=1.0, 273 | saturate_bound=False, 274 | averaging_weight=0.3, 275 | 276 | create_story=False, 277 | story_start_words=5, 278 | story_words_per_epoch=5, 279 | story_separator=None, 280 | gauss_sampling=False, 281 | gauss_mean=0.6, 282 | gauss_std=0.2, 283 | do_cutout=True, 284 | center_bias=False, 285 | center_focus=2, 286 | optimizer="AdamP", 287 | jit=True, 288 | hidden_size=256, 289 | save_gif=False, 290 | save_video=False, 291 | ): 292 | 293 | super().__init__() 294 | 295 | if exists(seed): 296 | tqdm.write(f'setting seed: {seed}') 297 | torch.manual_seed(seed) 298 | torch.cuda.manual_seed(seed) 299 | random.seed(seed) 300 | torch.backends.cudnn.deterministic = True 301 | 302 | # fields for story creation: 303 | self.create_story = create_story 304 | self.words = None 305 | self.separator = str(story_separator) if story_separator is not None else None 306 | if self.separator is not None and text is not None: 307 | #exit if text is just the separator 308 | if str(text).replace(' ','').replace(self.separator,'') == '': 309 | print('Exiting because the text only consists of the separator! Needs words or phrases that are separated by the separator.') 310 | exit() 311 | #adds a space to each separator and removes double spaces that might be generated 312 | text = text.replace(self.separator,self.separator+' ').replace(' ',' ').strip() 313 | self.all_words = text.split(" ") if text is not None else None 314 | self.num_start_words = story_start_words 315 | self.words_per_epoch = story_words_per_epoch 316 | if create_story: 317 | assert text is not None, "We need text input to create a story..." 318 | # overwrite epochs to match story length 319 | num_words = len(self.all_words) 320 | self.epochs = 1 + (num_words - self.num_start_words) / self.words_per_epoch 321 | # add one epoch if not divisible 322 | self.epochs = int(self.epochs) if int(self.epochs) == self.epochs else int(self.epochs) + 1 323 | if self.separator is not None: 324 | if self.separator not in text: 325 | print("Separator '"+self.separator+"' will be ignored since not in text!") 326 | self.separator = None 327 | else: 328 | self.epochs = len(list(filter(None,text.split(self.separator)))) 329 | print("Running for", self.epochs, "epochs" + (" (split with '"+self.separator+"' as the separator)" if self.separator is not None else "")) 330 | else: 331 | self.epochs = epochs 332 | 333 | # jit models only compatible with version 1.7.1 334 | if "1.7.1" not in torch.__version__: 335 | if jit == True: 336 | print("Setting jit to False because torch version is not 1.7.1.") 337 | jit = False 338 | 339 | # Load CLIP 340 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 341 | clip_perceptor, norm = load(model_name, jit=jit, device=self.device) 342 | self.perceptor = clip_perceptor.eval() 343 | for param in self.perceptor.parameters(): 344 | param.requires_grad = False 345 | if jit == False: 346 | input_res = clip_perceptor.visual.input_resolution 347 | else: 348 | input_res = clip_perceptor.input_resolution.item() 349 | self.clip_transform = create_clip_img_transform(input_res) 350 | 351 | self.iterations = iterations 352 | self.image_width = image_width 353 | total_batches = self.epochs * self.iterations * batch_size * gradient_accumulate_every 354 | model = DeepDaze( 355 | self.perceptor, 356 | norm, 357 | input_res, 358 | total_batches, 359 | batch_size=batch_size, 360 | image_width=image_width, 361 | num_layers=num_layers, 362 | theta_initial=theta_initial, 363 | theta_hidden=theta_hidden, 364 | lower_bound_cutout=lower_bound_cutout, 365 | upper_bound_cutout=upper_bound_cutout, 366 | saturate_bound=saturate_bound, 367 | gauss_sampling=gauss_sampling, 368 | gauss_mean=gauss_mean, 369 | gauss_std=gauss_std, 370 | do_cutout=do_cutout, 371 | center_bias=center_bias, 372 | center_focus=center_focus, 373 | hidden_size=hidden_size, 374 | averaging_weight=averaging_weight, 375 | ).to(self.device) 376 | self.model = model 377 | self.scaler = GradScaler() 378 | siren_params = model.model.parameters() 379 | if optimizer == "AdamP": 380 | self.optimizer = AdamP(siren_params, lr) 381 | elif optimizer == "Adam": 382 | self.optimizer = torch.optim.Adam(siren_params, lr) 383 | elif optimizer == "DiffGrad": 384 | self.optimizer = DiffGrad(siren_params, lr) 385 | self.gradient_accumulate_every = gradient_accumulate_every 386 | self.save_every = save_every 387 | self.save_date_time = save_date_time 388 | self.open_folder = open_folder 389 | self.save_progress = save_progress 390 | self.text = text 391 | self.image = img 392 | self.textpath = create_text_path(self.perceptor.context_length, text=text, img=img, encoding=clip_encoding, separator=story_separator) 393 | self.filename = self.image_output_path() 394 | 395 | # create coding to optimize for 396 | self.clip_encoding = self.create_clip_encoding(text=text, img=img, encoding=clip_encoding) 397 | 398 | self.start_image = None 399 | self.start_image_train_iters = start_image_train_iters 400 | self.start_image_lr = start_image_lr 401 | if exists(start_image_path): 402 | file = Path(start_image_path) 403 | assert file.exists(), f'file does not exist at given starting image path {start_image_path}' 404 | image = Image.open(str(file)) 405 | start_img_transform = T.Compose([T.Resize(image_width), 406 | T.CenterCrop((image_width, image_width)), 407 | T.ToTensor()]) 408 | image_tensor = start_img_transform(image).unsqueeze(0).to(self.device) 409 | self.start_image = image_tensor 410 | 411 | self.save_gif = save_gif 412 | self.save_video = save_video 413 | 414 | def create_clip_encoding(self, text=None, img=None, encoding=None): 415 | self.text = text 416 | self.img = img 417 | if encoding is not None: 418 | encoding = encoding.to(self.device) 419 | elif self.create_story: 420 | encoding = self.update_story_encoding(epoch=0, iteration=1) 421 | elif text is not None and img is not None: 422 | encoding = (self.create_text_encoding(text) + self.create_img_encoding(img)) / 2 423 | elif text is not None: 424 | encoding = self.create_text_encoding(text) 425 | elif img is not None: 426 | encoding = self.create_img_encoding(img) 427 | return encoding 428 | 429 | def create_text_encoding(self, text): 430 | tokenized_text = tokenize(text).to(self.device) 431 | with torch.no_grad(): 432 | text_encoding = self.perceptor.encode_text(tokenized_text).detach() 433 | return text_encoding 434 | 435 | def create_img_encoding(self, img): 436 | if isinstance(img, str): 437 | img = Image.open(img) 438 | normed_img = self.clip_transform(img).unsqueeze(0).to(self.device) 439 | with torch.no_grad(): 440 | img_encoding = self.perceptor.encode_image(normed_img).detach() 441 | return img_encoding 442 | 443 | def set_clip_encoding(self, text=None, img=None, encoding=None): 444 | encoding = self.create_clip_encoding(text=text, img=img, encoding=encoding) 445 | self.clip_encoding = encoding.to(self.device) 446 | 447 | def index_of_first_separator(self) -> int: 448 | for c, word in enumerate(self.all_words): 449 | if self.separator in str(word): 450 | return c +1 451 | 452 | def update_story_encoding(self, epoch, iteration): 453 | if self.separator is not None: 454 | self.words = " ".join(self.all_words[:self.index_of_first_separator()]) 455 | #removes separator from epoch-text 456 | self.words = self.words.replace(self.separator,'') 457 | self.all_words = self.all_words[self.index_of_first_separator():] 458 | else: 459 | if self.words is None: 460 | self.words = " ".join(self.all_words[:self.num_start_words]) 461 | self.all_words = self.all_words[self.num_start_words:] 462 | else: 463 | # add words_per_epoch new words 464 | count = 0 465 | while count < self.words_per_epoch and len(self.all_words) > 0: 466 | new_word = self.all_words[0] 467 | self.words = " ".join(self.words.split(" ") + [new_word]) 468 | self.all_words = self.all_words[1:] 469 | count += 1 470 | # remove words until it fits in context length 471 | while len(self.words) > self.perceptor.context_length: 472 | # remove first word 473 | self.words = " ".join(self.words.split(" ")[1:]) 474 | # get new encoding 475 | print("Now thinking of: ", '"', self.words, '"') 476 | sequence_number = self.get_img_sequence_number(epoch, iteration) 477 | # save new words to disc 478 | with open("story_transitions.txt", "a") as f: 479 | f.write(f"{epoch}, {sequence_number}, {self.words}\n") 480 | 481 | encoding = self.create_text_encoding(self.words) 482 | return encoding 483 | 484 | def image_output_path(self, sequence_number=None): 485 | """ 486 | Returns underscore separated Path. 487 | A current timestamp is prepended if `self.save_date_time` is set. 488 | Sequence number left padded with 6 zeroes is appended if `save_every` is set. 489 | :rtype: Path 490 | """ 491 | output_path = self.textpath 492 | if sequence_number: 493 | sequence_number_left_padded = str(sequence_number).zfill(6) 494 | output_path = f"{output_path}.{sequence_number_left_padded}" 495 | if self.save_date_time: 496 | current_time = datetime.now().strftime("%y%m%d-%H%M%S_%f") 497 | output_path = f"{current_time}_{output_path}" 498 | return Path(f"{output_path}.jpg") 499 | 500 | def train_step(self, epoch, iteration): 501 | total_loss = 0 502 | 503 | for _ in range(self.gradient_accumulate_every): 504 | with autocast(enabled=True): 505 | out, loss = self.model(self.clip_encoding) 506 | loss = loss / self.gradient_accumulate_every 507 | total_loss += loss 508 | self.scaler.scale(loss).backward() 509 | out = out.cpu().float().clamp(0., 1.) 510 | self.scaler.step(self.optimizer) 511 | self.scaler.update() 512 | self.optimizer.zero_grad() 513 | 514 | if (iteration % self.save_every == 0) and self.save_progress: 515 | self.save_image(epoch, iteration, img=out) 516 | 517 | return out, total_loss 518 | 519 | def get_img_sequence_number(self, epoch, iteration): 520 | current_total_iterations = epoch * self.iterations + iteration 521 | sequence_number = current_total_iterations // self.save_every 522 | return sequence_number 523 | 524 | @torch.no_grad() 525 | def save_image(self, epoch, iteration, img=None): 526 | sequence_number = self.get_img_sequence_number(epoch, iteration) 527 | 528 | if img is None: 529 | img = self.model(self.clip_encoding, return_loss=False).cpu().float().clamp(0., 1.) 530 | self.filename = self.image_output_path(sequence_number=sequence_number) 531 | 532 | pil_img = T.ToPILImage()(img.squeeze()) 533 | pil_img.save(self.filename, quality=95, subsampling=0) 534 | pil_img.save(f"{self.textpath}.jpg", quality=95, subsampling=0) 535 | 536 | tqdm.write(f'image updated at "./{str(self.filename)}"') 537 | 538 | def generate_gif(self): 539 | images = [] 540 | for file_name in sorted(os.listdir('./')): 541 | if file_name.startswith(self.textpath) and file_name != f'{self.textpath}.jpg': 542 | images.append(imread(os.path.join('./', file_name))) 543 | 544 | if self.save_video: 545 | mimsave(f'{self.textpath}.mp4', images) 546 | print(f'Generated image generation animation at ./{self.textpath}.mp4') 547 | if self.save_gif: 548 | mimsave(f'{self.textpath}.gif', images) 549 | print(f'Generated image generation animation at ./{self.textpath}.gif') 550 | 551 | def forward(self): 552 | if exists(self.start_image): 553 | tqdm.write('Preparing with initial image...') 554 | optim = DiffGrad(self.model.model.parameters(), lr = self.start_image_lr) 555 | pbar = trange(self.start_image_train_iters, desc='iteration') 556 | try: 557 | for _ in pbar: 558 | loss = self.model.model(self.start_image) 559 | loss.backward() 560 | pbar.set_description(f'loss: {loss.item():.2f}') 561 | 562 | optim.step() 563 | optim.zero_grad() 564 | except KeyboardInterrupt: 565 | print('interrupted by keyboard, gracefully exiting') 566 | return exit() 567 | 568 | del self.start_image 569 | del optim 570 | 571 | tqdm.write(f'Imagining "{self.textpath}" from the depths of my weights...') 572 | 573 | with torch.no_grad(): 574 | self.model(self.clip_encoding, dry_run=True) # do one warmup step due to potential issue with CLIP and CUDA 575 | 576 | if self.open_folder: 577 | open_folder('./') 578 | self.open_folder = False 579 | 580 | try: 581 | for epoch in trange(self.epochs, desc='epochs'): 582 | pbar = trange(self.iterations, desc='iteration') 583 | for i in pbar: 584 | _, loss = self.train_step(epoch, i) 585 | pbar.set_description(f'loss: {loss.item():.2f}') 586 | 587 | # Update clip_encoding per epoch if we are creating a story 588 | if self.create_story: 589 | self.clip_encoding = self.update_story_encoding(epoch, i) 590 | except KeyboardInterrupt: 591 | print('interrupted by keyboard, gracefully exiting') 592 | return 593 | 594 | self.save_image(epoch, i) # one final save at end 595 | 596 | if (self.save_gif or self.save_video) and self.save_progress: 597 | self.generate_gif() 598 | -------------------------------------------------------------------------------- /deep_daze/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.11.1' 2 | -------------------------------------------------------------------------------- /instruction_images/Windows/Step_1_DD_Win.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/deep-daze/c3c471e63c30ccabfd8dfc09ced3028a8979ebe4/instruction_images/Windows/Step_1_DD_Win.png -------------------------------------------------------------------------------- /instruction_images/Windows/Step_2_DD_Win.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/deep-daze/c3c471e63c30ccabfd8dfc09ced3028a8979ebe4/instruction_images/Windows/Step_2_DD_Win.png -------------------------------------------------------------------------------- /samples/A_man_painting_a_completely_red_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/deep-daze/c3c471e63c30ccabfd8dfc09ced3028a8979ebe4/samples/A_man_painting_a_completely_red_image.png -------------------------------------------------------------------------------- /samples/A_psychedelic_experience_on_LSD.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/deep-daze/c3c471e63c30ccabfd8dfc09ced3028a8979ebe4/samples/A_psychedelic_experience_on_LSD.png -------------------------------------------------------------------------------- /samples/A_time_traveler_in_the_crowd.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/deep-daze/c3c471e63c30ccabfd8dfc09ced3028a8979ebe4/samples/A_time_traveler_in_the_crowd.jpg -------------------------------------------------------------------------------- /samples/Autumn_1875_Frederic_Edwin_Church.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/deep-daze/c3c471e63c30ccabfd8dfc09ced3028a8979ebe4/samples/Autumn_1875_Frederic_Edwin_Church.jpg -------------------------------------------------------------------------------- /samples/Autumn_1875_Frederic_Edwin_Church_original.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/deep-daze/c3c471e63c30ccabfd8dfc09ced3028a8979ebe4/samples/Autumn_1875_Frederic_Edwin_Church_original.jpg -------------------------------------------------------------------------------- /samples/Cosmic_love_and_attention.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/deep-daze/c3c471e63c30ccabfd8dfc09ced3028a8979ebe4/samples/Cosmic_love_and_attention.jpg -------------------------------------------------------------------------------- /samples/Life_during_the_plague.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/deep-daze/c3c471e63c30ccabfd8dfc09ced3028a8979ebe4/samples/Life_during_the_plague.jpg -------------------------------------------------------------------------------- /samples/Meditative_peace_in_a_sunlit_forest.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/deep-daze/c3c471e63c30ccabfd8dfc09ced3028a8979ebe4/samples/Meditative_peace_in_a_sunlit_forest.jpg -------------------------------------------------------------------------------- /samples/Mist_over_green_hills.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/deep-daze/c3c471e63c30ccabfd8dfc09ced3028a8979ebe4/samples/Mist_over_green_hills.jpg -------------------------------------------------------------------------------- /samples/Shattered_plates_on_the_grass.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/deep-daze/c3c471e63c30ccabfd8dfc09ced3028a8979ebe4/samples/Shattered_plates_on_the_grass.jpg -------------------------------------------------------------------------------- /samples/cosmic-love.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/deep-daze/c3c471e63c30ccabfd8dfc09ced3028a8979ebe4/samples/cosmic-love.png -------------------------------------------------------------------------------- /samples/hot-dog.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/deep-daze/c3c471e63c30ccabfd8dfc09ced3028a8979ebe4/samples/hot-dog.jpg -------------------------------------------------------------------------------- /samples/hot-dog_imagined.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/deep-daze/c3c471e63c30ccabfd8dfc09ced3028a8979ebe4/samples/hot-dog_imagined.png -------------------------------------------------------------------------------- /samples/life-plague.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/deep-daze/c3c471e63c30ccabfd8dfc09ced3028a8979ebe4/samples/life-plague.png -------------------------------------------------------------------------------- /samples/mist-over-green-hills.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/deep-daze/c3c471e63c30ccabfd8dfc09ced3028a8979ebe4/samples/mist-over-green-hills.png -------------------------------------------------------------------------------- /samples/peace-sunlit-forest.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/deep-daze/c3c471e63c30ccabfd8dfc09ced3028a8979ebe4/samples/peace-sunlit-forest.png -------------------------------------------------------------------------------- /samples/prime-orig.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/deep-daze/c3c471e63c30ccabfd8dfc09ced3028a8979ebe4/samples/prime-orig.jpg -------------------------------------------------------------------------------- /samples/prime-trained.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/deep-daze/c3c471e63c30ccabfd8dfc09ced3028a8979ebe4/samples/prime-trained.png -------------------------------------------------------------------------------- /samples/psychedelic_hot_dog.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/deep-daze/c3c471e63c30ccabfd8dfc09ced3028a8979ebe4/samples/psychedelic_hot_dog.png -------------------------------------------------------------------------------- /samples/shattered-plates.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/deep-daze/c3c471e63c30ccabfd8dfc09ced3028a8979ebe4/samples/shattered-plates.png -------------------------------------------------------------------------------- /samples/time-traveler.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/deep-daze/c3c471e63c30ccabfd8dfc09ced3028a8979ebe4/samples/time-traveler.png -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from setuptools import setup, find_packages 3 | 4 | sys.path[0:0] = ['deep_daze'] 5 | from version import __version__ 6 | 7 | setup( 8 | name = 'deep-daze', 9 | packages = find_packages(), 10 | include_package_data = True, 11 | entry_points={ 12 | 'console_scripts': [ 13 | 'imagine = deep_daze.cli:main', 14 | ], 15 | }, 16 | version = __version__, 17 | license='MIT', 18 | description = 'Deep Daze', 19 | author = 'Ryan Murdock, Phil Wang', 20 | author_email = 'lucidrains@gmail.com', 21 | url = 'https://github.com/lucidrains/deep-daze', 22 | keywords = [ 23 | 'artificial intelligence', 24 | 'deep learning', 25 | 'transformers', 26 | 'implicit neural representations', 27 | 'text to image' 28 | ], 29 | install_requires=[ 30 | 'einops>=0.3', 31 | 'fire', 32 | 'ftfy', 33 | 'imageio>=2.9.0', 34 | 'siren-pytorch>=0.0.8', 35 | 'torch>=1.10', 36 | 'torch_optimizer', 37 | 'torchvision>=0.8.2', 38 | 'tqdm', 39 | 'regex' 40 | ], 41 | classifiers=[ 42 | 'Development Status :: 4 - Beta', 43 | 'Intended Audience :: Developers', 44 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 45 | 'License :: OSI Approved :: MIT License', 46 | 'Programming Language :: Python :: 3.6', 47 | ], 48 | ) 49 | --------------------------------------------------------------------------------