├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── arnheim_1.ipynb ├── arnheim_2.ipynb ├── arnheim_3.ipynb ├── arnheim_3 ├── README.md ├── configs │ ├── .DS_Store │ ├── config.yaml │ ├── config_compositional.yaml │ ├── config_compositional_tiled.yaml │ ├── config_debug.yaml │ ├── config_multipatch.yaml │ ├── config_simple.yaml │ └── config_simple_search.yaml ├── main.py ├── requirements.txt └── src │ ├── collage.py │ ├── collage_generator.py │ ├── patches.py │ ├── rendering.py │ ├── training.py │ ├── transformations.py │ └── video_utils.py ├── arnheim_3_patch_maker.ipynb ├── collage_patches ├── animals.npy ├── fruit.npy ├── handwritten_mnist.npy └── shore_glass.npy └── images ├── arnheim3_examples.png ├── bulls_ballet_faces_nature.jpg ├── chicken.png ├── dancer.png ├── face.png ├── fall_of_the_damned.jpg ├── fruit_bowl_animals.png ├── fruit_bowl_fruit.png ├── objects.png ├── swans_masked_transparency.png └── waves.png /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # Distribution / packaging 7 | .Python 8 | build/ 9 | develop-eggs/ 10 | dist/ 11 | downloads/ 12 | eggs/ 13 | .eggs/ 14 | lib/ 15 | lib64/ 16 | parts/ 17 | sdist/ 18 | var/ 19 | wheels/ 20 | share/python-wheels/ 21 | *.egg-info/ 22 | .installed.cfg 23 | *.egg 24 | MANIFEST 25 | 26 | # Temporary files generated by command-line version of Arnheim 3 27 | arnheim_3/*.npy 28 | arnheim_3/output_* 29 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | # Pull Requests 4 | 5 | Please send in fixes or feature additions through Pull Requests. 6 | 7 | ## Contributor License Agreement 8 | 9 | Contributions to this project must be accompanied by a Contributor License 10 | Agreement. You (or your employer) retain the copyright to your contribution, 11 | this simply gives us permission to use and redistribute your contributions as 12 | part of the project. Head over to to see 13 | your current agreements on file or to sign a new one. 14 | 15 | You generally only need to submit a CLA once, so if you've already submitted one 16 | (even if it was for a different project), you probably don't need to do it 17 | again. 18 | 19 | ## Code reviews 20 | 21 | All submissions, including submissions by project members, require review. We 22 | use GitHub pull requests for this purpose. Consult 23 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 24 | information on using pull requests. 25 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Generative Art Using Neural Visual Grammars and Dual Encoders 2 | 3 | ## Arnheim 1 4 | 5 | The original algorithm from the paper 6 | [Generative Art Using Neural Visual Grammars and Dual Encoders](https://arxiv.org/abs/2105.00162) 7 | running on 1 GPU allows optimization of any image using a genetic algorithm. 8 | This is much more general but much slower than using Arnheim 2 which uses 9 | gradients. 10 | 11 | ## Arnheim 2 12 | 13 | A reimplementation of the Arnheim 1 generative architecture in the CLIPDraw 14 | framework allowing optimization of its parameters using gradients. Much more 15 | efficient than Arnheim 1 above but requires differentiating through the image 16 | itself. 17 | 18 | ## Arnheim 3 (aka CLIP-CLOP: CLIP-Guided Collage and Photomontage) 19 | 20 | A spatial transformer-based Arnheim implementation for generating collage images. 21 | It employs a combination of evolution and training to create collages from 22 | opaque to transparent image patches. 23 | 24 | Example patch datasets, with the exception of 'Fruit and veg', are provided under 25 | [CC BY 4.0 licence](https://creativecommons.org/licenses/by/4.0/). 26 | The 'Fruit and veg' patches in `collage_patches/fruit.npy` are based on a subset 27 | of the Kaggle Fruits 360 and are provided under 28 | [CC BY-SA 4.0 licence](https://creativecommons.org/licenses/by-sa/4.0/), 29 | as are all example collages using them. 30 | 31 | ![The Fall of the Damned by Rubens and Eaton.](https://raw.githubusercontent.com/deepmind/arnheim/main/images/fall_of_the_damned.jpg) 32 | ![Collages made of different numbers of tree leaves patches (bulls in the top row), as well as Degas-inspired ballet dancers made from animals, faces made of fruit and still life or landscape made from patches of animals.](https://raw.githubusercontent.com/deepmind/arnheim/main/images/bulls_ballet_faces_nature.jpg) 33 | 34 | ## Usage 35 | 36 | Usage instructions are included in the Colabs which open and run on the 37 | free-to-use Google Colab platform - just click the buttons below! Improved 38 | performance and longer timeouts are available with Colab Pro. 39 | 40 | Arnheim 1 [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/deepmind/arnheim/blob/main/arnheim_1.ipynb) 41 | 42 | Arnheim 2 [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/deepmind/arnheim/blob/main/arnheim_2.ipynb) 43 | 44 | Arnheim 3 [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/deepmind/arnheim/blob/main/arnheim_3.ipynb) 45 | 46 | Arnheim 3 Patch Maker [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/deepmind/arnheim/blob/main/arnheim_3_patch_maker.ipynb) 47 | 48 | ## Video illustration of the CLIP-CLOP Collage and Photomontage Generator (Arnheim 3) 49 | 50 | [![CLIP-CLOP Collage and Photomontage Generator](https://img.youtube.com/vi/VnO4tibP9cg/0.jpg)](https://youtu.be/VnO4tibP9cg) 51 | 52 | 53 | ## Citing this work 54 | 55 | If you use this code (or any derived code), data or these models in your work, 56 | please cite the relevant accompanying papers on [Generative Art Using Neural Visual Grammars and Dual Encoders](https://arxiv.org/abs/2105.00162) 57 | or on [CLIP-CLOP: CLIP-Guided Collage and Photomontage](https://arxiv.org/abs/2205.03146). 58 | 59 | ``` 60 | @misc{fernando2021genart, 61 | title={Generative Art Using Neural Visual Grammars and Dual Encoders}, 62 | author={Chrisantha Fernando and S. M. Ali Eslami and Jean-Baptiste Alayrac and Piotr Mirowski and Dylan Banarse and Simon Osindero} 63 | year={2021}, 64 | eprint={2105.00162}, 65 | archivePrefix={arXiv}, 66 | primaryClass={cs.CV} 67 | } 68 | ``` 69 | ``` 70 | @inproceedings{mirowski2022clip, 71 | title={CLIP-CLOP: CLIP-Guided Collage and Photomontage}, 72 | author={Piotr Mirowski and Dylan Banarse and Mateusz Malinowski and Simon Osindero and Chrisantha Fernando}, 73 | booktitle={Proceedings of the Thirteenth International Conference on Computational Creativity}, 74 | year={2022} 75 | } 76 | ``` 77 | 78 | ## Disclaimer 79 | 80 | This is not an official Google product. 81 | 82 | CLIPDraw provided under license, Copyright 2021 Kevin Frans. 83 | 84 | Other works may be copyright of the authors of such work. 85 | -------------------------------------------------------------------------------- /arnheim_3/README.md: -------------------------------------------------------------------------------- 1 | # Generative Art Using Neural Visual Grammars and Dual Encoders 2 | 3 | ## Arnheim 3 (Command line version) 4 | 5 | A spatial transformer-based Arnheim implementation for generating collage images. 6 | It employs a combination of evolution and training to create collages from 7 | opaque to transparent image patches. Example patch datasets are provided under 8 | [CC BY-SA 4.0 licence](https://creativecommons.org/licenses/by-sa/4.0/). 9 | The 'Fruit and veg' patches are based on a subset of the 10 | [Kaggle Fruits 360](https://www.kaggle.com/moltean/fruits) under the same 11 | license. 12 | 13 | ## Installation 14 | 15 | Clone this GitHub repository and go the `arnheim_3` directory: 16 | ```sh 17 | git clone https://github.com/piotrmirowski/arnheim.git 18 | cd arnheim/arnheim_3 19 | ``` 20 | 21 | Install the required Python libraries: 22 | ```sh 23 | python3 -m pip install -r requirements.txt 24 | ``` 25 | 26 | Install [CLIP](https://github.com/openai/CLIP) from OpenAI's GitHub repository: 27 | ```sh 28 | python3 -m pip install git+https://github.com/openai/CLIP.git --no-deps 29 | ``` 30 | 31 | When using GCP, it might help to enable remote desktop in both your local Chrome browser and on the GCP virtual machine, which can be done following [these instructions](https://cloud.google.com/architecture/chrome-desktop-remote-on-compute-engine#cinnamon). 32 | 33 | ## Usage 34 | 35 | Configuration files are stored in YAML format in subdirectory `configs`. For instance script `configs/config_compositional_tiled.yaml` generates a composttional collage with global prompt `a photorealistic chicken` and 9 local prompts for `sky`, `sun`, `moon`, `tree`, `field` and `chicken`. 36 | 37 | Please refer to `configs/config.yaml` and to the help for explanation about the config. 38 | ```sh 39 | python3 main.py --help 40 | ``` 41 | 42 | To run with CUDA on a GPU accelerator: 43 | ```sh 44 | python3 main.py --config configs/config_compositional.yaml 45 | ``` 46 | 47 | To run without CUDA (e.g., on Mac OS - note this will be considerably slower): 48 | ```sh 49 | python3 main.py --no-cuda --config configs/config_compositional.yaml 50 | ``` 51 | 52 | By default, results are stored in a directory named `output_YYYYMMDD_hhmmss` (based on the timestamp) and contain the config `.yaml` file, and the resulting collage (and tiles) as `.png` and `.npy` files. 53 | 54 | ## Citing this work 55 | 56 | If you use this code (or any derived code), data or these models in your work, 57 | please cite the relevant accompanying [paper](https://arxiv.org/abs/2105.00162). 58 | 59 | ``` 60 | @misc{fernando2021genart, 61 | title={Generative Art Using Neural Visual Grammars and Dual Encoders}, 62 | author={Chrisantha Fernando and S. M. Ali Eslami and Jean-Baptiste Alayrac and Piotr Mirowski and Dylan Banarse and Simon Osindero} 63 | year={2021}, 64 | eprint={2105.00162}, 65 | archivePrefix={arXiv}, 66 | primaryClass={cs.CV} 67 | } 68 | ``` 69 | 70 | ## Disclaimer 71 | 72 | This is not an official Google product. 73 | 74 | CLIPDraw provided under license, Copyright 2021 Kevin Frans. 75 | 76 | Other works may be copyright of the authors of such work. 77 | -------------------------------------------------------------------------------- /arnheim_3/configs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/arnheim/2607ea41607a11e44c3de6348edd13672f85c52a/arnheim_3/configs/.DS_Store -------------------------------------------------------------------------------- /arnheim_3/configs/config.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | ### Default collage configuration 3 | # Canvas 4 | canvas_width: 224 5 | canvas_height: 224 6 | 7 | # Render methods 8 | # opacity patches overlay each other using a combination of alpha and depth, 9 | # transparency _adds_ patch colours (black therefore appearing transparent), 10 | # and masked transparency blends patches using the alpha channel. 11 | render_method: "masked_transparency_clipped" 12 | num_patches: 100 13 | # Colour transformations can be: "none", "RGB space", "HSV space" 14 | colour_transformations: "RGB space" 15 | # Invert image colours to have a white background? 16 | invert_colours: False 17 | high_res_multiplier: 4 18 | 19 | ### Affine transform settings 20 | # Translation bounds for X and Y. 21 | min_trans: -1. 22 | max_trans: 1. 23 | # Scale bounds (> 1 means zoom out and < 1 means zoom in). 24 | min_scale: 1 25 | max_scale: 2 26 | # Bounds on ratio between X and Y scale (default 1). 27 | min_squeeze: 0.5 28 | max_squeeze: 2.0 29 | # Shear deformation bounds (default 0) 30 | min_shear: -0.2 31 | max_shear: 0.2 32 | # Rotation bounds. 33 | min_rot_deg: -180 34 | max_rot_deg: 180 35 | 36 | ### Colour transform settings 37 | # RGB 38 | min_rgb: -0.2 39 | max_rgb: 1.0 40 | initial_min_rgb: 0.5 41 | initial_max_rgb: 1. 42 | # HSV 43 | min_hue_deg: 0. 44 | max_hue_deg: 360 45 | min_sat: 0. 46 | max_sat: 1. 47 | min_val: 0. 48 | max_val: 1. 49 | 50 | ### Training settings 51 | clip_model: "ViT-B/32" 52 | # Number of training steps 53 | optim_steps: 10000 54 | learning_rate: 0.1 55 | trace_every: 50 56 | # Number of augmentations to use in evaluation 57 | use_image_augmentations: True 58 | num_augs: 4 59 | # Normalize colours for CLIP, generally leave this as True 60 | use_normalized_clip: False 61 | # Gradient clipping during optimisation 62 | gradient_clipping: 10.0 63 | # Initial random search size (1 means no search) 64 | initial_search_size: 1 65 | 66 | ### Accelerator settings 67 | torch_device: "cuda" 68 | 69 | ### Evolution settings 70 | # For evolution set POP_SIZE greater than 1 71 | pop_size: 2 72 | evolution_frequency: 100 73 | # Microbial - loser of randomly selected pair is replaced by mutated winner. A low selection pressure. 74 | # Evolutionary Strategies - mutantions of the best individual replace the rest of the population. Much higher selection pressure than Microbial GA. 75 | ga_method: "Microbial" 76 | # ### Mutation levels 77 | # Scale mutation applied to position and rotation, scale, distortion, colour and patch swaps. 78 | pos_and_rot_mutation_scale: 0.02 79 | scale_mutation_scale: 0.02 80 | distort_mutation_scale: 0.02 81 | colour_mutation_scale: 0.02 82 | patch_mutation_probability: 1 83 | # Limit the number of individuals shown during training 84 | max_multiple_visualizations: 5 85 | 86 | ### Load segmented patches 87 | patch_set: "animals.npy" 88 | patch_repo_root: https://storage.googleapis.com/dm_arnheim_3_assets/collage_patches 89 | url_to_patch_file: "" 90 | 91 | ### Resize image patches to low- and high-res. 92 | fixed_scale_patches: True 93 | fixed_scale_coeff: 0.7 94 | normalize_patch_brightness: False 95 | patch_max_proportion: 5 96 | patch_width_min: 16 97 | patch_height_min: 16 98 | 99 | # Configure a background, e.g. uploaded picture or solid colour. 100 | background_url: "" 101 | # Background usage: Global: use image across whole image; Local: reuse same image for every tile 102 | background_use: "Global" 103 | # Colour configuration for solid colour background 104 | background_red: 0 105 | background_green: 0 106 | background_blue: 0 107 | 108 | # @title Configure image prompt and content 109 | # Enter a global description of the image, e.g. 'a photorealistic chicken' 110 | global_prompt: "a photorealistic chicken" 111 | 112 | # @title Tile prompts and tiling settings 113 | tile_images: False 114 | tiles_wide: 1 115 | tiles_high: 1 116 | 117 | # Prompt(s) for tiles 118 | # Global tile prompt uses GLOBAL_PROMPT (previous cell) for *all* tiles (e.g. "Roman mosaic of an unswept floor") 119 | global_tile_prompt: False 120 | 121 | # Otherwise, specify multiple tile prompts with columns separated by | and / to delineate new row. 122 | # E.g. multiple prompts for a 3x2 "landscape" image: "sun | clouds | sky / fields | fields | trees" 123 | tile_prompt_string: "" 124 | 125 | # Composition prompts 126 | # @title Composition prompts (within tiles) 127 | # Use additional prompts for different regions 128 | compositional_image: False 129 | 130 | # Single image (i.e. no tiling) composition prompts 131 | # Specify 3x3 prompts for each composition region (left to right, starting at the top) 132 | prompt_x0_y0: "a photorealistic sky with sun" 133 | prompt_x1_y0: "a photorealistic sky" 134 | prompt_x2_y0: "a photorealistic sky with moon" 135 | prompt_x0_y1: "a photorealistic tree" 136 | prompt_x1_y1: "a photorealistic tree" 137 | prompt_x2_y1: "a photorealistic tree" 138 | prompt_x0_y2: "a photorealistic field" 139 | prompt_x1_y2: "a photorealistic field" 140 | prompt_x2_y2: "a photorealistic chicken" 141 | 142 | # Tile composition prompts 143 | # This string is formated to autogenerate region prompts from tile prompt. e.g. "close-up of {}" 144 | tile_prompt_formating: "close-up of {}" 145 | -------------------------------------------------------------------------------- /arnheim_3/configs/config_compositional.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | # Render methods 3 | # opacity patches overlay each other using a combination of alpha and depth, 4 | # transparency _adds_ patch colours (black therefore appearing transparent), 5 | # and masked transparency blends patches using the alpha channel. 6 | render_method: "masked_transparency_clipped" 7 | num_patches: 100 8 | # Colour transformations can be: "none", "RGB space", "HSV space" 9 | colour_transformations: "RGB space" 10 | 11 | # Number of training steps 12 | optim_steps: 10000 13 | learning_rate: 0.1 14 | trace_every: 50 15 | 16 | ### Load segmented patches 17 | patch_set: "animals.npy" 18 | 19 | ### Resize image patches to low- and high-res. 20 | fixed_scale_patches: True 21 | fixed_scale_coeff: 0.5 22 | 23 | # Configure a background, e.g. uploaded picture or solid colour. 24 | background_url: "" 25 | # Background usage: Global: use image across whole image; Local: reuse same image for every tile 26 | background_use: "Global" 27 | # Colour configuration for solid colour background 28 | background_red: 0 29 | background_green: 0 30 | background_blue: 0 31 | 32 | # Enter a global description of the image, e.g. 'a photorealistic chicken' 33 | global_prompt: "a photorealistic chicken" 34 | 35 | compositional_image: True 36 | 37 | # Single image (i.e. no tiling) composition prompts 38 | # Specify 3x3 prompts for each composition region (left to right, starting at the top) 39 | prompt_x0_y0: "a photorealistic sky with sun" 40 | prompt_x1_y0: "a photorealistic sky" 41 | prompt_x2_y0: "a photorealistic sky with moon" 42 | prompt_x0_y1: "a photorealistic tree" 43 | prompt_x1_y1: "a photorealistic tree" 44 | prompt_x2_y1: "a photorealistic tree" 45 | prompt_x0_y2: "a photorealistic field" 46 | prompt_x1_y2: "a photorealistic field" 47 | prompt_x2_y2: "a photorealistic chicken" 48 | -------------------------------------------------------------------------------- /arnheim_3/configs/config_compositional_tiled.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | # Render methods 3 | # opacity patches overlay each other using a combination of alpha and depth, 4 | # transparency _adds_ patch colours (black therefore appearing transparent), 5 | # and masked transparency blends patches using the alpha channel. 6 | render_method: "masked_transparency_clipped" 7 | num_patches: 100 8 | # Colour transformations can be: "none", "RGB space", "HSV space" 9 | colour_transformations: "RGB space" 10 | 11 | # Number of training steps 12 | optim_steps: 10000 13 | learning_rate: 0.1 14 | trace_every: 50 15 | 16 | ### Load segmented patches 17 | patch_set: "animals.npy" 18 | 19 | ### Resize image patches to low- and high-res. 20 | fixed_scale_patches: True 21 | fixed_scale_coeff: 0.5 22 | 23 | # Configure a background, e.g. uploaded picture or solid colour. 24 | background_url: "" 25 | # Background usage: Global: use image across whole image; Local: reuse same image for every tile 26 | background_use: "Global" 27 | # Colour configuration for solid colour background 28 | background_red: 0 29 | background_green: 0 30 | background_blue: 0 31 | 32 | # Enter a global description of the image, e.g. 'a photorealistic chicken' 33 | global_prompt: "a photorealistic chicken" 34 | global_tile_prompt: True 35 | tile_images: True 36 | tiles_wide: 2 37 | tiles_high: 3 38 | compositional_image: True 39 | 40 | # Single image (i.e. no tiling) composition prompts 41 | # Specify 3x3 prompts for each composition region (left to right, starting at the top) 42 | prompt_x0_y0: "a photorealistic sky with sun" 43 | prompt_x1_y0: "a photorealistic sky" 44 | prompt_x2_y0: "a photorealistic sky with moon" 45 | prompt_x0_y1: "a photorealistic tree" 46 | prompt_x1_y1: "a photorealistic tree" 47 | prompt_x2_y1: "a photorealistic tree" 48 | prompt_x0_y2: "a photorealistic field" 49 | prompt_x1_y2: "a photorealistic field" 50 | prompt_x2_y2: "a photorealistic chicken" 51 | 52 | -------------------------------------------------------------------------------- /arnheim_3/configs/config_debug.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | # Render methods 3 | # opacity patches overlay each other using a combination of alpha and depth, 4 | # transparency _adds_ patch colours (black therefore appearing transparent), 5 | # and masked transparency blends patches using the alpha channel. 6 | render_method: "masked_transparency_clipped" 7 | num_patches: 100 8 | # Colour transformations can be: "none", "RGB space", "HSV space" 9 | colour_transformations: "RGB space" 10 | 11 | # Number of training steps 12 | optim_steps: 10 13 | learning_rate: 0.1 14 | 15 | ### Load segmented patches 16 | patch_set: "animals.npy" 17 | 18 | ### Resize image patches to low- and high-res. 19 | fixed_scale_patches: True 20 | fixed_scale_coeff: 0.5 21 | 22 | # Configure a background, e.g. uploaded picture or solid colour. 23 | background_url: "" 24 | # Background usage: Global: use image across whole image; Local: reuse same image for every tile 25 | background_use: "Global" 26 | # Colour configuration for solid colour background 27 | background_red: 0 28 | background_green: 0 29 | background_blue: 0 30 | 31 | # Enter a global description of the image, e.g. 'a photorealistic chicken' 32 | global_prompt: "a photorealistic chicken" 33 | global_tile_prompt: True 34 | tile_images: True 35 | tiles_wide: 2 36 | tiles_high: 3 37 | compositional_image: True 38 | 39 | # Single image (i.e. no tiling) composition prompts 40 | # Specify 3x3 prompts for each composition region (left to right, starting at the top) 41 | prompt_x0_y0: "a photorealistic sky with sun" 42 | prompt_x1_y0: "a photorealistic sky" 43 | prompt_x2_y0: "a photorealistic sky with moon" 44 | prompt_x0_y1: "a photorealistic tree" 45 | prompt_x1_y1: "a photorealistic tree" 46 | prompt_x2_y1: "a photorealistic tree" 47 | prompt_x0_y2: "a photorealistic field" 48 | prompt_x1_y2: "a photorealistic field" 49 | prompt_x2_y2: "a photorealistic chicken" 50 | 51 | -------------------------------------------------------------------------------- /arnheim_3/configs/config_multipatch.yaml: -------------------------------------------------------------------------------- 1 | # Example of using different patch sets for different cells. 2 | --- 3 | # Render methods 4 | # opacity patches overlay each other using a combination of alpha and depth, 5 | # transparency _adds_ patch colours (black therefore appearing transparent), 6 | # and masked transparency blends patches using the alpha channel. 7 | render_method: "masked_transparency_clipped" 8 | num_patches: 100 9 | # Colour transformations can be: "none", "RGB space", "HSV space" 10 | colour_transformations: "RGB space" 11 | 12 | # Number of training steps 13 | optim_steps: 2000 14 | learning_rate: 0.3 15 | trace_every: 50 16 | 17 | ### Load segmented patches 18 | patch_set: "animals.npy" 19 | ### Resize image patches to low- and high-res. 20 | no-fixed_scale_patches: true 21 | patch_max_proportion: 8 22 | 23 | # The multiple_* lists enable the above patch parameters to be overwritten on a 24 | # per-tile basis. Entries in the lists are used in the order tiles are created, 25 | # i.e. left to right, top to bottom, and the list is repeated if necessary. 26 | multiple_patch_set: ["shore_glass.npy", "animals.npy"] 27 | multiple_fixed_scale_patches: [true, true, false] 28 | multiple_fixed_scale_coeff: [0.8, 0.3] 29 | # multiple_patch_max_proportion: [3, 5, 5] 30 | 31 | global_tile_prompt: True 32 | tile_images: True 33 | tiles_wide: 2 34 | tiles_high: 2 35 | 36 | # Configure a background, e.g. uploaded picture or solid colour. 37 | # background_url: "https://upload.wikimedia.org/wikipedia/commons/0/0e/Lithographic_Drawing-Book_%28BM_1887%2C0722.360.2%29.jpg" 38 | background_url: "biggest_chicken_ever.jpg" 39 | # Background usage: Global: use image across whole image; Local: reuse same image for every tile 40 | background_use: "Global" 41 | # Colour configuration for solid colour background 42 | background_red: 0 43 | background_green: 0 44 | background_blue: 0 45 | 46 | # Enter a global description of the image, e.g. 'a photorealistic chicken' 47 | global_prompt: "torso" 48 | -------------------------------------------------------------------------------- /arnheim_3/configs/config_simple.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | # Render methods 3 | # opacity patches overlay each other using a combination of alpha and depth, 4 | # transparency _adds_ patch colours (black therefore appearing transparent), 5 | # and masked transparency blends patches using the alpha channel. 6 | render_method: "masked_transparency_clipped" 7 | num_patches: 100 8 | # Colour transformations can be: "none", "RGB space", "HSV space" 9 | colour_transformations: "RGB space" 10 | 11 | # Number of training steps 12 | optim_steps: 10000 13 | learning_rate: 0.1 14 | trace_every: 50 15 | 16 | ### Load segmented patches 17 | patch_set: "animals.npy" 18 | 19 | ### Resize image patches to low- and high-res. 20 | fixed_scale_patches: True 21 | fixed_scale_coeff: 0.5 22 | 23 | # Configure a background, e.g. uploaded picture or solid colour. 24 | background_url: "" 25 | # Background usage: Global: use image across whole image; Local: reuse same image for every tile 26 | background_use: "Global" 27 | # Colour configuration for solid colour background 28 | background_red: 0 29 | background_green: 0 30 | background_blue: 0 31 | 32 | # Enter a global description of the image, e.g. 'a photorealistic chicken' 33 | global_prompt: "a photorealistic chicken" 34 | -------------------------------------------------------------------------------- /arnheim_3/configs/config_simple_search.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | # Render methods 3 | # opacity patches overlay each other using a combination of alpha and depth, 4 | # transparency _adds_ patch colours (black therefore appearing transparent), 5 | # and masked transparency blends patches using the alpha channel. 6 | render_method: "masked_transparency_clipped" 7 | num_patches: 50 8 | # Colour transformations can be: "none", "RGB space", "HSV space" 9 | colour_transformations: "RGB space" 10 | pop_size: 2 11 | 12 | # Number of training steps 13 | optim_steps: 10000 14 | learning_rate: 0.07 15 | initial_search_size: 10 16 | initial_search_num_steps: 40 17 | trace_every: 50 18 | 19 | ### Load segmented patches 20 | patch_set: "animals.npy" 21 | 22 | ### Resize image patches to low- and high-res. 23 | fixed_scale_patches: True 24 | fixed_scale_coeff: 0.5 25 | 26 | # Configure a background, e.g. uploaded picture or solid colour. 27 | background_url: "" 28 | # Background usage: Global: use image across whole image; Local: reuse same image for every tile 29 | background_use: "Global" 30 | # Colour configuration for solid colour background 31 | background_red: 0 32 | background_green: 0 33 | background_blue: 0 34 | 35 | # Enter a global description of the image, e.g. 'a photorealistic chicken' 36 | global_prompt: "a photorealistic chicken" 37 | -------------------------------------------------------------------------------- /arnheim_3/main.py: -------------------------------------------------------------------------------- 1 | """Arnheim 3 - Collage Creator 2 | Piotr Mirowski, Dylan Banarse, Mateusz Malinowski, Yotam Doron, Oriol Vinyals, 3 | Simon Osindero, Chrisantha Fernando 4 | DeepMind, 2021-2022 5 | 6 | Copyright 2021 DeepMind Technologies Limited 7 | 8 | Licensed under the Apache License, Version 2.0 (the "License"); 9 | you may not use this file except in compliance with the License. 10 | You may obtain a copy of the License at 11 | https://www.apache.org/licenses/LICENSE-2.0 12 | Unless required by applicable law or agreed to in writing, 13 | software distributed under the License is distributed on an "AS IS" BASIS, 14 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | See the License for the specific language governing permissions and 16 | limitations under the License. 17 | """ 18 | 19 | import configargparse 20 | from datetime import datetime 21 | import glob 22 | import os 23 | import pathlib 24 | import subprocess 25 | import sys 26 | import yaml 27 | 28 | import numpy as np 29 | import torch 30 | 31 | import clip 32 | 33 | import src.collage as collage 34 | import src.video_utils as video_utils 35 | 36 | 37 | # Specify (and override) the config. 38 | ap = configargparse.ArgumentParser(default_config_files=["configs/config.yaml"]) 39 | ap.add_argument("-c", "--config", required=True, is_config_file=True, 40 | help="Config file") 41 | 42 | # Use CUDA? 43 | ap.add_argument("--cuda", dest="cuda", action="store_true") 44 | ap.add_argument("--no-cuda", dest="cuda", action="store_false") 45 | ap.set_defaults(cuda=True) 46 | ap.add_argument("--torch_device", type=str, default="cuda", 47 | help="Alternative way of specifying the device: cuda or cpu?") 48 | 49 | # Output directory. 50 | ap.add_argument("--init_checkpoint", type=str, default="", 51 | help="Path to checkpoint") 52 | 53 | # Output directory. 54 | ap.add_argument("--output_dir", type=str, default="", 55 | help="Output directory") 56 | 57 | # Clean-up? 58 | ap.add_argument("--clean_up", dest='clean_up', help="Remove all working files", 59 | action='store_true') 60 | ap.add_argument("--no-clean_up", dest='clean_up', 61 | help="Remove all working files", action='store_false') 62 | ap.set_defaults(clean_up=False) 63 | 64 | # GUI? 65 | ap.add_argument('--gui', dest='gui', action='store_true') 66 | ap.add_argument('--no-gui', dest='gui', action='store_false') 67 | ap.set_defaults(gui=False) 68 | 69 | # Video and tracing. 70 | ap.add_argument("--video_steps", type=int, default=0, 71 | help="Number of steps between two video frames") 72 | ap.add_argument("--trace_every", type=int, default=50, 73 | help="Number of steps between two logging traces") 74 | ap.add_argument('--population_video', dest='population_video', 75 | action='store_true', help='Write the video of population?') 76 | ap.add_argument('--no-population_video', dest='population_video', 77 | action='store_false', help='Write the video of population?') 78 | ap.set_defaults(population_video=False) 79 | 80 | # Canvas size. 81 | ap.add_argument("--canvas_width", type=int, default=224, 82 | help="Image width for CLIP optimization") 83 | ap.add_argument("--canvas_height", type=int, default=224, 84 | help="Image height for CLIP optimization") 85 | ap.add_argument("--max_block_size_high_res", type=int, default=2000, 86 | help="Max block size for high-res image") 87 | 88 | # Render methods. 89 | ap.add_argument("--render_method", type=str, default="transparency", 90 | help="opacity patches overlay each other using combinations of " 91 | "alpha and depth, transparency _adds_ patch RGB values (black " 92 | "therefore appearing transparent), masked_transparency_clipped " 93 | "and masked_transparency_normed blend patches using the alpha " 94 | "channel") 95 | ap.add_argument("--num_patches", type=int, default=100, 96 | help="Number of patches") 97 | ap.add_argument("--colour_transformations", type=str, default="RGB space", 98 | help="Can be none, RGB space or HHSV space") 99 | ap.add_argument("--invert_colours", dest="invert_colours", action='store_true', 100 | help="Invert image colours to have a white background?") 101 | ap.add_argument("--no-invert_colours", dest="invert_colours", 102 | action='store_false', 103 | help="Invert image colours to have a white background?") 104 | ap.set_defaults(invert_colours=False) 105 | ap.add_argument("--high_res_multiplier", type=int, default=4, 106 | help="Ratio between large canvas and CLIP-optimized canvas") 107 | ap.add_argument('--save_all_arrays', dest='save_all_arrays', 108 | action='store_true', 109 | help='Save the optimised patch arrays as an npy file?') 110 | ap.add_argument('--no-save_all_arrays', dest='save_all_arrays', 111 | action='store_false', 112 | help='Save the optimised patch arrays as an npy file?') 113 | ap.set_defaults(save_all_arrays=False) 114 | 115 | # Affine transform settings. 116 | ap.add_argument("--min_trans", type=float, default=-1., 117 | help="Translation min for X and Y") 118 | ap.add_argument("--max_trans", type=float, default=1., 119 | help="Translation max for X and Y") 120 | ap.add_argument("--min_trans_init", type=float, default=-1., 121 | help="Initial translation min for X and Y") 122 | ap.add_argument("--max_trans_init", type=float, default=1., 123 | help="Initial translation max for X and Y") 124 | ap.add_argument("--min_scale", type=float, default=1., 125 | help="Scale min (> 1 means zoom out and < 1 means zoom in)") 126 | ap.add_argument("--max_scale", type=float, default=2., 127 | help="Scale max (> 1 means zoom out and < 1 means zoom in)") 128 | ap.add_argument("--min_squeeze", type=float, default=0.5, 129 | help="Min ratio between X and Y scale") 130 | ap.add_argument("--max_squeeze", type=float, default=2., 131 | help="Max ratio between X and Y scale") 132 | ap.add_argument("--min_shear", type=float, default=-0.2, 133 | help="Min shear deformation") 134 | ap.add_argument("--max_shear", type=float, default=0.2, 135 | help="Max shear deformation") 136 | ap.add_argument("--min_rot_deg", type=float, default=-180, help="Min rotation") 137 | ap.add_argument("--max_rot_deg", type=float, default=180, help="Max rotation") 138 | 139 | # Colour transform settings. 140 | ap.add_argument("--min_rgb", type=float, default=-0.2, 141 | help="Min RGB between -1 and 1") 142 | ap.add_argument("--max_rgb", type=float, default=1.0, 143 | help="Max RGB between -1 and 1") 144 | ap.add_argument("--initial_min_rgb", type=float, default=0.5, 145 | help="Initial min RGB between -1 and 1") 146 | ap.add_argument("--initial_max_rgb", type=float, default=1., 147 | help="Initial max RGB between -1 and 1") 148 | ap.add_argument("--min_hue_deg", type=float, default=0., 149 | help="Min hue between 0 and 360") 150 | ap.add_argument("--max_hue_deg", type=float, default=360, 151 | help="Max hue (in degrees) between 0 and 360") 152 | ap.add_argument("--min_sat", type=float, default=0, 153 | help="Min saturation between 0 and 1") 154 | ap.add_argument("--max_sat", type=float, default=1, 155 | help="Max saturation between 0 and 1") 156 | ap.add_argument("--min_val", type=float, default=0, 157 | help="Min value between 0 and 1") 158 | ap.add_argument("--max_val", type=float, default=1, 159 | help="Max value between 0 and 1") 160 | 161 | # Training settings. 162 | ap.add_argument("--clip_model", type=str, default="ViT-B/32", help="CLIP model") 163 | ap.add_argument("--optim_steps", type=int, default=10000, 164 | help="Number of training steps (between 0 and 20000)") 165 | ap.add_argument("--learning_rate", type=float, default=0.1, 166 | help="Learning rate, typically between 0.05 and 0.3") 167 | ap.add_argument("--use_image_augmentations", dest="use_image_augmentations", 168 | action='store_true', 169 | help="User image augmentations for CLIP evaluation?") 170 | ap.add_argument("--no-use_image_augmentations", dest="use_image_augmentations", 171 | action='store_false', 172 | help="User image augmentations for CLIP evaluation?") 173 | ap.set_defaults(use_image_augmentations=True) 174 | ap.add_argument("--num_augs", type=int, default=4, 175 | help="Number of image augmentations to use in CLIP evaluation") 176 | ap.add_argument("--use_normalized_clip", dest="use_normalized_clip", 177 | action='store_true', 178 | help="Normalize colours for CLIP, generally leave this as True") 179 | ap.add_argument("--no-use_normalized_clip", dest="use_normalized_clip", 180 | action='store_false', 181 | help="Normalize colours for CLIP, generally leave this as True") 182 | ap.set_defaults(use_normalized_clip=False) 183 | ap.add_argument("--gradient_clipping", type=float, default=10.0, 184 | help="Gradient clipping during optimisation") 185 | ap.add_argument("--initial_search_size", type=int, default=1, 186 | help="Initial random search size (1 means no search)") 187 | ap.add_argument("--initial_search_num_steps", type=int, default=1, 188 | help="Number of gradient steps in initial random search size " 189 | "(1 means only random search, more means gradient descent)") 190 | 191 | # Evolution settings. 192 | ap.add_argument("--pop_size", type=int, default=2, 193 | help="For evolution set this to greater than 1") 194 | ap.add_argument("--evolution_frequency", type=int, default= 100, 195 | help="Number of gradient steps between two evolution mutations") 196 | ap.add_argument("--ga_method", type=str, default="Microbial", 197 | help="Microbial: loser of randomly selected pair is replaced " 198 | "by mutated winner. A low selection pressure. Evolutionary " 199 | "Strategies: mutantions of the best individual replace the " 200 | "rest of the population. Much higher selection pressure than " 201 | "Microbial GA") 202 | 203 | # Mutation levels. 204 | ap.add_argument("--pos_and_rot_mutation_scale", type=float, default=0.02, 205 | help="Probability of position and rotation mutations") 206 | ap.add_argument("--scale_mutation_scale", type=float, default=0.02, 207 | help="Probability of scale mutations") 208 | ap.add_argument("--distort_mutation_scale", type=float, default=0.02, 209 | help="Probability of distortion mutations") 210 | ap.add_argument("--colour_mutation_scale", type=float, default=0.02, 211 | help="Probability of colour mutations") 212 | ap.add_argument("--patch_mutation_probability", type=float, default=1, 213 | help="Probability of patch mutations") 214 | 215 | # Visualisation. 216 | ap.add_argument("--max_multiple_visualizations", type=int, default=5, 217 | help="Limit the number of individuals shown during training") 218 | 219 | # Load segmented patches. 220 | ap.add_argument("--multiple_patch_set", default=None, 221 | action='append', dest="multiple_patch_set") 222 | ap.add_argument("--multiple_fixed_scale_patches", default=None, 223 | action='append', dest="multiple_fixed_scale_patches") 224 | ap.add_argument("--multiple_patch_max_proportion", default=None, 225 | action='append', dest="multiple_patch_max_proportion") 226 | ap.add_argument("--multiple_fixed_scale_coeff", default=None, 227 | action='append', dest="multiple_fixed_scale_coeff") 228 | ap.add_argument("--patch_set", type=str, default="animals.npy", 229 | help="Name of Numpy file with patches") 230 | ap.add_argument("--patch_repo_root", type=str, 231 | default= 232 | "https://storage.googleapis.com/dm_arnheim_3_assets/collage_patches", 233 | help="URL to patches") 234 | ap.add_argument("--url_to_patch_file", type=str, default="", 235 | help="URL to a patch file") 236 | 237 | # Resize image patches to low- and high-res. 238 | ap.add_argument("--fixed_scale_patches", dest="fixed_scale_patches", 239 | action='store_true', help="Use fixed scale patches?") 240 | ap.add_argument("--no-fixed_scale_patches", dest="fixed_scale_patches", 241 | action='store_false', help="Use fixed scale patches?") 242 | ap.set_defaults(fixed_scale_patches=True) 243 | ap.add_argument("--fixed_scale_coeff", type=float, default=0.7, 244 | help="Scale coeff for fixed scale patches") 245 | ap.add_argument("--normalize_patch_brightness", 246 | dest="normalize_patch_brightness", action='store_true', 247 | help="Normalize the brightness of patches?") 248 | ap.add_argument("--no-normalize_patch_brightness", 249 | dest="normalize_patch_brightness", action='store_false', 250 | help="Normalize the brightness of patches?") 251 | ap.set_defaults(normalize_patch_brightness=False) 252 | ap.add_argument("--patch_max_proportion", type=int, default= 5, 253 | help="Max proportion of patches, between 2 and 8") 254 | ap.add_argument("--patch_width_min", type=int, default=16, 255 | help="Min width of patches") 256 | ap.add_argument("--patch_height_min", type=int, default=16, 257 | help="Min height of patches") 258 | 259 | # Configure a background, e.g. uploaded picture or solid colour. 260 | ap.add_argument("--background_use", type=str, default="Global", 261 | help="Global: use image across whole image, " 262 | "or Local: reuse same image for every tile") 263 | ap.add_argument("--background_url", type=str, default="", 264 | help="URL for background image") 265 | ap.add_argument("--background_red", type=int, default=0, 266 | help="Red solid colour background (0 to 255)") 267 | ap.add_argument("--background_green", type=int, default=0, 268 | help="Green solid colour background (0 to 255)") 269 | ap.add_argument("--background_blue", type=int, default=0, 270 | help="Blue solid colour background (0 to 255)") 271 | 272 | # Configure image prompt and content. 273 | ap.add_argument("--global_prompt", type=str, 274 | default="Roman mosaic of an unswept floor", 275 | help="Global description of the image") 276 | 277 | # Tile prompts and tiling settings. 278 | ap.add_argument("--tile_images", action='store_true', dest="tile_images", 279 | help="Tile images?") 280 | ap.add_argument("--no-tile_images", action='store_false', dest="tile_images", 281 | help="Tile images?") 282 | ap.set_defaults(tile_images=False) 283 | ap.add_argument("--tiles_wide", type=int, default=1, 284 | help="Number of width tiles") 285 | ap.add_argument("--tiles_high", type=int, default=1, 286 | help="Number of height tiles") 287 | ap.add_argument("--global_tile_prompt", dest="global_tile_prompt", 288 | action='store_true', 289 | help="Global tile prompt uses global_prompt (previous cell) " 290 | "for *all* tiles (e.g. Roman mosaic of an unswept floor)") 291 | ap.add_argument("--no-global_tile_prompt", dest="global_tile_prompt", 292 | action='store_false', 293 | help="Global tile prompt uses global_prompt (previous cell) " 294 | "for *all* tiles (e.g. Roman mosaic of an unswept floor)") 295 | ap.set_defaults(global_tile_prompt=False) 296 | ap.add_argument("--tile_prompt_string", type=str, default="", 297 | help="Otherwise, specify multiple tile prompts with columns " 298 | "separated by | and / to delineate new row. E.g. multiple " 299 | "prompts for a 3x2 'landscape' image: " 300 | "'sun | clouds | sky / fields | fields | trees'") 301 | 302 | # Composition prompts. 303 | ap.add_argument("--compositional_image", dest="compositional_image", 304 | action="store_true", 305 | help="Use additional prompts for different regions") 306 | ap.add_argument("--no-compositional_image", dest="compositional_image", 307 | action="store_false", 308 | help="Do not use additional prompts for different regions") 309 | ap.set_defaults(compositional_image=False) 310 | # Single image (i.e. no tiling) composition prompts: 311 | # specify 3x3 prompts for each composition region. 312 | ap.add_argument("--prompt_x0_y0", type=str, 313 | default="a photorealistic sky with sun", help="Top left prompt") 314 | ap.add_argument("--prompt_x1_y0", type=str, 315 | default="a photorealistic sky", help="Top centre prompt") 316 | ap.add_argument("--prompt_x2_y0", type=str, 317 | default="a photorealistic sky with moon", help="Top right prompt") 318 | ap.add_argument("--prompt_x0_y1", type=str, 319 | default="a photorealistic tree", help="Middle left prompt") 320 | ap.add_argument("--prompt_x1_y1", type=str, 321 | default="a photorealistic tree", help="Centre prompt") 322 | ap.add_argument("--prompt_x2_y1", type=str, 323 | default="a photorealistic tree", help="Middle right prompt") 324 | ap.add_argument("--prompt_x0_y2", type=str, 325 | default="a photorealistic field", help="Bottom left prompt") 326 | ap.add_argument("--prompt_x1_y2", type=str, 327 | default="a photorealistic field", help="Bottom centre prompt") 328 | ap.add_argument("--prompt_x2_y2", type=str, 329 | default="a photorealistic chicken", help="Bottom right prompt") 330 | 331 | # Tile composition prompts. 332 | ap.add_argument("--tile_prompt_formating", type=str, default="close-up of {}", 333 | help="This string is formated to autogenerate region prompts " 334 | "from tile prompt. e.g. close-up of {}") 335 | 336 | # Get the config. 337 | config = vars(ap.parse_args()) 338 | 339 | print(config) 340 | 341 | # Adjust config for compositional image. 342 | if config["compositional_image"] == True: 343 | print("Generating compositional image") 344 | config['canvas_width'] *= 2 345 | config['canvas_height'] *= 2 346 | config['high_res_multiplier'] = int(config['high_res_multiplier'] / 2) 347 | print("Using one image augmentations for compositional image creation.") 348 | config["use_image_augmentations"] = True 349 | config["num_augs"] = 1 350 | 351 | # Turn off tiling if either boolean is set or width/height set to 1. 352 | if (not config["tile_images"] or 353 | (config["tiles_wide"] == 1 and config["tiles_high"] == 1)): 354 | print("No tiling.") 355 | config["tiles_wide"] = 1 356 | config["tiles_high"] = 1 357 | config["tile_images"] = False 358 | 359 | # Default output dir. 360 | if len(config["output_dir"]) == 0: 361 | config["output_dir"] = "output_" 362 | config["output_dir"] += datetime.strftime(datetime.now(), '%Y%m%d_%H%M%S') 363 | config["output_dir"] += '/' 364 | 365 | # Print the config. 366 | print("\n") 367 | yaml.dump(config, sys.stdout, default_flow_style=False, allow_unicode=True) 368 | print("\n\n") 369 | 370 | 371 | # Configure CUDA. 372 | print("Torch version:", torch.__version__) 373 | if not config["cuda"] or config["torch_device"] == "cpu": 374 | config["torch_device"] = "cpu" 375 | config["cuda"] = False 376 | device = torch.device(config["torch_device"]) 377 | 378 | # Configure ffmpeg. 379 | os.environ["FFMPEG_BINARY"] = "ffmpeg" 380 | 381 | 382 | # Initialise and load CLIP model. 383 | print(f"Downloading CLIP model {config['clip_model']}...") 384 | clip_model, _ = clip.load(config["clip_model"], device, jit=False) 385 | 386 | # Make output dir. 387 | output_dir = config["output_dir"] 388 | print(f"Storing results in {output_dir}\n") 389 | pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True) 390 | 391 | # Save the config. 392 | config_filename = config["output_dir"] + '/' + "config.yaml" 393 | with open(config_filename, "w") as f: 394 | yaml.dump(config, f, default_flow_style=False, allow_unicode=True) 395 | 396 | # Tiling. 397 | if not config["tile_images"] or config["global_tile_prompt"]: 398 | tile_prompts = ( 399 | [config["global_prompt"]] * config["tiles_high"] * config["tiles_wide"]) 400 | else: 401 | tile_prompts = [] 402 | count_y = 0 403 | count_x = 0 404 | for row in config["tile_prompt_string"].split("/"): 405 | for prompt in row.split("|"): 406 | prompt = prompt.strip() 407 | tile_prompts.append(prompt) 408 | count_x += 1 409 | if count_x != config["tiles_wide"]: 410 | w = config["tiles_wide"] 411 | raise ValueError( 412 | f"Insufficient prompts for row {count_y}; expected {w}, got {count_x}") 413 | count_x = 0 414 | count_y += 1 415 | if count_y != config["tiles_high"]: 416 | h = config["tiles_high"] 417 | raise ValueError(f"Insufficient prompt rows; expected {h}, got {count_y}") 418 | 419 | print("Tile prompts: ", tile_prompts) 420 | # Prepare duplicates of config data if required for tiles. 421 | tile_count = 0 422 | all_prompts = [] 423 | for y in range(config["tiles_high"]): 424 | for x in range(config["tiles_wide"]): 425 | list_tile_prompts = [] 426 | if config["compositional_image"]: 427 | if config["tile_images"]: 428 | list_tile_prompts = [ 429 | config["tile_prompt_formating"].format(tile_prompts[tile_count]) 430 | ] * 9 431 | else: 432 | list_tile_prompts = [ 433 | config["prompt_x0_y0"], config["prompt_x1_y0"], 434 | config["prompt_x2_y0"], 435 | config["prompt_x0_y1"], config["prompt_x1_y1"], 436 | config["prompt_x2_y1"], 437 | config["prompt_x0_y2"], config["prompt_x1_y2"], 438 | config["prompt_x2_y2"]] 439 | list_tile_prompts.append(tile_prompts[tile_count]) 440 | tile_count += 1 441 | all_prompts.append(list_tile_prompts) 442 | print(f"All prompts: {all_prompts}") 443 | 444 | 445 | # Background. 446 | background_image = None 447 | background_url = config["background_url"] 448 | if len(background_url) > 0: 449 | # Load background image from URL. 450 | if background_url.startswith("http"): 451 | background_image = video_utils.cached_url_download(background_url, 452 | format="image_as_np") 453 | else: 454 | background_image = video_utils.load_image(background_url, 455 | show=config["gui"]) 456 | else: 457 | background_image = np.ones((10, 10, 3), dtype=np.float32) 458 | background_image[:, :, 0] = config["background_red"] / 255. 459 | background_image[:, :, 1] = config["background_green"] / 255. 460 | background_image[:, :, 2] = config["background_blue"] / 255. 461 | print('Defined background colour ({}, {}, {})'.format( 462 | config["background_red"], config["background_green"], 463 | config["background_blue"])) 464 | 465 | 466 | # Initialse the collage. 467 | ct = collage.CollageTiler( 468 | prompts=all_prompts, 469 | fixed_background_image=background_image, 470 | clip_model=clip_model, 471 | device=device, 472 | config=config) 473 | ct.initialise() 474 | 475 | # Collage optimisation loop. 476 | output = ct.loop() 477 | 478 | # Render high res image and finish up. 479 | ct.assemble_tiles() 480 | 481 | # Clean-up temporary files. 482 | if config["clean_up"]: 483 | for file_match in ["*.npy", "tile_*.png"]: 484 | output_dir = config["output_dir"] 485 | files = glob.glob(f"{output_dir}/{file_match}") 486 | for f in files: 487 | os.remove(f) 488 | -------------------------------------------------------------------------------- /arnheim_3/requirements.txt: -------------------------------------------------------------------------------- 1 | configargparse 2 | kornia 3 | ftfy 4 | regex 5 | opencv-python 6 | visdom 7 | torch 8 | torch-tools 9 | -------------------------------------------------------------------------------- /arnheim_3/src/collage.py: -------------------------------------------------------------------------------- 1 | """Collage-making class definitions. 2 | 3 | Arnheim 3 - Collage 4 | Piotr Mirowski, Dylan Banarse, Mateusz Malinowski, Yotam Doron, Oriol Vinyals, 5 | Simon Osindero, Chrisantha Fernando 6 | DeepMind, 2021-2022 7 | Copyright 2021 DeepMind Technologies Limited 8 | 9 | Licensed under the Apache License, Version 2.0 (the "License"); 10 | you may not use this file except in compliance with the License. 11 | You may obtain a copy of the License at 12 | https://www.apache.org/licenses/LICENSE-2.0 13 | Unless required by applicable law or agreed to in writing, 14 | software distributed under the License is distributed on an "AS IS" BASIS, 15 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | See the License for the specific language governing permissions and 17 | limitations under the License. 18 | """ 19 | 20 | import math 21 | import pathlib 22 | from . import training 23 | from . import video_utils 24 | from .collage_generator import PopulationCollage 25 | import cv2 26 | import numpy as np 27 | from .patches import get_segmented_data 28 | import torch 29 | import yaml 30 | 31 | 32 | class CollageMaker(): 33 | """Makes a single collage image. 34 | 35 | A collage image (aka tile) may involve 3x3 parallel evaluations. 36 | """ 37 | 38 | def __init__( 39 | self, 40 | prompts, 41 | segmented_data, 42 | background_image, 43 | clip_model, 44 | file_basename, 45 | device, 46 | config): 47 | """Create a single square collage image. 48 | 49 | Args: 50 | prompts: list of prompts. Optional compositional prompts plus a global one 51 | segmented_data: patches for the collage 52 | background_image: background image for the collage 53 | clip_model: CLIP model 54 | file_basename: string, name to use for the saved files 55 | device: CUDA device 56 | config: dictionary with the following fields. 57 | 58 | Config fields: 59 | compositional_image: bool, whether to use 3x3 CLIPs 60 | output_dir: string, directory to save working and final images 61 | video_steps: int, how many steps between video frames; 0 is never 62 | population_video: bool, create a video with members of the population 63 | use_normalized_clip: bool, colour-correct images for CLIP evaluation 64 | use_image_augmentations: bool, use image augmentations in evaluation 65 | optim_steps: int, training steps for the collage 66 | pop_size: int, size of population being evolved 67 | evolution_frequency: bool, how many steps between evolution evaluations 68 | initial_search_size: int, initial random search size (1 means no search) 69 | """ 70 | self._prompts = prompts 71 | self._segmented_data = segmented_data 72 | self._background_image = background_image 73 | self._clip_model = clip_model 74 | self._file_basename = file_basename 75 | self._device = device 76 | self._config = config 77 | self._compositional_image = self._config["compositional_image"] 78 | self._output_dir = self._config["output_dir"] 79 | self._use_normalized_clip = self._config["use_normalized_clip"] 80 | self._use_image_augmentations = self._config["use_image_augmentations"] 81 | self._optim_steps = self._config["optim_steps"] 82 | self._pop_size = self._config["pop_size"] 83 | self._population_video = self._config["population_video"] 84 | self._use_evolution = self._config["pop_size"] > 1 85 | self._evolution_frequency = self._config["evolution_frequency"] 86 | self._initial_search_size = self._config["initial_search_size"] 87 | 88 | self._video_steps = self._config["video_steps"] 89 | self._video_writer = None 90 | self._population_video_writer = None 91 | if self._video_steps: 92 | self._video_writer = video_utils.VideoWriter( 93 | filename=f"{self._output_dir}/{self._file_basename}.mp4") 94 | if self._population_video: 95 | self._population_video_writer = video_utils.VideoWriter( 96 | filename=f"{self._output_dir}/{self._file_basename}_pop_sample.mp4") 97 | 98 | if self._compositional_image: 99 | if len(self._prompts) != 10: 100 | raise ValueError( 101 | "Missing compositional image prompts; found {len(self._prompts)}") 102 | print("Global prompt is", self._prompts[-1]) 103 | print("Composition prompts", self._prompts) 104 | else: 105 | if len(self._prompts) != 1: 106 | raise ValueError( 107 | "Missing compositional image prompts; found {len(self._prompts)}") 108 | print("CLIP prompt", self._prompts[0]) 109 | 110 | # Prompt to CLIP features. 111 | self._prompt_features = training.compute_text_features( 112 | self._prompts, self._clip_model, self._device) 113 | self._augmentations = training.augmentation_transforms( 114 | 224, 115 | use_normalized_clip=self._use_normalized_clip, 116 | use_augmentation=self._use_image_augmentations) 117 | 118 | # Create population of collage generators. 119 | self._generator = PopulationCollage( 120 | config=self._config, 121 | device=self._device, 122 | is_high_res=False, 123 | pop_size=self._pop_size, 124 | segmented_data=self._segmented_data, 125 | background_image=self._background_image) 126 | 127 | self._optimizer = training.make_optimizer(self._generator, 128 | self._config["learning_rate"]) 129 | self._step = 0 130 | self._losses_history = [] 131 | self._losses_separated_history = [] 132 | 133 | @property 134 | def generator(self): 135 | return self._generator 136 | 137 | @property 138 | def step(self): 139 | return self._step 140 | 141 | def initialise(self): 142 | """Initialise the collage from checkpoint or search over hyper-parameters.""" 143 | 144 | # If we use a checkpoint. 145 | if len(self._config["init_checkpoint"]) > 0: 146 | self.load(self._config["init_checkpoint"]) 147 | return 148 | 149 | # If we do an initial random search. 150 | if self._initial_search_size > 1: 151 | print("\nInitial random search over " 152 | f"{self._initial_search_size} individuals") 153 | for j in range(self._pop_size): 154 | generator_search = PopulationCollage( 155 | config=self._config, 156 | device=self._device, 157 | pop_size=self._initial_search_size, 158 | is_high_res=False, 159 | segmented_data=self._segmented_data, 160 | background_image=self._background_image) 161 | self._optimizer = training.make_optimizer(generator_search, 162 | self._config["learning_rate"]) 163 | 164 | num_steps_search = self._config["initial_search_num_steps"] 165 | if num_steps_search > 1: 166 | # Run several steps of gradient descent? 167 | for step_search in range(num_steps_search): 168 | losses, _, _ = self._train( 169 | step=step_search, last_step=False, 170 | generator=generator_search) 171 | else: 172 | # Or simply let initialise the parameters randomly. 173 | _, _, losses, _ = training.evaluation( 174 | t=0, 175 | clip_enc=self._clip_model, 176 | generator=generator_search, 177 | augment_trans=self._augmentations, 178 | text_features=self._prompt_features, 179 | prompts=self._prompts, 180 | config=self._config, 181 | device=self._device) 182 | print(f"Search {losses}") 183 | idx_best = np.argmin(losses) 184 | print(f"Choose {idx_best} with loss {losses[idx_best]}") 185 | self._generator.copy_from(generator_search, j, idx_best) 186 | del generator_search 187 | print("Initial random search done\n") 188 | 189 | self._optimizer = training.make_optimizer(self._generator, 190 | self._config["learning_rate"]) 191 | 192 | def load(self, path_checkpoint): 193 | """Load an existing generator from state_dict stored in `path`.""" 194 | print(f"\nLoading spatial and colour transforms from {path_checkpoint}...") 195 | state_dict = torch.load(path_checkpoint, map_location=self._device.type) 196 | this_state_dict = self._generator.state_dict() 197 | if state_dict.keys() != this_state_dict.keys(): 198 | print(f"Current and loaded state_dict do not match") 199 | for key in this_state_dict: 200 | this_shape = this_state_dict[key].shape 201 | shape = state_dict[key].shape 202 | if this_shape != shape: 203 | print(f"state_dict[{key}] do not match: {this_shape} vs. {shape}") 204 | print(f"Abort loading from checkpoint.") 205 | return 206 | print(f"Checkpoint {path_checkpoint} restored.") 207 | self._generator.load_state_dict(state_dict) 208 | 209 | def _train(self, step, last_step, generator): 210 | losses, losses_separated, img_batch = training.step_optimization( 211 | t=step, 212 | clip_enc=self._clip_model, 213 | lr_scheduler=self._optimizer, 214 | generator=generator, 215 | augment_trans=self._augmentations, 216 | text_features=self._prompt_features, 217 | prompts=self._prompts, 218 | config=self._config, 219 | device=self._device, 220 | final_step=last_step) 221 | return losses, losses_separated, img_batch 222 | 223 | def loop(self): 224 | """Main optimisation/image generation loop. Can be interrupted.""" 225 | if self._step == 0: 226 | print("\nStarting optimization of collage.") 227 | else: 228 | print(f"\nContinuing optimization of collage at step {self._step}.") 229 | if self._video_steps: 230 | print("Aborting video creation (does not work when interrupted).") 231 | self._video_steps = 0 232 | self._video_writer = None 233 | self._population_video_writer = None 234 | 235 | while self._step < self._optim_steps: 236 | last_step = self._step == (self._optim_steps - 1) 237 | losses, losses_separated, img_batch = self._train( 238 | step=self._step, last_step=last_step, generator=self._generator) 239 | self._add_video_frames(img_batch, losses) 240 | self._losses_history.append(losses) 241 | self._losses_separated_history.append(losses_separated) 242 | 243 | if (self._use_evolution and self._step 244 | and self._step % self._evolution_frequency == 0): 245 | training.population_evolution_step( 246 | self._generator, self._config, losses) 247 | self._step += 1 248 | 249 | def high_res_render(self, 250 | segmented_data_high_res, 251 | background_image_high_res, 252 | gamma=1.0, 253 | show=True, 254 | save=True, 255 | no_background=False): 256 | """Save and/or show a high res render using high-res patches.""" 257 | generator_cpu = PopulationCollage( 258 | config=self._config, 259 | device="cpu", 260 | is_high_res=True, 261 | pop_size=1, 262 | segmented_data=segmented_data_high_res, 263 | background_image=background_image_high_res) 264 | idx_best = np.argmin(self._losses_history[-1]) 265 | lowest_loss = self._losses_history[-1][idx_best] 266 | print(f"Lowest loss: {lowest_loss} @ index {idx_best}: ") 267 | generator_cpu.copy_from(self._generator, 0, idx_best) 268 | generator_cpu = generator_cpu.to("cpu") 269 | generator_cpu.tensors_to("cpu") 270 | 271 | params = {"gamma": gamma, 272 | "max_block_size_high_res": self._config.get( 273 | "max_block_size_high_res")} 274 | if no_background: 275 | params["no_background"] = True 276 | with torch.no_grad(): 277 | img_high_res = generator_cpu.forward_high_res(params) 278 | img = img_high_res.detach().cpu().numpy()[0] 279 | 280 | img = np.clip(img, 0.0, 1.0) 281 | if save or show: 282 | # Swap Red with Blue 283 | if img.shape[2] == 4: 284 | print("Image has alpha channel") 285 | img = img[..., [2, 1, 0, 3]] 286 | else: 287 | img = img[..., [2, 1, 0]] 288 | img = np.clip(img, 0.0, 1.0) * 255 289 | if save: 290 | if no_background: 291 | image_filename = f"{self._output_dir}/{self._file_basename}_no_bkgd.png" 292 | else: 293 | image_filename = f"{self._output_dir}/{self._file_basename}.png" 294 | cv2.imwrite(image_filename, img) 295 | if show: 296 | video_utils.cv2_imshow(img) 297 | 298 | img = img[:, :, :3] 299 | 300 | return img 301 | 302 | def finish(self): 303 | """Finish video writing and save all other data.""" 304 | if self._losses_history: 305 | losses_filename = f"{self._output_dir}/{self._file_basename}_losses" 306 | training.plot_and_save_losses(self._losses_history, 307 | title=f"{self._file_basename} Losses", 308 | filename=losses_filename, 309 | show=self._config["gui"]) 310 | if self._video_steps: 311 | self._video_writer.close() 312 | if self._population_video_writer: 313 | self._population_video_writer.close() 314 | metadata_filename = f"{self._output_dir}/{self._file_basename}.yaml" 315 | with open(metadata_filename, "w") as f: 316 | yaml.dump(self._config, f, default_flow_style=False, allow_unicode=True) 317 | last_step = self._step 318 | last_loss = float(np.amin(self._losses_history[-1])) 319 | return (last_step, last_loss) 320 | 321 | def _add_video_frames(self, img_batch, losses): 322 | """Add images from numpy image batch to video writers. 323 | 324 | Args: 325 | img_batch: numpy array, batch of images (S,H,W,C) 326 | losses: numpy array, losses for each generator (S,N) 327 | """ 328 | if self._video_steps and self._step % self._video_steps == 0: 329 | # Write image to video. 330 | best_img = img_batch[np.argmin(losses)] 331 | self._video_writer.add(cv2.resize( 332 | best_img, (best_img.shape[1] * 3, best_img.shape[0] * 3))) 333 | if self._population_video_writer: 334 | laid_out = video_utils.layout_img_batch(img_batch) 335 | self._population_video_writer.add(cv2.resize( 336 | laid_out, (laid_out.shape[1] * 2, laid_out.shape[0] * 2))) 337 | 338 | 339 | class CollageTiler(): 340 | """Creates a large collage by producing multiple overlapping collages.""" 341 | 342 | def __init__(self, 343 | prompts, 344 | fixed_background_image, 345 | clip_model, 346 | device, 347 | config): 348 | """Create CollageTiler. 349 | 350 | Args: 351 | prompts: list of prompts for the collage maker 352 | fixed_background_image: highest res background image 353 | clip_model: CLIP model 354 | device: CUDA device 355 | config: dictionary with the following fields below: 356 | 357 | Config fields used: 358 | width: number of tiles wide 359 | height: number of tiles high 360 | background_use: how to use the background, e.g. per tile or whole image 361 | compositional_image: bool, compositional for multi-CLIP collage tiles 362 | high_res_multiplier: int, how much bigger is the final high-res image 363 | output_dir: directory for generated files 364 | torch_device: string, either cpu or cuda 365 | """ 366 | self._prompts = prompts 367 | self._fixed_background_image = fixed_background_image 368 | self._clip_model = clip_model 369 | self._device = device 370 | self._config = config 371 | self._tiles_wide = config["tiles_wide"] 372 | self._tiles_high = config["tiles_high"] 373 | self._background_use = config["background_use"] 374 | self._compositional_image = config["compositional_image"] 375 | self._high_res_multiplier = config["high_res_multiplier"] 376 | self._output_dir = config["output_dir"] 377 | self._torch_device = config["torch_device"] 378 | 379 | pathlib.Path(self._output_dir).mkdir(parents=True, exist_ok=True) 380 | self._tile_basename = "tile_y{}_x{}{}" 381 | self._tile_width = 448 if self._compositional_image else 224 382 | self._tile_height = 448 if self._compositional_image else 224 383 | self._overlap = 1. / 3. 384 | 385 | # Size of bigger image 386 | self._width = int(((2 * self._tiles_wide + 1) * self._tile_width) / 3.) 387 | self._height = int(((2 * self._tiles_high + 1) * self._tile_height) / 3.) 388 | 389 | self._high_res_tile_width = self._tile_width * self._high_res_multiplier 390 | self._high_res_tile_height = self._tile_height * self._high_res_multiplier 391 | self._high_res_width = self._high_res_tile_width * self._tiles_wide 392 | self._high_res_height = self._high_res_tile_height * self._tiles_high 393 | 394 | self._print_info() 395 | self._x = 0 396 | self._y = 0 397 | self._collage_maker = None 398 | self._fixed_background = self._scale_fixed_background(high_res=True) 399 | 400 | def _print_info(self): 401 | """Print some debugging information.""" 402 | 403 | print(f"Tiling {self._tiles_wide}x{self._tiles_high} collages") 404 | print("Optimisation:") 405 | print(f"Tile size: {self._tile_width}x{self._tile_height}") 406 | print(f"Global size: {self._width}x{self._height} (WxH)") 407 | print("High res:") 408 | print( 409 | f"Tile size: {self._high_res_tile_width}x{self._high_res_tile_height}") 410 | print(f"Global size: {self._high_res_width}x{self._high_res_height} (WxH)") 411 | for i, tile_prompts in enumerate(self._prompts): 412 | print(f"Tile {i} prompts: {tile_prompts}") 413 | 414 | def initialise(self): 415 | """Initialise the collage maker, optionally from a checkpoint or initial search.""" 416 | 417 | if not self._collage_maker: 418 | # Create new collage maker with its unique background. 419 | print(f"\nNew collage creator for y{self._y}, x{self._x} with bg") 420 | tile_bg, self._tile_high_res_bg = self._get_tile_background() 421 | video_utils.show_and_save(tile_bg, self._config, 422 | img_format="SCHW", stitch=False, 423 | show=self._config["gui"]) 424 | prompts_x_y = self._prompts[self._y * self._tiles_wide + self._x] 425 | segmented_data, self._segmented_data_high_res = ( 426 | get_segmented_data( 427 | self._config, self._x + self._y * self._tiles_wide)) 428 | self._collage_maker = CollageMaker( 429 | prompts=prompts_x_y, 430 | segmented_data=segmented_data, 431 | background_image=tile_bg, 432 | clip_model=self._clip_model, 433 | file_basename=self._tile_basename.format(self._y, self._x, ""), 434 | device=self._device, 435 | config=self._config) 436 | self._collage_maker.initialise() 437 | 438 | def load(self, path): 439 | """Load an existing CollageMaker generator from state_dict stored in `path`.""" 440 | self._collage_maker.load(path) 441 | 442 | def loop(self): 443 | """Re-entrable loop to optmise collage.""" 444 | 445 | res_training = {} 446 | while self._y < self._tiles_high: 447 | while self._x < self._tiles_wide: 448 | if not self._collage_maker: 449 | self.initialise() 450 | self._collage_maker.loop() 451 | collage_img = self._collage_maker.high_res_render( 452 | self._segmented_data_high_res, 453 | self._tile_high_res_bg, 454 | gamma=1.0, 455 | show=self._config["gui"], 456 | save=True) 457 | self._collage_maker.high_res_render( 458 | self._segmented_data_high_res, 459 | self._tile_high_res_bg, 460 | gamma=1.0, 461 | show=False, 462 | save=True, 463 | no_background=True) 464 | self._save_tile(collage_img / 255) 465 | 466 | (last_step, last_loss) = self._collage_maker.finish() 467 | res_training[f"tile_{self._y}_{self._x}_loss"] = last_loss 468 | res_training[f"tile_{self._y}_{self._x}_step"] = last_step 469 | del self._collage_maker 470 | self._collage_maker = None 471 | self._x += 1 472 | self._y += 1 473 | self._x = 0 474 | 475 | # Save results of all optimisations. 476 | res_filename = f"{self._output_dir}/results_training.yaml" 477 | with open(res_filename, "w") as f: 478 | yaml.dump(res_training, f, default_flow_style=False, allow_unicode=True) 479 | 480 | return collage_img # SHWC 481 | 482 | def _save_tile(self, img): 483 | background_image_np = np.asarray(img) 484 | background_image_np = background_image_np[..., ::-1].copy() 485 | filename = self._tile_basename.format(self._y, self._x, ".npy") 486 | np.save(f"{self._output_dir}/{filename}", background_image_np) 487 | 488 | def _save_tile_arrays(self, all_arrays): 489 | filename = self._tile_basename.format(self._y, self._x, "_arrays.npy") 490 | np.save(f"{self._output_dir}/{filename}", all_arrays) 491 | 492 | def _scale_fixed_background(self, high_res=True): 493 | """Get correctly sized background image.""" 494 | 495 | if self._fixed_background_image is None: 496 | return None 497 | multiplier = self._high_res_multiplier if high_res else 1 498 | if self._background_use == "Local": 499 | height = self._tile_height * multiplier 500 | width = self._tile_width * multiplier 501 | elif self._background_use == "Global": 502 | height = self._height * multiplier 503 | width = self._width * multiplier 504 | return cv2.resize(self._fixed_background_image.astype(float), 505 | (width, height)) 506 | 507 | def _get_tile_background(self): 508 | """Get the background for a particular tile. 509 | 510 | This involves getting bordering imagery from left, top left, above and top 511 | right, where appropriate. 512 | i.e. tile (1,1) shares overlap with (0,1), (0,2) and (1,0) 513 | (0,0), (0,1), (0,2), (0,3) 514 | (1,0), (1,1), (1,2), (1,3) 515 | (2,0), (2,1), (2,2), (2,3) 516 | Note that (0,0) is not needed as its contribution is already in (0,1) 517 | 518 | Returns: 519 | background_image: small background for optimisation 520 | background_image_high_res: high resolution background 521 | """ 522 | if self._fixed_background is None: 523 | tile_border_bg = np.zeros((self._high_res_tile_height, 524 | self._high_res_tile_width, 3)) 525 | else: 526 | if self._background_use == "Local": 527 | tile_border_bg = self._fixed_background.copy() 528 | else: # Crop out section for this tile. 529 | orgin_y = self._y * (self._high_res_tile_height 530 | - math.ceil(self._tile_height * self._overlap) 531 | * self._high_res_multiplier) 532 | orgin_x = self._x * (self._high_res_tile_width 533 | - math.ceil(self._tile_width * self._overlap) 534 | * self._high_res_multiplier) 535 | tile_border_bg = self._fixed_background[ 536 | orgin_y : orgin_y + self._high_res_tile_height, 537 | orgin_x : orgin_x + self._high_res_tile_width, :] 538 | tile_idx = dict() 539 | if self._x > 0: 540 | tile_idx["left"] = (self._y, self._x - 1) 541 | if self._y > 0: 542 | tile_idx["above"] = (self._y - 1, self._x) 543 | if self._x < self._tiles_wide - 1: # Penultimate on the row 544 | tile_idx["above_right"] = (self._y - 1, self._x + 1) 545 | 546 | # Get and insert bodering tile content in this order. 547 | if "above" in tile_idx: 548 | self._copy_overlap(tile_border_bg, "above", tile_idx["above"]) 549 | if "above_right" in tile_idx: 550 | self._copy_overlap(tile_border_bg, "above_right", tile_idx["above_right"]) 551 | if "left" in tile_idx: 552 | self._copy_overlap(tile_border_bg, "left", tile_idx["left"]) 553 | 554 | background_image = self._resize_image_for_torch( 555 | tile_border_bg, self._tile_height, self._tile_width) 556 | background_image_high_res = self._resize_image_for_torch( 557 | tile_border_bg, 558 | self._high_res_tile_height, 559 | self._high_res_tile_width).to("cpu") 560 | 561 | return background_image, background_image_high_res 562 | 563 | def _resize_image_for_torch(self, img, height, width): 564 | # Resize and permute to format used by Collage class (SCHW). 565 | img = torch.tensor(cv2.resize(img.astype(float), (width, height))) 566 | if self._torch_device == "cuda": 567 | img = img.cuda() 568 | return img.permute(2, 0, 1).to(torch.float32) 569 | 570 | def _copy_overlap(self, target, location, tile_idx): 571 | """Copy area from tile adjacent to target tile to target tile.""" 572 | 573 | big_height = self._high_res_tile_height 574 | big_width = self._high_res_tile_width 575 | pixel_overlap = int(big_width * self._overlap) 576 | 577 | filename = self._tile_basename.format(tile_idx[0], tile_idx[1], ".npy") 578 | # print(f"Loading tile {filename}) 579 | source = np.load(f"{self._output_dir}/{filename}") 580 | if location == "above": 581 | target[0 : pixel_overlap, 0 : big_width, :] = source[ 582 | big_height - pixel_overlap : big_height, 0 : big_width, :] 583 | if location == "left": 584 | target[:, 0 : pixel_overlap, :] = source[ 585 | :, big_width - pixel_overlap : big_width, :] 586 | elif location == "above_right": 587 | target[ 588 | 0 : pixel_overlap, big_width - pixel_overlap : big_width, :] = source[ 589 | big_height - pixel_overlap : big_height, 0 : pixel_overlap, :] 590 | 591 | def assemble_tiles(self): 592 | """Stitch together the whole image from saved tiles.""" 593 | 594 | big_height = self._high_res_tile_height 595 | big_width = self._high_res_tile_width 596 | full_height = int((big_height + 2 * big_height * self._tiles_high) / 3) 597 | full_width = int((big_width + 2 * big_width * self._tiles_wide) / 3) 598 | full_image = np.zeros((full_height, full_width, 3)).astype("float32") 599 | 600 | for y in range(self._tiles_high): 601 | for x in range(self._tiles_wide): 602 | filename = self._tile_basename.format(y, x, ".npy") 603 | tile = np.load(f"{self._output_dir}/{filename}") 604 | y_offset = int(big_height * y * 2 / 3) 605 | x_offset = int(big_width * x * 2 / 3) 606 | full_image[y_offset : y_offset + big_height, 607 | x_offset : x_offset + big_width, :] = tile[:, :, :] 608 | filename = "final_tiled_image" 609 | print(f"Saving assembled tiles to {filename}") 610 | video_utils.show_and_save( 611 | full_image, self._config, img_format="SHWC", stitch=False, 612 | filename=filename, show=self._config["gui"]) 613 | -------------------------------------------------------------------------------- /arnheim_3/src/collage_generator.py: -------------------------------------------------------------------------------- 1 | """Collage network definition. 2 | 3 | Arnheim 3 - Collage 4 | Piotr Mirowski, Dylan Banarse, Mateusz Malinowski, Yotam Doron, Oriol Vinyals, 5 | Simon Osindero, Chrisantha Fernando 6 | DeepMind, 2021-2022 7 | 8 | Copyright 2021 DeepMind Technologies Limited 9 | 10 | Licensed under the Apache License, Version 2.0 (the "License"); 11 | you may not use this file except in compliance with the License. 12 | You may obtain a copy of the License at 13 | https://www.apache.org/licenses/LICENSE-2.0 14 | Unless required by applicable law or agreed to in writing, 15 | software distributed under the License is distributed on an "AS IS" BASIS, 16 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | See the License for the specific language governing permissions and 18 | limitations under the License. 19 | """ 20 | 21 | import copy 22 | from . import rendering 23 | from . import transformations 24 | import numpy as np 25 | import torch 26 | 27 | 28 | class PopulationCollage(torch.nn.Module): 29 | """Population-based segmentation collage network. 30 | 31 | Image structure in this class is SCHW. 32 | """ 33 | 34 | def __init__(self, 35 | config, 36 | device, 37 | pop_size=1, 38 | is_high_res=False, 39 | segmented_data=None, 40 | background_image=None): 41 | """Constructor, relying on global parameters.""" 42 | super(PopulationCollage, self).__init__() 43 | 44 | # Config, device, number of patches and population size. 45 | self.config = config 46 | self.device = device 47 | self._canvas_height = config["canvas_height"] 48 | self._canvas_width = config["canvas_width"] 49 | self._high_res_multiplier = config["high_res_multiplier"] 50 | self._num_patches = self.config["num_patches"] 51 | self._pop_size = pop_size 52 | requires_grad = not is_high_res 53 | 54 | # Create the spatial transformer and colour transformer for patches. 55 | self.spatial_transformer = transformations.PopulationAffineTransforms( 56 | config, device, num_patches=self._num_patches, pop_size=pop_size, 57 | requires_grad=requires_grad, is_high_res=is_high_res) 58 | if self.config["colour_transformations"] == "HSV space": 59 | self.colour_transformer = transformations.PopulationColourHSVTransforms( 60 | config, device, num_patches=self._num_patches, pop_size=pop_size, 61 | requires_grad=requires_grad) 62 | elif self.config["colour_transformations"] == "RGB space": 63 | self.colour_transformer = transformations.PopulationColourRGBTransforms( 64 | config, device, num_patches=self._num_patches, pop_size=pop_size, 65 | requires_grad=requires_grad) 66 | else: 67 | self.colour_transformer = transformations.PopulationOrderOnlyTransforms( 68 | config, device, num_patches=self._num_patches, pop_size=pop_size, 69 | requires_grad=requires_grad) 70 | if config["torch_device"] == "cuda": 71 | self.spatial_transformer = self.spatial_transformer.cuda() 72 | self.colour_transformer = self.colour_transformer.cuda() 73 | self.coloured_patches = None 74 | 75 | # Optimisation is run in low-res, final rendering is in high-res. 76 | self._high_res = is_high_res 77 | 78 | # Store the background image (low- and high-res). 79 | self.background_image = background_image 80 | if self.background_image is not None: 81 | print(f"Background image of size {self.background_image.shape}") 82 | 83 | # Store the dataset (low- and high-res). 84 | self._dataset = segmented_data 85 | # print(f"There are {len(self._dataset)} image patches in the dataset") 86 | 87 | # Initial set of indices pointing to self._num_patches first dataset images. 88 | self.patch_indices = [np.arange(self._num_patches) % len(self._dataset) 89 | for _ in range(pop_size)] 90 | 91 | # Patches in low and high-res, will be initialised on demand. 92 | self.patches = None 93 | 94 | def store_patches(self, population_idx=None): 95 | """Store the image patches for each population element.""" 96 | if self._high_res: 97 | for _ in range(20): 98 | print("NOT STORING HIGH-RES PATCHES") 99 | return 100 | 101 | if population_idx is not None and self.patches is not None: 102 | list_indices_population = [population_idx] 103 | self.patches[population_idx, :, :4, :, :] = 0 104 | else: 105 | list_indices_population = np.arange(self._pop_size) 106 | self.patches = torch.zeros( 107 | self._pop_size, self._num_patches, 5, self._canvas_height, 108 | self._canvas_width).to(self.device) 109 | 110 | # Put the segmented data into the patches. 111 | for i in list_indices_population: 112 | for j in range(self._num_patches): 113 | patch_i_j = self._fetch_patch(i, j, self._high_res) 114 | self.patches[i, j, ...] = patch_i_j 115 | 116 | def _fetch_patch(self, idx_population, idx_patch, is_high_res): 117 | """Helper function to fetch a patch and store on the whole canvas.""" 118 | k = self.patch_indices[idx_population][idx_patch] 119 | patch_j = torch.tensor( 120 | self._dataset[k].swapaxes(0, 2) / 255.0).to(self.device) 121 | width_j = patch_j.shape[1] 122 | height_j = patch_j.shape[2] 123 | if is_high_res: 124 | w0 = int((self._canvas_width * self._high_res_multiplier - width_j) 125 | / 2.0) 126 | h0 = int((self._canvas_height * self._high_res_multiplier - height_j) 127 | / 2.0) 128 | mapped_patch = torch.zeros( 129 | 5, 130 | self._canvas_height * self._high_res_multiplier, 131 | self._canvas_width * self._high_res_multiplier 132 | ).to("cpu") 133 | else: 134 | w0 = int((self._canvas_width - width_j) / 2.0) 135 | h0 = int((self._canvas_height - height_j) / 2.0) 136 | mapped_patch = torch.zeros( 137 | 5, self._canvas_height, self._canvas_width).to(self.device) 138 | mapped_patch[4, :, :] = 1.0 139 | mapped_patch[:4, w0:(w0 + width_j), h0:(h0 + height_j)] = patch_j 140 | return mapped_patch 141 | 142 | def copy_and_mutate_s(self, parent, child): 143 | with torch.no_grad(): 144 | # Copy the patches indices from the parent to the child. 145 | self.patch_indices[child] = copy.deepcopy(self.patch_indices[parent]) 146 | 147 | # Mutate the child patches with a single swap from the original dataset. 148 | if self.config["patch_mutation_probability"] > np.random.uniform(): 149 | idx_dataset = np.random.randint(len(self._dataset)) 150 | idx_patch = np.random.randint(self._num_patches) 151 | self.patch_indices[child][idx_patch] = idx_dataset 152 | 153 | # Update all the patches for the child. 154 | self.store_patches(child) 155 | 156 | self.spatial_transformer.copy_and_mutate_s(parent, child) 157 | self.colour_transformer.copy_and_mutate_s(parent, child) 158 | 159 | def copy_from(self, other, idx_to, idx_from): 160 | """Copy parameters from other collage generator, for selected indices.""" 161 | assert idx_to < self._pop_size 162 | with torch.no_grad(): 163 | self.patch_indices[idx_to] = copy.deepcopy(other.patch_indices[idx_from]) 164 | self.spatial_transformer.copy_from( 165 | other.spatial_transformer, idx_to, idx_from) 166 | self.colour_transformer.copy_from( 167 | other.colour_transformer, idx_to, idx_from) 168 | if not self._high_res: 169 | self.store_patches(idx_to) 170 | 171 | def forward(self, params=None): 172 | """Input-less forward function.""" 173 | 174 | assert not self._high_res 175 | if self.patches is None: 176 | self.store_patches() 177 | shifted_patches = self.spatial_transformer(self.patches) 178 | background_image = self.background_image 179 | if params is not None and "no_background" in params: 180 | print("Not using background_image") 181 | background_image = None 182 | 183 | self.coloured_patches = self.colour_transformer(shifted_patches) 184 | if self.config["render_method"] == "transparency": 185 | img = rendering.population_render_transparency( 186 | self.coloured_patches, 187 | invert_colours=self.config["invert_colours"], b=background_image) 188 | elif self.config["render_method"] == "masked_transparency_clipped": 189 | img = rendering.population_render_masked_transparency( 190 | self.coloured_patches, mode="clipped", 191 | invert_colours=self.config["invert_colours"], b=background_image) 192 | elif self.config["render_method"] == "masked_transparency_normed": 193 | img = rendering.population_render_masked_transparency( 194 | self.coloured_patches, mode="normed", 195 | invert_colours=self.config["invert_colours"], b=background_image) 196 | elif self.config["render_method"] == "opacity": 197 | img = rendering.population_render_overlap( 198 | self.coloured_patches, 199 | invert_colours=self.config["invert_colours"], b=background_image) 200 | else: 201 | print("Unhandled render method") 202 | if params is not None and "no_background" in params: 203 | print("Setting alpha to zero outside of patches") 204 | mask = self.coloured_patches[:, :, 3:4, :, :].sum(1) > 0 205 | mask = mask.permute(0, 2, 3, 1) 206 | img = torch.concat([img, mask], axis=-1) 207 | return img 208 | 209 | def forward_high_res(self, params=None): 210 | """Input-less forward function.""" 211 | 212 | assert self._high_res 213 | 214 | max_render_size = params.get("max_block_size_high_res", 1000) 215 | w = self._canvas_width * self._high_res_multiplier 216 | h = self._canvas_height * self._high_res_multiplier 217 | if (self._high_res_multiplier % 8 == 0 and 218 | self._canvas_width * 8 < max_render_size and 219 | self._canvas_height * 8 < max_render_size): 220 | num_w = int(self._high_res_multiplier / 8) 221 | num_h = int(self._high_res_multiplier / 8) 222 | delta_w = self._canvas_width * 8 223 | delta_h = self._canvas_height * 8 224 | elif (self._high_res_multiplier % 4 == 0 and 225 | self._canvas_width * 4 < max_render_size and 226 | self._canvas_height * 4 < max_render_size): 227 | num_w = int(self._high_res_multiplier / 4) 228 | num_h = int(self._high_res_multiplier / 4) 229 | delta_w = self._canvas_width * 4 230 | delta_h = self._canvas_height * 4 231 | elif (self._high_res_multiplier % 2 == 0 and 232 | self._canvas_width * 2 < max_render_size and 233 | self._canvas_height * 2 < max_render_size): 234 | num_w = int(self._high_res_multiplier / 2) 235 | num_h = int(self._high_res_multiplier / 2) 236 | delta_w = self._canvas_width * 2 237 | delta_h = self._canvas_height * 2 238 | else: 239 | num_w = self._high_res_multiplier 240 | num_h = self._high_res_multiplier 241 | delta_w = self._canvas_width 242 | delta_h = self._canvas_height 243 | 244 | img = torch.zeros((1, h, w, 4)) 245 | img[..., 3] = 1.0 246 | 247 | background_image = self.background_image 248 | if params is not None and "no_background" in params: 249 | print("Not using background_image") 250 | background_image = None 251 | 252 | for u in range(num_w): 253 | for v in range(num_h): 254 | x0 = u * delta_w 255 | x1 = (u + 1) * delta_w 256 | y0 = v * delta_h 257 | y1 = (v + 1) * delta_h 258 | print(f"[{u}, {v}] idx [{x0}:{x1}], [{y0}:{y1}]") 259 | 260 | # Extract full patches, apply spatial transform individually and crop. 261 | shifted_patches_uv = [] 262 | for idx_patch in range(self._num_patches): 263 | patch = self._fetch_patch(0, idx_patch, True).unsqueeze(0) 264 | patch_uv = self.spatial_transformer(patch, idx_patch) 265 | patch_uv = patch_uv[:, :, :, y0:y1, x0:x1] 266 | shifted_patches_uv.append(patch_uv) 267 | shifted_patches_uv = torch.cat(shifted_patches_uv, 1) 268 | 269 | # Crop background? 270 | if background_image is not None: 271 | background_image_uv = background_image[:, y0:y1, x0:x1] 272 | else: 273 | background_image_uv = None 274 | 275 | # Appy colour transform and render. 276 | coloured_patches_uv = self.colour_transformer(shifted_patches_uv) 277 | if self.config["render_method"] == "transparency": 278 | img_uv = rendering.population_render_transparency( 279 | coloured_patches_uv, 280 | invert_colours=self.config["invert_colours"], 281 | b=background_image_uv) 282 | elif self.config["render_method"] == "masked_transparency_clipped": 283 | img_uv = rendering.population_render_masked_transparency( 284 | coloured_patches_uv, mode="clipped", 285 | invert_colours=self.config["invert_colours"], 286 | b=background_image_uv) 287 | elif self.config["render_method"] == "masked_transparency_normed": 288 | img_uv = rendering.population_render_masked_transparency( 289 | coloured_patches_uv, mode="normed", 290 | invert_colours=self.config["invert_colours"], 291 | b=background_image_uv) 292 | elif self.config["render_method"] == "opacity": 293 | img_uv = rendering.population_render_overlap( 294 | coloured_patches_uv, 295 | invert_colours=self.config["invert_colours"], 296 | b=background_image_uv) 297 | else: 298 | print("Unhandled render method") 299 | 300 | if params is not None and "no_background" in params: 301 | print("Setting alpha to zero outside of patches") 302 | mask_uv = coloured_patches_uv[:, :, 3:4, :, :].sum(1) > 0 303 | mask_uv = mask_uv.permute(0, 2, 3, 1) 304 | img_uv = torch.concat([img_uv, mask_uv], axis=-1) 305 | img[0, y0:y1, x0:x1, :4] = img_uv 306 | else: 307 | img[0, y0:y1, x0:x1, :3] = img_uv 308 | print(f"Finished [{u}, {v}] idx [{x0}:{x1}], [{y0}:{y1}]") 309 | 310 | print(img.size()) 311 | return img 312 | 313 | def tensors_to(self, device): 314 | self.spatial_transformer.tensor_to(device) 315 | self.colour_transformer.tensor_to(device) 316 | if self.patches is not None: 317 | self.patches = self.patches.to(device) 318 | -------------------------------------------------------------------------------- /arnheim_3/src/patches.py: -------------------------------------------------------------------------------- 1 | """Loading and processing collage patches. 2 | 3 | Arnheim 3 - Collage 4 | Piotr Mirowski, Dylan Banarse, Mateusz Malinowski, Yotam Doron, Oriol Vinyals, 5 | Simon Osindero, Chrisantha Fernando 6 | DeepMind, 2021-2022 7 | 8 | Copyright 2021 DeepMind Technologies Limited 9 | 10 | Licensed under the Apache License, Version 2.0 (the "License"); 11 | you may not use this file except in compliance with the License. 12 | You may obtain a copy of the License at 13 | https://www.apache.org/licenses/LICENSE-2.0 14 | Unless required by applicable law or agreed to in writing, 15 | software distributed under the License is distributed on an "AS IS" BASIS, 16 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | See the License for the specific language governing permissions and 18 | limitations under the License. 19 | """ 20 | 21 | import cv2 22 | import numpy as np 23 | from .video_utils import cached_url_download 24 | from .video_utils import cv2_imshow 25 | 26 | SHOW_PATCHES = False 27 | 28 | 29 | def add_binary_alpha_mask(patch): 30 | """Black pixels treated as having alpha=0, all other pixels have alpha=255.""" 31 | 32 | mask = ((patch.sum(2) > 0) * 255).astype(np.uint8) 33 | return np.concatenate([patch, np.expand_dims(mask, -1)], axis=-1) 34 | 35 | 36 | def resize_patch(patch, coeff): 37 | return cv2.resize(patch.astype(float), 38 | (int(np.round(patch.shape[1] * coeff)), 39 | int(np.round(patch.shape[0] * coeff)))) 40 | 41 | 42 | def print_size_segmented_data(segmented_data, show=True): 43 | """Print debug information on patch sizes.""" 44 | 45 | size_max = 0 46 | shape_max = None 47 | size_min = np.infty 48 | shape_min = None 49 | for i, segment in enumerate(segmented_data): 50 | segment = segment.swapaxes(0, 1) 51 | shape_i = segment.shape 52 | size_i = shape_i[0] * shape_i[1] 53 | if size_i > size_max: 54 | shape_max = shape_i 55 | size_max = size_i 56 | if size_i < size_min: 57 | shape_min = shape_i 58 | size_min = size_i 59 | print(f"Patch {i} of shape {shape_i}") 60 | if show: 61 | im_i = cv2.cvtColor(segment, cv2.COLOR_RGBA2BGRA) 62 | im_bgr = im_i[:, :, :3] 63 | im_mask = np.tile(im_i[:, :, 3:], (1, 1, 3)) 64 | im_render = np.concatenate([im_bgr, im_mask], 1) 65 | cv2_imshow(im_render) 66 | print(f"{len(segmented_data)} patches, max {shape_max}, min {shape_min}\n") 67 | 68 | 69 | def get_segmented_data_initial(config): 70 | """Load patch file and return segmented image data.""" 71 | 72 | if config["url_to_patch_file"]: 73 | segmented_data_initial = cached_url_download(config["url_to_patch_file"]) 74 | else: 75 | repo_file = config["patch_set"] 76 | repo_root = config["patch_repo_root"] 77 | segmented_data_initial = cached_url_download( 78 | f"{repo_root}/{repo_file}") 79 | 80 | segmented_data_initial_tmp = [] 81 | for i in range(len(segmented_data_initial)): 82 | if segmented_data_initial[i].shape[2] == 3: 83 | segmented_data_initial_tmp.append(add_binary_alpha_mask( 84 | segmented_data_initial[i])) 85 | else: 86 | segmented_data_initial_tmp.append( 87 | segmented_data_initial[i]) 88 | 89 | segmented_data_initial = segmented_data_initial_tmp 90 | return segmented_data_initial 91 | 92 | 93 | def normalise_patch_brightness(patch): 94 | max_intensity = max(patch.max(), 1.0) 95 | return ((patch / max_intensity) * 255).astype(np.uint8) 96 | 97 | 98 | def get_segmented_data(config, index): 99 | """Generate low and high res patch data for a collage. 100 | 101 | Args: 102 | config: dict, config file and command line args 103 | index: int, which subset of options to use when multiple are available. 104 | E.g. selecting patch set based on tile number. 105 | Returns: 106 | numpy arrays: low and high resolution patch data. 107 | """ 108 | # Select tile's patch set and/or parameters if multiple provided. 109 | if ("multiple_patch_set" in config and isinstance( 110 | config["multiple_patch_set"], list) and 111 | config["multiple_patch_set"] != ['null']): 112 | config["patch_set"] = config["multiple_patch_set"][ 113 | index % len(config["multiple_patch_set"])] 114 | if ("multiple_fixed_scale_patches" in config and isinstance( 115 | config["multiple_fixed_scale_patches"], list) and 116 | config["multiple_fixed_scale_patches"] != ['null']): 117 | config["fixed_scale_patches"] = config["multiple_fixed_scale_patches"][ 118 | index % len(config["multiple_fixed_scale_patches"])] == "True" 119 | if ("multiple_patch_max_proportion" in config and isinstance( 120 | config["multiple_patch_max_proportion"], list) and 121 | config["multiple_patch_max_proportion"] != ['null']): 122 | config["patch_max_proportion"] = int(config[ 123 | "multiple_patch_max_proportion"][ 124 | index % len(config["multiple_patch_max_proportion"])]) 125 | if ("multiple_fixed_scale_coeff" in config and isinstance( 126 | config["multiple_fixed_scale_coeff"], list) and 127 | config["multiple_fixed_scale_coeff"] != ['null']): 128 | config["fixed_scale_coeff"] = float(config["multiple_fixed_scale_coeff"][ 129 | index % len(config["multiple_fixed_scale_coeff"])]) 130 | 131 | segmented_data_initial = get_segmented_data_initial(config) 132 | 133 | # Fixed order for the segmented images. 134 | num_patches = len(segmented_data_initial) 135 | order = np.arange(num_patches) 136 | # The following permutes the patches but precludes reloading checkpoints. 137 | # order = np.random.permutation(num_patches) 138 | 139 | # Compress all images until they are at most 1/PATCH_MAX_PROPORTION of the 140 | # large canvas size. 141 | canvas_height = config["canvas_height"] 142 | canvas_width = config["canvas_width"] 143 | hires_height = canvas_height * config["high_res_multiplier"] 144 | hires_width = canvas_width * config["high_res_multiplier"] 145 | height_large_max = hires_height / config["patch_max_proportion"] 146 | width_large_max = hires_width / config["patch_max_proportion"] 147 | print(f"Patch set {config['patch_set']}, fixed_scale_patches? " 148 | f"{config['fixed_scale_patches']}, " 149 | f"fixed_scale_coeff={config['fixed_scale_coeff']}, " 150 | f"patch_max_proportion={config['patch_max_proportion']}") 151 | if config["fixed_scale_patches"]: 152 | print(f"Max size for fixed scale patches: ({hires_height},{hires_width})") 153 | else: 154 | print( 155 | f"Max patch size on large img: ({height_large_max}, {width_large_max})") 156 | print(type(config["fixed_scale_patches"])) 157 | segmented_data = [] 158 | segmented_data_high_res = [] 159 | for patch_i in range(num_patches): 160 | segmented_data_initial_i = segmented_data_initial[ 161 | order[patch_i]].astype(np.float32).swapaxes(0, 1) 162 | shape_i = segmented_data_initial_i.shape 163 | h_i = shape_i[0] 164 | w_i = shape_i[1] 165 | if h_i >= config["patch_height_min"] and w_i >= config["patch_width_min"]: 166 | # Coefficient for resizing the patch. 167 | if config["fixed_scale_patches"]: 168 | coeff_i_large = config["fixed_scale_coeff"] 169 | if h_i * coeff_i_large > hires_height: 170 | coeff_i_large = hires_height / h_i 171 | if w_i * coeff_i_large > width_large_max: 172 | coeff_i_large = min(coeff_i_large, hires_width / w_i) 173 | if coeff_i_large != config["fixed_scale_coeff"]: 174 | print( 175 | f"Patch {patch_i} too large; scaled to {coeff_i_large:.2f}") 176 | else: 177 | coeff_i_large = 1.0 178 | if h_i > height_large_max: 179 | coeff_i_large = height_large_max / h_i 180 | if w_i > width_large_max: 181 | coeff_i_large = min(coeff_i_large, width_large_max / w_i) 182 | 183 | # Resize the high-res patch? 184 | if coeff_i_large < 1.0: 185 | # print(f"Patch {patch_i} scaled by {coeff_i_large:.2f}") 186 | segmented_data_high_res_i = resize_patch(segmented_data_initial_i, 187 | coeff_i_large) 188 | else: 189 | segmented_data_high_res_i = np.copy(segmented_data_initial_i) 190 | 191 | # Resize the low-res patch. 192 | coeff_i = coeff_i_large / config["high_res_multiplier"] 193 | segmented_data_i = resize_patch(segmented_data_initial_i, coeff_i) 194 | shape_i = segmented_data_i.shape 195 | if (shape_i[0] > canvas_height 196 | or shape_i[1] > config["canvas_width"]): 197 | 198 | print(f"{shape_i} exceeds canvas ({canvas_height},{canvas_width})") 199 | if config["normalize_patch_brightness"]: 200 | segmented_data_i[..., :3] = normalise_patch_brightness( 201 | segmented_data_i[..., :3]) 202 | segmented_data_high_res_i[..., :3] = normalise_patch_brightness( 203 | segmented_data_high_res_i[..., :3]) 204 | segmented_data_high_res_i = segmented_data_high_res_i.astype(np.uint8) 205 | segmented_data_high_res.append(segmented_data_high_res_i) 206 | segmented_data_i = segmented_data_i.astype(np.uint8) 207 | segmented_data.append(segmented_data_i) 208 | else: 209 | print(f"Discard patch of size {h_i}x{w_i}") 210 | 211 | if SHOW_PATCHES: 212 | print("Patch sizes during optimisation:") 213 | print_size_segmented_data(segmented_data, show=config["gui"]) 214 | print("Patch sizes for high-resolution final image:") 215 | print_size_segmented_data(segmented_data_high_res, show=config["gui"]) 216 | 217 | return segmented_data, segmented_data_high_res 218 | -------------------------------------------------------------------------------- /arnheim_3/src/rendering.py: -------------------------------------------------------------------------------- 1 | """RGB image rendering from patch data. 2 | 3 | Arnheim 3 - Collage 4 | Piotr Mirowski, Dylan Banarse, Mateusz Malinowski, Yotam Doron, Oriol Vinyals, 5 | Simon Osindero, Chrisantha Fernando 6 | DeepMind, 2021-2022 7 | 8 | Copyright 2021 DeepMind Technologies Limited 9 | 10 | Licensed under the Apache License, Version 2.0 (the "License"); 11 | you may not use this file except in compliance with the License. 12 | You may obtain a copy of the License at 13 | https://www.apache.org/licenses/LICENSE-2.0 14 | Unless required by applicable law or agreed to in writing, 15 | software distributed under the License is distributed on an "AS IS" BASIS, 16 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | See the License for the specific language governing permissions and 18 | limitations under the License. 19 | """ 20 | 21 | import torch 22 | import torch.nn.functional as F 23 | 24 | 25 | RENDER_EPSILON = 1e-8 26 | RENDER_OVERLAP_TEMPERATURE = 0.1 27 | RENDER_OVERLAP_ZERO_OFFSET = -5 28 | RENDER_OVERLAP_MASK_THRESHOLD = 0.5 29 | 30 | 31 | def population_render_transparency(x, invert_colours=False, b=None): 32 | """Render image from patches with transparancy. 33 | 34 | Renders patches with transparency using black as the transparent colour. 35 | Args: 36 | x: tensor of transformed RGB image patches of shape [S, B, 5, H, W]. 37 | invert_colours: Invert all RGB values. 38 | b: optional tensor of background RGB image of shape [S, 3, H, W]. 39 | Returns: 40 | Tensor of rendered RGB images of shape [S, 3, H, W]. 41 | """ 42 | # Sum the RGB patches [S, B, 3, H, W] as [S, 3, H, W]. 43 | x = x[:, :, :3, :, :] * x[:, :, 3:4, :, :] 44 | y = x[:, :, :3, :, :].sum(1) 45 | if invert_colours: 46 | y[:, :3, :, :] = 1.0 - y[:, :3, :, :] 47 | # Add backgrounds [S, 3, H, W]. 48 | if b is not None: 49 | b = b.cuda() if x.is_cuda else b.cpu() 50 | y = (y + b).clamp(0., 1.) 51 | return y.clamp(0., 1.).permute(0, 2, 3, 1) 52 | 53 | 54 | def population_render_masked_transparency( 55 | x, mode, invert_colours=False, b=None): 56 | """Render image from patches using alpha channel for patch transparency. 57 | 58 | Args: 59 | x: tensor of transformed RGB image patches of shape [S, B, 5, H, W]. 60 | mode: ["clipped" | "normed"], methods of handling alpha with background. 61 | invert_colours: invert RGB values 62 | b: optional tensor of background RGB image of shape [S, 3, H, W]. 63 | Returns: 64 | Tensor of rendered RGB images of shape [S, 3, H, W]. 65 | """ 66 | # Get the patch mask [S, B, 1, H, W] and sum of masks [S, 1, H, W]. 67 | mask = x[:, :, 3:4, :, :] 68 | mask_sum = mask.sum(1) + RENDER_EPSILON 69 | # Mask the RGB patches [S, B, 4, H, W] -> [S, B, 3, H, W]. 70 | masked_x = x[:, :, :3, :, :] * mask 71 | # Compute mean of the RGB patches [S, B, 3, H, W] as [S, 3, H, W]. 72 | x_sum = masked_x.sum(1) 73 | y = torch.where( 74 | mask_sum > RENDER_EPSILON, x_sum / mask_sum, mask_sum) 75 | if invert_colours: 76 | y[:, :3, :, :] = 1.0 - y[:, :3, :, :] 77 | # Add backgrounds [S, 3, H, W]. 78 | if b is not None: 79 | b = b.cuda() if x.is_cuda else b.cpu() 80 | if mode == "normed": 81 | mask_max = mask_sum.max( 82 | dim=2, keepdim=True).values.max(dim=3, keepdim=True).values 83 | mask = mask_sum / mask_max 84 | elif mode == "clipped": 85 | mask = mask_sum.clamp(0., 1.) 86 | else: 87 | raise ValueError(f"Unknown masked_transparency mode {mode}") 88 | y = y[:, :3, :, :] * mask + b.unsqueeze(0)[:, :3, :, :] * (1 - mask) 89 | return y.clamp(0., 1.).permute(0, 2, 3, 1) 90 | 91 | 92 | def population_render_overlap(x, invert_colours=False, b=None): 93 | """Render image, overlaying patches on top of one another. 94 | 95 | Uses semi-translucent overlap using the alpha chanel as the mask colour 96 | and the 5th channel as the order for the overlapped images. 97 | Args: 98 | x: tensor of transformed RGB image patches of shape [S, B, 5, H, W]. 99 | invert_colours: invert RGB values 100 | b: optional tensor of background RGB image of shape [S, 3, H, W]. 101 | Returns: 102 | Tensor of rendered RGB images of shape [S, 3, H, W]. 103 | """ 104 | # Get the patch mask [S, B, 1, H, W]. 105 | mask = x[:, :, 3:4, :, :] 106 | # Mask the patches [S, B, 4, H, W] -> [S, B, 3, H, W] 107 | masked_x = x[:, :, :3, :, :] * mask * mask 108 | # Mask the orders [S, B, 1, H, W] -> [S, B, 1, H, W] 109 | order = torch.where( 110 | mask > RENDER_OVERLAP_MASK_THRESHOLD, 111 | x[:, :, 4:, :, :] * mask / RENDER_OVERLAP_TEMPERATURE, 112 | mask + RENDER_OVERLAP_ZERO_OFFSET) 113 | # Get weights from orders [S, B, 1, H, W] 114 | weights = F.softmax(order, dim=1) 115 | # Apply weights to masked patches and compute mean over patches [S, 3, H, W]. 116 | y = (weights * masked_x).sum(1) 117 | if invert_colours: 118 | y[:, :3, :, :] = 1.0 - y[:, :3, :, :] 119 | if b is not None: 120 | b = b.cuda() if x.is_cuda else b.cpu() 121 | y = torch.where(mask.sum(1) > RENDER_OVERLAP_MASK_THRESHOLD, y[:, :3, :, :], 122 | b.unsqueeze(0)[:, :3, :, :]) 123 | return y.clamp(0., 1.).permute(0, 2, 3, 1) 124 | -------------------------------------------------------------------------------- /arnheim_3/src/training.py: -------------------------------------------------------------------------------- 1 | """Functions for the optimisation (including evolution) and evaluation. 2 | 3 | Arnheim 3 - Collage 4 | Piotr Mirowski, Dylan Banarse, Mateusz Malinowski, Yotam Doron, Oriol Vinyals, 5 | Simon Osindero, Chrisantha Fernando 6 | DeepMind, 2021-2022 7 | 8 | Copyright 2021 DeepMind Technologies Limited 9 | 10 | Licensed under the Apache License, Version 2.0 (the "License"); 11 | you may not use this file except in compliance with the License. 12 | You may obtain a copy of the License at 13 | https://www.apache.org/licenses/LICENSE-2.0 14 | Unless required by applicable law or agreed to in writing, 15 | software distributed under the License is distributed on an "AS IS" BASIS, 16 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | See the License for the specific language governing permissions and 18 | limitations under the License. 19 | """ 20 | 21 | import clip 22 | from matplotlib import pyplot as plt 23 | import numpy as np 24 | import time 25 | import torch 26 | import torchvision.transforms as transforms 27 | from .video_utils import show_and_save 28 | 29 | 30 | # Show each image being evaluated for debugging purposes. 31 | VISUALISE_BATCH_IMAGES = False 32 | 33 | 34 | def augmentation_transforms(canvas_width, 35 | use_normalized_clip=False, 36 | use_augmentation=False): 37 | """Image transforms to produce distorted crops to augment the evaluation. 38 | 39 | Args: 40 | canvas_width: width of the drawing canvas 41 | use_normalized_clip: Normalisation to better suit CLIP's training data 42 | use_augmentation: Image augmentation by affine transform 43 | Returns: 44 | transforms 45 | """ 46 | if use_normalized_clip and use_augmentation: 47 | augment_trans = transforms.Compose( 48 | [transforms.RandomPerspective(fill=1, p=1, distortion_scale=0.6), 49 | transforms.RandomResizedCrop(canvas_width, scale=(0.7, 0.9)), 50 | transforms.Normalize((0.48145466, 0.4578275, 0.40821073), 51 | (0.26862954, 0.26130258, 0.27577711))]) 52 | elif use_augmentation: 53 | augment_trans = transforms.Compose([ 54 | transforms.RandomPerspective(fill=1, p=1, distortion_scale=0.6), 55 | transforms.RandomResizedCrop(canvas_width, scale=(0.7, 0.9)), 56 | ]) 57 | elif use_normalized_clip: 58 | augment_trans = transforms.Normalize( 59 | (0.48145466, 0.4578275, 0.40821073), 60 | (0.26862954, 0.26130258, 0.27577711)) 61 | else: 62 | augment_trans = transforms.RandomPerspective( 63 | fill=1, p=0, distortion_scale=0) 64 | 65 | return augment_trans 66 | 67 | 68 | def moving_average(a, n=3): 69 | ret = np.cumsum(a, dtype=float) 70 | ret[n:] = ret[n:] - ret[:-n] 71 | return ret[n - 1:] / n 72 | 73 | 74 | def plot_and_save_losses( 75 | loss_history, title="Losses", filename=None, show=True): 76 | """Plot losses and save to file.""" 77 | 78 | losses = np.array(loss_history) 79 | if filename: 80 | np.save(filename + ".npy", losses, allow_pickle=True) 81 | if show: 82 | plt.figure(figsize=(10, 10)) 83 | plt.xlabel("Training steps") 84 | plt.ylabel("Loss") 85 | plt.title(title) 86 | plt.plot(moving_average(losses, n=3)) 87 | plt.savefig(filename + ".png") 88 | 89 | 90 | def make_optimizer(generator, learning_rate): 91 | """Make optimizer for generator's parameters. 92 | 93 | Args: 94 | generator: generator model 95 | learning_rate: learning rate 96 | Returns: 97 | optimizer 98 | """ 99 | 100 | my_list = ["positions_top"] 101 | params = list(map(lambda x: x[1], list(filter(lambda kv: kv[0] in my_list, 102 | generator.named_parameters())))) 103 | base_params = list(map( 104 | lambda x: x[1], list(filter( 105 | lambda kv: kv[0] not in my_list, generator.named_parameters())))) 106 | lr_scheduler = torch.optim.SGD([{"params": base_params}, 107 | {"params": params, "lr": learning_rate}], 108 | lr=learning_rate) 109 | return lr_scheduler 110 | 111 | 112 | def compute_text_features(prompts, clip_model, device): 113 | """Compute CLIP features for all prompts.""" 114 | 115 | text_inputs = [] 116 | for prompt in prompts: 117 | text_inputs.append(clip.tokenize(prompt).to(device)) 118 | 119 | features = [] 120 | with torch.no_grad(): 121 | for text_input in text_inputs: 122 | features.append(clip_model.encode_text(text_input)) 123 | return features 124 | 125 | 126 | def create_augmented_batch(images, augment_trans, text_features, config): 127 | """Create batch of images to be evaluated. 128 | 129 | Args: 130 | images: batch of images to be augmented [N, C, H, W] 131 | augment_trans: transformations for augmentations 132 | text_features: text feature per image 133 | config: dictionary with config 134 | Returns: 135 | img_batch: Augmented versions of the original images [N*num_augs, C, H, W] 136 | num_augs: number of images per original image 137 | expanded_text_features: a text feature for each augmentation 138 | loss_weights: weights for the losses corresponding to each augmentation 139 | """ 140 | images = images.permute(0, 3, 1, 2) # NHWC -> NCHW 141 | expanded_text_features = [] 142 | if config["use_image_augmentations"]: 143 | num_augs = config["num_augs"] 144 | img_augs = [] 145 | for _ in range(num_augs): 146 | img_n = augment_trans(images) 147 | img_augs.append(img_n) 148 | expanded_text_features.append(text_features[0]) 149 | img_batch = torch.cat(img_augs) 150 | # Given images [P0, P1] and augmentations [a0(), a1()], output format: 151 | # [a0(P0), a0(P1), a1(P0), a1(P1)] 152 | else: 153 | num_augs = 1 154 | img_batch = augment_trans(images) 155 | expanded_text_features.append(text_features[0]) 156 | return img_batch, num_augs, expanded_text_features, [1] * config["num_augs"] 157 | 158 | 159 | def create_compositional_batch(images, augment_trans, text_features): 160 | """Create 10 sub-images per image by augmenting each with 3x3 crops. 161 | 162 | Args: 163 | images: population of N images, format [N, C, H, W] 164 | augment_trans: transformations for augmentations 165 | text_features: text feature per image 166 | Returns: 167 | Tensor of all compositional sub-images + originals; [N*10, C, H, W] format: 168 | [x0_y0(P0) ... x0_y0(PN), ..., x2_y2(P0) ... x2_y2(PN), P0, ..., PN] 169 | 10: Number of sub-images + whole, per original image. 170 | expanded_text_features: list of text features, 1 for each composition image 171 | loss_weights: weights for the losses corresponding to each composition image 172 | """ 173 | if len(text_features) != 10: 174 | # text_features should already be 10 in size. 175 | raise ValueError( 176 | "10 text prompts required for compositional image creation") 177 | resize_for_clip = transforms.Compose([transforms.Resize((224, 224))]) 178 | img_swap = torch.swapaxes(images, 3, 1) 179 | ims = [] 180 | i = 0 181 | for x in range(3): 182 | for y in range(3): 183 | for k in range(images.shape[0]): 184 | ims.append(resize_for_clip( 185 | img_swap[k][:, y * 112 : y * 112 + 224, x * 112 : x * 112 + 224])) 186 | i += 1 187 | 188 | # Top-level (whole) images 189 | for k in range(images.shape[0]): 190 | ims.append(resize_for_clip(img_swap[k])) 191 | all_img = torch.stack(ims) 192 | all_img = torch.swapaxes(all_img, 1, 3) 193 | all_img = all_img.permute(0, 3, 1, 2) # NHWC -> NCHW 194 | all_img = augment_trans(all_img) 195 | 196 | # Last image gets 9 times as much weight 197 | common_weight = 1 / 5 198 | loss_weights = [common_weight] * 9 199 | loss_weights.append(9 * common_weight) 200 | return all_img, 10, text_features, loss_weights 201 | 202 | 203 | def evaluation(t, clip_enc, generator, augment_trans, text_features, 204 | prompts, config, device): 205 | """Do a step of evaluation, returning images and losses. 206 | 207 | Args: 208 | t: step count 209 | clip_enc: model for CLIP encoding 210 | generator: drawing generator to optimise 211 | augment_trans: transforms for image augmentation 212 | text_features: tuple with the prompt two negative prompts 213 | prompts: for debugging/visualisation - the list of text prompts 214 | config: dictionary with hyperparameters 215 | device: torch device 216 | Returns: 217 | loss: torch.Tensor of single combines loss 218 | losses_separate_np: numpy array of loss for each image 219 | losses_individuals_np: numpy array with loss for each population individual 220 | img_np: numpy array of images from the generator 221 | """ 222 | 223 | # Annealing parameters. 224 | params = {"gamma": t / config["optim_steps"]} 225 | 226 | # Rebuild the generator. 227 | img = generator(params) 228 | img_np = img.detach().cpu().numpy() 229 | 230 | # Create images for different regions 231 | pop_size = img.shape[0] 232 | if config["compositional_image"]: 233 | (img_batch, num_augs, text_features, loss_weights 234 | ) = create_compositional_batch(img, augment_trans, text_features) 235 | else: 236 | (img_batch, num_augs, text_features, loss_weights 237 | ) = create_augmented_batch(img, augment_trans, text_features, config) 238 | losses = torch.zeros(pop_size, num_augs).to(device) 239 | 240 | # Compute and add losses after augmenting the image with transforms. 241 | img_batch = torch.clip(img_batch, 0, 1) # clip the images. 242 | image_features = clip_enc.encode_image(img_batch) 243 | count = 0 244 | for n in range(num_augs): # number of augmentations or composition images 245 | for p in range(pop_size): 246 | loss = torch.cosine_similarity( 247 | text_features[n], image_features[count:count+1], dim=1 248 | )[0] * loss_weights[n] 249 | losses[p, n] -= loss 250 | if VISUALISE_BATCH_IMAGES and t % 500 == 0: 251 | # Show all the images in the batch along with their losses. 252 | if config["compositional_image"]: 253 | print(f"Loss {loss} for image region with prompt {prompts[n]}:") 254 | else: 255 | print(f"Loss {loss} for image augmentation with prompt {prompts[0]}:") 256 | show_and_save(img_batch[count].unsqueeze(0), config, 257 | img_format="SCHW", show=config["gui"]) 258 | count += 1 259 | loss = torch.sum(losses) / pop_size 260 | losses_separate_np = losses.detach().cpu().numpy() 261 | # Sum losses for all each population individual. 262 | losses_individuals_np = losses_separate_np.sum(axis=1) 263 | return loss, losses_separate_np, losses_individuals_np, img_np 264 | 265 | 266 | def step_optimization(t, clip_enc, lr_scheduler, generator, augment_trans, 267 | text_features, prompts, config, device, final_step=False): 268 | """Do a step of optimization. 269 | 270 | Args: 271 | t: step count 272 | clip_enc: model for CLIP encoding 273 | lr_scheduler: optimizer 274 | generator: drawing generator to optimise 275 | augment_trans: transforms for image augmentation 276 | text_features: list or 1 or 9 prompts for normal and compositional creation 277 | prompts: for debugging/visualisation - the list of text prompts 278 | config: dictionary with hyperparameters 279 | device: CUDA device 280 | final_step: if True does extras such as saving the model 281 | Returns: 282 | losses_np: numpy array with loss for each population individual 283 | losses_separate_np: numpy array of loss for each image 284 | """ 285 | 286 | # Anneal learning rate and other parameters. 287 | t0 = time.time() 288 | if t == int(config["optim_steps"] / 3): 289 | for g in lr_scheduler.param_groups: 290 | g["lr"] = g["lr"] / 2.0 291 | if t == int(config["optim_steps"] * (2/3)): 292 | for g in lr_scheduler.param_groups: 293 | g["lr"] = g["lr"] / 2.0 294 | 295 | # Forward pass. 296 | lr_scheduler.zero_grad() 297 | loss, losses_separate_np, losses_np, img_np = evaluation( 298 | t=t, clip_enc=clip_enc, generator=generator, augment_trans=augment_trans, 299 | text_features=text_features, prompts=prompts, config=config, 300 | device=device) 301 | 302 | # Backpropagate the gradients. 303 | loss.backward() 304 | torch.nn.utils.clip_grad_norm(generator.parameters(), 305 | config["gradient_clipping"]) 306 | 307 | # Decay the learning rate. 308 | lr_scheduler.step() 309 | 310 | # Render the big version. 311 | if final_step: 312 | show_and_save( 313 | img_np, config, t=t, img_format="SHWC", show=config["gui"]) 314 | output_dir = config["output_dir"] 315 | print(f"Saving model to {output_dir}...") 316 | torch.save(generator.state_dict(), f"{output_dir}/generator.pt") 317 | 318 | if t % config["trace_every"] == 0: 319 | output_dir = config["output_dir"] 320 | filename = f"{output_dir}/optim_{t}" 321 | show_and_save(img_np, config, 322 | max_display=config["max_multiple_visualizations"], 323 | stitch=True, img_format="SHWC", 324 | show=config["gui"], 325 | filename=filename) 326 | 327 | t1 = time.time() 328 | print("Iteration {:3d}, rendering loss {:.6f}, {:.3f}s/iter".format( 329 | t, loss.item(), t1-t0)) 330 | return losses_np, losses_separate_np, img_np 331 | 332 | 333 | def population_evolution_step(generator, config, losses): 334 | """GA for the population.""" 335 | 336 | if config["ga_method"] == "Microbial": 337 | 338 | # Competition between 2 random individuals; mutated winner replaces loser. 339 | indices = list(range(len(losses))) 340 | np.random.shuffle(indices) 341 | select_1, select_2 = indices[0], indices[1] 342 | if losses[select_1] < losses[select_2]: 343 | generator.copy_and_mutate_s(select_1, select_2) 344 | else: 345 | generator.copy_and_mutate_s(select_2, select_1) 346 | elif config["ga_method"] == "Evolutionary Strategies": 347 | 348 | # Replace rest of population with mutants of the best. 349 | winner = np.argmin(losses) 350 | for other in range(len(losses)): 351 | if other == winner: 352 | continue 353 | generator.copy_and_mutate_s(winner, other) 354 | -------------------------------------------------------------------------------- /arnheim_3/src/transformations.py: -------------------------------------------------------------------------------- 1 | """Colour and affine transform classes. 2 | 3 | Arnheim 3 - Collage 4 | Piotr Mirowski, Dylan Banarse, Mateusz Malinowski, Yotam Doron, Oriol Vinyals, 5 | Simon Osindero, Chrisantha Fernando 6 | DeepMind, 2021-2022 7 | 8 | Copyright 2021 DeepMind Technologies Limited 9 | 10 | Licensed under the Apache License, Version 2.0 (the "License"); 11 | you may not use this file except in compliance with the License. 12 | You may obtain a copy of the License at 13 | https://www.apache.org/licenses/LICENSE-2.0 14 | Unless required by applicable law or agreed to in writing, 15 | software distributed under the License is distributed on an "AS IS" BASIS, 16 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | See the License for the specific language governing permissions and 18 | limitations under the License. 19 | """ 20 | 21 | from kornia.color import hsv 22 | 23 | import numpy as np 24 | import torch 25 | import torch.nn.functional as F 26 | 27 | 28 | class PopulationAffineTransforms(torch.nn.Module): 29 | """Population-based Affine Transform network.""" 30 | 31 | def __init__(self, config, device, num_patches=1, pop_size=1, 32 | requires_grad=True, is_high_res=False): 33 | super(PopulationAffineTransforms, self).__init__() 34 | 35 | self.config = config 36 | self.device = device 37 | self._pop_size = pop_size 38 | self._is_high_res = is_high_res 39 | print('PopulationAffineTransforms is_high_res={}, requires_grad={}'.format( 40 | self._is_high_res, requires_grad)) 41 | 42 | self._min_rot = self.config['min_rot_deg'] * np.pi / 180. 43 | self._max_rot = self.config['max_rot_deg'] * np.pi / 180. 44 | matrices_translation = ( 45 | (np.random.rand(pop_size, num_patches, 2, 1) 46 | * (self.config['max_trans_init'] - self.config['min_trans_init'])) 47 | + self.config['min_trans_init']) 48 | matrices_rotation = ( 49 | (np.random.rand(pop_size, num_patches, 1, 1) 50 | * (self._max_rot - self._min_rot)) + self._min_rot) 51 | matrices_scale = ( 52 | (np.random.rand(pop_size, num_patches, 1, 1) 53 | * (self.config['max_scale'] - self.config['min_scale'])) 54 | + self.config['min_scale']) 55 | matrices_squeeze = ( 56 | (np.random.rand(pop_size, num_patches, 1, 1) * ( 57 | (self.config['max_squeeze'] - self.config['min_squeeze']) 58 | + self.config['min_squeeze']))) 59 | matrices_shear = ( 60 | (np.random.rand(pop_size, num_patches, 1, 1) 61 | * (self.config['max_shear'] - self.config['min_shear'])) 62 | + self.config['min_shear']) 63 | self.translation = torch.nn.Parameter( 64 | torch.tensor(matrices_translation, dtype=torch.float), 65 | requires_grad=requires_grad) 66 | self.rotation = torch.nn.Parameter( 67 | torch.tensor(matrices_rotation, dtype=torch.float), 68 | requires_grad=requires_grad) 69 | self.scale = torch.nn.Parameter( 70 | torch.tensor(matrices_scale, dtype=torch.float), 71 | requires_grad=requires_grad) 72 | self.squeeze = torch.nn.Parameter( 73 | torch.tensor(matrices_squeeze, dtype=torch.float), 74 | requires_grad=requires_grad) 75 | self.shear = torch.nn.Parameter( 76 | torch.tensor(matrices_shear, dtype=torch.float), 77 | requires_grad=requires_grad) 78 | self._identity = ( 79 | torch.ones((pop_size, num_patches, 1, 1)) * torch.eye(2).unsqueeze(0) 80 | ).to(self.device) 81 | self._zero_column = torch.zeros( 82 | (pop_size, num_patches, 2, 1)).to(self.device) 83 | self._unit_row = ( 84 | torch.ones((pop_size, num_patches, 1, 1)) * torch.tensor([0., 0., 1.]) 85 | ).to(self.device) 86 | self._zeros = torch.zeros((pop_size, num_patches, 1, 1)).to(self.device) 87 | 88 | def _clamp(self): 89 | self.translation.data = self.translation.data.clamp( 90 | min=self.config['min_trans'], max=self.config['max_trans']) 91 | self.rotation.data = self.rotation.data.clamp( 92 | min=self._min_rot, max=self._max_rot) 93 | self.scale.data = self.scale.data.clamp( 94 | min=self.config['min_scale'], max=self.config['max_scale']) 95 | self.squeeze.data = self.squeeze.data.clamp( 96 | min=self.config['min_squeeze'], max=self.config['max_squeeze']) 97 | self.shear.data = self.shear.data.clamp( 98 | min=self.config['min_shear'], max=self.config['max_shear']) 99 | 100 | def copy_and_mutate_s(self, parent, child): 101 | """Copy parameters to child, mutating transform parameters.""" 102 | with torch.no_grad(): 103 | self.translation[child, ...] = ( 104 | self.translation[parent, ...] 105 | + self.config['pos_and_rot_mutation_scale'] * torch.randn( 106 | self.translation[child, ...].shape).to(self.device)) 107 | self.rotation[child, ...] = ( 108 | self.rotation[parent, ...] 109 | + self.config['pos_and_rot_mutation_scale'] * torch.randn( 110 | self.rotation[child, ...].shape).to(self.device)) 111 | self.scale[child, ...] = ( 112 | self.scale[parent, ...] 113 | + self.config['scale_mutation_scale'] * torch.randn( 114 | self.scale[child, ...].shape).to(self.device)) 115 | self.squeeze[child, ...] = ( 116 | self.squeeze[parent, ...] 117 | + self.config['distort_mutation_scale'] * torch.randn( 118 | self.squeeze[child, ...].shape).to(self.device)) 119 | self.shear[child, ...] = ( 120 | self.shear[parent, ...] 121 | + self.config['distort_mutation_scale'] * torch.randn( 122 | self.shear[child, ...].shape).to(self.device)) 123 | 124 | def copy_from(self, other, idx_to, idx_from): 125 | """Copy parameters from other spatial transform, for selected indices.""" 126 | assert idx_to < self._pop_size 127 | with torch.no_grad(): 128 | self.translation[idx_to, ...] = other.translation[idx_from, ...] 129 | self.rotation[idx_to, ...] = other.rotation[idx_from, ...] 130 | self.scale[idx_to, ...] = other.scale[idx_from, ...] 131 | self.squeeze[idx_to, ...] = other.squeeze[idx_from, ...] 132 | self.shear[idx_to, ...] = other.shear[idx_from, ...] 133 | 134 | def forward(self, x, idx_patch=None): 135 | self._clamp() 136 | scale_affine_mat = torch.cat([ 137 | torch.cat([self.scale, self.shear], 3), 138 | torch.cat([self._zeros, self.scale * self.squeeze], 3)], 2) 139 | scale_affine_mat = torch.cat([ 140 | torch.cat([scale_affine_mat, self._zero_column], 3), 141 | self._unit_row], 2) 142 | rotation_affine_mat = torch.cat([ 143 | torch.cat([torch.cos(self.rotation), -torch.sin(self.rotation)], 3), 144 | torch.cat([torch.sin(self.rotation), torch.cos(self.rotation)], 3)], 2) 145 | rotation_affine_mat = torch.cat([ 146 | torch.cat([rotation_affine_mat, self._zero_column], 3), 147 | self._unit_row], 2) 148 | 149 | scale_rotation_mat = torch.matmul(scale_affine_mat, 150 | rotation_affine_mat)[:, :, :2, :] 151 | # Population and patch dimensions (0 and 1) need to be merged. 152 | # E.g. from (POP_SIZE, NUM_PATCHES, CHANNELS, WIDTH, HEIGHT) 153 | # to (POP_SIZE * NUM_PATCHES, CHANNELS, WIDTH, HEIGHT) 154 | if idx_patch is not None and self._is_high_res: 155 | scale_rotation_mat = scale_rotation_mat[:, idx_patch, :, :] 156 | num_patches = 1 157 | else: 158 | scale_rotation_mat = scale_rotation_mat[:, :, :2, :].view( 159 | 1, -1, *(scale_rotation_mat[:, :, :2, :].size()[2:])).squeeze() 160 | num_patches = x.size()[1] 161 | x = x.view(1, -1, *(x.size()[2:])).squeeze() 162 | # print('scale_rotation_mat', scale_rotation_mat.size()) 163 | # print('x', x.size()) 164 | scaled_rotated_grid = F.affine_grid( 165 | scale_rotation_mat, x.size(), align_corners=True) 166 | scaled_rotated_x = F.grid_sample(x, scaled_rotated_grid, align_corners=True) 167 | 168 | translation_affine_mat = torch.cat([self._identity, self.translation], 3) 169 | if idx_patch is not None and self._is_high_res: 170 | translation_affine_mat = translation_affine_mat[:, idx_patch, :, :] 171 | else: 172 | translation_affine_mat = translation_affine_mat.view( 173 | 1, -1, *(translation_affine_mat.size()[2:])).squeeze() 174 | # print('translation_affine_mat', translation_affine_mat.size()) 175 | # print('scaled_rotated_x', scaled_rotated_x.size()) 176 | translated_grid = F.affine_grid( 177 | translation_affine_mat, scaled_rotated_x.size(), align_corners=True) 178 | y = F.grid_sample(scaled_rotated_x, translated_grid, align_corners=True) 179 | # print('y', y.size()) 180 | # print('num_patches', num_patches) 181 | return y.view(self._pop_size, num_patches, *(y.size()[1:])) 182 | 183 | def tensor_to(self, device): 184 | self.translation = self.translation.to(device) 185 | self.rotation = self.rotation.to(device) 186 | self.scale = self.scale.to(device) 187 | self.squeeze = self.squeeze.to(device) 188 | self.shear = self.shear.to(device) 189 | self._identity = self._identity.to(device) 190 | self._zero_column = self._zero_column.to(device) 191 | self._unit_row = self._unit_row.to(device) 192 | self._zeros = self._zeros.to(device) 193 | 194 | 195 | class PopulationOrderOnlyTransforms(torch.nn.Module): 196 | """No color transforms, just ordering of patches.""" 197 | 198 | def __init__(self, config, device, num_patches=1, pop_size=1, 199 | requires_grad=True): 200 | super(PopulationOrderOnlyTransforms, self).__init__() 201 | 202 | self.config = config 203 | self.device = device 204 | self._pop_size = pop_size 205 | print(f'PopulationOrderOnlyTransforms requires_grad={requires_grad}') 206 | 207 | population_zeros = np.ones((pop_size, num_patches, 1, 1, 1)) 208 | population_orders = np.random.rand(pop_size, num_patches, 1, 1, 1) 209 | 210 | self._zeros = torch.nn.Parameter( 211 | torch.tensor(population_zeros, dtype=torch.float), 212 | requires_grad=False) 213 | self.orders = torch.nn.Parameter( 214 | torch.tensor(population_orders, dtype=torch.float), 215 | requires_grad=requires_grad) 216 | self._hsv_to_rgb = hsv.HsvToRgb() 217 | 218 | def _clamp(self): 219 | self.orders.data = self.orders.data.clamp(min=0.0, max=1.0) 220 | 221 | def copy_and_mutate_s(self, parent, child): 222 | with torch.no_grad(): 223 | self.orders[child, ...] = self.orders[parent, ...] 224 | 225 | def copy_from(self, other, idx_to, idx_from): 226 | """Copy parameters from other colour transform, for selected indices.""" 227 | assert idx_to < self._pop_size 228 | with torch.no_grad(): 229 | self.orders[idx_to, ...] = other.orders[idx_from, ...] 230 | 231 | def forward(self, x): 232 | self._clamp() 233 | colours = torch.cat( 234 | [self._zeros, self._zeros, self._zeros, self._zeros, self.orders], 235 | 2) 236 | return colours * x 237 | 238 | def tensor_to(self, device): 239 | self.orders = self.orders.to(device) 240 | self._zeros = self._zeros.to(device) 241 | 242 | 243 | class PopulationColourHSVTransforms(torch.nn.Module): 244 | """HSV color transforms and ordering of patches.""" 245 | 246 | def __init__(self, config, device, num_patches=1, pop_size=1, 247 | requires_grad=True): 248 | super(PopulationColourHSVTransforms, self).__init__() 249 | 250 | self.config = config 251 | self.device = device 252 | print('PopulationColourHSVTransforms for {} patches, {} individuals'.format( 253 | num_patches, pop_size)) 254 | self._pop_size = pop_size 255 | self._min_hue = self.config['min_hue_deg'] * np.pi / 180. 256 | self._max_hue = self.config['max_hue_deg'] * np.pi / 180. 257 | print(f'PopulationColourHSVTransforms requires_grad={requires_grad}') 258 | 259 | coeff_hue = (0.5 * (self._max_hue - self._min_hue) + self._min_hue) 260 | coeff_sat = (0.5 * (self.config['max_sat'] - self.config['min_sat']) 261 | + self.config['min_sat']) 262 | coeff_val = (0.5 * (self.config['max_val'] - self.config['min_val']) 263 | + self.config['min_val']) 264 | population_hues = (np.random.rand(pop_size, num_patches, 1, 1, 1) 265 | * coeff_hue) 266 | population_saturations = np.random.rand( 267 | pop_size, num_patches, 1, 1, 1) * coeff_sat 268 | population_values = np.random.rand( 269 | pop_size, num_patches, 1, 1, 1) * coeff_val 270 | population_zeros = np.ones((pop_size, num_patches, 1, 1, 1)) 271 | population_orders = np.random.rand(pop_size, num_patches, 1, 1, 1) 272 | 273 | self.hues = torch.nn.Parameter( 274 | torch.tensor(population_hues, dtype=torch.float), 275 | requires_grad=requires_grad) 276 | self.saturations = torch.nn.Parameter( 277 | torch.tensor(population_saturations, dtype=torch.float), 278 | requires_grad=requires_grad) 279 | self.values = torch.nn.Parameter( 280 | torch.tensor(population_values, dtype=torch.float), 281 | requires_grad=requires_grad) 282 | self._zeros = torch.nn.Parameter( 283 | torch.tensor(population_zeros, dtype=torch.float), 284 | requires_grad=False) 285 | self.orders = torch.nn.Parameter( 286 | torch.tensor(population_orders, dtype=torch.float), 287 | requires_grad=requires_grad) 288 | self._hsv_to_rgb = hsv.HsvToRgb() 289 | 290 | def _clamp(self): 291 | self.hues.data = self.hues.data.clamp( 292 | min=self._min_hue, max=self._max_hue) 293 | self.saturations.data = self.saturations.data.clamp( 294 | min=self.config['min_sat'], max=self.config['max_sat']) 295 | self.values.data = self.values.data.clamp( 296 | min=self.config['min_val'], max=self.config['max_val']) 297 | self.orders.data = self.orders.data.clamp(min=0.0, max=1.0) 298 | 299 | def copy_and_mutate_s(self, parent, child): 300 | with torch.no_grad(): 301 | self.hues[child, ...] = ( 302 | self.hues[parent, ...] 303 | + self.config['colour_mutation_scale'] * torch.randn( 304 | self.hues[child, ...].shape).to(self.device)) 305 | self.saturations[child, ...] = ( 306 | self.saturations[parent, ...] 307 | + self.config['colour_mutation_scale'] * torch.randn( 308 | self.saturations[child, ...].shape).to(self.device)) 309 | self.values[child, ...] = ( 310 | self.values[parent, ...] 311 | + self.config['colour_mutation_scale'] * torch.randn( 312 | self.values[child, ...].shape).to(self.device)) 313 | self.orders[child, ...] = self.orders[parent, ...] 314 | 315 | def copy_from(self, other, idx_to, idx_from): 316 | """Copy parameters from other colour transform, for selected indices.""" 317 | assert idx_to < self._pop_size 318 | with torch.no_grad(): 319 | self.hues[idx_to, ...] = other.hues[idx_from, ...] 320 | self.saturations[idx_to, ...] = other.saturations[idx_from, ...] 321 | self.values[idx_to, ...] = other.values[idx_from, ...] 322 | self.orders[idx_to, ...] = other.orders[idx_from, ...] 323 | 324 | def forward(self, image): 325 | self._clamp() 326 | colours = torch.cat( 327 | [self.hues, self.saturations, self.values, self._zeros, self.orders], 2) 328 | hsv_image = colours * image 329 | rgb_image = self._hsv_to_rgb(hsv_image[:, :, :3, :, :]) 330 | return torch.cat([rgb_image, hsv_image[:, :, 3:, :, :]], axis=2) 331 | 332 | def tensor_to(self, device): 333 | self.hues = self.hues.to(device) 334 | self.saturations = self.saturations.to(device) 335 | self.values = self.values.to(device) 336 | self.orders = self.orders.to(device) 337 | self._zeros = self._zeros.to(device) 338 | 339 | 340 | class PopulationColourRGBTransforms(torch.nn.Module): 341 | """RGB color transforms and ordering of patches.""" 342 | 343 | def __init__(self, config, device, num_patches=1, pop_size=1, 344 | requires_grad=True): 345 | super(PopulationColourRGBTransforms, self).__init__() 346 | 347 | self.config = config 348 | self.device = device 349 | print('PopulationColourRGBTransforms for {} patches, {} individuals'.format( 350 | num_patches, pop_size)) 351 | self._pop_size = pop_size 352 | print(f'PopulationColourRGBTransforms requires_grad={requires_grad}') 353 | 354 | rgb_init_range = ( 355 | self.config['initial_max_rgb'] - self.config['initial_min_rgb']) 356 | population_reds = ( 357 | np.random.rand(pop_size, num_patches, 1, 1, 1) 358 | * rgb_init_range) + self.config['initial_min_rgb'] 359 | population_greens = ( 360 | np.random.rand(pop_size, num_patches, 1, 1, 1) 361 | * rgb_init_range) + self.config['initial_min_rgb'] 362 | population_blues = ( 363 | np.random.rand(pop_size, num_patches, 1, 1, 1) 364 | * rgb_init_range) + self.config['initial_min_rgb'] 365 | population_zeros = np.ones((pop_size, num_patches, 1, 1, 1)) 366 | population_orders = np.random.rand(pop_size, num_patches, 1, 1, 1) 367 | 368 | self.reds = torch.nn.Parameter( 369 | torch.tensor(population_reds, dtype=torch.float), 370 | requires_grad=requires_grad) 371 | self.greens = torch.nn.Parameter( 372 | torch.tensor(population_greens, dtype=torch.float), 373 | requires_grad=requires_grad) 374 | self.blues = torch.nn.Parameter( 375 | torch.tensor(population_blues, dtype=torch.float), 376 | requires_grad=requires_grad) 377 | self._zeros = torch.nn.Parameter( 378 | torch.tensor(population_zeros, dtype=torch.float), 379 | requires_grad=False) 380 | self.orders = torch.nn.Parameter( 381 | torch.tensor(population_orders, dtype=torch.float), 382 | requires_grad=requires_grad) 383 | 384 | def _clamp(self): 385 | self.reds.data = self.reds.data.clamp( 386 | min=self.config['min_rgb'], max=self.config['max_rgb']) 387 | self.greens.data = self.greens.data.clamp( 388 | min=self.config['min_rgb'], max=self.config['max_rgb']) 389 | self.blues.data = self.blues.data.clamp( 390 | min=self.config['min_rgb'], max=self.config['max_rgb']) 391 | self.orders.data = self.orders.data.clamp(min=0.0, max=1.0) 392 | 393 | def copy_and_mutate_s(self, parent, child): 394 | with torch.no_grad(): 395 | self.reds[child, ...] = ( 396 | self.reds[parent, ...] 397 | + self.config['colour_mutation_scale'] * torch.randn( 398 | self.reds[child, ...].shape).to(self.device)) 399 | self.greens[child, ...] = ( 400 | self.greens[parent, ...] 401 | + self.config['colour_mutation_scale'] * torch.randn( 402 | self.greens[child, ...].shape).to(self.device)) 403 | self.blues[child, ...] = ( 404 | self.blues[parent, ...] 405 | + self.config['colour_mutation_scale'] * torch.randn( 406 | self.blues[child, ...].shape).to(self.device)) 407 | self.orders[child, ...] = self.orders[parent, ...] 408 | 409 | def copy_from(self, other, idx_to, idx_from): 410 | """Copy parameters from other colour transform, for selected indices.""" 411 | assert idx_to < self._pop_size 412 | with torch.no_grad(): 413 | self.reds[idx_to, ...] = other.reds[idx_from, ...] 414 | self.greens[idx_to, ...] = other.greens[idx_from, ...] 415 | self.blues[idx_to, ...] = other.blues[idx_from, ...] 416 | self.orders[idx_to, ...] = other.orders[idx_from, ...] 417 | 418 | def forward(self, x): 419 | self._clamp() 420 | colours = torch.cat( 421 | [self.reds, self.greens, self.blues, self._zeros, self.orders], 2) 422 | return colours * x 423 | 424 | def tensor_to(self, device): 425 | self.reds = self.reds.to(device) 426 | self.greens = self.greens.to(device) 427 | self.blues = self.blues.to(device) 428 | self.orders = self.orders.to(device) 429 | self._zeros = self._zeros.to(device) 430 | -------------------------------------------------------------------------------- /arnheim_3/src/video_utils.py: -------------------------------------------------------------------------------- 1 | """Video utility functions, image rendering and display. 2 | 3 | Arnheim 3 - Collage 4 | Piotr Mirowski, Dylan Banarse, Mateusz Malinowski, Yotam Doron, Oriol Vinyals, 5 | Simon Osindero, Chrisantha Fernando 6 | DeepMind, 2021-2022 7 | 8 | Copyright 2021 DeepMind Technologies Limited 9 | 10 | Licensed under the Apache License, Version 2.0 (the "License"); 11 | you may not use this file except in compliance with the License. 12 | You may obtain a copy of the License at 13 | https://www.apache.org/licenses/LICENSE-2.0 14 | Unless required by applicable law or agreed to in writing, 15 | software distributed under the License is distributed on an "AS IS" BASIS, 16 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | See the License for the specific language governing permissions and 18 | limitations under the License. 19 | """ 20 | 21 | import io 22 | import os 23 | import pathlib 24 | import cv2 25 | import numpy as np 26 | import requests 27 | import torch 28 | 29 | 30 | try: 31 | from google.colab.patches import cv2_imshow # pylint: disable=g-import-not-at-top 32 | except: # pylint: disable=bare-except 33 | 34 | def cv2_imshow(img, name="CollageGenerator"): 35 | if img.dtype == np.float32 and img.max() > 1.: 36 | img = img.astype(np.uint8) 37 | cv2.imshow(name, img) 38 | cv2.waitKey(1) 39 | 40 | 41 | def load_image(filename, as_cv2_image=False, show=False): 42 | """Load an image as [0,1] RGB numpy array or cv2 image format.""" 43 | img = cv2.imread(filename) 44 | if show: 45 | cv2_imshow(img) 46 | if as_cv2_image: 47 | return img # With colour format BGR 48 | img = np.asarray(img) 49 | return img[..., ::-1] / 255. # Reverse colour dim to convert BGR to RGB 50 | 51 | 52 | def cached_url_download(url, file_format="np_array"): 53 | """Download file from URL and cache locally.""" 54 | cache_filename = os.path.basename(url) 55 | cache = pathlib.Path(cache_filename) 56 | if not cache.is_file(): 57 | print(f"Downloading {cache_filename} from {url}") 58 | r = requests.get(url) 59 | bytesio_object = io.BytesIO(r.content) 60 | with open(cache_filename, "wb") as f: 61 | f.write(bytesio_object.getbuffer()) 62 | else: 63 | print("Using cached version of " + cache_filename) 64 | if file_format == "np_array": 65 | return np.load(cache, allow_pickle=True) 66 | elif file_format == "cv2_image": 67 | return load_image(cache.name, as_cv2_image=True, show=False) 68 | elif file_format == "image_as_np": 69 | return load_image(cache.name, as_cv2_image=False, show=False) 70 | 71 | 72 | def layout_img_batch(img_batch, max_display=None): 73 | img_np = img_batch.transpose(0, 2, 1, 3).clip(0.0, 1.0) # S, W, H, C 74 | if max_display: 75 | img_np = img_np[:max_display, ...] 76 | sp = img_np.shape 77 | img_np[:, 0, :, :] = 1.0 # White line separator 78 | img_stitch = np.reshape(img_np, (sp[1] * sp[0], sp[2], sp[3])) 79 | img_r = img_stitch.transpose(1, 0, 2) # H, W, C 80 | return img_r 81 | 82 | 83 | def show_stitched_batch(img_batch, max_display=1, show=True): 84 | """Display stitched image batch. 85 | Args: 86 | img: image batch to display 87 | max_display: max number of images to display from population 88 | show: whether to display the image 89 | Returns: 90 | stitched image 91 | """ 92 | 93 | img_np = img_batch.detach().cpu().numpy() 94 | img_np = np.clip(img_np, 0.0, 1.0) 95 | num_images = img_np.shape[0] 96 | img_np = img_np.transpose((0, 2, 3, 1)) 97 | laid_out = layout_img_batch(img_np, max_display) 98 | if show: 99 | cv2_imshow(cv2.cvtColor(laid_out, cv2.COLOR_BGR2RGB) * 255) 100 | return laid_out 101 | 102 | 103 | def show_and_save(img_batch, config, t=None, 104 | max_display=1, stitch=True, 105 | img_format="SCHW", show=True, filename=None): 106 | """Save and display images. 107 | 108 | Args: 109 | img_batch: batch of images to display 110 | config: dictionary of all config settings 111 | t: time step 112 | max_display: max number of images to display from population 113 | stitch: append images side-by-side 114 | img_format: SHWC or SCHW (the latter used by CLIP) 115 | show: whether to display the image 116 | filename: save image using filename, if provided 117 | ) 118 | Returns: 119 | stitched image or None 120 | """ 121 | 122 | if isinstance(img_batch, torch.Tensor): 123 | img_np = img_batch.detach().cpu().numpy() 124 | else: 125 | img_np = img_batch 126 | 127 | if len(img_np.shape) == 3: 128 | # if not a batch make it one. 129 | img_np = np.expand_dims(img_np, axis=0) 130 | 131 | if not stitch: 132 | print(f"image (not stitch) min {img_np.min()}, max {img_np.max()}") 133 | for i in range(min(max_display, img_np.shape[0])): 134 | img = img_np[i] 135 | if img_format == "SCHW": # Convert to SHWC 136 | img = np.transpose(img, (1, 2, 0)) 137 | img = np.clip(img, 0.0, 1.0) 138 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) * 255 139 | if filename is not None: 140 | if img.shape[1] > config["canvas_width"]: 141 | filename = "highres_" + filename 142 | output_dir = config["output_dir"] 143 | filename = f"{output_dir}/{filename}_{str(i)}" 144 | if t is not None: 145 | filename += "_t_" + str(t) 146 | filename += ".png" 147 | print(f"Saving image {filename} (shape={img.shape})") 148 | cv2.imwrite(filename, img) 149 | if show: 150 | cv2_imshow(img) 151 | return None 152 | else: 153 | print(f"image (stitch) min {img_np.min()}, max {img_np.max()}") 154 | img_np = np.clip(img_np, 0.0, 1.0) 155 | if img_format == "SCHW": # Convert to SHWC 156 | img_np = img_np.transpose((0, 2, 3, 1)) 157 | laid_out = layout_img_batch(img_np, max_display) 158 | if filename is not None: 159 | filename += ".png" 160 | print(f"Saving temporary image {filename} (shape={laid_out.shape})") 161 | cv2.imwrite(filename, cv2.cvtColor(laid_out, cv2.COLOR_BGR2RGB) * 255) 162 | if show: 163 | cv2_imshow(cv2.cvtColor(laid_out, cv2.COLOR_BGR2RGB) * 255) 164 | return laid_out 165 | 166 | 167 | class VideoWriter: 168 | """Create a video from image frames.""" 169 | 170 | def __init__(self, filename="_autoplay.mp4", fps=20.0, show=False, **kw): 171 | """Video creator. 172 | 173 | Creates and display a video made from frames. The default 174 | filename causes the video to be displayed on exit. 175 | Args: 176 | filename: name of video file 177 | fps: frames per second for video 178 | show: display video on close 179 | **kw: args to be passed to FFMPEG_VideoWriter 180 | Returns: 181 | VideoWriter instance. 182 | """ 183 | 184 | self.writer = None 185 | self.params = dict(filename=filename, fps=fps, **kw) 186 | self._show = show 187 | print("No video writing implemented") 188 | 189 | def add(self, img): 190 | """Add image to video. 191 | 192 | Add new frame to image file, creating VideoWriter if requried. 193 | Args: 194 | img: array-like frame, shape [X, Y, 3] or [X, Y] 195 | Returns: 196 | None 197 | """ 198 | pass 199 | # img = np.asarray(img) 200 | # if self.writer is None: 201 | # h, w = img.shape[:2] 202 | # self.writer = FFMPEG_VideoWriter(size=(w, h), **self.params) 203 | # if img.dtype in [np.float32, np.float64]: 204 | # img = np.uint8(img.clip(0, 1)*255) 205 | # if len(img.shape) == 2: 206 | # img = np.repeat(img[..., None], 3, -1) 207 | # self.writer.write_frame(img) 208 | 209 | def close(self): 210 | if self.writer: 211 | self.writer.close() 212 | 213 | def __enter__(self): 214 | return self 215 | 216 | def __exit__(self, *kw): 217 | self.close() 218 | if self.params["filename"] == "_autoplay.mp4": 219 | self.show() 220 | 221 | def show(self, **kw): 222 | """Display video. 223 | 224 | Args: 225 | **kw: args to be passed to mvp.ipython_display 226 | Returns: 227 | None 228 | """ 229 | self.close() 230 | fn = self.params["filename"] 231 | if self._show: 232 | display(mvp.ipython_display(fn, **kw)) # pylint: disable=undefined-variable 233 | -------------------------------------------------------------------------------- /arnheim_3_patch_maker.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "9G-3dxY01xJd" 7 | }, 8 | "source": [ 9 | "# Arnheim 3 - Segmented Patch Creator\n", 10 | "DeepMind, 2021\n", 11 | "\n", 12 | "## Intructions\n", 13 | "This Colab is to support the creation of segmented patches for collage creation using Arnheim 3.\n", 14 | "\n", 15 | "The Colab uses [PixelLib](https://github.com/ayoolaolafenwa/PixelLib) and is pretty basic but good enough to get one started creating patches from JPG images.\n", 16 | "\n", 17 | "The process is\n", 18 | "\n", 19 | "1) Provide source images\n", 20 | "\n", 21 | "* Upload images using this Colab to either Google Drive or the temp folder\n", 22 | "* Alternatively use a Google Drive folder that already contains images\n", 23 | "\n", 24 | "2) Create segmented patches\n", 25 | "* The patch file is save to Google Drive. Be sure to copy the location of the file in the Arnheim 3 Colab.\n", 26 | "\n", 27 | "\n" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "metadata": { 34 | "cellView": "form", 35 | "id": "-X9f5OnKJ1-O" 36 | }, 37 | "outputs": [], 38 | "source": [ 39 | "#@title Installations\n", 40 | "!pip3 install pixellib\n", 41 | "!pip3 install tensorflow==2.0.1\n", 42 | "!pip3 install Keras==2.3.0\n", 43 | "!pip3 install h5py==2.10.0" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": null, 49 | "metadata": { 50 | "id": "K41eEXVbu1k6" 51 | }, 52 | "outputs": [], 53 | "source": [ 54 | "#@title Imports\n", 55 | "import glob\n", 56 | "from google.colab import drive\n", 57 | "from google.colab import files\n", 58 | "import io\n", 59 | "import numpy as np\n", 60 | "import os\n", 61 | "import pathlib\n", 62 | "import pixellib\n", 63 | "from pixellib.instance import instance_segmentation\n", 64 | "import requests\n" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": null, 70 | "metadata": { 71 | "id": "A_k0GI-du3E1" 72 | }, 73 | "outputs": [], 74 | "source": [ 75 | "#@title Function definitions\n", 76 | "\n", 77 | "def mkdir(path):\n", 78 | " pathlib.Path(path).mkdir(parents=True, exist_ok=True)\n", 79 | "\n", 80 | "def upload_files(target_path):\n", 81 | " \"\"\"Upload files to target directory.\"\"\"\n", 82 | " mkdir(target_path)\n", 83 | " uploaded = files.upload()\n", 84 | " for k, v in uploaded.items():\n", 85 | " open(target_path + \"/\" + k, 'wb').write(v)\n", 86 | " return list(uploaded.keys())\n", 87 | "\n", 88 | "\n", 89 | "def download_from_url(url, force=False):\n", 90 | " \"\"\"Download file from URL and cache it.\"\"\"\n", 91 | "\n", 92 | " cache_dir = \"/content/cache\"\n", 93 | " mkdir(cache_dir)\n", 94 | " cache_filename = f\"{cache_dir}/{os.path.basename(url)}\"\n", 95 | " cache = pathlib.Path(cache_filename)\n", 96 | " if not cache.is_file() or force:\n", 97 | " print(\"Downloading \" + url)\n", 98 | " r = requests.get(url)\n", 99 | " bytesio_object = io.BytesIO(r.content)\n", 100 | " with open(cache_filename, \"wb\") as f:\n", 101 | " f.write(bytesio_object.getbuffer())\n", 102 | " else:\n", 103 | " print(\"Using cached version of \" + url)\n", 104 | " return cache " 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": null, 110 | "metadata": { 111 | "cellView": "form", 112 | "id": "hpsPBACa4gps" 113 | }, 114 | "outputs": [], 115 | "source": [ 116 | "#@title Authorise and mount Google Drive\n", 117 | "ROOT = \"/content\"\n", 118 | "MOUNT_DIR = f\"{ROOT}/drive\"\n", 119 | "drive.mount(MOUNT_DIR)\n", 120 | "# ROOT_PATH = f\"{MOUNT_DIR}/MyDrive/Arnheim3\"\n", 121 | "# \n", 122 | "# mkdir(ROOT)\n", 123 | "# IMAGE_PATH = f\"{ROOT}/source_images\"\n", 124 | "# SEGMENTED_PATH = f\"{ROOT}/segmented\"\n", 125 | "# \n", 126 | "# print(f\"\\nUsing base directory: {ROOT}\")\n", 127 | "# print(f\"Source images directory: {IMAGE_PATH}\")\n", 128 | "# print(f\"Segmented directory: {SEGMENTED_PATH}\")" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": null, 134 | "metadata": { 135 | "cellView": "form", 136 | "collapsed": true, 137 | "id": "f2Dhp16ptHuF" 138 | }, 139 | "outputs": [], 140 | "source": [ 141 | "#@title Source images and target file locations\n", 142 | "#@markdown Source images can be stored temporarily with the Colab, be already on Google Drive, or can be uploaded to Google Drive.\n", 143 | "use_google_drive_for_source_images = True #@param {type:\"boolean\"}\n", 144 | "#@markdown Source images (if stored on Google Drive)\n", 145 | "GOOGLE_DRIVE_PATH_SOURCE_IMAGES = \"Art/Collage/Images\" #@param {type:\"string\"}\n", 146 | "#@markdown Target segmentation file will be saved to Google Drive for use with the Arnheim 3 Colab.\n", 147 | "SEGMENTED_DATA_FILENAME = \"fruit.npy\" #@param {type: \"string\"}\n", 148 | "GOOGLE_DRIVE_PATH_SEGMENTED_DATA = \"Art/Collage/Patches\" #@param {type:\"string\"}\n", 149 | "\n", 150 | "data_path = MOUNT_DIR + \"/MyDrive/\" + GOOGLE_DRIVE_PATH_SEGMENTED_DATA\n", 151 | "data_file = data_path + \"/\" + SEGMENTED_DATA_FILENAME\n", 152 | "\n", 153 | "if use_google_drive_for_source_images:\n", 154 | " IMAGE_PATH = MOUNT_DIR + \"/MyDrive/\" + GOOGLE_DRIVE_PATH_SOURCE_IMAGES\n", 155 | "else:\n", 156 | " IMAGE_PATH = f\"{ROOT}/source_images\"\n", 157 | "mkdir(IMAGE_PATH)\n", 158 | "mkdir(data_path)\n", 159 | "\n", 160 | "print(f\"Source images directory: {IMAGE_PATH}\")\n", 161 | "print(f\"Segmented data will be saved to: {data_file}\")" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": null, 167 | "metadata": { 168 | "cellView": "form", 169 | "id": "Ylun-hVm5iGq" 170 | }, 171 | "outputs": [], 172 | "source": [ 173 | "#@title Run this cell to upload a new set of images to segment\n", 174 | "empty_target_dir_before_upload = False #@param {type:\"boolean\"}\n", 175 | "\n", 176 | "if empty_target_dir_before_upload:\n", 177 | " !rm {IMAGE_PATH}/*\n", 178 | "\n", 179 | "upload_files(IMAGE_PATH)\n", 180 | "print(f\"Images uploaded images to {IMAGE_PATH}\")\n", 181 | "\n", 182 | "!ls -l {IMAGE_PATH}" 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": null, 188 | "metadata": { 189 | "id": "TEmqp9LOdv9P" 190 | }, 191 | "outputs": [], 192 | "source": [ 193 | "#@title Segment images and save patch file\n", 194 | "\n", 195 | "# https://pixellib.readthedocs.io/en/latest/Image_instance.html\n", 196 | "segment_image = instance_segmentation()\n", 197 | "segmentation_model_file = download_from_url(\n", 198 | " \"https://github.com/ayoolaolafenwa/PixelLib/releases/download/1.2/mask_rcnn_coco.h5\")\n", 199 | "segment_image.load_model(segmentation_model_file)\n", 200 | "\n", 201 | "imagefiles = []\n", 202 | "for file in glob.glob(f\"{IMAGE_PATH}/*.jpg\"):\n", 203 | " imagefiles.append(file)\n", 204 | "\n", 205 | "print(imagefiles)\n", 206 | "print(\"num images to process = \", len(imagefiles))\n", 207 | "\n", 208 | "segmented_images = []\n", 209 | "for imagefile in imagefiles:\n", 210 | " print(imagefile)\n", 211 | " try:\n", 212 | " seg, _ = segment_image.segmentImage(\n", 213 | " imagefile,\n", 214 | " extract_segmented_objects=True,\n", 215 | " save_extracted_objects =False,\n", 216 | " show_bboxes=False,\n", 217 | " output_image_name=str(imagefile) + \"______.tiff\")\n", 218 | " except:\n", 219 | " print(\"Error encounted - skipping\")\n", 220 | " continue\n", 221 | "\n", 222 | " if not len(seg[\"extracted_objects\"]):\n", 223 | " print(\"Failed to segment\", imagefile)\n", 224 | " else:\n", 225 | " for result in seg[\"extracted_objects\"]:\n", 226 | " print(result.shape)\n", 227 | " segmented_image = result[..., ::-1].copy()\n", 228 | " segmented_images.append(segmented_image)\n", 229 | "\n", 230 | "with open(data_file, \"wb\") as f:\n", 231 | " np.save(f, segmented_images)\n", 232 | "print(\"Saved patch file to\", data_file)" 233 | ] 234 | } 235 | ], 236 | "metadata": { 237 | "colab": { 238 | "collapsed_sections": [], 239 | "name": "MakeSegmentedPatches.ipynb", 240 | "provenance": [] 241 | }, 242 | "kernelspec": { 243 | "display_name": "Python 3", 244 | "name": "python3" 245 | }, 246 | "language_info": { 247 | "name": "python" 248 | } 249 | }, 250 | "nbformat": 4, 251 | "nbformat_minor": 0 252 | } 253 | -------------------------------------------------------------------------------- /collage_patches/animals.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/arnheim/2607ea41607a11e44c3de6348edd13672f85c52a/collage_patches/animals.npy -------------------------------------------------------------------------------- /collage_patches/fruit.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/arnheim/2607ea41607a11e44c3de6348edd13672f85c52a/collage_patches/fruit.npy -------------------------------------------------------------------------------- /collage_patches/handwritten_mnist.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/arnheim/2607ea41607a11e44c3de6348edd13672f85c52a/collage_patches/handwritten_mnist.npy -------------------------------------------------------------------------------- /collage_patches/shore_glass.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/arnheim/2607ea41607a11e44c3de6348edd13672f85c52a/collage_patches/shore_glass.npy -------------------------------------------------------------------------------- /images/arnheim3_examples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/arnheim/2607ea41607a11e44c3de6348edd13672f85c52a/images/arnheim3_examples.png -------------------------------------------------------------------------------- /images/bulls_ballet_faces_nature.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/arnheim/2607ea41607a11e44c3de6348edd13672f85c52a/images/bulls_ballet_faces_nature.jpg -------------------------------------------------------------------------------- /images/chicken.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/arnheim/2607ea41607a11e44c3de6348edd13672f85c52a/images/chicken.png -------------------------------------------------------------------------------- /images/dancer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/arnheim/2607ea41607a11e44c3de6348edd13672f85c52a/images/dancer.png -------------------------------------------------------------------------------- /images/face.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/arnheim/2607ea41607a11e44c3de6348edd13672f85c52a/images/face.png -------------------------------------------------------------------------------- /images/fall_of_the_damned.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/arnheim/2607ea41607a11e44c3de6348edd13672f85c52a/images/fall_of_the_damned.jpg -------------------------------------------------------------------------------- /images/fruit_bowl_animals.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/arnheim/2607ea41607a11e44c3de6348edd13672f85c52a/images/fruit_bowl_animals.png -------------------------------------------------------------------------------- /images/fruit_bowl_fruit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/arnheim/2607ea41607a11e44c3de6348edd13672f85c52a/images/fruit_bowl_fruit.png -------------------------------------------------------------------------------- /images/objects.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/arnheim/2607ea41607a11e44c3de6348edd13672f85c52a/images/objects.png -------------------------------------------------------------------------------- /images/swans_masked_transparency.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/arnheim/2607ea41607a11e44c3de6348edd13672f85c52a/images/swans_masked_transparency.png -------------------------------------------------------------------------------- /images/waves.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/arnheim/2607ea41607a11e44c3de6348edd13672f85c52a/images/waves.png --------------------------------------------------------------------------------