├── README.md ├── config ├── sp │ ├── SR3_sp_cosine_128.json │ ├── SR3_sp_cosine_128_finetune.json │ └── SR3_sp_cosine_128_finetune_infer.json ├── t2m │ ├── SR3_t2m_cosine_128.json │ ├── SR3_t2m_cosine_128_finetune.json │ └── SR3_t2m_cosine_128_finetune_infer.json ├── tp │ ├── SR3_tp_cosine_128.json │ ├── SR3_tp_cosine_128_finetune.json │ └── SR3_tp_cosine_128_finetune_infer.json ├── u │ ├── SR3_u.json │ ├── SR3_u_finetune.json │ └── SR3_u_finetune_infer.json └── v │ ├── SR3_v.json │ ├── SR3_v_finetune.json │ └── SR3_v_infer.json ├── configs.py ├── data ├── .ipynb_checkpoints │ └── mydataset_patch-checkpoint.py ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-39.pyc │ └── mydataset_patch.cpython-39.pyc ├── mydataset_patch.py └── test.png ├── finetune_all.py ├── inference_2_monthly.py ├── model ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-39.pyc │ ├── base_model.cpython-39.pyc │ ├── ema.cpython-39.pyc │ ├── model.cpython-39.pyc │ └── networks.cpython-39.pyc ├── base_model.py ├── diffusion │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-39.pyc │ │ ├── ddim.cpython-39.pyc │ │ ├── ddpm.cpython-39.pyc │ │ └── unet.cpython-39.pyc │ ├── ddim.py │ ├── ddpm.py │ ├── dpm_solver │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-39.pyc │ │ │ ├── dpm_solver.cpython-39.pyc │ │ │ └── sampler.cpython-39.pyc │ │ ├── dpm_solver.py │ │ └── sampler.py │ └── unet.py ├── ema.py ├── model.py └── networks.py ├── pytorch_ssim.py ├── src.png ├── trainer_all.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Diffusion_4_downscaling 2 | 3 | Diffusion Model-based Probabilistic Downscaling for Long-term East Asian climate reconstruction 4 | ![image](./src.png) 5 | 6 | ## References 7 | 8 | - Liangwei Jiang (2021) Image-Super-Resolution-via-Iterative-Refinement [[Source code](https://github.com/Janspiry/Image-Super-Resolution-via-Iterative-Refinement#readme)] 9 | - Song et al. (2021) Score-Based Generative Modeling through Stochastic Differential Equations [[Source code](https://github.com/yang-song/score_sde_pytorch)] 10 | - Davit Papikyan et al. (2022) Probabilistic Downscaling of Climate Variables Using Denoising Diffusion Probabilistic Models[[Source code](https://github.com/davitpapikyan/Probabilistic-Downscaling-of-Climate-Variables/))] 11 | - Robin Rombach et al. (2022)High-Resolution Image Synthesis with Latent Diffusion Models [[Source code](https://github.com/CompVis/stable-diffusion)] 12 | - Cheng Lu et al. (2022) DPM-Solver: A Fast ODE Solver for Diffusion Probabilistic Model Sampling in Around 10 Steps[[Source code](https://github.com/LuChengTHU/dpm-solver/)] 13 | -------------------------------------------------------------------------------- /config/sp/SR3_sp_cosine_128.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "SR3_sp_norm", 3 | "phase": "train", 4 | "gpu_ids": [2], 5 | "path": { 6 | "log": "logs", 7 | "tb_logger": "tb_logger", 8 | "results": "results", 9 | "checkpoint": "checkpoint", 10 | "resume_state":null 11 | }, 12 | "data": { 13 | "batch_size": 128, 14 | "num_workers": 6, 15 | "use_shuffle": true, 16 | "height": 128 17 | }, 18 | "model": { 19 | "finetune_norm": false, 20 | "unet": { 21 | "in_channel": 4, // [noisy,lr,land,mask]. 22 | "out_channel": 1, // [HR]. 23 | "inner_channel": 64, // channel 24 | "norm_groups": 32, //group 25 | "channel_multiplier": [1,1,2, 4], //64 32,16 26 | "attn_res": [32,16], // attention. 27 | "res_blocks": 1, 28 | "dropout": 0.1, 29 | "init_method": "kaiming" 30 | }, 31 | "beta_schedule": { 32 | "train": { 33 | "schedule": "cosine", 34 | "n_timestep": 1000, 35 | "linear_start": 1e-6, 36 | "linear_end": 1e-3 37 | }, 38 | "val": { 39 | "schedule": "cosine", 40 | "n_timestep": 1000, 41 | "linear_start": 1e-6, 42 | "linear_end": 1e-3 43 | }, 44 | "test": { 45 | "schedule": "cosine", 46 | "n_timestep": 1000, 47 | "linear_start": 1e-6, 48 | "linear_end": 1e-3 49 | } 50 | }, 51 | "diffusion": { 52 | "conditional": true, 53 | "loss": "l2" 54 | } 55 | }, 56 | "training": { 57 | "epoch_n_iter": 1000000, 58 | "val_freq": 10000,//25000 59 | "save_checkpoint_freq": 10000, 60 | "print_freq": 2500, 61 | "n_val_vis": 1, 62 | "val_vis_freq": 170,//30, 63 | "sample_size":5, 64 | "optimizer": { 65 | "type": "adamw", // Possible types are ['adam', 'adamw'] 66 | "amsgrad": false, 67 | "lr": 1e-4 68 | } 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /config/sp/SR3_sp_cosine_128_finetune.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "SR3_sp_norm", 3 | "phase": "train", 4 | "gpu_ids": [2], 5 | "path": { 6 | "log": "logs", 7 | "tb_logger": "tb_logger", 8 | "results": "results", 9 | "checkpoint": "checkpoint", 10 | "resume_state":"/home/data/downscaling/downscaling_1023/DDIM/SR3_new/experiments/SR3_sp_norm_230605_103122/checkpoint/I200000_E332" 11 | }, 12 | "data": { 13 | "batch_size": 128, 14 | "num_workers": 6, 15 | "use_shuffle": true, 16 | "height": 128 17 | }, 18 | "model": { 19 | "finetune_norm": false, 20 | "unet": { 21 | "in_channel": 4, // [noisy,lr,land,mask]. 22 | "out_channel": 1, // [HR]. 23 | "inner_channel": 64, // channel 24 | "norm_groups": 32, // group 25 | "channel_multiplier": [1,1,2, 4], //64 32,16 26 | "attn_res": [32,16], // 在同样的维度下加attention. 27 | "res_blocks": 1, 28 | "dropout": 0.1, 29 | "init_method": "kaiming" 30 | }, 31 | "beta_schedule": { 32 | "train": { 33 | "schedule": "cosine", 34 | "n_timestep": 1000, 35 | "linear_start": 1e-6, 36 | "linear_end": 1e-3 37 | }, 38 | "val": { 39 | "schedule": "cosine", 40 | "n_timestep": 1000, 41 | "linear_start": 1e-6, 42 | "linear_end": 1e-3 43 | }, 44 | "test": { 45 | "schedule": "cosine", 46 | "n_timestep": 1000, 47 | "linear_start": 1e-6, 48 | "linear_end": 1e-3 49 | } 50 | }, 51 | "diffusion": { 52 | "conditional": true, 53 | "loss": "l2" 54 | } 55 | }, 56 | "training": { 57 | "epoch_n_iter": 1000000, 58 | "val_freq": 10000,//25000 59 | "save_checkpoint_freq": 10000, 60 | "print_freq": 2500, 61 | "n_val_vis": 1, 62 | "val_vis_freq": 170,//30, 63 | "sample_size":5, 64 | "optimizer": { 65 | "type": "adamw", // Possible types are ['adam', 'adamw'] 66 | "amsgrad": false, 67 | "lr": 1e-4 68 | } 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /config/sp/SR3_sp_cosine_128_finetune_infer.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "SR3_sp_norm", 3 | "phase": "val", 4 | "gpu_ids": [2], 5 | "path": { 6 | "log": "logs", 7 | "tb_logger": "tb_logger", 8 | "results": "results", 9 | "checkpoint": "checkpoint", 10 | "resume_state":"/home/data/downscaling/downscaling_1023/DDIM/SR3_publish/experiments/SR3_sp_norm_230605_103122/checkpoint/I300000_E498" 11 | }, 12 | "data": { 13 | "batch_size": 128, 14 | "num_workers": 6, 15 | "use_shuffle": true, 16 | "height": 128 17 | }, 18 | "model": { 19 | "finetune_norm": false, 20 | "unet": { 21 | "in_channel": 4, // [noisy,lr,land,mask]. 22 | "out_channel": 1, // [HR]. 23 | "inner_channel": 64, // channel 24 | "norm_groups": 32, // group 25 | "channel_multiplier": [1,1,2, 4], //64 32,16 26 | "attn_res": [32,16], // 在同样的维度下加attention. 27 | "res_blocks": 1, 28 | "dropout": 0.1, 29 | "init_method": "kaiming" 30 | }, 31 | "beta_schedule": { 32 | "train": { 33 | "schedule": "cosine", 34 | "n_timestep": 1000, 35 | "linear_start": 1e-6, 36 | "linear_end": 1e-3 37 | }, 38 | "val": { 39 | "schedule": "cosine", 40 | "n_timestep": 1000, 41 | "linear_start": 1e-6, 42 | "linear_end": 1e-3 43 | }, 44 | "test": { 45 | "schedule": "cosine", 46 | "n_timestep": 1000, 47 | "linear_start": 1e-6, 48 | "linear_end": 1e-3 49 | } 50 | }, 51 | "diffusion": { 52 | "conditional": true, 53 | "loss": "l2" 54 | } 55 | }, 56 | "training": { 57 | "epoch_n_iter": 1000000, 58 | "val_freq": 10000,//25000 59 | "save_checkpoint_freq": 10000, 60 | "print_freq": 2500, 61 | "n_val_vis": 1, 62 | "val_vis_freq": 170,//30, 63 | "sample_size":100, 64 | "optimizer": { 65 | "type": "adamw", // Possible types are ['adam', 'adamw'] 66 | "amsgrad": false, 67 | "lr": 1e-4 68 | } 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /config/t2m/SR3_t2m_cosine_128.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "SR3_t2m_cosine_end1e-3_drop_0.1", 3 | "phase": "train", 4 | "gpu_ids": [2], 5 | "path": { 6 | "log": "logs", 7 | "tb_logger": "tb_logger", 8 | "results": "results", 9 | "checkpoint": "checkpoint", 10 | "resume_state":"/home/data/downscaling/downscaling_1023/DDIM/SR3_new/experiments/SR3_t2m_cosine_end1e-3_drop_0.1_230515_170651/checkpoint/I170000_E282" 11 | }, 12 | "data": { 13 | "batch_size": 128, 14 | "num_workers": 6, 15 | "use_shuffle": true, 16 | "height": 128 17 | }, 18 | "model": { 19 | "finetune_norm": false, 20 | "unet": { 21 | "in_channel": 4, // [noisy,lr,land,mask]. 22 | "out_channel": 1, // [HR]. 23 | "inner_channel": 64, // channel 24 | "norm_groups": 32, // normgroup 25 | "channel_multiplier": [1,1,2, 4], //64 32,16 26 | "attn_res": [32,16], // 在同样的维度下加attention. 27 | "res_blocks": 1, 28 | "dropout": 0.1, 29 | "init_method": "kaiming" 30 | }, 31 | "beta_schedule": { 32 | "train": { 33 | "schedule": "cosine", 34 | "n_timestep": 1000, 35 | "linear_start": 1e-6, 36 | "linear_end": 1e-3 37 | }, 38 | "val": { 39 | "schedule": "cosine", 40 | "n_timestep": 1000, 41 | "linear_start": 1e-6, 42 | "linear_end": 1e-3 43 | }, 44 | "test": { 45 | "schedule": "cosine", 46 | "n_timestep": 1000, 47 | "linear_start": 1e-6, 48 | "linear_end": 1e-3 49 | } 50 | }, 51 | "diffusion": { 52 | "conditional": true, 53 | "loss": "l2" 54 | } 55 | }, 56 | "training": { 57 | "epoch_n_iter": 1000000, 58 | "val_freq": 10000,//25000 59 | "save_checkpoint_freq": 10000, 60 | "print_freq": 2500, 61 | "n_val_vis": 1, 62 | "val_vis_freq": 170,//30, 63 | "sample_size":5, 64 | "optimizer": { 65 | "type": "adamw", // Possible types are ['adam', 'adamw'] 66 | "amsgrad": false, 67 | "lr": 1e-4 68 | } 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /config/t2m/SR3_t2m_cosine_128_finetune.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "SR3_t2m_cosine_end1e-3_drop_0.1", 3 | "phase": "train", 4 | "gpu_ids": [2], 5 | "path": { 6 | "log": "logs", 7 | "tb_logger": "tb_logger", 8 | "results": "results", 9 | "checkpoint": "checkpoint", 10 | "resume_state":"/home/data/downscaling/downscaling_1023/DDIM/SR3_new/experiments/SR3_t2m_cosine_end1e-3_drop_0.1_230515_170651/checkpoint/I170000_E282" 11 | }, 12 | "data": { 13 | "batch_size": 128, 14 | "num_workers": 6, 15 | "use_shuffle": true, 16 | "height": 128 17 | }, 18 | "model": { 19 | "finetune_norm": false, 20 | "unet": { 21 | "in_channel": 4, // [noisy,lr,land,mask]. 22 | "out_channel": 1, // [HR]. 23 | "inner_channel": 64, // channel 24 | "norm_groups": 32, // normgroup 25 | "channel_multiplier": [1,1,2, 4], //64 32,16 26 | "attn_res": [32,16], // 在同样的维度下加attention. 27 | "res_blocks": 1, 28 | "dropout": 0.1, 29 | "init_method": "kaiming" 30 | }, 31 | "beta_schedule": { 32 | "train": { 33 | "schedule": "cosine", 34 | "n_timestep": 1000, 35 | "linear_start": 1e-6, 36 | "linear_end": 1e-3 37 | }, 38 | "val": { 39 | "schedule": "cosine", 40 | "n_timestep": 1000, 41 | "linear_start": 1e-6, 42 | "linear_end": 1e-3 43 | }, 44 | "test": { 45 | "schedule": "cosine", 46 | "n_timestep": 1000, 47 | "linear_start": 1e-6, 48 | "linear_end": 1e-3 49 | } 50 | }, 51 | "diffusion": { 52 | "conditional": true, 53 | "loss": "l2" 54 | } 55 | }, 56 | "training": { 57 | "epoch_n_iter": 1000000, 58 | "val_freq": 10000,//25000 59 | "save_checkpoint_freq": 10000, 60 | "print_freq": 2500, 61 | "n_val_vis": 1, 62 | "val_vis_freq": 170,//30, 63 | "sample_size":5, 64 | "optimizer": { 65 | "type": "adamw", // Possible types are ['adam', 'adamw'] 66 | "amsgrad": false, 67 | "lr": 1e-4 68 | } 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /config/t2m/SR3_t2m_cosine_128_finetune_infer.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "SR3_t2m_cosine_end1e-3_drop_0.1", 3 | "phase": "val", 4 | "gpu_ids": [2], 5 | "path": { 6 | "log": "logs", 7 | "tb_logger": "tb_logger", 8 | "results": "results", 9 | "checkpoint": "checkpoint", 10 | "resume_state":"/home/data/downscaling/downscaling_1023/DDIM/SR3_publish/experiments/SR3_t2m_cosine_end1e-3_drop_0.1_230515_170651/checkpoint/I300000_E498" 11 | }, 12 | "data": { 13 | "batch_size": 128, 14 | "num_workers": 6, 15 | "use_shuffle": true, 16 | "height": 128 17 | }, 18 | "model": { 19 | "finetune_norm": false, 20 | "unet": { 21 | "in_channel": 4, // [noisy,lr,land,mask] 22 | "out_channel": 1, // [HR]. 23 | "inner_channel": 64, // channel 24 | "norm_groups": 32, // normgroup 25 | "channel_multiplier": [1,1,2, 4], //64 32,16 26 | "attn_res": [32,16], // 在同样的维度下加attention. 27 | "res_blocks": 1, 28 | "dropout": 0.1, 29 | "init_method": "kaiming" 30 | }, 31 | "beta_schedule": { 32 | "train": { 33 | "schedule": "cosine", 34 | "n_timestep": 1000, 35 | "linear_start": 1e-6, 36 | "linear_end": 1e-3 37 | }, 38 | "val": { 39 | "schedule": "cosine", 40 | "n_timestep": 1000, 41 | "linear_start": 1e-6, 42 | "linear_end": 1e-3 43 | }, 44 | "test": { 45 | "schedule": "cosine", 46 | "n_timestep": 1000, 47 | "linear_start": 1e-6, 48 | "linear_end": 1e-3 49 | } 50 | }, 51 | "diffusion": { 52 | "conditional": true, 53 | "loss": "l2" 54 | } 55 | }, 56 | "training": { 57 | "epoch_n_iter": 1000000, 58 | "val_freq": 10000,//25000 59 | "save_checkpoint_freq": 10000, 60 | "print_freq": 2500, 61 | "n_val_vis": 1, 62 | "val_vis_freq": 170,//30, 63 | "sample_size":33, 64 | "optimizer": { 65 | "type": "adamw", // Possible types are ['adam', 'adamw'] 66 | "amsgrad": false, 67 | "lr": 1e-4 68 | } 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /config/tp/SR3_tp_cosine_128.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "SR3_tp_cosine_end1e-3_drop_0.1", 3 | "phase": "train", 4 | "gpu_ids": [2], 5 | "path": { 6 | "log": "logs", 7 | "tb_logger": "tb_logger", 8 | "results": "results", 9 | "checkpoint": "checkpoint", 10 | "resume_state":null 11 | }, 12 | "data": { 13 | "batch_size": 128, 14 | "num_workers": 6, 15 | "use_shuffle": true, 16 | "height": 128 17 | }, 18 | "model": { 19 | "finetune_norm": false, 20 | "unet": { 21 | "in_channel": 4, // [noisy,lr,land,mask] 22 | "out_channel": 1, // [HR]. 23 | "inner_channel": 64, // channel 24 | "norm_groups": 32, // normgroup 25 | "channel_multiplier": [1,1,2, 4], //64 32,16 26 | "attn_res": [32,16], // 在同样的维度下加attention. 27 | "res_blocks": 1, 28 | "dropout": 0.1, 29 | "init_method": "kaiming" 30 | }, 31 | "beta_schedule": { 32 | "train": { 33 | "schedule": "cosine", 34 | "n_timestep": 1000, 35 | "linear_start": 1e-6, 36 | "linear_end": 1e-3 37 | }, 38 | "val": { 39 | "schedule": "cosine", 40 | "n_timestep": 1000, 41 | "linear_start": 1e-6, 42 | "linear_end": 1e-3 43 | }, 44 | "test": { 45 | "schedule": "cosine", 46 | "n_timestep": 1000, 47 | "linear_start": 1e-6, 48 | "linear_end": 1e-3 49 | } 50 | }, 51 | "diffusion": { 52 | "conditional": true, 53 | "loss": "l2" 54 | } 55 | }, 56 | "training": { 57 | "epoch_n_iter": 1000000, 58 | "val_freq": 10000,//25000 59 | "save_checkpoint_freq": 10000, 60 | "print_freq": 2500, 61 | "n_val_vis": 1, 62 | "val_vis_freq": 85,//30, 63 | "sample_size":5, 64 | "optimizer": { 65 | "type": "adamw", // Possible types are ['adam', 'adamw'] 66 | "amsgrad": false, 67 | "lr": 1e-4 68 | } 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /config/tp/SR3_tp_cosine_128_finetune.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "SR3_tp_cosine_end1e-3_drop_0.1", 3 | "phase": "train", 4 | "gpu_ids": [2], 5 | "path": { 6 | "log": "logs", 7 | "tb_logger": "tb_logger", 8 | "results": "results", 9 | "checkpoint": "checkpoint", 10 | "resume_state":"/home/data/downscaling/downscaling_1023/DDIM/SR3_new/experiments/SR3_tp_cosine_end1e-3_drop_0.1_230512_191807/checkpoint/I330000_E547" 11 | }, 12 | "data": { 13 | "batch_size": 128, 14 | "num_workers": 6, 15 | "use_shuffle": true, 16 | "height": 128 17 | }, 18 | "model": { 19 | "finetune_norm": false, 20 | "unet": { 21 | "in_channel": 4, // [noisy,lr,land,mask] 22 | "out_channel": 1, // [HR] 23 | "inner_channel": 64, // channel 24 | "norm_groups": 32, // normgroup 25 | "channel_multiplier": [1,1,2, 4], //每一个维度下降的参数64 32,16 26 | "attn_res": [32,16], // 在同样的维度下加attention 27 | "res_blocks": 1, 28 | "dropout": 0.1, 29 | "init_method": "kaiming" 30 | }, 31 | "beta_schedule": { 32 | "train": { 33 | "schedule": "cosine", 34 | "n_timestep": 1000, 35 | "linear_start": 1e-6, // 可修改 36 | "linear_end": 1e-3 37 | }, 38 | "val": { 39 | "schedule": "cosine", 40 | "n_timestep": 1000, 41 | "linear_start": 1e-6, 42 | "linear_end": 1e-3 43 | }, 44 | "test": { 45 | "schedule": "cosine", 46 | "n_timestep": 1000, 47 | "linear_start": 1e-6, 48 | "linear_end": 1e-3 49 | } 50 | }, 51 | "diffusion": { 52 | "conditional": true, 53 | "loss": "l2" 54 | } 55 | }, 56 | "training": { 57 | "epoch_n_iter": 1000000, 58 | "val_freq": 10000,//25000 59 | "save_checkpoint_freq": 10000, 60 | "print_freq": 2500, 61 | "n_val_vis": 1, 62 | "val_vis_freq": 85,//30, 63 | "sample_size":5, 64 | "optimizer": { 65 | "type": "adamw", // Possible types are ['adam', 'adamw'] 66 | "amsgrad": false, 67 | "lr": 1e-4 68 | } 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /config/tp/SR3_tp_cosine_128_finetune_infer.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "SR3_tp_cosine_end1e-3_drop_0.1", 3 | "phase": "val", 4 | "gpu_ids": [2], 5 | "path": { 6 | "log": "logs", 7 | "tb_logger": "tb_logger", 8 | "results": "results", 9 | "checkpoint": "checkpoint", 10 | "resume_state":"/home/data/downscaling/downscaling_1023/DDIM/SR3_publish/experiments/SR3_tp_cosine_end1e-3_drop_0.1_230512_191807/checkpoint/I480000_E796" 11 | }, 12 | "data": { 13 | "batch_size": 128, 14 | "num_workers": 6, 15 | "use_shuffle": true, 16 | "height": 128 17 | }, 18 | "model": { 19 | "finetune_norm": false, 20 | "unet": { 21 | "in_channel": 4, // [noisy,lr,land,mask] 22 | "out_channel": 1, // [HR]. 23 | "inner_channel": 64, // channel 24 | "norm_groups": 32, // normgroup 25 | "channel_multiplier": [1,1,2, 4], //每一个下降后的维度参数64 32,16 26 | "attn_res": [32,16], // 在同样的维度下加attention. 27 | "res_blocks": 1, 28 | "dropout": 0.1, 29 | "init_method": "kaiming" 30 | }, 31 | "beta_schedule": { 32 | "train": { 33 | "schedule": "cosine", 34 | "n_timestep": 1000, 35 | "linear_start": 1e-6, 36 | "linear_end": 1e-3 37 | }, 38 | "val": { 39 | "schedule": "cosine", 40 | "n_timestep": 1000, 41 | "linear_start": 1e-6, 42 | "linear_end": 1e-3 43 | }, 44 | "test": { 45 | "schedule": "cosine", 46 | "n_timestep": 1000, 47 | "linear_start": 1e-6, 48 | "linear_end": 1e-3 49 | } 50 | }, 51 | "diffusion": { 52 | "conditional": true, 53 | "loss": "l2" 54 | } 55 | }, 56 | "training": { 57 | "epoch_n_iter": 1000000, 58 | "val_freq": 10000,//25000 59 | "save_checkpoint_freq": 10000, 60 | "print_freq": 2500, 61 | "n_val_vis": 1, 62 | "val_vis_freq": 85,//30, 63 | "sample_size":33, 64 | "optimizer": { 65 | "type": "adamw", // Possible types are ['adam', 'adamw'] 66 | "amsgrad": false, 67 | "lr": 1e-4 68 | } 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /config/u/SR3_u.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "SR3_u_norm", 3 | "phase": "train", 4 | "gpu_ids": [2], 5 | "path": { 6 | "log": "logs", 7 | "tb_logger": "tb_logger", 8 | "results": "results", 9 | "checkpoint": "checkpoint", 10 | "resume_state":null 11 | }, 12 | "data": { 13 | "batch_size": 128, 14 | "num_workers": 6, 15 | "use_shuffle": true, 16 | "height": 128 17 | }, 18 | "model": { 19 | "finetune_norm": false, 20 | "unet": { 21 | "in_channel": 4, // [noisy,lr,land,mask]. 22 | "out_channel": 1, // [HR]. 23 | "inner_channel": 64, // channel 24 | "norm_groups": 32, // group 25 | "channel_multiplier": [1,1,2, 4], //64 32,16 26 | "attn_res": [32,16], // 在同样的维度下加attention. 27 | "res_blocks": 1, 28 | "dropout": 0.1, 29 | "init_method": "kaiming" 30 | }, 31 | "beta_schedule": { 32 | "train": { 33 | "schedule": "cosine", 34 | "n_timestep": 1000, 35 | "linear_start": 1e-6, 36 | "linear_end": 1e-3 37 | }, 38 | "val": { 39 | "schedule": "cosine", 40 | "n_timestep": 1000, 41 | "linear_start": 1e-6, 42 | "linear_end": 1e-3 43 | }, 44 | "test": { 45 | "schedule": "cosine", 46 | "n_timestep": 1000, 47 | "linear_start": 1e-6, 48 | "linear_end": 1e-3 49 | } 50 | }, 51 | "diffusion": { 52 | "conditional": true, 53 | "loss": "l2" 54 | } 55 | }, 56 | "training": { 57 | "epoch_n_iter": 1000000, 58 | "val_freq": 10000,//25000 59 | "save_checkpoint_freq": 10000, 60 | "print_freq": 2500, 61 | "n_val_vis": 1, 62 | "val_vis_freq": 170,//30, 63 | "sample_size":5, 64 | "optimizer": { 65 | "type": "adamw", // Possible types are ['adam', 'adamw'] 66 | "amsgrad": false, 67 | "lr": 1e-4 68 | } 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /config/u/SR3_u_finetune.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "SR3_u_norm", 3 | "phase": "train", 4 | "gpu_ids": [2], 5 | "path": { 6 | "log": "logs", 7 | "tb_logger": "tb_logger", 8 | "results": "results", 9 | "checkpoint": "checkpoint", 10 | "resume_state":"/home/data/downscaling/downscaling_1023/DDIM/SR3_publish/experiments/SR3_u_norm_230605_103243/checkpoint/I340000_E566" 11 | }, 12 | "data": { 13 | "batch_size": 128, 14 | "num_workers": 6, 15 | "use_shuffle": true, 16 | "height": 128 17 | }, 18 | "model": { 19 | "finetune_norm": true, 20 | "unet": { 21 | "in_channel": 4, // [noisy,lr,land,mask]. 22 | "out_channel": 1, // [HR]. 23 | "inner_channel": 64, // channel 24 | "norm_groups": 32, // group 25 | "channel_multiplier": [1,1,2, 4], //64 32,16 26 | "attn_res": [32,16], // 在同样的维度下加attention. 27 | "res_blocks": 1, 28 | "dropout": 0.1, 29 | "init_method": "kaiming" 30 | }, 31 | "beta_schedule": { 32 | "train": { 33 | "schedule": "cosine", 34 | "n_timestep": 1000, 35 | "linear_start": 1e-6, 36 | "linear_end": 1e-3 37 | }, 38 | "val": { 39 | "schedule": "cosine", 40 | "n_timestep": 1000, 41 | "linear_start": 1e-6, 42 | "linear_end": 1e-3 43 | }, 44 | "test": { 45 | "schedule": "cosine", 46 | "n_timestep": 1000, 47 | "linear_start": 1e-6, 48 | "linear_end": 1e-3 49 | } 50 | }, 51 | "diffusion": { 52 | "conditional": true, 53 | "loss": "l2" 54 | } 55 | }, 56 | "training": { 57 | "epoch_n_iter": 1000000, 58 | "val_freq": 10000,//25000 59 | "save_checkpoint_freq": 10000, 60 | "print_freq": 2500, 61 | "n_val_vis": 1, 62 | "val_vis_freq": 170,//30, 63 | "sample_size":5, 64 | "optimizer": { 65 | "type": "adamw", // Possible types are ['adam', 'adamw'] 66 | "amsgrad": false, 67 | "lr": 1e-4 68 | } 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /config/u/SR3_u_finetune_infer.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "SR3_u_norm", 3 | "phase": "val", 4 | "gpu_ids": [2], 5 | "path": { 6 | "log": "logs", 7 | "tb_logger": "tb_logger", 8 | "results": "results", 9 | "checkpoint": "checkpoint", 10 | "resume_state":"/home/data/downscaling/downscaling_1023/DDIM/SR3_publish/experiments/SR3_u_norm_230605_103243/checkpoint/I350000_E583" 11 | }, 12 | "data": { 13 | "batch_size": 128, 14 | "num_workers": 6, 15 | "use_shuffle": true, 16 | "height": 128 17 | }, 18 | "model": { 19 | "finetune_norm": false, 20 | "unet": { 21 | "in_channel": 4, //[noisy,lr,land,mask].. 22 | "out_channel": 1, // [HR]. 23 | "inner_channel": 64, // channel 24 | "norm_groups": 32, // group 25 | "channel_multiplier": [1,1,2, 4], //64 32,16 26 | "attn_res": [32,16], // 在同样的维度下加attention. 27 | "res_blocks": 1, 28 | "dropout": 0.1, 29 | "init_method": "kaiming" 30 | }, 31 | "beta_schedule": { 32 | "train": { 33 | "schedule": "cosine", 34 | "n_timestep": 1000, 35 | "linear_start": 1e-6, 36 | "linear_end": 1e-3 37 | }, 38 | "val": { 39 | "schedule": "cosine", 40 | "n_timestep": 1000, 41 | "linear_start": 1e-6, 42 | "linear_end": 1e-3 43 | }, 44 | "test": { 45 | "schedule": "cosine", 46 | "n_timestep": 1000, 47 | "linear_start": 1e-6, 48 | "linear_end": 1e-3 49 | } 50 | }, 51 | "diffusion": { 52 | "conditional": true, 53 | "loss": "l2" 54 | } 55 | }, 56 | "training": { 57 | "epoch_n_iter": 1000000, 58 | "val_freq": 10000,//25000 59 | "save_checkpoint_freq": 10000, 60 | "print_freq": 2500, 61 | "n_val_vis": 1, 62 | "val_vis_freq": 170,//30, 63 | "sample_size":33, 64 | "optimizer": { 65 | "type": "adamw", // Possible types are ['adam', 'adamw'] 66 | "amsgrad": false, 67 | "lr": 1e-4 68 | } 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /config/v/SR3_v.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "SR3_v_norm", 3 | "phase": "train", 4 | "gpu_ids": [2], 5 | "path": { 6 | "log": "logs", 7 | "tb_logger": "tb_logger", 8 | "results": "results", 9 | "checkpoint": "checkpoint", 10 | "resume_state":null 11 | }, 12 | "data": { 13 | "batch_size": 128, 14 | "num_workers": 6, 15 | "use_shuffle": true, 16 | "height": 128 17 | }, 18 | "model": { 19 | "finetune_norm": false, 20 | "unet": { 21 | "in_channel": 4, // [noisy,lr,land,mask]. 22 | "out_channel": 1, // [HR]. 23 | "inner_channel": 64, // channel 24 | "norm_groups": 32, // group 25 | "channel_multiplier": [1,1,2, 4], //64 32,16 26 | "attn_res": [32,16], // 在同样的维度下加attention. 27 | "res_blocks": 1, 28 | "dropout": 0.1, 29 | "init_method": "kaiming" 30 | }, 31 | "beta_schedule": { 32 | "train": { 33 | "schedule": "cosine", 34 | "n_timestep": 1000, 35 | "linear_start": 1e-6, 36 | "linear_end": 1e-3 37 | }, 38 | "val": { 39 | "schedule": "cosine", 40 | "n_timestep": 1000, 41 | "linear_start": 1e-6, 42 | "linear_end": 1e-3 43 | }, 44 | "test": { 45 | "schedule": "cosine", 46 | "n_timestep": 1000, 47 | "linear_start": 1e-6, 48 | "linear_end": 1e-3 49 | } 50 | }, 51 | "diffusion": { 52 | "conditional": true, 53 | "loss": "l2" 54 | } 55 | }, 56 | "training": { 57 | "epoch_n_iter": 1000000, 58 | "val_freq": 10000,//25000 59 | "save_checkpoint_freq": 10000, 60 | "print_freq": 2500, 61 | "n_val_vis": 1, 62 | "val_vis_freq": 170,//30, 63 | "sample_size":5, 64 | "optimizer": { 65 | "type": "adamw", // Possible types are ['adam', 'adamw'] 66 | "amsgrad": false, 67 | "lr": 1e-4 68 | } 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /config/v/SR3_v_finetune.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "SR3_v_norm", 3 | "phase": "train", 4 | "gpu_ids": [2], 5 | "path": { 6 | "log": "logs", 7 | "tb_logger": "tb_logger", 8 | "results": "results", 9 | "checkpoint": "checkpoint", 10 | "resume_state":"/home/data/downscaling/downscaling_1023/DDIM/SR3_publish/experiments/SR3_v_norm_230605_104034/checkpoint/I190000_E316" 11 | }, 12 | "data": { 13 | "batch_size": 128, 14 | "num_workers": 6, 15 | "use_shuffle": true, 16 | "height": 128 17 | }, 18 | "model": { 19 | "finetune_norm": false, 20 | "unet": { 21 | "in_channel": 4, // [noisy,lr,land,mask]. 22 | "out_channel": 1, // [HR]. 23 | "inner_channel": 64, // channel 24 | "norm_groups": 32, // group 25 | "channel_multiplier": [1,1,2, 4], //64 32,16 26 | "attn_res": [32,16], // 在同样的维度下加attention. 27 | "res_blocks": 1, 28 | "dropout": 0.1, 29 | "init_method": "kaiming" 30 | }, 31 | "beta_schedule": { 32 | "train": { 33 | "schedule": "cosine", 34 | "n_timestep": 1000, 35 | "linear_start": 1e-6, 36 | "linear_end": 1e-3 37 | }, 38 | "val": { 39 | "schedule": "cosine", 40 | "n_timestep": 1000, 41 | "linear_start": 1e-6, 42 | "linear_end": 1e-3 43 | }, 44 | "test": { 45 | "schedule": "cosine", 46 | "n_timestep": 1000, 47 | "linear_start": 1e-6, 48 | "linear_end": 1e-3 49 | } 50 | }, 51 | "diffusion": { 52 | "conditional": true, 53 | "loss": "l2" 54 | } 55 | }, 56 | "training": { 57 | "epoch_n_iter": 1000000, 58 | "val_freq": 10000,//25000 59 | "save_checkpoint_freq": 10000, 60 | "print_freq": 2500, 61 | "n_val_vis": 1, 62 | "val_vis_freq": 170,//30, 63 | "sample_size":5, 64 | "optimizer": { 65 | "type": "adamw", // Possible types are ['adam', 'adamw'] 66 | "amsgrad": false, 67 | "lr": 1e-4 68 | } 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /config/v/SR3_v_infer.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "SR3_v_norm", 3 | "phase": "val", 4 | "gpu_ids": [2], 5 | "path": { 6 | "log": "logs", 7 | "tb_logger": "tb_logger", 8 | "results": "results", 9 | "checkpoint": "checkpoint", 10 | "resume_state":"/home/data/downscaling/downscaling_1023/DDIM/SR3_publish/experiments/SR3_v_norm_230605_104034/checkpoint/I200000_E333" 11 | }, 12 | "data": { 13 | "batch_size": 128, 14 | "num_workers": 6, 15 | "use_shuffle": true, 16 | "height": 128 17 | }, 18 | "model": { 19 | "finetune_norm": false, 20 | "unet": { 21 | "in_channel": 4, // [noisy,lr,land,mask]. 22 | "out_channel": 1, // [HR] 23 | "inner_channel": 64, // channel 24 | "norm_groups": 32, // group 25 | "channel_multiplier": [1,1,2, 4], //64 32,16 26 | "attn_res": [32,16], //在同样的维度下加attention. 27 | "res_blocks": 1, 28 | "dropout": 0.1, 29 | "init_method": "kaiming" 30 | }, 31 | "beta_schedule": { 32 | "train": { 33 | "schedule": "cosine", 34 | "n_timestep": 1000, 35 | "linear_start": 1e-6, 36 | "linear_end": 1e-3 37 | }, 38 | "val": { 39 | "schedule": "cosine", 40 | "n_timestep": 1000, 41 | "linear_start": 1e-6, 42 | "linear_end": 1e-3 43 | }, 44 | "test": { 45 | "schedule": "cosine", 46 | "n_timestep": 1000, 47 | "linear_start": 1e-6, 48 | "linear_end": 1e-3 49 | } 50 | }, 51 | "diffusion": { 52 | "conditional": true, 53 | "loss": "l2" 54 | } 55 | }, 56 | "training": { 57 | "epoch_n_iter": 1000000, 58 | "val_freq": 10000,//25000 59 | "save_checkpoint_freq": 10000, 60 | "print_freq": 2500, 61 | "n_val_vis": 1, 62 | "val_vis_freq": 170,//30, 63 | "sample_size":33, 64 | "optimizer": { 65 | "type": "adamw", // Possible types are ['adam', 'adamw'] 66 | "amsgrad": false, 67 | "lr": 1e-4 68 | } 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /configs.py: -------------------------------------------------------------------------------- 1 | """Defines configuration parameters for the whole model and dataset. 2 | """ 3 | import argparse 4 | import json 5 | import os 6 | from collections import OrderedDict 7 | from datetime import datetime 8 | 9 | 10 | def get_current_datetime() -> str: 11 | """Converts the current datetime to string. 12 | Returns: 13 | String version of current datetime of the form: %y%m%d_%H%M%S. 14 | """ 15 | return datetime.now().strftime("%y%m%d_%H%M%S") 16 | 17 | 18 | def mkdirs(paths) -> None: 19 | """Creates directories represented by paths argument. 20 | Args: 21 | paths: Either list of paths or a single path. 22 | """ 23 | if isinstance(paths, str): 24 | os.makedirs(paths, exist_ok=True) 25 | else: 26 | for path in paths: 27 | os.makedirs(path, exist_ok=True) 28 | 29 | 30 | class Config: 31 | """Configuration class. 32 | Attributes: 33 | args: 命令行参数. 34 | root: json路径 35 | gpu_ids: list. 36 | params: 参数字典用于保存jsonjson. 37 | name: 实验名. 38 | phase: Either train or val. 39 | distributed: Whether the computation will be distributed among multiple GPUs or not. 40 | log: Path to logs. 41 | tb_logger: Tensorboard logging directory. 42 | results: Validation results directory. 43 | checkpoint: Model checkpoints directory. 44 | resume_state: The path to load the network. 45 | dataset_name: The name of dataset. 46 | dataroot: The path to dataset. 47 | batch_size: Batch size. 48 | num_workers: The number of processes for multi-process data loading. 49 | use_shuffle: Either to shuffle the training data or not. 50 | train_min_date: Minimum date starting from which to read the data for training. 51 | train_max_date: Maximum date until which to read the date for training. 52 | val_min_date: Minimum date starting from which to read the data for validation. 53 | val_max_date: Maximum date until which to read the date for validation. 54 | train_subset_min_date: Minimum date starting from which to read the data for model evaluation on train subset. 55 | train_subset_max_date: Maximum date starting until which to read the data for model evaluation on train subset. 56 | variables: A list of WeatherBench variables. 57 | finetune_norm: Whetehr to fine-tune or train from scratch. 58 | in_channel: The number of channels of input tensor of U-Net. 59 | out_channel: The number of channels of output tensor of U-Net. 60 | inner_channel: Timestep embedding dimension. 61 | norm_groups: The number of groups for group normalization. 62 | channel_multiplier: A tuple specifying the scaling factors of channels. 63 | attn_res: A tuple of spatial dimensions indicating in which resolutions to use self-attention layer. 64 | res_blocks: The number of residual blocks. 65 | dropout: Dropout probability. 66 | init_method: NN weight initialization method. One of normal, kaiming or orthogonal inisializations. 67 | train_schedule: Defines the type of beta schedule for training. 68 | train_n_timestep: Number of diffusion timesteps for training. 69 | train_linear_start: Minimum value of the linear schedule for training. 70 | train_linear_end: Maximum value of the linear schedule for training. 71 | val_schedule: Defines the type of beta schedule for validation. 72 | val_n_timestep: Number of diffusion timesteps for validation. 73 | val_linear_start: Minimum value of the linear schedule for validation. 74 | val_linear_end: Maximum value of the linear schedule for validation. 75 | test_schedule: Defines the type of beta schedule for inference. 76 | test_n_timestep: Number of diffusion timesteps for inference. 77 | test_linear_start: Minimum value of the linear schedule for inference. 78 | test_linear_end: Maximum value of the linear schedule for inference. 79 | conditional: Whether to condition on INTERPOLATED image or not. 80 | diffusion_loss: Either 'l1' or 'l2'. 81 | n_iter: Number of iterations to train. 82 | val_freq: Validation frequency. 83 | save_checkpoint_freq: Model checkpoint frequency. 84 | print_freq: The frequency of displaying training information. 85 | n_val_vis: Number of data points to visualize. 86 | val_vis_freq: Validation data points visualization frequency. 87 | sample_size: Numer of SR images to generate to calculate metrics. 88 | optimizer_type: The name of optimization algorithm. Supported values are 'adam', 'adamw'. 89 | amsgrad: Whether to use the AMSGrad variant of optimizer. 90 | lr: The learning rate. 91 | experiments_root: The path to experiment. 92 | tranform_monthly: Whether to apply transformation monthly or on the whole dataset. 93 | height: U-Net input tensor height value. 94 | """ 95 | 96 | def __init__(self, args: argparse.Namespace): 97 | self.args = args 98 | self.root = self.args.config 99 | self.gpu_ids = self.args.gpu_ids 100 | # self.members=self.args.members 101 | self.params = {} 102 | self.experiments_root = None 103 | self.__parse_configs() 104 | self.name = self.params["name"] 105 | self.phase = self.params["phase"] 106 | self.gpu_ids = self.params["gpu_ids"] 107 | self.distributed = self.params["distributed"] 108 | self.log = self.params["path"]["log"] 109 | self.tb_logger = self.params["path"]["tb_logger"] 110 | self.results = self.params["path"]["results"] 111 | self.checkpoint = self.params["path"]["checkpoint"] 112 | self.resume_state = self.params["path"]["resume_state"] 113 | self.batch_size = self.params["data"]["batch_size"] 114 | self.num_workers = self.params["data"]["num_workers"] 115 | self.use_shuffle = self.params["data"]["use_shuffle"] 116 | self.height = self.params["data"]["height"] 117 | self.finetune_norm = self.params["model"]["finetune_norm"] 118 | self.in_channel = self.params["model"]["unet"]["in_channel"] 119 | self.out_channel = self.params["model"]["unet"]["out_channel"] 120 | self.inner_channel = self.params["model"]["unet"]["inner_channel"] 121 | self.norm_groups = self.params["model"]["unet"]["norm_groups"] 122 | self.channel_multiplier = self.params["model"]["unet"]["channel_multiplier"] 123 | self.attn_res = self.params["model"]["unet"]["attn_res"] 124 | self.res_blocks = self.params["model"]["unet"]["res_blocks"] 125 | self.dropout = self.params["model"]["unet"]["dropout"] 126 | self.init_method = self.params["model"]["unet"]["init_method"] 127 | self.train_schedule = self.params["model"]["beta_schedule"]["train"]["schedule"] 128 | self.train_n_timestep = self.params["model"]["beta_schedule"]["train"]["n_timestep"] 129 | self.train_linear_start = self.params["model"]["beta_schedule"]["train"]["linear_start"] 130 | self.train_linear_end = self.params["model"]["beta_schedule"]["train"]["linear_end"] 131 | self.val_schedule = self.params["model"]["beta_schedule"]["val"]["schedule"] 132 | self.val_n_timestep = self.params["model"]["beta_schedule"]["val"]["n_timestep"] 133 | self.val_linear_start = self.params["model"]["beta_schedule"]["val"]["linear_start"] 134 | self.val_linear_end = self.params["model"]["beta_schedule"]["val"]["linear_end"] 135 | self.test_schedule = self.params["model"]["beta_schedule"]["test"]["schedule"] 136 | self.test_n_timestep = self.params["model"]["beta_schedule"]["test"]["n_timestep"] 137 | self.test_linear_start = self.params["model"]["beta_schedule"]["test"]["linear_start"] 138 | self.test_linear_end = self.params["model"]["beta_schedule"]["test"]["linear_end"] 139 | self.conditional = self.params["model"]["diffusion"]["conditional"] 140 | self.diffusion_loss = self.params["model"]["diffusion"]["loss"] 141 | self.n_iter = self.params["training"]["epoch_n_iter"] 142 | self.val_freq = self.params["training"]["val_freq"] 143 | self.save_checkpoint_freq = self.params["training"]["save_checkpoint_freq"] 144 | self.print_freq = self.params["training"]["print_freq"] 145 | self.n_val_vis = self.params["training"]["n_val_vis"] 146 | self.val_vis_freq = self.params["training"]["val_vis_freq"] 147 | self.sample_size = self.params["training"]["sample_size"] 148 | self.optimizer_type = self.params["training"]["optimizer"]["type"] 149 | self.amsgrad = self.params["training"]["optimizer"]["amsgrad"] 150 | self.lr = self.params["training"]["optimizer"]["lr"] 151 | 152 | def __parse_configs(self): 153 | """Reads configureation json file and stores in params attribute.""" 154 | json_str = "" 155 | with open(self.root, "r") as f: 156 | for line in f: 157 | json_str = f"{json_str}{line.split('//')[0]}\n" 158 | 159 | self.params = json.loads(json_str, object_pairs_hook=OrderedDict) 160 | 161 | if not self.params["path"]["resume_state"]: 162 | self.experiments_root = os.path.join("experiments", f"{self.params['name']}_{get_current_datetime()}") 163 | else: 164 | self.experiments_root = "/".join(self.params["path"]["resume_state"].split("/")[:-2]) 165 | 166 | for key, path in self.params["path"].items(): 167 | if not key.startswith("resume"): 168 | self.params["path"][key] = os.path.join(self.experiments_root, path) 169 | mkdirs(self.params["path"][key]) 170 | 171 | if self.gpu_ids: 172 | self.params["gpu_ids"] = [int(gpu_id) for gpu_id in self.gpu_ids.split(",")] 173 | gpu_list = self.gpu_ids 174 | else: 175 | gpu_list = ",".join(str(x) for x in self.params["gpu_ids"]) 176 | 177 | os.environ["CUDA_VISIBLE_DEVICES"] = gpu_list 178 | self.params["distributed"] = True if len(gpu_list) > 1 else False 179 | 180 | def __getattr__(self, item): 181 | """Returns None when attribute doesn't exist. 182 | Args: 183 | item: Attribute to retrieve. 184 | Returns: 185 | None 186 | """ 187 | return None 188 | 189 | def get_hyperparameters_as_dict(self): 190 | """Returns dictionary containg parsed configuration json file. 191 | """ 192 | return self.params 193 | 194 | if __name__ == '__main__': 195 | parser = argparse.ArgumentParser() 196 | parser.add_argument("-c", "--config", type=str, help="JSON file for configuration") 197 | parser.add_argument("-p", "--phase", type=str, choices=["train", "val"], 198 | help="Run either training or validation(inference).", default="train") 199 | parser.add_argument("-gpu", "--gpu_ids", type=str, default=None) 200 | args = parser.parse_args() 201 | configs = Config(args) 202 | print(configs.name) -------------------------------------------------------------------------------- /data/.ipynb_checkpoints/mydataset_patch-checkpoint.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import xarray as xr 3 | from torch.utils.data import Dataset, DataLoader 4 | from torch.nn.functional import interpolate 5 | import torch 6 | import glob 7 | import torch 8 | from bisect import bisect 9 | from tqdm import tqdm 10 | import matplotlib.pyplot as plt 11 | import random 12 | 13 | class SR3_Dataset_train(torch.utils.data.Dataset): 14 | def __init__(self,hr_paths,land_paths,mask_paths,lr_paths,var,patch_size): 15 | var_list={"u":0,"v":1,"t2m":2,"sp":3,"tp":4,} 16 | index_var=var_list[var] 17 | # for path1,path2 in zip(hr_paths,physical_paths): 18 | # print(path1,path2) 19 | self.target_hr = [np.load(path, mmap_mode='r+').transpose(0,3,1,2)[:,index_var:index_var+1] for path in hr_paths] 20 | self.target_lr = [np.load(path, mmap_mode='r+').transpose(0,3,1,2)[:,index_var:index_var+1] for path in lr_paths] 21 | 22 | #[0,2,4,6,8]# 500 zrtuv #[6,8,4,0,2]u v t z r 23 | self.land_01=np.expand_dims(np.load(land_paths, mmap_mode='r+'),axis=0) 24 | self.mask_data=np.expand_dims(np.load(mask_paths, mmap_mode='r+'),axis=0) 25 | self.start_indices = [0] * len(self.target_hr) 26 | self.data_count = 0 27 | # self.scale=scale 28 | self.patch_size=patch_size 29 | for index, memmap in enumerate(self.target_hr): 30 | self.start_indices[index] = self.data_count 31 | self.data_count += memmap.shape[0] 32 | def get_patch(self,hr,mask,hr_land,lr_inter): 33 | ih_hr, iw_hr = hr.shape[1:] 34 | ip=self.patch_size 35 | ix = random.randrange(0, iw_hr - ip + 1) 36 | iy = random.randrange(0, ih_hr - ip + 1) 37 | mask_data=torch.from_numpy(mask[:,iy:iy + ip, ix:ix + ip]).float() 38 | land_data=torch.from_numpy(hr_land[:,iy:iy + ip, ix:ix + ip]).float() 39 | lr_data=lr_inter[:,iy:iy + ip, ix:ix + ip].float() 40 | ret = { 41 | "HR":torch.from_numpy(hr[:,iy:iy + ip, ix:ix + ip]).float(), 42 | "mask":mask_data, 43 | "INTERPOLATED":torch.cat([lr_data,mask_data,land_data],axis=0), 44 | "LAND":land_data 45 | } 46 | return ret 47 | 48 | def __len__(self): 49 | return self.data_count 50 | 51 | def __getitem__(self, index): 52 | memmap_index = bisect(self.start_indices, index) - 1 53 | index_in_memmap = index - self.start_indices[memmap_index] 54 | 55 | land_01_data=self.land_01 56 | mask_data=self.mask_data 57 | hr_target = self.target_hr[memmap_index][index_in_memmap]*mask_data 58 | 59 | lr_inter=interpolate(torch.from_numpy(np.expand_dims(self.target_lr[memmap_index][index_in_memmap],axis=0)).float(),scale_factor=10, mode="bicubic").squeeze(0)*mask_data 60 | 61 | return self.get_patch(hr_target,mask_data,land_01_data,lr_inter) 62 | 63 | 64 | 65 | 66 | class SR3_Dataset_val_new(torch.utils.data.Dataset): 67 | def __init__(self,hr_paths,land_paths,mask_paths,lr_paths,var,patch_size,loc): 68 | index_list = [] 69 | for i, i_start in enumerate(np.arange(0, 400, patch_size)): 70 | for j, j_start in enumerate(np.arange(0, 700, patch_size)): 71 | i_end = i_start + patch_size 72 | j_end = j_start + patch_size 73 | if i_end > 400: 74 | i_end = 400 75 | i_start=400-128 76 | if j_end > 700: 77 | j_end = 700 78 | j_start=700-128 79 | index_list.append((i_start, i_end, j_start, j_end)) 80 | loc_dict={} 81 | for i,index in enumerate(index_list): 82 | loc_dict[str(i)]=index 83 | var_list={"u":0,"v":1,"t2m":2,"sp":3,"tp":4,} 84 | index_var=var_list[var] 85 | self.loc_index=loc_dict[str(loc)] 86 | self.target_hr = [np.load(path, mmap_mode='r+').transpose(0,3,1,2)[:,index_var:index_var+1] for path in hr_paths] 87 | self.target_lr = [np.load(path, mmap_mode='r+').transpose(0,3,1,2)[:,index_var:index_var+1] for path in lr_paths] 88 | 89 | #[0,2,4,6,8]# 500 zrtuv #[6,8,4,0,2]u v t z r 90 | self.land_01=np.expand_dims(np.load(land_paths, mmap_mode='r+'),axis=0) 91 | self.mask_data=np.expand_dims(np.load(mask_paths, mmap_mode='r+'),axis=0) 92 | self.start_indices = [0] * len(self.target_hr) 93 | self.data_count = 0 94 | self.patch_size=patch_size 95 | for index, memmap in enumerate(self.target_hr): 96 | self.start_indices[index] = self.data_count 97 | self.data_count += memmap.shape[0] 98 | def get_patch(self,hr,mask,hr_land,lr_inter): 99 | i_start,i_end, j_start,j_end=self.loc_index 100 | mask_data=torch.from_numpy(mask[:,i_start:i_end, j_start:j_end]).float() 101 | land_data=torch.from_numpy(hr_land[:,i_start:i_end, j_start:j_end]).float() 102 | lr_data=lr_inter[:,i_start:i_end, j_start:j_end].float() 103 | ret = { 104 | "HR":torch.from_numpy(hr[:,i_start:i_end, j_start:j_end]).float(), 105 | "mask":mask_data, 106 | "INTERPOLATED":torch.cat([lr_data,mask_data,land_data],axis=0), 107 | "LAND":land_data 108 | } 109 | return ret 110 | 111 | def __len__(self): 112 | return self.data_count 113 | 114 | def __getitem__(self, index): 115 | memmap_index = bisect(self.start_indices, index) - 1 116 | index_in_memmap = index - self.start_indices[memmap_index] 117 | 118 | land_01_data=self.land_01 119 | mask_data=self.mask_data 120 | hr_target = self.target_hr[memmap_index][index_in_memmap]*mask_data 121 | 122 | lr_inter=interpolate(torch.from_numpy(np.expand_dims(self.target_lr[memmap_index][index_in_memmap],axis=0)).float(),scale_factor=10, mode="bicubic").squeeze(0)*mask_data 123 | 124 | return self.get_patch(hr_target,mask_data,land_01_data,lr_inter) 125 | 126 | 127 | 128 | class SR3_Dataset_val(torch.utils.data.Dataset): 129 | def __init__(self,hr_paths,land_paths,mask_paths,lr_paths,var): 130 | var_list={"u":0,"v":1,"t2m":2,"sp":3,"tp":4,} 131 | index_var=var_list[var] 132 | self.target_hr = [np.load(path, mmap_mode='r+').transpose(0,3,1,2)[:,index_var:index_var+1] for path in hr_paths] 133 | self.target_lr = [np.load(path, mmap_mode='r+').transpose(0,3,1,2)[:,index_var:index_var+1] for path in lr_paths] 134 | self.land_01=np.expand_dims(np.load(land_paths, mmap_mode='r+'),axis=0) 135 | self.mask_data=np.expand_dims(np.load(mask_paths, mmap_mode='r+'),axis=0) 136 | self.start_indices = [0] * len(self.target_hr) 137 | self.data_count = 0 138 | for index, memmap in enumerate(self.target_hr): 139 | self.start_indices[index] = self.data_count 140 | self.data_count += memmap.shape[0] 141 | 142 | 143 | 144 | def get_patch(self,hr,mask,hr_land,lr_inter): 145 | ih_hr, iw_hr = hr.shape[1:] 146 | mask_data=torch.from_numpy(mask).float() 147 | land_data=torch.from_numpy(hr_land).float() 148 | lr_data=lr_inter.float() 149 | ret = { 150 | "HR":torch.from_numpy(hr).float(), 151 | "mask":mask_data, 152 | "INTERPOLATED":torch.cat([lr_data,mask_data,land_data],axis=0), 153 | "LAND":land_data 154 | } 155 | return ret 156 | 157 | 158 | 159 | def __len__(self): 160 | return self.data_count 161 | 162 | def __getitem__(self, index): 163 | memmap_index = bisect(self.start_indices, index) - 1 164 | index_in_memmap = index - self.start_indices[memmap_index] 165 | mask_data=self.mask_data 166 | land_01_data=self.land_01 167 | hr_target = self.target_hr[memmap_index][index_in_memmap]*mask_data 168 | 169 | 170 | lr_inter=interpolate(torch.from_numpy(np.expand_dims(self.target_lr[memmap_index][index_in_memmap],axis=0)).float(),scale_factor=10, mode="bicubic").squeeze(0)*mask_data 171 | 172 | return self.get_patch(hr_target,mask_data,land_01_data,lr_inter) 173 | 174 | 175 | 176 | 177 | 178 | class BigDataset_test(torch.utils.data.Dataset): 179 | def __init__(self,hr_paths,land_paths,mask_paths): 180 | self.target_hr = [np.load(path, mmap_mode='r+').transpose(0,3,1,2) for path in hr_paths] 181 | self.land_01=np.expand_dims(np.load(land_paths, mmap_mode='r+'),axis=0) 182 | self.mask_data=np.expand_dims(np.load(mask_paths, mmap_mode='r+'),axis=0) 183 | self.start_indices = [0] * len(self.target_hr) 184 | self.data_count = 0 185 | # self.scale=scale 186 | # self.max_01=np.load(max_paths, mmap_mode='r+') 187 | for index, memmap in enumerate(self.target_hr): 188 | self.start_indices[index] = self.data_count 189 | self.data_count += memmap.shape[0] 190 | 191 | 192 | 193 | def get_patch(self,hr,mask,hr_land): 194 | mask_data=torch.from_numpy(mask).float() 195 | land_data=torch.from_numpy(hr_land).float() 196 | random_index=random.random() 197 | if random_index<0: 198 | ret = { 199 | "HR":torch.from_numpy(hr).float(), 200 | "mask":mask_data, 201 | "INTERPOLATED":torch.cat([mask_data,land_data],axis=0), 202 | "LAND":land_data, 203 | } 204 | else: 205 | #patch_list=[256] 206 | ip=256#patch_list[random.randint(0, 2)] 207 | ih_hr, iw_hr = hr.shape[1:] 208 | ix = random.randrange(0, iw_hr - ip + 1) 209 | iy = random.randrange(0, ih_hr - ip + 1) 210 | mask_data=torch.from_numpy(mask[:,iy:iy + ip, ix:ix + ip]).float() 211 | land_data=torch.from_numpy(hr_land[:,iy:iy + ip, ix:ix + ip]).float() 212 | ret = { 213 | "HR":torch.from_numpy(hr[:,iy:iy + ip, ix:ix + ip]).float(), 214 | "mask":mask_data, 215 | "INTERPOLATED":torch.cat([mask_data,land_data],axis=0), 216 | "LAND":land_data 217 | } 218 | return ret 219 | 220 | 221 | 222 | def __len__(self): 223 | return self.data_count 224 | 225 | def __getitem__(self, index): 226 | memmap_index = bisect(self.start_indices, index) - 1 227 | index_in_memmap = index - self.start_indices[memmap_index] 228 | 229 | land_01_data=self.land_01 230 | hr_target = self.target_hr[memmap_index][index_in_memmap] 231 | # physical=self.data_physical[memmap_index][index_in_memmap] 232 | mask_data=self.mask_data 233 | 234 | 235 | return self.get_patch(hr_target,mask_data,land_01_data) 236 | 237 | 238 | 239 | 240 | class BigDataset_cascade_infer(torch.utils.data.Dataset): 241 | def __init__(self,lr_paths,mask_paths,mask_paths_2x,var): 242 | variable={"u10":0,"v10":1,"sp":2,"t2m":3,"tp":4} 243 | idx=variable[var] 244 | self.data_lr = [np.load(path, mmap_mode='r+').transpose(0,3,1,2)[:,idx:idx+1] for path in lr_paths] 245 | self.mask_data=np.expand_dims(np.load(mask_paths, mmap_mode='r+'),axis=0) 246 | #2倍 247 | self.mask_02=np.expand_dims(np.load(mask_paths_2x, mmap_mode='r+'),axis=0) 248 | self.start_indices = [0] * len(self.data_lr) 249 | self.data_count = 0 250 | # self.scale=scale 251 | # self.patch_size=patch_size 252 | 253 | for index, memmap in enumerate(self.data_lr): 254 | self.start_indices[index] = self.data_count 255 | self.data_count += memmap.shape[0] 256 | 257 | 258 | def __len__(self): 259 | return self.data_count 260 | 261 | def __getitem__(self, index): 262 | memmap_index = bisect(self.start_indices, index) - 1 263 | index_in_memmap = index - self.start_indices[memmap_index] 264 | lr_data = self.data_lr[memmap_index][index_in_memmap] 265 | mask_data=torch.from_numpy(self.mask_data).float() 266 | mask_data_2x=torch.from_numpy(self.mask_02).float() 267 | 268 | inter=interpolate(torch.from_numpy(np.expand_dims(lr_data,axis=0)).float(),scale_factor=2, mode="bicubic").squeeze(0) 269 | ret = { 270 | "LR":torch.from_numpy(lr_data).float(), 271 | "INTERPOLATED":inter*mask_data_2x,#/max_ 272 | "mask": mask_data 273 | 274 | } 275 | 276 | return ret 277 | 278 | 279 | 280 | if __name__ == '__main__': 281 | data_paths = sorted(glob.glob("/home/data/downscaling/downscaling_1023/data/train_dataset/sl/*npy")) 282 | target_paths = sorted(glob.glob("/home/data/downscaling/downscaling_1023/data/train_dataset/hr/*npy")) 283 | physical_paths = sorted(glob.glob("/home/data/downscaling/downscaling_1023/data/train_dataset/pl/*npy")) 284 | land_01_path="/home/data/downscaling/downscaling_1023/data/land10.npy" 285 | mask_path="/home/data/downscaling/downscaling_1023/data/mask10.npy" 286 | data_paths = sorted(glob.glob("/home/data/downscaling/downscaling_1023/data/train_dataset/sl/*npy")) 287 | 288 | max_path="/home/data/downscaling/downscaling_1023/data/max_50.npy" 289 | random_dataset_index= random.sample(range(0, len(target_paths)), 2) 290 | data_index=np.arange(0,len(target_paths)) 291 | train_index=np.delete(data_index,random_dataset_index) 292 | 293 | print(f"split_random dataset is {random_dataset_index}" ) 294 | train_data = SR3_Dataset_train(np.array(target_paths)[train_index],land_01_path,mask_path,np.array(data_paths)[train_index],'tp',patch_size=128) 295 | 296 | # val_data=SR3_Dataset_val(np.array(target_paths)[random_dataset_index],land_01_path,mask_path,np.array(data_paths)[random_dataset_index],'tp',patch=128) 297 | val_data=SR3_Dataset_val(np.array(target_paths)[random_dataset_index],land_01_path,mask_path,np.array(data_paths)[random_dataset_index],'tp') 298 | # dataset = Control_Dataset_val(target_paths,land_01_path,mask_path,physical_paths) 299 | 300 | 301 | train_loader = DataLoader(train_data, batch_size=3,drop_last=True) 302 | val_loader = DataLoader(val_data, batch_size=32,drop_last=True) 303 | print(len(val_data),len(train_data)) 304 | # train_size = int(len(dataset) * 0.8) 305 | # validate_size = len(dataset)-int(len(dataset) * 0.8) 306 | # train_dataset, validate_dataset = torch.utils.data.random_split(dataset, [train_size, validate_size]) 307 | # train_loader=torch.utils.data.DataLoader(dataset, batch_size=10,shuffle=False,num_workers=2) 308 | 309 | for i,ret in enumerate(train_loader): 310 | for key in ret.keys(): 311 | print(key) 312 | print(ret[key].shape) 313 | # x = {key: (item.to(self.device) if item.numel() else item) for key, item in x.items()} 314 | break 315 | 316 | # # print(lr_data.shape,hr_target.shape,mask_data.shape,land_01_data.shape,physical.shape) 317 | # figure,ax=plt.subplots(3,2,figsize=(5,10)) 318 | # ax[0,0].imshow(ret["INTERPOLATED"][0,0],vmin=0,vmax=0.5) 319 | # ax[0,1].imshow(ret["INTERPOLATED"][0,1],vmin=0,vmax=0.5) 320 | # ax[1,0].imshow(ret["Control_data"]['850hpa'][0,0],vmin=-5,vmax=5) 321 | # ax[1,1].imshow(ret["Control_data"]['850hpa'][0,1],vmin=-5,vmax=5) 322 | # ax[2,0].imshow(ret["Control_data"]['850hpa'][0,2],vmin=-5,vmax=5) 323 | # ax[2,1].imshow(ret["Control_data"]['850hpa'][0,3],vmin=-5,vmax=5) 324 | # plt.savefig("./test.png") 325 | # break 326 | # # exit() 327 | 328 | # if torch.isnan(dataset).any() or torch.isinf(dataset).any(): 329 | # print(i,"dataset_TRUE") 330 | # if torch.isnan(label).any() or torch.isinf(label).any(): 331 | # print(i,"label_TRUE") 332 | # if torch.isnan(land_01).any() or torch.isinf(land_01).any(): 333 | # print(i,"land_True") 334 | # # print(i,"max",label[:,0].max(),label[:,1].max(),label[:,2].max(),label[:,3].max(),label[:,4].max()) 335 | # # #print("min",dataset.min(),label.min(),land_01.min(),land_04.min()) 336 | # if torch.isnan(dataset).int().sum()>=1 or torch.isnan(label).int().sum()>=1 or torch.isnan(physical).int().sum()>=1: 337 | # print("Nan") 338 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LingFH/Diffusion_4_downscaling/5dea4cbaf404803c6bedd1abebb90a2effa8686c/data/__init__.py -------------------------------------------------------------------------------- /data/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LingFH/Diffusion_4_downscaling/5dea4cbaf404803c6bedd1abebb90a2effa8686c/data/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /data/__pycache__/mydataset_patch.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LingFH/Diffusion_4_downscaling/5dea4cbaf404803c6bedd1abebb90a2effa8686c/data/__pycache__/mydataset_patch.cpython-39.pyc -------------------------------------------------------------------------------- /data/mydataset_patch.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import xarray as xr 3 | from torch.utils.data import Dataset, DataLoader 4 | from torch.nn.functional import interpolate 5 | import torch 6 | import glob 7 | import torch 8 | from bisect import bisect 9 | from tqdm import tqdm 10 | import matplotlib.pyplot as plt 11 | import random 12 | 13 | class SR3_Dataset_patch(torch.utils.data.Dataset): 14 | def __init__(self,hr_paths,land_paths,mask_paths,lr_paths,var,patch_size): 15 | var_list={"u":0,"v":1,"t2m":2,"sp":3,"tp":4,} 16 | index_var=var_list[var] 17 | self.variable_name=var 18 | # for path1,path2 in zip(hr_paths,physical_paths): 19 | # print(path1,path2) 20 | self.target_hr = [np.load(path, mmap_mode='r+').transpose(0,3,1,2)[:,index_var:index_var+1] for path in hr_paths] 21 | self.target_lr = [np.load(path, mmap_mode='r+').transpose(0,3,1,2)[:,index_var:index_var+1] for path in lr_paths] 22 | 23 | #[0,2,4,6,8]# 500 zrtuv #[6,8,4,0,2]u v t z r 24 | self.land_01=np.expand_dims(np.load(land_paths, mmap_mode='r+'),axis=0) 25 | self.mask_data=np.expand_dims(np.load(mask_paths, mmap_mode='r+'),axis=0) 26 | self.start_indices = [0] * len(self.target_hr) 27 | self.data_count = 0 28 | # self.scale=scale 29 | self.patch_size=patch_size 30 | for index, memmap in enumerate(self.target_hr): 31 | self.start_indices[index] = self.data_count 32 | self.data_count += memmap.shape[0] 33 | self.max = torch.from_numpy(np.load("/home/data/downscaling/downscaling_1023/data/train_dataset/max_new_10.npy", mmap_mode='r+')).float() 34 | self.min = torch.from_numpy(np.load("/home/data/downscaling/downscaling_1023/data/train_dataset/min_new_10.npy", mmap_mode='r+')).float() 35 | # print(self.max.shape) 36 | def normal_max_min(self,data,iy,ix,ip): 37 | var_list={"u":0,"v":1,"t2m":2,"sp":3,"tp":4,} 38 | index_var=var_list[self.variable_name] 39 | 40 | max_=self.max[index_var:index_var+1,iy:iy + ip, ix:ix + ip] 41 | min_=self.min[index_var:index_var+1,iy:iy + ip, ix:ix + ip] 42 | 43 | # print(max_.max(),max_.min()) 44 | # print((max_-min_).max(),(max_-min_).min()) 45 | return (data-min_)/(max_-min_+1e-6) 46 | def get_patch(self,hr,mask,hr_land,lr_inter): 47 | 48 | ih_hr, iw_hr = hr.shape[1:] 49 | ip=self.patch_size 50 | ix = random.randrange(0, iw_hr - ip + 1) 51 | iy = random.randrange(0, ih_hr - ip + 1) 52 | mask_data=torch.from_numpy(mask[:,iy:iy + ip, ix:ix + ip]).float() 53 | land_data=torch.from_numpy(hr_land[:,iy:iy + ip, ix:ix + ip]).float() 54 | 55 | if self.variable_name in ["u","v","t2m","sp"]: 56 | lr_data=self.normal_max_min(lr_inter[:,iy:iy + ip, ix:ix + ip].float(),iy,ix,ip) 57 | ret = { 58 | "HR":self.normal_max_min(torch.from_numpy(hr[:,iy:iy + ip, ix:ix + ip]).float(),iy,ix,ip), 59 | "mask":mask_data, 60 | "INTERPOLATED":torch.cat([lr_data,mask_data,land_data],axis=0), 61 | "LAND":land_data 62 | } 63 | else: 64 | lr_data=lr_inter[:,iy:iy + ip, ix:ix + ip].float() 65 | ret = { 66 | "HR":torch.from_numpy(hr[:,iy:iy + ip, ix:ix + ip]).float(), 67 | "mask":mask_data, 68 | "INTERPOLATED":torch.cat([lr_data,mask_data,land_data],axis=0), 69 | "LAND":land_data 70 | } 71 | 72 | return ret 73 | 74 | def __len__(self): 75 | return self.data_count 76 | 77 | def __getitem__(self, index): 78 | memmap_index = bisect(self.start_indices, index) - 1 79 | index_in_memmap = index - self.start_indices[memmap_index] 80 | 81 | land_01_data=self.land_01 82 | mask_data=self.mask_data 83 | hr_target = self.target_hr[memmap_index][index_in_memmap]*mask_data 84 | if self.variable_name=='tp': 85 | lr_inter=interpolate(torch.from_numpy(np.expand_dims(self.target_lr[memmap_index][index_in_memmap],axis=0)).float(),scale_factor=10, mode="bilinear").squeeze(0)*mask_data 86 | else: 87 | lr_inter=interpolate(torch.from_numpy(np.expand_dims(self.target_lr[memmap_index][index_in_memmap],axis=0)).float(),scale_factor=10, mode="bicubic").squeeze(0)*mask_data 88 | return self.get_patch(hr_target,mask_data,land_01_data,lr_inter) 89 | 90 | 91 | 92 | class SR3_Dataset_finetune_patch(torch.utils.data.Dataset): 93 | def __init__(self,hr_paths,land_paths,mask_paths,lr_paths,var,patch_size): 94 | index_list = [] 95 | for i, i_start in enumerate(np.arange(0, 400, patch_size)): 96 | for j, j_start in enumerate(np.arange(0, 700, patch_size)): 97 | i_end = i_start + patch_size 98 | j_end = j_start + patch_size 99 | if i_end > 400: 100 | i_end = 400 101 | i_start=400-128 102 | if j_end > 700: 103 | j_end = 700 104 | j_start=700-128 105 | index_list.append((i_start, i_end, j_start, j_end)) 106 | self.loc_dict={} 107 | for i,index in enumerate(index_list): 108 | self.loc_dict[str(i)]=index 109 | var_list={"u":0,"v":1,"t2m":2,"sp":3,"tp":4,} 110 | index_var=var_list[var] 111 | self.variable_name=var 112 | # for path1,path2 in zip(hr_paths,physical_paths): 113 | # print(path1,path2) 114 | self.target_hr = [np.load(path, mmap_mode='r+').transpose(0,3,1,2)[:,index_var:index_var+1] for path in hr_paths] 115 | self.target_lr = [np.load(path, mmap_mode='r+').transpose(0,3,1,2)[:,index_var:index_var+1] for path in lr_paths] 116 | 117 | #[0,2,4,6,8]# 500 zrtuv #[6,8,4,0,2]u v t z r 118 | self.land_01=np.expand_dims(np.load(land_paths, mmap_mode='r+'),axis=0) 119 | self.mask_data=np.expand_dims(np.load(mask_paths, mmap_mode='r+'),axis=0) 120 | self.start_indices = [0] * len(self.target_hr) 121 | self.data_count = 0 122 | # self.scale=scale 123 | self.patch_size=patch_size 124 | for index, memmap in enumerate(self.target_hr): 125 | self.start_indices[index] = self.data_count 126 | self.data_count += memmap.shape[0] 127 | self.max = torch.from_numpy(np.load("/home/data/downscaling/downscaling_1023/data/train_dataset/max_new_10.npy", mmap_mode='r+')).float() 128 | self.min = torch.from_numpy(np.load("/home/data/downscaling/downscaling_1023/data/train_dataset/min_new_10.npy", mmap_mode='r+')).float() 129 | # print(self.max.shape) 130 | def normal_max_min(self,data,i_start,i_end, j_start,j_end): 131 | var_list={"u":0,"v":1,"t2m":2,"sp":3,"tp":4,} 132 | index_var=var_list[self.variable_name] 133 | 134 | max_=self.max[index_var:index_var+1,i_start:i_end, j_start:j_end] 135 | min_=self.min[index_var:index_var+1,i_start:i_end, j_start:j_end] 136 | 137 | # print(max_.max(),max_.min()) 138 | # print((max_-min_).max(),(max_-min_).min()) 139 | return (data-min_)/(max_-min_+1e-6) 140 | def get_patch(self,hr,mask,hr_land,lr_inter): 141 | loc=random.randrange(0, len(self.loc_dict)) 142 | i_start,i_end, j_start,j_end=self.loc_dict[str(loc)] 143 | # ih_hr, iw_hr = hr.shape[1:] 144 | # ip=self.patch_size 145 | # ix = random.randrange(0, iw_hr - ip + 1) 146 | # iy = random.randrange(0, ih_hr - ip + 1) 147 | mask_data=torch.from_numpy(mask[:,i_start:i_end, j_start:j_end]).float() 148 | land_data=torch.from_numpy(hr_land[:,i_start:i_end, j_start:j_end]).float() 149 | 150 | if self.variable_name in ["u","v","t2m","sp"]: 151 | lr_data=self.normal_max_min(lr_inter[:,i_start:i_end, j_start:j_end].float(),i_start,i_end, j_start,j_end) 152 | ret = { 153 | "HR":self.normal_max_min(torch.from_numpy(hr[:,i_start:i_end, j_start:j_end]).float(),i_start,i_end, j_start,j_end), 154 | "mask":mask_data, 155 | "INTERPOLATED":torch.cat([lr_data,mask_data,land_data],axis=0), 156 | "LAND":land_data 157 | } 158 | else: 159 | lr_data=lr_inter[:,i_start:i_end, j_start:j_end].float() 160 | ret = { 161 | "HR":torch.from_numpy(hr[:,i_start:i_end, j_start:j_end]).float(), 162 | "mask":mask_data, 163 | "INTERPOLATED":torch.cat([lr_data,mask_data,land_data],axis=0), 164 | "LAND":land_data 165 | } 166 | 167 | return ret 168 | 169 | def __len__(self): 170 | return self.data_count 171 | 172 | def __getitem__(self, index): 173 | memmap_index = bisect(self.start_indices, index) - 1 174 | index_in_memmap = index - self.start_indices[memmap_index] 175 | 176 | land_01_data=self.land_01 177 | mask_data=self.mask_data 178 | hr_target = self.target_hr[memmap_index][index_in_memmap]*mask_data 179 | if self.variable_name=='tp': 180 | lr_inter=interpolate(torch.from_numpy(np.expand_dims(self.target_lr[memmap_index][index_in_memmap],axis=0)).float(),scale_factor=10, mode="bilinear").squeeze(0)*mask_data 181 | else: 182 | lr_inter=interpolate(torch.from_numpy(np.expand_dims(self.target_lr[memmap_index][index_in_memmap],axis=0)).float(),scale_factor=10, mode="bicubic").squeeze(0)*mask_data 183 | 184 | return self.get_patch(hr_target,mask_data,land_01_data,lr_inter) 185 | 186 | 187 | # class SR3_Dataset_val_new(torch.utils.data.Dataset): 188 | # def __init__(self,hr_paths,land_paths,mask_paths,lr_paths,var,patch_size,loc): 189 | # index_list = [] 190 | # for i, i_start in enumerate(np.arange(0, 400, patch_size)): 191 | # for j, j_start in enumerate(np.arange(0, 700, patch_size)): 192 | # i_end = i_start + patch_size 193 | # j_end = j_start + patch_size 194 | # if i_end > 400: 195 | # i_end = 400 196 | # i_start=400-128 197 | # if j_end > 700: 198 | # j_end = 700 199 | # j_start=700-128 200 | # index_list.append((i_start, i_end, j_start, j_end)) 201 | # loc_dict={} 202 | # for i,index in enumerate(index_list): 203 | # loc_dict[str(i)]=index 204 | # var_list={"u":0,"v":1,"t2m":2,"sp":3,"tp":4,} 205 | # index_var=var_list[var] 206 | # self.loc_index=loc_dict[str(loc)] 207 | # self.target_hr = [np.load(path, mmap_mode='r+').transpose(0,3,1,2)[:,index_var:index_var+1] for path in hr_paths] 208 | # self.target_lr = [np.load(path, mmap_mode='r+').transpose(0,3,1,2)[:,index_var:index_var+1] for path in lr_paths] 209 | 210 | # #[0,2,4,6,8]# 500 zrtuv #[6,8,4,0,2]u v t z r 211 | # self.land_01=np.expand_dims(np.load(land_paths, mmap_mode='r+'),axis=0) 212 | # self.mask_data=np.expand_dims(np.load(mask_paths, mmap_mode='r+'),axis=0) 213 | # self.start_indices = [0] * len(self.target_hr) 214 | # self.data_count = 0 215 | # self.patch_size=patch_size 216 | # for index, memmap in enumerate(self.target_hr): 217 | # self.start_indices[index] = self.data_count 218 | # self.data_count += memmap.shape[0] 219 | # def get_patch(self,hr,mask,hr_land,lr_inter): 220 | # i_start,i_end, j_start,j_end=self.loc_index 221 | # mask_data=torch.from_numpy(mask[:,i_start:i_end, j_start:j_end]).float() 222 | # land_data=torch.from_numpy(hr_land[:,i_start:i_end, j_start:j_end]).float() 223 | # lr_data=lr_inter[:,i_start:i_end, j_start:j_end].float() 224 | # ret = { 225 | # "HR":torch.from_numpy(hr[:,i_start:i_end, j_start:j_end]).float(), 226 | # "mask":mask_data, 227 | # "INTERPOLATED":torch.cat([lr_data,mask_data,land_data],axis=0), 228 | # "LAND":land_data 229 | # } 230 | # return ret 231 | 232 | # def __len__(self): 233 | # return self.data_count 234 | 235 | # def __getitem__(self, index): 236 | # memmap_index = bisect(self.start_indices, index) - 1 237 | # index_in_memmap = index - self.start_indices[memmap_index] 238 | 239 | # land_01_data=self.land_01 240 | # mask_data=self.mask_data 241 | # hr_target = self.target_hr[memmap_index][index_in_memmap]*mask_data 242 | 243 | # lr_inter=interpolate(torch.from_numpy(np.expand_dims(self.target_lr[memmap_index][index_in_memmap],axis=0)).float(),scale_factor=10, mode="bicubic").squeeze(0)*mask_data 244 | 245 | # return self.get_patch(hr_target,mask_data,land_01_data,lr_inter) 246 | 247 | 248 | 249 | class SR3_Dataset_all(torch.utils.data.Dataset): 250 | def __init__(self,land_paths,mask_paths,lr_paths,var): 251 | var_list={"u":0,"v":1,"t2m":2,"sp":3,"tp":4,} 252 | index_var=var_list[var] 253 | self.variable_name=var 254 | self.target_lr = [np.load(path, mmap_mode='r+').transpose(0,3,1,2)[:,index_var:index_var+1] for path in lr_paths] 255 | self.land_01=np.expand_dims(np.load(land_paths, mmap_mode='r+'),axis=0) 256 | self.mask_data=np.expand_dims(np.load(mask_paths, mmap_mode='r+'),axis=0) 257 | self.start_indices = [0] * len(self.target_lr) 258 | self.data_count = 0 259 | for index, memmap in enumerate(self.target_lr): 260 | self.start_indices[index] = self.data_count 261 | self.data_count += memmap.shape[0] 262 | self.max = torch.from_numpy(np.load("/home/data/downscaling/downscaling_1023/data/train_dataset/max_new_10.npy", mmap_mode='r+')).float() 263 | self.min = torch.from_numpy(np.load("/home/data/downscaling/downscaling_1023/data/train_dataset/min_new_10.npy", mmap_mode='r+')).float() 264 | def normal_max_min(self,data): 265 | var_list={"u":0,"v":1,"t2m":2,"sp":3,"tp":4,} 266 | index_var=var_list[self.variable_name] 267 | 268 | max_=self.max[index_var:index_var+1] 269 | min_=self.min[index_var:index_var+1] 270 | 271 | # print(max_.max(),max_.min()) 272 | # print((max_-min_).max(),(max_-min_).min()) 273 | return (data-min_)/(max_-min_+1e-6) 274 | 275 | def get_patch(self,mask,hr_land,lr_inter): 276 | mask_data=torch.from_numpy(mask).float() 277 | land_data=torch.from_numpy(hr_land).float() 278 | if self.variable_name in ["u","v","t2m","sp"]: 279 | lr_data=self.normal_max_min(lr_inter.float()) 280 | ret = { 281 | "mask":mask_data, 282 | "INTERPOLATED":torch.cat([lr_data,mask_data,land_data],axis=0), 283 | "LAND":land_data 284 | } 285 | else: 286 | lr_data=lr_inter.float() 287 | ret = { 288 | "mask":mask_data, 289 | "INTERPOLATED":torch.cat([lr_data,mask_data,land_data],axis=0), 290 | "LAND":land_data 291 | } 292 | 293 | return ret 294 | 295 | 296 | 297 | def __len__(self): 298 | return self.data_count 299 | 300 | def __getitem__(self, index): 301 | memmap_index = bisect(self.start_indices, index) - 1 302 | index_in_memmap = index - self.start_indices[memmap_index] 303 | mask_data=self.mask_data 304 | land_01_data=self.land_01 305 | if self.variable_name=='tp': 306 | lr_inter=interpolate(torch.from_numpy(np.expand_dims(self.target_lr[memmap_index][index_in_memmap],axis=0)).float(),scale_factor=10, mode="bilinear").squeeze(0)*mask_data 307 | else: 308 | lr_inter=interpolate(torch.from_numpy(np.expand_dims(self.target_lr[memmap_index][index_in_memmap],axis=0)).float(),scale_factor=10, mode="bicubic").squeeze(0)*mask_data 309 | 310 | 311 | return self.get_patch(mask_data,land_01_data,lr_inter) 312 | 313 | 314 | class SR3_Dataset_test(torch.utils.data.Dataset): 315 | def __init__(self,land_paths,mask_paths,lr_paths,var,patch_size,loc): 316 | index_list = [] 317 | for i, i_start in enumerate(np.arange(0, 400, patch_size)): 318 | for j, j_start in enumerate(np.arange(0, 700, patch_size)): 319 | i_end = i_start + patch_size 320 | j_end = j_start + patch_size 321 | if i_end > 400: 322 | i_end = 400 323 | i_start=400-128 324 | if j_end > 700: 325 | j_end = 700 326 | j_start=700-128 327 | index_list.append((i_start, i_end, j_start, j_end)) 328 | loc_dict={} 329 | for i,index in enumerate(index_list): 330 | loc_dict[str(i)]=index 331 | var_list={"u":0,"v":1,"t2m":2,"sp":3,"tp":4,} 332 | index_var=var_list[var] 333 | self.loc_index=loc_dict[str(loc)] 334 | self.target_lr = [np.load(path, mmap_mode='r+').transpose(0,3,1,2)[:,index_var:index_var+1] for path in lr_paths] 335 | 336 | #[0,2,4,6,8]# 500 zrtuv #[6,8,4,0,2]u v t z r 337 | self.land_01=np.expand_dims(np.load(land_paths, mmap_mode='r+'),axis=0) 338 | self.mask_data=np.expand_dims(np.load(mask_paths, mmap_mode='r+'),axis=0) 339 | self.start_indices = [0] * len(self.target_lr) 340 | self.data_count = 0 341 | self.patch_size=patch_size 342 | for index, memmap in enumerate(self.target_lr): 343 | self.start_indices[index] = self.data_count 344 | self.data_count += memmap.shape[0] 345 | def get_patch(self,mask,hr_land,lr_inter): 346 | i_start,i_end, j_start,j_end=self.loc_index 347 | mask_data=torch.from_numpy(mask[:,i_start:i_end, j_start:j_end]).float() 348 | land_data=torch.from_numpy(hr_land[:,i_start:i_end, j_start:j_end]).float() 349 | lr_data=lr_inter[:,i_start:i_end, j_start:j_end].float() 350 | ret = { 351 | "mask":mask_data, 352 | "INTERPOLATED":torch.cat([lr_data,mask_data,land_data],axis=0), 353 | "LAND":land_data 354 | } 355 | return ret 356 | 357 | def __len__(self): 358 | return self.data_count 359 | 360 | def __getitem__(self, index): 361 | memmap_index = bisect(self.start_indices, index) - 1 362 | index_in_memmap = index - self.start_indices[memmap_index] 363 | 364 | land_01_data=self.land_01 365 | mask_data=self.mask_data 366 | 367 | lr_inter=interpolate(torch.from_numpy(np.expand_dims(self.target_lr[memmap_index][index_in_memmap],axis=0)).float(),scale_factor=10, mode="bicubic").squeeze(0)*mask_data 368 | 369 | return self.get_patch(mask_data,land_01_data,lr_inter) 370 | 371 | 372 | class BigDataset_test(torch.utils.data.Dataset): 373 | def __init__(self,hr_paths,land_paths,mask_paths): 374 | self.target_hr = [np.load(path, mmap_mode='r+').transpose(0,3,1,2) for path in hr_paths] 375 | self.land_01=np.expand_dims(np.load(land_paths, mmap_mode='r+'),axis=0) 376 | self.mask_data=np.expand_dims(np.load(mask_paths, mmap_mode='r+'),axis=0) 377 | self.start_indices = [0] * len(self.target_hr) 378 | self.data_count = 0 379 | # self.scale=scale 380 | # self.max_01=np.load(max_paths, mmap_mode='r+') 381 | for index, memmap in enumerate(self.target_hr): 382 | self.start_indices[index] = self.data_count 383 | self.data_count += memmap.shape[0] 384 | 385 | 386 | 387 | def get_patch(self,hr,mask,hr_land): 388 | mask_data=torch.from_numpy(mask).float() 389 | land_data=torch.from_numpy(hr_land).float() 390 | random_index=random.random() 391 | if random_index<0: 392 | ret = { 393 | "HR":torch.from_numpy(hr).float(), 394 | "mask":mask_data, 395 | "INTERPOLATED":torch.cat([mask_data,land_data],axis=0), 396 | "LAND":land_data, 397 | } 398 | else: 399 | #patch_list=[256] 400 | ip=256#patch_list[random.randint(0, 2)] 401 | ih_hr, iw_hr = hr.shape[1:] 402 | ix = random.randrange(0, iw_hr - ip + 1) 403 | iy = random.randrange(0, ih_hr - ip + 1) 404 | mask_data=torch.from_numpy(mask[:,iy:iy + ip, ix:ix + ip]).float() 405 | land_data=torch.from_numpy(hr_land[:,iy:iy + ip, ix:ix + ip]).float() 406 | ret = { 407 | "HR":torch.from_numpy(hr[:,iy:iy + ip, ix:ix + ip]).float(), 408 | "mask":mask_data, 409 | "INTERPOLATED":torch.cat([mask_data,land_data],axis=0), 410 | "LAND":land_data 411 | } 412 | return ret 413 | 414 | 415 | 416 | def __len__(self): 417 | return self.data_count 418 | 419 | def __getitem__(self, index): 420 | memmap_index = bisect(self.start_indices, index) - 1 421 | index_in_memmap = index - self.start_indices[memmap_index] 422 | 423 | land_01_data=self.land_01 424 | hr_target = self.target_hr[memmap_index][index_in_memmap] 425 | # physical=self.data_physical[memmap_index][index_in_memmap] 426 | mask_data=self.mask_data 427 | 428 | 429 | return self.get_patch(hr_target,mask_data,land_01_data) 430 | 431 | 432 | 433 | 434 | class BigDataset_cascade_infer(torch.utils.data.Dataset): 435 | def __init__(self,lr_paths,mask_paths,mask_paths_2x,var): 436 | variable={"u10":0,"v10":1,"sp":2,"t2m":3,"tp":4} 437 | idx=variable[var] 438 | self.data_lr = [np.load(path, mmap_mode='r+').transpose(0,3,1,2)[:,idx:idx+1] for path in lr_paths] 439 | self.mask_data=np.expand_dims(np.load(mask_paths, mmap_mode='r+'),axis=0) 440 | #2倍 441 | self.mask_02=np.expand_dims(np.load(mask_paths_2x, mmap_mode='r+'),axis=0) 442 | self.start_indices = [0] * len(self.data_lr) 443 | self.data_count = 0 444 | # self.scale=scale 445 | # self.patch_size=patch_size 446 | 447 | for index, memmap in enumerate(self.data_lr): 448 | self.start_indices[index] = self.data_count 449 | self.data_count += memmap.shape[0] 450 | 451 | 452 | def __len__(self): 453 | return self.data_count 454 | 455 | def __getitem__(self, index): 456 | memmap_index = bisect(self.start_indices, index) - 1 457 | index_in_memmap = index - self.start_indices[memmap_index] 458 | lr_data = self.data_lr[memmap_index][index_in_memmap] 459 | mask_data=torch.from_numpy(self.mask_data).float() 460 | mask_data_2x=torch.from_numpy(self.mask_02).float() 461 | 462 | inter=interpolate(torch.from_numpy(np.expand_dims(lr_data,axis=0)).float(),scale_factor=2, mode="bicubic").squeeze(0) 463 | ret = { 464 | "LR":torch.from_numpy(lr_data).float(), 465 | "INTERPOLATED":inter*mask_data_2x,#/max_ 466 | "mask": mask_data 467 | 468 | } 469 | 470 | return ret 471 | 472 | 473 | 474 | if __name__ == '__main__': 475 | data_paths = sorted(glob.glob("/home/data/downscaling/downscaling_1023/data/train_dataset/sl/*npy")) 476 | target_paths = sorted(glob.glob("/home/data/downscaling/downscaling_1023/data/train_dataset/hr/*npy")) 477 | physical_paths = sorted(glob.glob("/home/data/downscaling/downscaling_1023/data/train_dataset/pl/*npy")) 478 | land_01_path="/home/data/downscaling/downscaling_1023/data/land10.npy" 479 | -------------------------------------------------------------------------------- /data/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LingFH/Diffusion_4_downscaling/5dea4cbaf404803c6bedd1abebb90a2effa8686c/data/test.png -------------------------------------------------------------------------------- /finetune_all.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import pickle 5 | import warnings 6 | from collections import OrderedDict, defaultdict 7 | import numpy as np 8 | import torch 9 | from tensorboardX import SummaryWriter 10 | from torch.nn.functional import mse_loss, l1_loss 11 | from torch.utils.data import DataLoader 12 | import model 13 | # from x2_data.mydataset_patch import BigDataset_train 14 | from data.mydataset_patch import SR3_Dataset_patch,SR3_Dataset_finetune_patch 15 | from configs import Config 16 | import matplotlib 17 | import matplotlib.pyplot as plt 18 | matplotlib.use('Agg') 19 | import glob 20 | from utils import dict2str, setup_logger, construct_and_save_wbd_plots, \ 21 | accumulate_statistics, \ 22 | get_optimizer, construct_mask, set_seeds,psnr 23 | import random 24 | warnings.filterwarnings("ignore") 25 | 26 | 27 | if __name__ == "__main__": 28 | set_seeds() # For reproducability. 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument("-c", "--config", type=str, help="JSON file for configuration") 31 | parser.add_argument("-p", "--phase", type=str, choices=["train", "val"], 32 | help="Run either training or validation(inference).", default="train") 33 | parser.add_argument("-gpu", "--gpu_ids", type=str, default=None) 34 | parser.add_argument("-var", "--variable_name", type=str, default=None) 35 | args = parser.parse_args() 36 | variable_name=args.variable_name 37 | configs = Config(args) 38 | torch.backends.cudnn.enabled = True 39 | torch.backends.cudnn.benchmark = True 40 | 41 | setup_logger(None, configs.log, "train", screen=True) 42 | setup_logger("val", configs.log, "val") 43 | 44 | logger = logging.getLogger("base") 45 | val_logger = logging.getLogger("val") 46 | logger.info(dict2str(configs.get_hyperparameters_as_dict())) 47 | tb_logger = SummaryWriter(log_dir=configs.tb_logger) 48 | 49 | 50 | target_paths = sorted(glob.glob("/home/data/downscaling/downscaling_1023/data/train_dataset/hr/*npy")) 51 | land_01_path="/home/data/downscaling/downscaling_1023/data/land10.npy" 52 | mask_path="/home/data/downscaling/downscaling_1023/data/mask10.npy" 53 | # physical_paths= sorted(glob.glob("/home/data/downscaling/downscaling_1023/data/train_dataset/pl/*npy")) 54 | lr_paths= sorted(glob.glob("/home/data/downscaling/downscaling_1023/data/train_dataset/sl/*npy")) 55 | random_dataset_index= random.sample(range(0, len(target_paths)), 2) 56 | data_index=np.arange(0,len(target_paths)) 57 | train_index=np.delete(data_index,random_dataset_index) 58 | logger.info(f"split_random dataset is {random_dataset_index}" ) 59 | train_data = SR3_Dataset_finetune_patch(np.array(target_paths)[train_index],land_01_path,mask_path,lr_paths=np.array(lr_paths)[train_index],var=variable_name,patch_size=configs.height) 60 | val_data=SR3_Dataset_finetune_patch(np.array(target_paths)[random_dataset_index],land_01_path,mask_path,lr_paths=np.array(lr_paths)[random_dataset_index],var=variable_name,patch_size=configs.height) 61 | 62 | logger.info(f"Train size: {len(train_data)}, Val size: {len(val_data)}.") 63 | train_loader = DataLoader(train_data, batch_size=configs.batch_size,shuffle=configs.use_shuffle, num_workers=configs.num_workers,drop_last=True) 64 | val_loader = DataLoader(val_data, batch_size=np.int(configs.batch_size/12),shuffle=False, num_workers=configs.num_workers,drop_last=True) 65 | logger.info("Training and Validation dataloaders are ready.") 66 | # Defining the model. 67 | optimizer = get_optimizer(configs.optimizer_type) 68 | diffusion = model.create_model(in_channel=configs.in_channel, out_channel=configs.out_channel, 69 | norm_groups=configs.norm_groups, inner_channel=configs.inner_channel, 70 | channel_multiplier=configs.channel_multiplier, attn_res=configs.attn_res, 71 | res_blocks=configs.res_blocks, dropout=configs.dropout, 72 | diffusion_loss=configs.diffusion_loss, conditional=configs.conditional, 73 | gpu_ids=configs.gpu_ids, distributed=configs.distributed, 74 | init_method=configs.init_method, train_schedule=configs.train_schedule, 75 | train_n_timestep=configs.train_n_timestep, 76 | train_linear_start=configs.train_linear_start, 77 | train_linear_end=configs.train_linear_end, 78 | val_schedule=configs.val_schedule, val_n_timestep=configs.val_n_timestep, 79 | val_linear_start=configs.val_linear_start, val_linear_end=configs.val_linear_end, 80 | finetune_norm=configs.finetune_norm, optimizer=optimizer, amsgrad=configs.amsgrad, 81 | learning_rate=configs.lr, checkpoint=configs.checkpoint, 82 | resume_state=configs.resume_state,phase=configs.phase, height=configs.height) 83 | logger.info("Model initialization is finished.") 84 | 85 | current_step, current_epoch = diffusion.begin_step, diffusion.begin_epoch 86 | if configs.resume_state: 87 | logger.info(f"Resuming training from epoch: {current_epoch}, iter: {current_step}.") 88 | 89 | logger.info("Starting the training.") 90 | diffusion.register_schedule(beta_schedule=configs.train_schedule, timesteps=configs.train_n_timestep, 91 | linear_start=configs.train_linear_start, linear_end=configs.train_linear_end) 92 | 93 | accumulated_statistics = OrderedDict() 94 | 95 | val_metrics_dict={"MSE": 0.0, "MAE": 0.0,"MAE_inter":0.0} 96 | val_metrics_dict["PSNR_"+variable_name]=0.0 97 | val_metrics_dict["PSNR_inter_"+variable_name]=0.0 98 | val_metrics_dict["RMSE_"+variable_name]=0.0 99 | val_metrics_dict["RMSE_inter_"+variable_name]=0.0 100 | 101 | 102 | val_metrics = OrderedDict(val_metrics_dict) 103 | 104 | # Training. 105 | while current_step < configs.n_iter: 106 | current_epoch += 1 107 | 108 | for train_data in train_loader: 109 | current_step += 1 110 | 111 | if current_step > configs.n_iter: 112 | break 113 | 114 | # Training. 115 | diffusion.feed_data(train_data) 116 | diffusion.optimize_parameters() 117 | diffusion.lr_scheduler_step() # For lr scheduler updates per iteration. 118 | accumulate_statistics(diffusion.get_current_log(), accumulated_statistics) 119 | 120 | # Logging the training information. 121 | if current_step % configs.print_freq == 0: 122 | message = f"Epoch: {current_epoch:5} | Iteration: {current_step:8}" 123 | 124 | for metric, values in accumulated_statistics.items(): 125 | mean_value = np.mean(values) 126 | message = f"{message} | {metric:s}: {mean_value:.5f}" 127 | tb_logger.add_scalar(f"{metric}/train", mean_value, current_step) 128 | 129 | logger.info(message) 130 | # tb_logger.add_scalar(f"learning_rate", diffusion.get_lr(), current_step) 131 | 132 | # Visualizing distributions of parameters. 133 | # for name, param in diffusion.get_named_parameters(): 134 | # tb_logger.add_histogram(name, param.clone().cpu().data.numpy(), current_step) 135 | 136 | accumulated_statistics = OrderedDict() 137 | # Validation. 138 | if current_step % configs.val_freq == 0: 139 | logger.info("Starting validation.") 140 | idx = 0 141 | result_path = f"{configs.results}/{current_epoch}" 142 | os.makedirs(result_path, exist_ok=True) 143 | diffusion.register_schedule(beta_schedule=configs.val_schedule, 144 | timesteps=configs.val_n_timestep, 145 | linear_start=configs.val_linear_start, 146 | linear_end=configs.val_linear_end) 147 | 148 | # A dictionary for storing a list of mean temperatures for each month. 149 | # month2mean_temperature = defaultdict(list) 150 | 151 | for val_data in val_loader: 152 | idx += 1 153 | diffusion.feed_data(val_data) 154 | #实验一采用了250,实验二用50 155 | diffusion.test(continuous=False,use_ddim=True,ddim_steps=250,use_dpm_solver=False) # Continues=False to return only the last timesteps's outcome. 156 | 157 | # Computing metrics on vlaidation data. 158 | visuals = diffusion.get_current_visuals() 159 | # Computing MSE and RMSE on original data. 160 | mask=val_data["mask"] 161 | mse_value = mse_loss(visuals["HR"]*mask, visuals["SR"]*mask) 162 | val_metrics["MSE"] += mse_value 163 | val_metrics["MAE"] += l1_loss(visuals["HR"]*mask, visuals["SR"]*mask) 164 | val_metrics["MAE_inter"] += l1_loss(visuals["HR"]*mask, visuals["INTERPOLATED"]*mask) 165 | 166 | val_metrics["RMSE_"+variable_name] += torch.sqrt(mse_loss(visuals["HR"]*mask, visuals["SR"]*mask)) 167 | val_metrics["RMSE_inter_"+variable_name] += torch.sqrt(mse_loss(visuals["HR"]*mask, visuals["INTERPOLATED"]*mask)) 168 | val_metrics["PSNR_"+variable_name] += psnr(visuals["HR"]*mask, visuals["SR"]*mask) 169 | val_metrics["PSNR_inter_"+variable_name] += psnr(visuals["HR"]*mask, visuals["INTERPOLATED"]*mask) 170 | 171 | 172 | if idx % configs.val_vis_freq == 0: 173 | 174 | logger.info(f"[{idx//configs.val_vis_freq}] Visualizing and storing some examples.") 175 | 176 | sr_candidates = diffusion.generate_multiple_candidates(n=configs.sample_size,ddim_steps=100,use_dpm_solver=False) 177 | 178 | mean_candidate = sr_candidates.mean(dim=0) # [B, C, H, W] 179 | std_candidate = sr_candidates.std(dim=0) # [B, C, H, W] 180 | bias = mean_candidate - visuals["HR"] 181 | mean_bias_over_pixels = bias.mean() # Scalar. 182 | std_bias_over_pixels = bias.std() # Scalar. 183 | 184 | 185 | 186 | # # Choosing the first n_val_vis number of samples to visualize. 187 | # variable_id=0 188 | random_idx=np.random.randint(0,np.int(configs.batch_size/12),5) 189 | 190 | path = f"{result_path}/{current_epoch}_{current_step}_{idx}" 191 | figure,axs=plt.subplots(5,9,figsize=(25,12)) 192 | if variable_name=="tp": 193 | vmin=0 194 | cmap="BrBG" 195 | vmax=2 196 | elif variable_name in ["t2m","sp","u","v"]: 197 | vmin=0 198 | cmap="RdBu_r" 199 | vmax=1 200 | else: 201 | vmin=-2 202 | cmap="RdBu_r" 203 | vmax=2 204 | for idx_i,num in enumerate(random_idx): 205 | axs[idx_i,0].imshow(visuals["HR"][num,0],vmin=vmin,vmax=vmax,cmap=cmap) 206 | axs[idx_i,1].imshow(visuals["SR"][num,0],vmin=vmin,vmax=vmax,cmap=cmap) 207 | axs[idx_i,2].imshow(visuals["INTERPOLATED"][num,0],vmin=vmin,vmax=vmax,cmap=cmap) 208 | 209 | axs[idx_i,3].imshow(mean_candidate[num,0],vmin=vmin,vmax=vmax,cmap=cmap) 210 | axs[idx_i,4].imshow(std_candidate[num,0],vmin=0,vmax=2,cmap='Reds') 211 | axs[idx_i,5].imshow(np.abs(visuals["HR"][num,0]-visuals["SR"][num,0]),vmin=0,vmax=2,cmap="Reds") 212 | axs[idx_i,7].imshow(np.abs(visuals["HR"][num,0]-visuals["INTERPOLATED"][num,0]),vmin=0,vmax=2,cmap="Reds") 213 | axs[idx_i,6].imshow(np.abs(bias)[num,0],vmin=0,vmax=2,cmap="Reds") 214 | axs[idx_i,8].imshow(val_data['mask'][num,0],vmin=0,vmax=2,cmap="RdBu_r") 215 | axs[idx_i,8].set_title("mean_mae:%.3f,inter_mae:%.3f,sr_mae:%.3f"%(np.abs(bias)[num,0].mean(),np.abs(visuals["HR"][num,0]-visuals["INTERPOLATED"][num,0]).mean(),np.abs(visuals["HR"][num,0]-visuals["SR"][num,0]).mean())) 216 | for title , col in zip(["HR","Diffusion","INTERPOLATED","mean","std","mae_sr","mae_mean","mae_inter"],range(8)): 217 | axs[0,col].set_title(title) 218 | plt.savefig(f"{path}_.png", bbox_inches="tight") 219 | plt.close("all") 220 | 221 | val_metrics["MSE"] /= idx 222 | val_metrics["MAE"] /= idx 223 | val_metrics["MAE_inter"] /= idx 224 | 225 | val_metrics["RMSE_"+variable_name] /= idx 226 | val_metrics["RMSE_inter_"+variable_name] /= idx 227 | val_metrics["PSNR_"+variable_name] /= idx 228 | val_metrics["PSNR_inter_"+variable_name] /= idx 229 | diffusion.register_schedule(beta_schedule=configs.train_schedule, 230 | timesteps=configs.train_n_timestep, 231 | linear_start=configs.train_linear_start, 232 | linear_end=configs.train_linear_end) 233 | message = f"Epoch: {current_epoch:5} | Iteration: {current_step:8}" 234 | for metric, value in val_metrics.items(): 235 | message = f"{message} | {metric:s}: {value:.5f}" 236 | tb_logger.add_scalar(f"{metric}/val", value, current_step) 237 | 238 | val_logger.info(message) 239 | 240 | val_metrics = val_metrics.fromkeys(val_metrics, 0.0) # Sets all metrics to zero. 241 | 242 | if current_step % configs.save_checkpoint_freq == 0: 243 | logger.info("Saving models and training states.") 244 | diffusion.save_network(current_epoch, current_step) 245 | 246 | tb_logger.close() 247 | 248 | logger.info("End of training.") 249 | 250 | -------------------------------------------------------------------------------- /inference_2_monthly.py: -------------------------------------------------------------------------------- 1 | """The inference script for DDIM model. 2 | """ 3 | import argparse 4 | import logging 5 | import os 6 | import pickle 7 | import warnings 8 | from collections import OrderedDict, defaultdict 9 | import numpy as np 10 | import torch 11 | from torch.nn.functional import mse_loss, l1_loss 12 | from torch.utils.data import DataLoader 13 | from data.mydataset_patch import SR3_Dataset_all 14 | import model 15 | from configs import Config, get_current_datetime 16 | from utils import dict2str, setup_logger, construct_and_save_wbd_plots, \ 17 | construct_mask, set_seeds,psnr 18 | import xarray as xr 19 | import glob 20 | import matplotlib.pyplot as plt 21 | import matplotlib 22 | matplotlib.use('Agg') 23 | warnings.filterwarnings("ignore") 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | def loop_prediction(start_year,end_year,ddim_steps): 33 | max_normal = np.load("/home/data/downscaling/downscaling_1023/data/train_dataset/max_new_10.npy", mmap_mode='r+')[list_idx] 34 | min_normal = np.load("/home/data/downscaling/downscaling_1023/data/train_dataset/min_new_10.npy", mmap_mode='r+')[list_idx] 35 | for year in range(start_year,end_year): 36 | all_data=[] 37 | member_data=[] 38 | idx=0 39 | batch=12 40 | data_paths = sorted(glob.glob("/home/data/downscaling/downscaling_1023/data/test_dataset/sl/*{0}_monthly*npy".format(year))) 41 | information=xr.open_dataset("/home/data/downscaling/downscaling_1023/data/ERA_deal/different_grid/10km/ERA5_land_10km_East_china_{0}_monthly.nc".format(year)) 42 | val_logger.info(f"Dataset- Testing] is created=========year: "+str(year)) 43 | val_dataset = SR3_Dataset_all(land_01_path,mask_path,lr_paths=data_paths,var=variable_name) 44 | val_loader = DataLoader(val_dataset, batch_size=batch,shuffle=False, num_workers=3) 45 | idx=0 46 | with torch.no_grad(): 47 | for val_data in val_loader: 48 | if idx % 5==0: 49 | print(idx*batch) 50 | idx = idx+1 51 | diffusion.feed_data(val_data) 52 | diffusion.infer_patch(continuous=False,use_ddim=True,use_dpm_solver=False,ddim_steps=ddim_steps)#infer_patch这个是平均,v2是不带平均,两个需要实验看看 53 | visuals = diffusion.get_current_visuals(only_rec=True) 54 | pred_norm=visuals["SR"] 55 | all_data.append(pred_norm)# 56 | if need_member: 57 | sr_candidates = diffusion.infer_generate_multiple_candidates(n=sample_size,use_ddim=True,use_dpm_solver=False,ddim_steps=ddim_steps) 58 | # mem_candidate = sr_candidates* torch.from_numpy(std_hr).float() +torch.from_numpy(mean_hr).float() # [n,B, C, H, W] 59 | member_data.append(sr_candidates) 60 | 61 | if variable_name == "tp": 62 | new_data=torch.clamp(torch.cat(all_data,dim=0),0,5).numpy() 63 | new_data=np.exp(new_data[:,0,:,:])-1 64 | new_data[new_data<0]=0 65 | else: 66 | new_data=torch.clamp(torch.cat(all_data,dim=0),0,1).numpy() 67 | new_data=new_data[:,0,:,:]*(max_normal-min_normal)+min_normal 68 | new_data=new_data*std_hr[list_idx]+mean_hr[list_idx] 69 | 70 | if need_member: 71 | if variable_name == "tp": 72 | new_member_data=torch.clamp(torch.cat(member_data,dim=1),0,5).numpy() 73 | new_member_data=np.exp(new_member_data[:,:,0,:,:])-1 74 | new_member_data[new_member_data<0]=0 75 | else: 76 | new_member_data=torch.clamp(torch.cat(member_data,dim=1),0,1).numpy() 77 | new_member_data=new_member_data[:,:,0,:,:]*(max_normal-min_normal)+min_normal 78 | new_member_data=new_member_data*std_hr[list_idx]+mean_hr[list_idx] 79 | 80 | 81 | 82 | 83 | dataset_new=xr.Dataset({ 84 | variable_name:(["time", "latitude", "longitude"],new_data[:,:,:]) 85 | }, 86 | coords={ 87 | "time":information.time, 88 | "latitude":information.latitude, 89 | "longitude":information.longitude} 90 | ) 91 | dataset_new.to_netcdf(result_path+"/"+"single_results/predict_{0}_.nc".format(year)) 92 | # np.save(result_path+"/"+f"single_results/predict_{year}_{locs}_.npy",new_data) 93 | print(new_member_data.shape) 94 | if need_member: 95 | dataset_new=xr.Dataset({ 96 | variable_name:(["member","time", "latitude", "longitude"],new_member_data[:,:,:,:]) ,#if tp need exp 97 | # "v10":(["time", "latitude", "longitude"], new_data[:,1,:,:]), 98 | # "t2m":(["time", "latitude", "longitude"], new_data[:,2,:,:]), 99 | # "sp":(["time", "latitude", "longitude"], new_data[:,3,:,:]), 100 | # "tp":(["time", "latitude", "longitude"], np.exp(new_data[:,4,:,:])-1), 101 | }, 102 | coords={"member":np.arange(sample_size), 103 | "time":information.time, 104 | "latitude":information.latitude, 105 | "longitude":information.longitude} 106 | ) 107 | dataset_new.to_netcdf(result_path+"/"+"multi_member/predict_{0}_.nc".format(year)) 108 | # np.save(result_path+"/"+f"multi_member/predict_{year}_{locs}_.npy",new_member_data) 109 | val_logger.info(f"{year} member data is finished.") 110 | 111 | 112 | 113 | if __name__ == "__main__": 114 | set_seeds() 115 | parser = argparse.ArgumentParser() 116 | parser.add_argument("-c", "--config", type=str, help="JSON file for configuration") 117 | parser.add_argument("-p", "--phase", type=str, choices=["train", "val"], 118 | help="Run either training or validation(inference).", default="train") 119 | parser.add_argument("-gpu", "--gpu_ids", type=str, default=None) 120 | parser.add_argument("-var", "--variable_name", type=str, default=None) 121 | parser.add_argument("-member", "--member", type=str, default=None) 122 | 123 | step_s=25 124 | need_member=True 125 | inference_version="v1" 126 | args = parser.parse_args() 127 | configs = Config(args) 128 | variable_name=args.variable_name 129 | sample_size=int(args.member) 130 | variable_list={"u":0,"v":1,"sp":3,"t2m":2,"tp":4} 131 | list_idx=variable_list[variable_name] 132 | if variable_name == "tp": 133 | mean_hr=0 134 | std_hr=1 135 | else: 136 | mean_hr=np.load("/home/data/downscaling/downscaling_1023/data/train_dataset/mean&std/hr_mean.npy").transpose(2,0,1) 137 | 138 | std_hr=np.load("/home/data/downscaling/downscaling_1023/data/train_dataset/mean&std/hr_std.npy").transpose(2,0,1) 139 | torch.backends.cudnn.enabled = True 140 | torch.backends.cudnn.benchmark = True 141 | 142 | # index_list=got_index_list(patch_size=128) 143 | 144 | test_root = f"{configs.experiments_root}/test_{sample_size}member_{step_s}_{get_current_datetime()}" 145 | os.makedirs(test_root, exist_ok=True) 146 | setup_logger("test", test_root, "test", screen=True) 147 | val_logger = logging.getLogger("test") 148 | val_logger.info(dict2str(configs.get_hyperparameters_as_dict())) 149 | land_01_path="/home/data/downscaling/downscaling_1023/data/land10.npy" 150 | mask_path="/home/data/downscaling/downscaling_1023/data/mask10.npy" 151 | 152 | diffusion = model.create_model(in_channel=configs.in_channel, out_channel=configs.out_channel, 153 | norm_groups=configs.norm_groups, inner_channel=configs.inner_channel, 154 | channel_multiplier=configs.channel_multiplier, attn_res=configs.attn_res, 155 | res_blocks=configs.res_blocks, dropout=configs.dropout, 156 | diffusion_loss=configs.diffusion_loss, conditional=configs.conditional, 157 | gpu_ids=configs.gpu_ids, distributed=configs.distributed, 158 | init_method=configs.init_method, train_schedule=configs.train_schedule, 159 | train_n_timestep=configs.train_n_timestep, 160 | train_linear_start=configs.train_linear_start, 161 | train_linear_end=configs.train_linear_end, 162 | val_schedule=configs.val_schedule, val_n_timestep=configs.val_n_timestep, 163 | val_linear_start=configs.val_linear_start, val_linear_end=configs.val_linear_end, 164 | finetune_norm=configs.finetune_norm, optimizer=None, amsgrad=configs.amsgrad, 165 | learning_rate=configs.lr, checkpoint=configs.checkpoint, 166 | resume_state=configs.resume_state,phase=configs.phase, height=configs.height) 167 | result_path = f"{test_root}/results" 168 | os.makedirs(result_path+"/"+"single_results", exist_ok=True) 169 | os.makedirs(result_path+"/"+"multi_member", exist_ok=True) 170 | val_logger.info("Model initialization is finished.") 171 | 172 | 173 | val_logger.info("Testing dataset is ready.") 174 | current_step, current_epoch = diffusion.begin_step, diffusion.begin_epoch 175 | val_logger.info(f"Testing the model at epoch: {current_epoch}, iter: {current_step}.") 176 | 177 | diffusion.register_schedule(beta_schedule=configs.test_schedule, 178 | timesteps=configs.test_n_timestep, 179 | linear_start=configs.test_linear_start, 180 | linear_end=configs.test_linear_end) 181 | 182 | loop_prediction(2016,2022,step_s) 183 | 184 | val_logger.info("End of testing.") 185 | 186 | 187 | 188 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | """Module for creating end-to-end network for 2 | Single Image Super-Resolution task with DDPM. 3 | The work is based on https://github.com/Janspiry/Image-Super-Resolution-via-Iterative-Refinement. 4 | """ 5 | import logging 6 | 7 | from .model import DDPM 8 | 9 | logger = logging.getLogger(name="base") 10 | 11 | 12 | def create_model(in_channel, out_channel, norm_groups, inner_channel, 13 | channel_multiplier, attn_res, res_blocks, dropout, 14 | diffusion_loss, conditional, gpu_ids, distributed, init_method, 15 | train_schedule, train_n_timestep, train_linear_start, train_linear_end, 16 | val_schedule, val_n_timestep, val_linear_start, val_linear_end, 17 | finetune_norm, optimizer, amsgrad, learning_rate, checkpoint, resume_state, 18 | phase, height): 19 | """Creates DDPM model. 20 | Args: 21 | in_channel: The number of channels of input tensor of U-Net. 22 | out_channel: The number of channels of output tensor of U-Net. 23 | norm_groups: The number of groups for group normalization. 24 | inner_channel: Timestep embedding dimension. 25 | channel_multiplier: A tuple specifying the scaling factors of channels. 26 | attn_res: A tuple of spatial dimensions indicating in which resolutions to use self-attention layer. 27 | res_blocks: The number of residual blocks. 28 | dropout: Dropout probability. 29 | diffusion_loss: Either l1 or l2. 30 | conditional: Whether to condition on INTERPOLATED image or not. 31 | gpu_ids: IDs of gpus. 32 | distributed: Whether the computation will be distributed among multiple GPUs or not. 33 | init_method: NN weight initialization method. One of normal, kaiming or orthogonal inisializations. 34 | train_schedule: Defines the type of beta schedule for training. 35 | train_n_timestep: Number of diffusion timesteps for training. 36 | train_linear_start: Minimum value of the linear schedule for training. 37 | train_linear_end: Maximum value of the linear schedule for training. 38 | val_schedule: Defines the type of beta schedule for validation. 39 | val_n_timestep: Number of diffusion timesteps for validation. 40 | val_linear_start: Minimum value of the linear schedule for validation. 41 | val_linear_end: Maximum value of the linear schedule for validation. 42 | finetune_norm: Whetehr to fine-tune or train from scratch. 43 | optimizer: The optimization algorithm. 44 | amsgrad: Whether to use the AMSGrad variant of optimizer. 45 | learning_rate: The learning rate. 46 | checkpoint: Path to the checkpoint file. 47 | resume_state: The path to load the network. 48 | phase: Either train or val. 49 | height: U-Net input tensor height value. 50 | Returns: 51 | Returns DDPM model. 52 | """ 53 | diffusion_model = DDPM(in_channel=in_channel, out_channel=out_channel, norm_groups=norm_groups, 54 | inner_channel=inner_channel, channel_multiplier=channel_multiplier, 55 | attn_res=attn_res, res_blocks=res_blocks, dropout=dropout, 56 | diffusion_loss=diffusion_loss, conditional=conditional, 57 | gpu_ids=gpu_ids, distributed=distributed, init_method=init_method, 58 | train_schedule=train_schedule, train_n_timestep=train_n_timestep, 59 | train_linear_start=train_linear_start, train_linear_end=train_linear_end, 60 | val_schedule=val_schedule, val_n_timestep=val_n_timestep, 61 | val_linear_start=val_linear_start, val_linear_end=val_linear_end, 62 | finetune_norm=finetune_norm, optimizer=optimizer, amsgrad=amsgrad, 63 | learning_rate=learning_rate, checkpoint=checkpoint, 64 | resume_state=resume_state, phase=phase, height=height) 65 | logger.info("Model [{:s}] is created.".format(diffusion_model.__class__.__name__)) 66 | return diffusion_model -------------------------------------------------------------------------------- /model/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LingFH/Diffusion_4_downscaling/5dea4cbaf404803c6bedd1abebb90a2effa8686c/model/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /model/__pycache__/base_model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LingFH/Diffusion_4_downscaling/5dea4cbaf404803c6bedd1abebb90a2effa8686c/model/__pycache__/base_model.cpython-39.pyc -------------------------------------------------------------------------------- /model/__pycache__/ema.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LingFH/Diffusion_4_downscaling/5dea4cbaf404803c6bedd1abebb90a2effa8686c/model/__pycache__/ema.cpython-39.pyc -------------------------------------------------------------------------------- /model/__pycache__/model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LingFH/Diffusion_4_downscaling/5dea4cbaf404803c6bedd1abebb90a2effa8686c/model/__pycache__/model.cpython-39.pyc -------------------------------------------------------------------------------- /model/__pycache__/networks.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LingFH/Diffusion_4_downscaling/5dea4cbaf404803c6bedd1abebb90a2effa8686c/model/__pycache__/networks.cpython-39.pyc -------------------------------------------------------------------------------- /model/base_model.py: -------------------------------------------------------------------------------- 1 | """Defines a base class for DDPM model. 2 | The work is based on https://github.com/Janspiry/Image-Super-Resolution-via-Iterative-Refinement. 3 | """ 4 | import typing 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | class BaseModel: 11 | """A skeleton for DDPM models. 12 | Attributes: 13 | gpu_ids: IDs of gpus. 14 | """ 15 | 16 | def __init__(self, gpu_ids): 17 | self.gpu_ids = gpu_ids 18 | self.device = torch.device("cuda" if torch.cuda.is_available() and self.gpu_ids else "cpu") 19 | self.begin_step, self.begin_epoch = 0, 0 20 | 21 | def feed_data(self, data) -> None: 22 | """Provides model with data. 23 | Args: 24 | data: A batch of data. 25 | """ 26 | pass 27 | 28 | def optimize_parameters(self) -> None: 29 | """Computes loss and performs GD step on learnable parameters. 30 | """ 31 | pass 32 | 33 | def get_current_visuals(self) -> dict: 34 | """Returns reconstructed data points. 35 | """ 36 | pass 37 | 38 | def print_network(self) -> None: 39 | """Prints the network architecture. 40 | """ 41 | pass 42 | 43 | def set_device(self, x): 44 | """Sets values of x onto device specified by an attribute of the same name. 45 | Args: 46 | x: Value storage. 47 | Returns: 48 | x set on self.device. 49 | """ 50 | if isinstance(x, dict): 51 | x = {key: (item.to(self.device) if item.numel() else item) for key, item in x.items()} 52 | elif isinstance(x, list): 53 | x = [item.to(self.device) if item else item for item in x] 54 | else: 55 | x = x.to(self.device) 56 | return x 57 | 58 | @staticmethod 59 | def get_network_description(network: nn.Module) -> typing.Tuple[str, int]: 60 | """Get the network name and parameters. 61 | Args: 62 | network: The neural netowrk. 63 | Returns: 64 | Name of the network and the number of parameters. 65 | """ 66 | if isinstance(network, nn.DataParallel): 67 | network = network.module 68 | n_params = sum(map(lambda x: x.numel(), network.parameters())) 69 | return str(network), n_params -------------------------------------------------------------------------------- /model/diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LingFH/Diffusion_4_downscaling/5dea4cbaf404803c6bedd1abebb90a2effa8686c/model/diffusion/__init__.py -------------------------------------------------------------------------------- /model/diffusion/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LingFH/Diffusion_4_downscaling/5dea4cbaf404803c6bedd1abebb90a2effa8686c/model/diffusion/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /model/diffusion/__pycache__/ddim.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LingFH/Diffusion_4_downscaling/5dea4cbaf404803c6bedd1abebb90a2effa8686c/model/diffusion/__pycache__/ddim.cpython-39.pyc -------------------------------------------------------------------------------- /model/diffusion/__pycache__/ddpm.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LingFH/Diffusion_4_downscaling/5dea4cbaf404803c6bedd1abebb90a2effa8686c/model/diffusion/__pycache__/ddpm.cpython-39.pyc -------------------------------------------------------------------------------- /model/diffusion/__pycache__/unet.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LingFH/Diffusion_4_downscaling/5dea4cbaf404803c6bedd1abebb90a2effa8686c/model/diffusion/__pycache__/unet.cpython-39.pyc -------------------------------------------------------------------------------- /model/diffusion/ddim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from tqdm import tqdm 4 | from functools import partial 5 | from einops import repeat 6 | 7 | 8 | def noise_like(shape, device, repeat=False): 9 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) 10 | noise = lambda: torch.randn(shape, device=device) 11 | return repeat_noise() if repeat else noise() 12 | 13 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): 14 | # select alphas for computing the variance schedule 15 | alphas = alphacums[ddim_timesteps] 16 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) 17 | 18 | # according the the formula provided in https://arxiv.org/abs/2010.02502 19 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) 20 | if verbose: 21 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') 22 | print(f'For the chosen value of eta, which is {eta}, ' 23 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}') 24 | return sigmas, alphas, alphas_prev 25 | 26 | 27 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): 28 | if ddim_discr_method == 'uniform': 29 | c = num_ddpm_timesteps // num_ddim_timesteps 30 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) 31 | elif ddim_discr_method == 'quad': 32 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) 33 | else: 34 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') 35 | 36 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps 37 | # add one to get the final alpha values right (the ones from first scale to data during sampling) 38 | steps_out = ddim_timesteps + 1 39 | if verbose: 40 | print(f'Selected timesteps for ddim sampler: {steps_out}') 41 | return steps_out 42 | 43 | 44 | class DDIMSampler(object): 45 | def __init__(self, model, schedule="linear", **kwargs): 46 | super().__init__() 47 | self.model = model 48 | self.ddpm_num_timesteps = model.num_timesteps 49 | self.schedule = schedule 50 | 51 | def register_buffer(self, name, attr): 52 | if type(attr) == torch.Tensor: 53 | if attr.device != torch.device("cuda"): 54 | attr = attr.to(torch.device("cuda")) 55 | setattr(self, name, attr) 56 | 57 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): 58 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, 59 | num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) 60 | alphas_cumprod = self.model.alphas_cumprod 61 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' 62 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) 63 | 64 | self.register_buffer('betas', to_torch(self.model.betas)) 65 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 66 | self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) 67 | 68 | # calculations for diffusion q(x_t | x_{t-1}) and others 69 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) 70 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) 71 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) 72 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) 73 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) 74 | 75 | # ddim sampling parameters 76 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), 77 | ddim_timesteps=self.ddim_timesteps, 78 | eta=ddim_eta,verbose=verbose) 79 | self.register_buffer('ddim_sigmas', ddim_sigmas) 80 | self.register_buffer('ddim_alphas', ddim_alphas) 81 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) 82 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) 83 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( 84 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( 85 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) 86 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) 87 | 88 | @torch.no_grad() 89 | def sample(self, 90 | S, 91 | batch_size, 92 | shape, 93 | conditioning=None, 94 | callback=None, 95 | return_intermediates=False, 96 | normals_sequence=None, 97 | img_callback=None, 98 | quantize_x0=False, 99 | eta=0., 100 | mask=None, 101 | x0=None, 102 | temperature=1., 103 | noise_dropout=0., 104 | score_corrector=None, 105 | corrector_kwargs=None, 106 | verbose=True, 107 | x_T=None, 108 | log_every_t=10, 109 | unconditional_guidance_scale=1., 110 | unconditional_conditioning=None, 111 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 112 | **kwargs 113 | ): 114 | if conditioning is not None: 115 | if isinstance(conditioning, dict): 116 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 117 | if cbs != batch_size: 118 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 119 | else: 120 | if conditioning.shape[0] != batch_size: 121 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 122 | 123 | self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) 124 | # sampling 125 | C, H, W = shape 126 | size = (batch_size, C, H, W) 127 | # print(f'Data shape for DDIM sampling is {size}, eta {eta}') 128 | 129 | samples, intermediates = self.ddim_sampling(conditioning, size, 130 | callback=callback, 131 | img_callback=img_callback, 132 | quantize_denoised=quantize_x0, 133 | mask=mask, x0=x0, 134 | ddim_use_original_steps=False, 135 | noise_dropout=noise_dropout, 136 | temperature=temperature, 137 | score_corrector=score_corrector, 138 | corrector_kwargs=corrector_kwargs, 139 | x_T=x_T, 140 | log_every_t=log_every_t, 141 | unconditional_guidance_scale=unconditional_guidance_scale, 142 | unconditional_conditioning=unconditional_conditioning, 143 | ) 144 | # samples.clamp_(-1., 1.) 145 | if return_intermediates: 146 | return samples, intermediates 147 | else: 148 | return samples 149 | 150 | @torch.no_grad() 151 | def ddim_sampling(self, cond, shape, 152 | x_T=None, ddim_use_original_steps=False, 153 | callback=None, timesteps=None, quantize_denoised=False, 154 | mask=None, x0=None, img_callback=None, log_every_t=10, 155 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 156 | unconditional_guidance_scale=1., unconditional_conditioning=None,mycondition=True): 157 | device = self.model.betas.device 158 | b = shape[0] 159 | # if x_T is None: 160 | # img = torch.randn(shape, device=device) 161 | # else: 162 | # img = x_T 163 | img = torch.randn(shape, device=device) 164 | if timesteps is None: 165 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps 166 | elif timesteps is not None and not ddim_use_original_steps: 167 | subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 168 | timesteps = self.ddim_timesteps[:subset_end] 169 | 170 | intermediates = {'x_inter': [img], 'pred_x0': [img]} 171 | time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) 172 | 173 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] 174 | # print(f"Running DDIM Sampling with {total_steps} timesteps") 175 | # print(total_steps) 176 | # iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) 177 | 178 | for i, step in enumerate(time_range):#enumerate(iterator): 179 | index = total_steps - i - 1 180 | ts = torch.full((b,), step, device=device, dtype=torch.long) 181 | if mask is not None: 182 | assert x0 is not None 183 | img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? 184 | img = img_orig * mask + (1. - mask) * img 185 | # print("img",img.max(),img.min()) 186 | outs = self.p_sample_ddim(img, x_T, ts, index=index, use_original_steps=ddim_use_original_steps, 187 | quantize_denoised=quantize_denoised, temperature=temperature, 188 | noise_dropout=noise_dropout, score_corrector=score_corrector, 189 | corrector_kwargs=corrector_kwargs, 190 | unconditional_guidance_scale=unconditional_guidance_scale, 191 | unconditional_conditioning=unconditional_conditioning,mycondition=mycondition) 192 | img, pred_x0 = outs 193 | # img.clamp_(-1., 1.) 194 | if callback: callback(i) 195 | if img_callback: img_callback(pred_x0, i) 196 | 197 | if index % log_every_t == 0 or index == total_steps - 1: 198 | intermediates['x_inter'].append(img) 199 | intermediates['pred_x0'].append(pred_x0) 200 | 201 | return img, intermediates 202 | 203 | @torch.no_grad() 204 | def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, 205 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 206 | unconditional_guidance_scale=1., unconditional_conditioning=None,mycondition=True): 207 | b, *_, device = *x.shape, x.device 208 | 209 | if unconditional_conditioning is None or unconditional_guidance_scale == 1.: 210 | # e_t = self.model.apply_model(x, t, c) 211 | # img = torch.randn(x.shape, device=device) 212 | if mycondition: 213 | e_t = self.model.model(torch.cat([c,x],dim=1), t) 214 | else: 215 | e_t = self.model.model(x, t) 216 | # print(e_t.shape) 217 | else: 218 | x_in = torch.cat([x] * 2) 219 | t_in = torch.cat([t] * 2) 220 | c_in = torch.cat([unconditional_conditioning, c]) 221 | e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) 222 | e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) 223 | 224 | if score_corrector is not None: 225 | assert self.model.parameterization == "eps" 226 | e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) 227 | 228 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas 229 | alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev 230 | sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas 231 | sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas 232 | # select parameters corresponding to the currently considered timestep 233 | a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) 234 | a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) 235 | sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) 236 | sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) 237 | 238 | # current prediction for x_0 239 | # x.clamp_(-1., 1.) 240 | 241 | # print("x",x.max(),x.min()) 242 | # print("e_t",e_t.max(),e_t.min()) 243 | # print("a_t",a_t.max(),a_t.min()) 244 | # print("sqrt_one_minus_at",sqrt_one_minus_at.max(),sqrt_one_minus_at.min()) 245 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() 246 | # print("pred_x0",pred_x0.max(),pred_x0.min()) 247 | if quantize_denoised: 248 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) 249 | # direction pointing to x_t 250 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t 251 | # dir_xt.clamp_(-1., 1.) 252 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature 253 | if noise_dropout > 0.: 254 | noise = torch.nn.functional.dropout(noise, p=noise_dropout) 255 | 256 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise 257 | # print("x_prev",x_prev.abs().max(),x_prev.min()) 258 | # scale=torch.abs(x_prev).view(x_prev.shape[0],x_prev.shape[1],-1) 259 | # x_prev=x_prev/scale.max(dim=-1)[0].view(x_prev.shape[0],x_prev.shape[1],1,1) 260 | return x_prev, pred_x0 261 | -------------------------------------------------------------------------------- /model/diffusion/ddpm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from torch.optim.lr_scheduler import LambdaLR 5 | from einops import rearrange, repeat 6 | from functools import partial 7 | from tqdm import tqdm 8 | from model.diffusion.ddim import DDIMSampler 9 | from model.diffusion.dpm_solver import DPMSolverSampler 10 | import math 11 | # from contextlib import contextmanager 12 | 13 | def exists(x): 14 | return x is not None 15 | 16 | def extract_into_tensor(a, t, x_shape): 17 | b, *_ = t.shape 18 | out = a.gather(-1, t) 19 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 20 | 21 | def noise_like(shape, device, repeat=False): 22 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) 23 | noise = lambda: torch.randn(shape, device=device) 24 | return repeat_noise() if repeat else noise() 25 | 26 | #warmup_trick 27 | def _warmup_beta(linear_start, linear_end, n_timestep, warmup_frac): 28 | betas = linear_end * np.ones(n_timestep, dtype=np.float64) 29 | warmup_time = int(n_timestep * warmup_frac) 30 | betas[:warmup_time] = np.linspace( 31 | linear_start, linear_end, warmup_time, dtype=np.float64) 32 | return betas 33 | 34 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 35 | if schedule == 'quad': 36 | betas = np.linspace(linear_start ** 0.5, linear_end ** 0.5, 37 | n_timestep, dtype=np.float64) ** 2 38 | elif schedule == 'linear': 39 | betas = np.linspace(linear_start, linear_end, 40 | n_timestep, dtype=np.float64) 41 | elif schedule == 'warmup10': 42 | betas = _warmup_beta(linear_start, linear_end, 43 | n_timestep, 0.1) 44 | elif schedule == 'warmup50': 45 | betas = _warmup_beta(linear_start, linear_end, 46 | n_timestep, 0.5) 47 | elif schedule == 'const': 48 | betas = linear_end * np.ones(n_timestep, dtype=np.float64) 49 | elif schedule == 'jsd': # 1/T, 1/(T-1), 1/(T-2), ..., 1 50 | betas = 1. / np.linspace(n_timestep, 51 | 1, n_timestep, dtype=np.float64) 52 | elif schedule == "cosine": 53 | timesteps = ( 54 | torch.arange(n_timestep + 1, dtype=torch.float64) / 55 | n_timestep + cosine_s 56 | ) 57 | alphas = timesteps / (1 + cosine_s) * math.pi / 2 58 | alphas = torch.cos(alphas).pow(2) 59 | alphas = alphas / alphas[0] 60 | betas = 1 - alphas[1:] / alphas[:-1] 61 | betas = betas.clamp(max=0.999) 62 | else: 63 | raise NotImplementedError(schedule) 64 | return betas 65 | 66 | def disabled_train(self, mode=True): 67 | """Overwrite model.train with this function to make sure train/eval mode 68 | does not change anymore.""" 69 | return self 70 | 71 | 72 | def uniform_on_device(r1, r2, shape, device): 73 | return (r1 - r2) * torch.rand(*shape, device=device) + r2 74 | 75 | 76 | class DDPM(nn.Module): 77 | # classic DDPM with Gaussian diffusion, in image space 78 | def __init__(self, 79 | denoise_net:nn.Module, 80 | timesteps=1000, 81 | beta_schedule="linear", 82 | loss_type="l2", 83 | ckpt_path=None, 84 | gpu_ids=[0], 85 | # ignore_keys=[], 86 | # load_only_unet=False, 87 | #monitor="val/loss", 88 | use_ema=False, 89 | log_every_t=100, 90 | clip_denoised=True, 91 | linear_start=1e-4, 92 | linear_end=2e-2, 93 | cosine_s=8e-3, 94 | given_betas=None, 95 | original_elbo_weight=0., 96 | v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta 97 | l_simple_weight=1., 98 | conditioning_key=None, 99 | parameterization="eps", # all assuming fixed variance schedules 100 | scheduler_config=None, 101 | use_positional_encodings=False, 102 | learn_logvar=False, 103 | logvar_init=0., 104 | conditional=True 105 | ): 106 | super().__init__() 107 | assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"' 108 | self.parameterization = parameterization 109 | print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode") 110 | self.cond_stage_model = None 111 | self.clip_denoised = clip_denoised 112 | self.log_every_t = log_every_t 113 | self.num_timesteps = None 114 | self.use_positional_encodings = use_positional_encodings 115 | self.model = denoise_net 116 | self.conditional=conditional 117 | self.gpu_ids=gpu_ids 118 | self.device = torch.device("cuda" if torch.cuda.is_available() and self.gpu_ids else "cpu") 119 | # count_params(self.model, verbose=True) 120 | self.use_ema = use_ema#False 121 | # if self.use_ema:#要看一下 122 | # self.model_ema = LitEma(self.model) 123 | # print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") 124 | 125 | self.use_scheduler = scheduler_config is not None 126 | if self.use_scheduler: 127 | self.scheduler_config = scheduler_config 128 | 129 | self.v_posterior = v_posterior 130 | self.original_elbo_weight = original_elbo_weight 131 | self.l_simple_weight = l_simple_weight 132 | 133 | # if monitor is not None: 134 | # self.monitor = monitor 135 | # if ckpt_path is not None: 136 | # self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet) 137 | 138 | self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps, 139 | linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) 140 | 141 | self.loss_type = loss_type 142 | #可学习方差最好是false? 143 | self.learn_logvar = learn_logvar 144 | self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,)) 145 | if self.learn_logvar: 146 | self.logvar = nn.Parameter(self.logvar, requires_grad=True) 147 | 148 | 149 | def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, 150 | linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3,device=None): 151 | if exists(given_betas): 152 | betas = given_betas 153 | else: 154 | betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, 155 | cosine_s=cosine_s) 156 | betas=betas.detach().cpu().numpy() if isinstance(betas,torch.Tensor) else betas 157 | alphas = 1. - betas 158 | alphas_cumprod = np.cumprod(alphas, axis=0) 159 | alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) 160 | 161 | timesteps, = betas.shape 162 | self.num_timesteps = int(timesteps) 163 | self.linear_start = linear_start 164 | self.linear_end = linear_end 165 | assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' 166 | 167 | to_torch = partial(torch.tensor, dtype=torch.float32,device=device) 168 | 169 | self.register_buffer('betas', to_torch(betas)) 170 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 171 | self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) 172 | 173 | # calculations for diffusion q(x_t | x_{t-1}) and others 174 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) 175 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) 176 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) 177 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) 178 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) 179 | 180 | # calculations for posterior q(x_{t-1} | x_t, x_0) 181 | posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / ( 182 | 1. - alphas_cumprod) + self.v_posterior * betas 183 | # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) 184 | self.register_buffer('posterior_variance', to_torch(posterior_variance)) 185 | # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain 186 | self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) 187 | self.register_buffer('posterior_mean_coef1', to_torch( 188 | betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) 189 | self.register_buffer('posterior_mean_coef2', to_torch( 190 | (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) 191 | #eps是原始的形式,需要保留 192 | if self.parameterization == "eps": 193 | lvlb_weights = self.betas ** 2 / ( 194 | 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)) 195 | elif self.parameterization == "x0": 196 | lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod)) 197 | else: 198 | raise NotImplementedError("mu not supported") 199 | # TODO how to choose this term 200 | lvlb_weights[0] = lvlb_weights[1] 201 | self.register_buffer('lvlb_weights', lvlb_weights, persistent=False) 202 | assert not torch.isnan(self.lvlb_weights).all() 203 | 204 | def set_loss(self,device:torch.device): 205 | '''set loss function 206 | ''' 207 | if self.loss_type == "l1": 208 | self.loss_func=nn.L1Loss(reduction="sum").to(device) 209 | elif self.loss_type=="l2": 210 | self.loss_func=nn.MSELoss(reduction="sum").to(device) 211 | else: 212 | raise NotImplementedError("Specify loss_type attribute to be either \'l1\' or \'l2\'.") 213 | # @contextmanager 214 | # def ema_scope(self, context=None): 215 | # if self.use_ema: 216 | # self.model_ema.store(self.model.parameters()) 217 | # self.model_ema.copy_to(self.model) 218 | # if context is not None: 219 | # print(f"{context}: Switched to EMA weights") 220 | # try: 221 | # yield None 222 | # finally: 223 | # if self.use_ema: 224 | # self.model_ema.restore(self.model.parameters()) 225 | # if context is not None: 226 | # print(f"{context}: Restored training weights") 227 | 228 | # def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): 229 | # sd = torch.load(path, map_location="cpu") 230 | # if "state_dict" in list(sd.keys()): 231 | # sd = sd["state_dict"] 232 | # keys = list(sd.keys()) 233 | # for k in keys: 234 | # for ik in ignore_keys: 235 | # if k.startswith(ik): 236 | # print("Deleting key {} from state_dict.".format(k)) 237 | # del sd[k] 238 | # missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( 239 | # sd, strict=False) 240 | # print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") 241 | # if len(missing) > 0: 242 | # print(f"Missing Keys: {missing}") 243 | # if len(unexpected) > 0: 244 | # print(f"Unexpected Keys: {unexpected}") 245 | 246 | def q_mean_variance(self, x_start, t): 247 | """ 248 | Get the distribution q(x_t | x_0). 249 | :param x_start: the [N x C x ...] tensor of noiseless inputs. 250 | :param t: the number of diffusion steps (minus 1). Here, 0 means one step. 251 | :return: A tuple (mean, variance, log_variance), all of x_start's shape. 252 | """ 253 | mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start) 254 | variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) 255 | log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) 256 | return mean, variance, log_variance 257 | 258 | def predict_start_from_noise(self, x_t, t, noise): 259 | return ( 260 | extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - 261 | extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise 262 | ) 263 | 264 | def q_posterior(self, x_start, x_t, t): 265 | posterior_mean = ( 266 | extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + 267 | extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t 268 | ) 269 | posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape) 270 | posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape) 271 | return posterior_mean, posterior_variance, posterior_log_variance_clipped 272 | 273 | def p_mean_variance(self, x, t, clip_denoised: bool,condition_x: torch.Tensor = None): 274 | if condition_x is not None: 275 | model_out=self.model(torch.cat([condition_x,x],dim=1),t) 276 | else: 277 | model_out = self.model(x, t) 278 | if self.parameterization == "eps": 279 | x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) 280 | elif self.parameterization == "x0": 281 | x_recon = model_out 282 | if clip_denoised: 283 | x_recon.clamp_(-5., 5.) 284 | 285 | model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) 286 | return model_mean, posterior_variance, posterior_log_variance 287 | 288 | @torch.no_grad() 289 | def p_sample(self, x, t, clip_denoised=True, repeat_noise=False,condition_x: torch.Tensor = None): 290 | b, *_, device = *x.shape, x.device 291 | model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised,condition_x=condition_x) 292 | noise = noise_like(x.shape, device, repeat_noise) 293 | # no noise when t == 0 294 | nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) 295 | return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise 296 | 297 | @torch.no_grad() 298 | def p_sample_loop(self, x_in:torch.Tensor, return_intermediates=False): 299 | device = self.betas.device 300 | b = x_in.size(0) 301 | shape=(b,1, x_in.size(2),x_in.size(3)) 302 | img = torch.randn(shape, device=device) 303 | intermediates = [img] 304 | if self.conditional: 305 | condition_x=x_in 306 | else: 307 | condition_x=None 308 | for i in reversed(range(0, self.num_timesteps)): 309 | img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long), 310 | clip_denoised=self.clip_denoised,condition_x=condition_x) 311 | if i % self.log_every_t == 0 or i == self.num_timesteps - 1: 312 | intermediates.append(img) 313 | if return_intermediates: 314 | return img, intermediates 315 | else: 316 | return img 317 | 318 | @torch.no_grad() 319 | def sample(self, x_in: torch.Tensor, return_intermediates=False,ddim=True,use_dpm_solver=True,ddim_steps=200):#等价于super reoslution 320 | if ddim: 321 | if use_dpm_solver: 322 | sampler=DPMSolverSampler(self) 323 | shape = (1, x_in.size(2),x_in.size(3)) 324 | batch_size=x_in.size(0) 325 | return sampler.sample(x_in,ddim_steps,batch_size,shape) 326 | else: 327 | ddim_sampler = DDIMSampler(self) 328 | shape = (1, x_in.size(2),x_in.size(3)) 329 | batch_size=x_in.size(0) 330 | return ddim_sampler.sample(ddim_steps,batch_size, 331 | shape,x_T=x_in,verbose=False,return_intermediates=return_intermediates) 332 | else: 333 | return self.p_sample_loop(x_in, 334 | return_intermediates=return_intermediates) 335 | 336 | def q_sample(self, x_start, t, noise=None): 337 | if noise is None: 338 | noise=torch.randn_like(x_start) 339 | return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + 340 | extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) 341 | 342 | def get_loss(self, pred, target, mean=True): 343 | if self.loss_type == 'l1': 344 | loss = (target - pred).abs() 345 | if mean: 346 | loss = loss.mean() 347 | elif self.loss_type == 'l2': 348 | if mean: 349 | loss = torch.nn.functional.mse_loss(target, pred) 350 | else: 351 | loss = torch.nn.functional.mse_loss(target, pred, reduction='none') 352 | else: 353 | raise NotImplementedError("unknown loss type '{loss_type}'") 354 | 355 | return loss 356 | 357 | def p_losses(self, x_in, t, noise=None): 358 | x_start=x_in["HR"]#label 359 | if noise is None: 360 | noise=torch.randn_like(x_start) 361 | x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) 362 | if not self.conditional: 363 | model_out = self.model(x_noisy, t) 364 | else: 365 | model_out = self.model(torch.cat([x_in["INTERPOLATED"],x_noisy],dim=1), t) 366 | 367 | loss_dict = {} 368 | if self.parameterization == "eps": 369 | target = noise 370 | elif self.parameterization == "x0": 371 | target = x_start 372 | else: 373 | raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported") 374 | 375 | loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])#mean(dim=[1, 2, 3]) 376 | # print(loss.shape) 377 | # log_prefix = 'train' if self.training else 'val' 378 | 379 | # loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()}) 380 | loss_simple = loss * self.l_simple_weight 381 | 382 | loss_vlb = self.lvlb_weights[t] * loss 383 | # loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb}) 384 | 385 | loss = loss_simple + self.original_elbo_weight * loss_vlb 386 | # loss_dict.update({f'{log_prefix}/loss': loss}) 387 | 388 | return loss#, loss_dict 389 | 390 | def forward(self, x:dict, *args, **kwargs): 391 | 392 | # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size 393 | # assert h == img_size and w == img_size, f'height and width of image must be {img_size}' 394 | t = torch.randint(0, self.num_timesteps, (x["INTERPOLATED"].shape[0],), device=self.device).long() 395 | # print("t:",t.shape) 396 | return self.p_losses(x, t, *args, **kwargs) 397 | 398 | -------------------------------------------------------------------------------- /model/diffusion/dpm_solver/__init__.py: -------------------------------------------------------------------------------- 1 | from .sampler import DPMSolverSampler -------------------------------------------------------------------------------- /model/diffusion/dpm_solver/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LingFH/Diffusion_4_downscaling/5dea4cbaf404803c6bedd1abebb90a2effa8686c/model/diffusion/dpm_solver/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /model/diffusion/dpm_solver/__pycache__/dpm_solver.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LingFH/Diffusion_4_downscaling/5dea4cbaf404803c6bedd1abebb90a2effa8686c/model/diffusion/dpm_solver/__pycache__/dpm_solver.cpython-39.pyc -------------------------------------------------------------------------------- /model/diffusion/dpm_solver/__pycache__/sampler.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LingFH/Diffusion_4_downscaling/5dea4cbaf404803c6bedd1abebb90a2effa8686c/model/diffusion/dpm_solver/__pycache__/sampler.cpython-39.pyc -------------------------------------------------------------------------------- /model/diffusion/dpm_solver/sampler.py: -------------------------------------------------------------------------------- 1 | """SAMPLING ONLY.""" 2 | import torch 3 | 4 | from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver 5 | 6 | 7 | MODEL_TYPES = { 8 | "eps": "noise", 9 | "v": "v" 10 | } 11 | 12 | 13 | class DPMSolverSampler(object): 14 | def __init__(self, model, **kwargs): 15 | super().__init__() 16 | self.model = model 17 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device) 18 | self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod)) 19 | 20 | def register_buffer(self, name, attr): 21 | if type(attr) == torch.Tensor: 22 | if attr.device != torch.device("cuda"): 23 | attr = attr.to(torch.device("cuda")) 24 | setattr(self, name, attr) 25 | 26 | @torch.no_grad() 27 | def sample(self, 28 | img_in, 29 | S, 30 | batch_size, 31 | shape, 32 | conditioning=None, 33 | callback=None, 34 | normals_sequence=None, 35 | img_callback=None, 36 | quantize_x0=False, 37 | eta=0., 38 | mask=None, 39 | x0=None, 40 | temperature=1., 41 | noise_dropout=0., 42 | score_corrector=None, 43 | corrector_kwargs=None, 44 | verbose=True, 45 | x_T=None, 46 | log_every_t=100, 47 | unconditional_guidance_scale=1., 48 | unconditional_conditioning=None, 49 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 50 | **kwargs 51 | ): 52 | if conditioning is not None: 53 | if isinstance(conditioning, dict): 54 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 55 | if cbs != batch_size: 56 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 57 | else: 58 | if conditioning.shape[0] != batch_size: 59 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 60 | 61 | # sampling 62 | C, H, W = shape 63 | size = (batch_size, C, H, W) 64 | 65 | print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}') 66 | 67 | device = self.model.betas.device 68 | # if x_T is None: 69 | # img = torch.randn(size, device=device) 70 | # else: 71 | # img = x_T 72 | img = torch.randn(size, device=device) 73 | ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod) 74 | 75 | model_fn = model_wrapper( 76 | lambda x, t: self.model.model(x, t), 77 | ns, 78 | model_type=MODEL_TYPES[self.model.parameterization], 79 | guidance_type= "classifier-free", 80 | condition=img_in, 81 | unconditional_condition=unconditional_conditioning, 82 | guidance_scale=unconditional_guidance_scale, 83 | ) 84 | 85 | dpm_solver = DPM_Solver(model_fn, ns, predict_x0=False, thresholding=False) 86 | x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="adaptive", order=2, lower_order_final=True) 87 | 88 | return x.to(device) -------------------------------------------------------------------------------- /model/diffusion/unet.py: -------------------------------------------------------------------------------- 1 | """U-Net model for Denoising Diffusion Probabilistic Model. 2 | This implementation contains a number of modifications to 3 | original U-Net (residual blocks, multi-head attention) 4 | and also adds diffusion timestep embeddings t. 5 | The work is based on https://github.com/Janspiry/Image-Super-Resolution-via-Iterative-Refinement. 6 | https://github.com/davitpapikyan/Probabilistic-Downscaling-of-Climate-Variables/ 7 | """ 8 | import math 9 | import torch 10 | from torch import nn 11 | 12 | 13 | class PositionalEncoding(nn.Module): 14 | """Sinusoidal Positional Encoding component. 15 | Attributes: 16 | dim: Embedding dimension. 17 | """ 18 | 19 | def __init__(self, dim): 20 | super().__init__() 21 | self.dim = dim 22 | 23 | def forward(self, noise_level): 24 | """正弦位置编码.位置编码的是采样时刻的位置编码,而不是传统的空间位置? 25 | Args: 26 | noise_level: An array of size [B, 1] representing the difusion timesteps. 27 | Returns: 28 | Positional encodings of size [B, 1, D]. 29 | """ 30 | half_dim = self.dim // 2 31 | step = torch.arange(half_dim, dtype=noise_level.dtype, device=noise_level.device) / half_dim 32 | encoding = noise_level.unsqueeze(1) * torch.exp(-math.log(1e4) * step.unsqueeze(0)) 33 | encoding = torch.cat([torch.sin(encoding), torch.cos(encoding)], dim=-1) 34 | return encoding 35 | 36 | 37 | class FeatureWiseAffine(nn.Module): 38 | """为什么要增加一个仿射变换? 39 | Transformes timestep embeddings and injects it into input tensor. 40 | Attributes: 41 | in_channels: Input tensor channels. 42 | out_channels: Output tensor channels. 43 | use_affine_level: Whether to apply an affine transformation on input or add a noise. 44 | """ 45 | 46 | def __init__(self, in_channels: int, out_channels: int, use_affine_level: bool = False): 47 | super().__init__() 48 | self.use_affine_level = use_affine_level 49 | self.noise_func = nn.Linear(in_channels, out_channels * (1+self.use_affine_level))#如果增加affine就是*2 50 | 51 | def forward(self, x, time_emb): 52 | """Forward pass. 53 | Args: 54 | x: Input tensor of size [B, D, H, W]. 55 | time_emb: Timestep embeddings of size [B, 1, D] where D is the dimension of embedding. 56 | Returns: 57 | Transformed tensor of size [B, D, H, W]. 58 | """ 59 | batch_size = x.shape[0] 60 | if self.use_affine_level: 61 | gamma, beta = self.noise_func(time_emb).view(batch_size, -1, 1, 1).chunk(2, dim=1)#切分成gamma beta 62 | # The size of gamma and beta is (batch_size, out_channels, 1, 1). 63 | x = (1 + gamma) * x + beta #变成ax+b的形式相当于增加了一个gama变换 64 | else: 65 | x = x + self.noise_func(time_emb).view(batch_size, -1, 1, 1)#本质上就是加一个beata项 66 | return x 67 | 68 | 69 | class Upsample(nn.Module): 70 | """Scales the feature map by a factor of 2, i.e. upscale the feature map. 71 | Attributes: 72 | dim: Input/output tensor channels. 73 | """ 74 | 75 | def __init__(self, dim: int): 76 | super().__init__() 77 | self.up = nn.Upsample(scale_factor=2, mode="bicubic") 78 | self.conv = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=3, padding=1) 79 | 80 | def forward(self, x): 81 | """Upscales the spatial dimensions of the input tensor two times. 82 | Args: 83 | x: Input tensor of size [B, 8*D, H, W]. 84 | Returns: 85 | Upscaled tensor of size [B, 8*D, 2*H, 2*W]. 86 | """ 87 | return self.conv(self.up(x)) 88 | 89 | 90 | class Downsample(nn.Module): 91 | """Scale the feature map by a factor of 1/2, i.e. downscale the feature map. 92 | Attributes: 93 | dim: Input/output tensor channels. 94 | """ 95 | 96 | def __init__(self, dim: int): 97 | super().__init__() 98 | self.conv = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=3, stride=2, padding=1) 99 | 100 | def forward(self, x): 101 | """Downscales the spatial dimensions of the input tensor two times. 102 | Args: 103 | x: Input tensor of size [B, D, H, W]. 104 | Returns: 105 | Downscaled tensor of size [B, D, H/2, W/2]. 106 | """ 107 | return self.conv(x) 108 | 109 | 110 | class Block(nn.Module): 111 | """残差blockA building component of Residual block. 112 | Attributes: 113 | dim: Input tensor channels. 114 | dim_out: Output tensor channels. 115 | groups: Number of groups to separate the channels into. 116 | dropout: Dropout probability. 117 | """ 118 | 119 | def __init__(self, dim: int, dim_out: int, groups: int = 32, dropout: float = 0): 120 | super().__init__() 121 | self.block = nn.Sequential(nn.GroupNorm(num_groups=groups, num_channels=dim), 122 | nn.SiLU(), 123 | nn.Dropout2d(dropout) if dropout != 0 else nn.Identity(), 124 | nn.Conv2d(in_channels=dim, out_channels=dim_out, kernel_size=3, padding=1)) 125 | 126 | def forward(self, x): 127 | """Applies block transformations on input tensor. 128 | Args: 129 | x: Input tensor of size [B, D, H, W]. 130 | Returns: 131 | Transformed tensor of size [B, D, H, W]. 132 | """ 133 | return self.block(x) 134 | 135 | 136 | class ResnetBlock(nn.Module): 137 | """Residual block. 138 | Attributes: 139 | dim: Input tensor channels. 140 | dim_out: Output tensor channels. 141 | noise_level_emb_dim: Timestep embedding dimension. 142 | dropout: Dropout probability. 143 | use_affine_level: Whether to apply an affine transformation on input or add a noise. 144 | norm_groups: The number of groups for group normalization. 145 | """ 146 | 147 | def __init__(self, dim: int, dim_out: int, noise_level_emb_dim: int = None, dropout: float = 0, 148 | use_affine_level: bool = False, norm_groups: int = 32): 149 | super().__init__() 150 | self.noise_func = FeatureWiseAffine(in_channels=noise_level_emb_dim, out_channels=dim_out, 151 | use_affine_level=use_affine_level) 152 | self.block1 = Block(dim, dim_out, groups=norm_groups) 153 | self.block2 = Block(dim_out, dim_out, groups=norm_groups, dropout=dropout) 154 | self.res_conv = nn.Conv2d(in_channels=dim, out_channels=dim_out, kernel_size=1) \ 155 | if dim != dim_out else nn.Identity() 156 | 157 | def forward(self, x, time_emb): 158 | """Applied Residual block on input tensors. 159 | Args: 160 | x: Input tensor of size [B, D, H, W]. 161 | time_emb: Timestep embeddings of size [B, 1, D] where D is the dimension of embedding. 162 | Returns: 163 | Transformed tensor of size [B, D, H, W]. 164 | """ 165 | h = self.block1(x) # 编码 166 | h = self.noise_func(h, time_emb) #仿射变换 167 | h = self.block2(h) #编码 168 | return h + self.res_conv(x) #编码+残差 169 | 170 | 171 | class SelfAttention(nn.Module): 172 | """Multi-head attention. 173 | Attributes: 174 | in_channel: Input tensor channels. 175 | n_head: The number of heads in multi-head attention. 176 | norm_groups: The number of groups for group normalization. 177 | """ 178 | 179 | def __init__(self, in_channel: int, n_head: int = 1, norm_groups: int = 32): 180 | super().__init__() 181 | 182 | self.n_head = n_head 183 | self.norm = nn.GroupNorm(norm_groups, in_channel) 184 | self.qkv = nn.Conv2d(in_channels=in_channel, out_channels=3*in_channel, kernel_size=1, bias=False) 185 | self.out = nn.Conv2d(in_channels=in_channel, out_channels=in_channel, kernel_size=1) 186 | 187 | def forward(self, x): 188 | """Applies self-attention to input tensor. 189 | Args: 190 | x: Input tensor of size [B, 8*D, H, W]. 191 | Returns: 192 | Transformed tensor of size [B, 8*D, H, W]. 193 | """ 194 | batch_size, channel, height, width = x.shape 195 | head_dim = channel // self.n_head 196 | 197 | norm = self.norm(x) 198 | qkv = self.qkv(norm).view(batch_size, self.n_head, head_dim * 3, height, width) 199 | query, key, value = qkv.chunk(3, dim=2) 200 | 201 | attn = torch.einsum("bnchw, bncyx -> bnhwyx", query, key).contiguous() / math.sqrt(channel) 202 | attn = attn.view(batch_size, self.n_head, height, width, -1) 203 | attn = torch.softmax(attn, -1) 204 | attn = attn.view(batch_size, self.n_head, height, width, height, width) 205 | 206 | out = torch.einsum("bnhwyx, bncyx -> bnchw", attn, value).contiguous()#多头后用爱因斯坦和加一起 207 | out = self.out(out.view(batch_size, channel, height, width)) 208 | 209 | return out + x #多头后残差链接 210 | 211 | 212 | class ResnetBlocWithAttn(nn.Module): 213 | """ResnetBlock combined with sefl-attention layer. 214 | Attributes: 215 | dim: Input tensor channels. 216 | dim_out: Output tensor channels. 217 | noise_level_emb_dim: Timestep embedding dimension. 218 | norm_groups: The number of groups for group normalization. 219 | dropout: Dropout probability. 220 | with_attn: Whether to add self-attention layer or not. 221 | """ 222 | 223 | def __init__(self, dim: int, dim_out: int, *, noise_level_emb_dim: int = None, 224 | norm_groups: int = 32, dropout: float = 0, with_attn: bool = True): 225 | super().__init__() 226 | self.res_block = ResnetBlock(dim, dim_out, noise_level_emb_dim, norm_groups=norm_groups, dropout=dropout) 227 | self.attn = SelfAttention(dim_out, norm_groups=norm_groups) if with_attn else nn.Identity() 228 | 229 | def forward(self, x, time_emb): 230 | """Forward pass. 231 | Args: 232 | x: Input tensor of size [B, D, H, W]. 233 | time_emb: Timestep embeddings of size [B, 1, D] where D is the dimension of embedding. 234 | Returns: 235 | Transformed tensor of size [B, D, H, W]. 236 | """ 237 | x = self.res_block(x, time_emb) 238 | x = self.attn(x) 239 | return x 240 | 241 | 242 | class UNet(nn.Module): 243 | """Defines U-Net network. 244 | Attributes: 245 | in_channel: Input tensor channels. 246 | out_channel: Output tensor channels. 247 | inner_channel: Timestep embedding dimension. 248 | norm_groups: The number of groups for group normalization. 249 | channel_mults: A tuple specifying the scaling factors of channels. 250 | attn_res: A tuple of spatial dimensions indicating in which resolutions to use self-attention layer. 251 | res_blocks: The number of residual blocks. 252 | dropout: Dropout probability. 253 | with_noise_level_emb: Whether to apply timestep encodings or not. 254 | height: Height of input tensor. 255 | """ 256 | 257 | def __init__(self, in_channel: int, out_channel: int, inner_channel: int, 258 | norm_groups: int, channel_mults: tuple, attn_res: tuple, 259 | res_blocks: int, dropout: float, with_noise_level_emb: bool = True, height: int = 128): 260 | super().__init__() 261 | self.out_channel=out_channel 262 | if with_noise_level_emb: 263 | noise_level_channel = inner_channel 264 | 265 | # Time embedding layer that returns 266 | self.time_embedding = nn.Sequential(PositionalEncoding(inner_channel), 267 | nn.Linear(inner_channel, 4*inner_channel), 268 | nn.SiLU(), 269 | nn.Linear(4*inner_channel, inner_channel)) 270 | else: 271 | noise_level_channel, self.time_embedding = None, None 272 | 273 | num_mults = len(channel_mults) 274 | pre_channel = inner_channel 275 | feat_channels = [pre_channel] 276 | current_height = height 277 | downs = [nn.Conv2d(in_channel, inner_channel, kernel_size=3, padding=1)] 278 | 279 | for ind in range(num_mults): # For each channel growing factor. 280 | is_last = (ind == num_mults - 1) 281 | 282 | use_attn = current_height in attn_res 283 | channel_mult = inner_channel * channel_mults[ind] 284 | 285 | for _ in range(res_blocks): # Add res_blocks number of ResnetBlocWithAttn layer. 286 | downs.append(ResnetBlocWithAttn(pre_channel, channel_mult, noise_level_emb_dim=noise_level_channel, 287 | norm_groups=norm_groups, dropout=dropout, with_attn=use_attn)) 288 | feat_channels.append(channel_mult) 289 | pre_channel = channel_mult 290 | 291 | # If the newly added ResnetBlocWithAttn layer to downs list is not the last one, 292 | # then add a Downsampling layer. 293 | if not is_last: 294 | downs.append(Downsample(pre_channel)) 295 | feat_channels.append(pre_channel) 296 | current_height //= 2 297 | 298 | self.downs = nn.ModuleList(downs) 299 | self.mid = nn.ModuleList([ResnetBlocWithAttn(pre_channel, pre_channel, noise_level_emb_dim=noise_level_channel, 300 | norm_groups=norm_groups, dropout=dropout), 301 | ResnetBlocWithAttn(pre_channel, pre_channel, noise_level_emb_dim=noise_level_channel, 302 | norm_groups=norm_groups, dropout=dropout, with_attn=False)]) 303 | 304 | ups = [] 305 | for ind in reversed(range(num_mults)): # For each channel growing factor (in decreasing order). 306 | is_last = (ind < 1) 307 | use_attn = (current_height in attn_res) 308 | channel_mult = inner_channel * channel_mults[ind] 309 | 310 | for _ in range(res_blocks+1): # Add res_blocks+1 number of ResnetBlocWithAttn layer. 311 | ups.append(ResnetBlocWithAttn(pre_channel+feat_channels.pop(), channel_mult, 312 | noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups, 313 | dropout=dropout, with_attn=use_attn)) 314 | pre_channel = channel_mult 315 | 316 | # If the newly added ResnetBlocWithAttn layer to ups list is not the last one, 317 | # then add an Upsample layer. 318 | if not is_last: 319 | ups.append(Upsample(pre_channel)) 320 | current_height *= 2 321 | 322 | self.ups = nn.ModuleList(ups) 323 | 324 | # Final convolution layer to transform the spatial dimensions to the desired shapes. 325 | self.final_conv = Block(pre_channel, out_channel if out_channel else in_channel, groups=norm_groups) 326 | 327 | def forward(self, x, time): 328 | """Forward pass. 329 | Args: 330 | x: Input tensor of size: [B, C, H, W], for WeatherBench C=2. 331 | time: Diffusion timesteps of size: [B, 1]. 332 | Returns: 333 | Estimation of Gaussian noise. 334 | """ 335 | 336 | t = self.time_embedding(time) if self.time_embedding else None # [B, 1, D] 337 | feats = [] 338 | 339 | for layer in self.downs: 340 | x = layer(x, t) if isinstance(layer, ResnetBlocWithAttn) else layer(x) 341 | feats.append(x) 342 | 343 | for layer in self.mid: 344 | x = layer(x, t) if isinstance(layer, ResnetBlocWithAttn) else layer(x) 345 | 346 | # for i in feats: 347 | # print("feats",i.shape) 348 | for layer in self.ups: 349 | # print(x.shape) 350 | if x.shape[3] == 176 : 351 | x=x[:,:,:,:-1] 352 | elif x.shape[3]==36: 353 | x=x[:,:,:,:-1] 354 | x = layer(torch.cat((x, feats.pop()), dim=1), t) if isinstance(layer, ResnetBlocWithAttn) else layer(x) 355 | 356 | return self.final_conv(x) 357 | 358 | if __name__ == '__main__': 359 | import numpy as np 360 | model=UNet(10,5,64,32,[1,2,4,8],attn_res=[16],res_blocks=1,dropout=0.6) 361 | print(model) 362 | # print(model(torch.randn(2,10,300,400),torch.arange(1000)).shape) -------------------------------------------------------------------------------- /model/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | import torch.nn as nn 5 | 6 | 7 | class EMA(object): 8 | """An Exponential Moving Average class. 9 | Attributes: 10 | mu: IDs of gpus. 11 | shadow: The storage for parameter values. 12 | """ 13 | 14 | def __init__(self, mu=0.999): 15 | self.mu = mu 16 | self.shadow = {} 17 | 18 | def register(self, module): 19 | """Registers network parameters. 20 | Args: 21 | module: A parameter module, typically a neural network. 22 | """ 23 | if isinstance(module, nn.DataParallel): 24 | module = module.module 25 | for name, param in module.named_parameters(): 26 | if param.requires_grad: 27 | self.shadow[name] = param.data.clone() 28 | 29 | def update(self, module): 30 | """Updates parameters with a decay rate mu and stores in a storage. 31 | Args: 32 | module: A parameter module, typically a neural network. 33 | """ 34 | if isinstance(module, nn.DataParallel): 35 | module = module.module 36 | for name, param in module.named_parameters(): 37 | if param.requires_grad: 38 | self.shadow[name].data = (1. - self.mu) * param.data + self.mu * self.shadow[name].data 39 | 40 | def ema(self, module): 41 | """Updates network parameters from the storage. 42 | Args: 43 | module: A parameter module, typically a neural network. 44 | """ 45 | if isinstance(module, nn.DataParallel): 46 | module = module.module 47 | for name, param in module.named_parameters(): 48 | if param.requires_grad: 49 | param.data.copy_(self.shadow[name].data) 50 | 51 | def ema_copy(self, module): 52 | """Updates network parameters from the storage and returns a copy of it. 53 | Args: 54 | module: A parameter module, typically a neural network. 55 | Returns: 56 | A copy of network parameters. 57 | """ 58 | if isinstance(module, nn.DataParallel): 59 | inner_module = module.module 60 | module_copy = type(inner_module)( 61 | inner_module.config).to(inner_module.config.device) 62 | module_copy.load_state_dict(inner_module.state_dict()) 63 | module_copy = nn.DataParallel(module_copy) 64 | else: 65 | module_copy = type(module)(module.config).to(module.config.device) 66 | module_copy.load_state_dict(module.state_dict()) 67 | 68 | self.ema(module_copy) 69 | return module_copy 70 | 71 | def state_dict(self): 72 | """Returns current state of model parameters. 73 | Returns: 74 | Current state of model parameters stored in a local storage. 75 | """ 76 | return self.shadow 77 | 78 | def load_state_dict(self, state_dict): 79 | """Update local storage of parameters. 80 | Args: 81 | state_dict: A state of network parameters for updating local storage. 82 | """ 83 | self.shadow = state_dict 84 | 85 | 86 | 87 | 88 | 89 | 90 | class LitEma(nn.Module): 91 | def __init__(self, model, decay=0.9999, use_num_upates=True): 92 | super().__init__() 93 | if decay < 0.0 or decay > 1.0: 94 | raise ValueError('Decay must be between 0 and 1') 95 | 96 | self.m_name2s_name = {} 97 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 98 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates 99 | else torch.tensor(-1,dtype=torch.int)) 100 | 101 | for name, p in model.named_parameters(): 102 | if p.requires_grad: 103 | #remove as '.'-character is not allowed in buffers 104 | s_name = name.replace('.','') 105 | self.m_name2s_name.update({name:s_name}) 106 | self.register_buffer(s_name,p.clone().detach().data) 107 | 108 | self.collected_params = [] 109 | 110 | def forward(self,model): 111 | decay = self.decay 112 | 113 | if self.num_updates >= 0: 114 | self.num_updates += 1 115 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) 116 | 117 | one_minus_decay = 1.0 - decay 118 | 119 | with torch.no_grad(): 120 | m_param = dict(model.named_parameters()) 121 | shadow_params = dict(self.named_buffers()) 122 | 123 | for key in m_param: 124 | if m_param[key].requires_grad: 125 | sname = self.m_name2s_name[key] 126 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 127 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 128 | else: 129 | assert not key in self.m_name2s_name 130 | 131 | def copy_to(self, model): 132 | m_param = dict(model.named_parameters()) 133 | shadow_params = dict(self.named_buffers()) 134 | for key in m_param: 135 | if m_param[key].requires_grad: 136 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 137 | else: 138 | assert not key in self.m_name2s_name 139 | 140 | def store(self, parameters): 141 | """ 142 | Save the current parameters for restoring later. 143 | Args: 144 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 145 | temporarily stored. 146 | """ 147 | self.collected_params = [param.clone() for param in parameters] 148 | 149 | def restore(self, parameters): 150 | """ 151 | Restore the parameters stored with the `store` method. 152 | Useful to validate the model with EMA parameters without affecting the 153 | original optimization process. Store the parameters before the 154 | `copy_to` method. After validation (or model saving), use this to 155 | restore the former parameters. 156 | Args: 157 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 158 | updated with the stored parameters. 159 | """ 160 | for c_param, param in zip(self.collected_params, parameters): 161 | param.data.copy_(c_param.data) 162 | -------------------------------------------------------------------------------- /model/networks.py: -------------------------------------------------------------------------------- 1 | """Declares network weight initialization functions and a function 2 | to define final single image super-resolution solver architecture. 3 | Implements neural netowrk weight initialization methods such as 4 | normal, kaiming and orthogonal. Defines a function that 5 | creates a returns a network to train on single image 6 | super-resolution task. 7 | The work is based on https://github.com/Janspiry/Image-Super-Resolution-via-Iterative-Refinement. 8 | """ 9 | import functools 10 | import logging 11 | 12 | import torch 13 | import torch.nn as nn 14 | from torch.nn import init 15 | 16 | from .diffusion.ddpm import DDPM 17 | from .diffusion.unet import UNet 18 | 19 | logger = logging.getLogger("base") 20 | 21 | 22 | def weights_init_normal(model: nn.Module, std: float = 0.02) -> None: 23 | """Initializes model weights from Gaussian distribution. 24 | Args: 25 | model: The network. 26 | std: Standard deviation of Gaussian distrbiution. 27 | """ 28 | classname = model.__class__.__name__ 29 | if classname.find("Conv") != -1: 30 | init.normal_(model.weight.data, 0.0, std) 31 | if model.bias is not None: 32 | model.bias.data.zero_() 33 | elif classname.find("Linear") != -1: 34 | init.normal_(model.weight.data, 0.0, std) 35 | if model.bias is not None: 36 | model.bias.data.zero_() 37 | elif classname.find("BatchNorm2d") != -1: 38 | init.normal_(model.weight.data, 1.0, std) 39 | init.constant_(model.bias.data, 0.0) 40 | 41 | 42 | def weights_init_kaiming(model: nn.Module, scale: float = 1) -> None: 43 | """He initialization of model weights. 44 | Args: 45 | model: The network. 46 | scale: Scaling factor of weights. 47 | """ 48 | classname = model.__class__.__name__ 49 | if classname.find("Conv2d") != -1: 50 | init.kaiming_normal_(model.weight.data) 51 | model.weight.data *= scale 52 | if model.bias is not None: 53 | model.bias.data.zero_() 54 | elif classname.find("Linear") != -1: 55 | init.kaiming_normal_(model.weight.data) 56 | model.weight.data *= scale 57 | if model.bias is not None: 58 | model.bias.data.zero_() 59 | elif classname.find("BatchNorm2d") != -1: 60 | init.constant_(model.weight.data, 1.0) 61 | init.constant_(model.bias.data, 0.0) 62 | 63 | 64 | def weights_init_orthogonal(model: nn.Module) -> None: 65 | """Fills the model weights to be orthogonal matrices. 66 | Args: 67 | model: The network. 68 | """ 69 | classname = model.__class__.__name__ 70 | if classname.find("Conv") != -1: 71 | init.orthogonal_(model.weight.data) 72 | if model.bias is not None: 73 | model.bias.data.zero_() 74 | elif classname.find("Linear") != -1: 75 | init.orthogonal_(model.weight.data) 76 | if model.bias is not None: 77 | model.bias.data.zero_() 78 | elif classname.find("BatchNorm2d") != -1: 79 | init.constant_(model.weight.data, 1.0) 80 | init.constant_(model.bias.data, 0.0) 81 | 82 | 83 | def init_weights(net: nn.Module, init_type: str = "kaiming", scale: float = 1, std: float = 0.02) -> None: 84 | """Initializes network weights. 85 | Args: 86 | net: The neural network. 87 | init_type: One of "normal", "kaiming" or "orthogonal". 88 | scale: Scaling factor of weights used in kaiming initialization. 89 | std: Standard deviation of Gaussian distrbiution used in normal initialization. 90 | """ 91 | logger.info("Initialization method [{:s}]".format(init_type)) 92 | if init_type == "normal": 93 | weights_init_normal_ = functools.partial(weights_init_normal, std=std) 94 | net.apply(weights_init_normal_) 95 | elif init_type == "kaiming": 96 | weights_init_kaiming_ = functools.partial( 97 | weights_init_kaiming, scale=scale) 98 | net.apply(weights_init_kaiming_) 99 | elif init_type == "orthogonal": 100 | net.apply(weights_init_orthogonal) 101 | else: 102 | raise NotImplementedError("Initialization method [{:s}] not implemented".format(init_type)) 103 | 104 | 105 | def define_network(in_channel, out_channel, norm_groups, inner_channel, 106 | channel_multiplier, attn_res, res_blocks, dropout, 107 | diffusion_loss, conditional, gpu_ids, distributed, init_method, height) -> nn.Module: 108 | """Defines Gaussian Diffusion model for single image super-resolution task. 109 | Args: 110 | in_channel: The number of channels of input tensor of U-Net. 111 | out_channel: The number of channels of output tensor of U-Net. 112 | norm_groups: The number of groups for group normalization. 113 | inner_channel: Timestep embedding dimension. 114 | channel_multiplier: A tuple specifying the scaling factors of channels. 115 | attn_res: A tuple of spatial dimensions indicating in which resolutions to use self-attention layer. 116 | res_blocks: The number of residual blocks. 117 | dropout: Dropout probability. 118 | diffusion_loss: Either l1 or l2. 119 | conditional: Whether to condition on INTERPOLATED image or not. 120 | gpu_ids: IDs of gpus. 121 | distributed: Whether the computation will be distributed among multiple GPUs or not. 122 | init_method: NN weight initialization method. One of normal, kaiming or orthogonal inisializations. 123 | height: U-Net input tensor height value. 124 | Returns: 125 | A Gaussian Diffusion model. 126 | """ 127 | 128 | network = UNet(in_channel=in_channel, 129 | out_channel=out_channel, 130 | norm_groups=norm_groups if norm_groups else 32, 131 | inner_channel=inner_channel, 132 | channel_mults=channel_multiplier, 133 | attn_res=attn_res, 134 | res_blocks=res_blocks, 135 | dropout=dropout, 136 | height=height) 137 | 138 | model = DDPM(network, loss_type=diffusion_loss, conditional=conditional,gpu_ids=gpu_ids) 139 | init_weights(model, init_type=init_method)#ddpm是框架,unet是内部模块嵌套关系 140 | 141 | if gpu_ids and distributed: 142 | assert torch.cuda.is_available() 143 | model = nn.DataParallel(model) 144 | 145 | return model -------------------------------------------------------------------------------- /pytorch_ssim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | import numpy as np 5 | from math import exp 6 | 7 | def gaussian(window_size, sigma): 8 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 9 | return gauss/gauss.sum() 10 | 11 | def create_window(window_size, channel): 12 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 13 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 14 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 15 | return window 16 | 17 | def _ssim(img1, img2, window, window_size, channel, size_average = True): 18 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 19 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 20 | 21 | mu1_sq = mu1.pow(2) 22 | mu2_sq = mu2.pow(2) 23 | mu1_mu2 = mu1*mu2 24 | 25 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 26 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 27 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 28 | 29 | C1 = 0.01**2 30 | C2 = 0.03**2 31 | 32 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 33 | 34 | if size_average: 35 | return ssim_map.mean() 36 | else: 37 | return ssim_map#.mean(1).mean(1).mean(1) 38 | 39 | class SSIM(torch.nn.Module): 40 | def __init__(self, window_size = 11, size_average = True): 41 | super(SSIM, self).__init__() 42 | self.window_size = window_size 43 | self.size_average = size_average 44 | self.channel = 1 45 | self.window = create_window(window_size, self.channel) 46 | 47 | def forward(self, img1, img2): 48 | (_, channel, _, _) = img1.size() 49 | 50 | if channel == self.channel and self.window.data.type() == img1.data.type(): 51 | window = self.window 52 | else: 53 | window = create_window(self.window_size, channel) 54 | 55 | if img1.is_cuda: 56 | window = window.cuda(img1.get_device()) 57 | window = window.type_as(img1) 58 | 59 | self.window = window 60 | self.channel = channel 61 | 62 | 63 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 64 | 65 | def ssim(img1, img2, window_size = 11, size_average = True): 66 | (_, channel, _, _) = img1.size() 67 | window = create_window(window_size, channel) 68 | 69 | if img1.is_cuda: 70 | window = window.cuda(img1.get_device()) 71 | window = window.type_as(img1) 72 | 73 | return _ssim(img1, img2, window, window_size, channel, size_average) 74 | -------------------------------------------------------------------------------- /src.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LingFH/Diffusion_4_downscaling/5dea4cbaf404803c6bedd1abebb90a2effa8686c/src.png -------------------------------------------------------------------------------- /trainer_all.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import pickle 5 | import warnings 6 | from collections import OrderedDict, defaultdict 7 | import numpy as np 8 | import torch 9 | from tensorboardX import SummaryWriter 10 | from torch.nn.functional import mse_loss, l1_loss 11 | from torch.utils.data import DataLoader 12 | import model 13 | # from x2_data.mydataset_patch import BigDataset_train 14 | from data.mydataset_patch import SR3_Dataset_patch 15 | from configs import Config 16 | import matplotlib 17 | import matplotlib.pyplot as plt 18 | matplotlib.use('Agg') 19 | import glob 20 | from utils import dict2str, setup_logger, construct_and_save_wbd_plots, \ 21 | accumulate_statistics, \ 22 | get_optimizer, construct_mask, set_seeds,psnr 23 | import random 24 | warnings.filterwarnings("ignore") 25 | 26 | 27 | if __name__ == "__main__": 28 | set_seeds() # For reproducability. 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument("-c", "--config", type=str, help="JSON file for configuration") 31 | parser.add_argument("-p", "--phase", type=str, choices=["train", "val"], 32 | help="Run either training or validation(inference).", default="train") 33 | parser.add_argument("-gpu", "--gpu_ids", type=str, default=None) 34 | parser.add_argument("-var", "--variable_name", type=str, default=None) 35 | args = parser.parse_args() 36 | variable_name=args.variable_name 37 | configs = Config(args) 38 | torch.backends.cudnn.enabled = True 39 | torch.backends.cudnn.benchmark = True 40 | 41 | setup_logger(None, configs.log, "train", screen=True) 42 | setup_logger("val", configs.log, "val") 43 | 44 | logger = logging.getLogger("base") 45 | val_logger = logging.getLogger("val") 46 | logger.info(dict2str(configs.get_hyperparameters_as_dict())) 47 | tb_logger = SummaryWriter(log_dir=configs.tb_logger) 48 | 49 | 50 | target_paths = sorted(glob.glob("/home/data/downscaling/downscaling_1023/data/train_dataset/hr/*npy")) 51 | land_01_path="/home/data/downscaling/downscaling_1023/data/land10.npy" 52 | mask_path="/home/data/downscaling/downscaling_1023/data/mask10.npy" 53 | # physical_paths= sorted(glob.glob("/home/data/downscaling/downscaling_1023/data/train_dataset/pl/*npy")) 54 | lr_paths= sorted(glob.glob("/home/data/downscaling/downscaling_1023/data/train_dataset/sl/*npy")) 55 | random_dataset_index= random.sample(range(0, len(target_paths)), 2) 56 | data_index=np.arange(0,len(target_paths)) 57 | train_index=np.delete(data_index,random_dataset_index) 58 | logger.info(f"split_random dataset is {random_dataset_index}" ) 59 | train_data = SR3_Dataset_patch(np.array(target_paths)[train_index],land_01_path,mask_path,lr_paths=np.array(lr_paths)[train_index],var=variable_name,patch_size=configs.height) 60 | val_data=SR3_Dataset_patch(np.array(target_paths)[random_dataset_index],land_01_path,mask_path,lr_paths=np.array(lr_paths)[random_dataset_index],var=variable_name,patch_size=configs.height) 61 | 62 | logger.info(f"Train size: {len(train_data)}, Val size: {len(val_data)}.") 63 | train_loader = DataLoader(train_data, batch_size=configs.batch_size,shuffle=configs.use_shuffle, num_workers=configs.num_workers,drop_last=True) 64 | val_loader = DataLoader(val_data, batch_size=np.int(configs.batch_size/12),shuffle=False, num_workers=configs.num_workers,drop_last=True) 65 | logger.info("Training and Validation dataloaders are ready.") 66 | # Defining the model. 67 | optimizer = get_optimizer(configs.optimizer_type) 68 | diffusion = model.create_model(in_channel=configs.in_channel, out_channel=configs.out_channel, 69 | norm_groups=configs.norm_groups, inner_channel=configs.inner_channel, 70 | channel_multiplier=configs.channel_multiplier, attn_res=configs.attn_res, 71 | res_blocks=configs.res_blocks, dropout=configs.dropout, 72 | diffusion_loss=configs.diffusion_loss, conditional=configs.conditional, 73 | gpu_ids=configs.gpu_ids, distributed=configs.distributed, 74 | init_method=configs.init_method, train_schedule=configs.train_schedule, 75 | train_n_timestep=configs.train_n_timestep, 76 | train_linear_start=configs.train_linear_start, 77 | train_linear_end=configs.train_linear_end, 78 | val_schedule=configs.val_schedule, val_n_timestep=configs.val_n_timestep, 79 | val_linear_start=configs.val_linear_start, val_linear_end=configs.val_linear_end, 80 | finetune_norm=configs.finetune_norm, optimizer=optimizer, amsgrad=configs.amsgrad, 81 | learning_rate=configs.lr, checkpoint=configs.checkpoint, 82 | resume_state=configs.resume_state,phase=configs.phase, height=configs.height) 83 | logger.info("Model initialization is finished.") 84 | 85 | current_step, current_epoch = diffusion.begin_step, diffusion.begin_epoch 86 | if configs.resume_state: 87 | logger.info(f"Resuming training from epoch: {current_epoch}, iter: {current_step}.") 88 | 89 | logger.info("Starting the training.") 90 | diffusion.register_schedule(beta_schedule=configs.train_schedule, timesteps=configs.train_n_timestep, 91 | linear_start=configs.train_linear_start, linear_end=configs.train_linear_end) 92 | 93 | accumulated_statistics = OrderedDict() 94 | 95 | val_metrics_dict={"MSE": 0.0, "MAE": 0.0,"MAE_inter":0.0} 96 | val_metrics_dict["PSNR_"+variable_name]=0.0 97 | val_metrics_dict["PSNR_inter_"+variable_name]=0.0 98 | val_metrics_dict["RMSE_"+variable_name]=0.0 99 | val_metrics_dict["RMSE_inter_"+variable_name]=0.0 100 | 101 | 102 | val_metrics = OrderedDict(val_metrics_dict) 103 | 104 | # Training. 105 | while current_step < configs.n_iter: 106 | current_epoch += 1 107 | 108 | for train_data in train_loader: 109 | current_step += 1 110 | 111 | if current_step > configs.n_iter: 112 | break 113 | 114 | # Training. 115 | diffusion.feed_data(train_data) 116 | diffusion.optimize_parameters() 117 | diffusion.lr_scheduler_step() # For lr scheduler updates per iteration. 118 | accumulate_statistics(diffusion.get_current_log(), accumulated_statistics) 119 | 120 | # Logging the training information. 121 | if current_step % configs.print_freq == 0: 122 | message = f"Epoch: {current_epoch:5} | Iteration: {current_step:8}" 123 | 124 | for metric, values in accumulated_statistics.items(): 125 | mean_value = np.mean(values) 126 | message = f"{message} | {metric:s}: {mean_value:.5f}" 127 | tb_logger.add_scalar(f"{metric}/train", mean_value, current_step) 128 | 129 | logger.info(message) 130 | # tb_logger.add_scalar(f"learning_rate", diffusion.get_lr(), current_step) 131 | 132 | # Visualizing distributions of parameters. 133 | # for name, param in diffusion.get_named_parameters(): 134 | # tb_logger.add_histogram(name, param.clone().cpu().data.numpy(), current_step) 135 | 136 | accumulated_statistics = OrderedDict() 137 | # Validation. 138 | if current_step % configs.val_freq == 0: 139 | logger.info("Starting validation.") 140 | idx = 0 141 | result_path = f"{configs.results}/{current_epoch}" 142 | os.makedirs(result_path, exist_ok=True) 143 | diffusion.register_schedule(beta_schedule=configs.val_schedule, 144 | timesteps=configs.val_n_timestep, 145 | linear_start=configs.val_linear_start, 146 | linear_end=configs.val_linear_end) 147 | 148 | # A dictionary for storing a list of mean temperatures for each month. 149 | # month2mean_temperature = defaultdict(list) 150 | 151 | for val_data in val_loader: 152 | idx += 1 153 | diffusion.feed_data(val_data) 154 | #实验一采用了250,实验二用50 155 | diffusion.test(continuous=False,use_ddim=True,ddim_steps=250,use_dpm_solver=False) # Continues=False to return only the last timesteps's outcome. 156 | 157 | # Computing metrics on vlaidation data. 158 | visuals = diffusion.get_current_visuals() 159 | # Computing MSE and RMSE on original data. 160 | mask=val_data["mask"] 161 | mse_value = mse_loss(visuals["HR"]*mask, visuals["SR"]*mask) 162 | val_metrics["MSE"] += mse_value 163 | val_metrics["MAE"] += l1_loss(visuals["HR"]*mask, visuals["SR"]*mask) 164 | val_metrics["MAE_inter"] += l1_loss(visuals["HR"]*mask, visuals["INTERPOLATED"]*mask) 165 | 166 | val_metrics["RMSE_"+variable_name] += torch.sqrt(mse_loss(visuals["HR"]*mask, visuals["SR"]*mask)) 167 | val_metrics["RMSE_inter_"+variable_name] += torch.sqrt(mse_loss(visuals["HR"]*mask, visuals["INTERPOLATED"]*mask)) 168 | val_metrics["PSNR_"+variable_name] += psnr(visuals["HR"]*mask, visuals["SR"]*mask) 169 | val_metrics["PSNR_inter_"+variable_name] += psnr(visuals["HR"]*mask, visuals["INTERPOLATED"]*mask) 170 | 171 | 172 | if idx % configs.val_vis_freq == 0: 173 | 174 | logger.info(f"[{idx//configs.val_vis_freq}] Visualizing and storing some examples.") 175 | 176 | sr_candidates = diffusion.generate_multiple_candidates(n=configs.sample_size,ddim_steps=100,use_dpm_solver=False) 177 | 178 | mean_candidate = sr_candidates.mean(dim=0) # [B, C, H, W] 179 | std_candidate = sr_candidates.std(dim=0) # [B, C, H, W] 180 | bias = mean_candidate - visuals["HR"] 181 | 182 | 183 | 184 | # # Choosing the first n_val_vis number of samples to visualize. 185 | # variable_id=0 186 | random_idx=np.random.randint(0,np.int(configs.batch_size/12),5) 187 | 188 | path = f"{result_path}/{current_epoch}_{current_step}_{idx}" 189 | figure,axs=plt.subplots(5,9,figsize=(25,12)) 190 | if variable_name=="tp": 191 | vmin=0 192 | cmap="BrBG" 193 | vmax=2 194 | elif variable_name in ["u","v","t2m","sp"]: 195 | vmin=0 196 | cmap="RdBu_r" 197 | vmax=1 198 | else: 199 | vmin=-2 200 | cmap="RdBu_r" 201 | vmax=2 202 | for idx_i,num in enumerate(random_idx): 203 | axs[idx_i,0].imshow(visuals["HR"][num,0],vmin=vmin,vmax=vmax,cmap=cmap) 204 | axs[idx_i,1].imshow(visuals["SR"][num,0],vmin=vmin,vmax=vmax,cmap=cmap) 205 | axs[idx_i,2].imshow(visuals["INTERPOLATED"][num,0],vmin=vmin,vmax=vmax,cmap=cmap) 206 | 207 | axs[idx_i,3].imshow(mean_candidate[num,0],vmin=vmin,vmax=vmax,cmap=cmap) 208 | axs[idx_i,4].imshow(std_candidate[num,0],vmin=0,vmax=2,cmap='Reds') 209 | axs[idx_i,5].imshow(np.abs(visuals["HR"][num,0]-visuals["SR"][num,0]),vmin=0,vmax=2,cmap="Reds") 210 | axs[idx_i,7].imshow(np.abs(visuals["HR"][num,0]-visuals["INTERPOLATED"][num,0]),vmin=0,vmax=2,cmap="Reds") 211 | axs[idx_i,6].imshow(np.abs(bias)[num,0],vmin=0,vmax=2,cmap="Reds") 212 | axs[idx_i,8].imshow(val_data['mask'][num,0],vmin=0,vmax=2,cmap="RdBu_r") 213 | axs[idx_i,8].set_title("mean_mae:%.3f,inter_mae:%.3f,sr_mae:%.3f"%(np.abs(bias)[num,0].mean(),np.abs(visuals["HR"][num,0]-visuals["INTERPOLATED"][num,0]).mean(),np.abs(visuals["HR"][num,0]-visuals["SR"][num,0]).mean())) 214 | for title , col in zip(["HR","Diffusion","INTERPOLATED","mean","std","mae_sr","mae_mean","mae_inter"],range(8)): 215 | axs[0,col].set_title(title) 216 | plt.savefig(f"{path}_.png", bbox_inches="tight") 217 | plt.close("all") 218 | 219 | val_metrics["MSE"] /= idx 220 | val_metrics["MAE"] /= idx 221 | val_metrics["MAE_inter"] /= idx 222 | 223 | val_metrics["RMSE_"+variable_name] /= idx 224 | val_metrics["RMSE_inter_"+variable_name] /= idx 225 | val_metrics["PSNR_"+variable_name] /= idx 226 | val_metrics["PSNR_inter_"+variable_name] /= idx 227 | diffusion.register_schedule(beta_schedule=configs.train_schedule, 228 | timesteps=configs.train_n_timestep, 229 | linear_start=configs.train_linear_start, 230 | linear_end=configs.train_linear_end) 231 | message = f"Epoch: {current_epoch:5} | Iteration: {current_step:8}" 232 | for metric, value in val_metrics.items(): 233 | message = f"{message} | {metric:s}: {value:.5f}" 234 | tb_logger.add_scalar(f"{metric}/val", value, current_step) 235 | 236 | val_logger.info(message) 237 | 238 | val_metrics = val_metrics.fromkeys(val_metrics, 0.0) # Sets all metrics to zero. 239 | 240 | if current_step % configs.save_checkpoint_freq == 0: 241 | logger.info("Saving models and training states.") 242 | diffusion.save_network(current_epoch, current_step) 243 | 244 | tb_logger.close() 245 | 246 | logger.info("End of training.") 247 | 248 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """Defines auxiliary functions for fixing the seeds, setting 2 | a logger and visualizing WeatherBench data.""" 3 | import logging 4 | import os 5 | import random 6 | 7 | import cartopy.crs as ccrs 8 | import matplotlib 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | import torch 12 | from cartopy.mpl.geoaxes import GeoAxes 13 | from cartopy.mpl.ticker import LongitudeFormatter, LatitudeFormatter 14 | from cartopy.util import add_cyclic_point 15 | from matplotlib.figure import Figure 16 | from mpl_toolkits.axes_grid1 import AxesGrid 17 | 18 | 19 | # Tensorboard visualization titles. 20 | TITLES = ("Upsampled with interpolation", 21 | "Super-resolution reconstruction", 22 | "High-resolution original") 23 | 24 | 25 | def set_seeds(seed: int = 0): 26 | """Sets random seeds of Python, NumPy and PyTorch. 27 | Args: 28 | seed: Seed value. 29 | """ 30 | random.seed(seed) 31 | os.environ["PYTHONHASHSEED"] = str(seed) 32 | np.random.seed(seed) 33 | torch.manual_seed(seed) 34 | torch.cuda.manual_seed(seed) 35 | torch.backends.cudnn.deterministic = True 36 | 37 | 38 | def dict2str(dict_obj: dict, indent_l: int = 4) -> str: 39 | """Converts dictionary to string for printing out. 40 | Args: 41 | dict_obj: Dictionary or OrderedDict. 42 | indent_l: Left indentation level. 43 | Returns: 44 | Returns string version of opt. 45 | """ 46 | msg = "" 47 | for k, v in dict_obj.items(): 48 | if isinstance(v, dict): 49 | msg = f"{msg}{' '*(indent_l*2)}{k}:[\n{dict2str(v, indent_l+1)}{' '*(indent_l*2)}]\n" 50 | else: 51 | msg = f"{msg}{' '*(indent_l*2)}{k}: {v}\n" 52 | return msg 53 | 54 | 55 | def setup_logger(logger_name, root, phase, level=logging.INFO, screen=False): 56 | """Sets up the logger. 57 | Args: 58 | logger_name: The logger name. 59 | root: The directory of logger. 60 | phase: Either train or val. 61 | level: The level of logging. 62 | screen: If True then write logging records to a stream. 63 | """ 64 | logger = logging.getLogger(logger_name) 65 | formatter = logging.Formatter("%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s", datefmt="%y-%m-%d %H:%M:%S") 66 | log_file = os.path.join(root, "{}.log".format(phase)) 67 | fh = logging.FileHandler(log_file, mode="w") 68 | fh.setFormatter(formatter) 69 | logger.setLevel(level) 70 | logger.addHandler(fh) 71 | if screen: 72 | sh = logging.StreamHandler() 73 | sh.setFormatter(formatter) 74 | logger.addHandler(sh) 75 | 76 | 77 | def construct_and_save_wbd_plot(latitude: np.array, longitude: np.array, single_variable: torch.tensor, 78 | path: str, title: str = None, label: str = None, dpi: int = 200, 79 | figsize: tuple = (11, 8.5), cmap: str = "coolwarm", vmin=None, 80 | vmax=None, costline_color="black"): 81 | """Creates and saves WeatherBench data visualization for a single variable. 82 | Args: 83 | latitude: An array of latitudes. 84 | longitude: An array of longitudes. 85 | single_variable: A tensor to visualize. 86 | path: Path of a directory to save visualization. 87 | title: Title of the figure. 88 | label: Label of the colorbar. 89 | dpi: Resolution of the figure. 90 | figsize: Tuple of (width, height) in inches. 91 | cmap: A matplotlib colormap. 92 | vmin: Minimum value for colormap. 93 | vmax: Maximum value for colormap. 94 | costline_color: Matplotlib color. 95 | """ 96 | # single_variable, longitude = add_cyclic_point(single_variable, coord=np.array(longitude)) 97 | plt.figure(dpi=dpi, figsize=figsize) 98 | projection = ccrs.PlateCarree() 99 | ax = plt.axes()#projection=projection) 100 | 101 | if cmap == "binary": 102 | # For mask visualization. 103 | p = plt.contourf(longitude, latitude, single_variable, 10, transform=projection, 104 | cmap=(matplotlib.colors.ListedColormap(["white", "gray", "black"]) 105 | .with_extremes(over="0.25", under="0.75")), 106 | vmin=-2, vmax=2) 107 | boundaries, ticks = [-1, -0.33, 0.33, 1], [-1, 0, 1] 108 | elif cmap == "coolwarm": 109 | # For temperature visualization. 110 | p = plt.contourf(longitude, latitude, single_variable, 60, transform=projection, cmap=cmap, 111 | levels=np.linspace(vmin, vmax, max(int(np.abs(vmax-vmin))//2, 3))) 112 | boundaries, ticks = None, np.round(np.linspace(vmin, vmax, 7), 2) 113 | 114 | elif cmap == "Greens": 115 | # For visualization of standard deviation. 116 | p = plt.contourf(longitude, latitude, single_variable, 60, transform=projection, cmap=cmap, 117 | extend="max") 118 | boundaries, ticks = None, np.linspace(single_variable.min(), single_variable.max(), 5) 119 | 120 | # ax.set_xticks(np.linspace(-180, 180, 5), crs=projection) 121 | # ax.set_yticks(np.linspace(-90, 90, 5), crs=projection) 122 | # lon_formatter = LongitudeFormatter(zero_direction_label=True) 123 | # lat_formatter = LatitudeFormatter() 124 | # ax.xaxis.set_major_formatter(lon_formatter) 125 | # ax.yaxis.set_major_formatter(lat_formatter) 126 | # ax.coastlines(color=costline_color) 127 | 128 | # plt.colorbar(p, pad=0.06, label=label, orientation="horizontal", shrink=0.75, 129 | # boundaries=boundaries, ticks=ticks) 130 | 131 | plt.title(title) 132 | plt.savefig(path, bbox_inches="tight") 133 | plt.close("all") 134 | 135 | 136 | def add_batch_index(path: str, index: int): 137 | """Adds the number of batch gotten from data loader to path. 138 | Args: 139 | path: The path to which the function needs to add batch index. 140 | index: The batch index. 141 | Returns: 142 | The path with the index appended to the filename. 143 | """ 144 | try: 145 | filename, extension = path.split(".") 146 | except ValueError: 147 | splitted_parts = path.split(".") 148 | filename, extension = ".".join(splitted_parts[:-1]), splitted_parts[-1] 149 | return f"{filename}_{index}.{extension}" 150 | 151 | 152 | def construct_and_save_wbd_plots(latitude: np.array, longitude: np.array, data: torch.tensor, 153 | path: str, title: str = None, label: str = None, 154 | dpi: int = 200, figsize: tuple = (11, 8.5), cmap: str = "coolwarm", 155 | vmin=None, vmax=None, costline_color="black"): 156 | """Creates and saves WeatherBench data visualization. 157 | Args: 158 | latitude: An array of latitudes. 159 | longitude: An array of longitudes. 160 | data: A batch of variables to visualize. 161 | path: Path of a directory to save visualization. 162 | title: Title of the figure. 163 | label: Label of the colorbar. 164 | dpi: Resolution of the figure. 165 | figsize: Tuple of (width, height) in inches. 166 | cmap: A matplotlib colormap. 167 | vmin: Minimum value for colormap. 168 | vmax: Maximum value for colormap. 169 | costline_color: Matplotlib color. 170 | """ 171 | if len(data.shape) > 2: 172 | data = data.squeeze() 173 | 174 | if len(data.shape) > 2: 175 | for batch_index in range(data.shape[0]): 176 | path_for_sample = add_batch_index(path, batch_index) 177 | construct_and_save_wbd_plot(latitude, longitude, data[batch_index], path_for_sample, 178 | title, label, dpi, figsize, cmap, vmin, vmax, costline_color) 179 | else: 180 | construct_and_save_wbd_plot(latitude, longitude, data, path, title, label, dpi, figsize, cmap, 181 | vmin, vmax, costline_color) 182 | 183 | 184 | def construct_tb_visualization(latitude: np.array, longitude: np.array, data: tuple, label=None, 185 | dpi: int = 300, figsize: tuple = (22, 6), cmap: str = "coolwarm") -> Figure: 186 | """Construct tensorboard visualization figure. 187 | Args: 188 | latitude: An array of latitudes. 189 | longitude: An array of longitudes. 190 | data: A batch of variables to visualize. 191 | label: Label of the colorbar. 192 | dpi: Resolution of the figure. 193 | figsize: Tuple of (width, height) in inches. 194 | cmap: A matplotlib colormap. 195 | Returns: 196 | Matplotlib Figure. 197 | """ 198 | max_value = max((tensor.max() for tensor in data)) 199 | min_value = min((tensor.min() for tensor in data)) 200 | projection = ccrs.PlateCarree() 201 | axes_class = (GeoAxes, dict(map_projection=projection)) 202 | fig = plt.figure(figsize=figsize, dpi=dpi) 203 | axgr = AxesGrid(fig, 111, axes_class=axes_class, nrows_ncols=(1, 3), axes_pad=0.95, cbar_location="bottom", 204 | cbar_mode="single", cbar_pad=0.01, cbar_size="2%", label_mode='') 205 | lon_formatter = LongitudeFormatter(zero_direction_label=True) 206 | lat_formatter = LatitudeFormatter() 207 | 208 | for i, ax in enumerate(axgr): 209 | single_variable, lon = add_cyclic_point(data[i], coord=np.array(longitude)) 210 | ax.set_title(TITLES[i]) 211 | ax.gridlines(draw_labels=True, xformatter=lon_formatter, yformatter=lat_formatter, 212 | xlocs=np.linspace(-180, 180, 5), ylocs=np.linspace(-90, 90, 5)) 213 | p = ax.contourf(lon, latitude, single_variable, transform=projection, cmap=cmap, 214 | vmin=min_value, vmax=max_value) 215 | ax.coastlines() 216 | 217 | axgr.cbar_axes[0].colorbar(p, pad=0.01, label=label, shrink=0.95) 218 | fig.tight_layout() 219 | plt.close("all") 220 | return fig 221 | 222 | 223 | def accumulate_statistics(new_info: dict, storage: dict): 224 | """Accumulates statistics provided with new_info into storage. 225 | Args: 226 | new_info: A dictionary containing new information. 227 | storage: A dictionary where to accumulate new information. 228 | """ 229 | for key, value in new_info.items(): 230 | if key in storage: 231 | storage[key].append(value) 232 | else: 233 | storage[key] = [value] 234 | 235 | 236 | # def get_transformation(name: str) -> Transform: 237 | # """Return data transformation class corresponding to name. 238 | # Args: 239 | # name: The name of transformation. 240 | # Returns: 241 | # A data transformer. 242 | # """ 243 | # if name == "LocalStandardScaling": 244 | # from weatherbench_data.transforms import LocalStandardScaling as Transformation 245 | # elif name == "GlobalStandardScaling": 246 | # from weatherbench_data.transforms import GlobalStandardScaling as Transformation 247 | # return Transformation 248 | 249 | 250 | def get_optimizer(name: str) : 251 | """Return optimization algorithm class corresponding to name. 252 | Args: 253 | name: The name of optimizer. 254 | Returns: 255 | A torch optimizer. 256 | """ 257 | if name == "adam": 258 | from torch.optim import Adam as Optimizer 259 | elif name == "adamw": 260 | from torch.optim import AdamW as Optimizer 261 | return Optimizer 262 | 263 | 264 | def construct_mask(x: torch.tensor) -> torch.tensor: 265 | """Constructs signum(x) tensor with tolerance around 0 specified 266 | by torch.isclose function. 267 | Args: 268 | x: The input tensor. 269 | Returns: 270 | Signum(x) with slight tolerance around 0. 271 | """ 272 | values = torch.ones_like(x) 273 | zero_mask = torch.isclose(x, torch.zeros_like(x)) 274 | neg_mask = x < 0 275 | values[neg_mask] = -1 276 | values[zero_mask] = 0 277 | return values 278 | def normal(y_pred): 279 | y_pred_max=torch.max(y_pred,-1,keepdim=True).values.max(-2,keepdim=True).values 280 | y_pred_min=torch.min(y_pred,-1,keepdim=True).values.min(-2,keepdim=True).values 281 | return (y_pred-y_pred_min)/(y_pred_max-y_pred_min+1e-16) 282 | 283 | def psnr(y_pred,y_true): 284 | y_pred=normal(y_pred) 285 | y_true=normal(y_true) 286 | loss=10. * torch.log10(255. / torch.mean((y_pred - y_true) ** 2)) 287 | return loss 288 | 289 | 290 | 291 | # def my_plot(self,data_list): 292 | # for channel in range(5): 293 | # figure,axs=plt.subplots(5,5,figsize=(16,20)) 294 | # for i in range(5): 295 | # axs[0,i].imshow(pred[i,channel,:,:],vmin=-2,vmax=2,cmap='RdBu_r') 296 | # axs[1,i].imshow(true[i,channel,:,:],vmin=-2,vmax=2,cmap='RdBu_r') 297 | # axs[2,i].imshow(mask[i,0,:,:],vmin=-2,vmax=2,cmap='RdBu_r') 298 | # axs[3,i].imshow(land[i,0,:,:],vmin=-2,vmax=2,cmap='RdBu_r') 299 | # axs[4,i].imshow(test[i,channel,:,:],vmin=-2,vmax=2,cmap='RdBu_r') 300 | # axs[0,i].set_title("RMSE:"+str(np.sqrt(np.mean((pred[i,channel,:,:]-true[i,channel,:,:])**2)))) 301 | # axs[0,0].set_ylabel("y_pred") 302 | # plt.savefig(self.configs_train.draw_path[channel]+"/epoch{}.png".format(epoch)) 303 | # plt.close() 304 | # return fig 305 | 306 | 307 | 308 | # def reverse_transform_candidates(candidates: torch.tensor, reverse_transform: Transform, 309 | # transformations: dict, variables: list, data_type: str, 310 | # months: list, tranform_monthly: bool): 311 | # """Reverse transforms. 312 | # Args: 313 | # candidates: A tensor of shape [n, B, C, H, W]. 314 | # reverse_transform: A reverse transformation. 315 | # transformations: A dictionary of transformations. 316 | # variables: Weatherbench data variables. 317 | # data_type: Either 'lr' or 'hr'. 318 | # months: A list of months for each batch sample of length (B, ). 319 | # tranform_monthly: Either to apply transformation month-wise or not. 320 | # Returns: 321 | # Reversed transformed candidates. 322 | # """ 323 | # for i in range(candidates.shape[0]): 324 | # candidates[i] = reverse_transform(candidates[i], transformations, variables, 325 | # data_type, months, tranform_monthly) --------------------------------------------------------------------------------