├── LICENSE ├── configs ├── compressible_NS_cfg_config.yaml ├── compressible_NS_uncond_config.yaml ├── compressible_NS_vt_config.yaml ├── compressible_NS_xattn_config.yaml ├── darcy_cfg_config.yaml ├── darcy_uncond_config.yaml ├── darcy_vt_config.yaml ├── darcy_xattn_config.yaml ├── diffusion_reaction_cfg_config.yaml ├── diffusion_reaction_uncond_config.yaml ├── diffusion_reaction_vt_config.yaml ├── diffusion_reaction_xattn_config.yaml ├── shallow_water_cfg_config.yaml ├── shallow_water_uncond_config.yaml ├── shallow_water_vt_config.yaml └── shallow_water_xattn_config.yaml ├── dataloader └── dataset_class.py ├── evaluate.py ├── git_assest ├── bar_chart_0.01_largefont.png ├── darcy.png ├── dr_hf.png ├── encoding_block.png └── error_hist.png ├── losses ├── loss.py └── metric.py ├── models ├── unet2D.py └── unet2DCondition.py ├── noise_schedulers └── noise_sampler.py ├── pipelines └── pipeline_inv_prob.py ├── readme.md ├── requirements.txt ├── requirements_noversion.txt ├── train_cond.py ├── train_uncond.py ├── train_vt.py └── utils ├── attn_utils.py ├── general_utils.py ├── inverse_utils.py ├── pipeline_utils.py └── vt_utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /configs/compressible_NS_cfg_config.yaml: -------------------------------------------------------------------------------- 1 | unet: 2 | _class_name: diffuserUNet2DCondition 3 | _diffusers_version: 0.29.2 4 | act_fn: silu 5 | addition_embed_type: null 6 | addition_embed_type_num_heads: 64 7 | addition_time_embed_dim: null 8 | attention_head_dim: 8 9 | attention_type: default 10 | block_out_channels: 11 | - 128 12 | - 256 13 | - 256 14 | - 256 15 | center_input_sample: false 16 | class_embed_type: null 17 | class_embeddings_concat: false 18 | conv_in_kernel: 3 19 | conv_out_kernel: 3 20 | cross_attention_dim: 512 21 | cross_attention_norm: null 22 | down_block_types: 23 | - DownBlock2D 24 | - DownBlock2D 25 | - DownBlock2D 26 | - DownBlock2D 27 | downsample_padding: 1 28 | dropout: 0.0 29 | dual_cross_attention: false 30 | encoder_hid_dim: null 31 | encoder_hid_dim_type: null 32 | flip_sin_to_cos: true 33 | freq_shift: 0 34 | in_channels: 4 35 | layers_per_block: 2 36 | mid_block_only_cross_attention: null 37 | mid_block_scale_factor: 1 38 | mid_block_type: UNetMidBlock2D 39 | norm_eps: 1e-05 40 | norm_num_groups: 32 41 | num_attention_heads: null 42 | num_class_embeds: null 43 | only_cross_attention: false 44 | out_channels: 4 45 | projection_class_embeddings_input_dim: null 46 | resnet_out_scale_factor: 1.0 47 | resnet_skip_time_act: false 48 | resnet_time_scale_shift: scale_shift 49 | reverse_transformer_layers_per_block: null 50 | sample_size: 128 51 | time_cond_proj_dim: null 52 | time_embedding_act_fn: null 53 | time_embedding_dim: null 54 | time_embedding_type: positional 55 | timestep_post_act: null 56 | transformer_layers_per_block: 1 57 | up_block_types: 58 | - UpBlock2D 59 | - UpBlock2D 60 | - UpBlock2D 61 | - UpBlock2D 62 | upcast_attention: false 63 | use_linear_projection: false 64 | field_encoder_dict: 65 | hidden_size: 512 66 | intermediate_size: 2048 67 | projection_dim: 512 68 | image_size: 69 | - 128 70 | - 128 71 | patch_size: 8 72 | num_channels: 4 73 | num_hidden_layers: 4 74 | num_attention_heads: 8 75 | input_padding: 76 | - 0 77 | - 0 78 | output_hidden_state: false 79 | 80 | noise_scheduler: 81 | target: diffusers.EDMDPMSolverMultistepScheduler 82 | params: 83 | num_train_timesteps: 1000 84 | 85 | loss_fn: 86 | target: losses.loss.EDMLoss 87 | params: 88 | sigma_data: 0.5 89 | 90 | optimizer: 91 | betas: 92 | - 0.9 93 | - 0.999 94 | eps: 1e-08 95 | lr: 1e-4 96 | weight_decay: 1e-2 97 | 98 | lr_scheduler: 99 | #name: cosine 100 | name: constant 101 | num_warmup_steps: 500 102 | num_cycles: 0.5 103 | power: 1.0 104 | 105 | dataloader: 106 | data_dir: /scratch/kdur_root/kdur/shared_data/pdebench_data/2D/CFD/2D_Train_Rand/2D_CFD_Rand_M0.1_Eta0.1_Zeta0.1_periodic_128_Train.hdf5 107 | batch_size: 16 108 | num_workers: 0 109 | split_ratios: 110 | - 0.8 111 | - 0.2 112 | - 0.0 113 | transform: normalize 114 | transform_args: 115 | mean: [0.0001482, 0.0000293, 5.0561895, 25.3754578] 116 | std: [0.0422693, 0.0422711, 2.9536870, 22.0486488] 117 | target_std: 0.5 118 | data_name: compressible_NS 119 | 120 | accelerator: 121 | mixed_precision: fp16 122 | gradient_accumulation_steps: 1 123 | log_with: tensorboard 124 | 125 | ema: 126 | use_ema: True 127 | offload_ema: False 128 | ema_max_decay: 0.9999 129 | ema_inv_gamma: 1.0 130 | ema_power: 0.75 131 | foreach: True 132 | 133 | general: 134 | seed: 42 135 | num_epochs: null 136 | num_training_steps: 100000 137 | known_channels: [0,1,2,3] 138 | same_mask: True 139 | scale_lr: False 140 | output_dir: /scratch/kdur_root/kdur/ylzhuang/log/compressible_NS_cfg 141 | logging_dir: comp_NS 142 | tracker_project_name: compNS_tracker 143 | save_image_epochs: 5 144 | save_model_epochs: 50 145 | checkpointing_steps: 25000 146 | eval_batch_size: 8 147 | cond_drop_prob: null 148 | do_edm_style_training: True 149 | snr_gamma: null 150 | channel_names: ["Vx", "Vy", "pressure", "density"] -------------------------------------------------------------------------------- /configs/compressible_NS_uncond_config.yaml: -------------------------------------------------------------------------------- 1 | unet: 2 | _class_name: UNet2DModel 3 | _diffusers_version: 0.28.2 4 | act_fn: silu 5 | add_attention: true 6 | attention_head_dim: 8 7 | attn_norm_num_groups: null 8 | block_out_channels: 9 | - 128 10 | - 256 11 | - 256 12 | - 256 13 | center_input_sample: false 14 | class_embed_type: null 15 | down_block_types: 16 | - DownBlock2D 17 | - DownBlock2D 18 | - DownBlock2D 19 | - DownBlock2D 20 | downsample_padding: 1 21 | downsample_type: conv 22 | dropout: 0.0 23 | flip_sin_to_cos: true 24 | freq_shift: 0 25 | in_channels: 4 26 | layers_per_block: 2 27 | mid_block_scale_factor: 1 28 | norm_eps: 1e-05 29 | norm_num_groups: 32 30 | num_class_embeds: null 31 | num_train_timesteps: null 32 | out_channels: 4 33 | resnet_time_scale_shift: scale_shift 34 | sample_size: 128 35 | time_embedding_type: positional 36 | up_block_types: 37 | - UpBlock2D 38 | - UpBlock2D 39 | - UpBlock2D 40 | - UpBlock2D 41 | upsample_type: conv 42 | 43 | noise_scheduler: 44 | target: diffusers.EDMDPMSolverMultistepScheduler 45 | params: 46 | num_train_timesteps: 1000 47 | 48 | loss_fn: 49 | target: losses.loss.EDMLoss 50 | params: 51 | sigma_data: 0.5 52 | 53 | optimizer: 54 | betas: 55 | - 0.9 56 | - 0.999 57 | eps: 1e-08 58 | lr: 1e-4 59 | weight_decay: 1e-2 60 | 61 | lr_scheduler: 62 | #name: cosine 63 | name: constant 64 | num_warmup_steps: 500 65 | num_cycles: 0.5 66 | power: 1.0 67 | 68 | dataloader: 69 | data_dir: /scratch/kdur_root/kdur/shared_data/pdebench_data/2D/CFD/2D_Train_Rand/2D_CFD_Rand_M0.1_Eta0.1_Zeta0.1_periodic_128_Train.hdf5 70 | batch_size: 16 71 | num_workers: 0 72 | split_ratios: 73 | - 0.8 74 | - 0.2 75 | - 0.0 76 | transform: normalize 77 | transform_args: 78 | mean: [0.0001482, 0.0000293, 5.0561895, 25.3754578] 79 | std: [0.0422693, 0.0422711, 2.9536870, 22.0486488] 80 | target_std: 0.5 81 | data_name: compressible_NS 82 | 83 | accelerator: 84 | mixed_precision: fp16 85 | gradient_accumulation_steps: 1 86 | log_with: tensorboard 87 | 88 | ema: 89 | use_ema: True 90 | offload_ema: False 91 | ema_max_decay: 0.9999 92 | ema_inv_gamma: 1.0 93 | ema_power: 0.75 94 | foreach: True 95 | 96 | general: 97 | seed: 42 98 | num_epochs: null 99 | num_training_steps: 100000 100 | known_channels: [0,1,2,3] 101 | same_mask: True 102 | scale_lr: False 103 | output_dir: /scratch/kdur_root/kdur/ylzhuang/log/compressible_NS_uncond 104 | logging_dir: comp_NS 105 | tracker_project_name: compNS_tracker 106 | save_image_epochs: 5 107 | save_model_epochs: 50 108 | checkpointing_steps: 25000 109 | eval_batch_size: 8 110 | cond_drop_prob: null 111 | do_edm_style_training: True 112 | snr_gamma: null 113 | channel_names: ["Vx", "Vy", "pressure", "density"] -------------------------------------------------------------------------------- /configs/compressible_NS_vt_config.yaml: -------------------------------------------------------------------------------- 1 | unet: 2 | _class_name: UNet2DModel 3 | _diffusers_version: 0.28.2 4 | act_fn: silu 5 | add_attention: true 6 | attention_head_dim: 8 7 | attn_norm_num_groups: null 8 | block_out_channels: 9 | - 128 10 | - 256 11 | - 256 12 | - 256 13 | center_input_sample: false 14 | class_embed_type: null 15 | down_block_types: 16 | - DownBlock2D 17 | - DownBlock2D 18 | - DownBlock2D 19 | - DownBlock2D 20 | downsample_padding: 1 21 | downsample_type: conv 22 | dropout: 0.0 23 | flip_sin_to_cos: true 24 | freq_shift: 0 25 | in_channels: 4 26 | layers_per_block: 2 27 | mid_block_scale_factor: 1 28 | norm_eps: 1e-05 29 | norm_num_groups: 32 30 | num_class_embeds: null 31 | num_train_timesteps: null 32 | out_channels: 4 33 | resnet_time_scale_shift: scale_shift 34 | sample_size: 128 35 | time_embedding_type: positional 36 | up_block_types: 37 | - UpBlock2D 38 | - UpBlock2D 39 | - UpBlock2D 40 | - UpBlock2D 41 | upsample_type: conv 42 | 43 | noise_scheduler: 44 | target: diffusers.EDMDPMSolverMultistepScheduler 45 | params: 46 | num_train_timesteps: 1000 47 | 48 | loss_fn: 49 | target: losses.loss.EDMLoss 50 | params: 51 | sigma_data: 0.5 52 | 53 | optimizer: 54 | betas: 55 | - 0.9 56 | - 0.999 57 | eps: 1e-08 58 | lr: 1e-4 59 | weight_decay: 1e-2 60 | 61 | lr_scheduler: 62 | #name: cosine 63 | name: constant 64 | num_warmup_steps: 500 65 | num_cycles: 0.5 66 | power: 1.0 67 | 68 | dataloader: 69 | data_dir: /scratch/kdur_root/kdur/shared_data/pdebench_data/2D/CFD/2D_Train_Rand/2D_CFD_Rand_M0.1_Eta0.1_Zeta0.1_periodic_128_Train.hdf5 70 | batch_size: 16 71 | num_workers: 0 72 | split_ratios: 73 | - 0.8 74 | - 0.2 75 | - 0.0 76 | transform: normalize 77 | transform_args: 78 | mean: [0.0001482, 0.0000293, 5.0561895, 25.3754578] 79 | std: [0.0422693, 0.0422711, 2.9536870, 22.0486488] 80 | target_std: 0.5 81 | data_name: compressible_NS 82 | 83 | accelerator: 84 | mixed_precision: fp16 85 | gradient_accumulation_steps: 1 86 | log_with: tensorboard 87 | 88 | ema: 89 | use_ema: True 90 | offload_ema: False 91 | ema_max_decay: 0.9999 92 | ema_inv_gamma: 1.0 93 | ema_power: 0.75 94 | foreach: True 95 | 96 | general: 97 | seed: 42 98 | num_epochs: null 99 | num_training_steps: 100000 100 | known_channels: [0,1,2,3] 101 | same_mask: True 102 | scale_lr: False 103 | output_dir: /scratch/kdur_root/kdur/ylzhuang/log/compressible_NS_vt 104 | logging_dir: comp_NS 105 | tracker_project_name: compNS_tracker 106 | save_image_epochs: 5 107 | save_model_epochs: 50 108 | checkpointing_steps: 25000 109 | eval_batch_size: 8 110 | do_edm_style_training: True 111 | channel_names: ["Vx", "Vy", "pressure", "density"] -------------------------------------------------------------------------------- /configs/compressible_NS_xattn_config.yaml: -------------------------------------------------------------------------------- 1 | unet: 2 | _class_name: diffuserUNet2DCondition 3 | _diffusers_version: 0.29.2 4 | act_fn: silu 5 | addition_embed_type: null 6 | addition_embed_type_num_heads: 64 7 | addition_time_embed_dim: null 8 | attention_head_dim: 8 9 | attention_type: default 10 | block_out_channels: 11 | - 128 12 | - 256 13 | - 256 14 | - 256 15 | center_input_sample: false 16 | class_embed_type: null 17 | class_embeddings_concat: false 18 | conv_in_kernel: 3 19 | conv_out_kernel: 3 20 | cross_attention_dim: 512 21 | cross_attention_norm: null 22 | down_block_types: 23 | - DownBlock2D 24 | - CrossAttnDownBlock2D 25 | - CrossAttnDownBlock2D 26 | - CrossAttnDownBlock2D 27 | downsample_padding: 1 28 | dropout: 0.0 29 | dual_cross_attention: false 30 | encoder_hid_dim: null 31 | encoder_hid_dim_type: null 32 | flip_sin_to_cos: true 33 | freq_shift: 0 34 | in_channels: 4 35 | layers_per_block: 2 36 | mid_block_only_cross_attention: null 37 | mid_block_scale_factor: 1 38 | mid_block_type: UNetMidBlock2DCrossAttn 39 | norm_eps: 1e-05 40 | norm_num_groups: 32 41 | num_attention_heads: null 42 | num_class_embeds: null 43 | only_cross_attention: false 44 | out_channels: 4 45 | projection_class_embeddings_input_dim: null 46 | resnet_out_scale_factor: 1.0 47 | resnet_skip_time_act: false 48 | resnet_time_scale_shift: scale_shift 49 | reverse_transformer_layers_per_block: null 50 | sample_size: 128 51 | time_cond_proj_dim: null 52 | time_embedding_act_fn: null 53 | time_embedding_dim: null 54 | time_embedding_type: positional 55 | timestep_post_act: null 56 | transformer_layers_per_block: 1 57 | up_block_types: 58 | - CrossAttnUpBlock2D 59 | - CrossAttnUpBlock2D 60 | - CrossAttnUpBlock2D 61 | - UpBlock2D 62 | upcast_attention: false 63 | use_linear_projection: false 64 | field_encoder_dict: 65 | hidden_size: 512 66 | intermediate_size: 2048 67 | projection_dim: 512 68 | image_size: 69 | - 128 70 | - 128 71 | patch_size: 8 72 | num_channels: 4 73 | num_hidden_layers: 4 74 | num_attention_heads: 8 75 | input_padding: 76 | - 0 77 | - 0 78 | output_hidden_state: true 79 | 80 | noise_scheduler: 81 | target: diffusers.EDMDPMSolverMultistepScheduler 82 | params: 83 | num_train_timesteps: 1000 84 | 85 | loss_fn: 86 | target: losses.loss.EDMLoss 87 | params: 88 | sigma_data: 0.5 89 | 90 | optimizer: 91 | betas: 92 | - 0.9 93 | - 0.999 94 | eps: 1e-08 95 | lr: 1e-4 96 | weight_decay: 1e-2 97 | 98 | lr_scheduler: 99 | #name: cosine 100 | name: constant 101 | num_warmup_steps: 500 102 | num_cycles: 0.5 103 | power: 1.0 104 | 105 | dataloader: 106 | data_dir: /scratch/kdur_root/kdur/shared_data/pdebench_data/2D/CFD/2D_Train_Rand/2D_CFD_Rand_M0.1_Eta0.1_Zeta0.1_periodic_128_Train.hdf5 107 | batch_size: 16 108 | num_workers: 0 109 | split_ratios: 110 | - 0.8 111 | - 0.2 112 | - 0.0 113 | transform: normalize 114 | transform_args: 115 | mean: [0.0001482, 0.0000293, 5.0561895, 25.3754578] 116 | std: [0.0422693, 0.0422711, 2.9536870, 22.0486488] 117 | target_std: 0.5 118 | data_name: compressible_NS 119 | 120 | accelerator: 121 | mixed_precision: fp16 122 | gradient_accumulation_steps: 1 123 | log_with: tensorboard 124 | 125 | ema: 126 | use_ema: True 127 | offload_ema: False 128 | ema_max_decay: 0.9999 129 | ema_inv_gamma: 1.0 130 | ema_power: 0.75 131 | foreach: True 132 | 133 | general: 134 | seed: 42 135 | num_epochs: null 136 | num_training_steps: 100000 137 | known_channels: [0,1,2,3] 138 | same_mask: True 139 | scale_lr: False 140 | output_dir: /scratch/kdur_root/kdur/ylzhuang/log/compressible_NS_xattn 141 | logging_dir: comp_NS 142 | tracker_project_name: compNS_tracker 143 | save_image_epochs: 5 144 | save_model_epochs: 50 145 | checkpointing_steps: 25000 146 | eval_batch_size: 8 147 | cond_drop_prob: null 148 | do_edm_style_training: True 149 | snr_gamma: null 150 | channel_names: ["Vx", "Vy", "pressure", "density"] -------------------------------------------------------------------------------- /configs/darcy_cfg_config.yaml: -------------------------------------------------------------------------------- 1 | unet: 2 | _class_name: diffuserUNet2DCondition 3 | _diffusers_version: 0.29.2 4 | act_fn: silu 5 | addition_embed_type: null 6 | addition_embed_type_num_heads: 64 7 | addition_time_embed_dim: null 8 | attention_head_dim: 8 9 | attention_type: default 10 | block_out_channels: 11 | - 128 12 | - 256 13 | - 256 14 | - 256 15 | center_input_sample: false 16 | class_embed_type: null 17 | class_embeddings_concat: false 18 | conv_in_kernel: 3 19 | conv_out_kernel: 3 20 | cross_attention_dim: 256 21 | cross_attention_norm: null 22 | down_block_types: 23 | - DownBlock2D 24 | - DownBlock2D 25 | - DownBlock2D 26 | - DownBlock2D 27 | downsample_padding: 1 28 | dropout: 0.0 29 | dual_cross_attention: false 30 | encoder_hid_dim: null 31 | encoder_hid_dim_type: null 32 | flip_sin_to_cos: true 33 | freq_shift: 0 34 | in_channels: 2 35 | layers_per_block: 2 36 | mid_block_only_cross_attention: null 37 | mid_block_scale_factor: 1 38 | mid_block_type: UNetMidBlock2D 39 | norm_eps: 1e-05 40 | norm_num_groups: 32 41 | num_attention_heads: null 42 | num_class_embeds: null 43 | only_cross_attention: false 44 | out_channels: 2 45 | projection_class_embeddings_input_dim: null 46 | resnet_out_scale_factor: 1.0 47 | resnet_skip_time_act: false 48 | resnet_time_scale_shift: scale_shift 49 | reverse_transformer_layers_per_block: null 50 | sample_size: 128 51 | time_cond_proj_dim: null 52 | time_embedding_act_fn: null 53 | time_embedding_dim: null 54 | time_embedding_type: positional 55 | timestep_post_act: null 56 | transformer_layers_per_block: 1 57 | up_block_types: 58 | - UpBlock2D 59 | - UpBlock2D 60 | - UpBlock2D 61 | - UpBlock2D 62 | upcast_attention: false 63 | use_linear_projection: false 64 | field_encoder_dict: 65 | hidden_size: 256 66 | intermediate_size: 1024 67 | projection_dim: 256 68 | image_size: 69 | - 128 70 | - 128 71 | patch_size: 8 72 | num_channels: 1 73 | num_hidden_layers: 4 74 | num_attention_heads: 8 75 | input_padding: 76 | - 0 77 | - 0 78 | output_hidden_state: false 79 | 80 | noise_scheduler: 81 | target: diffusers.EDMDPMSolverMultistepScheduler 82 | params: 83 | num_train_timesteps: 1000 84 | 85 | loss_fn: 86 | target: losses.loss.EDMLoss 87 | params: 88 | sigma_data: 0.5 89 | 90 | optimizer: 91 | betas: 92 | - 0.9 93 | - 0.999 94 | eps: 1e-08 95 | lr: 1e-4 96 | weight_decay: 1e-2 97 | 98 | lr_scheduler: 99 | #name: cosine 100 | name: constant 101 | num_warmup_steps: 500 102 | num_cycles: 0.5 103 | power: 1.0 104 | 105 | dataloader: 106 | data_dir: /scratch/kdur_root/kdur/ylzhuang/darcy_mod.npy 107 | batch_size: 16 108 | num_workers: 0 109 | split_ratios: 110 | - 0.8 111 | - 0.2 112 | - 0.0 113 | transform: normalize 114 | transform_args: 115 | mean: [1.1482742, 46.6775513] 116 | std: [0.6740283, 30.9050026] 117 | target_std: 0.5 118 | data_name: darcy 119 | 120 | accelerator: 121 | mixed_precision: fp16 122 | gradient_accumulation_steps: 1 123 | log_with: tensorboard 124 | 125 | ema: 126 | use_ema: True 127 | offload_ema: False 128 | ema_max_decay: 0.9999 129 | ema_inv_gamma: 1.0 130 | ema_power: 0.75 131 | foreach: True 132 | 133 | general: 134 | seed: 42 135 | num_epochs: null 136 | num_training_steps: 100000 137 | known_channels: [1] 138 | same_mask: False 139 | scale_lr: False 140 | output_dir: /scratch/kdur_root/kdur/ylzhuang/log/darcy_cfg 141 | logging_dir: darcy 142 | tracker_project_name: darcy_tracker 143 | save_image_epochs: 50 144 | save_model_epochs: 200 145 | checkpointing_steps: 25000 146 | eval_batch_size: 8 147 | cond_drop_prob: null 148 | do_edm_style_training: True 149 | snr_gamma: null 150 | channel_names: ["permeability", "pressure"] -------------------------------------------------------------------------------- /configs/darcy_uncond_config.yaml: -------------------------------------------------------------------------------- 1 | unet: 2 | _class_name: UNet2DModel 3 | _diffusers_version: 0.28.2 4 | act_fn: silu 5 | add_attention: true 6 | attention_head_dim: 8 7 | attn_norm_num_groups: null 8 | block_out_channels: 9 | - 128 10 | - 256 11 | - 256 12 | - 256 13 | center_input_sample: false 14 | class_embed_type: null 15 | down_block_types: 16 | - DownBlock2D 17 | - DownBlock2D 18 | - DownBlock2D 19 | - DownBlock2D 20 | downsample_padding: 1 21 | downsample_type: conv 22 | dropout: 0.0 23 | flip_sin_to_cos: true 24 | freq_shift: 0 25 | in_channels: 2 26 | layers_per_block: 2 27 | mid_block_scale_factor: 1 28 | norm_eps: 1e-05 29 | norm_num_groups: 32 30 | num_class_embeds: null 31 | num_train_timesteps: null 32 | out_channels: 2 33 | resnet_time_scale_shift: scale_shift 34 | sample_size: 128 35 | time_embedding_type: positional 36 | up_block_types: 37 | - UpBlock2D 38 | - UpBlock2D 39 | - UpBlock2D 40 | - UpBlock2D 41 | upsample_type: conv 42 | 43 | noise_scheduler: 44 | target: diffusers.EDMDPMSolverMultistepScheduler 45 | params: 46 | num_train_timesteps: 1000 47 | 48 | loss_fn: 49 | target: losses.loss.EDMLoss 50 | params: 51 | sigma_data: 0.5 52 | 53 | optimizer: 54 | betas: 55 | - 0.9 56 | - 0.999 57 | eps: 1e-08 58 | lr: 1e-4 59 | weight_decay: 1e-2 60 | 61 | lr_scheduler: 62 | #name: cosine 63 | name: constant 64 | num_warmup_steps: 500 65 | num_cycles: 0.5 66 | power: 1.0 67 | 68 | dataloader: 69 | data_dir: /scratch/kdur_root/kdur/ylzhuang/darcy_mod.npy 70 | batch_size: 16 71 | num_workers: 0 72 | split_ratios: 73 | - 0.8 74 | - 0.2 75 | - 0.0 76 | transform: normalize 77 | transform_args: 78 | mean: [1.1482742, 46.6775513] 79 | std: [0.6740283, 30.9050026] 80 | target_std: 0.5 81 | data_name: darcy 82 | 83 | accelerator: 84 | mixed_precision: fp16 85 | gradient_accumulation_steps: 1 86 | log_with: tensorboard 87 | 88 | ema: 89 | use_ema: True 90 | offload_ema: False 91 | ema_max_decay: 0.9999 92 | ema_inv_gamma: 1.0 93 | ema_power: 0.75 94 | foreach: True 95 | 96 | general: 97 | seed: 42 98 | num_epochs: null 99 | num_training_steps: 100000 100 | known_channels: [1] 101 | same_mask: False 102 | scale_lr: False 103 | output_dir: /scratch/kdur_root/kdur/ylzhuang/log/darcy_uncond 104 | logging_dir: darcy 105 | tracker_project_name: darcy_tracker 106 | save_image_epochs: 50 107 | save_model_epochs: 200 108 | checkpointing_steps: 25000 109 | eval_batch_size: 8 110 | do_edm_style_training: True 111 | snr_gamma: null 112 | channel_names: ["permeability", "pressure"] -------------------------------------------------------------------------------- /configs/darcy_vt_config.yaml: -------------------------------------------------------------------------------- 1 | unet: 2 | _class_name: UNet2DModel 3 | _diffusers_version: 0.28.2 4 | act_fn: silu 5 | add_attention: true 6 | attention_head_dim: 8 7 | attn_norm_num_groups: null 8 | block_out_channels: 9 | - 128 10 | - 256 11 | - 256 12 | - 256 13 | center_input_sample: false 14 | class_embed_type: null 15 | down_block_types: 16 | - DownBlock2D 17 | - DownBlock2D 18 | - DownBlock2D 19 | - DownBlock2D 20 | downsample_padding: 1 21 | downsample_type: conv 22 | dropout: 0.0 23 | flip_sin_to_cos: true 24 | freq_shift: 0 25 | in_channels: 1 26 | layers_per_block: 2 27 | mid_block_scale_factor: 1 28 | norm_eps: 1e-05 29 | norm_num_groups: 32 30 | num_class_embeds: null 31 | num_train_timesteps: null 32 | out_channels: 2 33 | resnet_time_scale_shift: scale_shift 34 | sample_size: 128 35 | time_embedding_type: positional 36 | up_block_types: 37 | - UpBlock2D 38 | - UpBlock2D 39 | - UpBlock2D 40 | - UpBlock2D 41 | upsample_type: conv 42 | 43 | noise_scheduler: 44 | target: diffusers.EDMDPMSolverMultistepScheduler 45 | params: 46 | num_train_timesteps: 1000 47 | 48 | loss_fn: 49 | target: losses.loss.EDMLoss 50 | params: 51 | sigma_data: 0.5 52 | 53 | optimizer: 54 | betas: 55 | - 0.9 56 | - 0.999 57 | eps: 1e-08 58 | lr: 1e-4 59 | weight_decay: 1e-2 60 | 61 | lr_scheduler: 62 | #name: cosine 63 | name: constant 64 | num_warmup_steps: 500 65 | num_cycles: 0.5 66 | power: 1.0 67 | 68 | dataloader: 69 | data_dir: /scratch/kdur_root/kdur/ylzhuang/darcy_mod.npy 70 | batch_size: 16 71 | num_workers: 0 72 | split_ratios: 73 | - 0.8 74 | - 0.2 75 | - 0.0 76 | transform: normalize 77 | transform_args: 78 | mean: [1.1482742, 46.6775513] 79 | std: [0.6740283, 30.9050026] 80 | target_std: 0.5 81 | data_name: darcy 82 | 83 | accelerator: 84 | mixed_precision: fp16 85 | gradient_accumulation_steps: 1 86 | log_with: tensorboard 87 | 88 | ema: 89 | use_ema: True 90 | offload_ema: False 91 | ema_max_decay: 0.9999 92 | ema_inv_gamma: 1.0 93 | ema_power: 0.75 94 | foreach: True 95 | 96 | general: 97 | seed: 42 98 | num_epochs: null 99 | num_training_steps: 100000 100 | known_channels: [1] 101 | same_mask: False 102 | scale_lr: False 103 | output_dir: /scratch/kdur_root/kdur/ylzhuang/log/darcy_vt 104 | logging_dir: darcy 105 | tracker_project_name: darcy_tracker 106 | save_image_epochs: 50 107 | save_model_epochs: 200 108 | checkpointing_steps: 25000 109 | eval_batch_size: 8 110 | do_edm_style_training: True 111 | channel_names: ["permeability", "pressure"] -------------------------------------------------------------------------------- /configs/darcy_xattn_config.yaml: -------------------------------------------------------------------------------- 1 | unet: 2 | _class_name: diffuserUNet2DCondition 3 | _diffusers_version: 0.29.2 4 | act_fn: silu 5 | addition_embed_type: null 6 | addition_embed_type_num_heads: 64 7 | addition_time_embed_dim: null 8 | attention_head_dim: 8 9 | attention_type: default 10 | block_out_channels: 11 | - 128 12 | - 256 13 | - 256 14 | - 256 15 | center_input_sample: false 16 | class_embed_type: null 17 | class_embeddings_concat: false 18 | conv_in_kernel: 3 19 | conv_out_kernel: 3 20 | cross_attention_dim: 256 21 | cross_attention_norm: null 22 | down_block_types: 23 | - DownBlock2D 24 | - CrossAttnDownBlock2D 25 | - CrossAttnDownBlock2D 26 | - CrossAttnDownBlock2D 27 | downsample_padding: 1 28 | dropout: 0.0 29 | dual_cross_attention: false 30 | encoder_hid_dim: null 31 | encoder_hid_dim_type: null 32 | flip_sin_to_cos: true 33 | freq_shift: 0 34 | in_channels: 2 35 | layers_per_block: 2 36 | mid_block_only_cross_attention: null 37 | mid_block_scale_factor: 1 38 | mid_block_type: UNetMidBlock2DCrossAttn 39 | norm_eps: 1e-05 40 | norm_num_groups: 32 41 | num_attention_heads: null 42 | num_class_embeds: null 43 | only_cross_attention: false 44 | out_channels: 2 45 | projection_class_embeddings_input_dim: null 46 | resnet_out_scale_factor: 1.0 47 | resnet_skip_time_act: false 48 | resnet_time_scale_shift: scale_shift 49 | reverse_transformer_layers_per_block: null 50 | sample_size: 128 51 | time_cond_proj_dim: null 52 | time_embedding_act_fn: null 53 | time_embedding_dim: null 54 | time_embedding_type: positional 55 | timestep_post_act: null 56 | transformer_layers_per_block: 1 57 | up_block_types: 58 | - CrossAttnUpBlock2D 59 | - CrossAttnUpBlock2D 60 | - CrossAttnUpBlock2D 61 | - UpBlock2D 62 | upcast_attention: false 63 | use_linear_projection: false 64 | field_encoder_dict: 65 | hidden_size: 256 66 | intermediate_size: 1024 67 | projection_dim: 256 68 | image_size: 69 | - 128 70 | - 128 71 | patch_size: 8 72 | num_channels: 1 73 | num_hidden_layers: 4 74 | num_attention_heads: 8 75 | input_padding: 76 | - 0 77 | - 0 78 | output_hidden_state: true 79 | 80 | noise_scheduler: 81 | target: diffusers.EDMDPMSolverMultistepScheduler 82 | params: 83 | num_train_timesteps: 1000 84 | 85 | loss_fn: 86 | target: losses.loss.EDMLoss 87 | params: 88 | sigma_data: 0.5 89 | 90 | optimizer: 91 | betas: 92 | - 0.9 93 | - 0.999 94 | eps: 1e-08 95 | lr: 1e-4 96 | weight_decay: 1e-2 97 | 98 | lr_scheduler: 99 | #name: cosine 100 | name: constant 101 | num_warmup_steps: 500 102 | num_cycles: 0.5 103 | power: 1.0 104 | 105 | dataloader: 106 | data_dir: /scratch/kdur_root/kdur/ylzhuang/darcy_mod.npy 107 | batch_size: 16 108 | num_workers: 0 109 | split_ratios: 110 | - 0.8 111 | - 0.2 112 | - 0.0 113 | transform: normalize 114 | transform_args: 115 | mean: [1.1482742, 46.6775513] 116 | std: [0.6740283, 30.9050026] 117 | target_std: 0.5 118 | data_name: darcy 119 | 120 | accelerator: 121 | mixed_precision: fp16 122 | gradient_accumulation_steps: 1 123 | log_with: tensorboard 124 | 125 | ema: 126 | use_ema: True 127 | offload_ema: False 128 | ema_max_decay: 0.9999 129 | ema_inv_gamma: 1.0 130 | ema_power: 0.75 131 | foreach: True 132 | 133 | general: 134 | seed: 42 135 | num_epochs: null 136 | num_training_steps: 100000 137 | known_channels: [1] 138 | same_mask: False 139 | scale_lr: False 140 | output_dir: /scratch/kdur_root/kdur/ylzhuang/log/darcy_xattn 141 | logging_dir: darcy 142 | tracker_project_name: darcy_tracker 143 | save_image_epochs: 50 144 | save_model_epochs: 200 145 | checkpointing_steps: 25000 146 | eval_batch_size: 8 147 | cond_drop_prob: null 148 | do_edm_style_training: True 149 | snr_gamma: null 150 | channel_names: ["permeability", "pressure"] -------------------------------------------------------------------------------- /configs/diffusion_reaction_cfg_config.yaml: -------------------------------------------------------------------------------- 1 | unet: 2 | _class_name: diffuserUNet2DCondition 3 | _diffusers_version: 0.29.2 4 | act_fn: silu 5 | addition_embed_type: null 6 | addition_embed_type_num_heads: 64 7 | addition_time_embed_dim: null 8 | attention_head_dim: 8 9 | attention_type: default 10 | block_out_channels: 11 | - 128 12 | - 256 13 | - 256 14 | - 256 15 | center_input_sample: false 16 | class_embed_type: null 17 | class_embeddings_concat: false 18 | conv_in_kernel: 3 19 | conv_out_kernel: 3 20 | cross_attention_dim: 256 21 | cross_attention_norm: null 22 | down_block_types: 23 | - DownBlock2D 24 | - DownBlock2D 25 | - DownBlock2D 26 | - DownBlock2D 27 | downsample_padding: 1 28 | dropout: 0.0 29 | dual_cross_attention: false 30 | encoder_hid_dim: null 31 | encoder_hid_dim_type: null 32 | flip_sin_to_cos: true 33 | freq_shift: 0 34 | in_channels: 2 35 | layers_per_block: 2 36 | mid_block_only_cross_attention: null 37 | mid_block_scale_factor: 1 38 | mid_block_type: UNetMidBlock2D 39 | norm_eps: 1e-05 40 | norm_num_groups: 32 41 | num_attention_heads: null 42 | num_class_embeds: null 43 | only_cross_attention: false 44 | out_channels: 2 45 | projection_class_embeddings_input_dim: null 46 | resnet_out_scale_factor: 1.0 47 | resnet_skip_time_act: false 48 | resnet_time_scale_shift: scale_shift 49 | reverse_transformer_layers_per_block: null 50 | sample_size: 128 51 | time_cond_proj_dim: null 52 | time_embedding_act_fn: null 53 | time_embedding_dim: null 54 | time_embedding_type: positional 55 | timestep_post_act: null 56 | transformer_layers_per_block: 1 57 | up_block_types: 58 | - UpBlock2D 59 | - UpBlock2D 60 | - UpBlock2D 61 | - UpBlock2D 62 | upcast_attention: false 63 | use_linear_projection: false 64 | field_encoder_dict: 65 | hidden_size: 256 66 | intermediate_size: 1024 67 | projection_dim: 256 68 | image_size: 69 | - 128 70 | - 128 71 | patch_size: 8 72 | num_channels: 2 73 | num_hidden_layers: 4 74 | num_attention_heads: 8 75 | input_padding: 76 | - 0 77 | - 0 78 | output_hidden_state: false 79 | 80 | noise_scheduler: 81 | target: diffusers.EDMDPMSolverMultistepScheduler 82 | params: 83 | num_train_timesteps: 1000 84 | 85 | loss_fn: 86 | target: losses.loss.EDMLoss 87 | params: 88 | sigma_data: 0.5 89 | 90 | optimizer: 91 | betas: 92 | - 0.9 93 | - 0.999 94 | eps: 1e-08 95 | lr: 1e-4 96 | weight_decay: 1e-2 97 | 98 | lr_scheduler: 99 | #name: cosine 100 | name: constant 101 | num_warmup_steps: 500 102 | num_cycles: 0.5 103 | power: 1.0 104 | 105 | dataloader: 106 | data_dir: /scratch/kdur_root/kdur/shared_data/pdebench_data/2D/diffusion-reaction/2D_diff-react_NA_NA_reorg2.h5 107 | batch_size: 16 108 | num_workers: 0 109 | split_ratios: 110 | - 0.8 111 | - 0.2 112 | - 0.0 113 | transform: normalize 114 | transform_args: 115 | mean: [-0.0311127, -0.0199022] 116 | std: [0.1438150, 0.1117546] 117 | target_std: 0.5 118 | data_name: diffusion_reaction 119 | 120 | accelerator: 121 | mixed_precision: fp16 122 | gradient_accumulation_steps: 1 123 | log_with: tensorboard 124 | 125 | ema: 126 | use_ema: True 127 | offload_ema: False 128 | ema_max_decay: 0.9999 129 | ema_inv_gamma: 1.0 130 | ema_power: 0.75 131 | foreach: True 132 | 133 | general: 134 | seed: 42 135 | num_epochs: null 136 | num_training_steps: 100000 137 | known_channels: [0,1] 138 | same_mask: True 139 | scale_lr: False 140 | output_dir: /scratch/kdur_root/kdur/ylzhuang/log/diffusion_reaction_cfg 141 | logging_dir: diff_react 142 | tracker_project_name: diffreact_tracker 143 | save_image_epochs: 5 144 | save_model_epochs: 50 145 | checkpointing_steps: 25000 146 | eval_batch_size: 8 147 | cond_drop_prob: null 148 | do_edm_style_training: True 149 | snr_gamma: null 150 | channel_names: ["u", "v"] -------------------------------------------------------------------------------- /configs/diffusion_reaction_uncond_config.yaml: -------------------------------------------------------------------------------- 1 | unet: 2 | _class_name: UNet2DModel 3 | _diffusers_version: 0.28.2 4 | act_fn: silu 5 | add_attention: true 6 | attention_head_dim: 8 7 | attn_norm_num_groups: null 8 | block_out_channels: 9 | - 128 10 | - 256 11 | - 256 12 | - 256 13 | center_input_sample: false 14 | class_embed_type: null 15 | down_block_types: 16 | - DownBlock2D 17 | - DownBlock2D 18 | - DownBlock2D 19 | - DownBlock2D 20 | downsample_padding: 1 21 | downsample_type: conv 22 | dropout: 0.0 23 | flip_sin_to_cos: true 24 | freq_shift: 0 25 | in_channels: 2 26 | layers_per_block: 2 27 | mid_block_scale_factor: 1 28 | norm_eps: 1e-05 29 | norm_num_groups: 32 30 | num_class_embeds: null 31 | num_train_timesteps: null 32 | out_channels: 2 33 | resnet_time_scale_shift: scale_shift 34 | sample_size: 128 35 | time_embedding_type: positional 36 | up_block_types: 37 | - UpBlock2D 38 | - UpBlock2D 39 | - UpBlock2D 40 | - UpBlock2D 41 | upsample_type: conv 42 | 43 | noise_scheduler: 44 | target: diffusers.EDMDPMSolverMultistepScheduler 45 | params: 46 | num_train_timesteps: 1000 47 | 48 | loss_fn: 49 | target: losses.loss.EDMLoss 50 | params: 51 | sigma_data: 0.5 52 | 53 | optimizer: 54 | betas: 55 | - 0.9 56 | - 0.999 57 | eps: 1e-08 58 | lr: 1e-4 59 | weight_decay: 1e-2 60 | 61 | lr_scheduler: 62 | #name: cosine 63 | name: constant 64 | num_warmup_steps: 500 65 | num_cycles: 0.5 66 | power: 1.0 67 | 68 | dataloader: 69 | data_dir: /scratch/kdur_root/kdur/shared_data/pdebench_data/2D/diffusion-reaction/2D_diff-react_NA_NA_reorg2.h5 70 | batch_size: 16 71 | num_workers: 0 72 | split_ratios: 73 | - 0.8 74 | - 0.2 75 | - 0.0 76 | transform: normalize 77 | transform_args: 78 | mean: [-0.0311127, -0.0199022] 79 | std: [0.1438150, 0.1117546] 80 | target_std: 0.5 81 | data_name: diffusion_reaction 82 | 83 | accelerator: 84 | mixed_precision: fp16 85 | gradient_accumulation_steps: 1 86 | log_with: tensorboard 87 | 88 | ema: 89 | use_ema: True 90 | offload_ema: False 91 | ema_max_decay: 0.9999 92 | ema_inv_gamma: 1.0 93 | ema_power: 0.75 94 | foreach: True 95 | 96 | general: 97 | seed: 42 98 | num_epochs: null 99 | num_training_steps: 100000 100 | known_channels: [0,1] 101 | same_mask: True 102 | scale_lr: False 103 | output_dir: /scratch/kdur_root/kdur/ylzhuang/log/diffusion_reaction_uncond 104 | logging_dir: diff_react 105 | tracker_project_name: diffreact_tracker 106 | save_image_epochs: 5 107 | save_model_epochs: 50 108 | checkpointing_steps: 25000 109 | eval_batch_size: 8 110 | cond_drop_prob: null 111 | do_edm_style_training: True 112 | snr_gamma: null 113 | channel_names: ["u", "v"] -------------------------------------------------------------------------------- /configs/diffusion_reaction_vt_config.yaml: -------------------------------------------------------------------------------- 1 | unet: 2 | _class_name: UNet2DModel 3 | _diffusers_version: 0.28.2 4 | act_fn: silu 5 | add_attention: true 6 | attention_head_dim: 8 7 | attn_norm_num_groups: null 8 | block_out_channels: 9 | - 128 10 | - 256 11 | - 256 12 | - 256 13 | center_input_sample: false 14 | class_embed_type: null 15 | down_block_types: 16 | - DownBlock2D 17 | - DownBlock2D 18 | - DownBlock2D 19 | - DownBlock2D 20 | downsample_padding: 1 21 | downsample_type: conv 22 | dropout: 0.0 23 | flip_sin_to_cos: true 24 | freq_shift: 0 25 | in_channels: 2 26 | layers_per_block: 2 27 | mid_block_scale_factor: 1 28 | norm_eps: 1e-05 29 | norm_num_groups: 32 30 | num_class_embeds: null 31 | num_train_timesteps: null 32 | out_channels: 2 33 | sample_size: 128 34 | time_embedding_type: positional 35 | up_block_types: 36 | - UpBlock2D 37 | - UpBlock2D 38 | - UpBlock2D 39 | - UpBlock2D 40 | upsample_type: conv 41 | 42 | noise_scheduler: 43 | target: diffusers.EDMDPMSolverMultistepScheduler 44 | params: 45 | num_train_timesteps: 1000 46 | 47 | loss_fn: 48 | target: losses.loss.EDMLoss 49 | params: 50 | sigma_data: 0.5 51 | 52 | optimizer: 53 | betas: 54 | - 0.9 55 | - 0.999 56 | eps: 1e-08 57 | lr: 1e-4 58 | weight_decay: 1e-2 59 | 60 | lr_scheduler: 61 | #name: cosine 62 | name: constant 63 | num_warmup_steps: 500 64 | num_cycles: 0.5 65 | power: 1.0 66 | dataloader: 67 | data_dir: /scratch/kdur_root/kdur/shared_data/pdebench_data/2D/diffusion-reaction/2D_diff-react_NA_NA_reorg2.h5 68 | batch_size: 16 69 | num_workers: 0 70 | split_ratios: 71 | - 0.8 72 | - 0.2 73 | - 0.0 74 | transform: normalize 75 | transform_args: 76 | mean: [-0.0311127, -0.0199022] 77 | std: [0.1438150, 0.1117546] 78 | target_std: 0.5 79 | data_name: diffusion_reaction 80 | 81 | accelerator: 82 | mixed_precision: fp16 83 | gradient_accumulation_steps: 1 84 | log_with: tensorboard 85 | 86 | ema: 87 | use_ema: True 88 | offload_ema: False 89 | ema_max_decay: 0.9999 90 | ema_inv_gamma: 1.0 91 | ema_power: 0.75 92 | foreach: True 93 | 94 | general: 95 | seed: 42 96 | num_epochs: null 97 | num_training_steps: 100000 98 | known_channels: [0,1] 99 | same_mask: True 100 | scale_lr: False 101 | output_dir: /scratch/kdur_root/kdur/ylzhuang/log/diffusion_reaction_vt 102 | logging_dir: diff_react 103 | tracker_project_name: diffreact_tracker 104 | save_image_epochs: 5 105 | save_model_epochs: 50 106 | checkpointing_steps: 25000 107 | eval_batch_size: 8 108 | do_edm_style_training: True 109 | channel_names: ["u", "v"] -------------------------------------------------------------------------------- /configs/diffusion_reaction_xattn_config.yaml: -------------------------------------------------------------------------------- 1 | unet: 2 | _class_name: diffuserUNet2DCondition 3 | _diffusers_version: 0.29.2 4 | act_fn: silu 5 | addition_embed_type: null 6 | addition_embed_type_num_heads: 64 7 | addition_time_embed_dim: null 8 | attention_head_dim: 8 9 | attention_type: default 10 | block_out_channels: 11 | - 128 12 | - 256 13 | - 256 14 | - 256 15 | center_input_sample: false 16 | class_embed_type: null 17 | class_embeddings_concat: false 18 | conv_in_kernel: 3 19 | conv_out_kernel: 3 20 | cross_attention_dim: 256 21 | cross_attention_norm: null 22 | down_block_types: 23 | - DownBlock2D 24 | - CrossAttnDownBlock2D 25 | - CrossAttnDownBlock2D 26 | - CrossAttnDownBlock2D 27 | downsample_padding: 1 28 | dropout: 0.0 29 | dual_cross_attention: false 30 | encoder_hid_dim: null 31 | encoder_hid_dim_type: null 32 | flip_sin_to_cos: true 33 | freq_shift: 0 34 | in_channels: 2 35 | layers_per_block: 2 36 | mid_block_only_cross_attention: null 37 | mid_block_scale_factor: 1 38 | mid_block_type: UNetMidBlock2DCrossAttn 39 | norm_eps: 1e-05 40 | norm_num_groups: 32 41 | num_attention_heads: null 42 | num_class_embeds: null 43 | only_cross_attention: false 44 | out_channels: 2 45 | projection_class_embeddings_input_dim: null 46 | resnet_out_scale_factor: 1.0 47 | resnet_skip_time_act: false 48 | resnet_time_scale_shift: scale_shift 49 | reverse_transformer_layers_per_block: null 50 | sample_size: 128 51 | time_cond_proj_dim: null 52 | time_embedding_act_fn: null 53 | time_embedding_dim: null 54 | time_embedding_type: positional 55 | timestep_post_act: null 56 | transformer_layers_per_block: 1 57 | up_block_types: 58 | - CrossAttnUpBlock2D 59 | - CrossAttnUpBlock2D 60 | - CrossAttnUpBlock2D 61 | - UpBlock2D 62 | upcast_attention: false 63 | use_linear_projection: false 64 | field_encoder_dict: 65 | hidden_size: 256 66 | intermediate_size: 1024 67 | projection_dim: 256 68 | image_size: 69 | - 128 70 | - 128 71 | patch_size: 8 72 | num_channels: 2 73 | num_hidden_layers: 4 74 | num_attention_heads: 8 75 | input_padding: 76 | - 0 77 | - 0 78 | output_hidden_state: true 79 | 80 | noise_scheduler: 81 | target: diffusers.EDMDPMSolverMultistepScheduler 82 | params: 83 | num_train_timesteps: 1000 84 | 85 | loss_fn: 86 | target: losses.loss.EDMLoss 87 | params: 88 | sigma_data: 0.5 89 | 90 | optimizer: 91 | betas: 92 | - 0.9 93 | - 0.999 94 | eps: 1e-08 95 | lr: 1e-4 96 | weight_decay: 1e-2 97 | 98 | lr_scheduler: 99 | #name: cosine 100 | name: constant 101 | num_warmup_steps: 500 102 | num_cycles: 0.5 103 | power: 1.0 104 | 105 | dataloader: 106 | data_dir: /scratch/kdur_root/kdur/shared_data/pdebench_data/2D/diffusion-reaction/2D_diff-react_NA_NA_reorg2.h5 107 | batch_size: 16 108 | num_workers: 0 109 | split_ratios: 110 | - 0.8 111 | - 0.2 112 | - 0.0 113 | transform: normalize 114 | transform_args: 115 | mean: [-0.0311127, -0.0199022] 116 | std: [0.1438150, 0.1117546] 117 | target_std: 0.5 118 | data_name: diffusion_reaction 119 | 120 | accelerator: 121 | mixed_precision: fp16 122 | gradient_accumulation_steps: 1 123 | log_with: tensorboard 124 | 125 | ema: 126 | use_ema: True 127 | offload_ema: False 128 | ema_max_decay: 0.9999 129 | ema_inv_gamma: 1.0 130 | ema_power: 0.75 131 | foreach: True 132 | 133 | general: 134 | seed: 42 135 | num_epochs: null 136 | num_training_steps: 100000 137 | known_channels: [0,1] 138 | same_mask: True 139 | scale_lr: False 140 | output_dir: /scratch/kdur_root/kdur/ylzhuang/log/diffusion_reaction_xattn 141 | logging_dir: diff_react 142 | tracker_project_name: diffreact_tracker 143 | save_image_epochs: 5 144 | save_model_epochs: 50 145 | checkpointing_steps: 25000 146 | eval_batch_size: 8 147 | cond_drop_prob: null 148 | do_edm_style_training: True 149 | snr_gamma: null 150 | channel_names: ["u", "v"] -------------------------------------------------------------------------------- /configs/shallow_water_cfg_config.yaml: -------------------------------------------------------------------------------- 1 | unet: 2 | _class_name: diffuserUNet2DCondition 3 | _diffusers_version: 0.29.2 4 | act_fn: silu 5 | addition_embed_type: null 6 | addition_embed_type_num_heads: 64 7 | addition_time_embed_dim: null 8 | attention_head_dim: 8 9 | attention_type: default 10 | block_out_channels: 11 | - 128 12 | - 256 13 | - 256 14 | center_input_sample: false 15 | class_embed_type: null 16 | class_embeddings_concat: false 17 | conv_in_kernel: 3 18 | conv_out_kernel: 3 19 | cross_attention_dim: 384 20 | cross_attention_norm: null 21 | down_block_types: 22 | - DownBlock2D 23 | - DownBlock2D 24 | - DownBlock2D 25 | downsample_padding: 1 26 | dropout: 0.0 27 | dual_cross_attention: false 28 | encoder_hid_dim: null 29 | encoder_hid_dim_type: null 30 | flip_sin_to_cos: true 31 | freq_shift: 0 32 | in_channels: 3 33 | layers_per_block: 2 34 | mid_block_only_cross_attention: null 35 | mid_block_scale_factor: 1 36 | mid_block_type: UNetMidBlock2D 37 | norm_eps: 1e-05 38 | norm_num_groups: 32 39 | num_attention_heads: null 40 | num_class_embeds: null 41 | only_cross_attention: false 42 | out_channels: 3 43 | projection_class_embeddings_input_dim: null 44 | resnet_out_scale_factor: 1.0 45 | resnet_skip_time_act: false 46 | resnet_time_scale_shift: scale_shift 47 | reverse_transformer_layers_per_block: null 48 | sample_size: 64 49 | time_cond_proj_dim: null 50 | time_embedding_act_fn: null 51 | time_embedding_dim: null 52 | time_embedding_type: positional 53 | timestep_post_act: null 54 | transformer_layers_per_block: 1 55 | up_block_types: 56 | - UpBlock2D 57 | - UpBlock2D 58 | - UpBlock2D 59 | upcast_attention: false 60 | use_linear_projection: false 61 | field_encoder_dict: 62 | hidden_size: 384 63 | intermediate_size: 1536 64 | projection_dim: 384 65 | image_size: 66 | - 64 67 | - 64 68 | patch_size: 4 69 | num_channels: 3 70 | num_hidden_layers: 4 71 | num_attention_heads: 8 72 | input_padding: 73 | - 0 74 | - 0 75 | output_hidden_state: false 76 | 77 | 78 | noise_scheduler: 79 | target: diffusers.EDMDPMSolverMultistepScheduler 80 | params: 81 | num_train_timesteps: 1000 82 | 83 | loss_fn: 84 | target: losses.loss.EDMLoss 85 | params: 86 | sigma_data: 0.5 87 | 88 | optimizer: 89 | betas: 90 | - 0.9 91 | - 0.999 92 | eps: 1e-08 93 | lr: 1e-4 94 | weight_decay: 1e-2 95 | 96 | lr_scheduler: 97 | #name: cosine 98 | name: constant 99 | num_warmup_steps: 500 100 | num_cycles: 0.5 101 | power: 1.0 102 | 103 | dataloader: 104 | data_dir: /scratch/kdur_root/kdur/ylzhuang/uvh_unrolled.npy 105 | batch_size: 16 106 | num_workers: 0 107 | split_ratios: 108 | - 0.8 109 | - 0.2 110 | - 0.0 111 | transform: normalize 112 | transform_args: 113 | mean: [0., 0., 1.0055307] 114 | std: [0.0089056, 0.0089056, 0.0137398] 115 | target_std: 0.5 116 | data_name: shallow_water 117 | 118 | accelerator: 119 | mixed_precision: fp16 120 | gradient_accumulation_steps: 1 121 | log_with: tensorboard 122 | 123 | ema: 124 | use_ema: True 125 | offload_ema: False 126 | ema_max_decay: 0.9999 127 | ema_inv_gamma: 1.0 128 | ema_power: 0.75 129 | foreach: True 130 | 131 | general: 132 | seed: 42 133 | num_epochs: null 134 | num_training_steps: 100000 135 | known_channels: [0, 1, 2] 136 | same_mask: True 137 | scale_lr: False 138 | output_dir: /scratch/kdur_root/kdur/ylzhuang/log/shallow_water_cfg 139 | logging_dir: sw 140 | tracker_project_name: sw_tracker 141 | save_image_epochs: 50 142 | save_model_epochs: 200 143 | checkpointing_steps: 25000 144 | eval_batch_size: 8 145 | cond_drop_prob: null 146 | do_edm_style_training: True 147 | snr_gamma: null 148 | channel_names: ["u", "v", "h"] -------------------------------------------------------------------------------- /configs/shallow_water_uncond_config.yaml: -------------------------------------------------------------------------------- 1 | unet: 2 | _class_name: UNet2DModel 3 | _diffusers_version: 0.28.2 4 | act_fn: silu 5 | add_attention: true 6 | attention_head_dim: 8 7 | attn_norm_num_groups: null 8 | block_out_channels: 9 | - 128 10 | - 256 11 | - 256 12 | center_input_sample: false 13 | class_embed_type: null 14 | down_block_types: 15 | - DownBlock2D 16 | - DownBlock2D 17 | - DownBlock2D 18 | downsample_padding: 1 19 | downsample_type: conv 20 | dropout: 0.0 21 | flip_sin_to_cos: true 22 | freq_shift: 0 23 | in_channels: 3 24 | layers_per_block: 2 25 | mid_block_scale_factor: 1 26 | norm_eps: 1e-05 27 | norm_num_groups: 32 28 | num_class_embeds: null 29 | num_train_timesteps: null 30 | out_channels: 3 31 | resnet_time_scale_shift: scale_shift 32 | sample_size: 64 33 | time_embedding_type: positional 34 | up_block_types: 35 | - UpBlock2D 36 | - UpBlock2D 37 | - UpBlock2D 38 | upsample_type: conv 39 | 40 | noise_scheduler: 41 | target: diffusers.EDMDPMSolverMultistepScheduler 42 | params: 43 | num_train_timesteps: 1000 44 | 45 | loss_fn: 46 | target: losses.loss.EDMLoss 47 | params: 48 | sigma_data: 0.5 49 | 50 | optimizer: 51 | betas: 52 | - 0.9 53 | - 0.999 54 | eps: 1e-08 55 | lr: 1e-4 56 | weight_decay: 1e-2 57 | 58 | lr_scheduler: 59 | #name: cosine 60 | name: constant 61 | num_warmup_steps: 500 62 | num_cycles: 0.5 63 | power: 1.0 64 | 65 | dataloader: 66 | data_dir: /scratch/kdur_root/kdur/ylzhuang/uvh_unrolled.npy 67 | batch_size: 16 68 | num_workers: 0 69 | split_ratios: 70 | - 0.8 71 | - 0.2 72 | - 0.0 73 | transform: normalize 74 | transform_args: 75 | mean: [0., 0., 1.0055307] 76 | std: [0.0089056, 0.0089056, 0.0137398] 77 | target_std: 0.5 78 | data_name: shallow_water 79 | 80 | accelerator: 81 | mixed_precision: fp16 82 | gradient_accumulation_steps: 1 83 | log_with: tensorboard 84 | 85 | ema: 86 | use_ema: True 87 | offload_ema: False 88 | ema_max_decay: 0.9999 89 | ema_inv_gamma: 1.0 90 | ema_power: 0.75 91 | foreach: True 92 | 93 | general: 94 | seed: 42 95 | num_epochs: null 96 | num_training_steps: 100000 97 | known_channels: [0, 1, 2] 98 | same_mask: True 99 | scale_lr: False 100 | output_dir: /scratch/kdur_root/kdur/ylzhuang/log/shallow_water_uncond 101 | logging_dir: sw 102 | tracker_project_name: sw_tracker 103 | save_image_epochs: 50 104 | save_model_epochs: 200 105 | checkpointing_steps: 25000 106 | eval_batch_size: 8 107 | cond_drop_prob: null 108 | do_edm_style_training: True 109 | snr_gamma: null 110 | channel_names: ["u", "v", "h"] -------------------------------------------------------------------------------- /configs/shallow_water_vt_config.yaml: -------------------------------------------------------------------------------- 1 | unet: 2 | _class_name: UNet2DModel 3 | _diffusers_version: 0.28.2 4 | act_fn: silu 5 | add_attention: true 6 | attention_head_dim: 8 7 | attn_norm_num_groups: null 8 | block_out_channels: 9 | - 128 10 | - 256 11 | - 256 12 | center_input_sample: false 13 | class_embed_type: null 14 | down_block_types: 15 | - DownBlock2D 16 | - DownBlock2D 17 | - DownBlock2D 18 | downsample_padding: 1 19 | downsample_type: conv 20 | dropout: 0.0 21 | flip_sin_to_cos: true 22 | freq_shift: 0 23 | in_channels: 3 24 | layers_per_block: 2 25 | mid_block_scale_factor: 1 26 | norm_eps: 1e-05 27 | norm_num_groups: 32 28 | num_class_embeds: null 29 | num_train_timesteps: null 30 | out_channels: 3 31 | sample_size: 64 32 | time_embedding_type: positional 33 | up_block_types: 34 | - UpBlock2D 35 | - UpBlock2D 36 | - UpBlock2D 37 | upsample_type: conv 38 | 39 | noise_scheduler: 40 | target: diffusers.EDMDPMSolverMultistepScheduler 41 | params: 42 | num_train_timesteps: 1000 43 | 44 | loss_fn: 45 | target: losses.loss.EDMLoss 46 | params: 47 | sigma_data: 0.5 48 | 49 | optimizer: 50 | betas: 51 | - 0.9 52 | - 0.999 53 | eps: 1e-08 54 | lr: 1e-4 55 | weight_decay: 1e-2 56 | 57 | lr_scheduler: 58 | #name: cosine 59 | name: constant 60 | num_warmup_steps: 500 61 | num_cycles: 0.5 62 | power: 1.0 63 | 64 | dataloader: 65 | data_dir: /scratch/kdur_root/kdur/ylzhuang/uvh_unrolled.npy 66 | batch_size: 16 67 | num_workers: 0 68 | split_ratios: 69 | - 0.8 70 | - 0.2 71 | - 0.0 72 | transform: normalize 73 | transform_args: 74 | mean: [0., 0., 1.0055307] 75 | std: [0.0089056, 0.0089056, 0.0137398] 76 | target_std: 0.5 77 | data_name: shallow_water 78 | 79 | accelerator: 80 | mixed_precision: fp16 81 | gradient_accumulation_steps: 1 82 | log_with: tensorboard 83 | 84 | ema: 85 | use_ema: True 86 | offload_ema: False 87 | ema_max_decay: 0.9999 88 | ema_inv_gamma: 1.0 89 | ema_power: 0.75 90 | foreach: True 91 | 92 | general: 93 | seed: 42 94 | num_epochs: null 95 | num_training_steps: 100000 96 | known_channels: [0, 1, 2] 97 | same_mask: True 98 | scale_lr: False 99 | output_dir: /scratch/kdur_root/kdur/ylzhuang/log/shallow_water_vt 100 | logging_dir: sw 101 | tracker_project_name: sw_tracker 102 | save_image_epochs: 50 103 | save_model_epochs: 200 104 | checkpointing_steps: 25000 105 | eval_batch_size: 8 106 | do_edm_style_training: True 107 | snr_gamma: null 108 | channel_names: ["u", "v", "h"] -------------------------------------------------------------------------------- /configs/shallow_water_xattn_config.yaml: -------------------------------------------------------------------------------- 1 | unet: 2 | _class_name: diffuserUNet2DCondition 3 | _diffusers_version: 0.29.2 4 | act_fn: silu 5 | addition_embed_type: null 6 | addition_embed_type_num_heads: 64 7 | addition_time_embed_dim: null 8 | attention_head_dim: 8 9 | attention_type: default 10 | block_out_channels: 11 | - 128 12 | - 256 13 | - 256 14 | center_input_sample: false 15 | class_embed_type: null 16 | class_embeddings_concat: false 17 | conv_in_kernel: 3 18 | conv_out_kernel: 3 19 | cross_attention_dim: 384 20 | cross_attention_norm: null 21 | down_block_types: 22 | - DownBlock2D 23 | - CrossAttnDownBlock2D 24 | - CrossAttnDownBlock2D 25 | downsample_padding: 1 26 | dropout: 0.0 27 | dual_cross_attention: false 28 | encoder_hid_dim: null 29 | encoder_hid_dim_type: null 30 | flip_sin_to_cos: true 31 | freq_shift: 0 32 | in_channels: 3 33 | layers_per_block: 2 34 | mid_block_only_cross_attention: null 35 | mid_block_scale_factor: 1 36 | mid_block_type: UNetMidBlock2DCrossAttn 37 | norm_eps: 1e-05 38 | norm_num_groups: 32 39 | num_attention_heads: null 40 | num_class_embeds: null 41 | only_cross_attention: false 42 | out_channels: 3 43 | projection_class_embeddings_input_dim: null 44 | resnet_out_scale_factor: 1.0 45 | resnet_skip_time_act: false 46 | resnet_time_scale_shift: scale_shift 47 | reverse_transformer_layers_per_block: null 48 | sample_size: 64 49 | time_cond_proj_dim: null 50 | time_embedding_act_fn: null 51 | time_embedding_dim: null 52 | time_embedding_type: positional 53 | timestep_post_act: null 54 | transformer_layers_per_block: 1 55 | up_block_types: 56 | - CrossAttnUpBlock2D 57 | - CrossAttnUpBlock2D 58 | - UpBlock2D 59 | upcast_attention: false 60 | use_linear_projection: false 61 | field_encoder_dict: 62 | hidden_size: 384 63 | intermediate_size: 1536 64 | projection_dim: 384 65 | image_size: 66 | - 64 67 | - 64 68 | patch_size: 4 69 | num_channels: 3 70 | num_hidden_layers: 4 71 | num_attention_heads: 8 72 | input_padding: 73 | - 0 74 | - 0 75 | output_hidden_state: true 76 | 77 | 78 | noise_scheduler: 79 | target: diffusers.EDMDPMSolverMultistepScheduler 80 | params: 81 | num_train_timesteps: 1000 82 | 83 | loss_fn: 84 | target: losses.loss.EDMLoss 85 | params: 86 | sigma_data: 0.5 87 | 88 | optimizer: 89 | betas: 90 | - 0.9 91 | - 0.999 92 | eps: 1e-08 93 | lr: 1e-4 94 | weight_decay: 1e-2 95 | 96 | lr_scheduler: 97 | #name: cosine 98 | name: constant 99 | num_warmup_steps: 500 100 | num_cycles: 0.5 101 | power: 1.0 102 | 103 | dataloader: 104 | data_dir: /scratch/kdur_root/kdur/ylzhuang/uvh_unrolled.npy 105 | batch_size: 16 106 | num_workers: 0 107 | split_ratios: 108 | - 0.8 109 | - 0.2 110 | - 0.0 111 | transform: normalize 112 | transform_args: 113 | mean: [0., 0., 1.0055307] 114 | std: [0.0089056, 0.0089056, 0.0137398] 115 | target_std: 0.5 116 | data_name: shallow_water 117 | 118 | accelerator: 119 | mixed_precision: fp16 120 | gradient_accumulation_steps: 1 121 | log_with: tensorboard 122 | 123 | ema: 124 | use_ema: True 125 | offload_ema: False 126 | ema_max_decay: 0.9999 127 | ema_inv_gamma: 1.0 128 | ema_power: 0.75 129 | foreach: True 130 | 131 | general: 132 | seed: 42 133 | num_epochs: null 134 | num_training_steps: 100000 135 | known_channels: [0, 1, 2] 136 | same_mask: True 137 | scale_lr: False 138 | output_dir: /scratch/kdur_root/kdur/ylzhuang/log/shallow_water_xattn 139 | logging_dir: sw 140 | tracker_project_name: sw_tracker 141 | save_image_epochs: 50 142 | save_model_epochs: 200 143 | checkpointing_steps: 25000 144 | eval_batch_size: 8 145 | cond_drop_prob: null 146 | do_edm_style_training: True 147 | snr_gamma: null 148 | channel_names: ["u", "v", "h"] -------------------------------------------------------------------------------- /dataloader/dataset_class.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.utils.data import random_split 4 | from einops import rearrange 5 | from torch.utils.data import Dataset, DataLoader, Subset 6 | import xarray as xr 7 | 8 | import sys 9 | import os 10 | 11 | parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) 12 | sys.path.append(parent_dir) 13 | from utils.general_utils import read_hdf5_to_numpy 14 | 15 | def normalize_transform(sample, mean, std, target_std=1): 16 | # (C, H, W) 17 | mean = torch.tensor(mean, device=sample.device) 18 | std = torch.tensor(std, device=sample.device) 19 | return ((sample - mean[:, None, None]) / std[:, None, None]) * target_std 20 | 21 | def inverse_normalize_transform(normalized_sample, mean, std, target_std=1): 22 | # (C, H, W) 23 | mean = torch.tensor(mean, device=normalized_sample.device) 24 | std = torch.tensor(std, device=normalized_sample.device) 25 | return (normalized_sample / target_std) * std[:, None, None] + mean[:, None, None] 26 | 27 | class FullDataset(Dataset): 28 | def __init__(self, data, transform=None, transform_args=None): 29 | """ 30 | Args: 31 | data (numpy.ndarray): Full dataset. 32 | transform (callable, optional): Optional transform to be applied on a sample. 33 | """ 34 | self.data = data 35 | self.transform = transform 36 | if transform_args is None: 37 | transform_args = {} 38 | else: 39 | self.transform_args = transform_args 40 | 41 | def __len__(self): 42 | return len(self.data) 43 | 44 | def __getitem__(self, idx): 45 | sample = self.data[idx] 46 | if self.transform: 47 | sample = self.transform(sample, **self.transform_args) 48 | return sample 49 | 50 | class XarrayDataset2D(Dataset): 51 | def __init__( 52 | self, 53 | data: xr.Dataset, # (n, t, x, y), prefferable, but have to have char 'n' and 't' 54 | transform: str = None, 55 | transform_args: dict = None, 56 | load_in_memory: bool = False, # TODO: figure out how to save only one copy in memory 57 | ): 58 | if not load_in_memory: 59 | self.data = data 60 | else: 61 | self.data = data.load() 62 | self.transform = self._get_transform(transform, transform_args) 63 | self.transform_args = transform_args or {} 64 | self.var_names_list = list(data.data_vars.keys()) 65 | self.dims_dict = dict(data.dims) 66 | assert 'n' in self.dims_dict, 'Dataset must have dimension named "n".' 67 | assert 't' in self.dims_dict, 'Dataset must have dimension named "t".' 68 | self.length = self.dims_dict['n'] * self.dims_dict['t'] 69 | 70 | def _get_transform(self, transform, transform_args): 71 | if transform == 'normalize': 72 | mean = transform_args['mean'] 73 | std = transform_args['std'] 74 | if 'target_std' in transform_args: 75 | target_std = transform_args['target_std'] 76 | return lambda x: normalize_transform(x, mean, std, target_std) 77 | return lambda x: normalize_transform(x, mean, std) 78 | elif transform is None: 79 | return lambda x: x 80 | else: 81 | raise NotImplementedError(f'Transform: {transform} not implemented.') 82 | 83 | def _preprocess_data(self, data): 84 | data = torch.from_numpy(data).float() 85 | return self.transform(data) 86 | 87 | def _idx2nt(self, idx): 88 | return divmod(idx, self.dims_dict['t']) 89 | 90 | def get_array_from_xrdataset_2D(self, n_idx, t_idx): 91 | sliced_ds = self.data.isel({'n': n_idx, 't': t_idx}) 92 | # xr.to_array() is slower 93 | return np.stack([sliced_ds[var].values for var in self.var_names_list], axis=0) 94 | 95 | def __len__(self): 96 | return self.length 97 | 98 | def __getitem__(self, idx): 99 | assert idx >= 0, f'Index must be non-negative, got: {idx}.' 100 | n_idx, t_idx = self._idx2nt(idx) 101 | return self._preprocess_data(self.get_array_from_xrdataset_2D(n_idx, t_idx)) 102 | 103 | def npy2dataloader(full_dataset, batch_size, num_workers, split_ratios=(0.7, 0.2, 0.1), transform=None, transform_args=None, 104 | rearrange_args=None, random_dataset=True, generator=None, return_dataset=False): 105 | if rearrange_args is not None: 106 | full_dataset = rearrange(full_dataset, rearrange_args) 107 | full_dataset = torch.tensor(full_dataset, dtype=torch.float32) 108 | 109 | train_size = int(split_ratios[0] * len(full_dataset)) 110 | val_size = int(split_ratios[1] * len(full_dataset)) 111 | test_size = len(full_dataset) - train_size - val_size 112 | 113 | if random_dataset: 114 | print('Randomly splitting dataset.') 115 | train_dataset, val_dataset, test_dataset = random_split(full_dataset, [train_size, val_size, test_size], generator=generator) 116 | else: 117 | print('Splitting dataset by length.') 118 | train_dataset = Subset(full_dataset, list(range(0, train_size))) 119 | val_dataset = Subset(full_dataset, list(range(train_size, train_size+val_size))) 120 | test_dataset = Subset(full_dataset, list(range(train_size+val_size, len(full_dataset)))) 121 | 122 | if transform is not None: 123 | if transform == 'normalize': 124 | mean = transform_args['mean'] 125 | std = transform_args['std'] 126 | transform = normalize_transform 127 | if 'target_std' in transform_args: 128 | target_std = transform_args['target_std'] 129 | transform_args = {'mean': mean, 'std': std, 'target_std': target_std} 130 | else: 131 | transform_args = {'mean': mean, 'std': std} 132 | else: 133 | raise NotImplementedError(f'Transform: {transform} not implemented.') 134 | 135 | if not isinstance(full_dataset.data, xr.Dataset): 136 | # Apply the same transform to all splits if needed 137 | train_dataset = FullDataset(train_dataset, transform=transform, transform_args=transform_args) 138 | val_dataset = FullDataset(val_dataset, transform=transform, transform_args=transform_args) 139 | test_dataset = FullDataset(test_dataset, transform=transform, transform_args=transform_args) 140 | 141 | if return_dataset: 142 | return train_dataset, val_dataset, test_dataset 143 | 144 | # Create dataloaders 145 | train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, generator=generator) 146 | val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, generator=generator) 147 | test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, generator=generator) 148 | 149 | return train_loader, val_loader, test_loader 150 | 151 | def dataset2dataloader(full_dataset, batch_size, num_workers, split_ratios=(0.7, 0.2, 0.1), random_dataset=True, generator=None, return_dataset=False): 152 | train_size = int(split_ratios[0] * len(full_dataset)) 153 | val_size = int(split_ratios[1] * len(full_dataset)) 154 | test_size = len(full_dataset) - train_size - val_size 155 | 156 | if random_dataset: 157 | print('Randomly splitting dataset.') 158 | train_dataset, val_dataset, test_dataset = random_split(full_dataset, [train_size, val_size, test_size], generator=generator) 159 | else: 160 | print('Splitting dataset by length.') 161 | train_dataset = Subset(full_dataset, list(range(0, train_size))) 162 | val_dataset = Subset(full_dataset, list(range(train_size, train_size+val_size))) 163 | test_dataset = Subset(full_dataset, list(range(train_size+val_size, len(full_dataset)))) 164 | 165 | if return_dataset: 166 | return train_dataset, val_dataset, test_dataset 167 | 168 | # Create dataloaders 169 | train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, generator=generator) 170 | val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, generator=generator) 171 | test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, generator=generator) 172 | 173 | return train_loader, val_loader, test_loader 174 | 175 | def pdedata2dataloader(data_dir, batch_size=32, num_workers=1, split_ratios=(0.7, 0.2, 0.1), transform=None, transform_args=None, 176 | rearrange_args=None, generator=None, data_name=None, return_dataset=False, load_in_memory=False): 177 | 178 | if data_name == 'darcy': 179 | full_dataset = np.load(data_dir) 180 | return npy2dataloader(full_dataset, batch_size, num_workers, split_ratios=split_ratios, transform=transform, transform_args=transform_args, 181 | rearrange_args=rearrange_args, random_dataset=False, generator=generator, return_dataset=return_dataset) 182 | elif data_name == 'shallow_water': 183 | full_dataset = np.load(data_dir) 184 | return npy2dataloader(full_dataset, batch_size, num_workers, split_ratios=split_ratios, transform=transform, transform_args=transform_args, 185 | rearrange_args=rearrange_args, random_dataset=False, generator=generator, return_dataset=return_dataset) 186 | elif data_name == 'compressible_NS': 187 | # keys = ['Vx' , 'Vy', 'density', 'pressure'] 188 | xarray = xr.open_dataset(data_dir, phony_dims='access', engine='h5netcdf', 189 | drop_variables=['t-coordinate', 'x-coordinate', 'y-coordinate'], 190 | chunks='auto') 191 | xarray= xarray.rename({ 192 | 'phony_dim_0': 'n', 193 | 'phony_dim_1': 't', 194 | 'phony_dim_2': 'x', 195 | 'phony_dim_3': 'y' 196 | }) 197 | full_dataset = XarrayDataset2D(xarray, transform=transform, transform_args=transform_args, load_in_memory=load_in_memory) 198 | return dataset2dataloader(full_dataset, batch_size, num_workers, split_ratios=split_ratios, random_dataset=False, generator=generator, return_dataset=return_dataset) 199 | elif data_name == 'diffusion_reaction': 200 | # keys = ['u', 'v'] 201 | xarray = xr.open_dataset(data_dir, phony_dims='access', engine='h5netcdf', 202 | drop_variables=['t-coordinate', 'x-coordinate', 'y-coordinate'], 203 | chunks='auto') 204 | xarray= xarray.rename({ 205 | 'phony_dim_0': 't', 206 | 'phony_dim_1': 'n', 207 | 'phony_dim_2': 'x', 208 | 'phony_dim_3': 'y' 209 | }) 210 | full_dataset = XarrayDataset2D(xarray, transform=transform, transform_args=transform_args, load_in_memory=load_in_memory) 211 | return dataset2dataloader(full_dataset, batch_size, num_workers, split_ratios=split_ratios, random_dataset=False, generator=generator, return_dataset=return_dataset) 212 | else: 213 | raise NotImplementedError(f'Dataset: {data_name} not implemented.') -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import itertools 3 | from accelerate import Accelerator 4 | from accelerate.utils import set_seed 5 | from accelerate.logging import get_logger 6 | from diffusers import UNet2DModel 7 | import torch 8 | import numpy as np 9 | import pandas as pd 10 | from dataloader.dataset_class import pdedata2dataloader 11 | import os 12 | import copy 13 | from omegaconf import OmegaConf 14 | 15 | from pipelines.pipeline_inv_prob import InverseProblem2DPipeline, InverseProblem2DCondPipeline 16 | from models.unet2D import diffuserUNet2D 17 | from models.unet2DCondition import diffuserUNet2DCondition, diffuserUNet2DCFG 18 | from utils.general_utils import instantiate_from_config 19 | from utils.vt_utils import vt_obs 20 | from losses.metric import get_metrics_2D 21 | 22 | 23 | logger = get_logger(__name__, log_level="INFO") 24 | 25 | def parse_list_int(value): 26 | try: 27 | # Try to split by commas and convert to a list of integers 28 | steps = [int(x) for x in value.split(',')] 29 | return steps 30 | except ValueError: 31 | # If it fails, assume it's a single integer 32 | return [int(value)] 33 | 34 | def parse_list_float(value): 35 | try: 36 | # Try to split by commas and convert to a list of floats 37 | ratios = [float(x) for x in value.split(',')] 38 | return ratios 39 | except ValueError: 40 | # If it fails, assume it's a single float 41 | return [float(value)] 42 | 43 | def parse_args(): 44 | parser = argparse.ArgumentParser(description="Evaluate trained Diffusers model.") 45 | parser.add_argument('--config', type=str, required=True, help="Path to the YAML configuration file.") 46 | parser.add_argument('--repo_name', type=str, help="Repository name.") 47 | parser.add_argument('--subfolder', type=str, help="Subfolder in the repository.") 48 | parser.add_argument('--path_to_ckpt', type=str, help="Path to the checkpoint.") 49 | parser.add_argument( 50 | "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." 51 | ) 52 | parser.add_argument('--path_to_csv', type=str, required=True, help="Path to the CSV file for saving the results, if none exists, a new one will be created.") 53 | 54 | # eval 55 | parser.add_argument('--batch_size', type=int, default=32, help="Batch size for evaluation.") 56 | parser.add_argument('--num_inference_steps', type=parse_list_int, help='Number of inference steps, can be a single integer or a comma-separated list of integers.') 57 | parser.add_argument('--mask_ratios', type=parse_list_float, help='Mask ratios, can be a single float or a comma-separated list of floats.') 58 | parser.add_argument('--vt_spacing', type=parse_list_int, help="Spacing for dividing domain.") 59 | parser.add_argument('--mode', type=str, required=True, help="Mode for evaluation, 'edm', 'pipeline', 'vt'.") 60 | parser.add_argument('--conditioning_type', type=str, default='xattn', help="Conditioning type for evaluation.") 61 | parser.add_argument('--ensemble_size', type=int, default=25, help="Ensemble size for evaluation.") 62 | parser.add_argument('--channel_mean', action='store_true', help="Whether to output the channel mean.") 63 | parser.add_argument('--structure_sampling', action='store_true', default=False, help="Whether to sample with vt.") 64 | parser.add_argument('--noise_level', type=float, default=0., help="Noise level for evaluation.") 65 | parser.add_argument('--noise_type', type=str, default='white', help="Type of noise to add. Opitons: 'white', 'pink', 'red', 'blue', 'purple'.") 66 | parser.add_argument('--verbose', action='store_true', help="Whether to print verbose information.") 67 | parser.add_argument('--total_eval', type=int, default=1000, help="Total number of evaluation samples.") 68 | 69 | return parser.parse_args() 70 | 71 | def main(args): 72 | print(f"Received args: {args}") 73 | if args.config is None: 74 | raise ValueError("The config argument is missing. Please provide the path to the configuration file using --config.") 75 | config = OmegaConf.load(args.config) 76 | 77 | accelerator = Accelerator() 78 | 79 | use_vt = False 80 | if 'vt' in args.config: 81 | use_vt = True 82 | assert args.mode == 'vt', 'Mode should be vt when using vt' 83 | model_cls = diffuserUNet2D 84 | logger.info('Using vt model') 85 | else: 86 | if args.conditioning_type == 'cfg': 87 | model_cls = diffuserUNet2DCFG 88 | logger.info('Using cfg model') 89 | elif args.conditioning_type == 'xattn': 90 | model_cls = diffuserUNet2DCondition 91 | logger.info('Using xattn model') 92 | elif args.conditioning_type == 'uncond': 93 | model_cls = UNet2DModel 94 | logger.info('Using uncond model') 95 | else: 96 | raise NotImplementedError 97 | 98 | 99 | noise_scheduler_config = config.pop("noise_scheduler", OmegaConf.create()) 100 | dataloader_config = config.pop("dataloader", OmegaConf.create()) 101 | general_config = config.pop("general", OmegaConf.create()) 102 | 103 | set_seed(general_config.seed) 104 | 105 | repo_name = args.repo_name 106 | subfolder = args.subfolder 107 | 108 | if args.path_to_ckpt is None: 109 | unet = model_cls.from_pretrained(repo_name, 110 | subfolder=subfolder, 111 | use_safetensors=True, 112 | ) 113 | else: 114 | unet = model_cls.from_pretrained(args.path_to_ckpt, 115 | use_safetensors=True, 116 | ) 117 | 118 | noise_scheduler = instantiate_from_config(noise_scheduler_config) 119 | 120 | generator = torch.Generator(device=accelerator.device).manual_seed(general_config.seed) 121 | _, val_dataset, _ = pdedata2dataloader(**dataloader_config, generator=generator, 122 | return_dataset=True) 123 | 124 | #select_idx = np.random.choice(len(val_dataset), args.total_eval, replace=False) 125 | select_idx = np.arange(0, args.total_eval) 126 | reduced_val_dataset = torch.utils.data.Subset(val_dataset, select_idx) 127 | 128 | unet = accelerator.prepare_model(unet, evaluation_mode=True) 129 | resolution = unet.config.sample_size 130 | 131 | if args.conditioning_type == 'xattn' or args.conditioning_type == 'cfg': 132 | pipeline = InverseProblem2DCondPipeline(unet, scheduler=copy.deepcopy(noise_scheduler)) 133 | elif args.conditioning_type == 'uncond': 134 | pipeline = InverseProblem2DPipeline(unet, scheduler=copy.deepcopy(noise_scheduler)) 135 | else: 136 | raise NotImplementedError 137 | 138 | for num_inference_steps, mask_ratios, vt_spacing in itertools.product(args.num_inference_steps, args.mask_ratios, args.vt_spacing): 139 | 140 | sampler_kwargs = { 141 | "num_inference_steps": num_inference_steps, 142 | "known_channels": general_config.known_channels, 143 | #"same_mask": general_config.same_mask, 144 | } 145 | x_idx, y_idx = torch.meshgrid(torch.arange(vt_spacing, resolution, vt_spacing), torch.arange(vt_spacing, resolution, vt_spacing)) 146 | x_idx = x_idx.flatten().to(accelerator.device) 147 | y_idx = y_idx.flatten().to(accelerator.device) 148 | if "darcy" in args.subfolder: 149 | mask_kwargs = { 150 | "x_idx": x_idx, 151 | "y_idx": y_idx, 152 | "channels": general_config.known_channels, 153 | } 154 | else: 155 | mask_kwargs = { 156 | "ratio": mask_ratios, 157 | "channels": general_config.known_channels, 158 | } 159 | 160 | tmp_dim = unet.config.sample_size 161 | vt = vt_obs(x_dim=tmp_dim, y_dim=tmp_dim, x_spacing=vt_spacing, y_spacing=vt_spacing, known_channels=general_config.known_channels, device=accelerator.device) 162 | if not 'ratio' in mask_kwargs: 163 | if 'x_idx' in mask_kwargs: 164 | print('Total number of observation points: ', mask_kwargs["x_idx"].shape[0], ' Perceage of known points: ', mask_kwargs["x_idx"].shape[0] / tmp_dim**2) 165 | else: 166 | print('Total number of observation points: ', vt.x_start_grid.numel(), ' Perceage of known points: ', vt.x_start_grid.numel() / tmp_dim**2) 167 | else: 168 | print('Total number of observation points: ', mask_kwargs['ratio']*tmp_dim**2, ' Perceage of known points: ', mask_kwargs['ratio']) 169 | 170 | err_RMSE, err_nRMSE, err_CSV = get_metrics_2D(reduced_val_dataset, pipeline=pipeline, 171 | vt = vt, 172 | vt_model = unet if use_vt else None, 173 | batch_size = args.batch_size, 174 | ensemble_size = args.ensemble_size, 175 | sampler_kwargs = sampler_kwargs, 176 | mask_kwargs = mask_kwargs, 177 | known_channels=general_config.known_channels, 178 | device = accelerator.device, 179 | mode = args.mode, #'edm', 'pipeline', 'vt' 180 | conditioning_type = args.conditioning_type, 181 | inverse_transform = dataloader_config.transform, # 'normalize' 182 | inverse_transform_args = dataloader_config.transform_args, 183 | channel_mean = args.channel_mean, 184 | structure_sampling = args.structure_sampling, 185 | noise_level = args.noise_level, 186 | noise_type = args.noise_type, 187 | verbose = args.verbose, 188 | ) 189 | 190 | csv_filename = args.path_to_csv 191 | 192 | if args.mode != 'mean': 193 | if args.structure_sampling: 194 | index_value = f"{args.config.split('/')[-1].split('.')[0]}_spacing_{str(vt_spacing)}_mode_{args.mode}_step_{str(num_inference_steps)}" 195 | else: 196 | index_value = f"{args.config.split('/')[-1].split('.')[0]}_ratio_{str(mask_kwargs['ratio'])}_mode_{args.mode}_step_{str(num_inference_steps)}" 197 | else: 198 | index_value = f"{dataloader_config.data_name}_mean" 199 | 200 | if os.path.exists(csv_filename): 201 | df_existing = pd.read_csv(csv_filename, index_col=0) 202 | else: 203 | df_existing = pd.DataFrame(columns=['RMSE', 'nRMSE', 'CSV']) 204 | df_existing.index.name = 'Index' 205 | 206 | if "darcy" in args.subfolder: 207 | # For darcy we only save the permeability field 208 | df_new = pd.DataFrame({'RMSE': [err_RMSE.cpu().numpy()[0]], 209 | 'nRMSE': [err_nRMSE.cpu().numpy()[0]], 210 | 'CSV': [err_CSV.cpu().numpy()[0]]}, 211 | index=[index_value]) 212 | else: 213 | df_new = pd.DataFrame({'RMSE': [err_RMSE.cpu().numpy()], 214 | 'nRMSE': [err_nRMSE.cpu().numpy()], 215 | 'CSV': [err_CSV.cpu().numpy()]}, 216 | index=[index_value]) 217 | 218 | df_combined = pd.concat([df_existing, df_new]) 219 | df_combined.to_csv(csv_filename) 220 | 221 | logger.info(f'RMSE: {err_RMSE.cpu().numpy()}, nRMSE: {err_nRMSE.cpu().numpy()}, CSV: {err_CSV.cpu().numpy()}') 222 | 223 | if __name__ == "__main__": 224 | args = parse_args() 225 | main(args) -------------------------------------------------------------------------------- /git_assest/bar_chart_0.01_largefont.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonyzyl/DiffusionReconstruct/e1c11ad8034ebf27460c79b87ff5527cbb94a4f3/git_assest/bar_chart_0.01_largefont.png -------------------------------------------------------------------------------- /git_assest/darcy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonyzyl/DiffusionReconstruct/e1c11ad8034ebf27460c79b87ff5527cbb94a4f3/git_assest/darcy.png -------------------------------------------------------------------------------- /git_assest/dr_hf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonyzyl/DiffusionReconstruct/e1c11ad8034ebf27460c79b87ff5527cbb94a4f3/git_assest/dr_hf.png -------------------------------------------------------------------------------- /git_assest/encoding_block.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonyzyl/DiffusionReconstruct/e1c11ad8034ebf27460c79b87ff5527cbb94a4f3/git_assest/encoding_block.png -------------------------------------------------------------------------------- /git_assest/error_hist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonyzyl/DiffusionReconstruct/e1c11ad8034ebf27460c79b87ff5527cbb94a4f3/git_assest/error_hist.png -------------------------------------------------------------------------------- /losses/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class LpLoss(object): 4 | #loss function with rel/abs Lp loss, modified from neuralop: 5 | #https://github.com/neuraloperator/neuraloperator/blob/main/neuralop/losses/data_losses.py 6 | ''' 7 | LpLoss: Lp loss function, return the relative loss by default 8 | Args: 9 | d: int, start dimension of the field. E,g., for shape like (b, c, h, w), d=2 (default 1) 10 | p: int, p in Lp norm, default 2 11 | reduce_dims: int or list of int, dimensions to reduce 12 | reductions: str or list of str, 'sum' or 'mean' 13 | 14 | Call: (y_pred, y) 15 | ''' 16 | def __init__(self, d=1, p=2, reduce_dims=0, reductions='sum'): 17 | super().__init__() 18 | 19 | self.d = d 20 | self.p = p 21 | 22 | if isinstance(reduce_dims, int): 23 | self.reduce_dims = [reduce_dims] 24 | else: 25 | self.reduce_dims = reduce_dims 26 | 27 | if self.reduce_dims is not None: 28 | if isinstance(reductions, str): 29 | assert reductions == 'sum' or reductions == 'mean' 30 | self.reductions = [reductions]*len(self.reduce_dims) 31 | else: 32 | for j in range(len(reductions)): 33 | assert reductions[j] == 'sum' or reductions[j] == 'mean' 34 | self.reductions = reductions 35 | 36 | def reduce_all(self, x): 37 | for j in range(len(self.reduce_dims)): 38 | if self.reductions[j] == 'sum': 39 | x = torch.sum(x, dim=self.reduce_dims[j], keepdim=True) 40 | else: 41 | x = torch.mean(x, dim=self.reduce_dims[j], keepdim=True) 42 | 43 | return x 44 | 45 | def abs(self, x, y): 46 | diff = torch.norm(torch.flatten(x, start_dim=-self.d) - torch.flatten(y, start_dim=-self.d), \ 47 | p=self.p, dim=-1, keepdim=False) 48 | 49 | if self.reduce_dims is not None: 50 | diff = self.reduce_all(diff).squeeze() 51 | 52 | return diff 53 | 54 | def rel(self, x, y): 55 | 56 | diff = torch.norm(torch.flatten(x, start_dim=-self.d) - torch.flatten(y, start_dim=-self.d), \ 57 | p=self.p, dim=-1, keepdim=False) 58 | ynorm = torch.norm(torch.flatten(y, start_dim=-self.d), p=self.p, dim=-1, keepdim=False) 59 | 60 | diff = diff/ynorm 61 | 62 | if self.reduce_dims is not None: 63 | diff = self.reduce_all(diff).squeeze() 64 | 65 | return diff 66 | 67 | def __call__(self, y_pred, y, **kwargs): 68 | return self.rel(y_pred, y) 69 | 70 | class EDMLoss: 71 | def __init__(self, sigma_data=0.5): 72 | self.sigma_data = sigma_data 73 | 74 | def __call__(self, y_pred, y, sigma, **kwargs): 75 | # (comment from diffuser) We are not doing weighting here because it tends result in numerical problems. 76 | # See: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051 77 | # There might be other alternatives for weighting as well: 78 | # https://github.com/huggingface/diffusers/pull/7126#discussion_r1505404686 79 | weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2 80 | loss = weight * ((y - y_pred) ** 2) 81 | return loss.mean() 82 | 83 | class EDMLoss_reg: 84 | def __init__(self, sigma_data=0.5, reg_weight=0.001): 85 | self.sigma_data = sigma_data 86 | self.reg_weight = reg_weight 87 | 88 | def __call__(self, y_pred, y, sigma, **kwargs): 89 | weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2 90 | squared_err = ((y - y_pred) ** 2) 91 | csv_reg = (torch.sum(y_pred, dim=[-2, -1], keepdim=True) - torch.sum(y, dim=[-2, -1], keepdim=True))**2 92 | loss = weight * squared_err + self.reg_weight * csv_reg 93 | return loss.mean() -------------------------------------------------------------------------------- /losses/metric.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import copy 4 | import torch 5 | import math as mt 6 | from tqdm.auto import tqdm 7 | from torch.utils.data import DataLoader 8 | from diffusers.utils.torch_utils import randn_tensor 9 | from einops import rearrange, repeat 10 | 11 | parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) 12 | sys.path.append(parent_dir) 13 | from utils.inverse_utils import create_scatter_mask, ensemble_sample, colored_noise 14 | from dataloader.dataset_class import inverse_normalize_transform 15 | 16 | @torch.no_grad() 17 | def metric_func_2D(y_pred, y_true, mask=None, channel_mean=True): 18 | # Adapted from: https://github.com/pdebench/PDEBench/blob/main/pdebench/models/metrics.py 19 | # (B, C, H, W) 20 | ''' 21 | y_pred: torch.Tensor, shape (B, C, H, W) 22 | y_true: torch.Tensor, shape (B, C, H, W) 23 | mask: torch.Tensor, shape (B, H, W), optional 24 | 25 | Returns: err_RMSE, err_nRMSE, err_CSV 26 | ''' 27 | if mask is not None: 28 | unknown_mask = (1 - mask).float() 29 | num_unknowns = torch.sum(unknown_mask, dim=[-2, -1], keepdim=True) 30 | assert torch.all(num_unknowns) > 0, "All values are known. Cannot compute with mask." 31 | mse = torch.sum(((y_pred - y_true)**2) * unknown_mask, dim=[-2, -1], keepdim=True) / num_unknowns 32 | mse = rearrange(mse, 'b c 1 1 -> b c') 33 | nrm = torch.sqrt(torch.sum((y_true**2) * unknown_mask, dim=[-2, -1], keepdim=True) / num_unknowns) 34 | nrm = rearrange(nrm, 'b c 1 1 -> b c') 35 | csv_error = torch.mean((torch.sum(y_pred * unknown_mask, dim=[-2, -1]) - torch.sum(y_true * unknown_mask, dim=[-2, -1]))**2, dim=0) 36 | else: 37 | mse = torch.mean((y_pred - y_true)**2, dim=[-2, -1]) 38 | nrm = torch.sqrt(torch.mean(y_true**2, dim=[-2, -1])) 39 | csv_error = torch.mean((torch.sum(y_pred, dim=[-2, -1]) - torch.sum(y_true, dim=[-2, -1]))**2, dim=0) 40 | 41 | err_mean = torch.sqrt(mse) # (B, C) 42 | err_RMSE = torch.mean(err_mean, axis=0) # -> (C) 43 | err_nRMSE = torch.mean(err_mean / nrm, dim=0) # -> (C) 44 | err_CSV = torch.sqrt(csv_error) # (C) 45 | if mask is not None: 46 | err_CSV /= torch.mean(torch.sum(unknown_mask, dim=[-2, -1]), dim=0) 47 | else: 48 | err_CSV /= (y_true.shape[-2] * y_true.shape[-1]) 49 | 50 | # Channel mean 51 | if channel_mean: 52 | err_RMSE = torch.mean(err_RMSE, axis=0) 53 | err_nRMSE = torch.mean(err_nRMSE, axis=0) 54 | err_CSV = torch.mean(err_CSV, axis=0) 55 | 56 | return err_RMSE, err_nRMSE, err_CSV 57 | 58 | @torch.no_grad() 59 | def get_metrics_2D(val_dataset, pipeline=None, vt=None, vt_model=None, batch_size=64, ensemble_size=25, 60 | sampler_kwargs=None, mask_kwargs=None, known_channels=None, device='cpu', mode='edm', #'edm', 'pipeline', 'vt', 'mean' 61 | conditioning_type = None, #xattn, cfg 62 | inverse_transform=None, inverse_transform_args=None, channel_mean=True, 63 | structure_sampling=False, noise_level=0, noise_type='white', # 'white', 'pink', 'red', 'blue', 'purple' 64 | verbose=False): 65 | 66 | if inverse_transform == 'normalize': 67 | inverse_transform = inverse_normalize_transform 68 | 69 | if mode != 'mean': 70 | if mask_kwargs is None: 71 | mask_kwargs = {} 72 | if sampler_kwargs is None: 73 | sampler_kwargs = {} 74 | 75 | if mode == 'edm': 76 | model = pipeline.unet 77 | noise_scheduler = copy.deepcopy(pipeline.scheduler) 78 | 79 | if "x_idx" in mask_kwargs and "y_idx" in mask_kwargs: 80 | print("Using structure sampling, x_idx and y_idx are provided.") 81 | assert structure_sampling, "Structure sampling must be enabled when x_idx and y_idx are provided." 82 | x_idx, y_idx = mask_kwargs["x_idx"], mask_kwargs["y_idx"] 83 | mask_kwargs.pop("x_idx") 84 | mask_kwargs.pop("y_idx") 85 | ignore_idx_check = True 86 | else: 87 | if structure_sampling: 88 | print("Using structure sampling, x_idx and y_idx are not provided, selecting random points from each grid.") 89 | else: 90 | print("Not using structure sampling. points are uniformly sampled.") 91 | ignore_idx_check = False 92 | 93 | 94 | print(f'Calculating metrics for {mode}, ensemble size: {ensemble_size}') 95 | else: 96 | print('Calculating metrics for mean, all other arguments are ignored.') 97 | 98 | if structure_sampling and noise_type != 'white': 99 | Warning("Colored noise is not supported with structure sampling. Ignoring noise_type.") 100 | 101 | count = 0 102 | val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) 103 | for step, y_true in tqdm(enumerate(val_loader), total=len(val_loader)): 104 | # (B, C, H, W) 105 | y_true = y_true.to(device) 106 | num_sample, C, _, _ = y_true.shape 107 | if mode != 'mean': 108 | generator = [torch.Generator(device).manual_seed(int(seed) % (1 << 32)) for seed in range(count, count+num_sample)] 109 | if mode == 'vt': 110 | in_channels = vt_model.config.in_channels 111 | else: 112 | in_channels = C if known_channels is None else len(known_channels) 113 | interpolated_fields = torch.empty((num_sample, in_channels, y_true.shape[-2], y_true.shape[-1]), device=device, dtype=y_true.dtype) 114 | if structure_sampling: 115 | scatter_mask = torch.empty_like(y_true) 116 | 117 | for b in range(num_sample): 118 | if ("x_idx" not in mask_kwargs and "y_idx" not in mask_kwargs) and not ignore_idx_check: 119 | x_idx, y_idx = vt.structure_obs(generator=generator[b]) 120 | 121 | grid_points = vt._get_grid_points(x_idx, y_idx, generator=generator[b]) 122 | 123 | for idx, known_channel in enumerate(range(C) if known_channels is None else known_channels): 124 | field = y_true[b, known_channel][grid_points[:,1], grid_points[:,0]].flatten() 125 | if noise_level > 0: 126 | # drop support for colored noise in this case 127 | noise = randn_tensor(field.shape, device=device) 128 | field += noise * noise_level * field # y_obs = y_true + noise_level * y_true 129 | interpolated_values = vt.interpolate(grid_points, field) 130 | interpolated_fields[b, idx] = torch.tensor(interpolated_values, 131 | dtype=y_true.dtype, 132 | device=device) 133 | tmp_mask = tmp_mask = create_scatter_mask(y_true, x_idx=x_idx, y_idx=y_idx, device=device, **mask_kwargs) 134 | scatter_mask[b] = tmp_mask[0] 135 | else: 136 | scatter_mask = create_scatter_mask(y_true, generator=generator, device=device, **mask_kwargs) 137 | tmp_y_true = y_true.clone() 138 | if noise_level > 0: 139 | if noise_type == 'white': 140 | noise = randn_tensor(tmp_y_true.shape, device=device) 141 | else: 142 | noise = colored_noise(tmp_y_true.shape, noise_type=noise_type, device=device) 143 | tmp_y_true += noise * noise_level * tmp_y_true # y_obs = y_true + noise_level * y_true 144 | interpolated_fields = vt(known_fields=tmp_y_true, mask=scatter_mask) 145 | 146 | if mode == 'edm': 147 | y_pred = torch.empty_like(y_true) 148 | for b in range(num_sample): 149 | y_pred[b] = ensemble_sample(pipeline, ensemble_size, scatter_mask[[b]], sampler_kwargs=sampler_kwargs, conditioning_type=conditioning_type, 150 | class_labels=None, known_latents=interpolated_fields[[b]], sampler_type=mode, device=device).mean(dim=0) 151 | elif mode == 'pipeline': 152 | y_pred = torch.empty_like(y_true) 153 | for b in range(num_sample): 154 | y_pred[b] = ensemble_sample(pipeline, ensemble_size, scatter_mask[[b]], sampler_kwargs=sampler_kwargs, conditioning_type=conditioning_type, 155 | class_labels=None, known_latents=interpolated_fields[[b]], sampler_type=mode, device=device).mean(dim=0) 156 | 157 | elif mode == 'vt': 158 | y_pred = vt_model(interpolated_fields, return_dict=False)[0] 159 | 160 | else: 161 | raise NotImplementedError(f'Mode: {mode} not implemented.') 162 | 163 | if inverse_transform is not None: 164 | y_pred = inverse_transform(y_pred, **inverse_transform_args) 165 | y_true = inverse_transform(y_true, **inverse_transform_args) 166 | else: 167 | y_true = inverse_transform(y_true, **inverse_transform_args) 168 | y_pred = torch.ones_like(y_true) * repeat(torch.tensor(inverse_transform_args['mean'], device=device), 'c -> b c 1 1', b=num_sample) 169 | 170 | 171 | _err_RMSE, _err_nRMSE, _err_CSV = metric_func_2D(y_pred, y_true, 172 | mask=scatter_mask if mode != 'mean' else None, 173 | channel_mean=channel_mean) 174 | 175 | if step == 0: 176 | err_RMSE = _err_RMSE * num_sample 177 | err_nRMSE = _err_nRMSE * num_sample 178 | err_CSV = _err_CSV * num_sample 179 | else: 180 | err_RMSE += _err_RMSE * num_sample 181 | err_nRMSE += _err_nRMSE * num_sample 182 | err_CSV += _err_CSV * num_sample 183 | 184 | count += num_sample 185 | 186 | if verbose: 187 | print(f'RMSE: {err_RMSE / count}, nRMSE: {err_nRMSE / count}, CSV: {err_CSV / count}') 188 | 189 | return err_RMSE / count, err_nRMSE / count, err_CSV / count -------------------------------------------------------------------------------- /models/unet2D.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | from diffusers.utils import BaseOutput 5 | from diffusers.configuration_utils import ConfigMixin, register_to_config 6 | from diffusers.models.modeling_utils import ModelMixin 7 | from diffusers.models.unets.unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block 8 | from dataclasses import dataclass 9 | from typing import Optional, Tuple, Union 10 | 11 | @dataclass 12 | class UNet2DOutput(BaseOutput): 13 | 14 | """ 15 | The output of [`UNet2DModel`]. 16 | 17 | Args: 18 | sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`): 19 | The hidden states output from the last layer of the model. 20 | """ 21 | 22 | sample: torch.Tensor 23 | 24 | class diffuserUNet2D(ModelMixin, ConfigMixin): 25 | r""" 26 | A 2D UNet model with building blocks from diffusers. 27 | 28 | This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented 29 | for all models (such as downloading or saving). 30 | """ 31 | 32 | @register_to_config 33 | def __init__( 34 | self, 35 | sample_size: Optional[Union[int, Tuple[int, int]]] = None, 36 | in_channels: int = 3, 37 | out_channels: int = 3, 38 | center_input_sample: bool = False, 39 | time_embedding_type: str = "positional", 40 | freq_shift: int = 0, 41 | flip_sin_to_cos: bool = True, 42 | down_block_types: Tuple[str, ...] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"), 43 | up_block_types: Tuple[str, ...] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"), 44 | block_out_channels: Tuple[int, ...] = (224, 448, 672, 896), 45 | layers_per_block: int = 2, 46 | mid_block_scale_factor: float = 1, 47 | downsample_padding: int = 1, 48 | downsample_type: str = "conv", 49 | upsample_type: str = "conv", 50 | dropout: float = 0.0, 51 | act_fn: str = "silu", 52 | attention_head_dim: Optional[int] = 8, 53 | norm_num_groups: int = 32, 54 | attn_norm_num_groups: Optional[int] = None, 55 | norm_eps: float = 1e-5, 56 | resnet_time_scale_shift: str = "default", 57 | add_attention: bool = True, 58 | class_embed_type: Optional[str] = None, 59 | num_class_embeds: Optional[int] = None, 60 | num_train_timesteps: Optional[int] = None, 61 | ): 62 | super().__init__() 63 | 64 | self.sample_size = sample_size 65 | time_embed_dim = block_out_channels[0] * 4 if num_class_embeds is not None else None 66 | 67 | # Check inputs 68 | if len(down_block_types) != len(up_block_types): 69 | raise ValueError( 70 | f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." 71 | ) 72 | 73 | if len(block_out_channels) != len(down_block_types): 74 | raise ValueError( 75 | f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." 76 | ) 77 | 78 | # input 79 | self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) 80 | 81 | # class embedding 82 | if class_embed_type is None and num_class_embeds is not None: 83 | self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) 84 | elif class_embed_type == "identity": 85 | self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) 86 | else: 87 | self.class_embedding = None 88 | 89 | self.down_blocks = nn.ModuleList([]) 90 | self.mid_block = None 91 | self.up_blocks = nn.ModuleList([]) 92 | 93 | # down 94 | output_channel = block_out_channels[0] 95 | for i, down_block_type in enumerate(down_block_types): 96 | input_channel = output_channel 97 | output_channel = block_out_channels[i] 98 | is_final_block = i == len(block_out_channels) - 1 99 | 100 | down_block = get_down_block( 101 | down_block_type, 102 | num_layers=layers_per_block, 103 | in_channels=input_channel, 104 | out_channels=output_channel, 105 | temb_channels=time_embed_dim, 106 | add_downsample=not is_final_block, 107 | resnet_eps=norm_eps, 108 | resnet_act_fn=act_fn, 109 | resnet_groups=norm_num_groups, 110 | attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel, 111 | downsample_padding=downsample_padding, 112 | resnet_time_scale_shift=resnet_time_scale_shift, 113 | downsample_type=downsample_type, 114 | dropout=dropout, 115 | ) 116 | self.down_blocks.append(down_block) 117 | 118 | # mid 119 | self.mid_block = UNetMidBlock2D( 120 | in_channels=block_out_channels[-1], 121 | temb_channels=time_embed_dim, 122 | dropout=dropout, 123 | resnet_eps=norm_eps, 124 | resnet_act_fn=act_fn, 125 | output_scale_factor=mid_block_scale_factor, 126 | resnet_time_scale_shift=resnet_time_scale_shift, 127 | attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1], 128 | resnet_groups=norm_num_groups, 129 | attn_groups=attn_norm_num_groups, 130 | add_attention=add_attention, 131 | ) 132 | 133 | # up 134 | reversed_block_out_channels = list(reversed(block_out_channels)) 135 | output_channel = reversed_block_out_channels[0] 136 | for i, up_block_type in enumerate(up_block_types): 137 | prev_output_channel = output_channel 138 | output_channel = reversed_block_out_channels[i] 139 | input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] 140 | 141 | is_final_block = i == len(block_out_channels) - 1 142 | 143 | up_block = get_up_block( 144 | up_block_type, 145 | num_layers=layers_per_block + 1, 146 | in_channels=input_channel, 147 | out_channels=output_channel, 148 | prev_output_channel=prev_output_channel, 149 | temb_channels=time_embed_dim, 150 | add_upsample=not is_final_block, 151 | resnet_eps=norm_eps, 152 | resnet_act_fn=act_fn, 153 | resnet_groups=norm_num_groups, 154 | attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel, 155 | resnet_time_scale_shift=resnet_time_scale_shift, 156 | upsample_type=upsample_type, 157 | dropout=dropout, 158 | ) 159 | self.up_blocks.append(up_block) 160 | prev_output_channel = output_channel 161 | 162 | # out 163 | num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32) 164 | self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=norm_eps) 165 | self.conv_act = nn.SiLU() 166 | self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1) 167 | 168 | def forward( 169 | self, 170 | sample: torch.Tensor, 171 | class_labels: Optional[torch.Tensor] = None, 172 | return_dict: bool = True, 173 | ) -> Union[UNet2DOutput, Tuple]: 174 | r""" 175 | The [`UNet2DModel`] forward method. 176 | 177 | Args: 178 | sample (`torch.Tensor`): 179 | The noisy input tensor with the following shape `(batch, channel, height, width)`. 180 | timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input. 181 | class_labels (`torch.Tensor`, *optional*, defaults to `None`): 182 | Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. 183 | return_dict (`bool`, *optional*, defaults to `True`): 184 | Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple. 185 | 186 | Returns: 187 | [`~models.unet_2d.UNet2DOutput`] or `tuple`: 188 | If `return_dict` is True, an [`~models.unet_2d.UNet2DOutput`] is returned, otherwise a `tuple` is 189 | returned where the first element is the sample tensor. 190 | """ 191 | # 0. center input if necessary 192 | if self.config.center_input_sample: 193 | sample = 2 * sample - 1.0 194 | 195 | emb = None 196 | 197 | if self.class_embedding is not None: 198 | if class_labels is None: 199 | raise ValueError("class_labels should be provided when doing class conditioning") 200 | 201 | class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) 202 | emb = class_emb 203 | elif self.class_embedding is None and class_labels is not None: 204 | raise ValueError("class_embedding needs to be initialized in order to use class conditioning") 205 | 206 | # 2. pre-process 207 | skip_sample = sample 208 | sample = self.conv_in(sample) 209 | 210 | # 3. down 211 | down_block_res_samples = (sample,) 212 | for downsample_block in self.down_blocks: 213 | if hasattr(downsample_block, "skip_conv"): 214 | sample, res_samples, skip_sample = downsample_block( 215 | hidden_states=sample, temb=emb, skip_sample=skip_sample 216 | ) 217 | else: 218 | sample, res_samples = downsample_block(hidden_states=sample, temb=emb) 219 | 220 | down_block_res_samples += res_samples 221 | 222 | # 4. mid 223 | sample = self.mid_block(sample, emb) 224 | 225 | # 5. up 226 | skip_sample = None 227 | for upsample_block in self.up_blocks: 228 | res_samples = down_block_res_samples[-len(upsample_block.resnets) :] 229 | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] 230 | 231 | if hasattr(upsample_block, "skip_conv"): 232 | sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample) 233 | else: 234 | sample = upsample_block(sample, res_samples, emb) 235 | 236 | # 6. post-process 237 | sample = self.conv_norm_out(sample) 238 | sample = self.conv_act(sample) 239 | sample = self.conv_out(sample) 240 | 241 | if skip_sample is not None: 242 | sample += skip_sample 243 | 244 | if not return_dict: 245 | return (sample,) 246 | 247 | return UNet2DOutput(sample=sample) -------------------------------------------------------------------------------- /noise_schedulers/noise_sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class Karras_sigmas_lognormal: 4 | def __init__(self, sigmas, P_mean=-1.2, P_std=1.2, sigma_min=0.002, sigma_max=80, num_train_timesteps=1000): 5 | self.P_mean = P_mean 6 | self.P_std = P_std 7 | self.sigma_min = sigma_min 8 | self.sigma_max = sigma_max 9 | self.num_train_timesteps = num_train_timesteps 10 | self.sigmas = sigmas 11 | 12 | def __call__(self, batch_size, generator=None, device='cpu'): 13 | rnd_normal = torch.randn([batch_size, 1, 1, 1], device=device, generator=generator) 14 | sigma = (rnd_normal * self.P_std + self.P_mean).exp() 15 | # Find the indices of the closest matches 16 | # Expand self.sigmas to match the batch size 17 | # sigmas get concatenated with 0 at the end 18 | sigmas_expanded = self.sigmas[:-1].view(1, -1).to(device) 19 | sigma_expanded = sigma.view(batch_size, 1) 20 | 21 | # Calculate the difference and find the indices of the minimum difference 22 | diff = torch.abs(sigmas_expanded - sigma_expanded) 23 | indices = torch.argmin(diff, dim=1) 24 | 25 | return indices -------------------------------------------------------------------------------- /pipelines/pipeline_inv_prob.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput 3 | from diffusers.utils.torch_utils import randn_tensor 4 | from typing import Optional, Union, Tuple, List 5 | 6 | from utils.pipeline_utils import Fields2DPipelineOutput 7 | 8 | class InverseProblem2DPipeline(DiffusionPipeline): 9 | 10 | model_cpu_offload_seq = "unet" 11 | 12 | def __init__(self, unet, scheduler, scheduler_step_kwargs: Optional[dict] = None): 13 | super().__init__() 14 | 15 | self.register_modules(unet=unet, scheduler=scheduler) 16 | self.scheduler_step_kwargs = scheduler_step_kwargs or {} 17 | 18 | @torch.no_grad() 19 | def __call__( 20 | self, 21 | batch_size: int = 1, 22 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 23 | num_inference_steps: int = 50, 24 | return_dict: bool = True, 25 | mask: torch.Tensor = None, 26 | known_channels: List[int] = None, 27 | known_latents: torch.Tensor = None, 28 | ) -> Union[Fields2DPipelineOutput, Tuple]: 29 | 30 | # Sample gaussian noise to begin loop 31 | if isinstance(self.unet.config.sample_size, int): 32 | image_shape = ( 33 | batch_size, 34 | self.unet.config.out_channels, 35 | self.unet.config.sample_size, 36 | self.unet.config.sample_size, 37 | ) 38 | else: 39 | image_shape = (batch_size, self.unet.config.out_channels, *self.unet.config.sample_size) 40 | 41 | if isinstance(generator, list) and len(generator) != batch_size: 42 | raise ValueError( 43 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 44 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 45 | ) 46 | 47 | assert known_latents is not None, "known_latents must be provided" 48 | 49 | image = randn_tensor(image_shape, generator=generator, device=self._execution_device, dtype=self.unet.dtype) 50 | noise = image.clone() 51 | 52 | # set step values 53 | self.scheduler.set_timesteps(num_inference_steps) 54 | 55 | for i, t in enumerate(self.scheduler.timesteps): 56 | x_in = self.scheduler.scale_model_input(image, t) 57 | # 1. predict noise model_output 58 | model_output = self.unet(x_in, t, return_dict=False)[0] 59 | model_output = model_output * (1 - mask) + known_latents * mask 60 | 61 | # 2. do x_t -> x_t-1 62 | image = self.scheduler.step( 63 | model_output, t, image[:, :self.unet.config.out_channels], **self.scheduler_step_kwargs, 64 | return_dict=False 65 | )[0] 66 | 67 | tmp_known_latents = known_latents.clone() 68 | if i < len(self.scheduler.timesteps) - 1: 69 | noise_timestep = self.scheduler.timesteps[i + 1] 70 | tmp_known_latents = self.scheduler.add_noise(tmp_known_latents, noise, torch.tensor([noise_timestep])) 71 | 72 | image = image * (1 - mask) + tmp_known_latents * mask 73 | 74 | if not return_dict: 75 | return (image,) 76 | 77 | return Fields2DPipelineOutput(fields=image) 78 | 79 | class InverseProblem2DCondPipeline(DiffusionPipeline): 80 | 81 | model_cpu_offload_seq = "unet" 82 | 83 | def __init__(self, unet, scheduler, scheduler_step_kwargs: Optional[dict] = None): 84 | super().__init__() 85 | 86 | self.register_modules(unet=unet, scheduler=scheduler) 87 | self.scheduler_step_kwargs = scheduler_step_kwargs or {} 88 | 89 | @torch.no_grad() 90 | def __call__( 91 | self, 92 | batch_size: int = 1, 93 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 94 | num_inference_steps: int = 50, 95 | return_dict: bool = True, 96 | mask: torch.Tensor = None, 97 | known_channels: List[int] = None, 98 | known_latents: torch.Tensor = None, 99 | add_noise_to_obs: bool = False, 100 | ) -> Union[Fields2DPipelineOutput, Tuple]: 101 | 102 | # Sample gaussian noise to begin loop 103 | if isinstance(self.unet.config.sample_size, int): 104 | image_shape = ( 105 | batch_size, 106 | self.unet.config.out_channels, 107 | self.unet.config.sample_size, 108 | self.unet.config.sample_size, 109 | ) 110 | else: 111 | image_shape = (batch_size, self.unet.config.out_channels, *self.unet.config.sample_size) 112 | 113 | if isinstance(generator, list) and len(generator) != batch_size: 114 | raise ValueError( 115 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 116 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 117 | ) 118 | 119 | assert known_latents is not None, "known_latents must be provided" 120 | 121 | image = randn_tensor(image_shape, generator=generator, device=self._execution_device, dtype=self.unet.dtype) 122 | if add_noise_to_obs: 123 | noise = image.clone() 124 | conditioning_tensors = torch.cat((known_latents, mask[:, [known_channels[0]]]), dim=1) 125 | 126 | # set step values 127 | self.scheduler.set_timesteps(num_inference_steps) 128 | 129 | for i, t in enumerate(self.scheduler.timesteps): 130 | if mask is not None and not add_noise_to_obs: 131 | image = image * (1 - mask) + known_latents * mask 132 | x_in = self.scheduler.scale_model_input(image, t) 133 | # 1. predict noise model_output 134 | model_output = self.unet(x_in, t, conditioning_tensors=conditioning_tensors, return_dict=False)[0] 135 | 136 | # 2. do x_t -> x_t-1 137 | image = self.scheduler.step( 138 | model_output, t, image, **self.scheduler_step_kwargs, 139 | return_dict=False 140 | )[0] 141 | if add_noise_to_obs: 142 | tmp_known_latents = known_latents.clone() 143 | if i < len(self.scheduler.timesteps) - 1: 144 | noise_timestep = self.scheduler.timesteps[i + 1] 145 | tmp_known_latents = self.scheduler.add_noise(tmp_known_latents, noise, torch.tensor([noise_timestep])) 146 | 147 | if mask is not None: 148 | image = image * (1 - mask) + known_latents * mask 149 | 150 | if not return_dict: 151 | return (image,) 152 | 153 | return Fields2DPipelineOutput(fields=image) -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Diffusion models for field reconstruction 2 | --- 3 | 4 | Demo: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1RzcvX7jHDVc1VTkyUAe8bRA3C93xEffd?usp=sharing); [![Model on HF](https://huggingface.co/datasets/huggingface/badges/resolve/main/model-on-hf-sm.svg)](https://huggingface.co/tonyzyl/DiffusionReconstruct) 5 | 6 | This repository contains the code for the paper "Spatially-Aware Diffusion Models with Cross-Attention for Global Field Reconstruction with Sparse Observations". The paper is available on [arXiv]( 7 | https://doi.org/10.48550/arXiv.2409.00230), accepted by the CMAME. 8 | 9 | ### Summary of the work 10 | 11 | * We propose a cross-attention diffusion model for global field reconstruction. 12 | * We compare different conditioning methods against a strong deterministic baseline and conducted a comprehensive benchmark on the effect of hyperparameters on reconstruction quality. 13 | * The cross-attention diffusion model outperforms others in noisy settings, while the deterministic model is easier to train and excels in noiseless conditions. 14 | * The ensemble mean in probabilistic reconstructions of diffusion model converges to the deterministic output. 15 | * Latent diffusion might be beneficial for handling unevenly distributed fields and to reduce the overall training cost. 16 | * The tested PDEs include Darcy flow, diffusion-reaction, shallow water equations and compressible Navier-Stokes equations. 17 | 18 | 19 | Tested conditioning methods include: 20 | * Guided Sampling (or inpainting) 21 | * Classifier-Free Guided Sampling (CFG) 22 | * Cross-Attention with the proposed encoding block 23 | 24 |
25 | 26 |

