├── .github
└── workflows
│ └── python-publish.yml
├── .gitignore
├── LICENSE
├── MANIFEST.in
├── README.md
├── deep_daze
├── __init__.py
├── cli.py
├── clip.py
├── data
│ └── bpe_simple_vocab_16e6.txt
├── deep_daze.py
└── version.py
├── instruction_images
└── Windows
│ ├── Step_1_DD_Win.png
│ └── Step_2_DD_Win.png
├── samples
├── A_man_painting_a_completely_red_image.png
├── A_psychedelic_experience_on_LSD.png
├── A_time_traveler_in_the_crowd.jpg
├── Autumn_1875_Frederic_Edwin_Church.jpg
├── Autumn_1875_Frederic_Edwin_Church_original.jpg
├── Cosmic_love_and_attention.jpg
├── Life_during_the_plague.jpg
├── Meditative_peace_in_a_sunlit_forest.jpg
├── Mist_over_green_hills.jpg
├── Shattered_plates_on_the_grass.jpg
├── cosmic-love.png
├── hot-dog.jpg
├── hot-dog_imagined.png
├── life-plague.png
├── mist-over-green-hills.png
├── peace-sunlit-forest.png
├── prime-orig.jpg
├── prime-trained.png
├── psychedelic_hot_dog.png
├── shattered-plates.png
└── time-traveler.png
└── setup.py
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflow will upload a Python Package using Twine when a release is created
2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3 |
4 | name: Upload Python Package
5 |
6 | on:
7 | release:
8 | types: [created]
9 |
10 | jobs:
11 | deploy:
12 |
13 | runs-on: ubuntu-latest
14 |
15 | steps:
16 | - uses: actions/checkout@v2
17 | - name: Set up Python
18 | uses: actions/setup-python@v2
19 | with:
20 | python-version: '3.x'
21 | - name: Install dependencies
22 | run: |
23 | python -m pip install --upgrade pip
24 | pip install setuptools wheel twine
25 | - name: Build and publish
26 | env:
27 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
28 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
29 | run: |
30 | python setup.py sdist bdist_wheel
31 | twine upload dist/*
32 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | # IDEs
132 | .vscode/
133 | .idea/
134 |
135 | output/
136 | run.py
137 | run.sh
138 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Ryan Murdock, Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | recursive-include deep_daze *.txt
2 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Deep Daze
2 |
3 |
4 |
5 | *mist over green hills*
6 |
7 |
8 |
9 | *shattered plates on the grass*
10 |
11 |
12 |
13 | *cosmic love and attention*
14 |
15 |
16 |
17 | *a time traveler in the crowd*
18 |
19 |
20 |
21 | *life during the plague*
22 |
23 |
24 |
25 | *meditative peace in a sunlit forest*
26 |
27 |
28 |
29 | *a man painting a completely red image*
30 |
31 |
32 |
33 | *a psychedelic experience on LSD*
34 |
35 | ## What is this?
36 |
37 | Simple command line tool for text to image generation using OpenAI's CLIP and Siren. Credit goes to Ryan Murdock for the discovery of this technique (and for coming up with the great name)!
38 |
39 | Original notebook [![Open In Colab][colab-badge]][colab-notebook]
40 |
41 | New simplified notebook [![Open In Colab][colab-badge]][colab-notebook-2]
42 |
43 | [colab-notebook]:
44 | [colab-notebook-2]:
45 | [colab-badge]:
46 |
47 | This will require that you have an Nvidia GPU or AMD GPU
48 | - Recommended: 16GB VRAM
49 | - Minimum Requirements: 4GB VRAM (Using VERY LOW settings, see usage instructions below)
50 |
51 | ## Install
52 |
53 | ```bash
54 | $ pip install deep-daze
55 | ```
56 |
57 | ### Windows Install
58 |
59 |
60 |
61 | Presuming Python is installed:
62 | - Open command prompt and navigate to the directory of your current version of Python
63 | ```bash
64 | pip install deep-daze
65 | ```
66 |
67 | ## Examples
68 |
69 | ```bash
70 | $ imagine "a house in the forest"
71 | ```
72 | For Windows:
73 |
74 |
75 |
76 | - Open command prompt as administrator
77 | ```bash
78 | imagine "a house in the forest"
79 | ```
80 |
81 | That's it.
82 |
83 |
84 | If you have enough memory, you can get better quality by adding a `--deeper` flag
85 |
86 | ```bash
87 | $ imagine "shattered plates on the ground" --deeper
88 | ```
89 |
90 | ### Advanced
91 |
92 | In true deep learning fashion, more layers will yield better results. Default is at `16`, but can be increased to `32` depending on your resources.
93 |
94 | ```bash
95 | $ imagine "stranger in strange lands" --num-layers 32
96 | ```
97 |
98 | ## Usage
99 |
100 | ### CLI
101 | ```bash
102 | NAME
103 | imagine
104 |
105 | SYNOPSIS
106 | imagine TEXT
107 |
108 | POSITIONAL ARGUMENTS
109 | TEXT
110 | (required) A phrase less than 77 tokens which you would like to visualize.
111 |
112 | FLAGS
113 | --img=IMAGE_PATH
114 | Default: None
115 | Path to png/jpg image or PIL image to optimize on
116 | --encoding=ENCODING
117 | Default: None
118 | User-created custom CLIP encoding. If used, replaces any text or image that was used.
119 | --create_story=CREATE_STORY
120 | Default: False
121 | Creates a story by optimizing each epoch on a new sliding-window of the input words. If this is enabled, much longer texts than 77 tokens can be used. Requires save_progress to visualize the transitions of the story.
122 | --story_start_words=STORY_START_WORDS
123 | Default: 5
124 | Only used if create_story is True. How many words to optimize on for the first epoch.
125 | --story_words_per_epoch=STORY_WORDS_PER_EPOCH
126 | Default: 5
127 | Only used if create_story is True. How many words to add to the optimization goal per epoch after the first one.
128 | --story_separator:
129 | Default: None
130 | Only used if create_story is True. Defines a separator like '.' that splits the text into groups for each epoch. Separator needs to be in the text otherwise it will be ignored
131 | --lower_bound_cutout=LOWER_BOUND_CUTOUT
132 | Default: 0.1
133 | Lower bound of the sampling of the size of the random cut-out of the SIREN image per batch. Should be smaller than 0.8.
134 | --upper_bound_cutout=UPPER_BOUND_CUTOUT
135 | Default: 1.0
136 | Upper bound of the sampling of the size of the random cut-out of the SIREN image per batch. Should probably stay at 1.0.
137 | --saturate_bound=SATURATE_BOUND
138 | Default: False
139 | If True, the LOWER_BOUND_CUTOUT is linearly increased to 0.75 during training.
140 | --learning_rate=LEARNING_RATE
141 | Default: 1e-05
142 | The learning rate of the neural net.
143 | --num_layers=NUM_LAYERS
144 | Default: 16
145 | The number of hidden layers to use in the Siren neural net.
146 | --batch_size=BATCH_SIZE
147 | Default: 4
148 | The number of generated images to pass into Siren before calculating loss. Decreasing this can lower memory and accuracy.
149 | --gradient_accumulate_every=GRADIENT_ACCUMULATE_EVERY
150 | Default: 4
151 | Calculate a weighted loss of n samples for each iteration. Increasing this can help increase accuracy with lower batch sizes.
152 | --epochs=EPOCHS
153 | Default: 20
154 | The number of epochs to run.
155 | --iterations=ITERATIONS
156 | Default: 1050
157 | The number of times to calculate and backpropagate loss in a given epoch.
158 | --save_every=SAVE_EVERY
159 | Default: 100
160 | Generate an image every time iterations is a multiple of this number.
161 | --image_width=IMAGE_WIDTH
162 | Default: 512
163 | The desired resolution of the image.
164 | --deeper=DEEPER
165 | Default: False
166 | Uses a Siren neural net with 32 hidden layers.
167 | --overwrite=OVERWRITE
168 | Default: False
169 | Whether or not to overwrite existing generated images of the same name.
170 | --save_progress=SAVE_PROGRESS
171 | Default: False
172 | Whether or not to save images generated before training Siren is complete.
173 | --seed=SEED
174 | Type: Optional[]
175 | Default: None
176 | A seed to be used for deterministic runs.
177 | --open_folder=OPEN_FOLDER
178 | Default: True
179 | Whether or not to open a folder showing your generated images.
180 | --save_date_time=SAVE_DATE_TIME
181 | Default: False
182 | Save files with a timestamp prepended e.g. `%y%m%d-%H%M%S-my_phrase_here`
183 | --start_image_path=START_IMAGE_PATH
184 | Default: None
185 | The generator is trained first on a starting image before steered towards the textual input
186 | --start_image_train_iters=START_IMAGE_TRAIN_ITERS
187 | Default: 50
188 | The number of steps for the initial training on the starting image
189 | --theta_initial=THETA_INITIAL
190 | Default: 30.0
191 | Hyperparameter describing the frequency of the color space. Only applies to the first layer of the network.
192 | --theta_hidden=THETA_INITIAL
193 | Default: 30.0
194 | Hyperparameter describing the frequency of the color space. Only applies to the hidden layers of the network.
195 | --save_gif=SAVE_GIF
196 | Default: False
197 | Whether or not to save a GIF animation of the generation procedure. Only works if save_progress is set to True.
198 | ```
199 |
200 | ### Priming
201 |
202 | Technique first devised and shared by Mario Klingemann, it allows you to prime the generator network with a starting image, before being steered towards the text.
203 |
204 | Simply specify the path to the image you wish to use, and optionally the number of initial training steps.
205 |
206 | ```bash
207 | $ imagine 'a clear night sky filled with stars' --start_image_path ./cloudy-night-sky.jpg
208 | ```
209 |
210 | Primed starting image
211 |
212 |
213 |
214 | Then trained with the prompt `A pizza with green pepper.`
215 |
216 |
217 |
218 |
219 | ### Optimize for the interpretation of an image
220 |
221 | We can also feed in an image as an optimization goal, instead of only priming the generator network. Deepdaze will then render its own interpretation of that image:
222 | ```bash
223 | $ imagine --img samples/Autumn_1875_Frederic_Edwin_Church.jpg
224 | ```
225 | Original image:
226 |
227 |
228 |
229 | The network's interpretation:
230 |
231 |
232 |
233 | Original image:
234 |
235 |
236 |
237 | The network's interpretation:
238 |
239 |
240 |
241 | #### Optimize for text and image combined
242 |
243 | ```bash
244 | $ imagine "A psychedelic experience." --img samples/hot-dog.jpg
245 | ```
246 | The network's interpretation:
247 |
248 |
249 |
250 | ### New: Create a story
251 | The regular mode for texts only allows 77 tokens. If you want to visualize a full story/paragraph/song/poem, set `create_story` to `True`.
252 |
253 | Given the poem “Stopping by Woods On a Snowy Evening” by Robert Frost -
254 | "Whose woods these are I think I know. His house is in the village though; He will not see me stopping here To watch his woods fill up with snow. My little horse must think it queer To stop without a farmhouse near Between the woods and frozen lake The darkest evening of the year. He gives his harness bells a shake To ask if there is some mistake. The only other sound’s the sweep Of easy wind and downy flake. The woods are lovely, dark and deep, But I have promises to keep, And miles to go before I sleep, And miles to go before I sleep.".
255 |
256 | We get:
257 |
258 | https://user-images.githubusercontent.com/19983153/109539633-d671ef80-7ac1-11eb-8d8c-380332d7c868.mp4
259 |
260 |
261 |
262 | ### Python
263 | #### Invoke `deep_daze.Imagine` in Python
264 | ```python
265 | from deep_daze import Imagine
266 |
267 | imagine = Imagine(
268 | text = 'cosmic love and attention',
269 | num_layers = 24,
270 | )
271 | imagine()
272 | ```
273 |
274 | #### Save progress every fourth iteration
275 | Save images in the format insert_text_here.00001.png, insert_text_here.00002.png, ...up to `(total_iterations % save_every)`
276 | ```python
277 | imagine = Imagine(
278 | text=text,
279 | save_every=4,
280 | save_progress=True
281 | )
282 | ```
283 |
284 | #### Prepend current timestamp on each image.
285 | Creates files with both the timestamp and the sequence number.
286 |
287 | e.g. 210129-043928_328751_insert_text_here.00001.png, 210129-043928_512351_insert_text_here.00002.png, ...
288 | ```python
289 | imagine = Imagine(
290 | text=text,
291 | save_every=4,
292 | save_progress=True,
293 | save_date_time=True,
294 | )
295 | ```
296 |
297 | #### High GPU memory usage
298 | If you have at least 16 GiB of vram available, you should be able to run these settings with some wiggle room.
299 | ```python
300 | imagine = Imagine(
301 | text=text,
302 | num_layers=42,
303 | batch_size=64,
304 | gradient_accumulate_every=1,
305 | )
306 | ```
307 |
308 | #### Average GPU memory usage
309 | ```python
310 | imagine = Imagine(
311 | text=text,
312 | num_layers=24,
313 | batch_size=16,
314 | gradient_accumulate_every=2
315 | )
316 | ```
317 |
318 | #### Very low GPU memory usage (less than 4 GiB)
319 | If you are desperate to run this on a card with less than 8 GiB vram, you can lower the image_width.
320 | ```python
321 | imagine = Imagine(
322 | text=text,
323 | image_width=256,
324 | num_layers=16,
325 | batch_size=1,
326 | gradient_accumulate_every=16 # Increase gradient_accumulate_every to correct for loss in low batch sizes
327 | )
328 | ```
329 |
330 | ### VRAM and speed benchmarks:
331 | These experiments were conducted with a 2060 Super RTX and a 3700X Ryzen 5. We first mention the parameters (bs = batch size), then the memory usage and in some cases the training iterations per second:
332 |
333 | For an image resolution of 512:
334 | * bs 1, num_layers 22: 7.96 GB
335 | * bs 2, num_layers 20: 7.5 GB
336 | * bs 16, num_layers 16: 6.5 GB
337 |
338 | For an image resolution of 256:
339 | * bs 8, num_layers 48: 5.3 GB
340 | * bs 16, num_layers 48: 5.46 GB - 2.0 it/s
341 | * bs 32, num_layers 48: 5.92 GB - 1.67 it/s
342 | * bs 8, num_layers 44: 5 GB - 2.39 it/s
343 | * bs 32, num_layers 44, grad_acc 1: 5.62 GB - 4.83 it/s
344 | * bs 96, num_layers 44, grad_acc 1: 7.51 GB - 2.77 it/s
345 | * bs 32, num_layers 66, grad_acc 1: 7.09 GB - 3.7 it/s
346 |
347 | @NotNANtoN recommends a batch size of 32 with 44 layers and training 1-8 epochs.
348 |
349 |
350 | ## Where is this going?
351 |
352 | This is just a teaser. We will be able to generate images, sound, anything at will, with natural language. The holodeck is about to become real in our lifetimes.
353 |
354 | Please join replication efforts for DALL-E for Pytorch or Mesh Tensorflow if you are interested in furthering this technology.
355 |
356 | ## Alternatives
357 |
358 | Big Sleep - CLIP and the generator from Big GAN
359 |
360 | ## Citations
361 |
362 | ```bibtex
363 | @misc{unpublished2021clip,
364 | title = {CLIP: Connecting Text and Images},
365 | author = {Alec Radford, Ilya Sutskever, Jong Wook Kim, Gretchen Krueger, Sandhini Agarwal},
366 | year = {2021}
367 | }
368 | ```
369 |
370 | ```bibtex
371 | @misc{sitzmann2020implicit,
372 | title = {Implicit Neural Representations with Periodic Activation Functions},
373 | author = {Vincent Sitzmann and Julien N. P. Martel and Alexander W. Bergman and David B. Lindell and Gordon Wetzstein},
374 | year = {2020},
375 | eprint = {2006.09661},
376 | archivePrefix = {arXiv},
377 | primaryClass = {cs.CV}
378 | }
379 | ```
380 |
381 | [colab-notebook]:
382 |
--------------------------------------------------------------------------------
/deep_daze/__init__.py:
--------------------------------------------------------------------------------
1 | from deep_daze.deep_daze import DeepDaze, Imagine
2 |
--------------------------------------------------------------------------------
/deep_daze/cli.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | import fire
4 |
5 | from deep_daze import Imagine
6 |
7 |
8 | def train(
9 | text=None,
10 | img=None,
11 | learning_rate=1e-5,
12 | num_layers=16,
13 | hidden_size=256,
14 | batch_size=4,
15 | gradient_accumulate_every=4,
16 | epochs=20,
17 | iterations=1050,
18 | save_every=100,
19 | image_width=512,
20 | deeper=False,
21 | overwrite=False,
22 | save_progress=True,
23 | seed=None,
24 | open_folder=True,
25 | save_date_time=False,
26 | start_image_path=None,
27 | start_image_train_iters=50,
28 | theta_initial=None,
29 | theta_hidden=None,
30 | start_image_lr=3e-4,
31 | lower_bound_cutout=0.1,
32 | upper_bound_cutout=1.0,
33 | saturate_bound=False,
34 | create_story=False,
35 | story_start_words=5,
36 | story_words_per_epoch=5,
37 | story_separator=None,
38 | averaging_weight=0.3,
39 | gauss_sampling=False,
40 | gauss_mean=0.6,
41 | gauss_std=0.2,
42 | do_cutout=True,
43 | center_bias=False,
44 | center_focus=2,
45 | jit=True,
46 | save_gif=False,
47 | save_video=False,
48 | model_name="ViT-B/32",
49 | optimizer="AdamP"
50 | ):
51 | """
52 | :param text: (required) A phrase less than 77 tokens which you would like to visualize.
53 | :param img: The path to a jpg or png image which you would like to imagine. Can be combined with text.
54 | :param learning_rate: The learning rate of the neural net.
55 | :param hidden_size: The hidden layer size of the Siren net.
56 | :param num_layers: The number of hidden layers to use in the Siren neural net.
57 | :param batch_size: The number of generated images to pass into Siren before calculating loss. Decreasing this can lower memory and accuracy.
58 | :param gradient_accumulate_every: Calculate a weighted loss of n samples for each iteration. Increasing this can help increase accuracy with lower batch sizes.
59 | :param epochs: The number of epochs to run.
60 | :param iterations: The number of times to calculate and backpropagate loss in a given epoch.
61 | :param save_progress: Whether or not to save images generated before training Siren is complete.
62 | :param save_every: Generate an image every time iterations is a multiple of this number.
63 | :param open_folder: Whether or not to open a folder showing your generated images.
64 | :param overwrite: Whether or not to overwrite existing generated images of the same name.
65 | :param deeper: Uses a Siren neural net with 32 hidden layers.
66 | :param image_width: The desired resolution of the image.
67 | :param seed: A seed to be used for deterministic runs.
68 | :param save_date_time: Save files with a timestamp prepended e.g. `%y%m%d-%H%M%S-my_phrase_here.png`
69 | :param start_image_path: Path to the image you would like to prime the generator with initially
70 | :param start_image_train_iters: Number of iterations for priming, defaults to 50
71 | :param theta_initial: Hyperparameter describing the frequency of the color space. Only applies to the first layer of the network.
72 | :param theta_hidden: Hyperparameter describing the frequency of the color space. Only applies to the hidden layers of the network.
73 | :param start_image_lr: Learning rate for the start image training.
74 | :param upper_bound_cutout: The upper bound for the cutouts used in generation.
75 | :param lower_bound_cutout: The lower bound for the cutouts used in generation.
76 | :param saturate_bound: If True, the LOWER_BOUND_CUTOUT is linearly increased to 0.75 during training.
77 | :param create_story: Creates a story by optimizing each epoch on a new sliding-window of the input words. If this is enabled, much longer texts than 77 tokens can be used. Requires save_progress to visualize the transitions of the story.
78 | :param story_start_words: Only used if create_story is True. How many words to optimize on for the first epoch.
79 | :param story_words_per_epoch: Only used if create_story is True. How many words to add to the optimization goal per epoch after the first one.
80 | :param story_separator: Only used if create_story is True. Defines a separator like '.' that splits the text into groups for each epoch. Separator needs to be in the text otherwise it will be ignored!
81 | :param averaging_weight: How much to weigh the averaged features of the random cutouts over the individual random cutouts. Increasing this value leads to more details being represented at the cost of some global coherence and a parcellation into smaller scenes.
82 | :param gauss_sampling: Whether to use sampling from a Gaussian distribution instead of a uniform distribution.
83 | :param gauss_mean: The mean of the Gaussian sampling distribution.
84 | :param gauss_std: The standard deviation of the Gaussian sampling distribution.
85 | :param do_cutouts: Whether to use random cutouts as an augmentation. This basically needs to be turned on unless some new augmentations are added in code eventually.
86 | :param center_bias: Whether to use a Gaussian distribution centered around the center of the image to sample the locations of random cutouts instead of a uniform distribution. Leads to the main generated objects to be more focused in the center.
87 | :param center_focus: How much to focus on the center if using center_bias. std = sampling_range / center_focus. High values lead to a very correct representation in the center but washed out colors and details towards the edges,
88 | :param jit: Whether to use the jit-compiled CLIP model. The jit model is faster, but only compatible with torch version 1.7.1.
89 | :param save_gif: Only used if save_progress is True. Saves a GIF animation of the generation procedure using the saved frames.
90 | :param save_video: Only used if save_progress is True. Saves a MP4 animation of the generation procedure using the saved frames.
91 | """
92 | # Don't instantiate imagine if the user just wants help.
93 | if any("--help" in arg for arg in sys.argv):
94 | print("Type `imagine --help` for usage info.")
95 | sys.exit()
96 |
97 | num_layers = 32 if deeper else num_layers
98 |
99 | imagine = Imagine(
100 | text=text,
101 | img=img,
102 | lr=learning_rate,
103 | num_layers=num_layers,
104 | batch_size=batch_size,
105 | gradient_accumulate_every=gradient_accumulate_every,
106 | epochs=epochs,
107 | iterations=iterations,
108 | image_width=image_width,
109 | save_every=save_every,
110 | save_progress=save_progress,
111 | seed=seed,
112 | open_folder=open_folder,
113 | save_date_time=save_date_time,
114 | start_image_path=start_image_path,
115 | start_image_train_iters=start_image_train_iters,
116 | theta_initial=theta_initial,
117 | theta_hidden=theta_hidden,
118 | start_image_lr=start_image_lr,
119 | lower_bound_cutout=lower_bound_cutout,
120 | upper_bound_cutout=upper_bound_cutout,
121 | saturate_bound=saturate_bound,
122 | create_story=create_story,
123 | story_start_words=story_start_words,
124 | story_words_per_epoch=story_words_per_epoch,
125 | story_separator=story_separator,
126 | averaging_weight=averaging_weight,
127 | gauss_sampling=gauss_sampling,
128 | gauss_mean=gauss_mean,
129 | gauss_std=gauss_std,
130 | do_cutout=do_cutout,
131 | center_bias=center_bias,
132 | center_focus=center_focus,
133 | jit=jit,
134 | hidden_size=hidden_size,
135 | model_name=model_name,
136 | optimizer=optimizer,
137 | save_gif=save_gif,
138 | save_video=save_video,
139 | )
140 |
141 | print('Starting up...')
142 | if not overwrite and imagine.filename.exists():
143 | answer = input('Imagined image already exists, do you want to overwrite? (y/n) ').lower()
144 | if answer not in ('yes', 'y'):
145 | sys.exit()
146 |
147 | imagine()
148 |
149 |
150 | def main():
151 | fire.Fire(train)
152 |
--------------------------------------------------------------------------------
/deep_daze/clip.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | from typing import Tuple, Union
3 |
4 | import torch
5 | import torch.nn.functional as F
6 | from torch import nn
7 | from pathlib import Path
8 |
9 | import hashlib
10 | import os
11 | import urllib
12 | import warnings
13 | from typing import Union, List
14 | from torchvision.transforms import Compose, Normalize
15 | from tqdm import tqdm
16 |
17 | _MODELS = {
18 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
19 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
20 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
21 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
22 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"
23 | }
24 |
25 | def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
26 | os.makedirs(root, exist_ok=True)
27 | filename = os.path.basename(url)
28 |
29 | expected_sha256 = url.split("/")[-2]
30 | download_target = os.path.join(root, filename)
31 |
32 | if os.path.exists(download_target) and not os.path.isfile(download_target):
33 | raise RuntimeError(f"{download_target} exists and is not a regular file")
34 |
35 | if os.path.isfile(download_target):
36 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
37 | return download_target
38 | else:
39 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
40 |
41 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
42 | with tqdm(
43 | total=int(source.info().get("Content-Length")),
44 | unit='iB',
45 | unit_scale=True,
46 | desc=f"Downloading {filename}",
47 | ) as loop:
48 | while True:
49 | buffer = source.read(524288)
50 | if not buffer:
51 | break
52 |
53 | output.write(buffer)
54 | loop.update(len(buffer))
55 |
56 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
57 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
58 |
59 | return download_target
60 |
61 |
62 | def _transform():
63 | return Compose([
64 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
65 | ])
66 |
67 |
68 | def available_models() -> List[str]:
69 | """Returns the names of available CLIP models"""
70 | return list(_MODELS.keys())
71 |
72 |
73 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True):
74 | """Load a CLIP model
75 |
76 | Parameters
77 | ----------
78 | name : str
79 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
80 |
81 | device : Union[str, torch.device]
82 | The device to put the loaded model
83 |
84 | jit : bool
85 | Whether to load the optimized JIT model (default) or more hackable non-JIT model.
86 |
87 | Returns
88 | -------
89 | model : torch.nn.Module
90 | The CLIP model
91 |
92 | preprocess : Callable[[PIL.Image], torch.Tensor]
93 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
94 | """
95 | if name in _MODELS:
96 | model_path = _download(_MODELS[name])
97 | elif os.path.isfile(name):
98 | model_path = name
99 | else:
100 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
101 |
102 | try:
103 | # loading JIT archive
104 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
105 | state_dict = None
106 | except RuntimeError:
107 | # loading saved state dict
108 | if jit:
109 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
110 | jit = False
111 | state_dict = torch.load(model_path, map_location="cpu")
112 |
113 | if not jit:
114 | model = build_model(state_dict or model.state_dict()).to(device)
115 | if str(device) == "cpu":
116 | model.float()
117 | return model, _transform()
118 |
119 | # patch the device names
120 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
121 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
122 |
123 | def patch_device(module):
124 | graphs = [module.graph] if hasattr(module, "graph") else []
125 | if hasattr(module, "forward1"):
126 | graphs.append(module.forward1.graph)
127 |
128 | for graph in graphs:
129 | for node in graph.findAllNodes("prim::Constant"):
130 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
131 | node.copyAttributes(device_node)
132 |
133 | model.apply(patch_device)
134 | patch_device(model.encode_image)
135 | patch_device(model.encode_text)
136 |
137 | # patch dtype to float32 on CPU
138 | if str(device) == "cpu":
139 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
140 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
141 | float_node = float_input.node()
142 |
143 | def patch_float(module):
144 | graphs = [module.graph] if hasattr(module, "graph") else []
145 | if hasattr(module, "forward1"):
146 | graphs.append(module.forward1.graph)
147 |
148 | for graph in graphs:
149 | for node in graph.findAllNodes("aten::to"):
150 | inputs = list(node.inputs())
151 | for i in [1, 2]: # dtype can be the second or third argument to aten::to()
152 | if inputs[i].node()["value"] == 5:
153 | inputs[i].node().copyAttributes(float_node)
154 |
155 | model.apply(patch_float)
156 | patch_float(model.encode_image)
157 | patch_float(model.encode_text)
158 |
159 | model.float()
160 |
161 | return model, _transform()
162 |
163 |
164 | def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
165 | """
166 | Returns the tokenized representation of given input string(s)
167 |
168 | Parameters
169 | ----------
170 | texts : Union[str, List[str]]
171 | An input string or a list of input strings to tokenize
172 |
173 | context_length : int
174 | The context length to use; all CLIP models use 77 as the context length
175 |
176 | Returns
177 | -------
178 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
179 | """
180 | if isinstance(texts, str):
181 | texts = [texts]
182 |
183 | sot_token = _tokenizer.encoder["<|startoftext|>"]
184 | eot_token = _tokenizer.encoder["<|endoftext|>"]
185 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
186 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
187 |
188 | for i, tokens in enumerate(all_tokens):
189 | if len(tokens) > context_length:
190 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
191 | result[i, :len(tokens)] = torch.tensor(tokens)
192 |
193 | return result
194 |
195 | class Bottleneck(nn.Module):
196 | expansion = 4
197 |
198 | def __init__(self, inplanes, planes, stride=1):
199 | super().__init__()
200 |
201 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
202 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
203 | self.bn1 = nn.BatchNorm2d(planes)
204 |
205 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
206 | self.bn2 = nn.BatchNorm2d(planes)
207 |
208 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
209 |
210 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
211 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
212 |
213 | self.relu = nn.ReLU(inplace=True)
214 | self.downsample = None
215 | self.stride = stride
216 |
217 | if stride > 1 or inplanes != planes * Bottleneck.expansion:
218 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
219 | self.downsample = nn.Sequential(OrderedDict([
220 | ("-1", nn.AvgPool2d(stride)),
221 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
222 | ("1", nn.BatchNorm2d(planes * self.expansion))
223 | ]))
224 |
225 | def forward(self, x: torch.Tensor):
226 | identity = x
227 |
228 | out = self.relu(self.bn1(self.conv1(x)))
229 | out = self.relu(self.bn2(self.conv2(out)))
230 | out = self.avgpool(out)
231 | out = self.bn3(self.conv3(out))
232 |
233 | if self.downsample is not None:
234 | identity = self.downsample(x)
235 |
236 | out += identity
237 | out = self.relu(out)
238 | return out
239 |
240 |
241 | class AttentionPool2d(nn.Module):
242 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
243 | super().__init__()
244 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
245 | self.k_proj = nn.Linear(embed_dim, embed_dim)
246 | self.q_proj = nn.Linear(embed_dim, embed_dim)
247 | self.v_proj = nn.Linear(embed_dim, embed_dim)
248 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
249 | self.num_heads = num_heads
250 |
251 | def forward(self, x):
252 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
253 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
254 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
255 | x, _ = F.multi_head_attention_forward(
256 | query=x, key=x, value=x,
257 | embed_dim_to_check=x.shape[-1],
258 | num_heads=self.num_heads,
259 | q_proj_weight=self.q_proj.weight,
260 | k_proj_weight=self.k_proj.weight,
261 | v_proj_weight=self.v_proj.weight,
262 | in_proj_weight=None,
263 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
264 | bias_k=None,
265 | bias_v=None,
266 | add_zero_attn=False,
267 | dropout_p=0,
268 | out_proj_weight=self.c_proj.weight,
269 | out_proj_bias=self.c_proj.bias,
270 | use_separate_proj_weight=True,
271 | training=self.training,
272 | need_weights=False
273 | )
274 |
275 | return x[0]
276 |
277 |
278 | class ModifiedResNet(nn.Module):
279 | """
280 | A ResNet class that is similar to torchvision's but contains the following changes:
281 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
282 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
283 | - The final pooling layer is a QKV attention instead of an average pool
284 | """
285 |
286 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
287 | super().__init__()
288 | self.output_dim = output_dim
289 | self.input_resolution = input_resolution
290 |
291 | # the 3-layer stem
292 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
293 | self.bn1 = nn.BatchNorm2d(width // 2)
294 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
295 | self.bn2 = nn.BatchNorm2d(width // 2)
296 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
297 | self.bn3 = nn.BatchNorm2d(width)
298 | self.avgpool = nn.AvgPool2d(2)
299 | self.relu = nn.ReLU(inplace=True)
300 |
301 | # residual layers
302 | self._inplanes = width # this is a *mutable* variable used during construction
303 | self.layer1 = self._make_layer(width, layers[0])
304 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
305 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
306 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
307 |
308 | embed_dim = width * 32 # the ResNet feature dimension
309 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
310 |
311 | def _make_layer(self, planes, blocks, stride=1):
312 | layers = [Bottleneck(self._inplanes, planes, stride)]
313 |
314 | self._inplanes = planes * Bottleneck.expansion
315 | for _ in range(1, blocks):
316 | layers.append(Bottleneck(self._inplanes, planes))
317 |
318 | return nn.Sequential(*layers)
319 |
320 | def forward(self, x):
321 | def stem(x):
322 | for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
323 | x = self.relu(bn(conv(x)))
324 | x = self.avgpool(x)
325 | return x
326 |
327 | x = x.type(self.conv1.weight.dtype)
328 | x = stem(x)
329 | x = self.layer1(x)
330 | x = self.layer2(x)
331 | x = self.layer3(x)
332 | x = self.layer4(x)
333 | x = self.attnpool(x)
334 |
335 | return x
336 |
337 |
338 | class LayerNorm(nn.LayerNorm):
339 | """Subclass torch's LayerNorm to handle fp16."""
340 |
341 | def forward(self, x: torch.Tensor):
342 | orig_type = x.dtype
343 | ret = super().forward(x.type(torch.float32))
344 | return ret.type(orig_type)
345 |
346 |
347 | class QuickGELU(nn.Module):
348 | def forward(self, x: torch.Tensor):
349 | return x * torch.sigmoid(1.702 * x)
350 |
351 |
352 | class ResidualAttentionBlock(nn.Module):
353 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
354 | super().__init__()
355 |
356 | self.attn = nn.MultiheadAttention(d_model, n_head)
357 | self.ln_1 = LayerNorm(d_model)
358 | self.mlp = nn.Sequential(OrderedDict([
359 | ("c_fc", nn.Linear(d_model, d_model * 4)),
360 | ("gelu", QuickGELU()),
361 | ("c_proj", nn.Linear(d_model * 4, d_model))
362 | ]))
363 | self.ln_2 = LayerNorm(d_model)
364 | self.attn_mask = attn_mask
365 |
366 | def attention(self, x: torch.Tensor):
367 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
368 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
369 |
370 | def forward(self, x: torch.Tensor):
371 | x = x + self.attention(self.ln_1(x))
372 | x = x + self.mlp(self.ln_2(x))
373 | return x
374 |
375 |
376 | class Transformer(nn.Module):
377 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
378 | super().__init__()
379 | self.width = width
380 | self.layers = layers
381 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
382 |
383 | def forward(self, x: torch.Tensor):
384 | return self.resblocks(x)
385 |
386 |
387 | class VisualTransformer(nn.Module):
388 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
389 | super().__init__()
390 | self.input_resolution = input_resolution
391 | self.output_dim = output_dim
392 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
393 |
394 | scale = width ** -0.5
395 | self.class_embedding = nn.Parameter(scale * torch.randn(width))
396 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
397 | self.ln_pre = LayerNorm(width)
398 |
399 | self.transformer = Transformer(width, layers, heads)
400 |
401 | self.ln_post = LayerNorm(width)
402 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
403 |
404 | def forward(self, x: torch.Tensor):
405 | x = self.conv1(x) # shape = [*, width, grid, grid]
406 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
407 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
408 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
409 | x = x + self.positional_embedding.to(x.dtype)
410 | x = self.ln_pre(x)
411 |
412 | x = x.permute(1, 0, 2) # NLD -> LND
413 | x = self.transformer(x)
414 | x = x.permute(1, 0, 2) # LND -> NLD
415 |
416 | x = self.ln_post(x[:, 0, :])
417 |
418 | if self.proj is not None:
419 | x = x @ self.proj
420 |
421 | return x
422 |
423 |
424 | class CLIP(nn.Module):
425 | def __init__(self,
426 | embed_dim: int,
427 | # vision
428 | image_resolution: int,
429 | vision_layers: Union[Tuple[int, int, int, int], int],
430 | vision_width: int,
431 | vision_patch_size: int,
432 | # text
433 | context_length: int,
434 | vocab_size: int,
435 | transformer_width: int,
436 | transformer_heads: int,
437 | transformer_layers: int
438 | ):
439 | super().__init__()
440 |
441 | self.context_length = context_length
442 |
443 | if isinstance(vision_layers, (tuple, list)):
444 | vision_heads = vision_width * 32 // 64
445 | self.visual = ModifiedResNet(
446 | layers=vision_layers,
447 | output_dim=embed_dim,
448 | heads=vision_heads,
449 | input_resolution=image_resolution,
450 | width=vision_width
451 | )
452 | else:
453 | vision_heads = vision_width // 64
454 | self.visual = VisualTransformer(
455 | input_resolution=image_resolution,
456 | patch_size=vision_patch_size,
457 | width=vision_width,
458 | layers=vision_layers,
459 | heads=vision_heads,
460 | output_dim=embed_dim
461 | )
462 |
463 | self.transformer = Transformer(
464 | width=transformer_width,
465 | layers=transformer_layers,
466 | heads=transformer_heads,
467 | attn_mask=self.build_attention_mask()
468 | )
469 |
470 | self.vocab_size = vocab_size
471 | self.token_embedding = nn.Embedding(vocab_size, transformer_width)
472 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
473 | self.ln_final = LayerNorm(transformer_width)
474 |
475 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
476 | self.logit_scale = nn.Parameter(torch.ones([]))
477 |
478 | self.initialize_parameters()
479 |
480 | def initialize_parameters(self):
481 | nn.init.normal_(self.token_embedding.weight, std=0.02)
482 | nn.init.normal_(self.positional_embedding, std=0.01)
483 |
484 | if isinstance(self.visual, ModifiedResNet):
485 | if self.visual.attnpool is not None:
486 | std = self.visual.attnpool.c_proj.in_features ** -0.5
487 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
488 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
489 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
490 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
491 |
492 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
493 | for name, param in resnet_block.named_parameters():
494 | if name.endswith("bn3.weight"):
495 | nn.init.zeros_(param)
496 |
497 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
498 | attn_std = self.transformer.width ** -0.5
499 | fc_std = (2 * self.transformer.width) ** -0.5
500 | for block in self.transformer.resblocks:
501 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
502 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
503 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
504 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
505 |
506 | if self.text_projection is not None:
507 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
508 |
509 | def build_attention_mask(self):
510 | # lazily create causal attention mask, with full attention between the vision tokens
511 | # pytorch uses additive attention mask; fill with -inf
512 | mask = torch.empty(self.context_length, self.context_length)
513 | mask.fill_(float("-inf"))
514 | mask.triu_(1) # zero out the lower diagonal
515 | return mask
516 |
517 | @property
518 | def dtype(self):
519 | return self.visual.conv1.weight.dtype
520 |
521 | def encode_image(self, image):
522 | return self.visual(image.type(self.dtype))
523 |
524 | def encode_text(self, text):
525 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
526 |
527 | x = x + self.positional_embedding.type(self.dtype)
528 | x = x.permute(1, 0, 2) # NLD -> LND
529 | x = self.transformer(x)
530 | x = x.permute(1, 0, 2) # LND -> NLD
531 | x = self.ln_final(x).type(self.dtype)
532 |
533 | # x.shape = [batch_size, n_ctx, transformer.width]
534 | # take features from the eot embedding (eot_token is the highest number in each sequence)
535 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
536 |
537 | return x
538 |
539 | def forward(self, image, text):
540 | image_features = self.encode_image(image)
541 | text_features = self.encode_text(text)
542 |
543 | # normalized features
544 | image_features = image_features / image_features.norm(dim=-1, keepdim=True)
545 | text_features = text_features / text_features.norm(dim=-1, keepdim=True)
546 |
547 | # cosine similarity as logits
548 | logit_scale = self.logit_scale.exp()
549 | logits_per_image = logit_scale * image_features @ text_features.t()
550 | logits_per_text = logit_scale * text_features @ image_features.t()
551 |
552 | # shape = [global_batch_size, global_batch_size]
553 | return logits_per_image, logits_per_text
554 |
555 |
556 | def convert_weights(model: nn.Module):
557 | """Convert applicable model parameters to fp16"""
558 |
559 | def _convert_weights_to_fp16(l):
560 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
561 | l.weight.data = l.weight.data.half()
562 | if l.bias is not None:
563 | l.bias.data = l.bias.data.half()
564 |
565 | if isinstance(l, nn.MultiheadAttention):
566 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
567 | tensor = getattr(l, attr)
568 | if tensor is not None:
569 | tensor.data = tensor.data.half()
570 |
571 | for name in ["text_projection", "proj"]:
572 | if hasattr(l, name):
573 | attr = getattr(l, name)
574 | if attr is not None:
575 | attr.data = attr.data.half()
576 |
577 | model.apply(_convert_weights_to_fp16)
578 |
579 |
580 | def build_model(state_dict: dict):
581 | vit = "visual.proj" in state_dict
582 |
583 | if vit:
584 | vision_width = state_dict["visual.conv1.weight"].shape[0]
585 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
586 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
587 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
588 | image_resolution = vision_patch_size * grid_size
589 | else:
590 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
591 | vision_layers = tuple(counts)
592 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
593 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
594 | vision_patch_size = None
595 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
596 | image_resolution = output_width * 32
597 |
598 | embed_dim = state_dict["text_projection"].shape[1]
599 | context_length = state_dict["positional_embedding"].shape[0]
600 | vocab_size = state_dict["token_embedding.weight"].shape[0]
601 | transformer_width = state_dict["ln_final.weight"].shape[0]
602 | transformer_heads = transformer_width // 64
603 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
604 |
605 | model = CLIP(
606 | embed_dim,
607 | image_resolution, vision_layers, vision_width, vision_patch_size,
608 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
609 | )
610 |
611 | for key in ["input_resolution", "context_length", "vocab_size"]:
612 | if key in state_dict:
613 | del state_dict[key]
614 |
615 | convert_weights(model)
616 | model.load_state_dict(state_dict)
617 | return model.eval()
618 |
619 | import html
620 | from functools import lru_cache
621 |
622 | import ftfy
623 | import regex as re
624 |
625 |
626 | @lru_cache()
627 | def default_bpe():
628 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "data/bpe_simple_vocab_16e6.txt")
629 |
630 |
631 | @lru_cache()
632 | def bytes_to_unicode():
633 | """
634 | Returns list of utf-8 byte and a corresponding list of unicode strings.
635 | The reversible bpe codes work on unicode strings.
636 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
637 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
638 | This is a signficant percentage of your normal, say, 32K bpe vocab.
639 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
640 | And avoids mapping to whitespace/control characters the bpe code barfs on.
641 | """
642 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
643 | cs = bs[:]
644 | n = 0
645 | for b in range(2**8):
646 | if b not in bs:
647 | bs.append(b)
648 | cs.append(2**8+n)
649 | n += 1
650 | cs = [chr(n) for n in cs]
651 | return dict(zip(bs, cs))
652 |
653 |
654 | def get_pairs(word):
655 | """Return set of symbol pairs in a word.
656 | Word is represented as tuple of symbols (symbols being variable-length strings).
657 | """
658 | pairs = set()
659 | prev_char = word[0]
660 | for char in word[1:]:
661 | pairs.add((prev_char, char))
662 | prev_char = char
663 | return pairs
664 |
665 |
666 | def basic_clean(text):
667 | text = ftfy.fix_text(text)
668 | text = html.unescape(html.unescape(text))
669 | return text.strip()
670 |
671 |
672 | def whitespace_clean(text):
673 | text = re.sub(r'\s+', ' ', text)
674 | text = text.strip()
675 | return text
676 |
677 |
678 | class SimpleTokenizer(object):
679 | def __init__(self, bpe_path: str = default_bpe()):
680 | self.byte_encoder = bytes_to_unicode()
681 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
682 | merges = Path(bpe_path).read_text(encoding='utf8').split('\n')
683 | merges = merges[1:49152-256-2+1]
684 | merges = [tuple(merge.split()) for merge in merges]
685 | vocab = list(bytes_to_unicode().values())
686 | vocab = vocab + [v+'' for v in vocab]
687 | for merge in merges:
688 | vocab.append(''.join(merge))
689 | vocab.extend(['<|startoftext|>', '<|endoftext|>'])
690 | self.encoder = dict(zip(vocab, range(len(vocab))))
691 | self.decoder = {v: k for k, v in self.encoder.items()}
692 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
693 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
694 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
695 |
696 | def bpe(self, token):
697 | if token in self.cache:
698 | return self.cache[token]
699 | word = tuple(token[:-1]) + ( token[-1] + '',)
700 | pairs = get_pairs(word)
701 |
702 | if not pairs:
703 | return token+''
704 |
705 | while True:
706 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
707 | if bigram not in self.bpe_ranks:
708 | break
709 | first, second = bigram
710 | new_word = []
711 | i = 0
712 | while i < len(word):
713 | try:
714 | j = word.index(first, i)
715 | new_word.extend(word[i:j])
716 | i = j
717 | except:
718 | new_word.extend(word[i:])
719 | break
720 |
721 | if word[i] == first and i < len(word)-1 and word[i+1] == second:
722 | new_word.append(first+second)
723 | i += 2
724 | else:
725 | new_word.append(word[i])
726 | i += 1
727 | new_word = tuple(new_word)
728 | word = new_word
729 | if len(word) == 1:
730 | break
731 | else:
732 | pairs = get_pairs(word)
733 | word = ' '.join(word)
734 | self.cache[token] = word
735 | return word
736 |
737 | def encode(self, text):
738 | bpe_tokens = []
739 | text = whitespace_clean(basic_clean(text)).lower()
740 | for token in re.findall(self.pat, text):
741 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
742 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
743 | return bpe_tokens
744 |
745 | def decode(self, tokens):
746 | text = ''.join([self.decoder[token] for token in tokens])
747 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
748 | return text
749 | _tokenizer = SimpleTokenizer()
750 |
--------------------------------------------------------------------------------
/deep_daze/deep_daze.py:
--------------------------------------------------------------------------------
1 | import os
2 | import subprocess
3 | import sys
4 | import random
5 | from datetime import datetime
6 | from pathlib import Path
7 |
8 | import torch
9 | import torch.nn.functional as F
10 | from siren_pytorch import SirenNet, SirenWrapper
11 | from torch import nn
12 | from torch.cuda.amp import GradScaler, autocast
13 | from torch_optimizer import DiffGrad, AdamP
14 | import numpy as np
15 |
16 | from PIL import Image
17 | from imageio import imread, mimsave
18 | import torchvision.transforms as T
19 |
20 |
21 | from tqdm import trange, tqdm
22 |
23 | from .clip import load, tokenize
24 |
25 |
26 | # Helpers
27 |
28 | def exists(val):
29 | return val is not None
30 |
31 |
32 | def default(val, d):
33 | return val if exists(val) else d
34 |
35 |
36 | def interpolate(image, size):
37 | return F.interpolate(image, (size, size), mode='bilinear', align_corners=False)
38 |
39 |
40 | def rand_cutout(image, size, center_bias=False, center_focus=2):
41 | width = image.shape[-1]
42 | min_offset = 0
43 | max_offset = width - size
44 | if center_bias:
45 | # sample around image center
46 | center = max_offset / 2
47 | std = center / center_focus
48 | offset_x = int(random.gauss(mu=center, sigma=std))
49 | offset_y = int(random.gauss(mu=center, sigma=std))
50 | # resample uniformly if over boundaries
51 | offset_x = random.randint(min_offset, max_offset) if (offset_x > max_offset or offset_x < min_offset) else offset_x
52 | offset_y = random.randint(min_offset, max_offset) if (offset_y > max_offset or offset_y < min_offset) else offset_y
53 | else:
54 | offset_x = random.randint(min_offset, max_offset)
55 | offset_y = random.randint(min_offset, max_offset)
56 | cutout = image[:, :, offset_x:offset_x + size, offset_y:offset_y + size]
57 | return cutout
58 |
59 |
60 | def create_clip_img_transform(image_width):
61 | clip_mean = [0.48145466, 0.4578275, 0.40821073]
62 | clip_std = [0.26862954, 0.26130258, 0.27577711]
63 | transform = T.Compose([
64 | #T.ToPILImage(),
65 | T.Resize(image_width),
66 | T.CenterCrop((image_width, image_width)),
67 | T.ToTensor(),
68 | T.Normalize(mean=clip_mean, std=clip_std)
69 | ])
70 | return transform
71 |
72 |
73 | def open_folder(path):
74 | if os.path.isfile(path):
75 | path = os.path.dirname(path)
76 |
77 | if not os.path.isdir(path):
78 | return
79 |
80 | cmd_list = None
81 | if sys.platform == 'darwin':
82 | cmd_list = ['open', '--', path]
83 | elif sys.platform == 'linux2' or sys.platform == 'linux':
84 | cmd_list = ['xdg-open', path]
85 | elif sys.platform in ['win32', 'win64']:
86 | cmd_list = ['explorer', path.replace('/', '\\')]
87 | if cmd_list is None:
88 | return
89 |
90 | try:
91 | subprocess.check_call(cmd_list)
92 | except subprocess.CalledProcessError:
93 | pass
94 | except OSError:
95 | pass
96 |
97 |
98 | def norm_siren_output(img):
99 | return ((img + 1) * 0.5).clamp(0.0, 1.0)
100 |
101 |
102 | def create_text_path(context_length, text=None, img=None, encoding=None, separator=None):
103 | if text is not None:
104 | if separator is not None and separator in text:
105 | #Reduces filename to first epoch text
106 | text = text[:text.index(separator, )]
107 | input_name = text.replace(" ", "_")[:context_length]
108 | elif img is not None:
109 | if isinstance(img, str):
110 | input_name = "".join(img.replace(" ", "_").split(".")[:-1])
111 | else:
112 | input_name = "PIL_img"
113 | else:
114 | input_name = "your_encoding"
115 | return input_name
116 |
117 |
118 | class DeepDaze(nn.Module):
119 | def __init__(
120 | self,
121 | clip_perceptor,
122 | clip_norm,
123 | input_res,
124 | total_batches,
125 | batch_size,
126 | num_layers=8,
127 | image_width=512,
128 | loss_coef=100,
129 | theta_initial=None,
130 | theta_hidden=None,
131 | lower_bound_cutout=0.1, # should be smaller than 0.8
132 | upper_bound_cutout=1.0,
133 | saturate_bound=False,
134 | gauss_sampling=False,
135 | gauss_mean=0.6,
136 | gauss_std=0.2,
137 | do_cutout=True,
138 | center_bias=False,
139 | center_focus=2,
140 | hidden_size=256,
141 | averaging_weight=0.3,
142 | ):
143 | super().__init__()
144 | # load clip
145 | self.perceptor = clip_perceptor
146 | self.input_resolution = input_res
147 | self.normalize_image = clip_norm
148 |
149 | self.loss_coef = loss_coef
150 | self.image_width = image_width
151 |
152 | self.batch_size = batch_size
153 | self.total_batches = total_batches
154 | self.num_batches_processed = 0
155 |
156 | w0 = default(theta_hidden, 30.)
157 | w0_initial = default(theta_initial, 30.)
158 |
159 | siren = SirenNet(
160 | dim_in=2,
161 | dim_hidden=hidden_size,
162 | num_layers=num_layers,
163 | dim_out=3,
164 | use_bias=True,
165 | w0=w0,
166 | w0_initial=w0_initial
167 | )
168 |
169 | self.model = SirenWrapper(
170 | siren,
171 | image_width=image_width,
172 | image_height=image_width
173 | )
174 |
175 | self.saturate_bound = saturate_bound
176 | self.saturate_limit = 0.75 # cutouts above this value lead to destabilization
177 | self.lower_bound_cutout = lower_bound_cutout
178 | self.upper_bound_cutout = upper_bound_cutout
179 | self.gauss_sampling = gauss_sampling
180 | self.gauss_mean = gauss_mean
181 | self.gauss_std = gauss_std
182 | self.do_cutout = do_cutout
183 | self.center_bias = center_bias
184 | self.center_focus = center_focus
185 | self.averaging_weight = averaging_weight
186 |
187 | def sample_sizes(self, lower, upper, width, gauss_mean):
188 | if self.gauss_sampling:
189 | gauss_samples = torch.zeros(self.batch_size).normal_(mean=gauss_mean, std=self.gauss_std)
190 | outside_bounds_mask = (gauss_samples > upper) | (gauss_samples < upper)
191 | gauss_samples[outside_bounds_mask] = torch.zeros((len(gauss_samples[outside_bounds_mask]),)).uniform_(lower, upper)
192 | sizes = (gauss_samples * width).int()
193 | else:
194 | lower *= width
195 | upper *= width
196 | sizes = torch.randint(int(lower), int(upper), (self.batch_size,))
197 | return sizes
198 |
199 | def forward(self, text_embed, return_loss=True, dry_run=False):
200 | out = self.model()
201 | out = norm_siren_output(out)
202 |
203 | if not return_loss:
204 | return out
205 |
206 | # determine upper and lower sampling bound
207 | width = out.shape[-1]
208 | lower_bound = self.lower_bound_cutout
209 | if self.saturate_bound:
210 | progress_fraction = self.num_batches_processed / self.total_batches
211 | lower_bound += (self.saturate_limit - self.lower_bound_cutout) * progress_fraction
212 |
213 | # sample cutout sizes between lower and upper bound
214 | sizes = self.sample_sizes(lower_bound, self.upper_bound_cutout, width, self.gauss_mean)
215 |
216 | # create normalized random cutouts
217 | if self.do_cutout:
218 | image_pieces = [rand_cutout(out, size, center_bias=self.center_bias, center_focus=self.center_focus) for size in sizes]
219 | image_pieces = [interpolate(piece, self.input_resolution) for piece in image_pieces]
220 | else:
221 | image_pieces = [interpolate(out.clone(), self.input_resolution) for _ in sizes]
222 |
223 | # normalize
224 | image_pieces = torch.cat([self.normalize_image(piece) for piece in image_pieces])
225 |
226 | # calc image embedding
227 | with autocast(enabled=False):
228 | image_embed = self.perceptor.encode_image(image_pieces)
229 |
230 | # calc loss
231 | # loss over averaged features of cutouts
232 | avg_image_embed = image_embed.mean(dim=0).unsqueeze(0)
233 | averaged_loss = -self.loss_coef * torch.cosine_similarity(text_embed, avg_image_embed, dim=-1).mean()
234 | # loss over all cutouts
235 | general_loss = -self.loss_coef * torch.cosine_similarity(text_embed, image_embed, dim=-1).mean()
236 | # merge losses
237 | loss = averaged_loss * (self.averaging_weight) + general_loss * (1 - self.averaging_weight)
238 |
239 | # count batches
240 | if not dry_run:
241 | self.num_batches_processed += self.batch_size
242 |
243 | return out, loss
244 |
245 |
246 | class Imagine(nn.Module):
247 | def __init__(
248 | self,
249 | *,
250 | text=None,
251 | img=None,
252 | clip_encoding=None,
253 | lr=1e-5,
254 | batch_size=4,
255 | gradient_accumulate_every=4,
256 | save_every=100,
257 | image_width=512,
258 | num_layers=16,
259 | epochs=20,
260 | iterations=1050,
261 | save_progress=True,
262 | seed=None,
263 | open_folder=True,
264 | save_date_time=False,
265 | start_image_path=None,
266 | start_image_train_iters=10,
267 | start_image_lr=3e-4,
268 | theta_initial=None,
269 | theta_hidden=None,
270 | model_name="ViT-B/32",
271 | lower_bound_cutout=0.1, # should be smaller than 0.8
272 | upper_bound_cutout=1.0,
273 | saturate_bound=False,
274 | averaging_weight=0.3,
275 |
276 | create_story=False,
277 | story_start_words=5,
278 | story_words_per_epoch=5,
279 | story_separator=None,
280 | gauss_sampling=False,
281 | gauss_mean=0.6,
282 | gauss_std=0.2,
283 | do_cutout=True,
284 | center_bias=False,
285 | center_focus=2,
286 | optimizer="AdamP",
287 | jit=True,
288 | hidden_size=256,
289 | save_gif=False,
290 | save_video=False,
291 | ):
292 |
293 | super().__init__()
294 |
295 | if exists(seed):
296 | tqdm.write(f'setting seed: {seed}')
297 | torch.manual_seed(seed)
298 | torch.cuda.manual_seed(seed)
299 | random.seed(seed)
300 | torch.backends.cudnn.deterministic = True
301 |
302 | # fields for story creation:
303 | self.create_story = create_story
304 | self.words = None
305 | self.separator = str(story_separator) if story_separator is not None else None
306 | if self.separator is not None and text is not None:
307 | #exit if text is just the separator
308 | if str(text).replace(' ','').replace(self.separator,'') == '':
309 | print('Exiting because the text only consists of the separator! Needs words or phrases that are separated by the separator.')
310 | exit()
311 | #adds a space to each separator and removes double spaces that might be generated
312 | text = text.replace(self.separator,self.separator+' ').replace(' ',' ').strip()
313 | self.all_words = text.split(" ") if text is not None else None
314 | self.num_start_words = story_start_words
315 | self.words_per_epoch = story_words_per_epoch
316 | if create_story:
317 | assert text is not None, "We need text input to create a story..."
318 | # overwrite epochs to match story length
319 | num_words = len(self.all_words)
320 | self.epochs = 1 + (num_words - self.num_start_words) / self.words_per_epoch
321 | # add one epoch if not divisible
322 | self.epochs = int(self.epochs) if int(self.epochs) == self.epochs else int(self.epochs) + 1
323 | if self.separator is not None:
324 | if self.separator not in text:
325 | print("Separator '"+self.separator+"' will be ignored since not in text!")
326 | self.separator = None
327 | else:
328 | self.epochs = len(list(filter(None,text.split(self.separator))))
329 | print("Running for", self.epochs, "epochs" + (" (split with '"+self.separator+"' as the separator)" if self.separator is not None else ""))
330 | else:
331 | self.epochs = epochs
332 |
333 | # jit models only compatible with version 1.7.1
334 | if "1.7.1" not in torch.__version__:
335 | if jit == True:
336 | print("Setting jit to False because torch version is not 1.7.1.")
337 | jit = False
338 |
339 | # Load CLIP
340 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
341 | clip_perceptor, norm = load(model_name, jit=jit, device=self.device)
342 | self.perceptor = clip_perceptor.eval()
343 | for param in self.perceptor.parameters():
344 | param.requires_grad = False
345 | if jit == False:
346 | input_res = clip_perceptor.visual.input_resolution
347 | else:
348 | input_res = clip_perceptor.input_resolution.item()
349 | self.clip_transform = create_clip_img_transform(input_res)
350 |
351 | self.iterations = iterations
352 | self.image_width = image_width
353 | total_batches = self.epochs * self.iterations * batch_size * gradient_accumulate_every
354 | model = DeepDaze(
355 | self.perceptor,
356 | norm,
357 | input_res,
358 | total_batches,
359 | batch_size=batch_size,
360 | image_width=image_width,
361 | num_layers=num_layers,
362 | theta_initial=theta_initial,
363 | theta_hidden=theta_hidden,
364 | lower_bound_cutout=lower_bound_cutout,
365 | upper_bound_cutout=upper_bound_cutout,
366 | saturate_bound=saturate_bound,
367 | gauss_sampling=gauss_sampling,
368 | gauss_mean=gauss_mean,
369 | gauss_std=gauss_std,
370 | do_cutout=do_cutout,
371 | center_bias=center_bias,
372 | center_focus=center_focus,
373 | hidden_size=hidden_size,
374 | averaging_weight=averaging_weight,
375 | ).to(self.device)
376 | self.model = model
377 | self.scaler = GradScaler()
378 | siren_params = model.model.parameters()
379 | if optimizer == "AdamP":
380 | self.optimizer = AdamP(siren_params, lr)
381 | elif optimizer == "Adam":
382 | self.optimizer = torch.optim.Adam(siren_params, lr)
383 | elif optimizer == "DiffGrad":
384 | self.optimizer = DiffGrad(siren_params, lr)
385 | self.gradient_accumulate_every = gradient_accumulate_every
386 | self.save_every = save_every
387 | self.save_date_time = save_date_time
388 | self.open_folder = open_folder
389 | self.save_progress = save_progress
390 | self.text = text
391 | self.image = img
392 | self.textpath = create_text_path(self.perceptor.context_length, text=text, img=img, encoding=clip_encoding, separator=story_separator)
393 | self.filename = self.image_output_path()
394 |
395 | # create coding to optimize for
396 | self.clip_encoding = self.create_clip_encoding(text=text, img=img, encoding=clip_encoding)
397 |
398 | self.start_image = None
399 | self.start_image_train_iters = start_image_train_iters
400 | self.start_image_lr = start_image_lr
401 | if exists(start_image_path):
402 | file = Path(start_image_path)
403 | assert file.exists(), f'file does not exist at given starting image path {start_image_path}'
404 | image = Image.open(str(file))
405 | start_img_transform = T.Compose([T.Resize(image_width),
406 | T.CenterCrop((image_width, image_width)),
407 | T.ToTensor()])
408 | image_tensor = start_img_transform(image).unsqueeze(0).to(self.device)
409 | self.start_image = image_tensor
410 |
411 | self.save_gif = save_gif
412 | self.save_video = save_video
413 |
414 | def create_clip_encoding(self, text=None, img=None, encoding=None):
415 | self.text = text
416 | self.img = img
417 | if encoding is not None:
418 | encoding = encoding.to(self.device)
419 | elif self.create_story:
420 | encoding = self.update_story_encoding(epoch=0, iteration=1)
421 | elif text is not None and img is not None:
422 | encoding = (self.create_text_encoding(text) + self.create_img_encoding(img)) / 2
423 | elif text is not None:
424 | encoding = self.create_text_encoding(text)
425 | elif img is not None:
426 | encoding = self.create_img_encoding(img)
427 | return encoding
428 |
429 | def create_text_encoding(self, text):
430 | tokenized_text = tokenize(text).to(self.device)
431 | with torch.no_grad():
432 | text_encoding = self.perceptor.encode_text(tokenized_text).detach()
433 | return text_encoding
434 |
435 | def create_img_encoding(self, img):
436 | if isinstance(img, str):
437 | img = Image.open(img)
438 | normed_img = self.clip_transform(img).unsqueeze(0).to(self.device)
439 | with torch.no_grad():
440 | img_encoding = self.perceptor.encode_image(normed_img).detach()
441 | return img_encoding
442 |
443 | def set_clip_encoding(self, text=None, img=None, encoding=None):
444 | encoding = self.create_clip_encoding(text=text, img=img, encoding=encoding)
445 | self.clip_encoding = encoding.to(self.device)
446 |
447 | def index_of_first_separator(self) -> int:
448 | for c, word in enumerate(self.all_words):
449 | if self.separator in str(word):
450 | return c +1
451 |
452 | def update_story_encoding(self, epoch, iteration):
453 | if self.separator is not None:
454 | self.words = " ".join(self.all_words[:self.index_of_first_separator()])
455 | #removes separator from epoch-text
456 | self.words = self.words.replace(self.separator,'')
457 | self.all_words = self.all_words[self.index_of_first_separator():]
458 | else:
459 | if self.words is None:
460 | self.words = " ".join(self.all_words[:self.num_start_words])
461 | self.all_words = self.all_words[self.num_start_words:]
462 | else:
463 | # add words_per_epoch new words
464 | count = 0
465 | while count < self.words_per_epoch and len(self.all_words) > 0:
466 | new_word = self.all_words[0]
467 | self.words = " ".join(self.words.split(" ") + [new_word])
468 | self.all_words = self.all_words[1:]
469 | count += 1
470 | # remove words until it fits in context length
471 | while len(self.words) > self.perceptor.context_length:
472 | # remove first word
473 | self.words = " ".join(self.words.split(" ")[1:])
474 | # get new encoding
475 | print("Now thinking of: ", '"', self.words, '"')
476 | sequence_number = self.get_img_sequence_number(epoch, iteration)
477 | # save new words to disc
478 | with open("story_transitions.txt", "a") as f:
479 | f.write(f"{epoch}, {sequence_number}, {self.words}\n")
480 |
481 | encoding = self.create_text_encoding(self.words)
482 | return encoding
483 |
484 | def image_output_path(self, sequence_number=None):
485 | """
486 | Returns underscore separated Path.
487 | A current timestamp is prepended if `self.save_date_time` is set.
488 | Sequence number left padded with 6 zeroes is appended if `save_every` is set.
489 | :rtype: Path
490 | """
491 | output_path = self.textpath
492 | if sequence_number:
493 | sequence_number_left_padded = str(sequence_number).zfill(6)
494 | output_path = f"{output_path}.{sequence_number_left_padded}"
495 | if self.save_date_time:
496 | current_time = datetime.now().strftime("%y%m%d-%H%M%S_%f")
497 | output_path = f"{current_time}_{output_path}"
498 | return Path(f"{output_path}.jpg")
499 |
500 | def train_step(self, epoch, iteration):
501 | total_loss = 0
502 |
503 | for _ in range(self.gradient_accumulate_every):
504 | with autocast(enabled=True):
505 | out, loss = self.model(self.clip_encoding)
506 | loss = loss / self.gradient_accumulate_every
507 | total_loss += loss
508 | self.scaler.scale(loss).backward()
509 | out = out.cpu().float().clamp(0., 1.)
510 | self.scaler.step(self.optimizer)
511 | self.scaler.update()
512 | self.optimizer.zero_grad()
513 |
514 | if (iteration % self.save_every == 0) and self.save_progress:
515 | self.save_image(epoch, iteration, img=out)
516 |
517 | return out, total_loss
518 |
519 | def get_img_sequence_number(self, epoch, iteration):
520 | current_total_iterations = epoch * self.iterations + iteration
521 | sequence_number = current_total_iterations // self.save_every
522 | return sequence_number
523 |
524 | @torch.no_grad()
525 | def save_image(self, epoch, iteration, img=None):
526 | sequence_number = self.get_img_sequence_number(epoch, iteration)
527 |
528 | if img is None:
529 | img = self.model(self.clip_encoding, return_loss=False).cpu().float().clamp(0., 1.)
530 | self.filename = self.image_output_path(sequence_number=sequence_number)
531 |
532 | pil_img = T.ToPILImage()(img.squeeze())
533 | pil_img.save(self.filename, quality=95, subsampling=0)
534 | pil_img.save(f"{self.textpath}.jpg", quality=95, subsampling=0)
535 |
536 | tqdm.write(f'image updated at "./{str(self.filename)}"')
537 |
538 | def generate_gif(self):
539 | images = []
540 | for file_name in sorted(os.listdir('./')):
541 | if file_name.startswith(self.textpath) and file_name != f'{self.textpath}.jpg':
542 | images.append(imread(os.path.join('./', file_name)))
543 |
544 | if self.save_video:
545 | mimsave(f'{self.textpath}.mp4', images)
546 | print(f'Generated image generation animation at ./{self.textpath}.mp4')
547 | if self.save_gif:
548 | mimsave(f'{self.textpath}.gif', images)
549 | print(f'Generated image generation animation at ./{self.textpath}.gif')
550 |
551 | def forward(self):
552 | if exists(self.start_image):
553 | tqdm.write('Preparing with initial image...')
554 | optim = DiffGrad(self.model.model.parameters(), lr = self.start_image_lr)
555 | pbar = trange(self.start_image_train_iters, desc='iteration')
556 | try:
557 | for _ in pbar:
558 | loss = self.model.model(self.start_image)
559 | loss.backward()
560 | pbar.set_description(f'loss: {loss.item():.2f}')
561 |
562 | optim.step()
563 | optim.zero_grad()
564 | except KeyboardInterrupt:
565 | print('interrupted by keyboard, gracefully exiting')
566 | return exit()
567 |
568 | del self.start_image
569 | del optim
570 |
571 | tqdm.write(f'Imagining "{self.textpath}" from the depths of my weights...')
572 |
573 | with torch.no_grad():
574 | self.model(self.clip_encoding, dry_run=True) # do one warmup step due to potential issue with CLIP and CUDA
575 |
576 | if self.open_folder:
577 | open_folder('./')
578 | self.open_folder = False
579 |
580 | try:
581 | for epoch in trange(self.epochs, desc='epochs'):
582 | pbar = trange(self.iterations, desc='iteration')
583 | for i in pbar:
584 | _, loss = self.train_step(epoch, i)
585 | pbar.set_description(f'loss: {loss.item():.2f}')
586 |
587 | # Update clip_encoding per epoch if we are creating a story
588 | if self.create_story:
589 | self.clip_encoding = self.update_story_encoding(epoch, i)
590 | except KeyboardInterrupt:
591 | print('interrupted by keyboard, gracefully exiting')
592 | return
593 |
594 | self.save_image(epoch, i) # one final save at end
595 |
596 | if (self.save_gif or self.save_video) and self.save_progress:
597 | self.generate_gif()
598 |
--------------------------------------------------------------------------------
/deep_daze/version.py:
--------------------------------------------------------------------------------
1 | __version__ = '0.11.1'
2 |
--------------------------------------------------------------------------------
/instruction_images/Windows/Step_1_DD_Win.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/deep-daze/c3c471e63c30ccabfd8dfc09ced3028a8979ebe4/instruction_images/Windows/Step_1_DD_Win.png
--------------------------------------------------------------------------------
/instruction_images/Windows/Step_2_DD_Win.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/deep-daze/c3c471e63c30ccabfd8dfc09ced3028a8979ebe4/instruction_images/Windows/Step_2_DD_Win.png
--------------------------------------------------------------------------------
/samples/A_man_painting_a_completely_red_image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/deep-daze/c3c471e63c30ccabfd8dfc09ced3028a8979ebe4/samples/A_man_painting_a_completely_red_image.png
--------------------------------------------------------------------------------
/samples/A_psychedelic_experience_on_LSD.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/deep-daze/c3c471e63c30ccabfd8dfc09ced3028a8979ebe4/samples/A_psychedelic_experience_on_LSD.png
--------------------------------------------------------------------------------
/samples/A_time_traveler_in_the_crowd.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/deep-daze/c3c471e63c30ccabfd8dfc09ced3028a8979ebe4/samples/A_time_traveler_in_the_crowd.jpg
--------------------------------------------------------------------------------
/samples/Autumn_1875_Frederic_Edwin_Church.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/deep-daze/c3c471e63c30ccabfd8dfc09ced3028a8979ebe4/samples/Autumn_1875_Frederic_Edwin_Church.jpg
--------------------------------------------------------------------------------
/samples/Autumn_1875_Frederic_Edwin_Church_original.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/deep-daze/c3c471e63c30ccabfd8dfc09ced3028a8979ebe4/samples/Autumn_1875_Frederic_Edwin_Church_original.jpg
--------------------------------------------------------------------------------
/samples/Cosmic_love_and_attention.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/deep-daze/c3c471e63c30ccabfd8dfc09ced3028a8979ebe4/samples/Cosmic_love_and_attention.jpg
--------------------------------------------------------------------------------
/samples/Life_during_the_plague.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/deep-daze/c3c471e63c30ccabfd8dfc09ced3028a8979ebe4/samples/Life_during_the_plague.jpg
--------------------------------------------------------------------------------
/samples/Meditative_peace_in_a_sunlit_forest.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/deep-daze/c3c471e63c30ccabfd8dfc09ced3028a8979ebe4/samples/Meditative_peace_in_a_sunlit_forest.jpg
--------------------------------------------------------------------------------
/samples/Mist_over_green_hills.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/deep-daze/c3c471e63c30ccabfd8dfc09ced3028a8979ebe4/samples/Mist_over_green_hills.jpg
--------------------------------------------------------------------------------
/samples/Shattered_plates_on_the_grass.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/deep-daze/c3c471e63c30ccabfd8dfc09ced3028a8979ebe4/samples/Shattered_plates_on_the_grass.jpg
--------------------------------------------------------------------------------
/samples/cosmic-love.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/deep-daze/c3c471e63c30ccabfd8dfc09ced3028a8979ebe4/samples/cosmic-love.png
--------------------------------------------------------------------------------
/samples/hot-dog.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/deep-daze/c3c471e63c30ccabfd8dfc09ced3028a8979ebe4/samples/hot-dog.jpg
--------------------------------------------------------------------------------
/samples/hot-dog_imagined.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/deep-daze/c3c471e63c30ccabfd8dfc09ced3028a8979ebe4/samples/hot-dog_imagined.png
--------------------------------------------------------------------------------
/samples/life-plague.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/deep-daze/c3c471e63c30ccabfd8dfc09ced3028a8979ebe4/samples/life-plague.png
--------------------------------------------------------------------------------
/samples/mist-over-green-hills.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/deep-daze/c3c471e63c30ccabfd8dfc09ced3028a8979ebe4/samples/mist-over-green-hills.png
--------------------------------------------------------------------------------
/samples/peace-sunlit-forest.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/deep-daze/c3c471e63c30ccabfd8dfc09ced3028a8979ebe4/samples/peace-sunlit-forest.png
--------------------------------------------------------------------------------
/samples/prime-orig.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/deep-daze/c3c471e63c30ccabfd8dfc09ced3028a8979ebe4/samples/prime-orig.jpg
--------------------------------------------------------------------------------
/samples/prime-trained.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/deep-daze/c3c471e63c30ccabfd8dfc09ced3028a8979ebe4/samples/prime-trained.png
--------------------------------------------------------------------------------
/samples/psychedelic_hot_dog.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/deep-daze/c3c471e63c30ccabfd8dfc09ced3028a8979ebe4/samples/psychedelic_hot_dog.png
--------------------------------------------------------------------------------
/samples/shattered-plates.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/deep-daze/c3c471e63c30ccabfd8dfc09ced3028a8979ebe4/samples/shattered-plates.png
--------------------------------------------------------------------------------
/samples/time-traveler.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/deep-daze/c3c471e63c30ccabfd8dfc09ced3028a8979ebe4/samples/time-traveler.png
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from setuptools import setup, find_packages
3 |
4 | sys.path[0:0] = ['deep_daze']
5 | from version import __version__
6 |
7 | setup(
8 | name = 'deep-daze',
9 | packages = find_packages(),
10 | include_package_data = True,
11 | entry_points={
12 | 'console_scripts': [
13 | 'imagine = deep_daze.cli:main',
14 | ],
15 | },
16 | version = __version__,
17 | license='MIT',
18 | description = 'Deep Daze',
19 | author = 'Ryan Murdock, Phil Wang',
20 | author_email = 'lucidrains@gmail.com',
21 | url = 'https://github.com/lucidrains/deep-daze',
22 | keywords = [
23 | 'artificial intelligence',
24 | 'deep learning',
25 | 'transformers',
26 | 'implicit neural representations',
27 | 'text to image'
28 | ],
29 | install_requires=[
30 | 'einops>=0.3',
31 | 'fire',
32 | 'ftfy',
33 | 'imageio>=2.9.0',
34 | 'siren-pytorch>=0.0.8',
35 | 'torch>=1.10',
36 | 'torch_optimizer',
37 | 'torchvision>=0.8.2',
38 | 'tqdm',
39 | 'regex'
40 | ],
41 | classifiers=[
42 | 'Development Status :: 4 - Beta',
43 | 'Intended Audience :: Developers',
44 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
45 | 'License :: OSI Approved :: MIT License',
46 | 'Programming Language :: Python :: 3.6',
47 | ],
48 | )
49 |
--------------------------------------------------------------------------------