├── .gitignore ├── LICENSE.txt ├── NOTICE.txt ├── README.md ├── configs └── terra_4m.json ├── data_pipeline ├── README.md ├── __init__.py ├── bake_texture_maps.py ├── bake_texture_maps_internal.py ├── blur_images.py ├── generate_eval_set.py ├── generate_test_set.py ├── geotiff_to_png.py ├── image_dataset_to_parquet.py ├── isolate_uncorrupted_heightmaps.py ├── organize_heightmaps_into_folders.py ├── read_tiff.py ├── rename_dataset_images.py ├── rename_heightmaps.py ├── scrape_earthexplorer.py ├── split_heightmaps.py └── train_corrupted_heightmap_discriminator.py ├── fid ├── __init__.py ├── __main__.py └── inception.py ├── flake.lock ├── flake.nix ├── images └── display_heightmaps.png ├── models ├── common │ └── config_utils.py └── terra.py ├── sampling └── diffusion.py └── utilities ├── inception_test.py └── tf_to_onnx.py /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | wandb/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /NOTICE.txt: -------------------------------------------------------------------------------- 1 | Copyright 2023 Hayden Donnelly 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Neural Terrain Generation 2 | Neural Terrain Generation (NTG) is a collection of generative neural networks that output heightmaps for 3D terrain. This repository contains code for the entire NTG training pipeline. 3 | 4 | ## Development Environment 5 | The NTG development environment is managed with Nix. You can follow the steps below to get started. 6 | 1. Install Nix with the [official installer](https://nixos.org/download/) or the [determinate installer](https://github.com/DeterminateSystems/nix-installer). 7 | 2. Enable the experimental Nix Flakes feature by adding the following line to ``~/.config/nix/nix.conf`` or ``/etc/nix/nix.conf`` 8 | (this step can be skipped if you installed nix with the [determinate installer](https://github.com/DeterminateSystems/nix-installer)). 9 | ``` 10 | experimental-features = nix-command flakes 11 | ``` 12 | 3. Run the following command to open a development shell with all the dependencies installed. 13 | ``` 14 | nix develop --impure 15 | ``` 16 | -------------------------------------------------------------------------------- /configs/terra_4m.json: -------------------------------------------------------------------------------- 1 | { 2 | "embedding_dim": 32, 3 | "embedding_max_frequency": 1000.0, 4 | "num_features": [16, 32, 64, 128, 256], 5 | "num_groups": [16, 32, 32, 32, 32], 6 | "kernel_size": 3, 7 | "block_depth": 2, 8 | "output_channels": 1, 9 | "activation_fn": "silu", 10 | "weight_decay": 0.0001, 11 | "optimizer": "adamw", 12 | "adam_b1": 0.9, 13 | "adam_b2": 0.99, 14 | "adam_eps": 0.00000001, 15 | "adam_eps_root": 0.0, 16 | "lr_base": 0.00001, 17 | "lr_max": 0.0001, 18 | "lr_min": 0.00001, 19 | "lr_warmup_epochs": 2, 20 | "lr_decay_rate": 0.8, 21 | "lr_decay_epochs": 1, 22 | "image_size": 256, 23 | "epochs": 1000, 24 | "min_signal_rate": 0.01, 25 | "max_signal_rate": 0.99, 26 | "batch_size": 16, 27 | "adaptive_grad_clip": 2.0, 28 | "ema_decay": 0.999, 29 | "ema_warmup": 300, 30 | "dtype": "bfloat16", 31 | "param_dtype": "float32" 32 | } 33 | -------------------------------------------------------------------------------- /data_pipeline/README.md: -------------------------------------------------------------------------------- 1 | # Data Pipeline 2 | 3 | This folder contains scripts required to create datasets for NTG. This file serves to explain the functionality of each of the scripts, outline various datasets created with them, and explain how to recreate the datasets. 4 | 5 | ## Script Explanations 6 | 7 | ### scrape_earth_explorer 8 | 9 | - Logs in to [Earth Explorer]('https://earthexplorer.usgs.gov/'), and downloads the SRTM 1 arc-second dataset. This is a dataset of approximately 15k high resolution GEOTIFF heightmaps. 10 | 11 | ### geotiff_to_png 12 | 13 | - Converts GEOTIFFs to PNGs with the rasterio library. Primarily used to convert the SRTM 1 arc-second dataset into a more readable form. 14 | 15 | ### split_heightmaps 16 | 17 | - Splits heightmaps into 100 sub-heightmaps. Useful for transforming SRTM 1 arc-second PNG into something more suitable for training an image generation model. 18 | 19 | ### train_corrupted_heightmap_discriminator 20 | 21 | - The SRTM 1 arc-second PNG dataset contains a number of undesirable or otherwise corrupted heightmaps. These include heighmaps with very little terrain visible, heightmaps with padding issues, and heightmaps with only black and white values. This script trains a convolutional NN to discriminate between corrupted and uncorrupted heightmaps. A handpicked discrimination dataset is required for this. 22 | 23 | ### isolate_uncorrupted_heightmaps 24 | 25 | - Uses a NN (trained by the previous script) to identify and isolate uncorrupted heightmaps from a mixed dataset, thereby creating a new uncorrupted dataset. 26 | 27 | ### blur_images 28 | 29 | - Blurs images to create input for style transfer. Style transfer was never implemented so this was never used. 30 | 31 | ## Datasets 32 | 33 | ### SRTM 1 arc-second 34 | 35 | - This is the base dataset scraped from [Earth Explorer]('https://earthexplorer.usgs.gov/'). It contains approximately 15k highresolution GEOTIFF heightmaps. 36 | - To "recreate", simply run ``scrape_earthexplorer.py``. 37 | 38 | ### SRTM 1 arc-second PNG 39 | 40 | - SRTM 1 arc-second dataset converted into PNGs. 41 | - Use ``geotiff_to_png.py`` on SRTM 1 arc-second to recreate. -------------------------------------------------------------------------------- /data_pipeline/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/novaia/ntg/cb770a5dc7e1be81e2f143ec1a75f937c3b229c4/data_pipeline/__init__.py -------------------------------------------------------------------------------- /data_pipeline/bake_texture_maps.py: -------------------------------------------------------------------------------- 1 | """ 2 | Opens a .blend file and executes code to bake texture maps from procedural material. 3 | """ 4 | 5 | import bpy 6 | import os 7 | 8 | # Define the path of the blend file to open 9 | blend_file = os.path.join(os.path.dirname(bpy.data.filepath), 'example.blend') 10 | 11 | # Define the path of the script to execute 12 | script_file = os.path.join(os.path.dirname(bpy.data.filepath), 'bake_texture_maps_internal.py') 13 | 14 | # Define a function that executes the script after opening the blend file 15 | def execute_script(scene): 16 | # Unregister the handler 17 | bpy.app.handlers.load_post.remove(execute_script) 18 | # Execute the script 19 | exec(compile(open(script_file).read(), script_file, 'exec')) 20 | 21 | # Register the function as a persistent handler 22 | bpy.app.handlers.load_post.append(execute_script) 23 | 24 | # Open the blend file 25 | bpy.ops.wm.open_mainfile(filepath=blend_file) -------------------------------------------------------------------------------- /data_pipeline/bake_texture_maps_internal.py: -------------------------------------------------------------------------------- 1 | import bpy 2 | import os 3 | 4 | # Get the active object 5 | obj = bpy.context.active_object 6 | 7 | # Get the material of the object 8 | mat = obj.data.materials[0] 9 | 10 | # Enable nodes for the material 11 | mat.use_nodes = True 12 | 13 | # Create a texture node and link it to the material output 14 | tex_node = mat.node_tree.nodes.new('ShaderNodeTexImage') 15 | mat_output = mat.node_tree.nodes['Material Output'] 16 | mat.node_tree.links.new(tex_node.outputs['Color'], mat_output.inputs['Surface']) 17 | 18 | # Create a new image with the desired resolution and assign it to the texture node 19 | img = bpy.data.images.new(name=obj.name + '_Diffuse', width=4096, height=4096) 20 | tex_node.image = img 21 | 22 | # Select the texture node 23 | tex_node.select = True 24 | mat.node_tree.nodes.active = tex_node 25 | 26 | # Bake the diffuse map 27 | bpy.ops.object.bake(type='DIFFUSE', margin=16) 28 | 29 | # Save the image to a file in the same folder as the blend file 30 | img.filepath_raw = os.path.join(os.path.dirname(bpy.data.filepath), obj.name + '_Diffuse.png') 31 | img.file_format = 'PNG' 32 | img.save() 33 | 34 | # Create a new image with the desired resolution and assign it to the texture node 35 | img = bpy.data.images.new(name=obj.name + '_Normal', width=4096, height=4096) 36 | tex_node.image = img 37 | 38 | # Bake the normal map 39 | bpy.ops.object.bake(type='NORMAL', margin=16) 40 | 41 | # Save the image to a file in the same folder as the blend file 42 | img.filepath_raw = os.path.join(os.path.dirname(bpy.data.filepath), obj.name + '_Normal.png') 43 | img.file_format = 'PNG' 44 | img.save() -------------------------------------------------------------------------------- /data_pipeline/blur_images.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | 4 | input_directory = '/Users/Hayden/Desktop/ml/split_heightmaps/' 5 | output_directory = '../../Desktop/ml/blurred_heightmaps/' 6 | 7 | file_list = os.listdir(input_directory) 8 | 9 | for i in range(len(file_list)): 10 | image = cv2.imread('../../Desktop/ml/split_heightmaps/' + file_list[i]) 11 | resized_image = cv2.resize(image, (256, 256)) 12 | blurred_image = cv2.blur(resized_image, (40, 40)) 13 | cv2.imwrite(output_directory + 'blurred-' + file_list[i], blurred_image) -------------------------------------------------------------------------------- /data_pipeline/generate_eval_set.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import shutil 3 | import os 4 | import random 5 | from tqdm import tqdm 6 | 7 | dataset_path = '../../heightmaps/uncorrupted_split_heightmaps_second_pass' 8 | out_path = '../../heightmaps/uncorrupted_split_heightmaps_second_pass_eval' 9 | num_samples = 4500 10 | 11 | if __name__ == '__main__': 12 | parser = argparse.ArgumentParser() 13 | help_text = 'Path to original dataset.' 14 | parser.add_argument('--dataset_path', type=str, default=dataset_path, help=help_text) 15 | help_text = 'Path to output directory.' 16 | parser.add_argument('--out_path', type=str, default=out_path, help=help_text) 17 | help_text = 'Number of random samples to take from original dataset.' 18 | parser.add_argument('--num_samples', type=int, default=num_samples, help=help_text) 19 | help_text = 'Force script to run even if output directory is not empty.' 20 | parser.add_argument('--force_non_empty', type=bool, default=False, help=help_text) 21 | args = parser.parse_args() 22 | 23 | if os.path.exists(args.out_path) and not args.force_non_empty: 24 | out_path_is_empty = len(os.listdir(args.out_path)) == 0 25 | assertion_text = 'Out path is not empty and --force_non_empty is False' 26 | assert out_path_is_empty, assertion_text 27 | else: 28 | os.makedirs(args.out_path) 29 | 30 | assertion_text = 'Dataset path does not exist' 31 | assert os.path.exists(args.dataset_path), assertion_text 32 | 33 | dataset_list = os.listdir(args.dataset_path) 34 | assertion_text = 'Dataset has fewer files than num_samples' 35 | assert not len(dataset_list) < args.num_samples, assertion_text 36 | 37 | random_samples = random.sample(dataset_list, args.num_samples) 38 | print('Randomly sampled', args.num_samples, 'files from:', args.dataset_path) 39 | print('Copying samples to:', args.out_path) 40 | for sample in tqdm(random_samples): 41 | shutil.copy(os.path.join(args.dataset_path, sample), args.out_path) -------------------------------------------------------------------------------- /data_pipeline/generate_test_set.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import shutil 4 | import argparse 5 | from tqdm import tqdm 6 | 7 | dataset_path = '../../heightmaps/world-heightmaps-01/train' 8 | output_path = '../../heightmaps/world-heightmaps-01/test' 9 | test_split = 0.05 10 | file_extension = '.png' 11 | 12 | if __name__ == '__main__': 13 | parser = argparse.ArgumentParser() 14 | help_text = 'Path to dataset directory.' 15 | parser.add_argument('--dataset_path', type=str, default=dataset_path) 16 | help_text = 'Path to output directory for test set.' 17 | parser.add_argument('--output_path', type=str, default=output_path) 18 | help_text = 'Percentage of dataset to be reserved for test set.' 19 | parser.add_argument('--test_split', type=float, default=test_split) 20 | help_text = 'File extension of dataset files.' 21 | parser.add_argument('--file_extension', type=str, default=file_extension) 22 | args = parser.parse_args() 23 | 24 | file_list = [] 25 | 26 | for root, dirs, files in os.walk(args.dataset_path): 27 | for file in files: 28 | if file.endswith(".png"): 29 | file_list.append(os.path.join(root, file)) 30 | print(file_list[0]) 31 | print(f'Found {len(file_list)} files.') 32 | test_set_size = int(len(file_list) * args.test_split) 33 | print(f'Copying {test_set_size} files to {args.output_path}.') 34 | 35 | selected_files = random.sample(file_list, test_set_size) 36 | for i in tqdm(range(len(selected_files))): 37 | file = selected_files[i] 38 | split_file_path = file.split('/') 39 | class_folder = split_file_path[-2] 40 | file_name = split_file_path[-1] 41 | 42 | folder_output_path = os.path.join(args.output_path, class_folder) 43 | if not os.path.exists(folder_output_path): 44 | os.makedirs(folder_output_path) 45 | 46 | file_output_path = os.path.join(folder_output_path, file_name) 47 | shutil.move(file, file_output_path) 48 | -------------------------------------------------------------------------------- /data_pipeline/geotiff_to_png.py: -------------------------------------------------------------------------------- 1 | import rasterio 2 | import matplotlib.pyplot as plt 3 | import os 4 | 5 | input_directory = '/Users/Hayden/Desktop/geotiffs/' 6 | output_directory = 'D:\heightmap_pngs/' 7 | 8 | file_list = os.listdir(input_directory) 9 | 10 | for i in range(len(file_list)): 11 | image = rasterio.open(input_directory + file_list[i]) 12 | plt.imsave(output_directory + file_list[i][0:-4] + '.png', image.read(1), cmap = 'gray') -------------------------------------------------------------------------------- /data_pipeline/image_dataset_to_parquet.py: -------------------------------------------------------------------------------- 1 | import pyarrow as pa 2 | import pyarrow.parquet as pq 3 | import pandas as pd 4 | from PIL import Image 5 | import os 6 | import io 7 | import json 8 | 9 | samples_per_file = 10_000 10 | 11 | root_dir = 'data/datasets/world-heightmaps-256-v1' 12 | df = pd.read_csv(os.path.join(root_dir, 'metadata.csv')) 13 | df = df.sample(frac=1).reset_index(drop=True) 14 | 15 | print(df.head()) 16 | 17 | def save_table(image_data, table_number): 18 | print(f'Entries in table {table_number}: {len(image_data)}') 19 | schema = pa.schema( 20 | fields=[ 21 | ('heightmap', pa.struct([('bytes', pa.binary()), ('path', pa.string())])), 22 | ('latitude', pa.string()), 23 | ('longitude', pa.string()) 24 | ], 25 | metadata={ 26 | b'huggingface': json.dumps({ 27 | 'info': { 28 | 'features': { 29 | 'heightmap': {'_type': 'Image'}, 30 | 'latitude': {'_type': 'Value', 'dtype': 'string'}, 31 | 'longitude': {'_type': 'Value', 'dtype': 'string'} 32 | } 33 | } 34 | }).encode('utf-8') 35 | } 36 | ) 37 | 38 | table = pa.Table.from_pylist(image_data, schema=schema) 39 | pq.write_table(table, f'data/world-heightmaps-256-parquet/{str(table_number).zfill(4)}.parquet') 40 | 41 | image_data = [] 42 | samples_in_current_file = 0 43 | current_file_number = 0 44 | for i, row in df.iterrows(): 45 | if samples_in_current_file >= samples_per_file: 46 | save_table(image_data, current_file_number) 47 | image_data = [] 48 | samples_in_current_file = 0 49 | current_file_number += 1 50 | samples_in_current_file += 1 51 | image_path = row['file_name'] 52 | with Image.open(os.path.join(root_dir, image_path)) as image: 53 | image_bytes = io.BytesIO() 54 | image.save(image_bytes, format='PNG') 55 | image_dict = { 56 | 'heightmap': { 57 | 'bytes': image_bytes.getvalue(), 58 | 'path': image_path 59 | }, 60 | 'latitude': str(row['latitude']), 61 | 'longitude': str(row['longitude']) 62 | } 63 | image_data.append(image_dict) 64 | 65 | save_table(image_data, current_file_number) 66 | -------------------------------------------------------------------------------- /data_pipeline/isolate_uncorrupted_heightmaps.py: -------------------------------------------------------------------------------- 1 | """ 2 | Needs PyTorch implementation. 3 | """ -------------------------------------------------------------------------------- /data_pipeline/organize_heightmaps_into_folders.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import shutil 3 | import os 4 | from tqdm import tqdm 5 | 6 | dataset_path = '../../heightmaps/uncorrupted_split_heightmaps_second_pass' 7 | coordinate_list = [] 8 | 9 | if __name__ == '__main__': 10 | parser = argparse.ArgumentParser() 11 | help_text = 'Path to dataset.' 12 | parser.add_argument('--dataset_path', default=dataset_path, type=str, help=help_text) 13 | 14 | assert os.path.exists(parser.parse_args().dataset_path), 'Dataset path does not exist' 15 | 16 | dataset_list = os.listdir(parser.parse_args().dataset_path) 17 | for image in tqdm(dataset_list): 18 | coordinate_string = image[:8] 19 | destination_path = os.path.join(dataset_path, coordinate_string) 20 | if coordinate_string not in coordinate_list: 21 | coordinate_list.append(coordinate_string) 22 | os.makedirs(os.path.join(dataset_path, coordinate_string)) 23 | shutil.move(os.path.join(dataset_path, image), destination_path) 24 | -------------------------------------------------------------------------------- /data_pipeline/read_tiff.py: -------------------------------------------------------------------------------- 1 | import os, argparse, glob 2 | 3 | def main(): 4 | os.environ['GDAL_PAM_ENABLED'] = 'NO' 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument('--input_path', type=str, required=True) 7 | parser.add_argument('--output_path', type=str, required=True) 8 | parser.add_argument('--temp_tif', type=str, required=True) 9 | args = parser.parse_args() 10 | 11 | if not os.path.exists(args.output_path): 12 | os.makedirs(args.output_path) 13 | 14 | input_paths = glob.glob(f'{args.input_path}/*/*.tif') 15 | #input_paths = os.listdir(args.input_path) 16 | for input_tif in input_paths: 17 | print(input_tif) 18 | fillnodata_cmd = f'gdal_fillnodata.py -of GTiff -md 100 {input_tif} {args.temp_tif}' 19 | os.system(fillnodata_cmd) 20 | 21 | output_png = os.path.join(args.output_path, f'{os.path.basename(input_tif)[:-4]}_2.png') 22 | translate_cmd = f'gdal_translate -of PNG -scale {args.temp_tif} {output_png}' 23 | os.system(translate_cmd) 24 | 25 | os.remove(args.temp_tif) 26 | 27 | if __name__ == '__main__': 28 | main() 29 | -------------------------------------------------------------------------------- /data_pipeline/rename_dataset_images.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from tqdm import tqdm 4 | 5 | dataset_path = '../../heightmaps/uncorrupted_split_heightmaps_second_pass' 6 | 7 | if __name__ == '__main__': 8 | parser = argparse.ArgumentParser() 9 | help_text = 'Path to dataset.' 10 | parser.add_argument('--dataset_path', default=dataset_path, type=str, help=help_text) 11 | 12 | assert os.path.exists(parser.parse_args().dataset_path), 'Dataset path does not exist' 13 | 14 | dataset_list = os.listdir(parser.parse_args().dataset_path) 15 | for directory in tqdm(dataset_list): 16 | directory_path = os.path.join(dataset_path, directory) 17 | directory_list = os.listdir(directory_path) 18 | for i, image in enumerate(directory_list): 19 | old_name = os.path.join(directory_path, image) 20 | new_name = os.path.join(directory_path, str(i) + '.png') 21 | os.rename(old_name, new_name) -------------------------------------------------------------------------------- /data_pipeline/rename_heightmaps.py: -------------------------------------------------------------------------------- 1 | # This script is meant to be used for heightmaps that have already been 2 | # organized into longitude and latitude folders by 3 | # organize_heightmaps_into_folders.py. This script will remove redundant 4 | # .png postfixes from the heightmap images and rename them to their slice ID. 5 | # The slice ID is the 1-dimensional coordinate of the heightmap with respect to 6 | # its latitude and longitude. There were a maximum of 100 slices per 7 | # latitude/longitude, so the slice ID is a number between 0 and 99. 8 | 9 | import argparse 10 | import os 11 | 12 | if __name__ == '__main__': 13 | parser = argparse.ArgumentParser() 14 | help_text = 'Path to directory containing heightmaps.' 15 | parser.add_argument('--directory', type=str, help=help_text) 16 | help_text = 'Postfix to remove from heightmap filenames.' 17 | parser.add_argument('--postfix_to_remove', type=str, default='.png.png.png', help=help_text) 18 | args = parser.parse_args() 19 | 20 | assert os.path.exists(args.directory), 'Directory does not exist' 21 | 22 | for folder in os.listdir(args.directory): 23 | folder_path = os.path.join(args.directory, folder) 24 | for filename in os.listdir(folder_path): 25 | if filename.endswith(args.postfix_to_remove): 26 | new_filename = filename[:-len(args.postfix_to_remove)] 27 | slice_id = new_filename.split('_')[-1] 28 | if slice_id == 'squared': 29 | new_filename = 'whole' 30 | else: 31 | new_filename = str(slice_id) 32 | new_filename += '.png' 33 | old_path = os.path.join(folder_path, filename) 34 | new_path = os.path.join(folder_path, new_filename) 35 | os.rename(old_path, new_path) -------------------------------------------------------------------------------- /data_pipeline/scrape_earthexplorer.py: -------------------------------------------------------------------------------- 1 | from selenium import webdriver 2 | from selenium.webdriver.common.by import By 3 | from selenium.webdriver.common.keys import Keys 4 | from selenium.webdriver.support.wait import WebDriverWait 5 | from selenium.webdriver.support import expected_conditions as EC 6 | import time 7 | 8 | username = '' 9 | password = '' 10 | 11 | driver = webdriver.Chrome() 12 | driver.implicitly_wait(1000) 13 | 14 | # Login. 15 | driver.get('https://ers.cr.usgs.gov/login') 16 | 17 | username_input = driver.find_element(By.NAME, 'username') 18 | username_input.send_keys(username) 19 | 20 | password_input = driver.find_element(By.NAME, 'password') 21 | password_input.send_keys(password) 22 | 23 | login_button = driver.find_element(By.ID, 'loginButton') 24 | login_button.click() 25 | 26 | # Navigate to SRTM 1 Arc-Second results. 27 | driver.get('https://earthexplorer.usgs.gov/') 28 | 29 | datasets_tab = driver.find_element(By.ID, 'tab2') 30 | datasets_tab.click() 31 | print('Started loading datasets tab.') 32 | 33 | digital_elevation_li = driver.find_element(By.ID, 'cat_207') 34 | digital_elevation_expander = digital_elevation_li.find_element(By.CLASS_NAME, 'folder'); 35 | digital_elevation_expander.click() 36 | 37 | srtm_li = driver.find_element(By.ID, 'cat_1103') 38 | srtm_expander = srtm_li.find_element(By.CLASS_NAME, 'folder') 39 | srtm_expander.click() 40 | 41 | one_arcsecond_checkbox = driver.find_element(By.ID, 'coll_5e83a3ee1af480c5') 42 | one_arcsecond_checkbox.click() 43 | print('Selected SRTM 1 arc-second.') 44 | 45 | results_tab = driver.find_element(By.ID, 'tab4') 46 | results_tab.click() 47 | print('Started loading results tab.') 48 | 49 | # Download SRTM heightmaps. 50 | for k in range(1, 1428): 51 | page_selector = driver.find_element(By.ID, 'pageSelector_5e83a3ee1af480c5_F') 52 | 53 | if int(page_selector.get_attribute('value')) != k: 54 | print('Waiting for page ' + str(k) + ' to load.') 55 | while int(page_selector.get_attribute('value')) != k: 56 | time.sleep(driver, 0.5) 57 | print('Page ' + str(k) + ' has loaded.') 58 | 59 | download_options_buttons = driver.find_elements(By.CLASS_NAME, 'download') 60 | 61 | for i in range(len(download_options_buttons)): 62 | current_result_number = (k - 1) * len(download_options_buttons) + i + 1 63 | current_result_number_string = str(current_result_number) 64 | 65 | download_options_buttons[i].click() 66 | print('Opened download menu for result ' + current_result_number_string + '.') 67 | 68 | download_options_container = driver.find_element(By.ID, 'optionsContainer') 69 | geotiff_download_button = WebDriverWait(driver, 10).until(EC.element_to_be_clickable((By.XPATH, '/html/body/div[7]/div[2]/div/div[2]/div[3]/div[1]/button'))) 70 | #download_buttons = download_options_container.find_elements(By.CLASS_NAME, 'downloadButtons') 71 | #geotiff_download_button = download_buttons[2] 72 | geotiff_download_button.click() 73 | print('Started downloading result ' + current_result_number_string + '.') 74 | 75 | close_button = driver.find_element(By.XPATH, '/html/body/div[7]/div[1]/button') 76 | close_button.click() 77 | print('Closed download menu for result ' + current_result_number_string + '.') 78 | 79 | page_selector = driver.find_element(By.ID, 'pageSelector_5e83a3ee1af480c5_F') 80 | page_selector.send_keys(Keys.DELETE, Keys.DELETE, Keys.DELETE, Keys.DELETE) 81 | page_selector.send_keys(str(k + 1)) 82 | page_selector.send_keys(Keys.RETURN) 83 | print('Started loading page ' + str(k + 1) + '.') 84 | time.sleep(10) -------------------------------------------------------------------------------- /data_pipeline/split_heightmaps.py: -------------------------------------------------------------------------------- 1 | from split_image import split_image 2 | import os 3 | 4 | input_directory = 'D:\heightmap_pngs/' 5 | output_directory = 'D:\split_heightmaps/' 6 | 7 | file_list = os.listdir(input_directory) 8 | 9 | for i in range(len(file_list)): 10 | split_image(input_directory + file_list[i], 10, 10, should_square=True, should_cleanup=False, output_dir=output_directory) 11 | -------------------------------------------------------------------------------- /data_pipeline/train_corrupted_heightmap_discriminator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Needs PyTorch impelementation. 3 | """ -------------------------------------------------------------------------------- /fid/__init__.py: -------------------------------------------------------------------------------- 1 | # Frechet Inception Distance. 2 | # Mostly taken from: https://github.com/matthias-wright/jax-fid/tree/main 3 | # License: https://github.com/matthias-wright/jax-fid/blob/main/LICENSE 4 | # The code in this file was modified from the original. 5 | 6 | import jax 7 | import functools 8 | from tqdm import tqdm 9 | import os 10 | import numpy as np 11 | import jax.numpy as jnp 12 | from fid.inception import InceptionV3 13 | import scipy 14 | 15 | def compute_statistics_with_mmap( 16 | params, apply_fn, num_batches, batch_size, 17 | get_batch_fn, filename, dtype, num_activations 18 | ): 19 | activation_dim = 2048 20 | mm = np.memmap(filename, dtype=dtype, mode='w+', shape=(num_activations, activation_dim)) 21 | 22 | activation_sum = np.zeros((activation_dim)) 23 | for i in tqdm(range(num_batches)): 24 | x = get_batch_fn(seed = i) 25 | x = np.asarray(x) 26 | x = 2 * x - 1 27 | activation_batch = apply_fn(params, jax.lax.stop_gradient(x)) 28 | activation_batch = activation_batch.squeeze(axis=1).squeeze(axis=1) 29 | 30 | current_batch_size = activation_batch.shape[0] 31 | start_index = i * batch_size 32 | end_index = start_index + current_batch_size 33 | mm[start_index : end_index] = activation_batch 34 | 35 | activation_sum += activation_batch.sum(axis=0) 36 | 37 | mu = activation_sum / num_activations 38 | sigma = np.cov(mm, rowvar=False) 39 | 40 | return mu, sigma 41 | 42 | def compute_statistics(params, apply_fn, num_batches, get_batch_fn): 43 | activations = [] 44 | 45 | for i in tqdm(range(num_batches)): 46 | x = get_batch_fn(seed = i) 47 | x = np.asarray(x) 48 | x = 2 * x - 1 49 | pred = apply_fn(params, jax.lax.stop_gradient(x)) 50 | activations.append(pred.squeeze(axis=1).squeeze(axis=1)) 51 | activations = jnp.concatenate(activations, axis=0) 52 | 53 | mu = np.mean(activations, axis=0) 54 | sigma = np.cov(activations, rowvar=False) 55 | return mu, sigma 56 | 57 | def load_statistics(path): 58 | stats = np.load(path) 59 | mu, sigma = stats["mu"], stats["sigma"] 60 | return mu, sigma 61 | 62 | def save_statistics(path, mu, sigma): 63 | np.savez(path, mu=mu, sigma=sigma) 64 | 65 | # Taken from: https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/fid_score.py 66 | def compute_frechet_distance(mu1, mu2, sigma1, sigma2, eps=1e-6): 67 | mu1 = np.atleast_1d(mu1) 68 | mu2 = np.atleast_1d(mu2) 69 | sigma1 = np.atleast_1d(sigma1) 70 | sigma2 = np.atleast_1d(sigma2) 71 | 72 | assertion_text = f'mu shapes must be the same but are {mu1.shape} and {mu2.shape}' 73 | assert mu1.shape == mu2.shape, assertion_text 74 | assertion_text = f'sigma shapes must be the same but are {sigma1.shape} and {sigma2.shape}' 75 | assert sigma1.shape == sigma2.shape, assertion_text 76 | 77 | diff = mu1 - mu2 78 | 79 | covmean, _ = scipy.linalg.sqrtm(sigma1.dot(sigma2), disp=False) 80 | if not np.isfinite(covmean).all(): 81 | msg = ( 82 | "fid calculation produces singular product; " 83 | "adding %s to diagonal of cov estimates" 84 | ) % eps 85 | print(msg) 86 | offset = np.eye(sigma1.shape[0]) * eps 87 | covmean = scipy.linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 88 | 89 | # Numerical error might give slight imaginary component. 90 | if np.iscomplexobj(covmean): 91 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 92 | m = np.max(np.abs(covmean.imag)) 93 | raise ValueError("Imaginary component {}".format(m)) 94 | covmean = covmean.real 95 | 96 | tr_covmean = np.trace(covmean) 97 | return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean 98 | 99 | def get_inception_model(): 100 | rng = jax.random.PRNGKey(0) 101 | model = InceptionV3() 102 | params = model.init(rng, jnp.ones((1, 256, 256, 3))) 103 | apply_fn = jax.jit(functools.partial(model.apply, train=False)) 104 | return params, apply_fn 105 | 106 | def preprocessing_function(image): 107 | image = image.astype(float) / 255 108 | return image 109 | 110 | # Checks if FID has been setup correctly. 111 | # Useful to catch errors before training. 112 | def check_for_correct_setup(fid_stats_path=None): 113 | assertion_text = 'Must specify path to FID stats file if using FID.' 114 | assert fid_stats_path is not None, assertion_text 115 | assertion_text = f'FID stats file does not exist at {fid_stats_path}' 116 | assert os.path.isfile(fid_stats_path), assertion_text 117 | assertion_text = 'FID stats file must be a .npz file.' 118 | assert fid_stats_path.endswith('.npz'), assertion_text 119 | inception = get_inception_model() 120 | assertion_text = 'Failed to load Inception model.' 121 | assert inception is not None, assertion_text -------------------------------------------------------------------------------- /fid/__main__.py: -------------------------------------------------------------------------------- 1 | # Frechet Inception Distance. 2 | # Mostly taken from: https://github.com/matthias-wright/jax-fid/tree/main 3 | # License: https://github.com/matthias-wright/jax-fid/blob/main/LICENSE 4 | # The code in this file was modified from the original. 5 | 6 | import argparse 7 | import fid 8 | import os 9 | from keras.preprocessing.image import ImageDataGenerator 10 | 11 | # Wraps compute statistics functions and decides which one to call based on if mmap is True. 12 | def _compute_statistics_wrapper( 13 | params, apply_fn, num_batches, get_batch_fn, num_activations, args 14 | ): 15 | if args.mmap: 16 | mu, sigma = fid.compute_statistics_with_mmap( 17 | params = params, 18 | apply_fn = apply_fn, 19 | num_batches = num_batches, 20 | batch_size = args.batch_size, 21 | get_batch_fn = get_batch_fn, 22 | filename = args.mmap_filename, 23 | dtype = 'float32', 24 | num_activations = num_activations 25 | ) 26 | else: 27 | mu, sigma = fid.compute_statistics(params, apply_fn, num_batches, get_batch_fn) 28 | return mu, sigma 29 | 30 | def _get_directory_iterator(path, args, data_generator): 31 | directory_iterator = data_generator.flow_from_directory( 32 | path, 33 | target_size = args.img_size, 34 | batch_size = args.batch_size, 35 | color_mode = 'rgb', 36 | classes = [''] 37 | ) 38 | return directory_iterator 39 | 40 | def _precompute_and_save_statistics(args, params, apply_fn, data_generator): 41 | error_text = 'img_dir must be specified if precompute_stats is True' 42 | assert args.img_dir is not None, error_text 43 | error_text = 'out_dir must be specified if precompute_stats is True' 44 | assert args.out_dir is not None, error_text 45 | 46 | directory_iterator = _get_directory_iterator(args.img_dir, args, data_generator) 47 | mu, sigma = _compute_statistics_wrapper( 48 | params = params, 49 | apply_fn = apply_fn, 50 | num_batches = len(directory_iterator), 51 | get_batch_fn = lambda: directory_iterator.next()[0], 52 | num_activations = directory_iterator.samples, 53 | args = args 54 | ) 55 | 56 | os.makedirs(args.out_dir, exist_ok=True) 57 | fid.save_statistics(os.path.join(args.out_dir, args.out_name), mu=mu, sigma=sigma) 58 | print( 59 | 'Saved pre-computed statistics at:', 60 | os.path.join(args.out_dir, args.out_name + '.npz') 61 | ) 62 | 63 | def _get_statistics_and_compute_fid(args, params, apply_fn, data_generator): 64 | error_text = 'path1 must be specified if precompute_stats is False' 65 | assert args.path1 is not None, error_text 66 | error_text = 'path2 must be specified if precompute_stats is False' 67 | assert args.path2 is not None, error_text 68 | 69 | if args.path1.endswith('.npz'): 70 | mu1, sigma1 = fid.load_statistics(args.path1) 71 | else: 72 | directory_iterator1 = _get_directory_iterator(args.path1, args, data_generator) 73 | mu1, sigma1 = _compute_statistics_wrapper( 74 | params = params, 75 | apply_fn = apply_fn, 76 | num_batches = len(directory_iterator1), 77 | get_batch_fn = lambda: directory_iterator1.next()[0], 78 | num_activations = directory_iterator1.samples, 79 | args = args 80 | ) 81 | 82 | if args.path2.endswith('.npz'): 83 | mu2, sigma2 = fid.load_statistics(args.path2) 84 | else: 85 | directory_iterator2 = _get_directory_iterator(args.path2, args, data_generator) 86 | mu2, sigma2 = _compute_statistics_wrapper( 87 | params = params, 88 | apply_fn = apply_fn, 89 | num_batches = len(directory_iterator2), 90 | get_batch_fn = lambda: directory_iterator2.next()[0], 91 | num_activations = directory_iterator2.samples, 92 | args = args 93 | ) 94 | 95 | frechet_distance = fid.compute_frechet_distance(mu1, mu2, sigma1, sigma2) 96 | print('FID:', frechet_distance) 97 | 98 | if __name__ == '__main__': 99 | parser = argparse.ArgumentParser() 100 | help_text = 'Path to image directory or .npz file containing pre-computed statistics.' 101 | parser.add_argument('--path1', type=str, help=help_text) 102 | help_text = 'Path to image directory or .npz file containing pre-computed statistics.' 103 | parser.add_argument('--path2', type=str, help=help_text) 104 | help_text = 'Batch size per device for computing the Inception activations.' 105 | parser.add_argument('--batch_size', type=int, default=50, help=help_text) 106 | help_text = 'Resize images to this size. The format is (height, width).' 107 | parser.add_argument('--img_size', type=int, nargs=2, help=help_text) 108 | help_text = 'If True, pre-compute statistics for given image directory.' 109 | parser.add_argument('--precompute', type=bool, default=False, help=help_text) 110 | help_text = 'Path to image directory for pre-computing statistics.' 111 | parser.add_argument('--img_dir', type=str, help=help_text) 112 | help_text = 'Path where pre-computed statistics are stored.' 113 | parser.add_argument('--out_dir', type=str, help=help_text) 114 | help_text = 'Name of outputted statistics file.' 115 | parser.add_argument('--out_name', type=str, default='stats', help=help_text) 116 | help_text = 'If True, use mmap to compute statistics. Helpful for large datasets.' 117 | parser.add_argument('--mmap', type=bool, default=True, help=help_text) 118 | mmap_filename = 'data/temp/mmap_file' 119 | help_text = 'Name for mmap file. Only used if mmap is True.' 120 | parser.add_argument('--mmap_filename', type=str, default=mmap_filename, help=help_text) 121 | args = parser.parse_args() 122 | 123 | params, apply_fn = fid.get_inception_model() 124 | idg = ImageDataGenerator(preprocessing_function = fid.preprocessing_function) 125 | 126 | if args.precompute: 127 | _precompute_and_save_statistics(args, params, apply_fn, idg) 128 | else: 129 | _get_statistics_and_compute_fid(args, params, apply_fn, idg) -------------------------------------------------------------------------------- /fid/inception.py: -------------------------------------------------------------------------------- 1 | # Inception for calculating FID. 2 | # Doesn't include head and only loads precomputed weights. 3 | # Mostly taken from: https://github.com/matthias-wright/jax-fid/tree/main 4 | # License: https://github.com/matthias-wright/jax-fid/blob/main/LICENSE 5 | # The code in this file was modified from the original. 6 | 7 | import flax.linen as nn 8 | import jax.numpy as jnp 9 | import jax 10 | from jax import lax 11 | from jax.nn import initializers 12 | import pickle 13 | from typing import Callable, Iterable, Optional, Tuple, Union, Any 14 | import os 15 | 16 | PRNGKey = Any 17 | Array = Any 18 | Shape = Tuple[int] 19 | Dtype = Any 20 | 21 | class InceptionV3(nn.Module): 22 | checkpoint_path: str = 'data/inception_v3_fid.pickle' 23 | dtype: str = 'float32' 24 | num_classes: int = 1000 25 | 26 | def setup(self): 27 | assert os.path.isfile(self.checkpoint_path), 'Inception checkpoint not found' 28 | self.params_dict = pickle.load(open(self.checkpoint_path, 'rb')) 29 | 30 | # when I left off I was testing train (bool) values 31 | # TODO: here 32 | @nn.compact 33 | def __call__(self, x, train=True): 34 | x = BasicConv2d( 35 | out_channels=32, 36 | kernel_size=(3, 3), 37 | strides=(2, 2), 38 | params_dict=get_from_dict(self.params_dict, 'Conv2d_1a_3x3'), 39 | dtype=self.dtype 40 | )(x, train) 41 | 42 | x = BasicConv2d( 43 | out_channels=32, 44 | kernel_size=(3, 3), 45 | params_dict=get_from_dict(self.params_dict, 'Conv2d_2a_3x3'), 46 | dtype=self.dtype 47 | )(x, train) 48 | 49 | x = BasicConv2d( 50 | out_channels=64, 51 | kernel_size=(3, 3), 52 | padding=((1, 1), (1, 1)), 53 | params_dict=get_from_dict(self.params_dict, 'Conv2d_2b_3x3'), 54 | dtype=self.dtype 55 | )(x, train) 56 | 57 | x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2)) 58 | x = BasicConv2d( 59 | out_channels=80, 60 | kernel_size=(1, 1), 61 | params_dict=get_from_dict(self.params_dict, 'Conv2d_3b_1x1'), 62 | dtype=self.dtype 63 | )(x, train) 64 | 65 | x = BasicConv2d( 66 | out_channels=192, 67 | kernel_size=(3, 3), 68 | params_dict=get_from_dict(self.params_dict, 'Conv2d_4a_3x3'), 69 | dtype=self.dtype 70 | )(x, train) 71 | 72 | x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2)) 73 | x = InceptionA( 74 | pool_features=32, 75 | params_dict=get_from_dict(self.params_dict, 'Mixed_5b'), 76 | dtype=self.dtype 77 | )(x, train) 78 | 79 | x = InceptionA( 80 | pool_features=64, 81 | params_dict=get_from_dict(self.params_dict, 'Mixed_5c'), 82 | dtype=self.dtype 83 | )(x, train) 84 | 85 | x = InceptionA( 86 | pool_features=64, 87 | params_dict=get_from_dict(self.params_dict, 'Mixed_5d'), 88 | dtype=self.dtype 89 | )(x, train) 90 | 91 | x = InceptionB( 92 | params_dict=get_from_dict(self.params_dict, 'Mixed_6a'), 93 | dtype=self.dtype 94 | )(x, train) 95 | 96 | x = InceptionC( 97 | channels_7x7=128, 98 | params_dict=get_from_dict(self.params_dict, 'Mixed_6b'), 99 | dtype=self.dtype 100 | )(x, train) 101 | 102 | x = InceptionC( 103 | channels_7x7=160, 104 | params_dict=get_from_dict(self.params_dict, 'Mixed_6c'), 105 | dtype=self.dtype 106 | )(x, train) 107 | 108 | x = InceptionC( 109 | channels_7x7=160, 110 | params_dict=get_from_dict(self.params_dict, 'Mixed_6d'), 111 | dtype=self.dtype 112 | )(x, train) 113 | 114 | x = InceptionC( 115 | channels_7x7=192, 116 | params_dict=get_from_dict(self.params_dict, 'Mixed_6e'), 117 | dtype=self.dtype 118 | )(x, train) 119 | 120 | x = InceptionD( 121 | params_dict=get_from_dict(self.params_dict, 'Mixed_7a'), 122 | dtype=self.dtype 123 | )(x, train) 124 | 125 | x = InceptionE( 126 | avg_pool, params_dict=get_from_dict(self.params_dict, 'Mixed_7b'), 127 | dtype=self.dtype 128 | )(x, train) 129 | 130 | # Following the implementation by @mseitzer, we use max pooling instead 131 | # of average pooling here. 132 | # See: https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/inception.py#L320 133 | x = InceptionE( 134 | nn.max_pool, params_dict=get_from_dict(self.params_dict, 'Mixed_7c'), 135 | dtype=self.dtype 136 | )(x, train) 137 | x = jnp.mean(x, axis=(1, 2), keepdims=True) 138 | return x 139 | 140 | class Dense(nn.Module): 141 | features: int 142 | params_dict: dict=None 143 | dtype: str='float32' 144 | 145 | @nn.compact 146 | def __call__(self, x): 147 | x = nn.Dense( 148 | features=self.features, 149 | kernel_init = lambda *_ : jnp.array(get_from_dict(self.params_dict, 'kernel')), 150 | bias_init = lambda *_ : jnp.array(get_from_dict(self.params_dict, 'bias')) 151 | )(x) 152 | 153 | return x 154 | 155 | class BasicConv2d(nn.Module): 156 | out_channels: int 157 | kernel_size: Union[int, Iterable[int]]=(3, 3) 158 | strides: Optional[Iterable[int]]=(1, 1) 159 | padding: Union[str, Iterable[Tuple[int, int]]]='valid' 160 | use_bias: bool=False 161 | params_dict: dict=None 162 | dtype: str='float32' 163 | 164 | @nn.compact 165 | def __call__(self, x, train=True): 166 | x = nn.Conv( 167 | features = self.out_channels, 168 | kernel_size = self.kernel_size, 169 | strides = self.strides, 170 | padding = self.padding, 171 | use_bias = self.use_bias, 172 | kernel_init = lambda *_ : jnp.array(get_from_dict(self.params_dict['conv'], 'kernel')), 173 | bias_init = lambda *_ : jnp.array(get_from_dict(self.params_dict['conv'], 'bias')), 174 | dtype=self.dtype 175 | )(x) 176 | 177 | x = BatchNorm( 178 | epsilon=0.001, 179 | momentum=0.1, 180 | bias_init = lambda *_ : jnp.array(self.params_dict['bn']['bias']), 181 | scale_init = lambda *_ : jnp.array(self.params_dict['bn']['scale']), 182 | mean_init = lambda *_ : jnp.array(self.params_dict['bn']['mean']), 183 | var_init = lambda *_ : jnp.array(self.params_dict['bn']['var']), 184 | use_running_average=not train, 185 | dtype=self.dtype 186 | )(x) 187 | 188 | x = nn.relu(x) 189 | return x 190 | 191 | # Taken from: https://github.com/google/flax/blob/master/flax/linen/normalization.py 192 | class BatchNorm(nn.Module): 193 | use_running_average: Optional[bool] = None 194 | axis: int = -1 195 | momentum: float = 0.99 196 | epsilon: float = 1e-5 197 | dtype: Dtype = jnp.float32 198 | use_bias: bool = True 199 | use_scale: bool = True 200 | bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros 201 | scale_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.ones 202 | mean_init: Callable[[Shape], Array] = lambda s: jnp.zeros(s, jnp.float32) 203 | var_init: Callable[[Shape], Array] = lambda s: jnp.ones(s, jnp.float32) 204 | axis_name: Optional[str] = None 205 | axis_index_groups: Any = None 206 | 207 | @nn.compact 208 | def __call__(self, x, use_running_average: Optional[bool] = None): 209 | use_running_average = nn.module.merge_param( 210 | 'use_running_average', 211 | self.use_running_average, 212 | use_running_average 213 | ) 214 | x = jnp.asarray(x, jnp.float32) 215 | axis = self.axis if isinstance(self.axis, tuple) else (self.axis,) 216 | axis = absolute_dims(x.ndim, axis) 217 | feature_shape = tuple(d if i in axis else 1 for i, d in enumerate(x.shape)) 218 | reduced_feature_shape = tuple(d for i, d in enumerate(x.shape) if i in axis) 219 | reduction_axis = tuple(i for i in range(x.ndim) if i not in axis) 220 | 221 | # see NOTE above on initialization behavior 222 | initializing = self.is_mutable_collection('params') 223 | 224 | ra_mean = self.variable( 225 | 'batch_stats', 226 | 'mean', 227 | self.mean_init, 228 | reduced_feature_shape 229 | ) 230 | ra_var = self.variable( 231 | 'batch_stats', 232 | 'var', 233 | self.var_init, 234 | reduced_feature_shape 235 | ) 236 | 237 | if use_running_average: 238 | mean, var = ra_mean.value, ra_var.value 239 | else: 240 | mean = jnp.mean(x, axis=reduction_axis, keepdims=False) 241 | mean2 = jnp.mean(lax.square(x), axis=reduction_axis, keepdims=False) 242 | if self.axis_name is not None and not initializing: 243 | concatenated_mean = jnp.concatenate([mean, mean2]) 244 | mean, mean2 = jnp.split( 245 | lax.pmean( 246 | concatenated_mean, 247 | axis_name=self.axis_name, 248 | axis_index_groups=self.axis_index_groups), 2) 249 | var = mean2 - lax.square(mean) 250 | 251 | if not initializing: 252 | ra_mean.value = self.momentum * ra_mean.value + (1 - self.momentum) * mean 253 | ra_var.value = self.momentum * ra_var.value + (1 - self.momentum) * var 254 | 255 | y = x - mean.reshape(feature_shape) 256 | mul = lax.rsqrt(var + self.epsilon) 257 | if self.use_scale: 258 | scale = self.param( 259 | 'scale', 260 | self.scale_init, 261 | reduced_feature_shape 262 | ).reshape(feature_shape) 263 | mul = mul * scale 264 | y = y * mul 265 | if self.use_bias: 266 | bias = self.param( 267 | 'bias', 268 | self.bias_init, 269 | reduced_feature_shape 270 | ).reshape(feature_shape) 271 | y = y + bias 272 | return jnp.asarray(y, self.dtype) 273 | 274 | # Taken from: https://github.com/google/flax/blob/main/flax/linen/pooling.py 275 | def pool(inputs, init, reduce_fn, window_shape, strides, padding): 276 | strides = strides or (1,) * len(window_shape) 277 | assert len(window_shape) == len(strides), ( 278 | f"len({window_shape}) == len({strides})") 279 | strides = (1,) + strides + (1,) 280 | dims = (1,) + window_shape + (1,) 281 | 282 | is_single_input = False 283 | if inputs.ndim == len(dims) - 1: 284 | # Add singleton batch dimension because lax.reduce_window always 285 | # needs a batch dimension. 286 | inputs = inputs[None] 287 | is_single_input = True 288 | 289 | assert inputs.ndim == len(dims), f"len({inputs.shape}) != len({dims})" 290 | if not isinstance(padding, str): 291 | padding = tuple(map(tuple, padding)) 292 | assert(len(padding) == len(window_shape)), ( 293 | f"padding {padding} must specify pads for same number of dims as " 294 | f"window_shape {window_shape}") 295 | assert(all([len(x) == 2 for x in padding])), ( 296 | f"each entry in padding {padding} must be length 2") 297 | padding = ((0,0),) + padding + ((0,0),) 298 | y = jax.lax.reduce_window(inputs, init, reduce_fn, dims, strides, padding) 299 | if is_single_input: 300 | y = jnp.squeeze(y, axis=0) 301 | return y 302 | 303 | def avg_pool(inputs, window_shape, strides=None, padding='VALID'): 304 | assert inputs.ndim == 4 305 | assert len(window_shape) == 2 306 | 307 | y = pool(inputs, 0., jax.lax.add, window_shape, strides, padding) 308 | ones = jnp.ones(shape=(1, inputs.shape[1], inputs.shape[2], 1)).astype(inputs.dtype) 309 | counts = jax.lax.conv_general_dilated( 310 | ones, 311 | jnp.expand_dims(jnp.ones(window_shape).astype(inputs.dtype), axis=(-2, -1)), 312 | window_strides=(1, 1), 313 | padding=((1, 1), (1, 1)), 314 | dimension_numbers=nn.linear._conv_dimension_numbers(ones.shape), 315 | feature_group_count=1 316 | ) 317 | y = y / counts 318 | return y 319 | 320 | class InceptionA(nn.Module): 321 | pool_features: int 322 | params_dict: dict=None 323 | dtype: str='float32' 324 | 325 | @nn.compact 326 | def __call__(self, x, train=True): 327 | branch1x1 = BasicConv2d( 328 | out_channels=64, 329 | kernel_size=(1, 1), 330 | params_dict=get_from_dict(self.params_dict, 'branch1x1'), 331 | dtype=self.dtype 332 | )(x, train) 333 | 334 | branch5x5 = BasicConv2d( 335 | out_channels=48, 336 | kernel_size=(1, 1), 337 | params_dict=get_from_dict(self.params_dict, 'branch5x5_1'), 338 | dtype=self.dtype 339 | )(x, train) 340 | 341 | branch5x5 = BasicConv2d( 342 | out_channels=64, 343 | kernel_size=(5, 5), 344 | padding=((2, 2), (2, 2)), 345 | params_dict=get_from_dict(self.params_dict, 'branch5x5_2'), 346 | dtype=self.dtype 347 | )(branch5x5, train) 348 | 349 | branch3x3dbl = BasicConv2d( 350 | out_channels=64, 351 | kernel_size=(1, 1), 352 | params_dict=get_from_dict(self.params_dict, 'branch3x3dbl_1'), 353 | dtype=self.dtype 354 | )(x, train) 355 | 356 | branch3x3dbl = BasicConv2d( 357 | out_channels=96, 358 | kernel_size=(3, 3), 359 | padding=((1, 1), (1, 1)), 360 | params_dict=get_from_dict(self.params_dict, 'branch3x3dbl_2'), 361 | dtype=self.dtype 362 | )(branch3x3dbl, train) 363 | 364 | branch3x3dbl = BasicConv2d( 365 | out_channels=96, 366 | kernel_size=(3, 3), 367 | padding=((1, 1), (1, 1)), 368 | params_dict=get_from_dict(self.params_dict, 'branch3x3dbl_3'), 369 | dtype=self.dtype 370 | )(branch3x3dbl, train) 371 | 372 | branch_pool = avg_pool( 373 | x, 374 | window_shape=(3, 3), 375 | strides=(1, 1), 376 | padding=((1, 1), (1, 1)) 377 | ) 378 | 379 | branch_pool = BasicConv2d( 380 | out_channels=self.pool_features, 381 | kernel_size=(1, 1), 382 | params_dict=get_from_dict(self.params_dict, 'branch_pool'), 383 | dtype=self.dtype 384 | )(branch_pool, train) 385 | 386 | output = jnp.concatenate((branch1x1, branch5x5, branch3x3dbl, branch_pool), axis=-1) 387 | return output 388 | 389 | class InceptionB(nn.Module): 390 | params_dict: dict=None 391 | dtype: str='float32' 392 | 393 | @nn.compact 394 | def __call__(self, x, train=True): 395 | branch3x3 = BasicConv2d( 396 | out_channels=384, 397 | kernel_size=(3, 3), 398 | strides=(2, 2), 399 | params_dict=get_from_dict(self.params_dict, 'branch3x3'), 400 | dtype=self.dtype 401 | )(x, train) 402 | 403 | branch3x3dbl = BasicConv2d( 404 | out_channels=64, 405 | kernel_size=(1, 1), 406 | params_dict=get_from_dict(self.params_dict, 'branch3x3dbl_1'), 407 | dtype=self.dtype 408 | )(x, train) 409 | 410 | branch3x3dbl = BasicConv2d( 411 | out_channels=96, 412 | kernel_size=(3, 3), 413 | padding=((1, 1), (1, 1)), 414 | params_dict=get_from_dict(self.params_dict, 'branch3x3dbl_2'), 415 | dtype=self.dtype 416 | )(branch3x3dbl, train) 417 | 418 | branch3x3dbl = BasicConv2d( 419 | out_channels=96, 420 | kernel_size=(3, 3), 421 | strides=(2, 2), 422 | params_dict=get_from_dict(self.params_dict, 'branch3x3dbl_3'), 423 | dtype=self.dtype 424 | )(branch3x3dbl, train) 425 | 426 | branch_pool = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2)) 427 | 428 | output = jnp.concatenate((branch3x3, branch3x3dbl, branch_pool), axis=-1) 429 | return output 430 | 431 | class InceptionC(nn.Module): 432 | channels_7x7: int 433 | params_dict: dict=None 434 | dtype: str='float32' 435 | 436 | @nn.compact 437 | def __call__(self, x, train=True): 438 | branch1x1 = BasicConv2d( 439 | out_channels=192, 440 | kernel_size=(1, 1), 441 | params_dict=get_from_dict(self.params_dict, 'branch1x1'), 442 | dtype=self.dtype 443 | )(x, train) 444 | 445 | branch7x7 = BasicConv2d( 446 | out_channels=self.channels_7x7, 447 | kernel_size=(1, 1), 448 | params_dict=get_from_dict(self.params_dict, 'branch7x7_1'), 449 | dtype=self.dtype 450 | )(x, train) 451 | 452 | branch7x7 = BasicConv2d( 453 | out_channels=self.channels_7x7, 454 | kernel_size=(1, 7), 455 | padding=((0, 0), (3, 3)), 456 | params_dict=get_from_dict(self.params_dict, 'branch7x7_2'), 457 | dtype=self.dtype 458 | )(branch7x7, train) 459 | 460 | branch7x7 = BasicConv2d( 461 | out_channels=192, 462 | kernel_size=(7, 1), 463 | padding=((3, 3), (0, 0)), 464 | params_dict=get_from_dict(self.params_dict, 'branch7x7_3'), 465 | dtype=self.dtype 466 | )(branch7x7, train) 467 | 468 | branch7x7dbl = BasicConv2d( 469 | out_channels=self.channels_7x7, 470 | kernel_size=(1, 1), 471 | params_dict=get_from_dict(self.params_dict, 'branch7x7dbl_1'), 472 | dtype=self.dtype 473 | )(x, train) 474 | 475 | branch7x7dbl = BasicConv2d( 476 | out_channels=self.channels_7x7, 477 | kernel_size=(7, 1), 478 | padding=((3, 3), (0, 0)), 479 | params_dict=get_from_dict(self.params_dict, 'branch7x7dbl_2'), 480 | dtype=self.dtype 481 | )(branch7x7dbl, train) 482 | 483 | branch7x7dbl = BasicConv2d( 484 | out_channels=self.channels_7x7, 485 | kernel_size=(1, 7), 486 | padding=((0, 0), (3, 3)), 487 | params_dict=get_from_dict(self.params_dict, 'branch7x7dbl_3'), 488 | dtype=self.dtype 489 | )(branch7x7dbl, train) 490 | 491 | branch7x7dbl = BasicConv2d( 492 | out_channels=self.channels_7x7, 493 | kernel_size=(7, 1), 494 | padding=((3, 3), (0, 0)), 495 | params_dict=get_from_dict(self.params_dict, 'branch7x7dbl_4'), 496 | dtype=self.dtype 497 | )(branch7x7dbl, train) 498 | 499 | branch7x7dbl = BasicConv2d( 500 | out_channels=self.channels_7x7, 501 | kernel_size=(1, 7), 502 | padding=((0, 0), (3, 3)), 503 | params_dict=get_from_dict(self.params_dict, 'branch7x7dbl_5'), 504 | dtype=self.dtype 505 | )(branch7x7dbl, train) 506 | 507 | branch_pool = avg_pool( 508 | x, 509 | window_shape=(3, 3), 510 | strides=(1, 1), 511 | padding=((1, 1), (1, 1)) 512 | ) 513 | 514 | branch_pool = BasicConv2d( 515 | out_channels=192, 516 | kernel_size=(1, 1), 517 | params_dict=get_from_dict(self.params_dict, 'branch_pool'), 518 | dtype=self.dtype 519 | )(branch_pool, train) 520 | 521 | output = jnp.concatenate((branch1x1, branch7x7, branch7x7dbl, branch_pool), axis=-1) 522 | return output 523 | 524 | class InceptionD(nn.Module): 525 | params_dict: dict=None 526 | dtype: str='float32' 527 | 528 | @nn.compact 529 | def __call__(self, x, train=True): 530 | branch3x3 = BasicConv2d( 531 | out_channels=192, 532 | kernel_size=(1, 1), 533 | params_dict=get_from_dict(self.params_dict, 'branch3x3_1'), 534 | dtype=self.dtype 535 | )(x, train) 536 | 537 | branch3x3 = BasicConv2d( 538 | out_channels=320, 539 | kernel_size=(3, 3), 540 | strides=(2, 2), 541 | params_dict=get_from_dict(self.params_dict, 'branch3x3_2'), 542 | dtype=self.dtype 543 | )(branch3x3, train) 544 | 545 | branch7x7x3 = BasicConv2d( 546 | out_channels=192, 547 | kernel_size=(1, 1), 548 | params_dict=get_from_dict(self.params_dict, 'branch7x7x3_1'), 549 | dtype=self.dtype 550 | )(x, train) 551 | 552 | branch7x7x3 = BasicConv2d( 553 | out_channels=192, 554 | kernel_size=(1, 7), 555 | padding=((0, 0), (3, 3)), 556 | params_dict=get_from_dict(self.params_dict, 'branch7x7x3_2'), 557 | dtype=self.dtype 558 | )(branch7x7x3, train) 559 | 560 | branch7x7x3 = BasicConv2d( 561 | out_channels=192, 562 | kernel_size=(7, 1), 563 | padding=((3, 3), (0, 0)), 564 | params_dict=get_from_dict(self.params_dict, 'branch7x7x3_3'), 565 | dtype=self.dtype 566 | )(branch7x7x3, train) 567 | 568 | branch7x7x3 = BasicConv2d( 569 | out_channels=192, 570 | kernel_size=(3, 3), 571 | strides=(2, 2), 572 | params_dict=get_from_dict(self.params_dict, 'branch7x7x3_4'), 573 | dtype=self.dtype 574 | )(branch7x7x3, train) 575 | 576 | branch_pool = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2)) 577 | 578 | output = jnp.concatenate((branch3x3, branch7x7x3, branch_pool), axis=-1) 579 | return output 580 | 581 | class InceptionE(nn.Module): 582 | pooling: Callable 583 | params_dict: dict=None 584 | dtype: str='float32' 585 | 586 | @nn.compact 587 | def __call__(self, x, train=True): 588 | branch1x1 = BasicConv2d( 589 | out_channels=320, 590 | kernel_size=(1, 1), 591 | params_dict=get_from_dict(self.params_dict, 'branch1x1'), 592 | dtype=self.dtype 593 | )(x, train) 594 | 595 | branch3x3 = BasicConv2d( 596 | out_channels=384, 597 | kernel_size=(1, 1), 598 | params_dict=get_from_dict(self.params_dict, 'branch3x3_1'), 599 | dtype=self.dtype 600 | )(x, train) 601 | 602 | branch3x3_a = BasicConv2d( 603 | out_channels=384, 604 | kernel_size=(1, 3), 605 | padding=((0, 0), (1, 1)), 606 | params_dict=get_from_dict(self.params_dict, 'branch3x3_2a'), 607 | dtype=self.dtype 608 | )(branch3x3, train) 609 | 610 | branch3x3_b = BasicConv2d( 611 | out_channels=384, 612 | kernel_size=(3, 1), 613 | padding=((1, 1), (0, 0)), 614 | params_dict=get_from_dict(self.params_dict, 'branch3x3_2b'), 615 | dtype=self.dtype 616 | )(branch3x3, train) 617 | 618 | branch3x3 = jnp.concatenate((branch3x3_a, branch3x3_b), axis=-1) 619 | 620 | branch3x3dbl = BasicConv2d( 621 | out_channels=448, 622 | kernel_size=(1, 1), 623 | params_dict=get_from_dict(self.params_dict, 'branch3x3dbl_1'), 624 | dtype=self.dtype 625 | )(x, train) 626 | 627 | branch3x3dbl = BasicConv2d( 628 | out_channels=384, 629 | kernel_size=(3, 3), 630 | padding=((1, 1), (1, 1)), 631 | params_dict=get_from_dict(self.params_dict, 'branch3x3dbl_2'), 632 | dtype=self.dtype 633 | )(branch3x3dbl, train) 634 | 635 | branch3x3dbl_a = BasicConv2d( 636 | out_channels=384, 637 | kernel_size=(1, 3), 638 | padding=((0, 0), (1, 1)), 639 | params_dict=get_from_dict(self.params_dict, 'branch3x3dbl_3a'), 640 | dtype=self.dtype 641 | )(branch3x3dbl, train) 642 | 643 | branch3x3dbl_b = BasicConv2d( 644 | out_channels=384, 645 | kernel_size=(3, 1), 646 | padding=((1, 1), (0, 0)), 647 | params_dict=get_from_dict(self.params_dict, 'branch3x3dbl_3b'), 648 | dtype=self.dtype 649 | )(branch3x3dbl, train) 650 | 651 | branch3x3dbl = jnp.concatenate((branch3x3dbl_a, branch3x3dbl_b), axis=-1) 652 | 653 | branch_pool = self.pooling( 654 | x, 655 | window_shape=(3, 3), 656 | strides=(1, 1), 657 | padding=((1, 1), (1, 1)) 658 | ) 659 | 660 | branch_pool = BasicConv2d( 661 | out_channels=192, 662 | kernel_size=(1, 1), 663 | params_dict=get_from_dict(self.params_dict, 'branch_pool'), 664 | dtype=self.dtype 665 | )(branch_pool, train) 666 | 667 | output = jnp.concatenate((branch1x1, branch3x3, branch3x3dbl, branch_pool), axis=-1) 668 | return output 669 | 670 | def absolute_dims(rank, dims): 671 | return tuple([rank + dim if dim < 0 else dim for dim in dims]) 672 | 673 | def get_from_dict(dictionary, key): 674 | if dictionary is None or key not in dictionary: 675 | return None 676 | return dictionary[key] 677 | -------------------------------------------------------------------------------- /flake.lock: -------------------------------------------------------------------------------- 1 | { 2 | "nodes": { 3 | "flake-utils": { 4 | "inputs": { 5 | "systems": "systems" 6 | }, 7 | "locked": { 8 | "lastModified": 1710146030, 9 | "narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=", 10 | "owner": "numtide", 11 | "repo": "flake-utils", 12 | "rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a", 13 | "type": "github" 14 | }, 15 | "original": { 16 | "owner": "numtide", 17 | "repo": "flake-utils", 18 | "type": "github" 19 | } 20 | }, 21 | "flake-utils_2": { 22 | "locked": { 23 | "lastModified": 1659877975, 24 | "narHash": "sha256-zllb8aq3YO3h8B/U0/J1WBgAL8EX5yWf5pMj3G0NAmc=", 25 | "owner": "numtide", 26 | "repo": "flake-utils", 27 | "rev": "c0e246b9b83f637f4681389ecabcb2681b4f3af0", 28 | "type": "github" 29 | }, 30 | "original": { 31 | "owner": "numtide", 32 | "repo": "flake-utils", 33 | "type": "github" 34 | } 35 | }, 36 | "nixgl": { 37 | "inputs": { 38 | "flake-utils": "flake-utils_2", 39 | "nixpkgs": "nixpkgs" 40 | }, 41 | "locked": { 42 | "lastModified": 1710942276, 43 | "narHash": "sha256-+K0UtpWhsA5X0kZpN2eaBI6TP2OvrQAWFvtoeICWmhM=", 44 | "owner": "hayden-donnelly", 45 | "repo": "nixGL", 46 | "rev": "c58137326ef0bac8a504111591692ad849f49029", 47 | "type": "github" 48 | }, 49 | "original": { 50 | "owner": "hayden-donnelly", 51 | "repo": "nixGL", 52 | "type": "github" 53 | } 54 | }, 55 | "nixpkgs": { 56 | "locked": { 57 | "lastModified": 1660551188, 58 | "narHash": "sha256-a1LARMMYQ8DPx1BgoI/UN4bXe12hhZkCNqdxNi6uS0g=", 59 | "owner": "nixos", 60 | "repo": "nixpkgs", 61 | "rev": "441dc5d512153039f19ef198e662e4f3dbb9fd65", 62 | "type": "github" 63 | }, 64 | "original": { 65 | "owner": "nixos", 66 | "repo": "nixpkgs", 67 | "type": "github" 68 | } 69 | }, 70 | "nixpkgs-unstable": { 71 | "locked": { 72 | "lastModified": 1712192574, 73 | "narHash": "sha256-LbbVOliJKTF4Zl2b9salumvdMXuQBr2kuKP5+ZwbYq4=", 74 | "owner": "nixos", 75 | "repo": "nixpkgs", 76 | "rev": "f480f9d09e4b4cf87ee6151eba068197125714de", 77 | "type": "github" 78 | }, 79 | "original": { 80 | "owner": "nixos", 81 | "ref": "nixpkgs-unstable", 82 | "repo": "nixpkgs", 83 | "type": "github" 84 | } 85 | }, 86 | "nixpkgs_2": { 87 | "locked": { 88 | "lastModified": 1701282334, 89 | "narHash": "sha256-MxCVrXY6v4QmfTwIysjjaX0XUhqBbxTWWB4HXtDYsdk=", 90 | "owner": "nixos", 91 | "repo": "nixpkgs", 92 | "rev": "057f9aecfb71c4437d2b27d3323df7f93c010b7e", 93 | "type": "github" 94 | }, 95 | "original": { 96 | "owner": "nixos", 97 | "ref": "23.11", 98 | "repo": "nixpkgs", 99 | "type": "github" 100 | } 101 | }, 102 | "root": { 103 | "inputs": { 104 | "flake-utils": "flake-utils", 105 | "nixgl": "nixgl", 106 | "nixpkgs": "nixpkgs_2", 107 | "nixpkgs-unstable": "nixpkgs-unstable" 108 | } 109 | }, 110 | "systems": { 111 | "locked": { 112 | "lastModified": 1681028828, 113 | "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", 114 | "owner": "nix-systems", 115 | "repo": "default", 116 | "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", 117 | "type": "github" 118 | }, 119 | "original": { 120 | "owner": "nix-systems", 121 | "repo": "default", 122 | "type": "github" 123 | } 124 | } 125 | }, 126 | "root": "root", 127 | "version": 7 128 | } 129 | -------------------------------------------------------------------------------- /flake.nix: -------------------------------------------------------------------------------- 1 | { 2 | description = "Generative neural networks for 3D terrain"; 3 | 4 | inputs = { 5 | nixpkgs.url = "github:nixos/nixpkgs/23.11"; 6 | nixpkgs-unstable.url = "github:nixos/nixpkgs/nixpkgs-unstable"; 7 | flake-utils.url = "github:numtide/flake-utils"; 8 | # Patched version of nixGL from kenrandunderscore. 9 | # PR: https://github.com/nix-community/nixGL/pull/165 10 | # TODO: switch back to github:nix-community/nixGL when PR is merged. 11 | nixgl.url = "github:hayden-donnelly/nixGL"; 12 | }; 13 | outputs = inputs@{ self, nixpkgs, nixpkgs-unstable, flake-utils, nixgl, ... }: 14 | flake-utils.lib.eachSystem [ "x86_64-linux" ] (system: let 15 | inherit (nixpkgs-unstable) lib; 16 | in { 17 | devShells = let 18 | pyVer = "311"; 19 | py = "python${pyVer}"; 20 | overlays = [ 21 | nixgl.overlay 22 | (final: prev: { 23 | ${py} = prev.${py}.override { 24 | packageOverrides = finalPkgs: prevPkgs: { 25 | jax = prevPkgs.jax.overridePythonAttrs (o: { 26 | nativeCheckInputs = []; 27 | pythonImportsCheck = []; 28 | pytestFlagsArray = []; 29 | passthru.tests = []; 30 | doCheck = false; 31 | }); 32 | # For some reason Flax has jaxlib as a builtInput and tensorflow as a nativeCheckInput, 33 | # so set jaxlib to jaxlib-bin in order to avoid building jaxlib and turn off all checks 34 | # to avoid building tensorflow. 35 | jaxlib = prevPkgs.jaxlib-bin; 36 | flax = prevPkgs.flax.overridePythonAttrs (o: { 37 | nativeCheckInputs = []; 38 | pythonImportsCheck = []; 39 | pytestFlagsArray = []; 40 | doCheck = false; 41 | }); 42 | wandb = prevPkgs.wandb.overridePythonAttrs(o: { 43 | nativeCheckInputs = []; 44 | pythonIMportsCheck = []; 45 | doCheck = false; 46 | }); 47 | }; 48 | }; 49 | }) 50 | ]; 51 | unstableCudaPkgs = import nixpkgs-unstable { 52 | inherit system overlays; 53 | config = { 54 | allowUnfree = true; 55 | cudaSupport = true; 56 | }; 57 | }; 58 | in rec { 59 | default = unstableCudaPkgs.mkShell { 60 | name = "cuda"; 61 | buildInputs = [ 62 | (unstableCudaPkgs.${py}.withPackages (pyp: with pyp; [ 63 | jax 64 | jaxlib-bin 65 | flax 66 | pyarrow 67 | pillow 68 | pandas 69 | datasets 70 | wandb 71 | ])) 72 | unstableCudaPkgs.gdalMinimal 73 | unstableCudaPkgs.cudaPackages.cudatoolkit 74 | unstableCudaPkgs.cudaPackages.cuda_cudart 75 | unstableCudaPkgs.cudaPackages.cudnn 76 | ]; 77 | shellHook = '' 78 | source <(sed -Ee '/\$@/d' ${lib.getExe unstableCudaPkgs.nixgl.nixGLIntel}) 79 | source <(sed -Ee '/\$@/d' ${lib.getExe unstableCudaPkgs.nixgl.auto.nixGLNvidia}*) 80 | ''; 81 | }; 82 | }; 83 | }); 84 | } 85 | -------------------------------------------------------------------------------- /images/display_heightmaps.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/novaia/ntg/cb770a5dc7e1be81e2f143ec1a75f937c3b229c4/images/display_heightmaps.png -------------------------------------------------------------------------------- /models/common/config_utils.py: -------------------------------------------------------------------------------- 1 | from jax import numpy as jnp 2 | from flax import linen as nn 3 | import optax 4 | import argparse 5 | 6 | ACTIVATION_FN_MAP = {'gelu': nn.gelu, 'silu': nn.silu} 7 | DTYPE_MAP = {'float32': jnp.float32, 'bfloat16': jnp.bfloat16} 8 | 9 | def load_activation_fn(activation_fn_name:str): 10 | assert activation_fn_name in ACTIVATION_FN_MAP.keys(), ( 11 | f'Invalid activation function: {ACTIVATION_FN_MAP}. ', 12 | f'Must be one of the following: {list(ACTIVATION_FN_MAP.keys())}.' 13 | ) 14 | return ACTIVATION_FN_MAP[activation_fn_name] 15 | 16 | def load_dtype(dtype_name:str): 17 | assert dtype_name in DTYPE_MAP.keys(), ( 18 | f'Invalid dtype: {dtype_name}. Must be one of the following: {list(DTYPE_MAP.keys())}.' 19 | ) 20 | return DTYPE_MAP[dtype_name] 21 | 22 | def load_optimizer(config:dict, learning_rate): 23 | optimizer_name = config['optimizer'] 24 | if optimizer_name == 'sgd': 25 | return optax.sgd( 26 | learning_rate=learning_rate, momentum=config['sgd_momentum'], nesterov=config['sgd_nesterov'] 27 | ) 28 | elif optimizer_name == 'adam': 29 | return optax.adam( 30 | learning_rate=learning_rate, b1=config['adam_b1'], b2=config['adam_b2'], eps=config['adam_eps'], 31 | eps_root=config['adam_eps_root'] 32 | ) 33 | elif optimizer_name == 'adamw': 34 | return optax.adamw( 35 | learning_rate=learning_rate, b1=config['adam_b1'], b2=config['adam_b2'], eps=config['adam_eps'], 36 | eps_root=config['adam_eps_root'], weight_decay=config['weight_decay'] 37 | ) 38 | else: 39 | raise ValueError( 40 | f'Invalid optimizer: {optimizer_name}.', 41 | 'Must be one of the following: sgd, adam, adamw.' 42 | ) 43 | 44 | def parse_args(default_run_dir:str): 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument('--config', type=str, required=True) 47 | parser.add_argument('--dataset', type=str, required=True) 48 | parser.add_argument('--wandb', type=int, choices=[0, 1], default=1) 49 | parser.add_argument('--epochs_between_previews', type=int, default=1) 50 | parser.add_argument('--steps_between_wandb_logs', type=int, default=200) 51 | parser.add_argument('--save_checkpoints', type=int, choices=[0, 1], default=1) 52 | parser.add_argument('--checkpoint', type=str, default=None) 53 | parser.add_argument('--run_dir', type=str, default=default_run_dir) 54 | parser.add_argument('--tabulate', type=int, choices=[0, 1], default=0) 55 | return parser.parse_args() 56 | -------------------------------------------------------------------------------- /models/terra.py: -------------------------------------------------------------------------------- 1 | import jax 2 | from jax import numpy as jnp 3 | import numpy as np 4 | from flax import struct 5 | from flax import linen as nn 6 | from flax.training.train_state import TrainState 7 | from orbax import checkpoint as ocp 8 | import optax 9 | 10 | import glob 11 | from datasets import load_dataset, Dataset 12 | 13 | from PIL import Image 14 | import pandas as pd 15 | 16 | from typing import Any, Callable, List 17 | from functools import partial 18 | from datetime import datetime 19 | from copy import deepcopy 20 | import json 21 | import math 22 | import os 23 | 24 | from models.common import config_utils 25 | from sampling.diffusion import implicit as sample_implicit 26 | 27 | def save_samples(samples:jax.Array, step:int, save_dir:str): 28 | samples = ((samples + 1.0) / 2.0) * 255.0 29 | samples = jnp.clip(samples, 0.0, 255.0) 30 | samples = np.array(samples, dtype=np.uint8) 31 | for i in range(samples.shape[0]): 32 | image = Image.fromarray(samples[i].squeeze(axis=-1)) 33 | image.save(os.path.join(save_dir, f'step{step}_image{i}.png')) 34 | 35 | def get_dataset(dataset_path, batch_size): 36 | if dataset_path.endswith('/'): 37 | glob_pattern = f'{dataset_path}*.parquet' 38 | else: 39 | glob_pattern = f'{dataset_path}/*.parquet' 40 | parquet_files = glob.glob(glob_pattern) 41 | assert len(parquet_files) > 0, 'No parquet files were found in dataset directory.' 42 | print(f'Found {len(parquet_files)} parquet files in dataset directory.') 43 | dataset = load_dataset( 44 | 'parquet', 45 | data_files={'train': parquet_files}, 46 | split='train', 47 | num_proc=8 48 | ) 49 | steps_per_epoch = len(dataset) // batch_size 50 | dataset = dataset.with_format('jax') 51 | return dataset, steps_per_epoch 52 | 53 | class SinusoidalEmbedding(nn.Module): 54 | embedding_dim:int 55 | embedding_max_frequency:float 56 | embedding_min_frequency:float = 1.0 57 | dtype:Any = jnp.float32 58 | 59 | @nn.compact 60 | def __call__(self, x): 61 | frequencies = jnp.exp( 62 | jnp.linspace( 63 | jnp.log(self.embedding_min_frequency), 64 | jnp.log(self.embedding_max_frequency), 65 | self.embedding_dim // 2, 66 | dtype=self.dtype 67 | ) 68 | ) 69 | angular_speeds = 2.0 * math.pi * frequencies 70 | embeddings = jnp.concatenate( 71 | [jnp.sin(angular_speeds * x), jnp.cos(angular_speeds * x)], 72 | axis=-1, 73 | dtype=self.dtype 74 | ) 75 | return embeddings 76 | 77 | class ResidualBlock(nn.Module): 78 | num_features: int 79 | num_groups: int 80 | kernel_size: int 81 | activation_fn: Callable 82 | dtype: Any = jnp.float32 83 | param_dtype: Any = jnp.float32 84 | 85 | @nn.compact 86 | def __call__(self, x): 87 | input_features = x.shape[-1] 88 | if input_features == self.num_features: 89 | residual = x 90 | else: 91 | residual = nn.Conv( 92 | self.num_features, kernel_size=(1, 1), 93 | dtype=self.dtype, param_dtype=self.param_dtype 94 | )(x) 95 | x = nn.Conv( 96 | self.num_features, kernel_size=(self.kernel_size, self.kernel_size), 97 | dtype=self.dtype, param_dtype=self.param_dtype 98 | )(x) 99 | x = nn.GroupNorm(self.num_groups, dtype=self.dtype, param_dtype=self.param_dtype)(x) 100 | x = self.activation_fn(x) 101 | x = nn.Conv( 102 | self.num_features, kernel_size=(self.kernel_size, self.kernel_size), 103 | dtype=self.dtype, param_dtype=self.param_dtype 104 | )(x) 105 | x = self.activation_fn(x) 106 | x = x + residual 107 | return x 108 | 109 | class DownBlock(nn.Module): 110 | num_features: int 111 | num_groups: int 112 | block_depth: int 113 | kernel_size: int 114 | activation_fn: Callable 115 | dtype: Any = jnp.float32 116 | param_dtype: Any = jnp.float32 117 | 118 | @nn.compact 119 | def __call__(self, x, skips): 120 | for _ in range(self.block_depth): 121 | x = ResidualBlock( 122 | num_features=self.num_features, 123 | num_groups=self.num_groups, 124 | kernel_size=self.kernel_size, 125 | activation_fn=self.activation_fn, 126 | dtype=self.dtype, 127 | param_dtype=self.param_dtype 128 | )(x) 129 | skips.append(x) 130 | x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2)) 131 | return x, skips 132 | 133 | class UpBlock(nn.Module): 134 | num_features: int 135 | num_groups: int 136 | block_depth: int 137 | kernel_size: int 138 | activation_fn: Callable 139 | dtype: Any = jnp.float32 140 | param_dtype: Any = jnp.float32 141 | 142 | @nn.compact 143 | def __call__(self, x, skips): 144 | upsample_shape = (x.shape[0], x.shape[1] * 2, x.shape[2] * 2, x.shape[3]) 145 | x = jax.image.resize(x, upsample_shape, method='bilinear') 146 | 147 | for _ in range(self.block_depth): 148 | x = jnp.concatenate([x, skips.pop()], axis=-1) 149 | x = ResidualBlock( 150 | num_features=self.num_features, 151 | num_groups=self.num_groups, 152 | kernel_size=self.kernel_size, 153 | activation_fn=self.activation_fn, 154 | dtype=self.dtype, 155 | param_dtype=self.param_dtype 156 | )(x) 157 | return x, skips 158 | 159 | class Terra(nn.Module): 160 | embedding_dim: int 161 | embedding_max_frequency: float 162 | num_features: List[int] 163 | num_groups: List[int] 164 | kernel_size: int 165 | block_depth: int 166 | output_channels: int 167 | activation_fn: Callable 168 | dtype: Any = jnp.float32 169 | param_dtype: Any = jnp.float32 170 | 171 | @nn.compact 172 | def __call__(self, x, diffusion_time): 173 | time_emb = SinusoidalEmbedding( 174 | embedding_dim=self.embedding_dim, 175 | embedding_max_frequency=self.embedding_max_frequency, 176 | dtype=self.dtype 177 | )(diffusion_time) 178 | time_emb = jax.image.resize(time_emb, shape=(*x.shape[0:3], 1), method='nearest') 179 | x = jnp.concatenate([x, time_emb], axis=-1) 180 | 181 | skips = [] 182 | block_params = list(zip(self.num_features, self.num_groups))[:-1] 183 | for features, groups in block_params: 184 | x, skips = DownBlock( 185 | num_features=features, 186 | num_groups=groups, 187 | block_depth=self.block_depth, 188 | kernel_size=self.kernel_size, 189 | activation_fn=self.activation_fn, 190 | dtype=self.dtype, 191 | param_dtype=self.param_dtype 192 | )(x, skips) 193 | for _ in range(self.block_depth): 194 | x = ResidualBlock( 195 | num_features=self.num_features[-1], 196 | num_groups=self.num_groups[-1], 197 | kernel_size=self.kernel_size, 198 | activation_fn=self.activation_fn, 199 | dtype=self.dtype, 200 | param_dtype=self.param_dtype 201 | )(x) 202 | for features, groups in list(reversed(block_params)): 203 | x, skips = UpBlock( 204 | num_features=features, 205 | num_groups=groups, 206 | block_depth=self.block_depth, 207 | kernel_size=self.kernel_size, 208 | activation_fn=self.activation_fn, 209 | dtype=self.dtype, 210 | param_dtype=self.param_dtype 211 | )(x, skips) 212 | 213 | x = nn.Conv( 214 | self.output_channels, kernel_size=(1, 1), 215 | dtype=jnp.float32, param_dtype=jnp.float32 216 | )(x) 217 | return x 218 | 219 | class EmaTrainState(TrainState): 220 | ema_warmup: int = struct.field(pytree_node=False) 221 | ema_decay: float = struct.field(pytree_node=False) 222 | ema_params: dict = struct.field(pytree_node=True) 223 | 224 | def update_ema(self): 225 | def true_fn(state): 226 | def _update_ema(ema_param, base_param): 227 | return state.ema_decay * ema_param + (1 - state.ema_decay) * base_param 228 | 229 | new_ema_params = jax.tree_map(_update_ema, state.ema_params, state.params) 230 | return state.replace(ema_params=new_ema_params) 231 | 232 | def false_fn(state): 233 | return state.replace(ema_params=self.params) 234 | 235 | return jax.lax.cond(self.step <= self.ema_warmup, false_fn, true_fn, self) 236 | 237 | def diffusion_schedule(diffusion_times, min_signal_rate, max_signal_rate): 238 | start_angle = jnp.arccos(max_signal_rate) 239 | end_angle = jnp.arccos(min_signal_rate) 240 | diffusion_angles = start_angle + diffusion_times * (end_angle - start_angle) 241 | 242 | signal_rates = jnp.cos(diffusion_angles) 243 | noise_rates = jnp.sin(diffusion_angles) 244 | return noise_rates, signal_rates 245 | 246 | @partial(jax.jit, static_argnames=['min_signal_rate', 'max_signal_rate']) 247 | def train_step(state, images, min_signal_rate, max_signal_rate): 248 | key = jax.random.PRNGKey(state.step) 249 | images = (images / 127.5) - 1.0 250 | noise_key, diffusion_time_key = jax.random.split(key, 2) 251 | noises = jax.random.normal(noise_key, images.shape, dtype=jnp.float32) 252 | diffusion_times = jax.random.uniform(diffusion_time_key, (images.shape[0], 1, 1, 1)) 253 | noise_rates, signal_rates = diffusion_schedule( 254 | diffusion_times, min_signal_rate, max_signal_rate 255 | ) 256 | noisy_images = signal_rates * images + noise_rates * noises 257 | 258 | def loss_fn(params): 259 | pred_noises = state.apply_fn({'params': params}, noisy_images, noise_rates**2) 260 | return jnp.mean((pred_noises - noises)**2) 261 | 262 | grad_fn = jax.value_and_grad(loss_fn) 263 | loss, grads = grad_fn(state.params) 264 | state = state.apply_gradients(grads=grads).update_ema() 265 | return loss, state 266 | 267 | def main(): 268 | gpu = jax.devices('gpu')[0] 269 | print(gpu) 270 | 271 | args = config_utils.parse_args(default_run_dir='data/terra_runs/0') 272 | 273 | checkpoint_save_dir = os.path.join(args.run_dir, 'checkpoints') 274 | fixed_seed_save_dir = os.path.join(args.run_dir, 'images/fixed') 275 | dynamic_seed_save_dir = os.path.join(args.run_dir, 'images/dynamic') 276 | if not os.path.exists(checkpoint_save_dir): 277 | os.makedirs(checkpoint_save_dir) 278 | if not os.path.exists(fixed_seed_save_dir): 279 | os.makedirs(fixed_seed_save_dir) 280 | if not os.path.exists(dynamic_seed_save_dir): 281 | os.makedirs(dynamic_seed_save_dir) 282 | 283 | with open(args.config, 'r') as f: 284 | config = json.load(f) 285 | assert len(config['num_features']) == len(config['num_groups']), ( 286 | 'len(num_features) must equal len(num_groups).' 287 | ) 288 | 289 | dataset, steps_per_epoch = get_dataset( 290 | dataset_path=args.dataset, 291 | batch_size=config['batch_size'] 292 | ) 293 | print(f'Steps per epoch: {steps_per_epoch:,}') 294 | 295 | activation_fn = config_utils.load_activation_fn(config['activation_fn']) 296 | dtype = config_utils.load_dtype(config['dtype']) 297 | param_dtype = config_utils.load_dtype(config['param_dtype']) 298 | 299 | model = Terra( 300 | embedding_dim=config['embedding_dim'], 301 | embedding_max_frequency=config['embedding_max_frequency'], 302 | num_features=config['num_features'], 303 | num_groups=config['num_groups'], 304 | block_depth=config['block_depth'], 305 | kernel_size=config['kernel_size'], 306 | output_channels=config['output_channels'], 307 | activation_fn=activation_fn, 308 | dtype=dtype, 309 | param_dtype=param_dtype 310 | ) 311 | x = jnp.ones( 312 | ( 313 | config['batch_size'], 314 | config['image_size'], 315 | config['image_size'], 316 | config['output_channels'] 317 | ), 318 | dtype=dtype 319 | ) 320 | diffusion_times = jnp.ones((config['batch_size'], 1, 1, 1), dtype=dtype) 321 | model_key = jax.random.PRNGKey(0) 322 | if args.tabulate: 323 | print(model.tabulate(model_key, x, diffusion_times)) 324 | exit(0) 325 | params = model.init(model_key, x, diffusion_times)['params'] 326 | 327 | epochs_to_steps = partial(lambda steps, epochs: int(steps * epochs), steps=steps_per_epoch) 328 | lr_schedule = optax.warmup_exponential_decay_schedule( 329 | init_value=config['lr_base'], 330 | peak_value=config['lr_max'], 331 | warmup_steps=epochs_to_steps(epochs=config['lr_warmup_epochs']), 332 | transition_steps=epochs_to_steps(epochs=config['lr_decay_epochs']), 333 | decay_rate=config['lr_decay_rate'], 334 | staircase=False, 335 | end_value=config['lr_min'] 336 | ) 337 | tx = optax.chain( 338 | optax.zero_nans(), 339 | optax.adaptive_grad_clip(clipping=config['adaptive_grad_clip']), 340 | config_utils.load_optimizer(config=config, learning_rate=lr_schedule) 341 | ) 342 | state = EmaTrainState.create( 343 | apply_fn=model.apply, params=params, tx=tx, 344 | ema_warmup=config['ema_warmup'], ema_decay=config['ema_decay'], 345 | ema_params=deepcopy(params) 346 | ) 347 | 348 | param_count = sum(x.size for x in jax.tree_util.tree_leaves(state.params)) 349 | config['param_count'] = param_count 350 | print(f'Param count: {param_count:,}') 351 | 352 | checkpointer = ocp.Checkpointer(ocp.PyTreeCheckpointHandler(use_ocdbt=True)) 353 | if args.checkpoint is not None: 354 | state = checkpointer.restore(args.checkpoint, item=state) 355 | 356 | if args.wandb == 1: 357 | import wandb 358 | wandb.init(project='ntg-terra', config=config) 359 | min_signal_rate = config['min_signal_rate'] 360 | max_signal_rate = config['max_signal_rate'] 361 | sample_fn = partial( 362 | sample_implicit, 363 | num_images=16, 364 | diffusion_steps=20, 365 | diffusion_schedule=diffusion_schedule, 366 | image_width=config['image_size'], 367 | image_height=config['image_size'], 368 | channels=config['output_channels'], 369 | min_signal_rate=min_signal_rate, 370 | max_signal_rate=max_signal_rate, 371 | ) 372 | 373 | steps_between_loss_report = 300 374 | steps_since_last_loss_report = 0 375 | accumulated_losses = [] 376 | for epoch in range(config['epochs']): 377 | dataset.shuffle(seed=epoch) 378 | data_iterator = dataset.iter(batch_size=config['batch_size']) 379 | epoch_start_time = datetime.now() 380 | for _ in range(steps_per_epoch): 381 | images = jnp.expand_dims(next(data_iterator)['heightmap'], axis=-1) 382 | loss, state = train_step(state, images, min_signal_rate, max_signal_rate) 383 | accumulated_losses.append(loss) 384 | steps_since_last_loss_report += 1 385 | if steps_since_last_loss_report >= steps_between_loss_report: 386 | average_loss = sum(accumulated_losses) / len(accumulated_losses) 387 | if args.wandb == 1: 388 | if state.step % args.steps_between_wandb_logs == 0: 389 | wandb.log({'loss': average_loss}, step=state.step) 390 | else: 391 | print(state.step, average_loss) 392 | steps_since_last_loss_report = 0 393 | accumulated_losses = [] 394 | epoch_end_time = datetime.now() 395 | print( 396 | f'Epoch {epoch} completed in {epoch_end_time-epoch_start_time} at {epoch_end_time}' 397 | ) 398 | 399 | if args.save_checkpoints == 1: 400 | checkpointer.save( 401 | (os.path.join(os.path.abspath(checkpoint_save_dir), f'step{state.step}')), 402 | state, force=True 403 | ) 404 | if (epoch+1) % args.epochs_between_previews != 0: 405 | continue 406 | 407 | fixed_seed_samples = sample_fn(apply_fn=state.apply_fn, params=state.ema_params, seed=0) 408 | dynamic_seed_samples = sample_fn(apply_fn=state.apply_fn, params=state.ema_params, seed=state.step) 409 | save_samples(fixed_seed_samples, state.step, fixed_seed_save_dir) 410 | save_samples(dynamic_seed_samples, state.step, dynamic_seed_save_dir) 411 | 412 | if __name__ == '__main__': 413 | main() 414 | -------------------------------------------------------------------------------- /sampling/diffusion.py: -------------------------------------------------------------------------------- 1 | import jax 2 | from jax import lax 3 | from jax import numpy as jnp 4 | 5 | from typing import Callable, Any 6 | 7 | # TODO: figure out what this sampling method is called, it is implicit reverse diffusion 8 | # since it skips steps but I'm pretty sure there is a more specific name. 9 | def implicit( 10 | apply_fn:Callable, 11 | params:Any, 12 | num_images:int, 13 | diffusion_steps:int, 14 | diffusion_schedule:Callable, 15 | image_width:int, 16 | image_height:int, 17 | channels:int, 18 | min_signal_rate:float, 19 | max_signal_rate:float, 20 | seed:int, 21 | ): 22 | @jax.jit 23 | def inference_fn(noisy_images, diffusion_times): 24 | return lax.stop_gradient(apply_fn({'params': params}, noisy_images, diffusion_times)) 25 | 26 | initial_noise = jax.random.normal( 27 | jax.random.PRNGKey(seed), 28 | shape=(num_images, image_height, image_width, channels) 29 | ) 30 | step_size = 1.0 / diffusion_steps 31 | 32 | next_noisy_images = initial_noise 33 | for step in range(diffusion_steps): 34 | noisy_images = next_noisy_images 35 | 36 | diffusion_times = jnp.ones((num_images, 1, 1, 1)) - step * step_size 37 | noise_rates, signal_rates = diffusion_schedule( 38 | diffusion_times, min_signal_rate, max_signal_rate 39 | ) 40 | pred_noises = inference_fn(noisy_images, noise_rates**2) 41 | pred_images = (noisy_images - noise_rates * pred_noises) / signal_rates 42 | 43 | next_diffusion_times = diffusion_times - step_size 44 | next_noise_rates, next_signal_rates = diffusion_schedule( 45 | next_diffusion_times, min_signal_rate, max_signal_rate 46 | ) 47 | next_noisy_images = (next_signal_rates * pred_images + next_noise_rates * pred_noises) 48 | return pred_images 49 | -------------------------------------------------------------------------------- /utilities/inception_test.py: -------------------------------------------------------------------------------- 1 | # Import numpy library 2 | import numpy as np 3 | 4 | # Load the two numpy files 5 | file1 = np.load('../data/dataset_info/second_pass_fid_stats.npz') 6 | file2 = np.load('../data/dataset_info/stats.npz') 7 | 8 | # Extract the matrices 'mu' and 'sigma' from each file 9 | mu1 = file1['mu'] 10 | sigma1 = file1['sigma'] 11 | mu2 = file2['mu'] 12 | sigma2 = file2['sigma'] 13 | 14 | # Compare the mu and sigma from each file to see if they are the same 15 | if np.array_equal(mu1, mu2) and np.array_equal(sigma1, sigma2): 16 | print("The mu and sigma from both files are the same.") 17 | else: 18 | print("The mu and sigma from both files are different.") 19 | 20 | if np.array_equal(mu1, mu2): 21 | print("The mu from both files are the same.") 22 | 23 | if np.array_equal(sigma1, sigma2): 24 | print("The sigma from both files are the same.") -------------------------------------------------------------------------------- /utilities/tf_to_onnx.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tf2onnx 3 | 4 | input_path = "../data/temp/" 5 | model_name = "saved_model" 6 | file_type = '' 7 | output_path = "../data/temp" 8 | 9 | #pre_model = tf.keras.models.load_model(input_path + model_name + file_type) 10 | #tf2onnx.convert.from_keras( 11 | # pre_model, 12 | # output_path = output_path + model_name + ".onnx", 13 | # opset = 9 14 | #) 15 | 16 | # maybe export model graph then convert from graph to onnx? 17 | pre_model = tf.saved_model.load(input_path + model_name + file_type) 18 | print(dir(pre_model)) 19 | print(pre_model.vars) 20 | #tf2onnx.convert.from_tflite( 21 | # pre_model, 22 | # output_path = output_path + model_name + ".onnx", 23 | # opset = 9 24 | #) --------------------------------------------------------------------------------