Figure 1: Results of diffusion models are computed from an ensemble of 25 trajectories over 20 steps The red dashed line represents the baseline for reconstruction using the mean.

27 |
28 | 29 | **Table 1:Comparison of nRMSE and Computation Cost per sample for the Darcy flow, the computation cost of diffusion models are computed from an ensemble of 25 trajectories with predictor-corrector and 20 steps.** 30 | | | **Guided Sampling** | **CFG** | **Cross-Attention** | **VT-UNet** | **Numerical** | 31 | |--------------------------------|---------------------|---------|---------------------|-------------|---------------| 32 | | **nRMSE (0.3%)** | 0.476 | 0.291 | 0.178 | 0.176 | 0.202 | 33 | | **nRMSE (1.37%)** | 0.474 | 0.261 | 0.129 | 0.092 | 0.180 | 34 | | **Computation cost (s)** | 0.944 | 0.931 | 1.769 | 0.00206 | 62 | 35 | 36 | --- 37 | ### Proposed measurements encoding block 38 | 39 | 40 |
41 | 42 |

Figure 2: Schematic of the proposed condition encoding block. For CFG, mean-pooling is performed to reduce the dimensionality and to combine it with the noise scale embedding.

43 |
44 | 45 | The Voronoi-tessellated (VT) fields serve as an inductive bias and the proposed condition encoding block integrates the interpolated fields with the sensing positions. The two sources of information are fused using a learnable modulation following the Feature-wise Linear Modulation (FiLM). The refined representation leverages self-attention mechanism to effectively extract the conditioning representation and to establish a mapping between the observed and unobserved regions. 46 | 47 | --- 48 | ### Ability to capture possible realization 49 | 50 |
51 | Bar Chart 52 |

