├── LICENSE ├── configs ├── compressible_NS_cfg_config.yaml ├── compressible_NS_uncond_config.yaml ├── compressible_NS_vt_config.yaml ├── compressible_NS_xattn_config.yaml ├── darcy_cfg_config.yaml ├── darcy_uncond_config.yaml ├── darcy_vt_config.yaml ├── darcy_xattn_config.yaml ├── diffusion_reaction_cfg_config.yaml ├── diffusion_reaction_uncond_config.yaml ├── diffusion_reaction_vt_config.yaml ├── diffusion_reaction_xattn_config.yaml ├── shallow_water_cfg_config.yaml ├── shallow_water_uncond_config.yaml ├── shallow_water_vt_config.yaml └── shallow_water_xattn_config.yaml ├── dataloader └── dataset_class.py ├── evaluate.py ├── git_assest ├── bar_chart_0.01_largefont.png ├── darcy.png ├── dr_hf.png ├── encoding_block.png └── error_hist.png ├── losses ├── loss.py └── metric.py ├── models ├── unet2D.py └── unet2DCondition.py ├── noise_schedulers └── noise_sampler.py ├── pipelines └── pipeline_inv_prob.py ├── readme.md ├── requirements.txt ├── requirements_noversion.txt ├── train_cond.py ├── train_uncond.py ├── train_vt.py └── utils ├── attn_utils.py ├── general_utils.py ├── inverse_utils.py ├── pipeline_utils.py └── vt_utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /configs/compressible_NS_cfg_config.yaml: -------------------------------------------------------------------------------- 1 | unet: 2 | _class_name: diffuserUNet2DCondition 3 | _diffusers_version: 0.29.2 4 | act_fn: silu 5 | addition_embed_type: null 6 | addition_embed_type_num_heads: 64 7 | addition_time_embed_dim: null 8 | attention_head_dim: 8 9 | attention_type: default 10 | block_out_channels: 11 | - 128 12 | - 256 13 | - 256 14 | - 256 15 | center_input_sample: false 16 | class_embed_type: null 17 | class_embeddings_concat: false 18 | conv_in_kernel: 3 19 | conv_out_kernel: 3 20 | cross_attention_dim: 512 21 | cross_attention_norm: null 22 | down_block_types: 23 | - DownBlock2D 24 | - DownBlock2D 25 | - DownBlock2D 26 | - DownBlock2D 27 | downsample_padding: 1 28 | dropout: 0.0 29 | dual_cross_attention: false 30 | encoder_hid_dim: null 31 | encoder_hid_dim_type: null 32 | flip_sin_to_cos: true 33 | freq_shift: 0 34 | in_channels: 4 35 | layers_per_block: 2 36 | mid_block_only_cross_attention: null 37 | mid_block_scale_factor: 1 38 | mid_block_type: UNetMidBlock2D 39 | norm_eps: 1e-05 40 | norm_num_groups: 32 41 | num_attention_heads: null 42 | num_class_embeds: null 43 | only_cross_attention: false 44 | out_channels: 4 45 | projection_class_embeddings_input_dim: null 46 | resnet_out_scale_factor: 1.0 47 | resnet_skip_time_act: false 48 | resnet_time_scale_shift: scale_shift 49 | reverse_transformer_layers_per_block: null 50 | sample_size: 128 51 | time_cond_proj_dim: null 52 | time_embedding_act_fn: null 53 | time_embedding_dim: null 54 | time_embedding_type: positional 55 | timestep_post_act: null 56 | transformer_layers_per_block: 1 57 | up_block_types: 58 | - UpBlock2D 59 | - UpBlock2D 60 | - UpBlock2D 61 | - UpBlock2D 62 | upcast_attention: false 63 | use_linear_projection: false 64 | field_encoder_dict: 65 | hidden_size: 512 66 | intermediate_size: 2048 67 | projection_dim: 512 68 | image_size: 69 | - 128 70 | - 128 71 | patch_size: 8 72 | num_channels: 4 73 | num_hidden_layers: 4 74 | num_attention_heads: 8 75 | input_padding: 76 | - 0 77 | - 0 78 | output_hidden_state: false 79 | 80 | noise_scheduler: 81 | target: diffusers.EDMDPMSolverMultistepScheduler 82 | params: 83 | num_train_timesteps: 1000 84 | 85 | loss_fn: 86 | target: losses.loss.EDMLoss 87 | params: 88 | sigma_data: 0.5 89 | 90 | optimizer: 91 | betas: 92 | - 0.9 93 | - 0.999 94 | eps: 1e-08 95 | lr: 1e-4 96 | weight_decay: 1e-2 97 | 98 | lr_scheduler: 99 | #name: cosine 100 | name: constant 101 | num_warmup_steps: 500 102 | num_cycles: 0.5 103 | power: 1.0 104 | 105 | dataloader: 106 | data_dir: /scratch/kdur_root/kdur/shared_data/pdebench_data/2D/CFD/2D_Train_Rand/2D_CFD_Rand_M0.1_Eta0.1_Zeta0.1_periodic_128_Train.hdf5 107 | batch_size: 16 108 | num_workers: 0 109 | split_ratios: 110 | - 0.8 111 | - 0.2 112 | - 0.0 113 | transform: normalize 114 | transform_args: 115 | mean: [0.0001482, 0.0000293, 5.0561895, 25.3754578] 116 | std: [0.0422693, 0.0422711, 2.9536870, 22.0486488] 117 | target_std: 0.5 118 | data_name: compressible_NS 119 | 120 | accelerator: 121 | mixed_precision: fp16 122 | gradient_accumulation_steps: 1 123 | log_with: tensorboard 124 | 125 | ema: 126 | use_ema: True 127 | offload_ema: False 128 | ema_max_decay: 0.9999 129 | ema_inv_gamma: 1.0 130 | ema_power: 0.75 131 | foreach: True 132 | 133 | general: 134 | seed: 42 135 | num_epochs: null 136 | num_training_steps: 100000 137 | known_channels: [0,1,2,3] 138 | same_mask: True 139 | scale_lr: False 140 | output_dir: /scratch/kdur_root/kdur/ylzhuang/log/compressible_NS_cfg 141 | logging_dir: comp_NS 142 | tracker_project_name: compNS_tracker 143 | save_image_epochs: 5 144 | save_model_epochs: 50 145 | checkpointing_steps: 25000 146 | eval_batch_size: 8 147 | cond_drop_prob: null 148 | do_edm_style_training: True 149 | snr_gamma: null 150 | channel_names: ["Vx", "Vy", "pressure", "density"] -------------------------------------------------------------------------------- /configs/compressible_NS_uncond_config.yaml: -------------------------------------------------------------------------------- 1 | unet: 2 | _class_name: UNet2DModel 3 | _diffusers_version: 0.28.2 4 | act_fn: silu 5 | add_attention: true 6 | attention_head_dim: 8 7 | attn_norm_num_groups: null 8 | block_out_channels: 9 | - 128 10 | - 256 11 | - 256 12 | - 256 13 | center_input_sample: false 14 | class_embed_type: null 15 | down_block_types: 16 | - DownBlock2D 17 | - DownBlock2D 18 | - DownBlock2D 19 | - DownBlock2D 20 | downsample_padding: 1 21 | downsample_type: conv 22 | dropout: 0.0 23 | flip_sin_to_cos: true 24 | freq_shift: 0 25 | in_channels: 4 26 | layers_per_block: 2 27 | mid_block_scale_factor: 1 28 | norm_eps: 1e-05 29 | norm_num_groups: 32 30 | num_class_embeds: null 31 | num_train_timesteps: null 32 | out_channels: 4 33 | resnet_time_scale_shift: scale_shift 34 | sample_size: 128 35 | time_embedding_type: positional 36 | up_block_types: 37 | - UpBlock2D 38 | - UpBlock2D 39 | - UpBlock2D 40 | - UpBlock2D 41 | upsample_type: conv 42 | 43 | noise_scheduler: 44 | target: diffusers.EDMDPMSolverMultistepScheduler 45 | params: 46 | num_train_timesteps: 1000 47 | 48 | loss_fn: 49 | target: losses.loss.EDMLoss 50 | params: 51 | sigma_data: 0.5 52 | 53 | optimizer: 54 | betas: 55 | - 0.9 56 | - 0.999 57 | eps: 1e-08 58 | lr: 1e-4 59 | weight_decay: 1e-2 60 | 61 | lr_scheduler: 62 | #name: cosine 63 | name: constant 64 | num_warmup_steps: 500 65 | num_cycles: 0.5 66 | power: 1.0 67 | 68 | dataloader: 69 | data_dir: /scratch/kdur_root/kdur/shared_data/pdebench_data/2D/CFD/2D_Train_Rand/2D_CFD_Rand_M0.1_Eta0.1_Zeta0.1_periodic_128_Train.hdf5 70 | batch_size: 16 71 | num_workers: 0 72 | split_ratios: 73 | - 0.8 74 | - 0.2 75 | - 0.0 76 | transform: normalize 77 | transform_args: 78 | mean: [0.0001482, 0.0000293, 5.0561895, 25.3754578] 79 | std: [0.0422693, 0.0422711, 2.9536870, 22.0486488] 80 | target_std: 0.5 81 | data_name: compressible_NS 82 | 83 | accelerator: 84 | mixed_precision: fp16 85 | gradient_accumulation_steps: 1 86 | log_with: tensorboard 87 | 88 | ema: 89 | use_ema: True 90 | offload_ema: False 91 | ema_max_decay: 0.9999 92 | ema_inv_gamma: 1.0 93 | ema_power: 0.75 94 | foreach: True 95 | 96 | general: 97 | seed: 42 98 | num_epochs: null 99 | num_training_steps: 100000 100 | known_channels: [0,1,2,3] 101 | same_mask: True 102 | scale_lr: False 103 | output_dir: /scratch/kdur_root/kdur/ylzhuang/log/compressible_NS_uncond 104 | logging_dir: comp_NS 105 | tracker_project_name: compNS_tracker 106 | save_image_epochs: 5 107 | save_model_epochs: 50 108 | checkpointing_steps: 25000 109 | eval_batch_size: 8 110 | cond_drop_prob: null 111 | do_edm_style_training: True 112 | snr_gamma: null 113 | channel_names: ["Vx", "Vy", "pressure", "density"] -------------------------------------------------------------------------------- /configs/compressible_NS_vt_config.yaml: -------------------------------------------------------------------------------- 1 | unet: 2 | _class_name: UNet2DModel 3 | _diffusers_version: 0.28.2 4 | act_fn: silu 5 | add_attention: true 6 | attention_head_dim: 8 7 | attn_norm_num_groups: null 8 | block_out_channels: 9 | - 128 10 | - 256 11 | - 256 12 | - 256 13 | center_input_sample: false 14 | class_embed_type: null 15 | down_block_types: 16 | - DownBlock2D 17 | - DownBlock2D 18 | - DownBlock2D 19 | - DownBlock2D 20 | downsample_padding: 1 21 | downsample_type: conv 22 | dropout: 0.0 23 | flip_sin_to_cos: true 24 | freq_shift: 0 25 | in_channels: 4 26 | layers_per_block: 2 27 | mid_block_scale_factor: 1 28 | norm_eps: 1e-05 29 | norm_num_groups: 32 30 | num_class_embeds: null 31 | num_train_timesteps: null 32 | out_channels: 4 33 | resnet_time_scale_shift: scale_shift 34 | sample_size: 128 35 | time_embedding_type: positional 36 | up_block_types: 37 | - UpBlock2D 38 | - UpBlock2D 39 | - UpBlock2D 40 | - UpBlock2D 41 | upsample_type: conv 42 | 43 | noise_scheduler: 44 | target: diffusers.EDMDPMSolverMultistepScheduler 45 | params: 46 | num_train_timesteps: 1000 47 | 48 | loss_fn: 49 | target: losses.loss.EDMLoss 50 | params: 51 | sigma_data: 0.5 52 | 53 | optimizer: 54 | betas: 55 | - 0.9 56 | - 0.999 57 | eps: 1e-08 58 | lr: 1e-4 59 | weight_decay: 1e-2 60 | 61 | lr_scheduler: 62 | #name: cosine 63 | name: constant 64 | num_warmup_steps: 500 65 | num_cycles: 0.5 66 | power: 1.0 67 | 68 | dataloader: 69 | data_dir: /scratch/kdur_root/kdur/shared_data/pdebench_data/2D/CFD/2D_Train_Rand/2D_CFD_Rand_M0.1_Eta0.1_Zeta0.1_periodic_128_Train.hdf5 70 | batch_size: 16 71 | num_workers: 0 72 | split_ratios: 73 | - 0.8 74 | - 0.2 75 | - 0.0 76 | transform: normalize 77 | transform_args: 78 | mean: [0.0001482, 0.0000293, 5.0561895, 25.3754578] 79 | std: [0.0422693, 0.0422711, 2.9536870, 22.0486488] 80 | target_std: 0.5 81 | data_name: compressible_NS 82 | 83 | accelerator: 84 | mixed_precision: fp16 85 | gradient_accumulation_steps: 1 86 | log_with: tensorboard 87 | 88 | ema: 89 | use_ema: True 90 | offload_ema: False 91 | ema_max_decay: 0.9999 92 | ema_inv_gamma: 1.0 93 | ema_power: 0.75 94 | foreach: True 95 | 96 | general: 97 | seed: 42 98 | num_epochs: null 99 | num_training_steps: 100000 100 | known_channels: [0,1,2,3] 101 | same_mask: True 102 | scale_lr: False 103 | output_dir: /scratch/kdur_root/kdur/ylzhuang/log/compressible_NS_vt 104 | logging_dir: comp_NS 105 | tracker_project_name: compNS_tracker 106 | save_image_epochs: 5 107 | save_model_epochs: 50 108 | checkpointing_steps: 25000 109 | eval_batch_size: 8 110 | do_edm_style_training: True 111 | channel_names: ["Vx", "Vy", "pressure", "density"] -------------------------------------------------------------------------------- /configs/compressible_NS_xattn_config.yaml: -------------------------------------------------------------------------------- 1 | unet: 2 | _class_name: diffuserUNet2DCondition 3 | _diffusers_version: 0.29.2 4 | act_fn: silu 5 | addition_embed_type: null 6 | addition_embed_type_num_heads: 64 7 | addition_time_embed_dim: null 8 | attention_head_dim: 8 9 | attention_type: default 10 | block_out_channels: 11 | - 128 12 | - 256 13 | - 256 14 | - 256 15 | center_input_sample: false 16 | class_embed_type: null 17 | class_embeddings_concat: false 18 | conv_in_kernel: 3 19 | conv_out_kernel: 3 20 | cross_attention_dim: 512 21 | cross_attention_norm: null 22 | down_block_types: 23 | - DownBlock2D 24 | - CrossAttnDownBlock2D 25 | - CrossAttnDownBlock2D 26 | - CrossAttnDownBlock2D 27 | downsample_padding: 1 28 | dropout: 0.0 29 | dual_cross_attention: false 30 | encoder_hid_dim: null 31 | encoder_hid_dim_type: null 32 | flip_sin_to_cos: true 33 | freq_shift: 0 34 | in_channels: 4 35 | layers_per_block: 2 36 | mid_block_only_cross_attention: null 37 | mid_block_scale_factor: 1 38 | mid_block_type: UNetMidBlock2DCrossAttn 39 | norm_eps: 1e-05 40 | norm_num_groups: 32 41 | num_attention_heads: null 42 | num_class_embeds: null 43 | only_cross_attention: false 44 | out_channels: 4 45 | projection_class_embeddings_input_dim: null 46 | resnet_out_scale_factor: 1.0 47 | resnet_skip_time_act: false 48 | resnet_time_scale_shift: scale_shift 49 | reverse_transformer_layers_per_block: null 50 | sample_size: 128 51 | time_cond_proj_dim: null 52 | time_embedding_act_fn: null 53 | time_embedding_dim: null 54 | time_embedding_type: positional 55 | timestep_post_act: null 56 | transformer_layers_per_block: 1 57 | up_block_types: 58 | - CrossAttnUpBlock2D 59 | - CrossAttnUpBlock2D 60 | - CrossAttnUpBlock2D 61 | - UpBlock2D 62 | upcast_attention: false 63 | use_linear_projection: false 64 | field_encoder_dict: 65 | hidden_size: 512 66 | intermediate_size: 2048 67 | projection_dim: 512 68 | image_size: 69 | - 128 70 | - 128 71 | patch_size: 8 72 | num_channels: 4 73 | num_hidden_layers: 4 74 | num_attention_heads: 8 75 | input_padding: 76 | - 0 77 | - 0 78 | output_hidden_state: true 79 | 80 | noise_scheduler: 81 | target: diffusers.EDMDPMSolverMultistepScheduler 82 | params: 83 | num_train_timesteps: 1000 84 | 85 | loss_fn: 86 | target: losses.loss.EDMLoss 87 | params: 88 | sigma_data: 0.5 89 | 90 | optimizer: 91 | betas: 92 | - 0.9 93 | - 0.999 94 | eps: 1e-08 95 | lr: 1e-4 96 | weight_decay: 1e-2 97 | 98 | lr_scheduler: 99 | #name: cosine 100 | name: constant 101 | num_warmup_steps: 500 102 | num_cycles: 0.5 103 | power: 1.0 104 | 105 | dataloader: 106 | data_dir: /scratch/kdur_root/kdur/shared_data/pdebench_data/2D/CFD/2D_Train_Rand/2D_CFD_Rand_M0.1_Eta0.1_Zeta0.1_periodic_128_Train.hdf5 107 | batch_size: 16 108 | num_workers: 0 109 | split_ratios: 110 | - 0.8 111 | - 0.2 112 | - 0.0 113 | transform: normalize 114 | transform_args: 115 | mean: [0.0001482, 0.0000293, 5.0561895, 25.3754578] 116 | std: [0.0422693, 0.0422711, 2.9536870, 22.0486488] 117 | target_std: 0.5 118 | data_name: compressible_NS 119 | 120 | accelerator: 121 | mixed_precision: fp16 122 | gradient_accumulation_steps: 1 123 | log_with: tensorboard 124 | 125 | ema: 126 | use_ema: True 127 | offload_ema: False 128 | ema_max_decay: 0.9999 129 | ema_inv_gamma: 1.0 130 | ema_power: 0.75 131 | foreach: True 132 | 133 | general: 134 | seed: 42 135 | num_epochs: null 136 | num_training_steps: 100000 137 | known_channels: [0,1,2,3] 138 | same_mask: True 139 | scale_lr: False 140 | output_dir: /scratch/kdur_root/kdur/ylzhuang/log/compressible_NS_xattn 141 | logging_dir: comp_NS 142 | tracker_project_name: compNS_tracker 143 | save_image_epochs: 5 144 | save_model_epochs: 50 145 | checkpointing_steps: 25000 146 | eval_batch_size: 8 147 | cond_drop_prob: null 148 | do_edm_style_training: True 149 | snr_gamma: null 150 | channel_names: ["Vx", "Vy", "pressure", "density"] -------------------------------------------------------------------------------- /configs/darcy_cfg_config.yaml: -------------------------------------------------------------------------------- 1 | unet: 2 | _class_name: diffuserUNet2DCondition 3 | _diffusers_version: 0.29.2 4 | act_fn: silu 5 | addition_embed_type: null 6 | addition_embed_type_num_heads: 64 7 | addition_time_embed_dim: null 8 | attention_head_dim: 8 9 | attention_type: default 10 | block_out_channels: 11 | - 128 12 | - 256 13 | - 256 14 | - 256 15 | center_input_sample: false 16 | class_embed_type: null 17 | class_embeddings_concat: false 18 | conv_in_kernel: 3 19 | conv_out_kernel: 3 20 | cross_attention_dim: 256 21 | cross_attention_norm: null 22 | down_block_types: 23 | - DownBlock2D 24 | - DownBlock2D 25 | - DownBlock2D 26 | - DownBlock2D 27 | downsample_padding: 1 28 | dropout: 0.0 29 | dual_cross_attention: false 30 | encoder_hid_dim: null 31 | encoder_hid_dim_type: null 32 | flip_sin_to_cos: true 33 | freq_shift: 0 34 | in_channels: 2 35 | layers_per_block: 2 36 | mid_block_only_cross_attention: null 37 | mid_block_scale_factor: 1 38 | mid_block_type: UNetMidBlock2D 39 | norm_eps: 1e-05 40 | norm_num_groups: 32 41 | num_attention_heads: null 42 | num_class_embeds: null 43 | only_cross_attention: false 44 | out_channels: 2 45 | projection_class_embeddings_input_dim: null 46 | resnet_out_scale_factor: 1.0 47 | resnet_skip_time_act: false 48 | resnet_time_scale_shift: scale_shift 49 | reverse_transformer_layers_per_block: null 50 | sample_size: 128 51 | time_cond_proj_dim: null 52 | time_embedding_act_fn: null 53 | time_embedding_dim: null 54 | time_embedding_type: positional 55 | timestep_post_act: null 56 | transformer_layers_per_block: 1 57 | up_block_types: 58 | - UpBlock2D 59 | - UpBlock2D 60 | - UpBlock2D 61 | - UpBlock2D 62 | upcast_attention: false 63 | use_linear_projection: false 64 | field_encoder_dict: 65 | hidden_size: 256 66 | intermediate_size: 1024 67 | projection_dim: 256 68 | image_size: 69 | - 128 70 | - 128 71 | patch_size: 8 72 | num_channels: 1 73 | num_hidden_layers: 4 74 | num_attention_heads: 8 75 | input_padding: 76 | - 0 77 | - 0 78 | output_hidden_state: false 79 | 80 | noise_scheduler: 81 | target: diffusers.EDMDPMSolverMultistepScheduler 82 | params: 83 | num_train_timesteps: 1000 84 | 85 | loss_fn: 86 | target: losses.loss.EDMLoss 87 | params: 88 | sigma_data: 0.5 89 | 90 | optimizer: 91 | betas: 92 | - 0.9 93 | - 0.999 94 | eps: 1e-08 95 | lr: 1e-4 96 | weight_decay: 1e-2 97 | 98 | lr_scheduler: 99 | #name: cosine 100 | name: constant 101 | num_warmup_steps: 500 102 | num_cycles: 0.5 103 | power: 1.0 104 | 105 | dataloader: 106 | data_dir: /scratch/kdur_root/kdur/ylzhuang/darcy_mod.npy 107 | batch_size: 16 108 | num_workers: 0 109 | split_ratios: 110 | - 0.8 111 | - 0.2 112 | - 0.0 113 | transform: normalize 114 | transform_args: 115 | mean: [1.1482742, 46.6775513] 116 | std: [0.6740283, 30.9050026] 117 | target_std: 0.5 118 | data_name: darcy 119 | 120 | accelerator: 121 | mixed_precision: fp16 122 | gradient_accumulation_steps: 1 123 | log_with: tensorboard 124 | 125 | ema: 126 | use_ema: True 127 | offload_ema: False 128 | ema_max_decay: 0.9999 129 | ema_inv_gamma: 1.0 130 | ema_power: 0.75 131 | foreach: True 132 | 133 | general: 134 | seed: 42 135 | num_epochs: null 136 | num_training_steps: 100000 137 | known_channels: [1] 138 | same_mask: False 139 | scale_lr: False 140 | output_dir: /scratch/kdur_root/kdur/ylzhuang/log/darcy_cfg 141 | logging_dir: darcy 142 | tracker_project_name: darcy_tracker 143 | save_image_epochs: 50 144 | save_model_epochs: 200 145 | checkpointing_steps: 25000 146 | eval_batch_size: 8 147 | cond_drop_prob: null 148 | do_edm_style_training: True 149 | snr_gamma: null 150 | channel_names: ["permeability", "pressure"] -------------------------------------------------------------------------------- /configs/darcy_uncond_config.yaml: -------------------------------------------------------------------------------- 1 | unet: 2 | _class_name: UNet2DModel 3 | _diffusers_version: 0.28.2 4 | act_fn: silu 5 | add_attention: true 6 | attention_head_dim: 8 7 | attn_norm_num_groups: null 8 | block_out_channels: 9 | - 128 10 | - 256 11 | - 256 12 | - 256 13 | center_input_sample: false 14 | class_embed_type: null 15 | down_block_types: 16 | - DownBlock2D 17 | - DownBlock2D 18 | - DownBlock2D 19 | - DownBlock2D 20 | downsample_padding: 1 21 | downsample_type: conv 22 | dropout: 0.0 23 | flip_sin_to_cos: true 24 | freq_shift: 0 25 | in_channels: 2 26 | layers_per_block: 2 27 | mid_block_scale_factor: 1 28 | norm_eps: 1e-05 29 | norm_num_groups: 32 30 | num_class_embeds: null 31 | num_train_timesteps: null 32 | out_channels: 2 33 | resnet_time_scale_shift: scale_shift 34 | sample_size: 128 35 | time_embedding_type: positional 36 | up_block_types: 37 | - UpBlock2D 38 | - UpBlock2D 39 | - UpBlock2D 40 | - UpBlock2D 41 | upsample_type: conv 42 | 43 | noise_scheduler: 44 | target: diffusers.EDMDPMSolverMultistepScheduler 45 | params: 46 | num_train_timesteps: 1000 47 | 48 | loss_fn: 49 | target: losses.loss.EDMLoss 50 | params: 51 | sigma_data: 0.5 52 | 53 | optimizer: 54 | betas: 55 | - 0.9 56 | - 0.999 57 | eps: 1e-08 58 | lr: 1e-4 59 | weight_decay: 1e-2 60 | 61 | lr_scheduler: 62 | #name: cosine 63 | name: constant 64 | num_warmup_steps: 500 65 | num_cycles: 0.5 66 | power: 1.0 67 | 68 | dataloader: 69 | data_dir: /scratch/kdur_root/kdur/ylzhuang/darcy_mod.npy 70 | batch_size: 16 71 | num_workers: 0 72 | split_ratios: 73 | - 0.8 74 | - 0.2 75 | - 0.0 76 | transform: normalize 77 | transform_args: 78 | mean: [1.1482742, 46.6775513] 79 | std: [0.6740283, 30.9050026] 80 | target_std: 0.5 81 | data_name: darcy 82 | 83 | accelerator: 84 | mixed_precision: fp16 85 | gradient_accumulation_steps: 1 86 | log_with: tensorboard 87 | 88 | ema: 89 | use_ema: True 90 | offload_ema: False 91 | ema_max_decay: 0.9999 92 | ema_inv_gamma: 1.0 93 | ema_power: 0.75 94 | foreach: True 95 | 96 | general: 97 | seed: 42 98 | num_epochs: null 99 | num_training_steps: 100000 100 | known_channels: [1] 101 | same_mask: False 102 | scale_lr: False 103 | output_dir: /scratch/kdur_root/kdur/ylzhuang/log/darcy_uncond 104 | logging_dir: darcy 105 | tracker_project_name: darcy_tracker 106 | save_image_epochs: 50 107 | save_model_epochs: 200 108 | checkpointing_steps: 25000 109 | eval_batch_size: 8 110 | do_edm_style_training: True 111 | snr_gamma: null 112 | channel_names: ["permeability", "pressure"] -------------------------------------------------------------------------------- /configs/darcy_vt_config.yaml: -------------------------------------------------------------------------------- 1 | unet: 2 | _class_name: UNet2DModel 3 | _diffusers_version: 0.28.2 4 | act_fn: silu 5 | add_attention: true 6 | attention_head_dim: 8 7 | attn_norm_num_groups: null 8 | block_out_channels: 9 | - 128 10 | - 256 11 | - 256 12 | - 256 13 | center_input_sample: false 14 | class_embed_type: null 15 | down_block_types: 16 | - DownBlock2D 17 | - DownBlock2D 18 | - DownBlock2D 19 | - DownBlock2D 20 | downsample_padding: 1 21 | downsample_type: conv 22 | dropout: 0.0 23 | flip_sin_to_cos: true 24 | freq_shift: 0 25 | in_channels: 1 26 | layers_per_block: 2 27 | mid_block_scale_factor: 1 28 | norm_eps: 1e-05 29 | norm_num_groups: 32 30 | num_class_embeds: null 31 | num_train_timesteps: null 32 | out_channels: 2 33 | resnet_time_scale_shift: scale_shift 34 | sample_size: 128 35 | time_embedding_type: positional 36 | up_block_types: 37 | - UpBlock2D 38 | - UpBlock2D 39 | - UpBlock2D 40 | - UpBlock2D 41 | upsample_type: conv 42 | 43 | noise_scheduler: 44 | target: diffusers.EDMDPMSolverMultistepScheduler 45 | params: 46 | num_train_timesteps: 1000 47 | 48 | loss_fn: 49 | target: losses.loss.EDMLoss 50 | params: 51 | sigma_data: 0.5 52 | 53 | optimizer: 54 | betas: 55 | - 0.9 56 | - 0.999 57 | eps: 1e-08 58 | lr: 1e-4 59 | weight_decay: 1e-2 60 | 61 | lr_scheduler: 62 | #name: cosine 63 | name: constant 64 | num_warmup_steps: 500 65 | num_cycles: 0.5 66 | power: 1.0 67 | 68 | dataloader: 69 | data_dir: /scratch/kdur_root/kdur/ylzhuang/darcy_mod.npy 70 | batch_size: 16 71 | num_workers: 0 72 | split_ratios: 73 | - 0.8 74 | - 0.2 75 | - 0.0 76 | transform: normalize 77 | transform_args: 78 | mean: [1.1482742, 46.6775513] 79 | std: [0.6740283, 30.9050026] 80 | target_std: 0.5 81 | data_name: darcy 82 | 83 | accelerator: 84 | mixed_precision: fp16 85 | gradient_accumulation_steps: 1 86 | log_with: tensorboard 87 | 88 | ema: 89 | use_ema: True 90 | offload_ema: False 91 | ema_max_decay: 0.9999 92 | ema_inv_gamma: 1.0 93 | ema_power: 0.75 94 | foreach: True 95 | 96 | general: 97 | seed: 42 98 | num_epochs: null 99 | num_training_steps: 100000 100 | known_channels: [1] 101 | same_mask: False 102 | scale_lr: False 103 | output_dir: /scratch/kdur_root/kdur/ylzhuang/log/darcy_vt 104 | logging_dir: darcy 105 | tracker_project_name: darcy_tracker 106 | save_image_epochs: 50 107 | save_model_epochs: 200 108 | checkpointing_steps: 25000 109 | eval_batch_size: 8 110 | do_edm_style_training: True 111 | channel_names: ["permeability", "pressure"] -------------------------------------------------------------------------------- /configs/darcy_xattn_config.yaml: -------------------------------------------------------------------------------- 1 | unet: 2 | _class_name: diffuserUNet2DCondition 3 | _diffusers_version: 0.29.2 4 | act_fn: silu 5 | addition_embed_type: null 6 | addition_embed_type_num_heads: 64 7 | addition_time_embed_dim: null 8 | attention_head_dim: 8 9 | attention_type: default 10 | block_out_channels: 11 | - 128 12 | - 256 13 | - 256 14 | - 256 15 | center_input_sample: false 16 | class_embed_type: null 17 | class_embeddings_concat: false 18 | conv_in_kernel: 3 19 | conv_out_kernel: 3 20 | cross_attention_dim: 256 21 | cross_attention_norm: null 22 | down_block_types: 23 | - DownBlock2D 24 | - CrossAttnDownBlock2D 25 | - CrossAttnDownBlock2D 26 | - CrossAttnDownBlock2D 27 | downsample_padding: 1 28 | dropout: 0.0 29 | dual_cross_attention: false 30 | encoder_hid_dim: null 31 | encoder_hid_dim_type: null 32 | flip_sin_to_cos: true 33 | freq_shift: 0 34 | in_channels: 2 35 | layers_per_block: 2 36 | mid_block_only_cross_attention: null 37 | mid_block_scale_factor: 1 38 | mid_block_type: UNetMidBlock2DCrossAttn 39 | norm_eps: 1e-05 40 | norm_num_groups: 32 41 | num_attention_heads: null 42 | num_class_embeds: null 43 | only_cross_attention: false 44 | out_channels: 2 45 | projection_class_embeddings_input_dim: null 46 | resnet_out_scale_factor: 1.0 47 | resnet_skip_time_act: false 48 | resnet_time_scale_shift: scale_shift 49 | reverse_transformer_layers_per_block: null 50 | sample_size: 128 51 | time_cond_proj_dim: null 52 | time_embedding_act_fn: null 53 | time_embedding_dim: null 54 | time_embedding_type: positional 55 | timestep_post_act: null 56 | transformer_layers_per_block: 1 57 | up_block_types: 58 | - CrossAttnUpBlock2D 59 | - CrossAttnUpBlock2D 60 | - CrossAttnUpBlock2D 61 | - UpBlock2D 62 | upcast_attention: false 63 | use_linear_projection: false 64 | field_encoder_dict: 65 | hidden_size: 256 66 | intermediate_size: 1024 67 | projection_dim: 256 68 | image_size: 69 | - 128 70 | - 128 71 | patch_size: 8 72 | num_channels: 1 73 | num_hidden_layers: 4 74 | num_attention_heads: 8 75 | input_padding: 76 | - 0 77 | - 0 78 | output_hidden_state: true 79 | 80 | noise_scheduler: 81 | target: diffusers.EDMDPMSolverMultistepScheduler 82 | params: 83 | num_train_timesteps: 1000 84 | 85 | loss_fn: 86 | target: losses.loss.EDMLoss 87 | params: 88 | sigma_data: 0.5 89 | 90 | optimizer: 91 | betas: 92 | - 0.9 93 | - 0.999 94 | eps: 1e-08 95 | lr: 1e-4 96 | weight_decay: 1e-2 97 | 98 | lr_scheduler: 99 | #name: cosine 100 | name: constant 101 | num_warmup_steps: 500 102 | num_cycles: 0.5 103 | power: 1.0 104 | 105 | dataloader: 106 | data_dir: /scratch/kdur_root/kdur/ylzhuang/darcy_mod.npy 107 | batch_size: 16 108 | num_workers: 0 109 | split_ratios: 110 | - 0.8 111 | - 0.2 112 | - 0.0 113 | transform: normalize 114 | transform_args: 115 | mean: [1.1482742, 46.6775513] 116 | std: [0.6740283, 30.9050026] 117 | target_std: 0.5 118 | data_name: darcy 119 | 120 | accelerator: 121 | mixed_precision: fp16 122 | gradient_accumulation_steps: 1 123 | log_with: tensorboard 124 | 125 | ema: 126 | use_ema: True 127 | offload_ema: False 128 | ema_max_decay: 0.9999 129 | ema_inv_gamma: 1.0 130 | ema_power: 0.75 131 | foreach: True 132 | 133 | general: 134 | seed: 42 135 | num_epochs: null 136 | num_training_steps: 100000 137 | known_channels: [1] 138 | same_mask: False 139 | scale_lr: False 140 | output_dir: /scratch/kdur_root/kdur/ylzhuang/log/darcy_xattn 141 | logging_dir: darcy 142 | tracker_project_name: darcy_tracker 143 | save_image_epochs: 50 144 | save_model_epochs: 200 145 | checkpointing_steps: 25000 146 | eval_batch_size: 8 147 | cond_drop_prob: null 148 | do_edm_style_training: True 149 | snr_gamma: null 150 | channel_names: ["permeability", "pressure"] -------------------------------------------------------------------------------- /configs/diffusion_reaction_cfg_config.yaml: -------------------------------------------------------------------------------- 1 | unet: 2 | _class_name: diffuserUNet2DCondition 3 | _diffusers_version: 0.29.2 4 | act_fn: silu 5 | addition_embed_type: null 6 | addition_embed_type_num_heads: 64 7 | addition_time_embed_dim: null 8 | attention_head_dim: 8 9 | attention_type: default 10 | block_out_channels: 11 | - 128 12 | - 256 13 | - 256 14 | - 256 15 | center_input_sample: false 16 | class_embed_type: null 17 | class_embeddings_concat: false 18 | conv_in_kernel: 3 19 | conv_out_kernel: 3 20 | cross_attention_dim: 256 21 | cross_attention_norm: null 22 | down_block_types: 23 | - DownBlock2D 24 | - DownBlock2D 25 | - DownBlock2D 26 | - DownBlock2D 27 | downsample_padding: 1 28 | dropout: 0.0 29 | dual_cross_attention: false 30 | encoder_hid_dim: null 31 | encoder_hid_dim_type: null 32 | flip_sin_to_cos: true 33 | freq_shift: 0 34 | in_channels: 2 35 | layers_per_block: 2 36 | mid_block_only_cross_attention: null 37 | mid_block_scale_factor: 1 38 | mid_block_type: UNetMidBlock2D 39 | norm_eps: 1e-05 40 | norm_num_groups: 32 41 | num_attention_heads: null 42 | num_class_embeds: null 43 | only_cross_attention: false 44 | out_channels: 2 45 | projection_class_embeddings_input_dim: null 46 | resnet_out_scale_factor: 1.0 47 | resnet_skip_time_act: false 48 | resnet_time_scale_shift: scale_shift 49 | reverse_transformer_layers_per_block: null 50 | sample_size: 128 51 | time_cond_proj_dim: null 52 | time_embedding_act_fn: null 53 | time_embedding_dim: null 54 | time_embedding_type: positional 55 | timestep_post_act: null 56 | transformer_layers_per_block: 1 57 | up_block_types: 58 | - UpBlock2D 59 | - UpBlock2D 60 | - UpBlock2D 61 | - UpBlock2D 62 | upcast_attention: false 63 | use_linear_projection: false 64 | field_encoder_dict: 65 | hidden_size: 256 66 | intermediate_size: 1024 67 | projection_dim: 256 68 | image_size: 69 | - 128 70 | - 128 71 | patch_size: 8 72 | num_channels: 2 73 | num_hidden_layers: 4 74 | num_attention_heads: 8 75 | input_padding: 76 | - 0 77 | - 0 78 | output_hidden_state: false 79 | 80 | noise_scheduler: 81 | target: diffusers.EDMDPMSolverMultistepScheduler 82 | params: 83 | num_train_timesteps: 1000 84 | 85 | loss_fn: 86 | target: losses.loss.EDMLoss 87 | params: 88 | sigma_data: 0.5 89 | 90 | optimizer: 91 | betas: 92 | - 0.9 93 | - 0.999 94 | eps: 1e-08 95 | lr: 1e-4 96 | weight_decay: 1e-2 97 | 98 | lr_scheduler: 99 | #name: cosine 100 | name: constant 101 | num_warmup_steps: 500 102 | num_cycles: 0.5 103 | power: 1.0 104 | 105 | dataloader: 106 | data_dir: /scratch/kdur_root/kdur/shared_data/pdebench_data/2D/diffusion-reaction/2D_diff-react_NA_NA_reorg2.h5 107 | batch_size: 16 108 | num_workers: 0 109 | split_ratios: 110 | - 0.8 111 | - 0.2 112 | - 0.0 113 | transform: normalize 114 | transform_args: 115 | mean: [-0.0311127, -0.0199022] 116 | std: [0.1438150, 0.1117546] 117 | target_std: 0.5 118 | data_name: diffusion_reaction 119 | 120 | accelerator: 121 | mixed_precision: fp16 122 | gradient_accumulation_steps: 1 123 | log_with: tensorboard 124 | 125 | ema: 126 | use_ema: True 127 | offload_ema: False 128 | ema_max_decay: 0.9999 129 | ema_inv_gamma: 1.0 130 | ema_power: 0.75 131 | foreach: True 132 | 133 | general: 134 | seed: 42 135 | num_epochs: null 136 | num_training_steps: 100000 137 | known_channels: [0,1] 138 | same_mask: True 139 | scale_lr: False 140 | output_dir: /scratch/kdur_root/kdur/ylzhuang/log/diffusion_reaction_cfg 141 | logging_dir: diff_react 142 | tracker_project_name: diffreact_tracker 143 | save_image_epochs: 5 144 | save_model_epochs: 50 145 | checkpointing_steps: 25000 146 | eval_batch_size: 8 147 | cond_drop_prob: null 148 | do_edm_style_training: True 149 | snr_gamma: null 150 | channel_names: ["u", "v"] -------------------------------------------------------------------------------- /configs/diffusion_reaction_uncond_config.yaml: -------------------------------------------------------------------------------- 1 | unet: 2 | _class_name: UNet2DModel 3 | _diffusers_version: 0.28.2 4 | act_fn: silu 5 | add_attention: true 6 | attention_head_dim: 8 7 | attn_norm_num_groups: null 8 | block_out_channels: 9 | - 128 10 | - 256 11 | - 256 12 | - 256 13 | center_input_sample: false 14 | class_embed_type: null 15 | down_block_types: 16 | - DownBlock2D 17 | - DownBlock2D 18 | - DownBlock2D 19 | - DownBlock2D 20 | downsample_padding: 1 21 | downsample_type: conv 22 | dropout: 0.0 23 | flip_sin_to_cos: true 24 | freq_shift: 0 25 | in_channels: 2 26 | layers_per_block: 2 27 | mid_block_scale_factor: 1 28 | norm_eps: 1e-05 29 | norm_num_groups: 32 30 | num_class_embeds: null 31 | num_train_timesteps: null 32 | out_channels: 2 33 | resnet_time_scale_shift: scale_shift 34 | sample_size: 128 35 | time_embedding_type: positional 36 | up_block_types: 37 | - UpBlock2D 38 | - UpBlock2D 39 | - UpBlock2D 40 | - UpBlock2D 41 | upsample_type: conv 42 | 43 | noise_scheduler: 44 | target: diffusers.EDMDPMSolverMultistepScheduler 45 | params: 46 | num_train_timesteps: 1000 47 | 48 | loss_fn: 49 | target: losses.loss.EDMLoss 50 | params: 51 | sigma_data: 0.5 52 | 53 | optimizer: 54 | betas: 55 | - 0.9 56 | - 0.999 57 | eps: 1e-08 58 | lr: 1e-4 59 | weight_decay: 1e-2 60 | 61 | lr_scheduler: 62 | #name: cosine 63 | name: constant 64 | num_warmup_steps: 500 65 | num_cycles: 0.5 66 | power: 1.0 67 | 68 | dataloader: 69 | data_dir: /scratch/kdur_root/kdur/shared_data/pdebench_data/2D/diffusion-reaction/2D_diff-react_NA_NA_reorg2.h5 70 | batch_size: 16 71 | num_workers: 0 72 | split_ratios: 73 | - 0.8 74 | - 0.2 75 | - 0.0 76 | transform: normalize 77 | transform_args: 78 | mean: [-0.0311127, -0.0199022] 79 | std: [0.1438150, 0.1117546] 80 | target_std: 0.5 81 | data_name: diffusion_reaction 82 | 83 | accelerator: 84 | mixed_precision: fp16 85 | gradient_accumulation_steps: 1 86 | log_with: tensorboard 87 | 88 | ema: 89 | use_ema: True 90 | offload_ema: False 91 | ema_max_decay: 0.9999 92 | ema_inv_gamma: 1.0 93 | ema_power: 0.75 94 | foreach: True 95 | 96 | general: 97 | seed: 42 98 | num_epochs: null 99 | num_training_steps: 100000 100 | known_channels: [0,1] 101 | same_mask: True 102 | scale_lr: False 103 | output_dir: /scratch/kdur_root/kdur/ylzhuang/log/diffusion_reaction_uncond 104 | logging_dir: diff_react 105 | tracker_project_name: diffreact_tracker 106 | save_image_epochs: 5 107 | save_model_epochs: 50 108 | checkpointing_steps: 25000 109 | eval_batch_size: 8 110 | cond_drop_prob: null 111 | do_edm_style_training: True 112 | snr_gamma: null 113 | channel_names: ["u", "v"] -------------------------------------------------------------------------------- /configs/diffusion_reaction_vt_config.yaml: -------------------------------------------------------------------------------- 1 | unet: 2 | _class_name: UNet2DModel 3 | _diffusers_version: 0.28.2 4 | act_fn: silu 5 | add_attention: true 6 | attention_head_dim: 8 7 | attn_norm_num_groups: null 8 | block_out_channels: 9 | - 128 10 | - 256 11 | - 256 12 | - 256 13 | center_input_sample: false 14 | class_embed_type: null 15 | down_block_types: 16 | - DownBlock2D 17 | - DownBlock2D 18 | - DownBlock2D 19 | - DownBlock2D 20 | downsample_padding: 1 21 | downsample_type: conv 22 | dropout: 0.0 23 | flip_sin_to_cos: true 24 | freq_shift: 0 25 | in_channels: 2 26 | layers_per_block: 2 27 | mid_block_scale_factor: 1 28 | norm_eps: 1e-05 29 | norm_num_groups: 32 30 | num_class_embeds: null 31 | num_train_timesteps: null 32 | out_channels: 2 33 | sample_size: 128 34 | time_embedding_type: positional 35 | up_block_types: 36 | - UpBlock2D 37 | - UpBlock2D 38 | - UpBlock2D 39 | - UpBlock2D 40 | upsample_type: conv 41 | 42 | noise_scheduler: 43 | target: diffusers.EDMDPMSolverMultistepScheduler 44 | params: 45 | num_train_timesteps: 1000 46 | 47 | loss_fn: 48 | target: losses.loss.EDMLoss 49 | params: 50 | sigma_data: 0.5 51 | 52 | optimizer: 53 | betas: 54 | - 0.9 55 | - 0.999 56 | eps: 1e-08 57 | lr: 1e-4 58 | weight_decay: 1e-2 59 | 60 | lr_scheduler: 61 | #name: cosine 62 | name: constant 63 | num_warmup_steps: 500 64 | num_cycles: 0.5 65 | power: 1.0 66 | dataloader: 67 | data_dir: /scratch/kdur_root/kdur/shared_data/pdebench_data/2D/diffusion-reaction/2D_diff-react_NA_NA_reorg2.h5 68 | batch_size: 16 69 | num_workers: 0 70 | split_ratios: 71 | - 0.8 72 | - 0.2 73 | - 0.0 74 | transform: normalize 75 | transform_args: 76 | mean: [-0.0311127, -0.0199022] 77 | std: [0.1438150, 0.1117546] 78 | target_std: 0.5 79 | data_name: diffusion_reaction 80 | 81 | accelerator: 82 | mixed_precision: fp16 83 | gradient_accumulation_steps: 1 84 | log_with: tensorboard 85 | 86 | ema: 87 | use_ema: True 88 | offload_ema: False 89 | ema_max_decay: 0.9999 90 | ema_inv_gamma: 1.0 91 | ema_power: 0.75 92 | foreach: True 93 | 94 | general: 95 | seed: 42 96 | num_epochs: null 97 | num_training_steps: 100000 98 | known_channels: [0,1] 99 | same_mask: True 100 | scale_lr: False 101 | output_dir: /scratch/kdur_root/kdur/ylzhuang/log/diffusion_reaction_vt 102 | logging_dir: diff_react 103 | tracker_project_name: diffreact_tracker 104 | save_image_epochs: 5 105 | save_model_epochs: 50 106 | checkpointing_steps: 25000 107 | eval_batch_size: 8 108 | do_edm_style_training: True 109 | channel_names: ["u", "v"] -------------------------------------------------------------------------------- /configs/diffusion_reaction_xattn_config.yaml: -------------------------------------------------------------------------------- 1 | unet: 2 | _class_name: diffuserUNet2DCondition 3 | _diffusers_version: 0.29.2 4 | act_fn: silu 5 | addition_embed_type: null 6 | addition_embed_type_num_heads: 64 7 | addition_time_embed_dim: null 8 | attention_head_dim: 8 9 | attention_type: default 10 | block_out_channels: 11 | - 128 12 | - 256 13 | - 256 14 | - 256 15 | center_input_sample: false 16 | class_embed_type: null 17 | class_embeddings_concat: false 18 | conv_in_kernel: 3 19 | conv_out_kernel: 3 20 | cross_attention_dim: 256 21 | cross_attention_norm: null 22 | down_block_types: 23 | - DownBlock2D 24 | - CrossAttnDownBlock2D 25 | - CrossAttnDownBlock2D 26 | - CrossAttnDownBlock2D 27 | downsample_padding: 1 28 | dropout: 0.0 29 | dual_cross_attention: false 30 | encoder_hid_dim: null 31 | encoder_hid_dim_type: null 32 | flip_sin_to_cos: true 33 | freq_shift: 0 34 | in_channels: 2 35 | layers_per_block: 2 36 | mid_block_only_cross_attention: null 37 | mid_block_scale_factor: 1 38 | mid_block_type: UNetMidBlock2DCrossAttn 39 | norm_eps: 1e-05 40 | norm_num_groups: 32 41 | num_attention_heads: null 42 | num_class_embeds: null 43 | only_cross_attention: false 44 | out_channels: 2 45 | projection_class_embeddings_input_dim: null 46 | resnet_out_scale_factor: 1.0 47 | resnet_skip_time_act: false 48 | resnet_time_scale_shift: scale_shift 49 | reverse_transformer_layers_per_block: null 50 | sample_size: 128 51 | time_cond_proj_dim: null 52 | time_embedding_act_fn: null 53 | time_embedding_dim: null 54 | time_embedding_type: positional 55 | timestep_post_act: null 56 | transformer_layers_per_block: 1 57 | up_block_types: 58 | - CrossAttnUpBlock2D 59 | - CrossAttnUpBlock2D 60 | - CrossAttnUpBlock2D 61 | - UpBlock2D 62 | upcast_attention: false 63 | use_linear_projection: false 64 | field_encoder_dict: 65 | hidden_size: 256 66 | intermediate_size: 1024 67 | projection_dim: 256 68 | image_size: 69 | - 128 70 | - 128 71 | patch_size: 8 72 | num_channels: 2 73 | num_hidden_layers: 4 74 | num_attention_heads: 8 75 | input_padding: 76 | - 0 77 | - 0 78 | output_hidden_state: true 79 | 80 | noise_scheduler: 81 | target: diffusers.EDMDPMSolverMultistepScheduler 82 | params: 83 | num_train_timesteps: 1000 84 | 85 | loss_fn: 86 | target: losses.loss.EDMLoss 87 | params: 88 | sigma_data: 0.5 89 | 90 | optimizer: 91 | betas: 92 | - 0.9 93 | - 0.999 94 | eps: 1e-08 95 | lr: 1e-4 96 | weight_decay: 1e-2 97 | 98 | lr_scheduler: 99 | #name: cosine 100 | name: constant 101 | num_warmup_steps: 500 102 | num_cycles: 0.5 103 | power: 1.0 104 | 105 | dataloader: 106 | data_dir: /scratch/kdur_root/kdur/shared_data/pdebench_data/2D/diffusion-reaction/2D_diff-react_NA_NA_reorg2.h5 107 | batch_size: 16 108 | num_workers: 0 109 | split_ratios: 110 | - 0.8 111 | - 0.2 112 | - 0.0 113 | transform: normalize 114 | transform_args: 115 | mean: [-0.0311127, -0.0199022] 116 | std: [0.1438150, 0.1117546] 117 | target_std: 0.5 118 | data_name: diffusion_reaction 119 | 120 | accelerator: 121 | mixed_precision: fp16 122 | gradient_accumulation_steps: 1 123 | log_with: tensorboard 124 | 125 | ema: 126 | use_ema: True 127 | offload_ema: False 128 | ema_max_decay: 0.9999 129 | ema_inv_gamma: 1.0 130 | ema_power: 0.75 131 | foreach: True 132 | 133 | general: 134 | seed: 42 135 | num_epochs: null 136 | num_training_steps: 100000 137 | known_channels: [0,1] 138 | same_mask: True 139 | scale_lr: False 140 | output_dir: /scratch/kdur_root/kdur/ylzhuang/log/diffusion_reaction_xattn 141 | logging_dir: diff_react 142 | tracker_project_name: diffreact_tracker 143 | save_image_epochs: 5 144 | save_model_epochs: 50 145 | checkpointing_steps: 25000 146 | eval_batch_size: 8 147 | cond_drop_prob: null 148 | do_edm_style_training: True 149 | snr_gamma: null 150 | channel_names: ["u", "v"] -------------------------------------------------------------------------------- /configs/shallow_water_cfg_config.yaml: -------------------------------------------------------------------------------- 1 | unet: 2 | _class_name: diffuserUNet2DCondition 3 | _diffusers_version: 0.29.2 4 | act_fn: silu 5 | addition_embed_type: null 6 | addition_embed_type_num_heads: 64 7 | addition_time_embed_dim: null 8 | attention_head_dim: 8 9 | attention_type: default 10 | block_out_channels: 11 | - 128 12 | - 256 13 | - 256 14 | center_input_sample: false 15 | class_embed_type: null 16 | class_embeddings_concat: false 17 | conv_in_kernel: 3 18 | conv_out_kernel: 3 19 | cross_attention_dim: 384 20 | cross_attention_norm: null 21 | down_block_types: 22 | - DownBlock2D 23 | - DownBlock2D 24 | - DownBlock2D 25 | downsample_padding: 1 26 | dropout: 0.0 27 | dual_cross_attention: false 28 | encoder_hid_dim: null 29 | encoder_hid_dim_type: null 30 | flip_sin_to_cos: true 31 | freq_shift: 0 32 | in_channels: 3 33 | layers_per_block: 2 34 | mid_block_only_cross_attention: null 35 | mid_block_scale_factor: 1 36 | mid_block_type: UNetMidBlock2D 37 | norm_eps: 1e-05 38 | norm_num_groups: 32 39 | num_attention_heads: null 40 | num_class_embeds: null 41 | only_cross_attention: false 42 | out_channels: 3 43 | projection_class_embeddings_input_dim: null 44 | resnet_out_scale_factor: 1.0 45 | resnet_skip_time_act: false 46 | resnet_time_scale_shift: scale_shift 47 | reverse_transformer_layers_per_block: null 48 | sample_size: 64 49 | time_cond_proj_dim: null 50 | time_embedding_act_fn: null 51 | time_embedding_dim: null 52 | time_embedding_type: positional 53 | timestep_post_act: null 54 | transformer_layers_per_block: 1 55 | up_block_types: 56 | - UpBlock2D 57 | - UpBlock2D 58 | - UpBlock2D 59 | upcast_attention: false 60 | use_linear_projection: false 61 | field_encoder_dict: 62 | hidden_size: 384 63 | intermediate_size: 1536 64 | projection_dim: 384 65 | image_size: 66 | - 64 67 | - 64 68 | patch_size: 4 69 | num_channels: 3 70 | num_hidden_layers: 4 71 | num_attention_heads: 8 72 | input_padding: 73 | - 0 74 | - 0 75 | output_hidden_state: false 76 | 77 | 78 | noise_scheduler: 79 | target: diffusers.EDMDPMSolverMultistepScheduler 80 | params: 81 | num_train_timesteps: 1000 82 | 83 | loss_fn: 84 | target: losses.loss.EDMLoss 85 | params: 86 | sigma_data: 0.5 87 | 88 | optimizer: 89 | betas: 90 | - 0.9 91 | - 0.999 92 | eps: 1e-08 93 | lr: 1e-4 94 | weight_decay: 1e-2 95 | 96 | lr_scheduler: 97 | #name: cosine 98 | name: constant 99 | num_warmup_steps: 500 100 | num_cycles: 0.5 101 | power: 1.0 102 | 103 | dataloader: 104 | data_dir: /scratch/kdur_root/kdur/ylzhuang/uvh_unrolled.npy 105 | batch_size: 16 106 | num_workers: 0 107 | split_ratios: 108 | - 0.8 109 | - 0.2 110 | - 0.0 111 | transform: normalize 112 | transform_args: 113 | mean: [0., 0., 1.0055307] 114 | std: [0.0089056, 0.0089056, 0.0137398] 115 | target_std: 0.5 116 | data_name: shallow_water 117 | 118 | accelerator: 119 | mixed_precision: fp16 120 | gradient_accumulation_steps: 1 121 | log_with: tensorboard 122 | 123 | ema: 124 | use_ema: True 125 | offload_ema: False 126 | ema_max_decay: 0.9999 127 | ema_inv_gamma: 1.0 128 | ema_power: 0.75 129 | foreach: True 130 | 131 | general: 132 | seed: 42 133 | num_epochs: null 134 | num_training_steps: 100000 135 | known_channels: [0, 1, 2] 136 | same_mask: True 137 | scale_lr: False 138 | output_dir: /scratch/kdur_root/kdur/ylzhuang/log/shallow_water_cfg 139 | logging_dir: sw 140 | tracker_project_name: sw_tracker 141 | save_image_epochs: 50 142 | save_model_epochs: 200 143 | checkpointing_steps: 25000 144 | eval_batch_size: 8 145 | cond_drop_prob: null 146 | do_edm_style_training: True 147 | snr_gamma: null 148 | channel_names: ["u", "v", "h"] -------------------------------------------------------------------------------- /configs/shallow_water_uncond_config.yaml: -------------------------------------------------------------------------------- 1 | unet: 2 | _class_name: UNet2DModel 3 | _diffusers_version: 0.28.2 4 | act_fn: silu 5 | add_attention: true 6 | attention_head_dim: 8 7 | attn_norm_num_groups: null 8 | block_out_channels: 9 | - 128 10 | - 256 11 | - 256 12 | center_input_sample: false 13 | class_embed_type: null 14 | down_block_types: 15 | - DownBlock2D 16 | - DownBlock2D 17 | - DownBlock2D 18 | downsample_padding: 1 19 | downsample_type: conv 20 | dropout: 0.0 21 | flip_sin_to_cos: true 22 | freq_shift: 0 23 | in_channels: 3 24 | layers_per_block: 2 25 | mid_block_scale_factor: 1 26 | norm_eps: 1e-05 27 | norm_num_groups: 32 28 | num_class_embeds: null 29 | num_train_timesteps: null 30 | out_channels: 3 31 | resnet_time_scale_shift: scale_shift 32 | sample_size: 64 33 | time_embedding_type: positional 34 | up_block_types: 35 | - UpBlock2D 36 | - UpBlock2D 37 | - UpBlock2D 38 | upsample_type: conv 39 | 40 | noise_scheduler: 41 | target: diffusers.EDMDPMSolverMultistepScheduler 42 | params: 43 | num_train_timesteps: 1000 44 | 45 | loss_fn: 46 | target: losses.loss.EDMLoss 47 | params: 48 | sigma_data: 0.5 49 | 50 | optimizer: 51 | betas: 52 | - 0.9 53 | - 0.999 54 | eps: 1e-08 55 | lr: 1e-4 56 | weight_decay: 1e-2 57 | 58 | lr_scheduler: 59 | #name: cosine 60 | name: constant 61 | num_warmup_steps: 500 62 | num_cycles: 0.5 63 | power: 1.0 64 | 65 | dataloader: 66 | data_dir: /scratch/kdur_root/kdur/ylzhuang/uvh_unrolled.npy 67 | batch_size: 16 68 | num_workers: 0 69 | split_ratios: 70 | - 0.8 71 | - 0.2 72 | - 0.0 73 | transform: normalize 74 | transform_args: 75 | mean: [0., 0., 1.0055307] 76 | std: [0.0089056, 0.0089056, 0.0137398] 77 | target_std: 0.5 78 | data_name: shallow_water 79 | 80 | accelerator: 81 | mixed_precision: fp16 82 | gradient_accumulation_steps: 1 83 | log_with: tensorboard 84 | 85 | ema: 86 | use_ema: True 87 | offload_ema: False 88 | ema_max_decay: 0.9999 89 | ema_inv_gamma: 1.0 90 | ema_power: 0.75 91 | foreach: True 92 | 93 | general: 94 | seed: 42 95 | num_epochs: null 96 | num_training_steps: 100000 97 | known_channels: [0, 1, 2] 98 | same_mask: True 99 | scale_lr: False 100 | output_dir: /scratch/kdur_root/kdur/ylzhuang/log/shallow_water_uncond 101 | logging_dir: sw 102 | tracker_project_name: sw_tracker 103 | save_image_epochs: 50 104 | save_model_epochs: 200 105 | checkpointing_steps: 25000 106 | eval_batch_size: 8 107 | cond_drop_prob: null 108 | do_edm_style_training: True 109 | snr_gamma: null 110 | channel_names: ["u", "v", "h"] -------------------------------------------------------------------------------- /configs/shallow_water_vt_config.yaml: -------------------------------------------------------------------------------- 1 | unet: 2 | _class_name: UNet2DModel 3 | _diffusers_version: 0.28.2 4 | act_fn: silu 5 | add_attention: true 6 | attention_head_dim: 8 7 | attn_norm_num_groups: null 8 | block_out_channels: 9 | - 128 10 | - 256 11 | - 256 12 | center_input_sample: false 13 | class_embed_type: null 14 | down_block_types: 15 | - DownBlock2D 16 | - DownBlock2D 17 | - DownBlock2D 18 | downsample_padding: 1 19 | downsample_type: conv 20 | dropout: 0.0 21 | flip_sin_to_cos: true 22 | freq_shift: 0 23 | in_channels: 3 24 | layers_per_block: 2 25 | mid_block_scale_factor: 1 26 | norm_eps: 1e-05 27 | norm_num_groups: 32 28 | num_class_embeds: null 29 | num_train_timesteps: null 30 | out_channels: 3 31 | sample_size: 64 32 | time_embedding_type: positional 33 | up_block_types: 34 | - UpBlock2D 35 | - UpBlock2D 36 | - UpBlock2D 37 | upsample_type: conv 38 | 39 | noise_scheduler: 40 | target: diffusers.EDMDPMSolverMultistepScheduler 41 | params: 42 | num_train_timesteps: 1000 43 | 44 | loss_fn: 45 | target: losses.loss.EDMLoss 46 | params: 47 | sigma_data: 0.5 48 | 49 | optimizer: 50 | betas: 51 | - 0.9 52 | - 0.999 53 | eps: 1e-08 54 | lr: 1e-4 55 | weight_decay: 1e-2 56 | 57 | lr_scheduler: 58 | #name: cosine 59 | name: constant 60 | num_warmup_steps: 500 61 | num_cycles: 0.5 62 | power: 1.0 63 | 64 | dataloader: 65 | data_dir: /scratch/kdur_root/kdur/ylzhuang/uvh_unrolled.npy 66 | batch_size: 16 67 | num_workers: 0 68 | split_ratios: 69 | - 0.8 70 | - 0.2 71 | - 0.0 72 | transform: normalize 73 | transform_args: 74 | mean: [0., 0., 1.0055307] 75 | std: [0.0089056, 0.0089056, 0.0137398] 76 | target_std: 0.5 77 | data_name: shallow_water 78 | 79 | accelerator: 80 | mixed_precision: fp16 81 | gradient_accumulation_steps: 1 82 | log_with: tensorboard 83 | 84 | ema: 85 | use_ema: True 86 | offload_ema: False 87 | ema_max_decay: 0.9999 88 | ema_inv_gamma: 1.0 89 | ema_power: 0.75 90 | foreach: True 91 | 92 | general: 93 | seed: 42 94 | num_epochs: null 95 | num_training_steps: 100000 96 | known_channels: [0, 1, 2] 97 | same_mask: True 98 | scale_lr: False 99 | output_dir: /scratch/kdur_root/kdur/ylzhuang/log/shallow_water_vt 100 | logging_dir: sw 101 | tracker_project_name: sw_tracker 102 | save_image_epochs: 50 103 | save_model_epochs: 200 104 | checkpointing_steps: 25000 105 | eval_batch_size: 8 106 | do_edm_style_training: True 107 | snr_gamma: null 108 | channel_names: ["u", "v", "h"] -------------------------------------------------------------------------------- /configs/shallow_water_xattn_config.yaml: -------------------------------------------------------------------------------- 1 | unet: 2 | _class_name: diffuserUNet2DCondition 3 | _diffusers_version: 0.29.2 4 | act_fn: silu 5 | addition_embed_type: null 6 | addition_embed_type_num_heads: 64 7 | addition_time_embed_dim: null 8 | attention_head_dim: 8 9 | attention_type: default 10 | block_out_channels: 11 | - 128 12 | - 256 13 | - 256 14 | center_input_sample: false 15 | class_embed_type: null 16 | class_embeddings_concat: false 17 | conv_in_kernel: 3 18 | conv_out_kernel: 3 19 | cross_attention_dim: 384 20 | cross_attention_norm: null 21 | down_block_types: 22 | - DownBlock2D 23 | - CrossAttnDownBlock2D 24 | - CrossAttnDownBlock2D 25 | downsample_padding: 1 26 | dropout: 0.0 27 | dual_cross_attention: false 28 | encoder_hid_dim: null 29 | encoder_hid_dim_type: null 30 | flip_sin_to_cos: true 31 | freq_shift: 0 32 | in_channels: 3 33 | layers_per_block: 2 34 | mid_block_only_cross_attention: null 35 | mid_block_scale_factor: 1 36 | mid_block_type: UNetMidBlock2DCrossAttn 37 | norm_eps: 1e-05 38 | norm_num_groups: 32 39 | num_attention_heads: null 40 | num_class_embeds: null 41 | only_cross_attention: false 42 | out_channels: 3 43 | projection_class_embeddings_input_dim: null 44 | resnet_out_scale_factor: 1.0 45 | resnet_skip_time_act: false 46 | resnet_time_scale_shift: scale_shift 47 | reverse_transformer_layers_per_block: null 48 | sample_size: 64 49 | time_cond_proj_dim: null 50 | time_embedding_act_fn: null 51 | time_embedding_dim: null 52 | time_embedding_type: positional 53 | timestep_post_act: null 54 | transformer_layers_per_block: 1 55 | up_block_types: 56 | - CrossAttnUpBlock2D 57 | - CrossAttnUpBlock2D 58 | - UpBlock2D 59 | upcast_attention: false 60 | use_linear_projection: false 61 | field_encoder_dict: 62 | hidden_size: 384 63 | intermediate_size: 1536 64 | projection_dim: 384 65 | image_size: 66 | - 64 67 | - 64 68 | patch_size: 4 69 | num_channels: 3 70 | num_hidden_layers: 4 71 | num_attention_heads: 8 72 | input_padding: 73 | - 0 74 | - 0 75 | output_hidden_state: true 76 | 77 | 78 | noise_scheduler: 79 | target: diffusers.EDMDPMSolverMultistepScheduler 80 | params: 81 | num_train_timesteps: 1000 82 | 83 | loss_fn: 84 | target: losses.loss.EDMLoss 85 | params: 86 | sigma_data: 0.5 87 | 88 | optimizer: 89 | betas: 90 | - 0.9 91 | - 0.999 92 | eps: 1e-08 93 | lr: 1e-4 94 | weight_decay: 1e-2 95 | 96 | lr_scheduler: 97 | #name: cosine 98 | name: constant 99 | num_warmup_steps: 500 100 | num_cycles: 0.5 101 | power: 1.0 102 | 103 | dataloader: 104 | data_dir: /scratch/kdur_root/kdur/ylzhuang/uvh_unrolled.npy 105 | batch_size: 16 106 | num_workers: 0 107 | split_ratios: 108 | - 0.8 109 | - 0.2 110 | - 0.0 111 | transform: normalize 112 | transform_args: 113 | mean: [0., 0., 1.0055307] 114 | std: [0.0089056, 0.0089056, 0.0137398] 115 | target_std: 0.5 116 | data_name: shallow_water 117 | 118 | accelerator: 119 | mixed_precision: fp16 120 | gradient_accumulation_steps: 1 121 | log_with: tensorboard 122 | 123 | ema: 124 | use_ema: True 125 | offload_ema: False 126 | ema_max_decay: 0.9999 127 | ema_inv_gamma: 1.0 128 | ema_power: 0.75 129 | foreach: True 130 | 131 | general: 132 | seed: 42 133 | num_epochs: null 134 | num_training_steps: 100000 135 | known_channels: [0, 1, 2] 136 | same_mask: True 137 | scale_lr: False 138 | output_dir: /scratch/kdur_root/kdur/ylzhuang/log/shallow_water_xattn 139 | logging_dir: sw 140 | tracker_project_name: sw_tracker 141 | save_image_epochs: 50 142 | save_model_epochs: 200 143 | checkpointing_steps: 25000 144 | eval_batch_size: 8 145 | cond_drop_prob: null 146 | do_edm_style_training: True 147 | snr_gamma: null 148 | channel_names: ["u", "v", "h"] -------------------------------------------------------------------------------- /dataloader/dataset_class.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.utils.data import random_split 4 | from einops import rearrange 5 | from torch.utils.data import Dataset, DataLoader, Subset 6 | import xarray as xr 7 | 8 | import sys 9 | import os 10 | 11 | parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) 12 | sys.path.append(parent_dir) 13 | from utils.general_utils import read_hdf5_to_numpy 14 | 15 | def normalize_transform(sample, mean, std, target_std=1): 16 | # (C, H, W) 17 | mean = torch.tensor(mean, device=sample.device) 18 | std = torch.tensor(std, device=sample.device) 19 | return ((sample - mean[:, None, None]) / std[:, None, None]) * target_std 20 | 21 | def inverse_normalize_transform(normalized_sample, mean, std, target_std=1): 22 | # (C, H, W) 23 | mean = torch.tensor(mean, device=normalized_sample.device) 24 | std = torch.tensor(std, device=normalized_sample.device) 25 | return (normalized_sample / target_std) * std[:, None, None] + mean[:, None, None] 26 | 27 | class FullDataset(Dataset): 28 | def __init__(self, data, transform=None, transform_args=None): 29 | """ 30 | Args: 31 | data (numpy.ndarray): Full dataset. 32 | transform (callable, optional): Optional transform to be applied on a sample. 33 | """ 34 | self.data = data 35 | self.transform = transform 36 | if transform_args is None: 37 | transform_args = {} 38 | else: 39 | self.transform_args = transform_args 40 | 41 | def __len__(self): 42 | return len(self.data) 43 | 44 | def __getitem__(self, idx): 45 | sample = self.data[idx] 46 | if self.transform: 47 | sample = self.transform(sample, **self.transform_args) 48 | return sample 49 | 50 | class XarrayDataset2D(Dataset): 51 | def __init__( 52 | self, 53 | data: xr.Dataset, # (n, t, x, y), prefferable, but have to have char 'n' and 't' 54 | transform: str = None, 55 | transform_args: dict = None, 56 | load_in_memory: bool = False, # TODO: figure out how to save only one copy in memory 57 | ): 58 | if not load_in_memory: 59 | self.data = data 60 | else: 61 | self.data = data.load() 62 | self.transform = self._get_transform(transform, transform_args) 63 | self.transform_args = transform_args or {} 64 | self.var_names_list = list(data.data_vars.keys()) 65 | self.dims_dict = dict(data.dims) 66 | assert 'n' in self.dims_dict, 'Dataset must have dimension named "n".' 67 | assert 't' in self.dims_dict, 'Dataset must have dimension named "t".' 68 | self.length = self.dims_dict['n'] * self.dims_dict['t'] 69 | 70 | def _get_transform(self, transform, transform_args): 71 | if transform == 'normalize': 72 | mean = transform_args['mean'] 73 | std = transform_args['std'] 74 | if 'target_std' in transform_args: 75 | target_std = transform_args['target_std'] 76 | return lambda x: normalize_transform(x, mean, std, target_std) 77 | return lambda x: normalize_transform(x, mean, std) 78 | elif transform is None: 79 | return lambda x: x 80 | else: 81 | raise NotImplementedError(f'Transform: {transform} not implemented.') 82 | 83 | def _preprocess_data(self, data): 84 | data = torch.from_numpy(data).float() 85 | return self.transform(data) 86 | 87 | def _idx2nt(self, idx): 88 | return divmod(idx, self.dims_dict['t']) 89 | 90 | def get_array_from_xrdataset_2D(self, n_idx, t_idx): 91 | sliced_ds = self.data.isel({'n': n_idx, 't': t_idx}) 92 | # xr.to_array() is slower 93 | return np.stack([sliced_ds[var].values for var in self.var_names_list], axis=0) 94 | 95 | def __len__(self): 96 | return self.length 97 | 98 | def __getitem__(self, idx): 99 | assert idx >= 0, f'Index must be non-negative, got: {idx}.' 100 | n_idx, t_idx = self._idx2nt(idx) 101 | return self._preprocess_data(self.get_array_from_xrdataset_2D(n_idx, t_idx)) 102 | 103 | def npy2dataloader(full_dataset, batch_size, num_workers, split_ratios=(0.7, 0.2, 0.1), transform=None, transform_args=None, 104 | rearrange_args=None, random_dataset=True, generator=None, return_dataset=False): 105 | if rearrange_args is not None: 106 | full_dataset = rearrange(full_dataset, rearrange_args) 107 | full_dataset = torch.tensor(full_dataset, dtype=torch.float32) 108 | 109 | train_size = int(split_ratios[0] * len(full_dataset)) 110 | val_size = int(split_ratios[1] * len(full_dataset)) 111 | test_size = len(full_dataset) - train_size - val_size 112 | 113 | if random_dataset: 114 | print('Randomly splitting dataset.') 115 | train_dataset, val_dataset, test_dataset = random_split(full_dataset, [train_size, val_size, test_size], generator=generator) 116 | else: 117 | print('Splitting dataset by length.') 118 | train_dataset = Subset(full_dataset, list(range(0, train_size))) 119 | val_dataset = Subset(full_dataset, list(range(train_size, train_size+val_size))) 120 | test_dataset = Subset(full_dataset, list(range(train_size+val_size, len(full_dataset)))) 121 | 122 | if transform is not None: 123 | if transform == 'normalize': 124 | mean = transform_args['mean'] 125 | std = transform_args['std'] 126 | transform = normalize_transform 127 | if 'target_std' in transform_args: 128 | target_std = transform_args['target_std'] 129 | transform_args = {'mean': mean, 'std': std, 'target_std': target_std} 130 | else: 131 | transform_args = {'mean': mean, 'std': std} 132 | else: 133 | raise NotImplementedError(f'Transform: {transform} not implemented.') 134 | 135 | if not isinstance(full_dataset.data, xr.Dataset): 136 | # Apply the same transform to all splits if needed 137 | train_dataset = FullDataset(train_dataset, transform=transform, transform_args=transform_args) 138 | val_dataset = FullDataset(val_dataset, transform=transform, transform_args=transform_args) 139 | test_dataset = FullDataset(test_dataset, transform=transform, transform_args=transform_args) 140 | 141 | if return_dataset: 142 | return train_dataset, val_dataset, test_dataset 143 | 144 | # Create dataloaders 145 | train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, generator=generator) 146 | val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, generator=generator) 147 | test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, generator=generator) 148 | 149 | return train_loader, val_loader, test_loader 150 | 151 | def dataset2dataloader(full_dataset, batch_size, num_workers, split_ratios=(0.7, 0.2, 0.1), random_dataset=True, generator=None, return_dataset=False): 152 | train_size = int(split_ratios[0] * len(full_dataset)) 153 | val_size = int(split_ratios[1] * len(full_dataset)) 154 | test_size = len(full_dataset) - train_size - val_size 155 | 156 | if random_dataset: 157 | print('Randomly splitting dataset.') 158 | train_dataset, val_dataset, test_dataset = random_split(full_dataset, [train_size, val_size, test_size], generator=generator) 159 | else: 160 | print('Splitting dataset by length.') 161 | train_dataset = Subset(full_dataset, list(range(0, train_size))) 162 | val_dataset = Subset(full_dataset, list(range(train_size, train_size+val_size))) 163 | test_dataset = Subset(full_dataset, list(range(train_size+val_size, len(full_dataset)))) 164 | 165 | if return_dataset: 166 | return train_dataset, val_dataset, test_dataset 167 | 168 | # Create dataloaders 169 | train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, generator=generator) 170 | val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, generator=generator) 171 | test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, generator=generator) 172 | 173 | return train_loader, val_loader, test_loader 174 | 175 | def pdedata2dataloader(data_dir, batch_size=32, num_workers=1, split_ratios=(0.7, 0.2, 0.1), transform=None, transform_args=None, 176 | rearrange_args=None, generator=None, data_name=None, return_dataset=False, load_in_memory=False): 177 | 178 | if data_name == 'darcy': 179 | full_dataset = np.load(data_dir) 180 | return npy2dataloader(full_dataset, batch_size, num_workers, split_ratios=split_ratios, transform=transform, transform_args=transform_args, 181 | rearrange_args=rearrange_args, random_dataset=False, generator=generator, return_dataset=return_dataset) 182 | elif data_name == 'shallow_water': 183 | full_dataset = np.load(data_dir) 184 | return npy2dataloader(full_dataset, batch_size, num_workers, split_ratios=split_ratios, transform=transform, transform_args=transform_args, 185 | rearrange_args=rearrange_args, random_dataset=False, generator=generator, return_dataset=return_dataset) 186 | elif data_name == 'compressible_NS': 187 | # keys = ['Vx' , 'Vy', 'density', 'pressure'] 188 | xarray = xr.open_dataset(data_dir, phony_dims='access', engine='h5netcdf', 189 | drop_variables=['t-coordinate', 'x-coordinate', 'y-coordinate'], 190 | chunks='auto') 191 | xarray= xarray.rename({ 192 | 'phony_dim_0': 'n', 193 | 'phony_dim_1': 't', 194 | 'phony_dim_2': 'x', 195 | 'phony_dim_3': 'y' 196 | }) 197 | full_dataset = XarrayDataset2D(xarray, transform=transform, transform_args=transform_args, load_in_memory=load_in_memory) 198 | return dataset2dataloader(full_dataset, batch_size, num_workers, split_ratios=split_ratios, random_dataset=False, generator=generator, return_dataset=return_dataset) 199 | elif data_name == 'diffusion_reaction': 200 | # keys = ['u', 'v'] 201 | xarray = xr.open_dataset(data_dir, phony_dims='access', engine='h5netcdf', 202 | drop_variables=['t-coordinate', 'x-coordinate', 'y-coordinate'], 203 | chunks='auto') 204 | xarray= xarray.rename({ 205 | 'phony_dim_0': 't', 206 | 'phony_dim_1': 'n', 207 | 'phony_dim_2': 'x', 208 | 'phony_dim_3': 'y' 209 | }) 210 | full_dataset = XarrayDataset2D(xarray, transform=transform, transform_args=transform_args, load_in_memory=load_in_memory) 211 | return dataset2dataloader(full_dataset, batch_size, num_workers, split_ratios=split_ratios, random_dataset=False, generator=generator, return_dataset=return_dataset) 212 | else: 213 | raise NotImplementedError(f'Dataset: {data_name} not implemented.') -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import itertools 3 | from accelerate import Accelerator 4 | from accelerate.utils import set_seed 5 | from accelerate.logging import get_logger 6 | from diffusers import UNet2DModel 7 | import torch 8 | import numpy as np 9 | import pandas as pd 10 | from dataloader.dataset_class import pdedata2dataloader 11 | import os 12 | import copy 13 | from omegaconf import OmegaConf 14 | 15 | from pipelines.pipeline_inv_prob import InverseProblem2DPipeline, InverseProblem2DCondPipeline 16 | from models.unet2D import diffuserUNet2D 17 | from models.unet2DCondition import diffuserUNet2DCondition, diffuserUNet2DCFG 18 | from utils.general_utils import instantiate_from_config 19 | from utils.vt_utils import vt_obs 20 | from losses.metric import get_metrics_2D 21 | 22 | 23 | logger = get_logger(__name__, log_level="INFO") 24 | 25 | def parse_list_int(value): 26 | try: 27 | # Try to split by commas and convert to a list of integers 28 | steps = [int(x) for x in value.split(',')] 29 | return steps 30 | except ValueError: 31 | # If it fails, assume it's a single integer 32 | return [int(value)] 33 | 34 | def parse_list_float(value): 35 | try: 36 | # Try to split by commas and convert to a list of floats 37 | ratios = [float(x) for x in value.split(',')] 38 | return ratios 39 | except ValueError: 40 | # If it fails, assume it's a single float 41 | return [float(value)] 42 | 43 | def parse_args(): 44 | parser = argparse.ArgumentParser(description="Evaluate trained Diffusers model.") 45 | parser.add_argument('--config', type=str, required=True, help="Path to the YAML configuration file.") 46 | parser.add_argument('--repo_name', type=str, help="Repository name.") 47 | parser.add_argument('--subfolder', type=str, help="Subfolder in the repository.") 48 | parser.add_argument('--path_to_ckpt', type=str, help="Path to the checkpoint.") 49 | parser.add_argument( 50 | "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." 51 | ) 52 | parser.add_argument('--path_to_csv', type=str, required=True, help="Path to the CSV file for saving the results, if none exists, a new one will be created.") 53 | 54 | # eval 55 | parser.add_argument('--batch_size', type=int, default=32, help="Batch size for evaluation.") 56 | parser.add_argument('--num_inference_steps', type=parse_list_int, help='Number of inference steps, can be a single integer or a comma-separated list of integers.') 57 | parser.add_argument('--mask_ratios', type=parse_list_float, help='Mask ratios, can be a single float or a comma-separated list of floats.') 58 | parser.add_argument('--vt_spacing', type=parse_list_int, help="Spacing for dividing domain.") 59 | parser.add_argument('--mode', type=str, required=True, help="Mode for evaluation, 'edm', 'pipeline', 'vt'.") 60 | parser.add_argument('--conditioning_type', type=str, default='xattn', help="Conditioning type for evaluation.") 61 | parser.add_argument('--ensemble_size', type=int, default=25, help="Ensemble size for evaluation.") 62 | parser.add_argument('--channel_mean', action='store_true', help="Whether to output the channel mean.") 63 | parser.add_argument('--structure_sampling', action='store_true', default=False, help="Whether to sample with vt.") 64 | parser.add_argument('--noise_level', type=float, default=0., help="Noise level for evaluation.") 65 | parser.add_argument('--noise_type', type=str, default='white', help="Type of noise to add. Opitons: 'white', 'pink', 'red', 'blue', 'purple'.") 66 | parser.add_argument('--verbose', action='store_true', help="Whether to print verbose information.") 67 | parser.add_argument('--total_eval', type=int, default=1000, help="Total number of evaluation samples.") 68 | 69 | return parser.parse_args() 70 | 71 | def main(args): 72 | print(f"Received args: {args}") 73 | if args.config is None: 74 | raise ValueError("The config argument is missing. Please provide the path to the configuration file using --config.") 75 | config = OmegaConf.load(args.config) 76 | 77 | accelerator = Accelerator() 78 | 79 | use_vt = False 80 | if 'vt' in args.config: 81 | use_vt = True 82 | assert args.mode == 'vt', 'Mode should be vt when using vt' 83 | model_cls = diffuserUNet2D 84 | logger.info('Using vt model') 85 | else: 86 | if args.conditioning_type == 'cfg': 87 | model_cls = diffuserUNet2DCFG 88 | logger.info('Using cfg model') 89 | elif args.conditioning_type == 'xattn': 90 | model_cls = diffuserUNet2DCondition 91 | logger.info('Using xattn model') 92 | elif args.conditioning_type == 'uncond': 93 | model_cls = UNet2DModel 94 | logger.info('Using uncond model') 95 | else: 96 | raise NotImplementedError 97 | 98 | 99 | noise_scheduler_config = config.pop("noise_scheduler", OmegaConf.create()) 100 | dataloader_config = config.pop("dataloader", OmegaConf.create()) 101 | general_config = config.pop("general", OmegaConf.create()) 102 | 103 | set_seed(general_config.seed) 104 | 105 | repo_name = args.repo_name 106 | subfolder = args.subfolder 107 | 108 | if args.path_to_ckpt is None: 109 | unet = model_cls.from_pretrained(repo_name, 110 | subfolder=subfolder, 111 | use_safetensors=True, 112 | ) 113 | else: 114 | unet = model_cls.from_pretrained(args.path_to_ckpt, 115 | use_safetensors=True, 116 | ) 117 | 118 | noise_scheduler = instantiate_from_config(noise_scheduler_config) 119 | 120 | generator = torch.Generator(device=accelerator.device).manual_seed(general_config.seed) 121 | _, val_dataset, _ = pdedata2dataloader(**dataloader_config, generator=generator, 122 | return_dataset=True) 123 | 124 | #select_idx = np.random.choice(len(val_dataset), args.total_eval, replace=False) 125 | select_idx = np.arange(0, args.total_eval) 126 | reduced_val_dataset = torch.utils.data.Subset(val_dataset, select_idx) 127 | 128 | unet = accelerator.prepare_model(unet, evaluation_mode=True) 129 | resolution = unet.config.sample_size 130 | 131 | if args.conditioning_type == 'xattn' or args.conditioning_type == 'cfg': 132 | pipeline = InverseProblem2DCondPipeline(unet, scheduler=copy.deepcopy(noise_scheduler)) 133 | elif args.conditioning_type == 'uncond': 134 | pipeline = InverseProblem2DPipeline(unet, scheduler=copy.deepcopy(noise_scheduler)) 135 | else: 136 | raise NotImplementedError 137 | 138 | for num_inference_steps, mask_ratios, vt_spacing in itertools.product(args.num_inference_steps, args.mask_ratios, args.vt_spacing): 139 | 140 | sampler_kwargs = { 141 | "num_inference_steps": num_inference_steps, 142 | "known_channels": general_config.known_channels, 143 | #"same_mask": general_config.same_mask, 144 | } 145 | x_idx, y_idx = torch.meshgrid(torch.arange(vt_spacing, resolution, vt_spacing), torch.arange(vt_spacing, resolution, vt_spacing)) 146 | x_idx = x_idx.flatten().to(accelerator.device) 147 | y_idx = y_idx.flatten().to(accelerator.device) 148 | if "darcy" in args.subfolder: 149 | mask_kwargs = { 150 | "x_idx": x_idx, 151 | "y_idx": y_idx, 152 | "channels": general_config.known_channels, 153 | } 154 | else: 155 | mask_kwargs = { 156 | "ratio": mask_ratios, 157 | "channels": general_config.known_channels, 158 | } 159 | 160 | tmp_dim = unet.config.sample_size 161 | vt = vt_obs(x_dim=tmp_dim, y_dim=tmp_dim, x_spacing=vt_spacing, y_spacing=vt_spacing, known_channels=general_config.known_channels, device=accelerator.device) 162 | if not 'ratio' in mask_kwargs: 163 | if 'x_idx' in mask_kwargs: 164 | print('Total number of observation points: ', mask_kwargs["x_idx"].shape[0], ' Perceage of known points: ', mask_kwargs["x_idx"].shape[0] / tmp_dim**2) 165 | else: 166 | print('Total number of observation points: ', vt.x_start_grid.numel(), ' Perceage of known points: ', vt.x_start_grid.numel() / tmp_dim**2) 167 | else: 168 | print('Total number of observation points: ', mask_kwargs['ratio']*tmp_dim**2, ' Perceage of known points: ', mask_kwargs['ratio']) 169 | 170 | err_RMSE, err_nRMSE, err_CSV = get_metrics_2D(reduced_val_dataset, pipeline=pipeline, 171 | vt = vt, 172 | vt_model = unet if use_vt else None, 173 | batch_size = args.batch_size, 174 | ensemble_size = args.ensemble_size, 175 | sampler_kwargs = sampler_kwargs, 176 | mask_kwargs = mask_kwargs, 177 | known_channels=general_config.known_channels, 178 | device = accelerator.device, 179 | mode = args.mode, #'edm', 'pipeline', 'vt' 180 | conditioning_type = args.conditioning_type, 181 | inverse_transform = dataloader_config.transform, # 'normalize' 182 | inverse_transform_args = dataloader_config.transform_args, 183 | channel_mean = args.channel_mean, 184 | structure_sampling = args.structure_sampling, 185 | noise_level = args.noise_level, 186 | noise_type = args.noise_type, 187 | verbose = args.verbose, 188 | ) 189 | 190 | csv_filename = args.path_to_csv 191 | 192 | if args.mode != 'mean': 193 | if args.structure_sampling: 194 | index_value = f"{args.config.split('/')[-1].split('.')[0]}_spacing_{str(vt_spacing)}_mode_{args.mode}_step_{str(num_inference_steps)}" 195 | else: 196 | index_value = f"{args.config.split('/')[-1].split('.')[0]}_ratio_{str(mask_kwargs['ratio'])}_mode_{args.mode}_step_{str(num_inference_steps)}" 197 | else: 198 | index_value = f"{dataloader_config.data_name}_mean" 199 | 200 | if os.path.exists(csv_filename): 201 | df_existing = pd.read_csv(csv_filename, index_col=0) 202 | else: 203 | df_existing = pd.DataFrame(columns=['RMSE', 'nRMSE', 'CSV']) 204 | df_existing.index.name = 'Index' 205 | 206 | if "darcy" in args.subfolder: 207 | # For darcy we only save the permeability field 208 | df_new = pd.DataFrame({'RMSE': [err_RMSE.cpu().numpy()[0]], 209 | 'nRMSE': [err_nRMSE.cpu().numpy()[0]], 210 | 'CSV': [err_CSV.cpu().numpy()[0]]}, 211 | index=[index_value]) 212 | else: 213 | df_new = pd.DataFrame({'RMSE': [err_RMSE.cpu().numpy()], 214 | 'nRMSE': [err_nRMSE.cpu().numpy()], 215 | 'CSV': [err_CSV.cpu().numpy()]}, 216 | index=[index_value]) 217 | 218 | df_combined = pd.concat([df_existing, df_new]) 219 | df_combined.to_csv(csv_filename) 220 | 221 | logger.info(f'RMSE: {err_RMSE.cpu().numpy()}, nRMSE: {err_nRMSE.cpu().numpy()}, CSV: {err_CSV.cpu().numpy()}') 222 | 223 | if __name__ == "__main__": 224 | args = parse_args() 225 | main(args) -------------------------------------------------------------------------------- /git_assest/bar_chart_0.01_largefont.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonyzyl/DiffusionReconstruct/e1c11ad8034ebf27460c79b87ff5527cbb94a4f3/git_assest/bar_chart_0.01_largefont.png -------------------------------------------------------------------------------- /git_assest/darcy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonyzyl/DiffusionReconstruct/e1c11ad8034ebf27460c79b87ff5527cbb94a4f3/git_assest/darcy.png -------------------------------------------------------------------------------- /git_assest/dr_hf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonyzyl/DiffusionReconstruct/e1c11ad8034ebf27460c79b87ff5527cbb94a4f3/git_assest/dr_hf.png -------------------------------------------------------------------------------- /git_assest/encoding_block.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonyzyl/DiffusionReconstruct/e1c11ad8034ebf27460c79b87ff5527cbb94a4f3/git_assest/encoding_block.png -------------------------------------------------------------------------------- /git_assest/error_hist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonyzyl/DiffusionReconstruct/e1c11ad8034ebf27460c79b87ff5527cbb94a4f3/git_assest/error_hist.png -------------------------------------------------------------------------------- /losses/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class LpLoss(object): 4 | #loss function with rel/abs Lp loss, modified from neuralop: 5 | #https://github.com/neuraloperator/neuraloperator/blob/main/neuralop/losses/data_losses.py 6 | ''' 7 | LpLoss: Lp loss function, return the relative loss by default 8 | Args: 9 | d: int, start dimension of the field. E,g., for shape like (b, c, h, w), d=2 (default 1) 10 | p: int, p in Lp norm, default 2 11 | reduce_dims: int or list of int, dimensions to reduce 12 | reductions: str or list of str, 'sum' or 'mean' 13 | 14 | Call: (y_pred, y) 15 | ''' 16 | def __init__(self, d=1, p=2, reduce_dims=0, reductions='sum'): 17 | super().__init__() 18 | 19 | self.d = d 20 | self.p = p 21 | 22 | if isinstance(reduce_dims, int): 23 | self.reduce_dims = [reduce_dims] 24 | else: 25 | self.reduce_dims = reduce_dims 26 | 27 | if self.reduce_dims is not None: 28 | if isinstance(reductions, str): 29 | assert reductions == 'sum' or reductions == 'mean' 30 | self.reductions = [reductions]*len(self.reduce_dims) 31 | else: 32 | for j in range(len(reductions)): 33 | assert reductions[j] == 'sum' or reductions[j] == 'mean' 34 | self.reductions = reductions 35 | 36 | def reduce_all(self, x): 37 | for j in range(len(self.reduce_dims)): 38 | if self.reductions[j] == 'sum': 39 | x = torch.sum(x, dim=self.reduce_dims[j], keepdim=True) 40 | else: 41 | x = torch.mean(x, dim=self.reduce_dims[j], keepdim=True) 42 | 43 | return x 44 | 45 | def abs(self, x, y): 46 | diff = torch.norm(torch.flatten(x, start_dim=-self.d) - torch.flatten(y, start_dim=-self.d), \ 47 | p=self.p, dim=-1, keepdim=False) 48 | 49 | if self.reduce_dims is not None: 50 | diff = self.reduce_all(diff).squeeze() 51 | 52 | return diff 53 | 54 | def rel(self, x, y): 55 | 56 | diff = torch.norm(torch.flatten(x, start_dim=-self.d) - torch.flatten(y, start_dim=-self.d), \ 57 | p=self.p, dim=-1, keepdim=False) 58 | ynorm = torch.norm(torch.flatten(y, start_dim=-self.d), p=self.p, dim=-1, keepdim=False) 59 | 60 | diff = diff/ynorm 61 | 62 | if self.reduce_dims is not None: 63 | diff = self.reduce_all(diff).squeeze() 64 | 65 | return diff 66 | 67 | def __call__(self, y_pred, y, **kwargs): 68 | return self.rel(y_pred, y) 69 | 70 | class EDMLoss: 71 | def __init__(self, sigma_data=0.5): 72 | self.sigma_data = sigma_data 73 | 74 | def __call__(self, y_pred, y, sigma, **kwargs): 75 | # (comment from diffuser) We are not doing weighting here because it tends result in numerical problems. 76 | # See: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051 77 | # There might be other alternatives for weighting as well: 78 | # https://github.com/huggingface/diffusers/pull/7126#discussion_r1505404686 79 | weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2 80 | loss = weight * ((y - y_pred) ** 2) 81 | return loss.mean() 82 | 83 | class EDMLoss_reg: 84 | def __init__(self, sigma_data=0.5, reg_weight=0.001): 85 | self.sigma_data = sigma_data 86 | self.reg_weight = reg_weight 87 | 88 | def __call__(self, y_pred, y, sigma, **kwargs): 89 | weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2 90 | squared_err = ((y - y_pred) ** 2) 91 | csv_reg = (torch.sum(y_pred, dim=[-2, -1], keepdim=True) - torch.sum(y, dim=[-2, -1], keepdim=True))**2 92 | loss = weight * squared_err + self.reg_weight * csv_reg 93 | return loss.mean() -------------------------------------------------------------------------------- /losses/metric.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import copy 4 | import torch 5 | import math as mt 6 | from tqdm.auto import tqdm 7 | from torch.utils.data import DataLoader 8 | from diffusers.utils.torch_utils import randn_tensor 9 | from einops import rearrange, repeat 10 | 11 | parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) 12 | sys.path.append(parent_dir) 13 | from utils.inverse_utils import create_scatter_mask, ensemble_sample, colored_noise 14 | from dataloader.dataset_class import inverse_normalize_transform 15 | 16 | @torch.no_grad() 17 | def metric_func_2D(y_pred, y_true, mask=None, channel_mean=True): 18 | # Adapted from: https://github.com/pdebench/PDEBench/blob/main/pdebench/models/metrics.py 19 | # (B, C, H, W) 20 | ''' 21 | y_pred: torch.Tensor, shape (B, C, H, W) 22 | y_true: torch.Tensor, shape (B, C, H, W) 23 | mask: torch.Tensor, shape (B, H, W), optional 24 | 25 | Returns: err_RMSE, err_nRMSE, err_CSV 26 | ''' 27 | if mask is not None: 28 | unknown_mask = (1 - mask).float() 29 | num_unknowns = torch.sum(unknown_mask, dim=[-2, -1], keepdim=True) 30 | assert torch.all(num_unknowns) > 0, "All values are known. Cannot compute with mask." 31 | mse = torch.sum(((y_pred - y_true)**2) * unknown_mask, dim=[-2, -1], keepdim=True) / num_unknowns 32 | mse = rearrange(mse, 'b c 1 1 -> b c') 33 | nrm = torch.sqrt(torch.sum((y_true**2) * unknown_mask, dim=[-2, -1], keepdim=True) / num_unknowns) 34 | nrm = rearrange(nrm, 'b c 1 1 -> b c') 35 | csv_error = torch.mean((torch.sum(y_pred * unknown_mask, dim=[-2, -1]) - torch.sum(y_true * unknown_mask, dim=[-2, -1]))**2, dim=0) 36 | else: 37 | mse = torch.mean((y_pred - y_true)**2, dim=[-2, -1]) 38 | nrm = torch.sqrt(torch.mean(y_true**2, dim=[-2, -1])) 39 | csv_error = torch.mean((torch.sum(y_pred, dim=[-2, -1]) - torch.sum(y_true, dim=[-2, -1]))**2, dim=0) 40 | 41 | err_mean = torch.sqrt(mse) # (B, C) 42 | err_RMSE = torch.mean(err_mean, axis=0) # -> (C) 43 | err_nRMSE = torch.mean(err_mean / nrm, dim=0) # -> (C) 44 | err_CSV = torch.sqrt(csv_error) # (C) 45 | if mask is not None: 46 | err_CSV /= torch.mean(torch.sum(unknown_mask, dim=[-2, -1]), dim=0) 47 | else: 48 | err_CSV /= (y_true.shape[-2] * y_true.shape[-1]) 49 | 50 | # Channel mean 51 | if channel_mean: 52 | err_RMSE = torch.mean(err_RMSE, axis=0) 53 | err_nRMSE = torch.mean(err_nRMSE, axis=0) 54 | err_CSV = torch.mean(err_CSV, axis=0) 55 | 56 | return err_RMSE, err_nRMSE, err_CSV 57 | 58 | @torch.no_grad() 59 | def get_metrics_2D(val_dataset, pipeline=None, vt=None, vt_model=None, batch_size=64, ensemble_size=25, 60 | sampler_kwargs=None, mask_kwargs=None, known_channels=None, device='cpu', mode='edm', #'edm', 'pipeline', 'vt', 'mean' 61 | conditioning_type = None, #xattn, cfg 62 | inverse_transform=None, inverse_transform_args=None, channel_mean=True, 63 | structure_sampling=False, noise_level=0, noise_type='white', # 'white', 'pink', 'red', 'blue', 'purple' 64 | verbose=False): 65 | 66 | if inverse_transform == 'normalize': 67 | inverse_transform = inverse_normalize_transform 68 | 69 | if mode != 'mean': 70 | if mask_kwargs is None: 71 | mask_kwargs = {} 72 | if sampler_kwargs is None: 73 | sampler_kwargs = {} 74 | 75 | if mode == 'edm': 76 | model = pipeline.unet 77 | noise_scheduler = copy.deepcopy(pipeline.scheduler) 78 | 79 | if "x_idx" in mask_kwargs and "y_idx" in mask_kwargs: 80 | print("Using structure sampling, x_idx and y_idx are provided.") 81 | assert structure_sampling, "Structure sampling must be enabled when x_idx and y_idx are provided." 82 | x_idx, y_idx = mask_kwargs["x_idx"], mask_kwargs["y_idx"] 83 | mask_kwargs.pop("x_idx") 84 | mask_kwargs.pop("y_idx") 85 | ignore_idx_check = True 86 | else: 87 | if structure_sampling: 88 | print("Using structure sampling, x_idx and y_idx are not provided, selecting random points from each grid.") 89 | else: 90 | print("Not using structure sampling. points are uniformly sampled.") 91 | ignore_idx_check = False 92 | 93 | 94 | print(f'Calculating metrics for {mode}, ensemble size: {ensemble_size}') 95 | else: 96 | print('Calculating metrics for mean, all other arguments are ignored.') 97 | 98 | if structure_sampling and noise_type != 'white': 99 | Warning("Colored noise is not supported with structure sampling. Ignoring noise_type.") 100 | 101 | count = 0 102 | val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) 103 | for step, y_true in tqdm(enumerate(val_loader), total=len(val_loader)): 104 | # (B, C, H, W) 105 | y_true = y_true.to(device) 106 | num_sample, C, _, _ = y_true.shape 107 | if mode != 'mean': 108 | generator = [torch.Generator(device).manual_seed(int(seed) % (1 << 32)) for seed in range(count, count+num_sample)] 109 | if mode == 'vt': 110 | in_channels = vt_model.config.in_channels 111 | else: 112 | in_channels = C if known_channels is None else len(known_channels) 113 | interpolated_fields = torch.empty((num_sample, in_channels, y_true.shape[-2], y_true.shape[-1]), device=device, dtype=y_true.dtype) 114 | if structure_sampling: 115 | scatter_mask = torch.empty_like(y_true) 116 | 117 | for b in range(num_sample): 118 | if ("x_idx" not in mask_kwargs and "y_idx" not in mask_kwargs) and not ignore_idx_check: 119 | x_idx, y_idx = vt.structure_obs(generator=generator[b]) 120 | 121 | grid_points = vt._get_grid_points(x_idx, y_idx, generator=generator[b]) 122 | 123 | for idx, known_channel in enumerate(range(C) if known_channels is None else known_channels): 124 | field = y_true[b, known_channel][grid_points[:,1], grid_points[:,0]].flatten() 125 | if noise_level > 0: 126 | # drop support for colored noise in this case 127 | noise = randn_tensor(field.shape, device=device) 128 | field += noise * noise_level * field # y_obs = y_true + noise_level * y_true 129 | interpolated_values = vt.interpolate(grid_points, field) 130 | interpolated_fields[b, idx] = torch.tensor(interpolated_values, 131 | dtype=y_true.dtype, 132 | device=device) 133 | tmp_mask = tmp_mask = create_scatter_mask(y_true, x_idx=x_idx, y_idx=y_idx, device=device, **mask_kwargs) 134 | scatter_mask[b] = tmp_mask[0] 135 | else: 136 | scatter_mask = create_scatter_mask(y_true, generator=generator, device=device, **mask_kwargs) 137 | tmp_y_true = y_true.clone() 138 | if noise_level > 0: 139 | if noise_type == 'white': 140 | noise = randn_tensor(tmp_y_true.shape, device=device) 141 | else: 142 | noise = colored_noise(tmp_y_true.shape, noise_type=noise_type, device=device) 143 | tmp_y_true += noise * noise_level * tmp_y_true # y_obs = y_true + noise_level * y_true 144 | interpolated_fields = vt(known_fields=tmp_y_true, mask=scatter_mask) 145 | 146 | if mode == 'edm': 147 | y_pred = torch.empty_like(y_true) 148 | for b in range(num_sample): 149 | y_pred[b] = ensemble_sample(pipeline, ensemble_size, scatter_mask[[b]], sampler_kwargs=sampler_kwargs, conditioning_type=conditioning_type, 150 | class_labels=None, known_latents=interpolated_fields[[b]], sampler_type=mode, device=device).mean(dim=0) 151 | elif mode == 'pipeline': 152 | y_pred = torch.empty_like(y_true) 153 | for b in range(num_sample): 154 | y_pred[b] = ensemble_sample(pipeline, ensemble_size, scatter_mask[[b]], sampler_kwargs=sampler_kwargs, conditioning_type=conditioning_type, 155 | class_labels=None, known_latents=interpolated_fields[[b]], sampler_type=mode, device=device).mean(dim=0) 156 | 157 | elif mode == 'vt': 158 | y_pred = vt_model(interpolated_fields, return_dict=False)[0] 159 | 160 | else: 161 | raise NotImplementedError(f'Mode: {mode} not implemented.') 162 | 163 | if inverse_transform is not None: 164 | y_pred = inverse_transform(y_pred, **inverse_transform_args) 165 | y_true = inverse_transform(y_true, **inverse_transform_args) 166 | else: 167 | y_true = inverse_transform(y_true, **inverse_transform_args) 168 | y_pred = torch.ones_like(y_true) * repeat(torch.tensor(inverse_transform_args['mean'], device=device), 'c -> b c 1 1', b=num_sample) 169 | 170 | 171 | _err_RMSE, _err_nRMSE, _err_CSV = metric_func_2D(y_pred, y_true, 172 | mask=scatter_mask if mode != 'mean' else None, 173 | channel_mean=channel_mean) 174 | 175 | if step == 0: 176 | err_RMSE = _err_RMSE * num_sample 177 | err_nRMSE = _err_nRMSE * num_sample 178 | err_CSV = _err_CSV * num_sample 179 | else: 180 | err_RMSE += _err_RMSE * num_sample 181 | err_nRMSE += _err_nRMSE * num_sample 182 | err_CSV += _err_CSV * num_sample 183 | 184 | count += num_sample 185 | 186 | if verbose: 187 | print(f'RMSE: {err_RMSE / count}, nRMSE: {err_nRMSE / count}, CSV: {err_CSV / count}') 188 | 189 | return err_RMSE / count, err_nRMSE / count, err_CSV / count -------------------------------------------------------------------------------- /models/unet2D.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | from diffusers.utils import BaseOutput 5 | from diffusers.configuration_utils import ConfigMixin, register_to_config 6 | from diffusers.models.modeling_utils import ModelMixin 7 | from diffusers.models.unets.unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block 8 | from dataclasses import dataclass 9 | from typing import Optional, Tuple, Union 10 | 11 | @dataclass 12 | class UNet2DOutput(BaseOutput): 13 | 14 | """ 15 | The output of [`UNet2DModel`]. 16 | 17 | Args: 18 | sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`): 19 | The hidden states output from the last layer of the model. 20 | """ 21 | 22 | sample: torch.Tensor 23 | 24 | class diffuserUNet2D(ModelMixin, ConfigMixin): 25 | r""" 26 | A 2D UNet model with building blocks from diffusers. 27 | 28 | This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented 29 | for all models (such as downloading or saving). 30 | """ 31 | 32 | @register_to_config 33 | def __init__( 34 | self, 35 | sample_size: Optional[Union[int, Tuple[int, int]]] = None, 36 | in_channels: int = 3, 37 | out_channels: int = 3, 38 | center_input_sample: bool = False, 39 | time_embedding_type: str = "positional", 40 | freq_shift: int = 0, 41 | flip_sin_to_cos: bool = True, 42 | down_block_types: Tuple[str, ...] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"), 43 | up_block_types: Tuple[str, ...] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"), 44 | block_out_channels: Tuple[int, ...] = (224, 448, 672, 896), 45 | layers_per_block: int = 2, 46 | mid_block_scale_factor: float = 1, 47 | downsample_padding: int = 1, 48 | downsample_type: str = "conv", 49 | upsample_type: str = "conv", 50 | dropout: float = 0.0, 51 | act_fn: str = "silu", 52 | attention_head_dim: Optional[int] = 8, 53 | norm_num_groups: int = 32, 54 | attn_norm_num_groups: Optional[int] = None, 55 | norm_eps: float = 1e-5, 56 | resnet_time_scale_shift: str = "default", 57 | add_attention: bool = True, 58 | class_embed_type: Optional[str] = None, 59 | num_class_embeds: Optional[int] = None, 60 | num_train_timesteps: Optional[int] = None, 61 | ): 62 | super().__init__() 63 | 64 | self.sample_size = sample_size 65 | time_embed_dim = block_out_channels[0] * 4 if num_class_embeds is not None else None 66 | 67 | # Check inputs 68 | if len(down_block_types) != len(up_block_types): 69 | raise ValueError( 70 | f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." 71 | ) 72 | 73 | if len(block_out_channels) != len(down_block_types): 74 | raise ValueError( 75 | f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." 76 | ) 77 | 78 | # input 79 | self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) 80 | 81 | # class embedding 82 | if class_embed_type is None and num_class_embeds is not None: 83 | self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) 84 | elif class_embed_type == "identity": 85 | self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) 86 | else: 87 | self.class_embedding = None 88 | 89 | self.down_blocks = nn.ModuleList([]) 90 | self.mid_block = None 91 | self.up_blocks = nn.ModuleList([]) 92 | 93 | # down 94 | output_channel = block_out_channels[0] 95 | for i, down_block_type in enumerate(down_block_types): 96 | input_channel = output_channel 97 | output_channel = block_out_channels[i] 98 | is_final_block = i == len(block_out_channels) - 1 99 | 100 | down_block = get_down_block( 101 | down_block_type, 102 | num_layers=layers_per_block, 103 | in_channels=input_channel, 104 | out_channels=output_channel, 105 | temb_channels=time_embed_dim, 106 | add_downsample=not is_final_block, 107 | resnet_eps=norm_eps, 108 | resnet_act_fn=act_fn, 109 | resnet_groups=norm_num_groups, 110 | attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel, 111 | downsample_padding=downsample_padding, 112 | resnet_time_scale_shift=resnet_time_scale_shift, 113 | downsample_type=downsample_type, 114 | dropout=dropout, 115 | ) 116 | self.down_blocks.append(down_block) 117 | 118 | # mid 119 | self.mid_block = UNetMidBlock2D( 120 | in_channels=block_out_channels[-1], 121 | temb_channels=time_embed_dim, 122 | dropout=dropout, 123 | resnet_eps=norm_eps, 124 | resnet_act_fn=act_fn, 125 | output_scale_factor=mid_block_scale_factor, 126 | resnet_time_scale_shift=resnet_time_scale_shift, 127 | attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1], 128 | resnet_groups=norm_num_groups, 129 | attn_groups=attn_norm_num_groups, 130 | add_attention=add_attention, 131 | ) 132 | 133 | # up 134 | reversed_block_out_channels = list(reversed(block_out_channels)) 135 | output_channel = reversed_block_out_channels[0] 136 | for i, up_block_type in enumerate(up_block_types): 137 | prev_output_channel = output_channel 138 | output_channel = reversed_block_out_channels[i] 139 | input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] 140 | 141 | is_final_block = i == len(block_out_channels) - 1 142 | 143 | up_block = get_up_block( 144 | up_block_type, 145 | num_layers=layers_per_block + 1, 146 | in_channels=input_channel, 147 | out_channels=output_channel, 148 | prev_output_channel=prev_output_channel, 149 | temb_channels=time_embed_dim, 150 | add_upsample=not is_final_block, 151 | resnet_eps=norm_eps, 152 | resnet_act_fn=act_fn, 153 | resnet_groups=norm_num_groups, 154 | attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel, 155 | resnet_time_scale_shift=resnet_time_scale_shift, 156 | upsample_type=upsample_type, 157 | dropout=dropout, 158 | ) 159 | self.up_blocks.append(up_block) 160 | prev_output_channel = output_channel 161 | 162 | # out 163 | num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32) 164 | self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=norm_eps) 165 | self.conv_act = nn.SiLU() 166 | self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1) 167 | 168 | def forward( 169 | self, 170 | sample: torch.Tensor, 171 | class_labels: Optional[torch.Tensor] = None, 172 | return_dict: bool = True, 173 | ) -> Union[UNet2DOutput, Tuple]: 174 | r""" 175 | The [`UNet2DModel`] forward method. 176 | 177 | Args: 178 | sample (`torch.Tensor`): 179 | The noisy input tensor with the following shape `(batch, channel, height, width)`. 180 | timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input. 181 | class_labels (`torch.Tensor`, *optional*, defaults to `None`): 182 | Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. 183 | return_dict (`bool`, *optional*, defaults to `True`): 184 | Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple. 185 | 186 | Returns: 187 | [`~models.unet_2d.UNet2DOutput`] or `tuple`: 188 | If `return_dict` is True, an [`~models.unet_2d.UNet2DOutput`] is returned, otherwise a `tuple` is 189 | returned where the first element is the sample tensor. 190 | """ 191 | # 0. center input if necessary 192 | if self.config.center_input_sample: 193 | sample = 2 * sample - 1.0 194 | 195 | emb = None 196 | 197 | if self.class_embedding is not None: 198 | if class_labels is None: 199 | raise ValueError("class_labels should be provided when doing class conditioning") 200 | 201 | class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) 202 | emb = class_emb 203 | elif self.class_embedding is None and class_labels is not None: 204 | raise ValueError("class_embedding needs to be initialized in order to use class conditioning") 205 | 206 | # 2. pre-process 207 | skip_sample = sample 208 | sample = self.conv_in(sample) 209 | 210 | # 3. down 211 | down_block_res_samples = (sample,) 212 | for downsample_block in self.down_blocks: 213 | if hasattr(downsample_block, "skip_conv"): 214 | sample, res_samples, skip_sample = downsample_block( 215 | hidden_states=sample, temb=emb, skip_sample=skip_sample 216 | ) 217 | else: 218 | sample, res_samples = downsample_block(hidden_states=sample, temb=emb) 219 | 220 | down_block_res_samples += res_samples 221 | 222 | # 4. mid 223 | sample = self.mid_block(sample, emb) 224 | 225 | # 5. up 226 | skip_sample = None 227 | for upsample_block in self.up_blocks: 228 | res_samples = down_block_res_samples[-len(upsample_block.resnets) :] 229 | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] 230 | 231 | if hasattr(upsample_block, "skip_conv"): 232 | sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample) 233 | else: 234 | sample = upsample_block(sample, res_samples, emb) 235 | 236 | # 6. post-process 237 | sample = self.conv_norm_out(sample) 238 | sample = self.conv_act(sample) 239 | sample = self.conv_out(sample) 240 | 241 | if skip_sample is not None: 242 | sample += skip_sample 243 | 244 | if not return_dict: 245 | return (sample,) 246 | 247 | return UNet2DOutput(sample=sample) -------------------------------------------------------------------------------- /noise_schedulers/noise_sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class Karras_sigmas_lognormal: 4 | def __init__(self, sigmas, P_mean=-1.2, P_std=1.2, sigma_min=0.002, sigma_max=80, num_train_timesteps=1000): 5 | self.P_mean = P_mean 6 | self.P_std = P_std 7 | self.sigma_min = sigma_min 8 | self.sigma_max = sigma_max 9 | self.num_train_timesteps = num_train_timesteps 10 | self.sigmas = sigmas 11 | 12 | def __call__(self, batch_size, generator=None, device='cpu'): 13 | rnd_normal = torch.randn([batch_size, 1, 1, 1], device=device, generator=generator) 14 | sigma = (rnd_normal * self.P_std + self.P_mean).exp() 15 | # Find the indices of the closest matches 16 | # Expand self.sigmas to match the batch size 17 | # sigmas get concatenated with 0 at the end 18 | sigmas_expanded = self.sigmas[:-1].view(1, -1).to(device) 19 | sigma_expanded = sigma.view(batch_size, 1) 20 | 21 | # Calculate the difference and find the indices of the minimum difference 22 | diff = torch.abs(sigmas_expanded - sigma_expanded) 23 | indices = torch.argmin(diff, dim=1) 24 | 25 | return indices -------------------------------------------------------------------------------- /pipelines/pipeline_inv_prob.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput 3 | from diffusers.utils.torch_utils import randn_tensor 4 | from typing import Optional, Union, Tuple, List 5 | 6 | from utils.pipeline_utils import Fields2DPipelineOutput 7 | 8 | class InverseProblem2DPipeline(DiffusionPipeline): 9 | 10 | model_cpu_offload_seq = "unet" 11 | 12 | def __init__(self, unet, scheduler, scheduler_step_kwargs: Optional[dict] = None): 13 | super().__init__() 14 | 15 | self.register_modules(unet=unet, scheduler=scheduler) 16 | self.scheduler_step_kwargs = scheduler_step_kwargs or {} 17 | 18 | @torch.no_grad() 19 | def __call__( 20 | self, 21 | batch_size: int = 1, 22 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 23 | num_inference_steps: int = 50, 24 | return_dict: bool = True, 25 | mask: torch.Tensor = None, 26 | known_channels: List[int] = None, 27 | known_latents: torch.Tensor = None, 28 | ) -> Union[Fields2DPipelineOutput, Tuple]: 29 | 30 | # Sample gaussian noise to begin loop 31 | if isinstance(self.unet.config.sample_size, int): 32 | image_shape = ( 33 | batch_size, 34 | self.unet.config.out_channels, 35 | self.unet.config.sample_size, 36 | self.unet.config.sample_size, 37 | ) 38 | else: 39 | image_shape = (batch_size, self.unet.config.out_channels, *self.unet.config.sample_size) 40 | 41 | if isinstance(generator, list) and len(generator) != batch_size: 42 | raise ValueError( 43 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 44 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 45 | ) 46 | 47 | assert known_latents is not None, "known_latents must be provided" 48 | 49 | image = randn_tensor(image_shape, generator=generator, device=self._execution_device, dtype=self.unet.dtype) 50 | noise = image.clone() 51 | 52 | # set step values 53 | self.scheduler.set_timesteps(num_inference_steps) 54 | 55 | for i, t in enumerate(self.scheduler.timesteps): 56 | x_in = self.scheduler.scale_model_input(image, t) 57 | # 1. predict noise model_output 58 | model_output = self.unet(x_in, t, return_dict=False)[0] 59 | model_output = model_output * (1 - mask) + known_latents * mask 60 | 61 | # 2. do x_t -> x_t-1 62 | image = self.scheduler.step( 63 | model_output, t, image[:, :self.unet.config.out_channels], **self.scheduler_step_kwargs, 64 | return_dict=False 65 | )[0] 66 | 67 | tmp_known_latents = known_latents.clone() 68 | if i < len(self.scheduler.timesteps) - 1: 69 | noise_timestep = self.scheduler.timesteps[i + 1] 70 | tmp_known_latents = self.scheduler.add_noise(tmp_known_latents, noise, torch.tensor([noise_timestep])) 71 | 72 | image = image * (1 - mask) + tmp_known_latents * mask 73 | 74 | if not return_dict: 75 | return (image,) 76 | 77 | return Fields2DPipelineOutput(fields=image) 78 | 79 | class InverseProblem2DCondPipeline(DiffusionPipeline): 80 | 81 | model_cpu_offload_seq = "unet" 82 | 83 | def __init__(self, unet, scheduler, scheduler_step_kwargs: Optional[dict] = None): 84 | super().__init__() 85 | 86 | self.register_modules(unet=unet, scheduler=scheduler) 87 | self.scheduler_step_kwargs = scheduler_step_kwargs or {} 88 | 89 | @torch.no_grad() 90 | def __call__( 91 | self, 92 | batch_size: int = 1, 93 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 94 | num_inference_steps: int = 50, 95 | return_dict: bool = True, 96 | mask: torch.Tensor = None, 97 | known_channels: List[int] = None, 98 | known_latents: torch.Tensor = None, 99 | add_noise_to_obs: bool = False, 100 | ) -> Union[Fields2DPipelineOutput, Tuple]: 101 | 102 | # Sample gaussian noise to begin loop 103 | if isinstance(self.unet.config.sample_size, int): 104 | image_shape = ( 105 | batch_size, 106 | self.unet.config.out_channels, 107 | self.unet.config.sample_size, 108 | self.unet.config.sample_size, 109 | ) 110 | else: 111 | image_shape = (batch_size, self.unet.config.out_channels, *self.unet.config.sample_size) 112 | 113 | if isinstance(generator, list) and len(generator) != batch_size: 114 | raise ValueError( 115 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 116 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 117 | ) 118 | 119 | assert known_latents is not None, "known_latents must be provided" 120 | 121 | image = randn_tensor(image_shape, generator=generator, device=self._execution_device, dtype=self.unet.dtype) 122 | if add_noise_to_obs: 123 | noise = image.clone() 124 | conditioning_tensors = torch.cat((known_latents, mask[:, [known_channels[0]]]), dim=1) 125 | 126 | # set step values 127 | self.scheduler.set_timesteps(num_inference_steps) 128 | 129 | for i, t in enumerate(self.scheduler.timesteps): 130 | if mask is not None and not add_noise_to_obs: 131 | image = image * (1 - mask) + known_latents * mask 132 | x_in = self.scheduler.scale_model_input(image, t) 133 | # 1. predict noise model_output 134 | model_output = self.unet(x_in, t, conditioning_tensors=conditioning_tensors, return_dict=False)[0] 135 | 136 | # 2. do x_t -> x_t-1 137 | image = self.scheduler.step( 138 | model_output, t, image, **self.scheduler_step_kwargs, 139 | return_dict=False 140 | )[0] 141 | if add_noise_to_obs: 142 | tmp_known_latents = known_latents.clone() 143 | if i < len(self.scheduler.timesteps) - 1: 144 | noise_timestep = self.scheduler.timesteps[i + 1] 145 | tmp_known_latents = self.scheduler.add_noise(tmp_known_latents, noise, torch.tensor([noise_timestep])) 146 | 147 | if mask is not None: 148 | image = image * (1 - mask) + known_latents * mask 149 | 150 | if not return_dict: 151 | return (image,) 152 | 153 | return Fields2DPipelineOutput(fields=image) -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Diffusion models for field reconstruction 2 | --- 3 | 4 | Demo: [](https://colab.research.google.com/drive/1RzcvX7jHDVc1VTkyUAe8bRA3C93xEffd?usp=sharing); [](https://huggingface.co/tonyzyl/DiffusionReconstruct) 5 | 6 | This repository contains the code for the paper "Spatially-Aware Diffusion Models with Cross-Attention for Global Field Reconstruction with Sparse Observations". The paper is available on [arXiv]( 7 | https://doi.org/10.48550/arXiv.2409.00230), accepted by the CMAME. 8 | 9 | ### Summary of the work 10 | 11 | * We propose a cross-attention diffusion model for global field reconstruction. 12 | * We compare different conditioning methods against a strong deterministic baseline and conducted a comprehensive benchmark on the effect of hyperparameters on reconstruction quality. 13 | * The cross-attention diffusion model outperforms others in noisy settings, while the deterministic model is easier to train and excels in noiseless conditions. 14 | * The ensemble mean in probabilistic reconstructions of diffusion model converges to the deterministic output. 15 | * Latent diffusion might be beneficial for handling unevenly distributed fields and to reduce the overall training cost. 16 | * The tested PDEs include Darcy flow, diffusion-reaction, shallow water equations and compressible Navier-Stokes equations. 17 | 18 | 19 | Tested conditioning methods include: 20 | * Guided Sampling (or inpainting) 21 | * Classifier-Free Guided Sampling (CFG) 22 | * Cross-Attention with the proposed encoding block 23 | 24 |
Figure 1: Results of diffusion models are computed from an ensemble of 25 trajectories over 20 steps The red dashed line represents the baseline for reconstruction using the mean.
27 |Figure 2: Schematic of the proposed condition encoding block. For CFG, mean-pooling is performed to reduce the dimensionality and to combine it with the noise scale embedding.
43 |Figure 3: Comparison of the generated fields by VT-UNet, single trajectory and ensemble mean of cross-attention diffusion model for the Diffusion Reaction equations with 1% observed data points.
53 |Figure 4: Histogram of relative error improvement distribution with different DA error covariances on the shallow water equations.
65 |