Figure 3: Comparison of the generated fields by VT-UNet, single trajectory and ensemble mean of cross-attention diffusion model for the Diffusion Reaction equations with 1% observed data points.

53 |
54 | 55 | The diffusion-reaction equation describes the evolution of activator and inhibitor concentration profiles from a Gaussian initialization, governed by a cubic reaction relation. 56 | 57 | The diffusion model, equipped with the proposed encoding block, captures possible realizations, with the mean converging to the deterministic output. 58 | 59 | --- 60 | ### Ensemble sampling for enhancing Data Assimilation(DA) 61 | 62 |
63 | 64 |

Figure 4: Histogram of relative error improvement distribution with different DA error covariances on the shallow water equations.

65 |
66 | 67 | Overall, both data assimilation methods, whether employing the diffusion ensemble covariance matrix or the identity covariance matrix, enhance the average accuracy of field reconstruction. However, the diffusion ensemble covariance matrix consistently outperforms in the majority of corrections applied to the diffusion model output. 68 | 69 | --- 70 | ### Dataset 71 | **Table 2: Summary of the datasets used in the experiments.** 72 | | **PDE** | **$N_d$** | **$N_t$** | **Boundary Condition** | **Number of Simulations** | **Data Source** | 73 | |--------------------------------|-----------------|-----------|------------------------|---------------------------|-----------------------------------| 74 | | Darcy flow | $128 \times 128$| N/A | Dirichlet | 10,000 | [Huang et al., 2022](https://doi.org/10.1016/j.jcp.2022.111262) | 75 | | Shallow water | $64 \times 64$ | 50 | Periodic | 250 | [Cheng et al., 2024](https://doi.org/10.1016/j.jcp.2023.112581) | 76 | | 2D Diffusion reaction | $128 \times 128$| 101 | Neumann | 1,000 | [Takamoto et al., 2022](https://arxiv.org/abs/2210.07182) | 77 | | 2D Compressible Navier Stokes | $128 \times 128$| 21 | Periodic | 10,000 | [Takamoto et al., 2022](https://arxiv.org/abs/2210.07182) | 78 | 79 | --- 80 | ### Installation 81 | 82 | We recommend using the virtual environment (Python 3.11) to install the required packages. Using the provided `requirements.txt` file, you can install the dependencies using the following command: 83 | 84 | ```bash 85 | pip install -r requirements.txt 86 | ``` 87 | 88 | --- 89 | ### Training 90 | 91 | To train the model, once you configure the corresponding config file, you can run the following command (distributed example): 92 | 93 | ```bash 94 | srun accelerate launch --config_file \ 95 | --multi_gpu \ 96 | --rdzv_backend c10d \ 97 | --machine_rank $SLURM_NODEID \ 98 | --num_machines $SLURM_NNODES \ 99 | --main_process_ip $MASTER_ADDR \ 100 | --main_process_port $MASTER_PORT \ 101 | train_cond.py --config configs/darcy_xattn_config.yaml \ 102 | --conditioning_type=xattn \ 103 | --checkpoints_total_limit=5 \ 104 | ``` 105 | 106 | #### Some notes on training: 107 | * When loading from checkpoint, the parameters of EMA are not properly loaded, this is likely a bug of diffusers(v0.29.2), for details on how to fix this, please refer to the discussion in the comments. 108 | * If not loading the dataset in memory, the implemented xarray is not optimized for fetching data from disk. 109 | * The training script follows the diffusers example, which only save the optimization state of the main process, for proper loading from checkpoint, one should save/load all the states of the processes. 110 | 111 | ### Changelog 112 | * 2024-12-04: Corrected the error in the noise percentage being overestimated; the conclusion remains the same. Paper: Added a comparison between training from scratch and fine-tuning to adapt to different resolutions. Also added a comparison for colored noise. A no-version requirement is provided for the ease of installation. 113 | 114 | #### Citation 115 | If you find this work useful, please consider citing the following paper: 116 | 117 | ```bibtex 118 | @misc{zhuang2024spatiallyawarediffusionmodelscrossattention, 119 | title={Spatially-Aware Diffusion Models with Cross-Attention for Global Field Reconstruction with Sparse Observations}, 120 | author={Yilin Zhuang and Sibo Cheng and Karthik Duraisamy}, 121 | year={2024}, 122 | eprint={2409.00230}, 123 | archivePrefix={arXiv}, 124 | primaryClass={cs.LG}, 125 | url={https://arxiv.org/abs/2409.00230}, 126 | } 127 | ``` -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.1.0 2 | accelerate==0.32.1 3 | aiohttp==3.9.5 4 | aiosignal==1.3.1 5 | antlr4-python3-runtime==4.9.3 6 | anyio==4.4.0 7 | argon2-cffi==23.1.0 8 | argon2-cffi-bindings==21.2.0 9 | arrow==1.3.0 10 | asttokens==2.4.1 11 | async-lru==2.0.4 12 | attrs==23.2.0 13 | Babel==2.15.0 14 | beautifulsoup4==4.12.3 15 | bleach==6.1.0 16 | certifi==2024.7.4 17 | cffi==1.16.0 18 | charset-normalizer==3.3.2 19 | click==8.1.7 20 | cloudpickle==3.0.0 21 | comm==0.2.2 22 | contourpy==1.2.1 23 | cycler==0.12.1 24 | dask==2024.7.0 25 | datasets==2.20.0 26 | debugpy==1.8.2 27 | decorator==5.1.1 28 | defusedxml==0.7.1 29 | diffusers==0.29.2 30 | dill==0.3.8 31 | einops==0.8.0 32 | executing==2.0.1 33 | fastjsonschema==2.20.0 34 | filelock==3.15.4 35 | fonttools==4.53.1 36 | fqdn==1.5.1 37 | frozenlist==1.4.1 38 | fsspec==2024.5.0 39 | grpcio==1.64.1 40 | h11==0.14.0 41 | h5netcdf==1.3.0 42 | h5py==3.11.0 43 | httpcore==1.0.5 44 | httpx==0.27.0 45 | huggingface-hub==0.23.4 46 | idna==3.7 47 | importlib_metadata==8.0.0 48 | ipykernel==6.29.5 49 | ipython==8.26.0 50 | isoduration==20.11.0 51 | jedi==0.19.1 52 | Jinja2==3.1.4 53 | json5==0.9.25 54 | jsonpointer==3.0.0 55 | jsonschema==4.23.0 56 | jsonschema-specifications==2023.12.1 57 | jupyter-events==0.10.0 58 | jupyter-lsp==2.2.5 59 | jupyter_client==8.6.2 60 | jupyter_core==5.7.2 61 | jupyter_server==2.14.1 62 | jupyter_server_terminals==0.5.3 63 | jupyterlab==4.2.3 64 | jupyterlab_pygments==0.3.0 65 | jupyterlab_server==2.27.2 66 | kiwisolver==1.4.5 67 | locket==1.0.0 68 | Markdown==3.6 69 | MarkupSafe==2.1.5 70 | matplotlib==3.9.1 71 | matplotlib-inline==0.1.7 72 | mistune==3.0.2 73 | mpmath==1.3.0 74 | multidict==6.0.5 75 | multiprocess==0.70.16 76 | nbclient==0.10.0 77 | nbconvert==7.16.4 78 | nbformat==5.10.4 79 | nest-asyncio==1.6.0 80 | networkx==3.3 81 | notebook==7.2.1 82 | notebook_shim==0.2.4 83 | numpy==1.26.4 84 | nvidia-cublas-cu12==12.1.3.1 85 | nvidia-cuda-cupti-cu12==12.1.105 86 | nvidia-cuda-nvrtc-cu12==12.1.105 87 | nvidia-cuda-runtime-cu12==12.1.105 88 | nvidia-cudnn-cu12==8.9.2.26 89 | nvidia-cufft-cu12==11.0.2.54 90 | nvidia-curand-cu12==10.3.2.106 91 | nvidia-cusolver-cu12==11.4.5.107 92 | nvidia-cusparse-cu12==12.1.0.106 93 | nvidia-nccl-cu12==2.20.5 94 | nvidia-nvjitlink-cu12==12.5.82 95 | nvidia-nvtx-cu12==12.1.105 96 | omegaconf==2.3.0 97 | overrides==7.7.0 98 | packaging==24.1 99 | pandas==2.2.2 100 | pandocfilters==1.5.1 101 | parso==0.8.4 102 | partd==1.4.2 103 | pexpect==4.9.0 104 | pillow==10.4.0 105 | platformdirs==4.2.2 106 | prometheus_client==0.20.0 107 | prompt_toolkit==3.0.47 108 | protobuf==4.25.3 109 | psutil==6.0.0 110 | ptyprocess==0.7.0 111 | pure-eval==0.2.2 112 | pyarrow==16.1.0 113 | pyarrow-hotfix==0.6 114 | pycparser==2.22 115 | Pygments==2.18.0 116 | pyparsing==3.1.2 117 | python-dateutil==2.9.0.post0 118 | python-json-logger==2.0.7 119 | pytz==2024.1 120 | PyYAML==6.0.1 121 | pyzmq==26.0.3 122 | referencing==0.35.1 123 | regex==2024.5.15 124 | requests==2.32.3 125 | rfc3339-validator==0.1.4 126 | rfc3986-validator==0.1.1 127 | rpds-py==0.19.0 128 | safetensors==0.4.3 129 | scipy==1.14.0 130 | Send2Trash==1.8.3 131 | six==1.16.0 132 | sniffio==1.3.1 133 | soupsieve==2.5 134 | stack-data==0.6.3 135 | sympy==1.13.0 136 | tensorboard==2.17.0 137 | tensorboard-data-server==0.7.2 138 | terminado==0.18.1 139 | tinycss2==1.3.0 140 | tokenizers==0.19.1 141 | toolz==0.12.1 142 | torch==2.3.1 143 | torchvision==0.18.1 144 | tornado==6.4.1 145 | tqdm==4.66.4 146 | traitlets==5.14.3 147 | transformers==4.42.4 148 | triton==2.3.1 149 | types-python-dateutil==2.9.0.20240316 150 | typing_extensions==4.12.2 151 | tzdata==2024.1 152 | uri-template==1.3.0 153 | urllib3==2.2.2 154 | wcwidth==0.2.13 155 | webcolors==24.6.0 156 | webencodings==0.5.1 157 | websocket-client==1.8.0 158 | Werkzeug==3.0.3 159 | xarray==2023.7.0 160 | xxhash==3.4.1 161 | yarl==1.9.4 162 | zipp==3.19.2 163 | -------------------------------------------------------------------------------- /requirements_noversion.txt: -------------------------------------------------------------------------------- 1 | diffusers[torch] 2 | transformers 3 | dask 4 | xarray 5 | tensorboard 6 | einops 7 | h5netcdf 8 | h5py 9 | scipy 10 | pandas 11 | matplotlib 12 | Omegaconf 13 | ipykernel 14 | notebook -------------------------------------------------------------------------------- /train_vt.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | from datetime import timedelta 5 | import torch 6 | import shutil 7 | from packaging import version 8 | import accelerate 9 | from accelerate import Accelerator 10 | from accelerate.utils import ProjectConfiguration, set_seed, InitProcessGroupKwargs 11 | from accelerate.logging import get_logger 12 | from diffusers.optimization import get_scheduler 13 | from diffusers.training_utils import EMAModel 14 | from diffusers.utils.import_utils import is_xformers_available 15 | from diffusers.utils.torch_utils import is_compiled_module 16 | from torch.optim import AdamW 17 | import torch.nn.functional as F 18 | from torchvision.utils import make_grid 19 | from tqdm.auto import tqdm 20 | from omegaconf import OmegaConf 21 | from pathlib import Path 22 | import math 23 | from huggingface_hub import upload_folder 24 | 25 | from models.unet2D import diffuserUNet2D 26 | from utils.general_utils import instantiate_from_config, flatten_and_filter_config, convert_to_rgb 27 | from utils.inverse_utils import create_scatter_mask 28 | from utils.vt_utils import vt_obs 29 | from dataloader.dataset_class import pdedata2dataloader 30 | from losses.metric import metric_func_2D 31 | 32 | logger = get_logger(__name__, log_level="INFO") 33 | 34 | @torch.no_grad() 35 | def evaluate(phase_name, config, epoch, vt, model, accelerator, known_latents=None): 36 | # Generate some sample images 37 | image_dim = known_latents.shape[-1] 38 | generator = torch.Generator(device='cpu').manual_seed(config.seed) # Use a separate torch generator to avoid rewinding the random state of the main training loop 39 | tmp_latents = known_latents[:config.eval_batch_size] 40 | mask = create_scatter_mask(tmp_latents, channels=config.known_channels, ratio=0.02, generator=generator, device=known_latents.device) 41 | interpolated_fields = vt(known_latents[:config.eval_batch_size], mask=mask) 42 | #''' 43 | sample_images = model( 44 | interpolated_fields, 45 | return_dict=False, 46 | )[0] 47 | try: 48 | channel_names = config.channel_names 49 | except: 50 | channel_names = ['' for _ in range(sample_images.shape[1])] 51 | 52 | #pressure = convert_to_rgb(sample_images[:, 0].reshape(-1, 1, 64, 64)) 53 | #permeability = convert_to_rgb(sample_images[:, 1].reshape(-1, 1, 64, 64)) 54 | images_list = [] 55 | GT_list = [] 56 | for i in range(sample_images.shape[1]): 57 | tmp_image = convert_to_rgb(sample_images[:, i].reshape(-1, 1, image_dim, image_dim)) 58 | ground_truth = convert_to_rgb(known_latents[:config.eval_batch_size, i].reshape(-1, 1, image_dim, image_dim)) 59 | images_list.append(make_grid(torch.stack(tmp_image))) 60 | GT_list.append(make_grid(torch.stack(ground_truth))) 61 | 62 | err_RMSE, err_nRMSE, err_CSV = metric_func_2D(sample_images, known_latents[:config.eval_batch_size], mask=mask) 63 | for tracker in accelerator.trackers: 64 | if tracker.name == 'tensorboard': 65 | for i, (img, gt) in enumerate(zip(images_list, GT_list)): 66 | tracker.writer.add_image(phase_name + ' sample ' + channel_names[i], img, epoch) 67 | tracker.writer.add_image(phase_name + ' GT ' + channel_names[i], gt, epoch) 68 | 69 | accelerator.log({"RMSE": err_RMSE, "nRMSE": err_nRMSE, "CSV": err_CSV}, step=epoch) 70 | 71 | if torch.cuda.is_available(): 72 | torch.cuda.empty_cache() 73 | 74 | def parse_args(): 75 | parser = argparse.ArgumentParser(description="Train a Diffusers model.") 76 | parser.add_argument('--config', type=str, required=True, help="Path to the YAML configuration file.") 77 | parser.add_argument( 78 | "--resume_from_checkpoint", 79 | type=str, 80 | default=None, 81 | help=( 82 | "Whether training should be resumed from a previous checkpoint. Use a path saved by" 83 | ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' 84 | ), 85 | ) 86 | parser.add_argument( 87 | "--checkpoints_total_limit", 88 | type=int, 89 | default=None, 90 | help=("Max number of checkpoints to store."), 91 | ) 92 | parser.add_argument( 93 | "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." 94 | ) 95 | parser.add_argument( 96 | "--hub_model_id", 97 | type=str, 98 | default=None, 99 | help="The name of the repository to keep in sync with the local `output_dir`.", 100 | ) 101 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") 102 | parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") 103 | 104 | return parser.parse_args() 105 | 106 | def main(args): 107 | 108 | config = OmegaConf.load(args.config) 109 | tracker_config = flatten_and_filter_config(OmegaConf.to_container(config, resolve=True)) 110 | 111 | unet_config = OmegaConf.to_container(config.pop("unet", OmegaConf.create()), resolve=True) 112 | noise_scheduler_config = config.pop("noise_scheduler", OmegaConf.create()) 113 | accelerator_config = config.pop("accelerator", OmegaConf.create()) 114 | loss_fn_config = config.pop("loss_fn", OmegaConf.create()) 115 | optimizer_config = config.pop("optimizer", OmegaConf.create()) 116 | lr_scheduler_config = config.pop("lr_scheduler", OmegaConf.create()) 117 | dataloader_config = config.pop("dataloader", OmegaConf.create()) 118 | ema_config = config.pop("ema", OmegaConf.create()) 119 | general_config = config.pop("general", OmegaConf.create()) 120 | 121 | set_seed(general_config.seed) 122 | 123 | if 'resnet_time_scale_shift' in unet_config: 124 | del unet_config['resnet_time_scale_shift'] # No temb 125 | unet = diffuserUNet2D.from_config(config=unet_config) 126 | 127 | logging_dir = Path(general_config.output_dir, general_config.logging_dir) 128 | accelerator_project_config = ProjectConfiguration(project_dir=general_config.output_dir, logging_dir=logging_dir) 129 | kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=1800)) 130 | accelerator = Accelerator( 131 | project_config=accelerator_project_config, 132 | **accelerator_config, 133 | kwargs_handlers=[kwargs] 134 | ) 135 | 136 | # spacing can be arbitrary when only using the vt to interpolate fields 137 | vt = vt_obs(x_dim=unet_config['sample_size'], y_dim=unet_config['sample_size'], 138 | known_channels=general_config.known_channels, x_spacing=8, y_spacing=8, device=accelerator.device) 139 | 140 | # Create EMA for the model. 141 | if ema_config.use_ema: 142 | ema_model = EMAModel( 143 | unet.parameters(), 144 | decay=ema_config.ema_max_decay, 145 | use_ema_warmup=True, 146 | inv_gamma=ema_config.ema_inv_gamma, 147 | power=ema_config.ema_power, 148 | model_cls=diffuserUNet2D, 149 | model_config=unet.config, 150 | foreach = ema_config.foreach, 151 | ) 152 | 153 | # Does not work with torch.compile() 154 | if args.enable_xformers_memory_efficient_attention: 155 | if is_xformers_available(): 156 | import xformers 157 | 158 | xformers_version = version.parse(xformers.__version__) 159 | if xformers_version == version.parse("0.0.16"): 160 | logger.warning( 161 | "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." 162 | ) 163 | unet.enable_xformers_memory_efficient_attention() 164 | else: 165 | raise ValueError("xformers is not available. Make sure it is installed correctly") 166 | 167 | # `accelerate` 0.16.0 will have better support for customized saving 168 | if version.parse(accelerate.__version__) >= version.parse("0.16.0"): 169 | # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format 170 | def save_model_hook(models, weights, output_dir): 171 | if accelerator.is_main_process: 172 | if ema_config.use_ema: 173 | ema_model.save_pretrained(os.path.join(output_dir, "unet_ema")) 174 | 175 | for i, model in enumerate(models): 176 | model.save_pretrained(os.path.join(output_dir, "unet")) 177 | 178 | # make sure to pop weight so that corresponding model is not saved again 179 | weights.pop() 180 | 181 | def load_model_hook(models, input_dir): 182 | if ema_config.use_ema: 183 | # TODO: follow up on loading checkpoint with EMA 184 | load_model = EMAModel.from_pretrained( 185 | #os.path.join(input_dir, "unet_ema"), UNet2DModel 186 | os.path.join(input_dir, "unet_ema"), diffuserUNet2D, #foreach=ema_config.foreach # not yet released in v0.29.2 187 | ) 188 | ema_model.load_state_dict(load_model.state_dict()) 189 | if ema_config.offload_ema: 190 | ema_model.pin_memory() 191 | else: 192 | ema_model.to(accelerator.device) 193 | del load_model 194 | 195 | for _ in range(len(models)): 196 | # pop models so that they are not loaded again 197 | model = models.pop() 198 | 199 | # load diffusers style into model 200 | load_model = diffuserUNet2D.from_pretrained(input_dir, subfolder="unet") 201 | model.register_to_config(**load_model.config) 202 | 203 | model.load_state_dict(load_model.state_dict()) 204 | del load_model 205 | 206 | accelerator.register_save_state_pre_hook(save_model_hook) 207 | accelerator.register_load_state_pre_hook(load_model_hook) 208 | 209 | loss_fn = instantiate_from_config(loss_fn_config) 210 | 211 | generator = torch.Generator(device='cpu').manual_seed(general_config.seed) 212 | with accelerator.main_process_first(): 213 | # https://github.com/huggingface/accelerate/issues/503 214 | # https://discuss.huggingface.co/t/shared-memory-in-accelerate/28619 215 | train_dataloader, val_dataloader, test_dataloader = pdedata2dataloader(**dataloader_config, generator=generator) 216 | 217 | 218 | # Make one log on every process with the configuration for debugging. 219 | logging.basicConfig( 220 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 221 | datefmt="%m/%d/%Y %H:%M:%S", 222 | level=logging.INFO, 223 | ) 224 | logger.info(accelerator.state, main_process_only=False) 225 | 226 | if general_config.scale_lr: 227 | optimizer_config.lr = ( 228 | optimizer_config.lr 229 | * accelerator.num_processes 230 | * accelerator.gradient_accumulation_steps 231 | * dataloader_config.batch_size 232 | ) 233 | 234 | optimizer = AdamW(unet.parameters(), **optimizer_config) 235 | 236 | # Scheduler and math around the number of training steps. 237 | overrode_max_train_steps = False 238 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.gradient_accumulation_steps) 239 | if "num_training_steps" not in general_config: 240 | general_config.num_training_steps = num_update_steps_per_epoch * general_config.num_epochs 241 | logger.info(f"num_training_steps not found in lr_scheduler_config. Setting num_training_steps to product of num_epochs and training dataloader length: {general_config.num_training_steps}") 242 | overrode_max_train_steps = True 243 | 244 | lr_scheduler = get_scheduler(lr_scheduler_config.name, optimizer, 245 | num_warmup_steps = lr_scheduler_config.num_warmup_steps * accelerator.num_processes, 246 | num_training_steps = general_config.num_training_steps * accelerator.num_processes, 247 | num_cycles = lr_scheduler_config.num_cycles, 248 | power = lr_scheduler_config.power) 249 | 250 | print('start preparing dataloader') 251 | unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( 252 | unet, optimizer, train_dataloader, val_dataloader, lr_scheduler 253 | ) 254 | print('finished preparing dataloader') 255 | 256 | if ema_config.use_ema: 257 | if ema_config.offload_ema: 258 | ema_model.pin_memory() 259 | else: 260 | ema_model.to(accelerator.device) 261 | 262 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 263 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.gradient_accumulation_steps) 264 | if overrode_max_train_steps: 265 | general_config.num_training_steps = general_config.num_epochs * num_update_steps_per_epoch 266 | # Afterwards we recalculate our number of training epochs 267 | general_config.num_epochs = math.ceil(general_config.num_training_steps / num_update_steps_per_epoch) 268 | 269 | # We need to initialize the trackers we use, and also store our configuration. 270 | # The trackers initializes automatically on the main process. 271 | if accelerator.is_main_process: 272 | print(tracker_config) 273 | accelerator.init_trackers(general_config.tracker_project_name, config=tracker_config) 274 | 275 | # Function for unwrapping if model was compiled with `torch.compile`. 276 | def unwrap_model(model): 277 | # https://github.com/huggingface/diffusers/issues/6503 278 | model = accelerator.unwrap_model(model) 279 | model = model._orig_mod if is_compiled_module(model) else model 280 | return model 281 | 282 | total_batch_size = dataloader_config.batch_size * accelerator.num_processes * accelerator.gradient_accumulation_steps 283 | 284 | logger.info("***** Running training *****") 285 | #logger.info(f" Num examples = {len(train_dataset)}") 286 | logger.info(f" Num batches each epoch = {len(train_dataloader)}") 287 | logger.info(f" Num Epochs = {general_config.num_epochs}") 288 | logger.info(f" Instantaneous batch size per device = {dataloader_config.batch_size}") 289 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 290 | logger.info(f" Gradient Accumulation steps = {accelerator.gradient_accumulation_steps}") 291 | logger.info(f" Total optimization steps = {general_config.num_training_steps}") 292 | logger.info(f" Total training epochs = {general_config.num_epochs}") 293 | global_step = 0 294 | first_epoch = 0 295 | 296 | # Potentially load in the weights and states from a previous save 297 | if args.resume_from_checkpoint: 298 | if args.resume_from_checkpoint != "latest": 299 | path = os.path.basename(args.resume_from_checkpoint) 300 | else: 301 | # Get the mos recent checkpoint 302 | dirs = os.listdir(general_config.output_dir) 303 | dirs = [d for d in dirs if d.startswith("checkpoint")] 304 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) 305 | path = dirs[-1] if len(dirs) > 0 else None 306 | 307 | if path is None: 308 | accelerator.print( 309 | f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." 310 | ) 311 | args.resume_from_checkpoint = None 312 | initial_global_step = 0 313 | else: 314 | accelerator.print(f"Resuming from checkpoint {path}") 315 | if accelerator.is_main_process: # temp fix for only having one random state 316 | accelerator.load_state(os.path.join(general_config.output_dir, path)) 317 | global_step = int(path.split("-")[1]) 318 | 319 | initial_global_step = global_step 320 | first_epoch = global_step // num_update_steps_per_epoch 321 | 322 | else: 323 | initial_global_step = 0 324 | 325 | progress_bar = tqdm( 326 | range(0, general_config.num_training_steps), 327 | initial=initial_global_step, 328 | desc="Steps", 329 | # Only show the progress bar once on each machine. 330 | disable=not accelerator.is_local_main_process, 331 | ) 332 | 333 | # Now you train the model 334 | for epoch in range(first_epoch, general_config.num_epochs): 335 | unet.train() 336 | train_loss = 0.0 337 | for step, batch in enumerate(train_dataloader): 338 | with accelerator.accumulate(unet): 339 | clean_images = batch 340 | #interpolated_fields = vt(clean_images) 341 | tmp_ratio = torch.rand(clean_images.shape[0], device=clean_images.device)*0.1 342 | tmp_ratio = torch.where(tmp_ratio<=0.001, 0.001, tmp_ratio) # to avoid no points get sampled 343 | mask = create_scatter_mask(clean_images, channels=general_config.known_channels, 344 | ratio=tmp_ratio) 345 | interpolated_fields = vt(clean_images, mask=mask) 346 | denoised_fields = unet(interpolated_fields, return_dict=False)[0] 347 | loss = ((clean_images.float() - denoised_fields.float()) ** 2).mean() 348 | train_loss += loss.item() / accelerator.gradient_accumulation_steps 349 | 350 | accelerator.backward(loss) 351 | if accelerator.sync_gradients: 352 | accelerator.clip_grad_norm_(unet.parameters(), 1.0) 353 | optimizer.step() 354 | lr_scheduler.step() 355 | optimizer.zero_grad() 356 | 357 | # Checks if the accelerator has performed an optimization step behind the scenes 358 | if accelerator.sync_gradients: 359 | if ema_config.use_ema: 360 | if ema_config.offload_ema: 361 | ema_model.to(device="cuda", non_blocking=True) 362 | ema_model.step(unet.parameters()) 363 | if ema_config.offload_ema: 364 | ema_model.to(device="cpu", non_blocking=True) 365 | progress_bar.update(1) 366 | logs = {"train loss": train_loss, "lr": lr_scheduler.get_last_lr()[0], "step": global_step} 367 | if ema_config.use_ema: 368 | logs["ema_decay"] = ema_model.cur_decay_value 369 | global_step += 1 370 | accelerator.log(logs, step=global_step) 371 | train_loss = 0.0 372 | 373 | if accelerator.is_main_process: 374 | if global_step % general_config.checkpointing_steps == 0: 375 | ''' 376 | if config.push_to_hub: 377 | upload_folder( 378 | repo_id=repo_id, 379 | folder_path=config.output_dir, 380 | commit_message=f"Epoch {epoch}", 381 | ignore_patterns=["step_*", "epoch_*"], 382 | ) 383 | else: 384 | ''' 385 | 386 | if args.checkpoints_total_limit is not None: 387 | checkpoints = os.listdir(general_config.output_dir) 388 | checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] 389 | checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) 390 | 391 | # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints 392 | if len(checkpoints) >= args.checkpoints_total_limit: 393 | num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 394 | removing_checkpoints = checkpoints[0:num_to_remove] 395 | 396 | logger.info( 397 | f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" 398 | ) 399 | logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") 400 | 401 | for removing_checkpoint in removing_checkpoints: 402 | removing_checkpoint = os.path.join(general_config.output_dir, removing_checkpoint) 403 | shutil.rmtree(removing_checkpoint) 404 | 405 | save_path = os.path.join(general_config.output_dir, f"checkpoint-{global_step}") 406 | accelerator.save_state(save_path) 407 | logger.info(f"Saved state to {save_path}") 408 | 409 | logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} 410 | if ema_config.use_ema: 411 | logs["ema_decay"] = ema_model.cur_decay_value 412 | progress_bar.set_postfix(**logs) 413 | 414 | if global_step >= general_config.num_training_steps: 415 | break 416 | 417 | # After each epoch you optionally sample some demo images with evaluate() and save the model 418 | if accelerator.is_main_process: 419 | 420 | if (epoch + 1) % general_config.save_image_epochs == 0 or epoch == general_config.num_epochs - 1: 421 | if ema_config.use_ema: 422 | # Store the UNet parameters temporarily and load the EMA parameters to perform inference. 423 | ema_model.store(unet.parameters()) 424 | ema_model.copy_to(unet.parameters()) 425 | evaluate('train', general_config, epoch, vt, unwrap_model(unet), accelerator=accelerator, known_latents=batch) 426 | if ema_config.use_ema: 427 | # Restore the UNet parameters. 428 | ema_model.restore(unet.parameters()) 429 | 430 | if (epoch + 1) % general_config.save_model_epochs == 0 or epoch == general_config.num_epochs - 1: 431 | # save the model 432 | 433 | if ema_config.use_ema: 434 | ema_model.store(unet.parameters()) 435 | ema_model.copy_to(unet.parameters()) 436 | 437 | unwrap_model(unet).save_pretrained(os.path.join(general_config.output_dir, "unet")) 438 | 439 | if ema_config.use_ema: 440 | ema_model.restore(unet.parameters()) 441 | 442 | if args.push_to_hub: 443 | upload_folder( 444 | repo_id=args.hub_model_id, 445 | folder_path=general_config.output_dir+"/unet", 446 | path_in_repo=general_config.output_dir.split("/")[-1], 447 | commit_message="running weight", 448 | ignore_patterns=["checkpoint_"], 449 | token=args.hub_token if args.hub_token else None, 450 | ) 451 | 452 | accelerator.end_training() 453 | 454 | if __name__ == "__main__": 455 | args = parse_args() 456 | main(args) -------------------------------------------------------------------------------- /utils/attn_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional, Tuple, Union 2 | 3 | import torch 4 | import torch.utils.checkpoint 5 | from torch import nn 6 | 7 | from diffusers.models.embeddings import PatchEmbed 8 | from diffusers.models.embeddings import get_2d_sincos_pos_embed 9 | 10 | from transformers.configuration_utils import PretrainedConfig 11 | from transformers.models.clip.modeling_clip import CLIPEncoder, CLIPPreTrainedModel, CLIPVisionModelOutput 12 | from transformers.modeling_outputs import BaseModelOutputWithPooling 13 | 14 | class FieldsVisionConfig(PretrainedConfig): 15 | # modified from CLIPVisionConfig 16 | model_type = "fields_vision_model" 17 | def __init__( 18 | self, 19 | hidden_size=768, 20 | intermediate_size=3072, 21 | projection_dim=512, 22 | num_hidden_layers=12, 23 | num_attention_heads=12, 24 | num_channels=3, 25 | image_size: Union[int, Tuple[int, int]] = (128, 128), 26 | patch_size=32, 27 | hidden_act="quick_gelu", 28 | layer_norm_eps=1e-5, 29 | attention_dropout=0.0, 30 | initializer_range=0.02, 31 | initializer_factor=1.0, 32 | input_padding: Union[int, Tuple[int, int]] = (0, 0), 33 | output_hidden_state=False, 34 | **kwargs, 35 | ): 36 | super().__init__(**kwargs) 37 | 38 | self.hidden_size = hidden_size 39 | self.intermediate_size = intermediate_size 40 | self.projection_dim = projection_dim 41 | self.num_hidden_layers = num_hidden_layers 42 | self.num_attention_heads = num_attention_heads 43 | self.num_channels = num_channels 44 | self.patch_size = patch_size 45 | self.image_size = image_size 46 | self.initializer_range = initializer_range 47 | self.initializer_factor = initializer_factor 48 | self.attention_dropout = attention_dropout 49 | self.layer_norm_eps = layer_norm_eps 50 | self.hidden_act = hidden_act 51 | self.input_padding = input_padding 52 | self.output_hidden_state = output_hidden_state 53 | 54 | class FieldsEmbeddings(nn.Module): 55 | # modified from CLIPVisionEmbddings in: 56 | # https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/modeling_clip.py 57 | def __init__(self, config: FieldsVisionConfig): 58 | super().__init__() 59 | self.config = config 60 | self.embed_dim = config.hidden_size 61 | self.image_size = config.image_size 62 | self.patch_size = config.patch_size 63 | if isinstance(self.image_size, int): 64 | self.image_size = (self.image_size, self.image_size) 65 | self.input_padding = config.input_padding 66 | if isinstance(self.input_padding, int): 67 | self.input_padding = (self.input_padding, self.input_padding) 68 | 69 | self.patch_embedding = PatchEmbed( 70 | height=self.image_size[0], 71 | width=self.image_size[1], 72 | patch_size=self.patch_size, 73 | in_channels=config.num_channels, 74 | embed_dim=self.embed_dim, 75 | bias=True, 76 | pos_embed_type=None, 77 | ) 78 | 79 | self.sensing_array_patch_embedding = PatchEmbed( 80 | height=self.image_size[0], 81 | width=self.image_size[1], 82 | patch_size=self.patch_size, 83 | in_channels=1, 84 | embed_dim=2*self.embed_dim, 85 | bias=True, 86 | pos_embed_type=None, 87 | ) 88 | 89 | 90 | padded_inputs = tuple(a + b for a, b in zip(self.input_padding, self.image_size)) 91 | num_patch_height = padded_inputs[0] // self.patch_size 92 | num_patch_width = padded_inputs[1] // self.patch_size 93 | 94 | pos_embed = get_2d_sincos_pos_embed(self.embed_dim, (num_patch_height, num_patch_width), base_size=num_patch_height) 95 | self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False) 96 | 97 | self._mod_init_weights() 98 | 99 | def _mod_init_weights(self): 100 | # Since we modify the class name, we need to handle the weight initialization ourselves 101 | factor = self.config.initializer_factor 102 | nn.init.normal_(self.patch_embedding.proj.weight, std=self.config.initializer_range * factor) 103 | nn.init.constant_(self.patch_embedding.proj.bias, 0.0) 104 | nn.init.normal_(self.sensing_array_patch_embedding.proj.weight, std=self.config.initializer_range * factor) 105 | nn.init.constant_(self.sensing_array_patch_embedding.proj.bias, 0.0) 106 | 107 | def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: 108 | batch_size = pixel_values.shape[0] 109 | target_dtype = self.patch_embedding.proj.weight.dtype 110 | patch_embeds = self.patch_embedding(pixel_values[:,:-1].to(dtype=target_dtype)) # shape = [*, num_patches, embed_dim] 111 | sensing_array_patch_embeds = self.sensing_array_patch_embedding(pixel_values[:,[-1]].to(dtype=target_dtype)) 112 | scale, shift = torch.chunk(sensing_array_patch_embeds, 2, dim=-1) 113 | patch_embeds = patch_embeds * (1 + scale) + shift 114 | 115 | embeddings = patch_embeds + self.pos_embed 116 | return embeddings 117 | 118 | class FieldsVisionTransformer(nn.Module): 119 | # modified from CLIPVisionTransformer in: 120 | def __init__(self, config: FieldsVisionConfig): 121 | super().__init__() 122 | self.config = config 123 | embed_dim = config.hidden_size 124 | 125 | self.embeddings = FieldsEmbeddings(config) 126 | self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) 127 | self.encoder = CLIPEncoder(config) 128 | if not config.output_hidden_state: 129 | self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) 130 | 131 | def forward( 132 | self, 133 | pixel_values: Optional[torch.FloatTensor] = None, 134 | output_attentions: Optional[bool] = None, 135 | output_hidden_states: Optional[bool] = None, 136 | return_dict: Optional[bool] = None, 137 | ) -> Union[Tuple, BaseModelOutputWithPooling]: 138 | r""" 139 | Returns: 140 | 141 | """ 142 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 143 | output_hidden_states = ( 144 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 145 | ) 146 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 147 | 148 | if pixel_values is None: 149 | raise ValueError("You have to specify pixel_values") 150 | 151 | hidden_states = self.embeddings(pixel_values) 152 | hidden_states = self.pre_layrnorm(hidden_states) 153 | 154 | encoder_outputs = self.encoder( 155 | inputs_embeds=hidden_states, 156 | output_attentions=output_attentions, 157 | output_hidden_states=output_hidden_states, 158 | return_dict=return_dict, 159 | ) 160 | 161 | last_hidden_state = encoder_outputs[0] 162 | if not self.config.output_hidden_state: 163 | pooled_output = last_hidden_state.mean(dim=1) 164 | pooled_output = self.post_layernorm(pooled_output) 165 | else: 166 | pooled_output = None 167 | 168 | if not return_dict: 169 | return (last_hidden_state, pooled_output) + encoder_outputs[1:] 170 | 171 | return BaseModelOutputWithPooling( 172 | last_hidden_state=last_hidden_state, 173 | pooler_output=pooled_output, 174 | hidden_states=encoder_outputs.hidden_states, 175 | attentions=encoder_outputs.attentions, 176 | ) 177 | 178 | class FieldsVisionModelWithProjection(CLIPPreTrainedModel): 179 | # modified from: CLIPVisionModelWithProjection 180 | config_class = FieldsVisionConfig 181 | main_input_name = "pixel_values" 182 | 183 | def __init__(self, config: FieldsVisionConfig): 184 | super().__init__(config) 185 | 186 | self.vision_model = FieldsVisionTransformer(config) 187 | self.config = config 188 | 189 | if not config.output_hidden_state: 190 | self.visual_projection = nn.Linear(config.hidden_size, config.projection_dim, bias=False) 191 | self._mod_init_weights() 192 | 193 | # Initialize weights and apply final processing 194 | self.post_init() 195 | 196 | def _mod_init_weights(self): 197 | # Since we modify the class name, we need to handle the weight initialization ourselves 198 | nn.init.normal_( 199 | self.visual_projection.weight, 200 | std=self.config.hidden_size**-0.5 * self.config.initializer_factor, 201 | ) 202 | 203 | def get_input_embeddings(self) -> nn.Module: 204 | return self.vision_model.embeddings.patch_embedding 205 | 206 | def forward( 207 | self, 208 | pixel_values: Optional[torch.FloatTensor] = None, 209 | output_attentions: Optional[bool] = None, 210 | output_hidden_states: Optional[bool] = None, 211 | return_dict: Optional[bool] = None, 212 | ) -> Union[Tuple, CLIPVisionModelOutput]: 213 | 214 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 215 | 216 | vision_outputs = self.vision_model( 217 | pixel_values=pixel_values, 218 | output_attentions=output_attentions, 219 | output_hidden_states=output_hidden_states, 220 | return_dict=return_dict, 221 | ) 222 | 223 | if not self.config.output_hidden_state: 224 | pooled_output = vision_outputs[1] # pooled_output 225 | image_embeds = self.visual_projection(pooled_output) 226 | else: 227 | image_embeds = None 228 | 229 | if not return_dict: 230 | outputs = (image_embeds, vision_outputs[0]) + vision_outputs[2:] 231 | return tuple(output for output in outputs if output is not None) 232 | 233 | return CLIPVisionModelOutput( 234 | image_embeds=image_embeds, 235 | last_hidden_state=vision_outputs.last_hidden_state, 236 | hidden_states=vision_outputs.hidden_states, 237 | attentions=vision_outputs.attentions, 238 | ) -------------------------------------------------------------------------------- /utils/general_utils.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import torch 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import h5py 6 | from typing import Union, List, Optional, Tuple 7 | from diffusers.utils import logging 8 | 9 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 10 | 11 | def get_obj_from_str(string): 12 | module, cls = string.rsplit(".", 1) 13 | return getattr(importlib.import_module(module, package=None), cls) 14 | 15 | def instantiate_from_config(config): 16 | if not "target" in config: 17 | raise Exception("target not in config! ", config) 18 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 19 | 20 | def flatten_dict(d, parent_key='', sep='.'): 21 | items = [] 22 | for k, v in d.items(): 23 | new_key = f"{parent_key}{sep}{k}" if parent_key else k 24 | if isinstance(v, dict): 25 | items.extend(flatten_dict(v, new_key, sep=sep).items()) 26 | else: 27 | items.append((new_key, v)) 28 | return dict(items) 29 | 30 | def flatten_and_filter_config(config): 31 | flat_config = flatten_dict(config) 32 | filtered_config = {} 33 | for key, value in flat_config.items(): 34 | if isinstance(value, (int, float, str, bool, torch.Tensor)): 35 | filtered_config[key] = value 36 | else: 37 | filtered_config[key] = str(value) # Convert unsupported types to string 38 | return filtered_config 39 | 40 | def convert_to_rgb(images): 41 | # Get the colormap 42 | cmap = plt.get_cmap('jet') 43 | 44 | # Ensure images are detached and converted to numpy for colormap application 45 | images_np = images.squeeze(1).detach().cpu().numpy() # shape: (B, H, W) 46 | 47 | converted_images = [] 48 | for img in images_np: 49 | # Normalize img to range [0, 1] 50 | img_normalized = (img - img.min()) / (img.max() - img.min() + 1e-5) 51 | 52 | # Apply colormap and convert to RGB 53 | img_rgb = cmap(img_normalized) 54 | 55 | # Convert from RGBA (4 channels) to RGB (3 channels) 56 | img_rgb = img_rgb[..., :3] # shape: (H, W, 3) 57 | 58 | # Convert to PyTorch tensor and scale to range [0, 255] 59 | img_rgb_tensor = torch.tensor(img_rgb * 255, dtype=torch.uint8).permute(2, 0, 1) # shape: (3, H, W) 60 | 61 | converted_images.append(img_rgb_tensor) 62 | 63 | return converted_images 64 | 65 | def read_hdf5_to_numpy(file_path, key, data_type=np.float32): 66 | with h5py.File(file_path, 'r') as f: 67 | #print(f'Reading dataset {key}') 68 | dataset = f[key] 69 | data_dtype = data_type 70 | data_array = np.asarray(dataset, dtype=data_dtype) 71 | return data_array 72 | 73 | @torch.no_grad() 74 | def calculate_covariance(samples, channel): 75 | """ 76 | Calculate the covariance matrix for a selected channel in the ensemble of samples. 77 | 78 | Args: 79 | samples (torch.Tensor): The ensemble of samples with shape (B, C, H, W). 80 | channel (int): The index of the channel for which to calculate the covariance matrix. 81 | 82 | Returns: 83 | torch.Tensor: The covariance matrix of the selected channel with shape (H*W, H*W). 84 | """ 85 | # Extract the selected channel 86 | selected_channel_data = samples[:, channel, :, :] # Shape: (B, H, W) 87 | 88 | # Flatten the spatial dimensions (H, W) into a single dimension 89 | B, H, W = selected_channel_data.shape 90 | flattened_data = selected_channel_data.view(B, -1) # Shape: (B, H*W) 91 | 92 | # Calculate the covariance matrix 93 | mean = torch.mean(flattened_data, dim=0, keepdim=True) # Mean along batch dimension 94 | centered_data = flattened_data - mean # Centering data 95 | covariance_matrix = centered_data.t() @ centered_data / (B - 1) # Covariance matrix 96 | 97 | return covariance_matrix 98 | 99 | def rand_tensor( 100 | shape: Union[Tuple, List], 101 | generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None, 102 | device: Optional["torch.device"] = None, 103 | dtype: Optional["torch.dtype"] = None, 104 | layout: Optional["torch.layout"] = None, 105 | ): 106 | # modified from diffusers.utils.torch_utils.randn_tensor 107 | """A helper function to create random tensors with uniform distribution on the desired `device` with the desired `dtype`. When 108 | passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor 109 | is always created on the CPU. 110 | """ 111 | # device on which tensor is created defaults to device 112 | rand_device = device 113 | batch_size = shape[0] 114 | 115 | layout = layout or torch.strided 116 | device = device or torch.device("cpu") 117 | 118 | if generator is not None: 119 | gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type 120 | if gen_device_type != device.type and gen_device_type == "cpu": 121 | rand_device = "cpu" 122 | if device != "mps": 123 | logger.info( 124 | f"The passed generator was created on 'cpu' even though a tensor on {device} was expected." 125 | f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably" 126 | f" slightly speed up this function by passing a generator that was created on the {device} device." 127 | ) 128 | elif gen_device_type != device.type and gen_device_type == "cuda": 129 | raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.") 130 | 131 | # make sure generator list of length 1 is treated like a non-list 132 | if isinstance(generator, list) and len(generator) == 1: 133 | generator = generator[0] 134 | 135 | if isinstance(generator, list): 136 | shape = (1,) + shape[1:] 137 | latents = [ 138 | torch.rand(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout) 139 | for i in range(batch_size) 140 | ] 141 | latents = torch.cat(latents, dim=0).to(device) 142 | else: 143 | latents = torch.rand(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) 144 | 145 | return latents 146 | 147 | def plot_channel(samples, channel, title, cb=False, mask=None, save=False): 148 | # samples need to be divisible by 4 149 | try: 150 | samples = samples.detach().cpu().numpy() 151 | if mask is not None: 152 | mask = mask.detach().cpu().numpy 153 | except: 154 | pass 155 | h2w_ratio = int(samples.shape[0]/4) 156 | fig, axes = plt.subplots(h2w_ratio, 4, figsize=(20, int(h2w_ratio*5)), sharey=True, sharex=True) 157 | fig.suptitle(title, fontsize=16) 158 | for i, ax in enumerate(axes.flatten()): 159 | if i < samples.shape[0]: 160 | im = ax.imshow(samples[i, channel, :, :], cmap='jet') 161 | ax.axis('off') 162 | if cb: 163 | fig.colorbar(im, ax=ax) # Add colorbar to the current axis 164 | if mask is not None: 165 | tmp_mask = mask[i, channel, :, :] 166 | mask_indices = np.where(tmp_mask == 1) 167 | ax.scatter(mask_indices[1], mask_indices[0], c='black', marker='x', s=7) 168 | else: 169 | ax.axis('off') # Hide empty plots 170 | plt.tight_layout() 171 | if save: 172 | plt.savefig(title + '.png', dpi=300, bbox_inches='tight') 173 | 174 | def plot_steps(samples, idx_in_batch=0, start_step=0, interval=1, mask=None): 175 | # samples shape: (T, B, C, H, W), mask shape: (B, C, H, W) 176 | try: 177 | samples = samples.detach().cpu().numpy() 178 | if mask is not None: 179 | mask = mask.detach().cpu().numpy() 180 | except: 181 | pass 182 | total_steps = samples.shape[0] 183 | num_images = (total_steps - start_step) // interval 184 | 185 | num_channels = samples.shape[2] 186 | fig, axes = plt.subplots(num_channels, num_images, figsize=(int(3*num_images), int(3*num_channels))) # Create 2 rows of subplots 187 | 188 | for i in range(num_images): 189 | step = start_step + (i+1) * interval - 1 190 | for channel in range(num_channels): 191 | axes[channel, i].imshow(samples[step, idx_in_batch, channel, :, :], cmap='jet') # Plot the image at the current step 192 | if mask is not None: 193 | tmp_mask = mask[idx_in_batch, channel, :, :] # Get the corresponding mask 194 | mask_indices = np.where(tmp_mask == 1) 195 | axes[channel, i].scatter(mask_indices[1], mask_indices[0], c='black', marker='x', s=6) # Overlay black crosses 196 | if channel == 0: 197 | axes[channel, i].set_title(f'Step {step+1}') # Set the title to the current step 198 | axes[channel, i].axis('off') # Hide the axis 199 | 200 | plt.tight_layout() 201 | plt.show() 202 | 203 | def plot_one_sample(samples, num_in_batch=0, cb=True, mask=None, channel_names=None, save_name=None, dpi=300): 204 | ''' 205 | samples: (B, C, H, W), mask: (B, C, H, W) 206 | ''' 207 | try: 208 | samples = samples.detach().cpu().numpy() 209 | if mask is not None: 210 | mask = mask.detach().cpu().numpy() 211 | except: 212 | pass 213 | num_images = samples.shape[1] 214 | image = samples[num_in_batch, :, :, :] 215 | fig, axes = plt.subplots(num_images, 1, figsize=(4, int(4*num_images)), sharey=True, sharex=True) 216 | for i, ax in enumerate(axes.flatten()): 217 | im = ax.imshow(image[i, :, :], cmap='jet', origin='lower') 218 | if channel_names is not None: 219 | ax.set_title(channel_names[i]) 220 | ax.axis('off') 221 | if cb: 222 | fig.colorbar(im, ax=ax) 223 | if mask is not None: 224 | tmp_mask = mask[num_in_batch, i, :, :] 225 | mask_indices = np.where(tmp_mask == 1) 226 | ax.scatter(mask_indices[1], mask_indices[0], c='black', marker='x', s=7) 227 | plt.tight_layout() 228 | #if title is not None: 229 | # plt.suptitle(title, fontsize=16) 230 | if save_name is not None: 231 | plt.savefig(save_name + '.png', dpi=dpi, bbox_inches='tight') 232 | 233 | def plot_horizontal(images_list, channel_names=None, image_names=None, save_name=None, mask=None, plot_mask_idx=[], 234 | which_cb=None, dpi=300, font_size=14): 235 | ''' 236 | images_list: list of images, each with shape (C, H, W) 237 | mask: Optional, shape (C, H, W) 238 | which_cb: Optional, index of the image to be used for scaling the color bar 239 | ''' 240 | try: 241 | tmp_images_list = [] 242 | for img in images_list: 243 | try: 244 | tmp_images_list.append(img.detach().cpu().numpy()) 245 | except: 246 | tmp_images_list.append(img) 247 | images_list = tmp_images_list 248 | if mask is not None: 249 | mask = mask.detach().cpu().numpy() 250 | except: 251 | pass 252 | num_images = len(images_list) 253 | num_channels = images_list[0].shape[0] 254 | 255 | fig, axes = plt.subplots(num_channels, num_images, figsize=(4*num_images, 4*num_channels), sharey=True, sharex=True) 256 | 257 | if which_cb is not None and 0 <= which_cb < num_images: 258 | cb_image = images_list[which_cb] 259 | vmin = cb_image.min(axis=(1, 2)) 260 | vmax = cb_image.max(axis=(1, 2)) 261 | else: 262 | vmin, vmax = None, None 263 | 264 | for img_idx, image in enumerate(images_list): 265 | for ch_idx in range(num_channels): 266 | ax = axes[ch_idx, img_idx] if num_channels > 1 else axes[img_idx] 267 | im = ax.imshow(image[ch_idx, :, :], cmap='jet', origin='lower', vmin=vmin[ch_idx] if vmin is not None else None, vmax=vmax[ch_idx] if vmax is not None else None) 268 | if mask is not None: 269 | if img_idx in plot_mask_idx: 270 | tmp_mask = mask[ch_idx, :, :] 271 | mask_indices = np.where(tmp_mask == 1) 272 | ax.scatter(mask_indices[1], mask_indices[0], c='black', marker='x', s=7) 273 | if channel_names is not None: 274 | if image_names is None: 275 | ax.set_title(channel_names[ch_idx], fontsize=font_size) 276 | else: 277 | ax.set_title(f'{channel_names[ch_idx]} - {image_names[img_idx]}', fontsize=font_size) 278 | ax.axis('off') 279 | 280 | if which_cb is not None: 281 | for ch_idx in range(num_channels): 282 | ax_pos = axes[ch_idx, -1].get_position() 283 | cbar_height = ax_pos.height * 0.8 284 | cbar_ax = fig.add_axes([ax_pos.x1 + 0.02, ax_pos.y0 + (ax_pos.height - cbar_height) / 2, 0.02, cbar_height]) 285 | fig.colorbar(plt.cm.ScalarMappable(norm=plt.Normalize(vmin=vmin[ch_idx], vmax=vmax[ch_idx]), cmap='jet'), cax=cbar_ax) 286 | 287 | plt.tight_layout(rect=[0, 0, 0.9, 1]) 288 | if save_name is not None: 289 | plt.savefig(save_name + '.png', dpi=dpi, bbox_inches='tight') 290 | plt.show() 291 | 292 | def plot_ensemble(samples, title, cb=True, mask=None, save=False, GT=None): 293 | # 1st row is mean, 2nd row is std 294 | # samples: (B, C, H, W), mask: (C, H, W) 295 | num_row = 2 if GT is None else 3 296 | w2h_ratio = int(samples.shape[1]/2) 297 | fig, axes = plt.subplots(num_row, samples.shape[1], figsize=(int(w2h_ratio*8), 8) ,sharey=True, sharex=True) 298 | fig.suptitle(title, fontsize=16) 299 | sample_mean = torch.mean(samples, dim=0).detach().cpu().numpy() 300 | sample_std = torch.std(samples, dim=0).detach().cpu().numpy() 301 | for i, ax in enumerate(axes.flatten()): 302 | if i < samples.shape[1]: 303 | im = ax.imshow(sample_mean[i, :, :], cmap='jet') 304 | ax.axis('off') 305 | if cb: 306 | fig.colorbar(im, ax=ax) 307 | if mask is not None: 308 | tmp_mask = mask[i, :, :].detach().cpu().numpy() 309 | mask_indices = np.where(tmp_mask == 1) 310 | ax.scatter(mask_indices[1], mask_indices[0], c='black', marker='x', s=7) 311 | elif i < 2*samples.shape[1]: 312 | im = ax.imshow(sample_std[i-samples.shape[1], :, :], cmap='jet') 313 | ax.axis('off') 314 | if cb: 315 | fig.colorbar(im, ax=ax) 316 | if mask is not None: 317 | tmp_mask = mask[i-samples.shape[1], :, :].detach().cpu().numpy() 318 | mask_indices = np.where(tmp_mask == 1) 319 | ax.scatter(mask_indices[1], mask_indices[0], c='black', marker='x', s=7) 320 | else: 321 | im = ax.imshow(GT[i-2*samples.shape[1], :, :], cmap='jet') 322 | ax.axis('off') 323 | if cb: 324 | fig.colorbar(im, ax=ax) 325 | plt.tight_layout() 326 | if save: 327 | plt.savefig(title + '.png', dpi=300, bbox_inches='tight') 328 | -------------------------------------------------------------------------------- /utils/inverse_utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import torch 3 | from diffusers.utils import make_image_grid 4 | from diffusers.utils.torch_utils import randn_tensor 5 | import os 6 | import numpy as np 7 | from einops import repeat 8 | #from tqdm.auto import tqdm 9 | from typing import List, Optional, Tuple, Union 10 | import copy 11 | 12 | from .general_utils import rand_tensor 13 | 14 | #''' 15 | # differences are negligible, but this one is faster 16 | @torch.no_grad() 17 | def create_scatter_mask(tensor, 18 | channels: List[int] = None, 19 | ratio: Union[float, torch.Tensor] = 0.1, 20 | x_idx = None, 21 | y_idx = None, 22 | generator = None, 23 | device = None): 24 | ''' 25 | return a mask that has the same shape as the input tensor, if multiple channels are specified, the same mask will be applied to all channels 26 | tensor: torch.Tensor 27 | channels: list of ints, denote the idx of known channels, default None. If None, all channels are masked 28 | ratio: float or array-like, default 0.1. The ratio of known elements 29 | x_idx, y_idx: int, default None. If not None, the mask will be applied to the specified indices. OrientationL (0,0) is the top left corner 30 | They can be either 2D or 1D tensors 31 | 32 | return: torch.Tensor (B, C, H, W) 33 | ''' 34 | #TODO: handle generator 35 | if device is None: 36 | device = tensor.device 37 | B, C, H, W = tensor.shape 38 | if channels is None: 39 | channels = torch.arange(C, device=device) # Ensure the same device as the input tensor 40 | else: 41 | channels = torch.tensor(channels, device=device) # Ensure the same device as the input tensor 42 | 43 | # Create a random mask for all elements 44 | if x_idx is not None and y_idx is not None: 45 | mask = torch.zeros(B, 1, H, W, device=device) 46 | mask[:, :, y_idx, x_idx] = 1 47 | if len(channels) > 1: 48 | mask = repeat(mask, 'B 1 H W -> B C H W', C=len(channels)) 49 | else: 50 | # For now, only support same mask for all channels 51 | mask = torch.zeros(B, 1, H, W, device=device) 52 | 53 | if isinstance(ratio, float) or ratio.numel() == 1: 54 | num_elements_to_select = int(H * W * ratio) 55 | ratios = [num_elements_to_select] * B 56 | else: 57 | ratios = [int(H * W * r) for r in ratio] 58 | 59 | for b in range(B): 60 | indices = torch.randperm(H * W, device=device)[:ratios[b]] 61 | mask[b, 0].view(-1)[indices] = 1 62 | 63 | if len(channels) > 1: 64 | mask = repeat(mask, 'B 1 H W -> B C H W', C=len(channels)) 65 | 66 | # Initialize the final mask with zeros 67 | final_mask = torch.zeros_like(tensor) 68 | mask = mask.type_as(final_mask) 69 | final_mask[:, channels, :, :] = mask 70 | 71 | return final_mask 72 | 73 | def create_patch_mask(tensor, channels=None, ratio=0.1): 74 | B, C, H, W = tensor.shape 75 | if channels is None: 76 | channels = range(C) # Assume apply to all channels 77 | patch_size = int(min(H, W) * ratio) 78 | start = (H - patch_size) // 2 79 | end = start + patch_size 80 | mask = torch.zeros_like(tensor) 81 | mask[:, channels] = 1 82 | mask[:, channels, start:end, start:end] = 0 83 | return mask 84 | 85 | @torch.no_grad() 86 | def edm_sampler_cond( 87 | net, noise_scheduler, batch_size=1, class_labels=None, randn_like=torch.randn_like, 88 | num_inference_steps=18, S_churn=0, S_min=0, S_max=float('inf'), S_noise=0, 89 | deterministic=True, mask=None, known_latents=None, known_channels=None, 90 | return_trajectory=False, add_noise_to_obs=False, 91 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 92 | device = 'cpu' 93 | ): 94 | ''' 95 | mask: torch.Tensor, shape (H, W) or (B, C, H, W), 1 for known values, 0 for unknown 96 | known_latents: torch.Tensor, shape (H, W) or (B, C, H, W), known values 97 | ''' 98 | if known_latents is not None: 99 | assert batch_size == known_latents.shape[0], "Batch size must match the known_latents shape" 100 | # Sample gaussian noise to begin loop 101 | 102 | if isinstance(device, str): 103 | device = torch.device(device) 104 | 105 | if isinstance(net.config.sample_size, int): 106 | latents_shape = ( 107 | batch_size, 108 | net.config.out_channels, 109 | net.config.sample_size, 110 | net.config.sample_size, 111 | ) 112 | else: 113 | latents_shape = (batch_size, net.config.out_channels, *net.config.sample_size) 114 | 115 | latents = randn_tensor(latents_shape, generator=generator, device=device, dtype=net.dtype) 116 | if add_noise_to_obs: 117 | noise = latents.clone() 118 | conditioning_tensors = torch.cat((known_latents, mask[:, [known_channels[0]]]), dim=1) 119 | noise_scheduler.set_timesteps(num_inference_steps, device=device) 120 | 121 | t_steps = noise_scheduler.sigmas.to(device) 122 | 123 | x_next = latents.to(torch.float64) * t_steps[0] 124 | if mask is not None: 125 | if len(mask.shape) == 2: 126 | mask = mask[None, None, ...].expand_as(x_next) 127 | else: 128 | mask = torch.zeros_like(x_next) 129 | if mask is not None: 130 | x_next = x_next * (1 - mask) + known_latents * mask 131 | 132 | if return_trajectory: 133 | whole_trajectory = torch.zeros((num_inference_steps, *x_next.shape), dtype=torch.float64) 134 | # Main sampling loop. 135 | for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 136 | x_cur = x_next 137 | if not deterministic: 138 | # Increase noise temporarily. 139 | gamma = min(S_churn / num_inference_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0 140 | t_hat = torch.as_tensor(t_cur + gamma * t_cur) 141 | x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur) * (1 - mask) 142 | else: 143 | t_hat = t_cur 144 | x_hat = x_cur 145 | 146 | if known_latents is not None and add_noise_to_obs: 147 | tmp_known_latents = known_latents.clone() 148 | tmp_known_latents = noise_scheduler.add_noise(tmp_known_latents, noise, t_hat.view(-1)) 149 | x_hat = x_hat * (1 - mask) + tmp_known_latents * mask 150 | 151 | tmp_x_hat = x_hat.clone() 152 | c_noise = noise_scheduler.precondition_noise(t_hat) 153 | # Euler step. 154 | tmp_x_hat = noise_scheduler.precondition_inputs(tmp_x_hat, t_hat) 155 | 156 | denoised = net(tmp_x_hat.to(torch.float32), c_noise.reshape(-1).to(torch.float32), conditioning_tensors).sample.to(torch.float64) 157 | denoised = noise_scheduler.precondition_outputs(x_hat, denoised, t_hat) 158 | 159 | d_cur = (x_hat - denoised) / t_hat # denoise has the same shape as x_hat (b, out_channels, h, w) 160 | x_next = x_hat + (t_next - t_hat) * d_cur * (1 - mask) 161 | 162 | # Apply 2nd order correction. 163 | if i < num_inference_steps - 1: 164 | 165 | if known_latents is not None and add_noise_to_obs: 166 | tmp_known_latents = known_latents.clone() 167 | tmp_known_latents = noise_scheduler.add_noise(tmp_known_latents, noise, t_next.view(-1)) 168 | x_next = x_next * (1 - mask) + tmp_known_latents * mask 169 | 170 | tmp_x_next = x_next.clone() 171 | c_noise = noise_scheduler.precondition_noise(t_next) 172 | 173 | tmp_x_next = noise_scheduler.precondition_inputs(tmp_x_next, t_next) 174 | 175 | denoised = net(tmp_x_next.to(torch.float32),c_noise.reshape(-1).to(torch.float32), conditioning_tensors).sample.to(torch.float64) 176 | denoised = noise_scheduler.precondition_outputs(x_next, denoised, t_next) 177 | 178 | d_prime = (x_next - denoised) / t_next 179 | x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) * (1 - mask) 180 | 181 | if return_trajectory: 182 | whole_trajectory[i] = x_next 183 | 184 | if return_trajectory: 185 | return x_next, whole_trajectory 186 | else: 187 | return x_next 188 | 189 | @torch.no_grad() 190 | def edm_sampler_uncond( 191 | net, noise_scheduler, batch_size=1, class_labels=None, randn_like=torch.randn_like, 192 | num_inference_steps=18, S_churn=0, S_min=0, S_max=float('inf'), S_noise=0, 193 | deterministic=True, mask=None, known_channels=None, known_latents=None, 194 | return_trajectory=False, 195 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 196 | device = 'cpu' 197 | ): 198 | ''' 199 | mask: torch.Tensor, shape (H, W) or (B, C, H, W), 1 for known values, 0 for unknown 200 | known_latents: torch.Tensor, shape (H, W) or (B, C, H, W), known values 201 | ''' 202 | if known_latents is not None: 203 | assert batch_size == known_latents.shape[0], "Batch size must match the known_latents shape" 204 | # Sample gaussian noise to begin loop 205 | 206 | if isinstance(device, str): 207 | device = torch.device(device) 208 | 209 | if isinstance(net.config.sample_size, int): 210 | latents_shape = ( 211 | batch_size, 212 | net.config.out_channels, 213 | net.config.sample_size, 214 | net.config.sample_size, 215 | ) 216 | else: 217 | latents_shape = (batch_size, net.config.out_channels, *net.config.sample_size) 218 | 219 | latents = randn_tensor(latents_shape, generator=generator, device=device, dtype=net.dtype) 220 | noise = latents.clone() 221 | noise_scheduler.set_timesteps(num_inference_steps, device=device) 222 | 223 | t_steps = noise_scheduler.sigmas.to(device) 224 | 225 | x_next = latents.to(torch.float64) * t_steps[0] # edm start with max sigma 226 | if mask is not None: 227 | if len(mask.shape) == 2: 228 | mask = mask[None, None, ...].expand_as(x_next) 229 | else: 230 | mask = torch.zeros_like(x_next) 231 | 232 | if return_trajectory: 233 | whole_trajectory = torch.zeros((num_inference_steps, *x_next.shape), dtype=torch.float64) 234 | # Main sampling loop. 235 | for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 236 | x_cur = x_next 237 | if not deterministic: 238 | # Increase noise temporarily. 239 | gamma = min(S_churn / num_inference_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0 240 | t_hat = torch.as_tensor(t_cur + gamma * t_cur) 241 | x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur) 242 | else: 243 | t_hat = t_cur 244 | x_hat = x_cur 245 | 246 | if known_latents is not None: 247 | tmp_known_latents = known_latents.clone() 248 | tmp_known_latents = noise_scheduler.add_noise(tmp_known_latents, noise, t_hat.view(-1)) 249 | x_hat = x_hat * (1 - mask) + tmp_known_latents * mask 250 | 251 | tmp_x_hat = x_hat.clone() 252 | c_noise = noise_scheduler.precondition_noise(t_hat) 253 | # Euler step. 254 | tmp_x_hat = noise_scheduler.precondition_inputs(tmp_x_hat, t_hat) 255 | 256 | denoised = net(tmp_x_hat.to(torch.float32), c_noise.reshape(-1).to(torch.float32), class_labels).sample.to(torch.float64) 257 | denoised = noise_scheduler.precondition_outputs(x_hat, denoised, t_hat) 258 | 259 | d_cur = (x_hat - denoised) / t_hat # denoise has the same shape as x_hat (b, out_channels, h, w) 260 | x_next = x_hat + (t_next - t_hat) * d_cur 261 | 262 | # Apply 2nd order correction. 263 | if i < num_inference_steps - 1: 264 | 265 | #""" 266 | if known_latents is not None: 267 | tmp_known_latents = known_latents.clone() 268 | tmp_known_latents = noise_scheduler.add_noise(tmp_known_latents, noise, t_next.view(-1)) 269 | x_next = x_next * (1 - mask) + tmp_known_latents * mask 270 | #""" 271 | 272 | tmp_x_next = x_next.clone() 273 | c_noise = noise_scheduler.precondition_noise(t_next) 274 | """ 275 | if mask is not None: 276 | tmp_x_next = torch.cat((tmp_x_next, concat_mask), dim=1) 277 | """ 278 | 279 | tmp_x_next = noise_scheduler.precondition_inputs(tmp_x_next, t_next) 280 | 281 | denoised = net(tmp_x_next.to(torch.float32),c_noise.reshape(-1).to(torch.float32), class_labels).sample.to(torch.float64) 282 | denoised = noise_scheduler.precondition_outputs(x_next, denoised, t_next) 283 | 284 | d_prime = (x_next - denoised) / t_next 285 | x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) 286 | 287 | if return_trajectory: 288 | whole_trajectory[i] = x_next 289 | 290 | if return_trajectory: 291 | return x_next, whole_trajectory 292 | else: 293 | return x_next 294 | 295 | @torch.no_grad() 296 | def ensemble_sample(pipeline, sample_size, mask, sampler_kwargs=None, class_labels=None, known_latents=None, 297 | batch_size=64, sampler_type: Optional[str] = 'edm', # 'edm' or 'pipeline' 298 | device='cpu', conditioning_type='xattn', # 'xattn' or 'cfg' 299 | ): 300 | batch_size_list = [batch_size]*int(sample_size/batch_size) + [sample_size % batch_size] 301 | #print(latents.shape, class_labels.shape, mask.shape, known_latents.shape) 302 | count = 0 303 | samples = torch.empty(sample_size, pipeline.unet.config.out_channels, pipeline.unet.config.sample_size, 304 | pipeline.unet.config.sample_size, device=device, dtype=pipeline.unet.dtype) 305 | if sampler_kwargs is None: 306 | sampler_kwargs = {} 307 | if sampler_type == 'edm': 308 | model = pipeline.unet 309 | noise_scheduler = copy.deepcopy(pipeline.scheduler) 310 | for num_sample in batch_size_list: 311 | #tmp_class_labels = repeat(class_labels, 'C -> B C', B=num_sample) 312 | generator = [torch.Generator(device).manual_seed(int(seed) % (1 << 32)) for seed in range(count, count+num_sample)] 313 | tmp_mask = repeat(mask, '1 C H W -> B C H W', B=num_sample) 314 | tmp_known_latents = repeat(known_latents, '1 C H W -> B C H W', B=num_sample) 315 | if sampler_type == 'edm': 316 | if conditioning_type == 'xattn' or conditioning_type == 'cfg': 317 | tmp_samples = edm_sampler_cond(model, noise_scheduler, batch_size=num_sample, generator=generator, device=device, 318 | class_labels=class_labels, mask=tmp_mask, known_latents=tmp_known_latents, **sampler_kwargs) 319 | elif conditioning_type == 'uncond': 320 | tmp_samples = edm_sampler_uncond(model, noise_scheduler, batch_size=num_sample, generator=generator, device=device, 321 | class_labels=class_labels, mask=tmp_mask, known_latents=tmp_known_latents, **sampler_kwargs) 322 | elif sampler_type == 'pipeline': 323 | tmp_samples = pipeline(batch_size=num_sample, generator=generator, 324 | mask=tmp_mask, known_latents=tmp_known_latents, return_dict=False, **sampler_kwargs)[0] 325 | samples[count:count+num_sample] = tmp_samples 326 | count += num_sample 327 | return samples 328 | 329 | def colored_noise(shape, noise_type='pink', device='cpu', normalize=False): 330 | """ 331 | Generate colored noise (pink, red, blue, purple) in the spatial domain. 332 | 333 | Args: 334 | shape (tuple): Shape of the noise tensor (b, c, h, w). 335 | noise_type (str): Type of noise ('white', 'pink', 'red', 'blue', 'purple'). 336 | device (str): Device for the tensor. 337 | normalize (bool): Whether to normalize the output to [-1, 1] range. 338 | 339 | Returns: 340 | torch.Tensor: Colored noise tensor (b, c, h, w). 341 | """ 342 | if len(shape) != 4: 343 | raise ValueError("Input shape must be of the form (b, c, h, w)") 344 | 345 | valid_noise_types = ['white', 'pink', 'red', 'blue', 'purple'] 346 | if noise_type not in valid_noise_types: 347 | raise ValueError(f"Noise type must be one of {valid_noise_types}") 348 | 349 | b, c, h, w = shape 350 | 351 | # Initialize the output noise tensor 352 | output_noise = torch.zeros(shape, device=device) 353 | 354 | # Loop over the batch and channel dimensions 355 | for batch in range(b): 356 | for channel in range(c): 357 | # Generate white noise for the current (h, w) slice 358 | white_noise = torch.randn(h, w, device=device) 359 | 360 | # Apply Fourier transform to convert to frequency domain 361 | noise_fft = torch.fft.rfftn(white_noise, dim=(-2, -1)) 362 | 363 | # Create frequency grid for both dimensions 364 | freqs_x = torch.fft.fftfreq(h, d=1.0).to(device) 365 | freqs_y = torch.fft.rfftfreq(w, d=1.0).to(device) 366 | 367 | # Generate 2D frequency grid 368 | freq_grid = torch.sqrt(freqs_x[:, None]**2 + freqs_y[None, :]**2) 369 | eps = torch.finfo(freq_grid.dtype).eps 370 | 371 | # Modify the amplitude spectrum based on the type of colored noise 372 | spectral_factors = { 373 | 'white': lambda f: torch.ones_like(f), 374 | 'pink': lambda f: 1.0 / torch.sqrt(f + eps), 375 | 'red': lambda f: 1.0 / (f + eps), 376 | 'blue': lambda f: torch.sqrt(f + eps), 377 | 'purple': lambda f: f + eps 378 | } 379 | 380 | factor = spectral_factors[noise_type](freq_grid) 381 | 382 | # Handle DC component (zero frequency) specially 383 | if noise_type in ['pink', 'red']: 384 | factor[0, 0] = 1.0 385 | 386 | # Multiply the amplitude spectrum by the factor 387 | noise_fft *= factor 388 | 389 | # Inverse Fourier transform back to the spatial domain 390 | colored_noise = torch.fft.irfftn(noise_fft, s=(h, w), dim=(-2, -1)) 391 | 392 | # Normalize if requested 393 | if normalize: 394 | colored_noise = 2.0 * (colored_noise - colored_noise.min()) / (colored_noise.max() - colored_noise.min()) - 1.0 395 | 396 | # Assign the generated noise to the corresponding slice in the output tensor 397 | output_noise[batch, channel] = colored_noise 398 | 399 | return output_noise -------------------------------------------------------------------------------- /utils/pipeline_utils.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union 2 | from diffusers.utils import BaseOutput 3 | from dataclasses import dataclass 4 | import torch 5 | import numpy as np 6 | 7 | @dataclass 8 | class Fields2DPipelineOutput(BaseOutput): 9 | """ 10 | Output class for image pipelines. 11 | 12 | Args: 13 | fields torch or numpy array of shape (batch_size, channels, height, width): 14 | """ 15 | 16 | fields: Union[torch.tensor, np.ndarray] 17 | 18 | def get_sigmas(noise_scheduler, timesteps, n_dim=4, dtype=torch.float32, device='cpu'): 19 | # modified from diffusers/examples/dreambooth/train_dreambooth_lora_sdxl.py 20 | sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype) 21 | schedule_timesteps = noise_scheduler.timesteps.to(device) 22 | timesteps = timesteps.to(device) 23 | step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] 24 | 25 | sigma = sigmas[step_indices].flatten() 26 | while len(sigma.shape) < n_dim: 27 | sigma = sigma.unsqueeze(-1) 28 | return sigma -------------------------------------------------------------------------------- /utils/vt_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import matplotlib.pyplot as plt 4 | 5 | def get_grid_points_from_mask(batch_idx, channel_idx, mask): 6 | ''' 7 | mask: torch.Tensor (B, C, H, W) 8 | ''' 9 | if len(mask.shape) == 3: 10 | # maks (C, H, W) 11 | mask = mask.unsqueeze(0) 12 | flatten_batch_idx, flatten_channel_idx, flatten_y_idx, flatten_x_idx = torch.nonzero(mask, as_tuple=True) 13 | target_idx = torch.logical_and(flatten_batch_idx == batch_idx, flatten_channel_idx == channel_idx) 14 | return torch.column_stack((flatten_x_idx[target_idx], flatten_y_idx[target_idx])) 15 | 16 | class vt_obs(object): 17 | def __init__(self, x_dim, y_dim, x_spacing, y_spacing, known_channels=None, device='cpu'): 18 | self.x_dim = x_dim 19 | self.y_dim = y_dim 20 | self.x_spacing = x_spacing 21 | self.y_spacing = y_spacing 22 | self.x_list = torch.arange(0, x_dim-x_spacing+1, x_spacing, device=device) 23 | self.y_list = torch.arange(0, y_dim-y_spacing+1, y_spacing, device=device) 24 | self.x_start_grid, self.y_start_grid = torch.meshgrid(self.x_list, self.y_list, indexing='ij') 25 | self.grid_x, self.grid_y = torch.meshgrid(torch.linspace(0, x_dim-1, x_dim), 26 | torch.linspace(0, y_dim-1, y_dim), 27 | indexing='xy') 28 | self.known_channels = known_channels 29 | self.device = device 30 | 31 | self.x_start_grid = self.x_start_grid.to(device) 32 | self.y_start_grid = self.y_start_grid.to(device) 33 | self.grid_x = self.grid_x.to(device) 34 | self.grid_y = self.grid_y.to(device) 35 | 36 | @torch.no_grad() 37 | def structure_obs(self, generator=None): 38 | x_offset = torch.randint(0, self.x_spacing, self.x_start_grid.shape, 39 | device=self.device, generator=generator) 40 | y_offset = torch.randint(0, self.y_spacing, self.y_start_grid.shape, 41 | device=self.device, generator=generator) 42 | x_coords = (self.x_start_grid + x_offset).flatten() 43 | y_coords = (self.y_start_grid + y_offset).flatten() 44 | return x_coords, y_coords 45 | 46 | @torch.no_grad() 47 | def _get_grid_points(self, x_coords=None, y_coords=None, generator=None): 48 | if x_coords is None and y_coords is None: 49 | x_coords, y_coords = self.structure_obs(generator=generator) 50 | return torch.column_stack((x_coords, y_coords)) 51 | 52 | @torch.no_grad() 53 | def _torch_griddata_nearest(self, points, values, xi): 54 | distances = torch.cdist(xi, points) 55 | nearest_indices = torch.argmin(distances, dim=1) 56 | interpolated_values = values[nearest_indices] 57 | return interpolated_values.reshape(self.y_dim, self.x_dim) 58 | 59 | @torch.no_grad() 60 | def interpolate(self, grid_points, field): 61 | ''' 62 | return griddata( 63 | grid_points, 64 | field, 65 | (self.grid_x, self.grid_y), 66 | method='nearest' 67 | ) 68 | ''' 69 | return self._torch_griddata_nearest( 70 | grid_points.float(), 71 | field, 72 | torch.stack((self.grid_x.flatten(), self.grid_y.flatten()), dim=1).float(), 73 | ) 74 | 75 | def _plot_vt(self, known_fields, mask=None, x_coords=None, y_coords=None, plot_scatter=True, cb=True): 76 | ''' 77 | known_fields: (C, H, W) 78 | mask: (C, H, W) 79 | mask_channel_idx: int, if using same mask, input the corresponding channel index 80 | ''' 81 | C, H, W = known_fields.shape 82 | if mask is None: 83 | grid_points = self._get_grid_points(x_coords=x_coords, y_coords=y_coords).to(self.device) 84 | in_channels = C if self.known_channels is None else len(self.known_channels) 85 | interpolated_fields = torch.zeros(in_channels, self.y_dim, self.x_dim, dtype=known_fields.dtype) 86 | for idx, known_channel in enumerate(range(C) if self.known_channels is None else self.known_channels): 87 | if mask is not None: 88 | grid_points = get_grid_points_from_mask(0, known_channel, mask).to(self.device) 89 | field = known_fields[known_channel][grid_points[:,1], grid_points[:,0]].flatten() 90 | interpolated_values = self.interpolate(grid_points, field) 91 | interpolated_fields[idx] = torch.tensor(interpolated_values, 92 | dtype=known_fields.dtype, 93 | device=self.device) 94 | if x_coords is None and y_coords is None: 95 | x_coords, y_coords = grid_points[:,0], grid_points[:,1] 96 | fig, axs = plt.subplots(in_channels, 1, figsize=(4, 4*in_channels)) 97 | if in_channels == 1: 98 | axs = [axs] 99 | for c in range(in_channels): 100 | im = axs[c].imshow(interpolated_fields[c].cpu().numpy(), cmap='jet', origin='lower') 101 | axs[c].axis('off') 102 | if plot_scatter: 103 | axs[c].scatter(x_coords.cpu().numpy(), y_coords.cpu().numpy(), c='r', s=1) 104 | if cb: 105 | fig.colorbar(im, ax=axs[c]) 106 | plt.tight_layout() 107 | plt.show() 108 | 109 | @torch.no_grad() 110 | def __call__(self, known_fields, mask=None, x_coords=None, y_coords=None, generator=None): 111 | # known_fields: (B, C, H, W) 112 | B, C, _, _ = known_fields.shape 113 | in_channels = C if self.known_channels is None else len(self.known_channels) 114 | interpolated_fields = torch.zeros(B, in_channels, self.y_dim, self.x_dim, device=known_fields.device, dtype=known_fields.dtype) 115 | 116 | for b in range(B): 117 | if mask is None: 118 | grid_points = self._get_grid_points(x_coords=x_coords, y_coords=y_coords, generator=generator).to(self.device) 119 | 120 | for idx, known_channel in enumerate(range(C) if self.known_channels is None else self.known_channels): 121 | if mask is not None: 122 | grid_points = get_grid_points_from_mask(b, known_channel, mask).to(self.device) 123 | field = known_fields[b, known_channel][grid_points[:,1], grid_points[:,0]].flatten() 124 | interpolated_values = self.interpolate(grid_points, field) 125 | interpolated_fields[b, idx] = interpolated_values 126 | 127 | return interpolated_fields --------------------------------------------------------------------------------