├── README.md ├── config ├── AFNO_afno_8_8.yaml └── AFNO_afno_8_8_finetune_next.yaml ├── engine_train.py ├── engine_train_next.py ├── inference.py ├── main_train.py ├── main_train_next.py ├── mainnet.py ├── mode_moe_next.py ├── model_moe.py ├── util ├── crop.py ├── datasets.py ├── lars.py ├── lr_decay.py ├── lr_sched.py ├── misc.py ├── misc_finetune.py ├── misc_pre.py └── pos_embed.py └── utils ├── YParams.py ├── darcy_loss.py ├── data_loader_multifiles.py ├── data_loader_multifiles_precip.py ├── data_loader_multifiles_twoStep.py ├── date_time_to_hours.py ├── img_utils.py ├── logging_utils.py └── weighted_acc_rmse.py /README.md: -------------------------------------------------------------------------------- 1 | # EWMoE -------------------------------------------------------------------------------- /config/AFNO_afno_8_8.yaml: -------------------------------------------------------------------------------- 1 | ### base config ### 2 | full_field: &FULL_FIELD 3 | loss: 'l2' 4 | lr: 1E-3 5 | scheduler: 'ReduceLROnPlateau' # lichangyu todo 6 | num_data_workers: 4 # lichangyu todo 7 | dt: 1 # how many timesteps ahead the model will predict 8 | n_history: 0 #how many previous timesteps to consider 9 | prediction_type: 'iterative' 10 | prediction_length: 41 #applicable only if prediction_type == 'iterative' 11 | n_initial_conditions: 5 #applicable only if prediction_type == 'iterative' # lichangyu todo??? 12 | ics_type: "default" 13 | save_raw_forecasts: !!bool True 14 | save_channel: !!bool False 15 | masked_acc: !!bool False 16 | maskpath: None 17 | perturb: !!bool False 18 | add_grid: !!bool False 19 | N_grid_channels: 0 20 | gridtype: 'sinusoidal' #options 'sinusoidal' or 'linear' 21 | roll: !!bool False 22 | max_epochs: 50 23 | batch_size: 64 24 | 25 | #afno hyperparams 26 | num_blocks: 4 27 | nettype: 'afno' 28 | patch_size: 8 29 | width: 56 # lichangyu todo ??? 30 | modes: 32 # lichangyu todo ??? 31 | #options default, residual 32 | target: 'default' 33 | in_channels: [0,1] 34 | out_channels: [0,1] #must be same as in_channels if prediction_type == 'iterative' 35 | normalization: 'zscore' #options zscore (minmax not supported) 36 | train_data_path: '/pscratch/sd/j/jpathak/wind/train' 37 | valid_data_path: '/pscratch/sd/j/jpathak/wind/test' 38 | inf_data_path: '/pscratch/sd/j/jpathak/wind/out_of_sample' # test set path for inference 39 | exp_dir: '/pscratch/sd/j/jpathak/ERA5_expts_gtc/wind' 40 | time_means_path: '/pscratch/sd/j/jpathak/wind/time_means.npy' 41 | global_means_path: '/pscratch/sd/j/jpathak/wind/global_means.npy' 42 | global_stds_path: '/pscratch/sd/j/jpathak/wind/global_stds.npy' 43 | 44 | orography: !!bool False 45 | orography_path: None 46 | 47 | log_to_screen: !!bool True 48 | log_to_wandb: !!bool True 49 | save_checkpoint: !!bool True 50 | 51 | enable_nhwc: !!bool False 52 | # optimizer_type: 'FusedAdam' 53 | optimizer_type: '' 54 | crop_size_x: None 55 | crop_size_y: None 56 | 57 | two_step_training: !!bool False 58 | plot_animations: !!bool False 59 | 60 | add_noise: !!bool False 61 | noise_std: 0 62 | 63 | afno_backbone: &backbone 64 | <<: *FULL_FIELD 65 | log_to_wandb: !!bool True 66 | lr: 5E-4 67 | # batch_size: 1 68 | max_epochs: 150 69 | scheduler: 'CosineAnnealingLR' 70 | in_channels: [0, 1 ,2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] 71 | out_channels: [0, 1 ,2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] 72 | orography: !!bool False 73 | orography_path: None 74 | exp_dir: '/home/manxin/glh/mae' 75 | 76 | 77 | train_data_path: '/home/manxin/codes/fcn/data/train' 78 | valid_data_path: '/home/manxin/codes/fcn/data/test' 79 | inf_data_path: '/home/manxin/codes/fcn/data/out_of_sample' 80 | time_means_path: '/home/manxin/codes/fcn/additional/time_means.npy' 81 | global_means_path: '/home/manxin/codes/fcn/additional/global_means.npy' 82 | global_stds_path: '/home/manxin/codes/fcn/additional/global_stds.npy' 83 | # lichangyu todo arguments for mae 84 | batch_size: 2 85 | epochs: 150 86 | accum_iter: 1 87 | model: 'mae_vit_base_patch16' 88 | input_size: [720, 1440] 89 | mask_ratio: 0.75 90 | norm_pix_loss: !!bool False # lichangyu todo ?? 91 | weight_decay: 0.05 92 | lr_new: None 93 | blr: 1e-3 94 | min_lr: 0. 95 | warmup_epochs: 40 96 | data_path: '' 97 | output_dir: '' 98 | # output_dir: '/home/lichangyu/codes/mae/mae-main/output_dir' 99 | log_dir: './output_dir' 100 | device: 'cuda' 101 | seed: 0 102 | start_epoch: 0 103 | num_workers: 0 104 | pin_mem: !!bool True 105 | no_pin_mem: !!bool False 106 | world_size: 1 107 | local_rank: -1 108 | rank: 0 109 | dist_on_itp: !!bool False 110 | dist_url: 'env://' 111 | train_data_path_h5: '/home/manxin/codes/fcn/data/train' 112 | n_in_channels: 20 113 | n_out_channels: 20 114 | run_num: '' 115 | run_mode: 'pretrain' 116 | save_dir: '' 117 | yaml_config: './config/AFNO_v5_8_8.yaml' 118 | config: 'afno_backbone' 119 | patch_size: [8, 8] # lichangyu todo aaa [16, 16] 120 | img_size: [720, 1440] 121 | iters: 0 # 122 | resuming: !!bool False 123 | checkpoint_path: '' 124 | pretrained_ckpt_path: '' 125 | 126 | afno_backbone_orography: &backbone_orography 127 | <<: *backbone 128 | orography: !!bool True # lichangyu todo ?? 129 | orography_path: '/home/manxin/codes/fcn/data/static/orography.h5' 130 | 131 | afno_backbone_finetune: 132 | <<: *backbone 133 | lr: 1E-4 134 | batch_size: 1 135 | log_to_wandb: !!bool True 136 | max_epochs: 50 137 | pretrained: !!bool True 138 | two_step_training: !!bool True 139 | pretrained_ckpt_path: '/pscratch/sd/s/shas1693/results/era5_wind/afno_backbone/0/training_checkpoints/best_ckpt.tar' 140 | 141 | perturbations: 142 | <<: *backbone 143 | lr: 1E-4 144 | batch_size: 64 145 | max_epochs: 50 146 | pretrained: !!bool True 147 | two_step_training: !!bool True 148 | pretrained_ckpt_path: '/pscratch/sd/j/jpathak/ERA5_expts_gtc/wind/afno_20ch_bs_64_lr5em4_blk_8_patch_8_cosine_sched/1/training_checkpoints/best_ckpt.tar' 149 | prediction_length: 24 150 | ics_type: "datetime" 151 | n_perturbations: 100 152 | save_channel: !bool True 153 | save_idx: 4 154 | save_raw_forecasts: !!bool False 155 | date_strings: ["2018-01-01 00:00:00"] 156 | inference_file_tag: " " 157 | valid_data_path: "/pscratch/sd/j/jpathak/ " 158 | perturb: !!bool True 159 | n_level: 0.3 160 | 161 | ### PRECIP ### 162 | precip: &precip 163 | <<: *backbone 164 | in_channels: [0, 1 ,2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] 165 | out_channels: [0] 166 | nettype: 'afno' 167 | nettype_wind: 'afno' 168 | log_to_wandb: !!bool True 169 | lr: 2.5E-4 170 | batch_size: 1 171 | max_epochs: 25 172 | precip: '/home/lichangyu/codes/FourCastNet/data/FCN_ERA5_data_v0/precip' 173 | time_means_path_tp: '/home/lichangyu/codes/FourCastNet/additional/stats_v0/precip/time_means.npy' 174 | # precip: '/pscratch/sd/p/pharring/ERA5/precip/total_precipitation' 175 | # time_means_path_tp: '/pscratch/sd/p/pharring/ERA5/precip/total_precipitation/time_means.npy' 176 | model_wind_path: '/home/lichangyu/codes/FourCastNet/NVLabs-FourCastNet/FourCastNet-master/results_ms_pre/era5_wind/afno_backbone_finetune/check_ft/training_checkpoints/best_ckpt.tar' 177 | precip_eps: !!float 1e-5 178 | 179 | #result/afno_backbone_finetune/0/training_checkpoints/best_ckpt: 180 | # tar: 181 | -------------------------------------------------------------------------------- /config/AFNO_afno_8_8_finetune_next.yaml: -------------------------------------------------------------------------------- 1 | ### base config ### 2 | full_field: &FULL_FIELD 3 | loss: 'l2' 4 | lr: 1E-3 5 | scheduler: 'ReduceLROnPlateau' # lichangyu todo 6 | num_data_workers: 4 # lichangyu todo 7 | dt: 1 # how many timesteps ahead the model will predict 8 | n_history: 0 #how many previous timesteps to consider 9 | prediction_type: 'iterative' 10 | prediction_length: 41 #applicable only if prediction_type == 'iterative' 11 | n_initial_conditions: 5 #applicable only if prediction_type == 'iterative' # lichangyu todo??? 12 | ics_type: "default" 13 | save_raw_forecasts: !!bool True 14 | save_channel: !!bool False 15 | masked_acc: !!bool False 16 | maskpath: None 17 | perturb: !!bool False 18 | add_grid: !!bool False 19 | N_grid_channels: 0 20 | gridtype: 'sinusoidal' #options 'sinusoidal' or 'linear' 21 | roll: !!bool False 22 | max_epochs: 50 23 | batch_size: 64 24 | 25 | #afno hyperparams 26 | num_blocks: 4 27 | nettype: 'afno' 28 | patch_size: 8 29 | width: 56 # lichangyu todo ??? 30 | modes: 32 # lichangyu todo ??? 31 | #options default, residual 32 | target: 'default' 33 | in_channels: [0,1] 34 | out_channels: [0,1] #must be same as in_channels if prediction_type == 'iterative' 35 | normalization: 'zscore' #options zscore (minmax not supported) 36 | train_data_path: '/pscratch/sd/j/jpathak/wind/train' 37 | valid_data_path: '/pscratch/sd/j/jpathak/wind/test' 38 | inf_data_path: '/pscratch/sd/j/jpathak/wind/out_of_sample' # test set path for inference 39 | exp_dir: '/pscratch/sd/j/jpathak/ERA5_expts_gtc/wind' 40 | time_means_path: '/pscratch/sd/j/jpathak/wind/time_means.npy' 41 | global_means_path: '/pscratch/sd/j/jpathak/wind/global_means.npy' 42 | global_stds_path: '/pscratch/sd/j/jpathak/wind/global_stds.npy' 43 | 44 | orography: !!bool False 45 | orography_path: None 46 | 47 | log_to_screen: !!bool True 48 | log_to_wandb: !!bool True 49 | save_checkpoint: !!bool True 50 | 51 | enable_nhwc: !!bool False 52 | # optimizer_type: 'FusedAdam' 53 | optimizer_type: '' 54 | crop_size_x: None 55 | crop_size_y: None 56 | 57 | two_step_training: !!bool False 58 | plot_animations: !!bool False 59 | 60 | add_noise: !!bool False 61 | noise_std: 0 62 | 63 | afno_backbone: &backbone 64 | <<: *FULL_FIELD 65 | log_to_wandb: !!bool True 66 | lr: 5E-4 67 | # batch_size: 1 68 | max_epochs: 150 69 | scheduler: 'CosineAnnealingLR' 70 | in_channels: [0, 1 ,2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] 71 | out_channels: [0, 1 ,2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] 72 | orography: !!bool False 73 | orography_path: None 74 | exp_dir: '/home/manxin/glh/mae' 75 | 76 | 77 | train_data_path: '/home/manxin/codes/fcn/data/train' 78 | valid_data_path: '/home/manxin/codes/fcn/data/test' 79 | inf_data_path: '/home/manxin/codes/fcn/data/out_of_sample' 80 | time_means_path: '/home/manxin/codes/fcn/additional/time_means.npy' 81 | global_means_path: '/home/manxin/codes/fcn/additional/global_means.npy' 82 | global_stds_path: '/home/manxin/codes/fcn/additional/global_stds.npy' 83 | train_data_path_h5: '/home/manxin/codes/fcn/data/train' 84 | valid_data_path_h5: '/home/manxin/codes/fcn/data/test' 85 | # lichangyu todo arguments for mae 86 | batch_size: 2 87 | epochs: 50 88 | accum_iter: 1 89 | model: 'mae_vit_base_patch16' 90 | input_size: [720, 1440] 91 | mask_ratio: 0. 92 | norm_pix_loss: !!bool False # lichangyu todo ?? 93 | weight_decay: 0.05 94 | lr_new: None 95 | blr: 1e-3 96 | min_lr: 0. 97 | warmup_epochs: 40 98 | data_path: '' 99 | output_dir: '' 100 | # output_dir: '/home/lichangyu/codes/mae/mae-main/output_dir' 101 | log_dir: './output_dir' 102 | device: 'cuda' 103 | seed: 0 104 | start_epoch: 0 105 | num_workers: 0 106 | pin_mem: !!bool True 107 | no_pin_mem: !!bool False 108 | world_size: 1 109 | local_rank: -1 110 | rank: 0 111 | dist_on_itp: !!bool False 112 | dist_url: 'env://' 113 | n_in_channels: 20 114 | n_out_channels: 20 115 | run_num: '' 116 | run_mode: 'pretrain' 117 | save_dir: '' 118 | yaml_config: './config/AFNO_v5_8_8.yaml' 119 | config: 'afno_backbone' 120 | patch_size: [8, 8] # lichangyu todo aaa [16, 16] 121 | img_size: [720, 1440] 122 | iters: 0 # 123 | resuming: !!bool False 124 | checkpoint_path: '' 125 | pretrained_ckpt_path: '' 126 | 127 | afno_backbone_orography: &backbone_orography 128 | <<: *backbone 129 | orography: !!bool True # lichangyu todo ?? 130 | orography_path: '/home/manxin/codes/fcn/data/static/orography.h5' 131 | 132 | afno_backbone_finetune: 133 | <<: *backbone 134 | lr: 1E-4 135 | batch_size: 1 136 | log_to_wandb: !!bool True 137 | max_epochs: 2 138 | pretrained: !!bool True 139 | two_step_training: !!bool True 140 | pretrained_ckpt_path: '/pscratch/sd/s/shas1693/results/era5_wind/afno_backbone/0/training_checkpoints/best_ckpt.tar' 141 | 142 | perturbations: 143 | <<: *backbone 144 | lr: 1E-4 145 | batch_size: 64 146 | max_epochs: 50 147 | pretrained: !!bool True 148 | two_step_training: !!bool True 149 | pretrained_ckpt_path: '/pscratch/sd/j/jpathak/ERA5_expts_gtc/wind/afno_20ch_bs_64_lr5em4_blk_8_patch_8_cosine_sched/1/training_checkpoints/best_ckpt.tar' 150 | prediction_length: 24 151 | ics_type: "datetime" 152 | n_perturbations: 100 153 | save_channel: !bool True 154 | save_idx: 4 155 | save_raw_forecasts: !!bool False 156 | date_strings: ["2018-01-01 00:00:00"] 157 | inference_file_tag: " " 158 | valid_data_path: "/pscratch/sd/j/jpathak/ " 159 | perturb: !!bool True 160 | n_level: 0.3 161 | 162 | ### PRECIP ### 163 | precip: &precip 164 | <<: *backbone 165 | in_channels: [0, 1 ,2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] 166 | out_channels: [0] 167 | nettype: 'afno' 168 | nettype_wind: 'afno' 169 | log_to_wandb: !!bool True 170 | lr: 2.5E-4 171 | batch_size: 1 172 | max_epochs: 25 173 | precip: '/home/lichangyu/codes/FourCastNet/data/FCN_ERA5_data_v0/precip' 174 | time_means_path_tp: '/home/lichangyu/codes/FourCastNet/additional/stats_v0/precip/time_means.npy' 175 | # precip: '/pscratch/sd/p/pharring/ERA5/precip/total_precipitation' 176 | # time_means_path_tp: '/pscratch/sd/p/pharring/ERA5/precip/total_precipitation/time_means.npy' 177 | model_wind_path: '/home/lichangyu/codes/FourCastNet/NVLabs-FourCastNet/FourCastNet-master/results_ms_pre/era5_wind/afno_backbone_finetune/check_ft/training_checkpoints/best_ckpt.tar' 178 | precip_eps: !!float 1e-5 179 | 180 | #result/afno_backbone_finetune/0/training_checkpoints/best_ckpt: 181 | # tar: -------------------------------------------------------------------------------- /engine_train.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import sys 4 | from typing import Iterable 5 | 6 | import torch 7 | import util.misc_pre as misc 8 | import util.lr_sched as lr_sched 9 | from einops import rearrange, repeat 10 | from torchvision.utils import save_image 11 | from tqdm import tqdm 12 | 13 | def train_one_epoch(model: torch.nn.Module, 14 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 15 | device: torch.device, epoch: int, loss_scaler, 16 | log_writer=None, 17 | args=None): 18 | model.train(True) 19 | metric_logger = misc.MetricLogger(delimiter=" ") 20 | metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) 21 | header = 'Epoch: [{}]'.format(epoch) 22 | print_freq = 20 23 | 24 | accum_iter = args.accum_iter 25 | optimizer.zero_grad() 26 | if log_writer is not None: 27 | print('log_dir: {}'.format(log_writer.log_dir)) 28 | sample_iter_step = 0 29 | # data_loader_bar = tqdm(data_loader) 30 | for data_iter_step, (samples) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 31 | 32 | # we use a per iteration (instead of per epoch) lr scheduler 33 | if data_iter_step % accum_iter == 0: 34 | lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args) 35 | # samples = torch.load('/home/manxin/codes/fcn/pt_file/' + str(sample_iter_step) + '.pt') 36 | sample_iter_step += 1 37 | #manxin todo 38 | if epoch == 0 and data_iter_step ==1: 39 | # if epoch == 0: 40 | # print("data_iter_step""""""""""""""") 41 | # print(data_iter_step) 42 | inp = samples[0] 43 | for iii in range(20): 44 | try: 45 | os.mkdir(args.save_dir + "/" + str(iii)) 46 | # os.mkdir("/home/manxin/codes/FourCastNet/NVLabs-FourCastNet/FourCastNet-master/mae/mae-main/output_dir" + "/" + str(iii)) 47 | except: 48 | pass 49 | save_image(inp[0][iii], args.save_dir + "/" + str(iii) +"/ori.png") 50 | samples[0] = samples[0].to(device, non_blocking=True) 51 | with torch.cuda.amp.autocast(): 52 | loss, pred, _ = model(samples[0], mask_ratio=args.mask_ratio) 53 | # print(pred.shape) 54 | pred = rearrang_v1(args, pred) 55 | # print(pred.shape) 56 | # if data_iter_step == 1 and (((epoch % 5) == 0) or (epoch == (args.epochs - 2))): 57 | if data_iter_step == 1: 58 | # if data_iter_step == 2 and epoch == 0: 59 | for jjj in range(20): 60 | save_image(pred[0][jjj], args.save_dir + "/" + str(jjj) + "/" + str(epoch)+ ".png") 61 | # save_image(pred[0][jjj], "/home/manxin/codes/FourCastNet/NVLabs-FourCastNet/FourCastNet-master/mae/mae-main/output_dir" + "/" + str(jjj) + "/" + str(epoch)+ ".png") 62 | loss_value = loss.item() 63 | 64 | if not math.isfinite(loss_value): 65 | print("Loss is {}, stopping training".format(loss_value)) 66 | sys.exit(1) 67 | 68 | loss /= accum_iter 69 | # # manxin todo 70 | # optimizer.zero_grad() 71 | # loss.backward() 72 | # optimizer.step() 73 | # manxin todo 74 | loss_scaler(loss, optimizer, parameters=model.parameters(), 75 | update_grad=(data_iter_step + 1) % accum_iter == 0) 76 | if (data_iter_step + 1) % accum_iter == 0: 77 | optimizer.zero_grad() 78 | 79 | torch.cuda.synchronize() 80 | 81 | metric_logger.update(loss=loss_value) 82 | 83 | lr = optimizer.param_groups[0]["lr"] 84 | metric_logger.update(lr=lr) 85 | 86 | loss_value_reduce = misc.all_reduce_mean(loss_value) 87 | if log_writer is not None and (data_iter_step + 1) % accum_iter == 0: 88 | """ We use epoch_1000x as the x-axis in tensorboard. 89 | This calibrates different curves when batch size changes. 90 | """ 91 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) 92 | log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x) 93 | log_writer.add_scalar('lr', lr, epoch_1000x) 94 | 95 | # gather the stats from all processes 96 | metric_logger.synchronize_between_processes() 97 | print("Averaged stats:", metric_logger) 98 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 99 | 100 | def rearrang_v1(args, x): 101 | # print("x.shape::::::::::::::::::::") 102 | # print(x.shape) 103 | # print(args.img_size) 104 | # print(args.input_size) 105 | # print("x.shape::::::::::::::::::::") 106 | B = x.shape[0] 107 | embed_dim = int(args.n_in_channels * (args.patch_size[0] * args.patch_size[1])) 108 | h = int(args.img_size[0] // (args.patch_size[0])) 109 | w = int(args.img_size[1] // (args.patch_size[1])) 110 | x = x.reshape(B, h, w, embed_dim) 111 | x = rearrange( 112 | x, 113 | "b h w (p1 p2 c_out) -> b c_out (h p1) (w p2)", 114 | h = h, 115 | w = w, 116 | p1=args.patch_size[0], 117 | p2=args.patch_size[1], 118 | ) 119 | return x -------------------------------------------------------------------------------- /engine_train_next.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import sys 4 | from typing import Iterable 5 | 6 | import torch 7 | import util.misc_finetune as misc 8 | import util.lr_sched as lr_sched 9 | from einops import rearrange, repeat 10 | from torchvision.utils import save_image 11 | from tqdm import tqdm 12 | 13 | def train_one_epoch(model: torch.nn.Module, 14 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 15 | device: torch.device, epoch: int, loss_scaler, 16 | log_writer=None, 17 | args=None): 18 | model.train(True) 19 | metric_logger = misc.MetricLogger(delimiter=" ") 20 | metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) 21 | header = 'Epoch: [{}]'.format(epoch) 22 | print_freq = 20 23 | 24 | accum_iter = args.accum_iter 25 | optimizer.zero_grad() 26 | if log_writer is not None: 27 | print('log_dir: {}'.format(log_writer.log_dir)) 28 | sample_iter_step = 0 29 | # data_loader_bar = tqdm(data_loader) 30 | for data_iter_step, (samples) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 31 | 32 | # we use a per iteration (instead of per epoch) lr scheduler 33 | if data_iter_step % accum_iter == 0: 34 | lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args) 35 | sample_iter_step += 1 36 | #manxin todo 37 | if epoch == 0 and data_iter_step ==1: 38 | inp = samples[1] 39 | for iii in range(20): 40 | try: 41 | os.mkdir(args.save_dir + "/" + str(iii)) 42 | # os.mkdir("/home/manxin/codes/FourCastNet/NVLabs-FourCastNet/FourCastNet-master/mae/mae-main/output_dir" + "/" + str(iii)) 43 | except: 44 | pass 45 | save_image(inp[0][iii], args.save_dir + "/" + str(iii) +"/tar.png") 46 | samples[0] = samples[0].to(device, non_blocking=True) 47 | samples[1] = samples[1].to(device, non_blocking=True) 48 | 49 | with torch.cuda.amp.autocast(): 50 | loss, pred, _ = model(samples, mask_ratio=args.mask_ratio) 51 | if data_iter_step == 1: 52 | pred = rearrang_v1(args, pred) 53 | for jjj in range(20): 54 | #save_image(pred[0][jjj], args.save_dir + "/" + str(jjj) + "/" + str(epoch)+ ".png") 55 | print("111") 56 | loss_value = loss.item() 57 | 58 | if not math.isfinite(loss_value): 59 | print("Loss is {}, stopping training".format(loss_value)) 60 | sys.exit(1) 61 | 62 | loss /= accum_iter 63 | # # manxin todo 64 | # optimizer.zero_grad() 65 | # loss.backward() 66 | # optimizer.step() 67 | # manxin todo 68 | loss_scaler(loss, optimizer, parameters=model.parameters(), 69 | update_grad=(data_iter_step + 1) % accum_iter == 0) 70 | if (data_iter_step + 1) % accum_iter == 0: 71 | optimizer.zero_grad() 72 | 73 | torch.cuda.synchronize() 74 | 75 | metric_logger.update(loss=loss_value) 76 | 77 | lr = optimizer.param_groups[0]["lr"] 78 | metric_logger.update(lr=lr) 79 | 80 | loss_value_reduce = misc.all_reduce_mean(loss_value) 81 | if log_writer is not None and (data_iter_step + 1) % accum_iter == 0: 82 | """ We use epoch_1000x as the x-axis in tensorboard. 83 | This calibrates different curves when batch size changes. 84 | """ 85 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) 86 | log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x) 87 | log_writer.add_scalar('lr', lr, epoch_1000x) 88 | 89 | # gather the stats from all processes 90 | metric_logger.synchronize_between_processes() 91 | print("Averaged stats:", metric_logger) 92 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 93 | 94 | def valid_one_epoch(model: torch.nn.Module, 95 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 96 | device: torch.device, epoch: int, loss_scaler, 97 | log_writer=None, 98 | args=None): 99 | model.eval() 100 | metric_logger = misc.MetricLogger(delimiter=" ") 101 | metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) 102 | header = 'Epoch: [{}]'.format(epoch) 103 | print_freq =10000 104 | 105 | accum_iter = args.accum_iter 106 | if log_writer is not None: 107 | print('log_dir: {}'.format(log_writer.log_dir)) 108 | with torch.no_grad(): 109 | c_count = 0 110 | loss_value = 0. 111 | data_bar = tqdm(data_loader) 112 | for data_iter_step, (samples) in enumerate(data_bar): 113 | # if data_iter_step == 5: 114 | # break 115 | # we use a per iteration (instead of per epoch) lr scheduler 116 | #manxin todo 117 | samples[0] = samples[0].to(device, non_blocking=True) 118 | samples[1] = samples[1].to(device, non_blocking=True) 119 | loss, pred, _ = model(samples, mask_ratio=args.mask_ratio) 120 | c_count += 1 121 | loss_value += loss.item() 122 | loss_value = loss_value / c_count 123 | 124 | # gather the stats from all processes 125 | print("Valid loss: {}".format(loss_value)) 126 | return {loss_value} 127 | 128 | def rearrang_v1(args, x): 129 | # print("x.shape::::::::::::::::::::") 130 | # print(x.shape) 131 | # print(args.img_size) 132 | # print(args.input_size) 133 | # print("x.shape::::::::::::::::::::") 134 | B = x.shape[0] 135 | embed_dim = int(args.n_in_channels * (args.patch_size[0] * args.patch_size[1])) 136 | h = int(args.img_size[0] // (args.patch_size[0])) 137 | w = int(args.img_size[1] // (args.patch_size[1])) 138 | x = x.reshape(B, h, w, embed_dim) 139 | x = rearrange( 140 | x, 141 | "b h w (p1 p2 c_out) -> b c_out (h p1) (w p2)", 142 | h = h, 143 | w = w, 144 | p1=args.patch_size[0], 145 | p2=args.patch_size[1], 146 | ) 147 | return x -------------------------------------------------------------------------------- /main_train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import json 4 | import numpy as np 5 | import os 6 | import time 7 | from pathlib import Path 8 | # todo manxin need to delete when using distributed-train or command line to run 9 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 10 | os.environ['CUDA_LAUNCH_BLOCKING'] = '0' 11 | import torch 12 | import torch.backends.cudnn as cudnn 13 | from torch.utils.tensorboard import SummaryWriter 14 | import torchvision.transforms as transforms 15 | import torchvision.datasets as datasets 16 | from torchvision.utils import save_image 17 | import timm 18 | from tqdm import tqdm 19 | 20 | assert timm.__version__ == "0.3.2" # version check 21 | import timm.optim.optim_factory as optim_factory 22 | 23 | import util.misc_pre as misc 24 | from util.misc_pre import NativeScalerWithGradNormCount as NativeScaler 25 | 26 | import models_mae_afno_8_8 27 | from engine_pretrain_afno_8_8 import train_one_epoch 28 | import torch.distributed as dist 29 | from torch.nn.parallel import DistributedDataParallel 30 | import sys 31 | 32 | # manxin todo for h5 file loading 33 | from utils.data_loader_multifiles import get_data_loader 34 | from utils.YParams import YParams 35 | from collections import OrderedDict 36 | from torchsummary import summary 37 | 38 | # manxin todo save log info 39 | class Logger(object): 40 | def __init__(self, logFile="Default.log"): 41 | self.terminal = sys.stdout 42 | self.log = open(logFile, 'a') 43 | 44 | def write(self, message): 45 | self.terminal.write(message) 46 | self.log.write(message) 47 | 48 | def flush(self): 49 | pass 50 | 51 | def get_args_parser(): 52 | parser = argparse.ArgumentParser('MAE pre-training', add_help=False) 53 | # parser.add_argument('--batch_size', default=64, type=int, 54 | # help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus') 55 | # parser.add_argument('--epochs', default=800, type=int) 56 | parser.add_argument('--accum_iter', default=1, type=int, 57 | help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)') 58 | 59 | # Model parameters 60 | parser.add_argument('--model', default='mae_vit_large_patch16', type=str, metavar='MODEL', 61 | help='Name of model to train') 62 | 63 | parser.add_argument('--input_size', default=(720,1440), type=int, 64 | help='images input size') 65 | 66 | parser.add_argument('--mask_ratio', default=0.75, type=float, 67 | help='Masking ratio (percentage of removed patches).') 68 | 69 | parser.add_argument('--norm_pix_loss', action='store_true', 70 | help='Use (per-patch) normalized pixels as targets for computing loss') 71 | parser.set_defaults(norm_pix_loss=False) 72 | 73 | # Optimizer parameters 74 | parser.add_argument('--weight_decay', type=float, default=0.05, 75 | help='weight decay (default: 0.05)') 76 | 77 | parser.add_argument('--lr', type=float, default=None, metavar='LR', 78 | help='learning rate (absolute lr)') 79 | parser.add_argument('--blr', type=float, default=1e-3, metavar='LR', 80 | help='base learning rate: absolute_lr = base_lr * total_batch_size / 256') 81 | parser.add_argument('--min_lr', type=float, default=0., metavar='LR', 82 | help='lower lr bound for cyclic schedulers that hit 0') 83 | 84 | parser.add_argument('--warmup_epochs', type=int, default=40, metavar='N', 85 | help='epochs to warmup LR') 86 | 87 | # Dataset parameters 88 | parser.add_argument('--output_dir', 89 | default='', 90 | help='path where to save, empty for no saving') 91 | parser.add_argument('--log_dir', default='./output_dir', 92 | help='path where to tensorboard log') 93 | parser.add_argument('--device', default='cuda', 94 | help='device to use for training / testing') 95 | parser.add_argument('--seed', default=0, type=int) 96 | # manxin todo dataloader to shuffle or seed or fixed (for re-implement) 97 | parser.add_argument('--resume', default='', 98 | help='resume from checkpoint') 99 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 100 | help='start epoch') 101 | parser.add_argument('--num_workers', default=1, type=int) 102 | parser.add_argument('--pin_mem', action='store_true', 103 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 104 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') 105 | parser.set_defaults(pin_mem=True) 106 | # distributed training parameters 107 | parser.add_argument('--world_size', default=1, type=int, 108 | help='number of distributed processes') 109 | parser.add_argument('--local_rank', default=-1, type=int) 110 | parser.add_argument('--dist_on_itp', action='store_true') 111 | parser.add_argument('--dist_url', default='env://', 112 | help='url used to set up distributed training') 113 | 114 | # manxin load the h5 data file 115 | 116 | parser.add_argument('--dt', default=1, type=int, 117 | help='how many timesteps ahead the model will predict') 118 | parser.add_argument('--n_history', default=0, type=int, 119 | help='how many previous timesteps to consider') 120 | parser.add_argument('--in_channels', default=[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10, 11, 12, 13, 14, 15, 16, 17, 18, 19], type=type([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10, 11, 12, 13, 14, 15, 16, 17, 18, 19])) 121 | parser.add_argument('--out_channels', default=[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10, 11, 12, 13, 14, 15, 16, 17, 18, 19], type=type([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10, 11, 12, 13, 14, 15, 16, 17, 18, 19])) 122 | 123 | parser.add_argument('--n_in_channels', default=20, type=int) 124 | parser.add_argument('--n_out_channels', default=20, type=int) 125 | parser.add_argument('--crop_size_x', default=None, type=bool) 126 | parser.add_argument('--crop_size_y', default=None, type=bool) 127 | 128 | parser.add_argument('--roll', default=False, type=bool) 129 | parser.add_argument('--two_step_training', default=False, type=bool) 130 | parser.add_argument('--orography', default=False, type=bool) 131 | # parser.add_argument('--precip', default=False, type=bool) 132 | parser.add_argument('--num_data_workers', default=10, type=int) 133 | parser.add_argument('--normalization', default='zscore', type=str) 134 | parser.add_argument('--add_grid', default=False, type=bool) 135 | 136 | # manxin todo 137 | parser.add_argument('--run_num', default='run_1', type=str) 138 | parser.add_argument('--run_mode', default='pretraining', type=str) 139 | parser.add_argument('--save_dir', default='', type=str) 140 | parser.add_argument('--pretrained_ckpt_path', default='', type=str) 141 | 142 | parser.add_argument("--yaml_config", default='./config/AFNO_afno_8_8.yaml', type=str) 143 | parser.add_argument("--config", default='afno_backbone', type=str) 144 | 145 | return parser 146 | 147 | 148 | def main(args): 149 | # manxin todo alternative command 150 | print(args.batch_size) 151 | print(args.output_dir) 152 | misc.init_distributed_mode(args) 153 | 154 | print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) 155 | print("{}".format(args).replace(', ', ',\n')) 156 | 157 | device = torch.device(args.device) 158 | # fix the seed for reproducibility 159 | seed = args.seed + misc.get_rank() 160 | torch.manual_seed(seed) 161 | np.random.seed(seed) 162 | cudnn.benchmark = True 163 | log_writer = None #manxin 164 | train_data_loader_h5, train_dataset_h5, train_sampler_h5 = get_data_loader(args, args.train_data_path_h5, 165 | args.distributed, 166 | train=True) 167 | 168 | # model = models_mae_afno_8_8.__dict__[args.model](norm_pix_loss=args.norm_pix_loss, img_size= args.img_size, patch_size = args.patch_size) 169 | model = models_mae_afno_8_8.__dict__[args.model](norm_pix_loss=args.norm_pix_loss, img_size= args.img_size, patch_size = args.patch_size) 170 | model.to(device) 171 | model_without_ddp = model 172 | print("Model = %s" % str(model_without_ddp)) 173 | 174 | summary(model,(20,720,1440)) 175 | 176 | # manxin 177 | print(str(args.save_dir) + "/" + "log.log") 178 | # sys.stdout = Logger(str(args.save_dir) + "/" + "log.log") 179 | eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() 180 | if args.lr_new is None: # only base_lr is specified 181 | args.lr_new = args.blr * eff_batch_size / 256 182 | print("base lr: %.2e" % (args.lr_new * 256 / eff_batch_size), flush=True) 183 | print("actual lr: %.2e" % args.lr_new) 184 | print("accumulate grad iterations: %d" % args.accum_iter) 185 | print("effective batch size: %d" % eff_batch_size) 186 | 187 | if args.distributed: 188 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True) 189 | model_without_ddp = model.module 190 | 191 | # following timm: set wd as 0 for bias and norm layers 192 | param_groups = optim_factory.add_weight_decay(model_without_ddp, args.weight_decay) 193 | optimizer = torch.optim.AdamW(param_groups, lr=args.lr_new, betas=(0.9, 0.95)) 194 | loss_scaler = NativeScaler() 195 | print(optimizer) 196 | 197 | # misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler) 198 | # misc.load_model_v1(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler, ckpt_path=args.checkpoint_path) 199 | 200 | if args.resuming: 201 | ckpt_path = args.checkpoint_path 202 | else: 203 | ckpt_path = args.pretrained_ckpt_path 204 | misc.load_model_v1(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler, 205 | ckpt_path=ckpt_path) 206 | 207 | print(f"Start training for {args.epochs} epochs from epoch: {args.start_epoch}") 208 | start_time = time.time() 209 | 210 | for epoch in range(args.start_epoch, args.epochs): 211 | if args.distributed: 212 | train_data_loader_h5.sampler.set_epoch(epoch) 213 | print(type(train_data_loader_h5)) 214 | train_stats = train_one_epoch( 215 | model, train_data_loader_h5, 216 | optimizer, device, epoch, loss_scaler, 217 | log_writer=log_writer, 218 | args=args 219 | ) 220 | if args.save_dir: 221 | misc.save_model( 222 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 223 | loss_scaler=loss_scaler, epoch=epoch) 224 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 225 | 'epoch': epoch, } 226 | 227 | if args.save_dir and misc.is_main_process(): 228 | if log_writer is not None: 229 | log_writer.flush() 230 | with open(os.path.join(args.save_dir, "log.txt"), mode="a", encoding="utf-8") as f: 231 | f.write(json.dumps(log_stats) + "\n") 232 | 233 | total_time = time.time() - start_time 234 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 235 | print('Training time {}'.format(total_time_str)) 236 | if args.save_dir: 237 | if log_writer is not None: 238 | log_writer.flush() 239 | log_stats_time = {**{f'Training time': total_time_str}, 240 | 'epoch': epoch, } 241 | with open(os.path.join(args.save_dir, "log.txt"), mode="a", encoding="utf-8") as f: 242 | f.write(json.dumps(log_stats_time) + "\n") 243 | 244 | 245 | def restore_checkpoint(args, model, optimizer, loss_scaler, ckpt_path): 246 | """ We intentionally require a checkpoint_dir to be passed 247 | in order to allow Ray Tune to use this function """ 248 | checkpoint = torch.load(ckpt_path, map_location='cuda:{}'.format(misc.get_rank())) # manxin todo ??misc.get_rank() 249 | try: 250 | model.load_state_dict(checkpoint['model']) 251 | loss_scaler.load_state_dict(checkpoint['scaler']) 252 | except: 253 | new_state_dict = OrderedDict() 254 | for key, val in checkpoint['model'].items(): 255 | name = key[7:] 256 | new_state_dict[name] = val 257 | model.load_state_dict(new_state_dict) 258 | # args.iters = checkpoint['iters'] # manxin todo 259 | args.startEpoch = checkpoint['epoch'] 260 | if args.resuming: 261 | #restore checkpoint is used for finetuning as well as resuming. If finetuning (i.e., not resuming), restore checkpoint does not load optimizer state, instead uses config specified lr. 262 | optimizer.load_state_dict(checkpoint['optimizer']) 263 | return args, model, optimizer, loss_scaler 264 | 265 | if __name__ == '__main__': 266 | args = get_args_parser() 267 | args = args.parse_args() 268 | args.run_mode = 'pretraining' 269 | args.run_num = 'p_afno_8_8' # 'p_1' 270 | args.pretrained_ckpt_path = '' 271 | args.output_dir = '/home/manxin/glh/mae/output' 272 | 273 | if args.output_dir: 274 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 275 | try: 276 | os.makedirs(args.output_dir + "/" + str(args.run_num) + "/" + str(args.run_mode)) 277 | except: 278 | pass 279 | save_dir = args.output_dir + "/" + str(args.run_num) + "/" + str(args.run_mode) 280 | args.save_dir = save_dir 281 | 282 | params = YParams(os.path.abspath(args.yaml_config), args.config) 283 | params['output_dir'] = args.output_dir 284 | params['save_dir'] = args.save_dir 285 | params['run_num'] = args.run_num 286 | params['run_mode'] = args.run_mode 287 | 288 | if args.save_dir: 289 | try: 290 | os.makedirs(args.save_dir + "/" + "ckpt" ) 291 | except: 292 | pass 293 | 294 | params['checkpoint_folder'] = os.path.join(save_dir, 'ckpt') 295 | params['checkpoint_path'] = os.path.join(save_dir, 'ckpt/checkpoint-cur.pth') # manxin todo aaa 296 | 297 | print(params.checkpoint_path) 298 | params['resuming'] = True if os.path.isfile(params.checkpoint_path) else False 299 | params['pretrained_ckpt_path'] = args.pretrained_ckpt_path 300 | params['pretrained'] = True if os.path.isfile(params.pretrained_ckpt_path) else False 301 | # sys.stdout = Logger(str(args.save_dir) + "/" + "log.log") #manxin 302 | sys.stdout = Logger(logFile=str(args.save_dir) + "/" + "log.log") 303 | main(params) 304 | -------------------------------------------------------------------------------- /main_train_next.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import json 4 | import numpy as np 5 | import os 6 | import time 7 | from pathlib import Path 8 | # todo manxin need to delete when using distributed-train or command line to run 9 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 10 | os.environ['CUDA_LAUNCH_BLOCKING'] = '0' 11 | import torch 12 | import torch.backends.cudnn as cudnn 13 | from torch.utils.tensorboard import SummaryWriter 14 | import torchvision.transforms as transforms 15 | import torchvision.datasets as datasets 16 | from torchvision.utils import save_image 17 | import timm 18 | from tqdm import tqdm 19 | 20 | assert timm.__version__ == "0.3.2" # version check 21 | import timm.optim.optim_factory as optim_factory 22 | 23 | import util.misc_finetune as misc 24 | from util.misc_finetune import NativeScalerWithGradNormCount as NativeScaler 25 | 26 | import models_mae_afno_8_8_finetune_next 27 | from engine_finetune_afno_8_8_next import train_one_epoch, valid_one_epoch 28 | import torch.distributed as dist 29 | from torch.nn.parallel import DistributedDataParallel 30 | import sys 31 | 32 | # manxin todo for h5 file loading 33 | from utils.data_loader_multifiles import get_data_loader 34 | from utils.YParams import YParams 35 | from collections import OrderedDict 36 | from torchsummary import summary 37 | 38 | # manxin todo save log info 39 | class Logger(object): 40 | def __init__(self, logFile="Default.log"): 41 | self.terminal = sys.stdout 42 | self.log = open(logFile, 'a') 43 | 44 | def write(self, message): 45 | self.terminal.write(message) 46 | self.log.write(message) 47 | 48 | def flush(self): 49 | pass 50 | 51 | def get_args_parser(): 52 | parser = argparse.ArgumentParser('MAE pre-training', add_help=False) 53 | # parser.add_argument('--batch_size', default=64, type=int, 54 | # help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus') 55 | # parser.add_argument('--epochs', default=800, type=int) 56 | parser.add_argument('--accum_iter', default=1, type=int, 57 | help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)') 58 | 59 | # Model parameters 60 | parser.add_argument('--model', default='mae_vit_large_patch16', type=str, metavar='MODEL', 61 | help='Name of model to train') 62 | 63 | parser.add_argument('--input_size', default=(720,1440), type=int, 64 | help='images input size') 65 | 66 | parser.add_argument('--mask_ratio', default=0.75, type=float, 67 | help='Masking ratio (percentage of removed patches).') 68 | 69 | parser.add_argument('--norm_pix_loss', action='store_true', 70 | help='Use (per-patch) normalized pixels as targets for computing loss') 71 | parser.set_defaults(norm_pix_loss=False) 72 | 73 | # Optimizer parameters 74 | parser.add_argument('--weight_decay', type=float, default=0.05, 75 | help='weight decay (default: 0.05)') 76 | 77 | parser.add_argument('--lr', type=float, default=None, metavar='LR', 78 | help='learning rate (absolute lr)') 79 | parser.add_argument('--blr', type=float, default=1e-3, metavar='LR', 80 | help='base learning rate: absolute_lr = base_lr * total_batch_size / 256') 81 | parser.add_argument('--min_lr', type=float, default=0., metavar='LR', 82 | help='lower lr bound for cyclic schedulers that hit 0') 83 | 84 | parser.add_argument('--warmup_epochs', type=int, default=40, metavar='N', 85 | help='epochs to warmup LR') 86 | 87 | # Dataset parameters 88 | parser.add_argument('--output_dir', 89 | default='', 90 | help='path where to save, empty for no saving') 91 | parser.add_argument('--log_dir', default='./output_dir', 92 | help='path where to tensorboard log') 93 | parser.add_argument('--device', default='cuda', 94 | help='device to use for training / testing') 95 | parser.add_argument('--seed', default=0, type=int) 96 | # manxin todo dataloader to shuffle or seed or fixed (for re-implement) 97 | parser.add_argument('--resume', default='', 98 | help='resume from checkpoint') 99 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 100 | help='start epoch') 101 | parser.add_argument('--num_workers', default=1, type=int) 102 | parser.add_argument('--pin_mem', action='store_true', 103 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 104 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') 105 | parser.set_defaults(pin_mem=True) 106 | # distributed training parameters 107 | parser.add_argument('--world_size', default=1, type=int, 108 | help='number of distributed processes') 109 | parser.add_argument('--local_rank', default=-1, type=int) 110 | parser.add_argument('--dist_on_itp', action='store_true') 111 | parser.add_argument('--dist_url', default='env://', 112 | help='url used to set up distributed training') 113 | 114 | # manxin load the h5 data file 115 | 116 | parser.add_argument('--dt', default=1, type=int, 117 | help='how many timesteps ahead the model will predict') 118 | parser.add_argument('--n_history', default=0, type=int, 119 | help='how many previous timesteps to consider') 120 | parser.add_argument('--in_channels', default=[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10, 11, 12, 13, 14, 15, 16, 17, 18, 19], type=type([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10, 11, 12, 13, 14, 15, 16, 17, 18, 19])) 121 | parser.add_argument('--out_channels', default=[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10, 11, 12, 13, 14, 15, 16, 17, 18, 19], type=type([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10, 11, 12, 13, 14, 15, 16, 17, 18, 19])) 122 | 123 | parser.add_argument('--n_in_channels', default=20, type=int) 124 | parser.add_argument('--n_out_channels', default=20, type=int) 125 | parser.add_argument('--crop_size_x', default=None, type=bool) 126 | parser.add_argument('--crop_size_y', default=None, type=bool) 127 | 128 | parser.add_argument('--roll', default=False, type=bool) 129 | parser.add_argument('--two_step_training', default=False, type=bool) 130 | parser.add_argument('--orography', default=False, type=bool) 131 | # parser.add_argument('--precip', default=False, type=bool) 132 | parser.add_argument('--num_data_workers', default=10, type=int) 133 | parser.add_argument('--normalization', default='zscore', type=str) 134 | parser.add_argument('--add_grid', default=False, type=bool) 135 | 136 | # manxin todo 137 | parser.add_argument('--run_num', default='', type=str) 138 | parser.add_argument('--run_mode', default='pretraining', type=str) 139 | parser.add_argument('--save_dir', default='', type=str) 140 | parser.add_argument('--pretrained_ckpt_path', default='', type=str) 141 | 142 | parser.add_argument("--yaml_config", default='', type=str) 143 | parser.add_argument("--config", default='afno_backbone', type=str) 144 | 145 | return parser 146 | 147 | 148 | def main(args): 149 | # manxin todo alternative command 150 | print(args.batch_size) 151 | print(args.output_dir) 152 | misc.init_distributed_mode(args) 153 | 154 | print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) 155 | print("{}".format(args).replace(', ', ',\n')) 156 | 157 | device = torch.device(args.device) 158 | # fix the seed for reproducibility 159 | seed = args.seed + misc.get_rank() 160 | torch.manual_seed(seed) 161 | np.random.seed(seed) 162 | cudnn.benchmark = True 163 | log_writer = None #manxin 164 | train_data_loader_h5, train_dataset_h5, train_sampler_h5 = get_data_loader(args, args.train_data_path_h5, 165 | args.distributed, 166 | train=True) 167 | valid_data_loader_h5, valid_dataset_h5, valid_sampler_h5 = get_data_loader(args, args.valid_data_path_h5, 168 | args.distributed, 169 | train=True) 170 | 171 | # model = models_mae_afno_8_8_finetune.__dict__[args.model](norm_pix_loss=args.norm_pix_loss, img_size= args.img_size, patch_size = args.patch_size) 172 | model = models_mae_afno_8_8_finetune_next.__dict__[args.model](norm_pix_loss=args.norm_pix_loss, img_size= args.img_size, patch_size = args.patch_size) 173 | model.to(device) 174 | model_without_ddp = model 175 | print("Model = %s" % str(model_without_ddp)) 176 | 177 | #summary(model,(20,720,1440)) 178 | 179 | # manxin 180 | print(str(args.save_dir) + "/" + "log.log") 181 | # sys.stdout = Logger(str(args.save_dir) + "/" + "log.log") 182 | eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() 183 | if args.lr_new is None: # only base_lr is specified 184 | args.lr_new = args.blr * eff_batch_size / 256 185 | print("base lr: %.2e" % (args.lr_new * 256 / eff_batch_size), flush=True) 186 | print("actual lr: %.2e" % args.lr_new) 187 | print("accumulate grad iterations: %d" % args.accum_iter) 188 | print("effective batch size: %d" % eff_batch_size) 189 | 190 | if args.distributed: 191 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True) 192 | model_without_ddp = model.module 193 | 194 | # following timm: set wd as 0 for bias and norm layers 195 | param_groups = optim_factory.add_weight_decay(model_without_ddp, args.weight_decay) 196 | optimizer = torch.optim.AdamW(param_groups, lr=args.lr_new, betas=(0.9, 0.95)) 197 | loss_scaler = NativeScaler() 198 | print(optimizer) 199 | 200 | # misc._model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler) 201 | # misc.load_model_v1(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler, ckpt_path=args.checkpoint_path) 202 | use_pretrained_model = False 203 | if args.resuming: 204 | ckpt_path = args.checkpoint_path 205 | else: 206 | ckpt_path = args.pretrained_ckpt_path 207 | args.resuming = True 208 | use_pretrained_model = True 209 | misc.load_model_v1(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler, 210 | ckpt_path=ckpt_path) 211 | # if use_pretrained_model == True: 212 | # args.start_epoch = 0 213 | 214 | print(f"Start training for {args.epochs} epochs from epoch: {args.start_epoch}") 215 | start_time = time.time() 216 | best_loss = 20. 217 | for epoch in range(args.start_epoch, args.epochs): 218 | if args.distributed: 219 | train_data_loader_h5.sampler.set_epoch(epoch) 220 | train_stats = train_one_epoch( 221 | model, train_data_loader_h5, 222 | optimizer, device, epoch, loss_scaler, 223 | log_writer=log_writer, 224 | args=args 225 | ) 226 | valid_loss = valid_one_epoch( 227 | model, valid_data_loader_h5, 228 | optimizer, device, epoch, loss_scaler, 229 | log_writer=log_writer, 230 | args=args 231 | ) 232 | print("valid_loss:".format(valid_loss)) 233 | vaild_loss_value = valid_loss.pop() 234 | if float(vaild_loss_value) < best_loss: 235 | best_loss = vaild_loss_value 236 | if args.save_dir: 237 | misc.save_model( 238 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 239 | loss_scaler=loss_scaler, epoch=epoch, best_loss_yn=True) 240 | # valid_loss_list = [] 241 | # valid_loss_list.append(best_loss) 242 | # valid_loss = valid_one_epoch( 243 | # model, valid_data_loader_h5, 244 | # optimizer, device, epoch, loss_scaler, 245 | # log_writer=log_writer, 246 | # args=args 247 | # ) 248 | # print("valid_loss:".format(valid_loss)) 249 | # if valid_loss < valid_loss_list[len(valid_loss_list)-1]: 250 | # if int(len(valid_loss_list))<=50: 251 | # valid_loss_list.append(valid_loss) 252 | # 253 | # elif (len(valid_loss_list))>50: 254 | # valid_loss_list[len(valid_loss_list) - 1] = valid_loss 255 | # valid_loss_list.sort() 256 | # best_loss_yn = True 257 | # if args.save_dir: 258 | # misc.save_model( 259 | # args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 260 | # loss_scaler=loss_scaler, epoch=epoch, best_loss_yn=True) 261 | 262 | 263 | if args.save_dir: 264 | misc.save_model( 265 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 266 | loss_scaler=loss_scaler, epoch=epoch) 267 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 268 | 'epoch': epoch, 'valid_loss': vaild_loss_value} 269 | 270 | if args.save_dir and misc.is_main_process(): 271 | if log_writer is not None: 272 | log_writer.flush() 273 | with open(os.path.join(args.save_dir, "log.txt"), mode="a", encoding="utf-8") as f: 274 | f.write(json.dumps(log_stats) + "\n") 275 | 276 | total_time = time.time() - start_time 277 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 278 | print('Training time {}'.format(total_time_str)) 279 | if args.save_dir: 280 | if log_writer is not None: 281 | log_writer.flush() 282 | log_stats_time = {**{f'Training time': total_time_str}, 283 | 'epoch': epoch, } 284 | with open(os.path.join(args.save_dir, "log.txt"), mode="a", encoding="utf-8") as f: 285 | f.write(json.dumps(log_stats_time) + "\n") 286 | 287 | 288 | def restore_checkpoint(args, model, optimizer, loss_scaler, ckpt_path): 289 | """ We intentionally require a checkpoint_dir to be passed 290 | in order to allow Ray Tune to use this function """ 291 | checkpoint = torch.load(ckpt_path, map_location='cuda:{}'.format(misc.get_rank())) # manxin todo ??misc.get_rank() 292 | try: 293 | model.load_state_dict(checkpoint['model']) 294 | loss_scaler.load_state_dict(checkpoint['scaler']) 295 | except: 296 | new_state_dict = OrderedDict() 297 | for key, val in checkpoint['model'].items(): 298 | name = key[7:] 299 | new_state_dict[name] = val 300 | model.load_state_dict(new_state_dict) 301 | # args.iters = checkpoint['iters'] # manxin todo 302 | args.startEpoch = checkpoint['epoch'] 303 | if args.resuming: 304 | #restore checkpoint is used for finetuning as well as resuming. If finetuning (i.e., not resuming), restore checkpoint does not load optimizer state, instead uses config specified lr. 305 | optimizer.load_state_dict(checkpoint['optimizer']) 306 | return args, model, optimizer, loss_scaler 307 | 308 | if __name__ == '__main__': 309 | args = get_args_parser() 310 | args = args.parse_args() 311 | args.run_mode = 'finetuning' 312 | args.run_num = 'f_afno_8_8_next_614' # 'p_1' 313 | args.pretrained_ckpt_path = '/home/manxin/glh/mae/output/p_afno_8_8/pretraining/ckpt/checkpoint-20.pth' # ckpt45 314 | args.output_dir = '/home/manxin/glh/mae/output' 315 | args.yaml_config = './config/AFNO_afno_8_8_finetune_next.yaml' 316 | 317 | if args.output_dir: 318 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 319 | try: 320 | os.makedirs(args.output_dir + "/" + str(args.run_num) + "/" + str(args.run_mode)) 321 | except: 322 | pass 323 | save_dir = args.output_dir + "/" + str(args.run_num) + "/" + str(args.run_mode) 324 | args.save_dir = save_dir 325 | 326 | params = YParams( 327 | os.path.abspath(args.yaml_config), args.config) 328 | params['output_dir'] = args.output_dir 329 | params['save_dir'] = args.save_dir 330 | params['run_num'] = args.run_num 331 | params['run_mode'] = args.run_mode 332 | 333 | if args.save_dir: 334 | try: 335 | os.makedirs(args.save_dir + "/" + "ckpt" ) 336 | except: 337 | pass 338 | 339 | params['checkpoint_folder'] = os.path.join(save_dir, 'ckpt') 340 | params['checkpoint_path'] = os.path.join(save_dir, 'ckpt/checkpoint-cur.pth') # manxin todo aaa 341 | 342 | print(params.checkpoint_path) 343 | params['resuming'] = True if os.path.isfile(params.checkpoint_path) else False 344 | params['pretrained_ckpt_path'] = args.pretrained_ckpt_path 345 | params['pretrained'] = True if os.path.isfile(params.pretrained_ckpt_path) else False 346 | # sys.stdout = Logger(str(args.save_dir) + "/" + "log.log") #manxin 347 | sys.stdout = Logger(logFile=str(args.save_dir) + "/" + "log.log") 348 | main(params) 349 | 350 | -------------------------------------------------------------------------------- /mode_moe_next.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | # from timm.models.vision_transformer import PatchEmbed, Block 7 | from timm.models.vision_transformer import PatchEmbed # todo manxin 8 | 9 | from timm.models.vision_transformer import Mlp # todo manxin added 10 | from timm.models.layers import DropPath, trunc_normal_ # todo manxin added 11 | from util.pos_embed import get_2d_sincos_pos_embed 12 | from util.pos_embed import get_2d_sincos_pos_embed_v1 13 | 14 | from utils.img_utils import PeriodicPad2d 15 | from afnonet_8 import Block,Decoder_Block 16 | 17 | class MaskedAutoencoderViT(nn.Module): 18 | """ Masked Autoencoder with VisionTransformer backbone 19 | """ 20 | 21 | def __init__(self, img_size=(720,1440), patch_size=(8,8), in_chans=20, 22 | embed_dim=768, depth=12, num_heads=16, 23 | decoder_embed_dim=512, decoder_depth=4, decoder_num_heads=16, 24 | mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False): 25 | super().__init__() 26 | # -------------------------------------------------------------------------- 27 | # MAE encoder specifics 28 | self.img_size =img_size 29 | self.patch_size = patch_size 30 | self.in_chans = in_chans 31 | self.embed_dim = embed_dim 32 | self.decoder_embed_dim = decoder_embed_dim 33 | self.patch_embed = PatchEmbed_v1(img_size=self.img_size, patch_size=self.patch_size, in_chans=self.in_chans, embed_dim=self.embed_dim) 34 | num_patches = self.patch_embed.num_patches 35 | 36 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 37 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), 38 | requires_grad=False) # fixed sin-cos embedding 39 | self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) 40 | 41 | # self.blocks = nn.ModuleList([ 42 | # Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer) 43 | # for i in range(depth)]) 44 | self.blocks = nn.ModuleList([ 45 | Block(dim = embed_dim, 46 | mlp_ratio = 4., 47 | drop=0., 48 | drop_path=0., 49 | act_layer = nn.GELU, 50 | norm_layer=norm_layer, 51 | double_skip = True, 52 | num_blocks = 8, 53 | sparsity_threshold = 0.01, 54 | hard_thresholding_fraction = 1.0) 55 | for i in range(depth)]) 56 | self.norm = norm_layer(embed_dim) 57 | # -------------------------------------------------------------------------- 58 | 59 | # -------------------------------------------------------------------------- 60 | # MAE decoder specifics 61 | self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True) 62 | 63 | self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) 64 | # manxin todo no cls token 65 | self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches, decoder_embed_dim), 66 | requires_grad=False) # fixed sin-cos embedding 67 | self.decoder_absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, decoder_embed_dim)) 68 | # print(self.decoder_pos_embed.shape) 69 | # self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), 70 | # requires_grad=False) # fixed sin-cos embedding 71 | 72 | # self.decoder_blocks = nn.ModuleList([ 73 | # decoder_Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer) 74 | # for i in range(decoder_depth)]) 75 | self.decoder_blocks = nn.ModuleList([ 76 | Decoder_Block(dim=decoder_embed_dim, 77 | mlp_ratio=4., 78 | drop=0., 79 | drop_path=0., 80 | act_layer=nn.GELU, 81 | norm_layer=norm_layer, 82 | double_skip=True, 83 | num_blocks=8, 84 | sparsity_threshold=0.01, 85 | hard_thresholding_fraction=1.0) 86 | for i in range(decoder_depth)]) 87 | 88 | self.decoder_norm = norm_layer(decoder_embed_dim) 89 | self.decoder_pred = nn.Linear(decoder_embed_dim, self.in_chans * self.patch_size[0] * self.patch_size[1], bias=True) # decoder to patch 90 | # manxin todo 91 | # self.head = nn.Linear(decoder_embed_dim, self.in_chans * self.patch_size[0] * self.patch_size[1], bias=False) 92 | # -------------------------------------------------------------------------- 93 | 94 | self.norm_pix_loss = norm_pix_loss 95 | 96 | self.initialize_weights() 97 | # todo manxin rewrite pos_emb initial method 98 | def initialize_weights(self): 99 | # initialization 100 | # initialize (and freeze) pos_embed by sin-cos embedding 101 | pos_embed = get_2d_sincos_pos_embed_v1(self.pos_embed.shape[-1], ([self.img_size, self.patch_size]), 102 | cls_token=True) 103 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) 104 | 105 | # manxin todo no cls token 106 | decoder_pos_embed = get_2d_sincos_pos_embed_v1(self.decoder_pos_embed.shape[-1], 107 | ([self.img_size, self.patch_size]), cls_token=False) 108 | self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) 109 | 110 | # decoder_pos_embed = get_2d_sincos_pos_embed_v1(self.decoder_pos_embed.shape[-1], 111 | # ([self.img_size, self.patch_size]), cls_token=True) 112 | # self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) 113 | 114 | # initialize patch_embed like nn.Linear (instead of nn.Conv2d) 115 | w = self.patch_embed.proj.weight.data 116 | torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 117 | 118 | # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) 119 | torch.nn.init.normal_(self.cls_token, std=.02) 120 | torch.nn.init.normal_(self.mask_token, std=.02) 121 | torch.nn.init.normal_(self.absolute_pos_embed, std=.02) 122 | torch.nn.init.normal_(self.decoder_absolute_pos_embed,std=.02) 123 | 124 | # initialize nn.Linear and nn.LayerNorm 125 | self.apply(self._init_weights) 126 | 127 | def _init_weights(self, m): 128 | if isinstance(m, nn.Linear): 129 | # we use xavier_uniform following official JAX ViT: 130 | torch.nn.init.xavier_uniform_(m.weight) 131 | if isinstance(m, nn.Linear) and m.bias is not None: 132 | nn.init.constant_(m.bias, 0) 133 | elif isinstance(m, nn.LayerNorm): 134 | nn.init.constant_(m.bias, 0) 135 | nn.init.constant_(m.weight, 1.0) 136 | 137 | def patchify(self, imgs): 138 | """ 139 | imgs: (N, 20, H, W) 140 | x: (N, L, patch_size**2 *20) 141 | """ 142 | p = self.patch_embed.patch_size[0] 143 | # assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 144 | 145 | h = imgs.shape[2] // p 146 | w = imgs.shape[3] // p 147 | x = imgs.reshape(shape=(imgs.shape[0],20, h, p, w, p)) 148 | x = torch.einsum('nchpwq->nhwpqc', x) 149 | x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * 20)) 150 | return x 151 | 152 | def unpatchify(self, x): 153 | """ 154 | x: (N, L, patch_size**2 *3) 155 | imgs: (N, 3, H, W) 156 | """ 157 | p = self.patch_embed.patch_size[0] 158 | h = w = int(x.shape[1] ** .5) 159 | assert h * w == x.shape[1] 160 | 161 | x = x.reshape(shape=(x.shape[0], h, w, p, p, 20)) 162 | x = torch.einsum('nhwpqc->nchpwq', x) 163 | imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) 164 | return imgs 165 | 166 | def random_masking(self, x, mask_ratio): 167 | """ 168 | Perform per-sample random masking by per-sample shuffling. 169 | Per-sample shuffling is done by argsort random noise. 170 | x: [N, L, D], sequence 171 | """ 172 | N, L, D = x.shape # batch, length, dim 173 | len_keep = int(L * (1 - mask_ratio)) 174 | 175 | noise = torch.rand(N, L, device=x.device) # noise in [0, 1] 176 | 177 | # sort noise for each sample 178 | ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove 179 | ids_restore = torch.argsort(ids_shuffle, dim=1) 180 | 181 | # keep the first subset 182 | ids_keep = ids_shuffle[:, :len_keep] 183 | x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) 184 | 185 | # generate the binary mask: 0 is keep, 1 is remove 186 | mask = torch.ones([N, L], device=x.device) 187 | mask[:, :len_keep] = 0 188 | # unshuffle to get the binary mask 189 | mask = torch.gather(mask, dim=1, index=ids_restore) 190 | 191 | return x_masked, mask, ids_restore 192 | 193 | def forward_encoder(self, x, mask_ratio): 194 | # embed patches 195 | x = self.patch_embed(x) 196 | 197 | # add pos embed w/o cls token 198 | x = x + self.pos_embed[:, 1:, :] + self.absolute_pos_embed 199 | 200 | # masking: length -> length * mask_ratio 201 | x, mask, ids_restore = self.random_masking(x, mask_ratio) 202 | 203 | # append cls token 204 | # cls_token = self.cls_token + self.pos_embed[:, :1, :] 205 | # cls_tokens = cls_token.expand(x.shape[0], -1, -1) 206 | # x = torch.cat((cls_tokens, x), dim=1) 207 | if mask_ratio == 0.75: 208 | x = x.reshape(x.shape[0], int(self.img_size[0] // self.patch_size[0] // 2), self.img_size[1] // self.patch_size[1] // 2, self.embed_dim) 209 | elif mask_ratio == 0.: 210 | x = x.reshape(x.shape[0], int(self.img_size[0] // self.patch_size[0]), self.img_size[1] // self.patch_size[1], self.embed_dim) 211 | else: 212 | exit -2 213 | # apply Transformer blocks 214 | for blk in self.blocks: 215 | x = blk(x) 216 | # x = self.norm(x)s 217 | x = x.reshape(shape=( 218 | x.shape[0], (int(x.shape[1]) * int(x.shape[2])), 219 | x.shape[3])) 220 | return x, mask, ids_restore 221 | 222 | def forward_decoder(self, x, ids_restore): 223 | # embed tokens 224 | x = self.decoder_embed(x)+self.decoder_absolute_pos_embed 225 | 226 | # append mask tokens to sequence 227 | mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) 228 | x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token 229 | x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle 230 | # manxin todo no cls token 231 | # x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token 232 | 233 | # add pos embed 234 | x = x_ + self.decoder_pos_embed 235 | 236 | # manxin todo no cls token 237 | x = x.reshape(x.shape[0], int(self.img_size[0] // self.patch_size[0]), self.img_size[1] // self.patch_size[1], self.decoder_embed_dim) 238 | # [1, 16200, 512] ==> [1, 90, 180, 512] 239 | # manxin todo no cls token no block && simple nn layer 240 | # apply Transformer blocks [1, 90, 180, 512] (90*180 = 16200) 241 | for blk in self.decoder_blocks: 242 | x = blk(x) 243 | x = self.decoder_norm(x) 244 | 245 | # predictor projection 246 | x = self.decoder_pred(x) 247 | 248 | # manxin todo no cls token 249 | # # remove cls token 250 | # x = x[:, 1:, :] 251 | # # x = self.head(x) 252 | # manxin todo no cls token 253 | x = x.reshape(shape=( 254 | x.shape[0], (int(self.img_size[0] // self.patch_size[0]))* int(self.img_size[1] // self.patch_size[1]), 255 | x.shape[3])) 256 | return x 257 | 258 | def forward_loss(self, imgs, pred, mask, mask_ratio): 259 | """ 260 | imgs: [N, 3, H, W] 261 | pred: [N, L, p*p*3] 262 | mask: [N, L], 0 is keep, 1 is remove, 263 | """ 264 | target = self.patchify(imgs) 265 | if self.norm_pix_loss: 266 | mean = target.mean(dim=-1, keepdim=True) 267 | var = target.var(dim=-1, keepdim=True) 268 | target = (target - mean) / (var + 1.e-6) ** .5 269 | loss = (pred - target) ** 2 270 | loss = loss.mean(dim=-1) # [N, L], mean loss per patch 271 | if mask_ratio == 0.75: 272 | loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches 273 | elif mask_ratio == 0.: 274 | mask = torch.ones_like(loss) 275 | loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches 276 | else: 277 | exit -3 278 | 279 | return loss 280 | 281 | def forward(self, imgs, mask_ratio=0.75): 282 | latent, mask, ids_restore = self.forward_encoder(imgs[0], mask_ratio) 283 | pred = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3] 284 | loss = self.forward_loss(imgs[1], pred, mask, mask_ratio) 285 | return loss, pred, mask 286 | 287 | class PatchEmbed_v1(nn.Module): 288 | def __init__(self, img_size=(224, 224), patch_size=(16, 16), in_chans=20, embed_dim=768): 289 | super().__init__() 290 | self.img_size = img_size 291 | self.patch_size = patch_size 292 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) 293 | 294 | self.num_patches = num_patches 295 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 296 | 297 | def forward(self, x): 298 | # print(x) 299 | B, C, H, W = x.shape 300 | assert H == self.img_size[0] and W == self.img_size[1], f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 301 | x = self.proj(x) 302 | # print("projx=====================") 303 | # print(x) 304 | x = x.flatten(2) 305 | # print("flatenx=====================") 306 | # print(x) 307 | x = x.transpose(1, 2) 308 | # print("transx=====================") 309 | # print(x) 310 | return x 311 | 312 | # todo manxin added 313 | 314 | class Attention_v1(nn.Module): 315 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 316 | super().__init__() 317 | self.num_heads = num_heads 318 | self.window_size = (2,6,12) 319 | head_dim = dim // num_heads 320 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 321 | self.scale = qk_scale or head_dim ** -0.5 322 | 323 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 324 | self.attn_drop = nn.Dropout(attn_drop) 325 | self.proj = nn.Linear(dim, dim) 326 | self.proj_drop = nn.Dropout(proj_drop) 327 | 328 | 329 | def _construct_index(self): 330 | coords_zi = torch.range(0, self.window_size[0]) 331 | coords_zj = -torch.range(0, self.window_size[0]) * self.window_size[0] 332 | 333 | coords_hi = torch.range(0, self.window_size[1]) 334 | coords_hj = -torch.range(0, self.window_size[1]) * self.window_size[1] 335 | 336 | coords_w = torch.range(0, self.window_size[2]) 337 | 338 | coords_1 = torch.stack(torch.meshgrid([coords_zi, coords_hi, coords_w])) 339 | coords_2 = torch.stack(torch.meshgrid([coords_zj, coords_hj, coords_w])) 340 | 341 | coords_flatten_1 = torch.flatten(coords_1, start_dim=1) 342 | coords_flatten_2 = torch.flatten(coords_2, start_dim=1) 343 | coords = coords_flatten_1[:, :, None] - coords_flatten_2[:, None, :] 344 | coords = coords.permute((1, 2, 0)) 345 | 346 | coords[:, :, 2] += self.window_size[2] - 1 347 | coords[:, :, 1] *= 2 * self.window_size[2] - 1 348 | coords[:, :, 0] *= (2 * self.window_size[2] - 1) * self.window_size[1] * self.window_size[1] 349 | 350 | position_index = torch.sum(coords, dim=-1) 351 | position_index = torch.flatten(position_index) 352 | 353 | def forward(self, x): 354 | B, H, W, C = x.shape 355 | qkv = self.qkv(x).reshape(B, H, W, 3, self.num_heads, C // self.num_heads).permute(3, 0, 4, 1, 2, 5) 356 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 357 | 358 | print(q.shape) 359 | print(k.shape) 360 | print(v.shape) 361 | 362 | attn = (q @ k.transpose(-2, -1)) * self.scale 363 | print(attn.shape) 364 | 365 | attn = attn.softmax(dim=-1) 366 | print(attn.shape) 367 | 368 | attn = self.attn_drop(attn) 369 | print(attn.shape) 370 | 371 | x = (attn @ v).transpose(1, 2).reshape(B, H, W, C) 372 | print(x.shape) 373 | 374 | x = self.proj(x) 375 | print(x.shape) 376 | 377 | x = self.proj_drop(x) 378 | print(x.shape) 379 | 380 | return x 381 | 382 | # def decoder_Block 383 | class decoder_Block(nn.Module): 384 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 385 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 386 | super().__init__() 387 | self.norm1 = norm_layer(dim) 388 | # self.filter = Attention_v1(dim, num_blocks, sparsity_threshold, hard_thresholding_fraction) 389 | self.attn = Attention_v1( 390 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 391 | 392 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 393 | #self.drop_path = nn.Identity() 394 | self.norm2 = norm_layer(dim) 395 | mlp_hidden_dim = int(dim * mlp_ratio) 396 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 397 | # self.double_skip = double_skip 398 | 399 | def forward(self, x): 400 | residual = x 401 | x = self.norm1(x) 402 | x = self.attn(x) 403 | # todo manxin 404 | # if self.double_skip: 405 | x = x + residual 406 | residual = x 407 | 408 | x = self.norm2(x) 409 | x = self.mlp(x) 410 | x = self.drop_path(x) 411 | x = x + residual 412 | return x 413 | 414 | def mae_vit_base_patch16_dec512d8b(**kwargs): 415 | model = MaskedAutoencoderViT( 416 | embed_dim=768, depth=6, num_heads=12, 417 | decoder_embed_dim=512, decoder_depth=6, decoder_num_heads=8, 418 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 419 | return model 420 | # model = MaskedAutoencoderViT( 421 | # img_size=kwargs.img_size, patch_size=kwargs.patch_size, embed_dim=768, depth=12, num_heads=12, 422 | # decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 423 | # mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 424 | # return model 425 | 426 | def mae_vit_large_patch16_dec512d8b(**kwargs): 427 | model = MaskedAutoencoderViT( 428 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, 429 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 430 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 431 | return model 432 | 433 | 434 | def mae_vit_huge_patch14_dec512d8b(**kwargs): 435 | model = MaskedAutoencoderViT( 436 | patch_size=14, embed_dim=1280, depth=32, num_heads=16, 437 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 438 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 439 | return model 440 | 441 | 442 | # set recommended archs 443 | mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b # decoder: 512 dim, 8 blocks 444 | mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b # decoder: 512 dim, 8 blocks 445 | mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b # decoder: 512 dim, 8 blocks 446 | -------------------------------------------------------------------------------- /model_moe.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | # from timm.models.vision_transformer import PatchEmbed, Block 7 | from timm.models.vision_transformer import PatchEmbed # todo manxin 8 | 9 | from timm.models.vision_transformer import Mlp # todo manxin added 10 | from timm.models.layers import DropPath, trunc_normal_ # todo manxin added 11 | from util.pos_embed import get_2d_sincos_pos_embed 12 | from util.pos_embed import get_2d_sincos_pos_embed_v1 13 | 14 | from utils.img_utils import PeriodicPad2d 15 | from afnonet_8 import Block,Decoder_Block 16 | import random 17 | 18 | class MaskedAutoencoderViT(nn.Module): 19 | """ Masked Autoencoder with VisionTransformer backbone 20 | """ 21 | 22 | def __init__(self, img_size=(720,1440), patch_size=(8,8), in_chans=20, 23 | embed_dim=768, depth=12, num_heads=16, 24 | decoder_embed_dim=512, decoder_depth=4, decoder_num_heads=16, 25 | mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False): 26 | super().__init__() 27 | # -------------------------------------------------------------------------- 28 | # MAE encoder specifics 29 | self.img_size =img_size 30 | self.patch_size = patch_size 31 | self.in_chans = in_chans 32 | self.embed_dim = embed_dim 33 | self.decoder_embed_dim = decoder_embed_dim 34 | self.patch_embed = PatchEmbed_v1(img_size=self.img_size, patch_size=self.patch_size, in_chans=self.in_chans, embed_dim=self.embed_dim) 35 | num_patches = self.patch_embed.num_patches 36 | 37 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 38 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), 39 | requires_grad=False) # fixed sin-cos embedding 40 | self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) 41 | 42 | # self.blocks = nn.ModuleList([ 43 | # Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer) 44 | # for i in range(depth)]) 45 | self.blocks = nn.ModuleList([ 46 | Block(dim = embed_dim, 47 | mlp_ratio = 4., 48 | drop=0., 49 | drop_path=0., 50 | act_layer = nn.GELU, 51 | norm_layer=norm_layer, 52 | double_skip = True, 53 | num_blocks = 8, 54 | sparsity_threshold = 0.01, 55 | hard_thresholding_fraction = 1.0) 56 | for i in range(depth)]) 57 | self.norm = norm_layer(embed_dim) 58 | # -------------------------------------------------------------------------- 59 | 60 | # -------------------------------------------------------------------------- 61 | # MAE decoder specifics 62 | self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True) 63 | 64 | self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) 65 | # manxin todo no cls token 66 | self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches, decoder_embed_dim), 67 | requires_grad=False) # fixed sin-cos embedding 68 | self.decoder_absolute_pos_embed = nn.Parameter(torch.zeros(1,num_patches,decoder_embed_dim)) 69 | 70 | # print(self.decoder_pos_embed.shape) 71 | # self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), 72 | # requires_grad=False) # fixed sin-cos embedding 73 | 74 | # self.decoder_blocks = nn.ModuleList([ 75 | # decoder_Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer) 76 | # for i in range(decoder_depth)]) 77 | self.decoder_blocks = nn.ModuleList([ 78 | Decoder_Block(dim=decoder_embed_dim, 79 | mlp_ratio=4., 80 | drop=0., 81 | drop_path=0., 82 | act_layer=nn.GELU, 83 | norm_layer=norm_layer, 84 | double_skip=True, 85 | num_blocks=8, 86 | sparsity_threshold=0.01, 87 | hard_thresholding_fraction=1.0) 88 | for i in range(decoder_depth)]) 89 | 90 | self.decoder_norm = norm_layer(decoder_embed_dim) 91 | self.decoder_pred = nn.Linear(decoder_embed_dim, self.in_chans * self.patch_size[0] * self.patch_size[1], bias=True) # decoder to patch 92 | # manxin todo 93 | # self.head = nn.Linear(decoder_embed_dim, self.in_chans * self.patch_size[0] * self.patch_size[1], bias=False) 94 | # -------------------------------------------------------------------------- 95 | 96 | 97 | self.norm_pix_loss = norm_pix_loss 98 | 99 | self.initialize_weights() 100 | # todo manxin rewrite pos_emb initial method 101 | def initialize_weights(self): 102 | # initialization 103 | # initialize (and freeze) pos_embed by sin-cos embedding 104 | pos_embed = get_2d_sincos_pos_embed_v1(self.pos_embed.shape[-1], ([self.img_size, self.patch_size]), 105 | cls_token=True) 106 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) 107 | 108 | # manxin todo no cls token 109 | decoder_pos_embed = get_2d_sincos_pos_embed_v1(self.decoder_pos_embed.shape[-1], 110 | ([self.img_size, self.patch_size]), cls_token=False) 111 | self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) 112 | 113 | # decoder_pos_embed = get_2d_sincos_pos_embed_v1(self.decoder_pos_embed.shape[-1], 114 | # ([self.img_size, self.patch_size]), cls_token=True) 115 | # self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) 116 | 117 | # initialize patch_embed like nn.Linear (instead of nn.Conv2d) 118 | w = self.patch_embed.proj.weight.data 119 | torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 120 | 121 | # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) 122 | torch.nn.init.normal_(self.cls_token, std=.02) 123 | torch.nn.init.normal_(self.mask_token, std=.02) 124 | torch.nn.init.normal_(self.absolute_pos_embed, std=.02) 125 | torch.nn.init.normal_(self.decoder_absolute_pos_embed,std=.02) 126 | 127 | # initialize nn.Linear and nn.LayerNorm 128 | self.apply(self._init_weights) 129 | 130 | def _init_weights(self, m): 131 | if isinstance(m, nn.Linear): 132 | # we use xavier_uniform following official JAX ViT: 133 | torch.nn.init.xavier_uniform_(m.weight) 134 | if isinstance(m, nn.Linear) and m.bias is not None: 135 | nn.init.constant_(m.bias, 0) 136 | elif isinstance(m, nn.LayerNorm): 137 | nn.init.constant_(m.bias, 0) 138 | nn.init.constant_(m.weight, 1.0) 139 | 140 | def patchify(self, imgs): 141 | """ 142 | imgs: (N, 20, H, W) 143 | x: (N, L, patch_size**2 *20) 144 | """ 145 | p = self.patch_embed.patch_size[0] 146 | # assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 147 | 148 | h = imgs.shape[2] // p 149 | w = imgs.shape[3] // p 150 | x = imgs.reshape(shape=(imgs.shape[0],20, h, p, w, p)) 151 | x = torch.einsum('nchpwq->nhwpqc', x) 152 | x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * 20)) 153 | return x 154 | 155 | def unpatchify(self, x): 156 | """ 157 | x: (N, L, patch_size**2 *3) 158 | imgs: (N, 3, H, W) 159 | """ 160 | p = self.patch_embed.patch_size[0] 161 | h = w = int(x.shape[1] ** .5) 162 | assert h * w == x.shape[1] 163 | 164 | x = x.reshape(shape=(x.shape[0], h, w, p, p, 20)) 165 | x = torch.einsum('nhwpqc->nchpwq', x) 166 | imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) 167 | return imgs 168 | 169 | def random_masking(self, x, mask_ratio): 170 | """ 171 | Perform per-sample random masking by per-sample shuffling. 172 | Per-sample shuffling is done by argsort random noise. 173 | x: [N, L, D], sequence 174 | """ 175 | N, L, D = x.shape # batch, length, dim 176 | 177 | #print(N) #4 178 | #print(L) #16200 179 | #print(D) #768 180 | 181 | len_keep = int(L * (1 - mask_ratio)) #4050 182 | 183 | #print(len_keep) 184 | 185 | noise = torch.rand(N, L, device=x.device) + len_keep # noise in [0, 1] 186 | 187 | #print(noise.shape) 188 | 189 | for i in range(N): 190 | left, right = 0, 3 191 | target = 0 192 | for j in range(len_keep): 193 | index = random.randint(left, right) 194 | noise[i, index] = target 195 | target += 1 196 | left += 4 197 | right += 4 198 | 199 | # sort noise for each sample 200 | ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove 201 | ids_restore = torch.argsort(ids_shuffle, dim=1) 202 | 203 | # keep the first subset 204 | ids_keep = ids_shuffle[:, :len_keep] #torch.Size([4, 4050]) 205 | x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) 206 | 207 | # generate the binary mask: 0 is keep, 1 is remove 208 | mask = torch.ones([N, L], device=x.device) 209 | mask[:, :len_keep] = 0 210 | # unshuffle to get the binary mask 211 | mask = torch.gather(mask, dim=1, index=ids_restore) 212 | 213 | return x_masked, mask, ids_restore 214 | 215 | def forward_encoder(self, x, mask_ratio): 216 | # embed patches 217 | 218 | 219 | x = self.patch_embed(x) 220 | 221 | 222 | # add pos embed w/o cls token 223 | 224 | x = x + self.pos_embed[:, 1:, :] + self.absolute_pos_embed 225 | 226 | 227 | 228 | # masking: length -> length * mask_ratio 229 | x, mask, ids_restore = self.random_masking(x, mask_ratio) 230 | 231 | 232 | # append cls token 233 | # cls_token = self.cls_token + self.pos_embed[:, :1, :] 234 | # cls_tokens = cls_token.expand(x.shape[0], -1, -1) 235 | # x = torch.cat((cls_tokens, x), dim=1) 236 | if mask_ratio == 0.75: 237 | x = x.reshape(x.shape[0], int(self.img_size[0] // self.patch_size[0] // 2), self.img_size[1] // self.patch_size[1] // 2, self.embed_dim) 238 | elif mask_ratio == 0.: 239 | x = x.reshape(x.shape[0], int(self.img_size[0] // self.patch_size[0]), self.img_size[1] // self.patch_size[1], self.embed_dim) 240 | else: 241 | exit -2 242 | # apply Transformer blocks 243 | for blk in self.blocks: 244 | x = blk(x) 245 | 246 | 247 | # x = self.norm(x)s 248 | x = x.reshape(shape=( 249 | x.shape[0], (int(x.shape[1]) * int(x.shape[2])), 250 | x.shape[3])) 251 | return x, mask, ids_restore 252 | 253 | def forward_decoder(self, x, ids_restore): 254 | # embed tokens 255 | x = self.decoder_embed(x) 256 | 257 | # append mask tokens to sequence 258 | mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) 259 | x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token 260 | x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle 261 | # manxin todo no cls token 262 | # x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token 263 | 264 | # add pos embed 265 | x = x_ + self.decoder_pos_embed+self.decoder_absolute_pos_embed 266 | 267 | # manxin todo no cls token 268 | x = x.reshape(x.shape[0], int(self.img_size[0] // self.patch_size[0]), self.img_size[1] // self.patch_size[1], self.decoder_embed_dim) 269 | # [1, 16200, 512] ==> [1, 90, 180, 512] 270 | # manxin todo no cls token no block && simple nn layer 271 | # apply Transformer blocks [1, 90, 180, 512] (90*180 = 16200) 272 | for blk in self.decoder_blocks: 273 | x = blk(x) 274 | x = self.decoder_norm(x) 275 | 276 | # predictor projection 277 | x = self.decoder_pred(x) 278 | 279 | # manxin todo no cls token 280 | # # remove cls token 281 | # x = x[:, 1:, :] 282 | # # x = self.head(x) 283 | # manxin todo no cls token 284 | x = x.reshape(shape=( 285 | x.shape[0], (int(self.img_size[0] // self.patch_size[0]))* int(self.img_size[1] // self.patch_size[1]), 286 | x.shape[3])) 287 | return x 288 | 289 | def forward_loss(self, imgs, pred, mask, mask_ratio): 290 | """ 291 | imgs: [N, 3, H, W] 292 | pred: [N, L, p*p*3] 293 | mask: [N, L], 0 is keep, 1 is remove, 294 | """ 295 | target = self.patchify(imgs) 296 | if self.norm_pix_loss: 297 | mean = target.mean(dim=-1, keepdim=True) 298 | var = target.var(dim=-1, keepdim=True) 299 | target = (target - mean) / (var + 1.e-6) ** .5 300 | loss = (pred - target) ** 2 301 | loss = loss.mean(dim=-1) # [N, L], mean loss per patch 302 | if mask_ratio == 0.75: 303 | loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches 304 | elif mask_ratio == 0.: 305 | mask = torch.ones_like(loss) 306 | loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches 307 | else: 308 | exit -3 309 | 310 | return loss 311 | 312 | def forward(self, imgs, mask_ratio=0.75): 313 | latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio) 314 | pred = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3] 315 | loss = self.forward_loss(imgs, pred, mask, mask_ratio) 316 | return loss, pred, mask 317 | 318 | class PatchEmbed_v1(nn.Module): 319 | def __init__(self, img_size=(224, 224), patch_size=(16, 16), in_chans=20, embed_dim=768): 320 | super().__init__() 321 | self.img_size = img_size 322 | self.patch_size = patch_size 323 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) 324 | 325 | self.num_patches = num_patches 326 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 327 | 328 | def forward(self, x): 329 | #print(x.shape) 330 | B, C, H, W = x.shape 331 | assert H == self.img_size[0] and W == self.img_size[1], f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 332 | x = self.proj(x) 333 | # print("projx=====================") 334 | # print(x) 335 | x = x.flatten(2) 336 | # print("flatenx=====================") 337 | # print(x) 338 | x = x.transpose(1, 2) 339 | # print("transx=====================") 340 | # print(x) 341 | return x 342 | 343 | # todo manxin added 344 | 345 | class Attention_v1(nn.Module): 346 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 347 | super().__init__() 348 | self.num_heads = num_heads 349 | self.window_size = (2,6,12) 350 | 351 | head_dim = dim // num_heads 352 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 353 | self.scale = qk_scale or head_dim ** -0.5 354 | 355 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 356 | self.attn_drop = nn.Dropout(attn_drop) 357 | self.proj = nn.Linear(dim, dim) 358 | self.proj_drop = nn.Dropout(proj_drop) 359 | 360 | 361 | def _construct_index(self): 362 | coords_zi = torch.range(0, self.window_size[0]) 363 | coords_zj = -torch.range(0, self.window_size[0]) * self.window_size[0] 364 | 365 | coords_hi = torch.range(0, self.window_size[1]) 366 | coords_hj = -torch.range(0, self.window_size[1]) * self.window_size[1] 367 | 368 | coords_w = torch.range(0, self.window_size[2]) 369 | 370 | coords_1 = torch.stack(torch.meshgrid([coords_zi, coords_hi, coords_w])) 371 | coords_2 = torch.stack(torch.meshgrid([coords_zj, coords_hj, coords_w])) 372 | 373 | coords_flatten_1 = torch.flatten(coords_1, start_dim=1) 374 | coords_flatten_2 = torch.flatten(coords_2, start_dim=1) 375 | coords = coords_flatten_1[:, :, None] - coords_flatten_2[:, None, :] 376 | coords = coords.permute((1, 2, 0)) 377 | 378 | coords[:, :, 2] += self.window_size[2] - 1 379 | coords[:, :, 1] *= 2 * self.window_size[2] - 1 380 | coords[:, :, 0] *= (2 * self.window_size[2] - 1) * self.window_size[1] * self.window_size[1] 381 | 382 | position_index = torch.sum(coords, dim=-1) 383 | position_index = torch.flatten(position_index) 384 | 385 | def forward(self, x): 386 | B, H, W, C = x.shape 387 | qkv = self.qkv(x).reshape(B, H, W, 3, self.num_heads, C // self.num_heads).permute(3, 0, 4, 1, 2, 5) 388 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 389 | 390 | print(q.shape) 391 | print(k.shape) 392 | print(v.shape) 393 | 394 | attn = (q @ k.transpose(-2, -1)) * self.scale 395 | print(attn.shape) 396 | 397 | attn = attn.softmax(dim=-1) 398 | print(attn.shape) 399 | 400 | attn = self.attn_drop(attn) 401 | print(attn.shape) 402 | 403 | x = (attn @ v).transpose(1, 2).reshape(B, H, W, C) 404 | print(x.shape) 405 | 406 | x = self.proj(x) 407 | print(x.shape) 408 | 409 | x = self.proj_drop(x) 410 | print(x.shape) 411 | 412 | return x 413 | 414 | # def decoder_Block 415 | class decoder_Block(nn.Module): 416 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 417 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 418 | super().__init__() 419 | self.norm1 = norm_layer(dim) 420 | # self.filter = Attention_v1(dim, num_blocks, sparsity_threshold, hard_thresholding_fraction) 421 | self.attn = Attention_v1( 422 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 423 | 424 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 425 | #self.drop_path = nn.Identity() 426 | self.norm2 = norm_layer(dim) 427 | mlp_hidden_dim = int(dim * mlp_ratio) 428 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 429 | # self.double_skip = double_skip 430 | 431 | def forward(self, x): 432 | residual = x 433 | x = self.norm1(x) 434 | x = self.attn(x) 435 | # todo manxin 436 | # if self.double_skip: 437 | x = x + residual 438 | residual = x 439 | 440 | x = self.norm2(x) 441 | x = self.mlp(x) 442 | x = self.drop_path(x) 443 | x = x + residual 444 | return x 445 | 446 | def mae_vit_base_patch16_dec512d8b(**kwargs): 447 | model = MaskedAutoencoderViT( 448 | embed_dim=768, depth=6, num_heads=4, 449 | decoder_embed_dim=512, decoder_depth=6, decoder_num_heads=8, 450 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 451 | return model 452 | # model = MaskedAutoencoderViT( 453 | # img_size=kwargs.img_size, patch_size=kwargs.patch_size, embed_dim=768, depth=12, num_heads=12, 454 | # decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 455 | # mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 456 | # return model 457 | 458 | def mae_vit_large_patch16_dec512d8b(**kwargs): 459 | model = MaskedAutoencoderViT( 460 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, 461 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 462 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 463 | return model 464 | 465 | 466 | def mae_vit_huge_patch14_dec512d8b(**kwargs): 467 | model = MaskedAutoencoderViT( 468 | patch_size=14, embed_dim=1280, depth=32, num_heads=16, 469 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 470 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 471 | return model 472 | 473 | 474 | # set recommended archs 475 | mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b # decoder: 512 dim, 8 blocks 476 | mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b # decoder: 512 dim, 8 blocks 477 | mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b # decoder: 512 dim, 8 blocks 478 | -------------------------------------------------------------------------------- /util/crop.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | import torch 10 | 11 | from torchvision import transforms 12 | from torchvision.transforms import functional as F 13 | 14 | 15 | class RandomResizedCrop(transforms.RandomResizedCrop): 16 | """ 17 | RandomResizedCrop for matching TF/TPU implementation: no for-loop is used. 18 | This may lead to results different with torchvision's version. 19 | Following BYOL's TF code: 20 | https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206 21 | """ 22 | @staticmethod 23 | def get_params(img, scale, ratio): 24 | width, height = F._get_image_size(img) 25 | area = height * width 26 | 27 | target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() 28 | log_ratio = torch.log(torch.tensor(ratio)) 29 | aspect_ratio = torch.exp( 30 | torch.empty(1).uniform_(log_ratio[0], log_ratio[1]) 31 | ).item() 32 | 33 | w = int(round(math.sqrt(target_area * aspect_ratio))) 34 | h = int(round(math.sqrt(target_area / aspect_ratio))) 35 | 36 | w = min(w, width) 37 | h = min(h, height) 38 | 39 | i = torch.randint(0, height - h + 1, size=(1,)).item() 40 | j = torch.randint(0, width - w + 1, size=(1,)).item() 41 | 42 | return i, j, h, w 43 | -------------------------------------------------------------------------------- /util/datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # -------------------------------------------------------- 10 | 11 | import os 12 | import PIL 13 | 14 | from torchvision import datasets, transforms 15 | 16 | from timm.data import create_transform 17 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 18 | 19 | 20 | def build_dataset(is_train, args): 21 | transform = build_transform(is_train, args) 22 | 23 | root = os.path.join(args.data_path, 'train' if is_train else 'val') 24 | dataset = datasets.ImageFolder(root, transform=transform) 25 | 26 | print(dataset) 27 | 28 | return dataset 29 | 30 | 31 | def build_transform(is_train, args): 32 | mean = IMAGENET_DEFAULT_MEAN 33 | std = IMAGENET_DEFAULT_STD 34 | # train transform 35 | if is_train: 36 | # this should always dispatch to transforms_imagenet_train 37 | transform = create_transform( 38 | input_size=args.input_size, 39 | is_training=True, 40 | color_jitter=args.color_jitter, 41 | auto_augment=args.aa, 42 | interpolation='bicubic', 43 | re_prob=args.reprob, 44 | re_mode=args.remode, 45 | re_count=args.recount, 46 | mean=mean, 47 | std=std, 48 | ) 49 | return transform 50 | 51 | # eval transform 52 | t = [] 53 | if args.input_size <= 224: 54 | crop_pct = 224 / 256 55 | else: 56 | crop_pct = 1.0 57 | size = int(args.input_size / crop_pct) 58 | t.append( 59 | transforms.Resize(size, interpolation=PIL.Image.BICUBIC), # to maintain same ratio w.r.t. 224 images 60 | ) 61 | t.append(transforms.CenterCrop(args.input_size)) 62 | 63 | t.append(transforms.ToTensor()) 64 | t.append(transforms.Normalize(mean, std)) 65 | return transforms.Compose(t) 66 | -------------------------------------------------------------------------------- /util/lars.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # LARS optimizer, implementation from MoCo v3: 8 | # https://github.com/facebookresearch/moco-v3 9 | # -------------------------------------------------------- 10 | 11 | import torch 12 | 13 | 14 | class LARS(torch.optim.Optimizer): 15 | """ 16 | LARS optimizer, no rate scaling or weight decay for parameters <= 1D. 17 | """ 18 | def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001): 19 | defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, trust_coefficient=trust_coefficient) 20 | super().__init__(params, defaults) 21 | 22 | @torch.no_grad() 23 | def step(self): 24 | for g in self.param_groups: 25 | for p in g['params']: 26 | dp = p.grad 27 | 28 | if dp is None: 29 | continue 30 | 31 | if p.ndim > 1: # if not normalization gamma/beta or bias 32 | dp = dp.add(p, alpha=g['weight_decay']) 33 | param_norm = torch.norm(p) 34 | update_norm = torch.norm(dp) 35 | one = torch.ones_like(param_norm) 36 | q = torch.where(param_norm > 0., 37 | torch.where(update_norm > 0, 38 | (g['trust_coefficient'] * param_norm / update_norm), one), 39 | one) 40 | dp = dp.mul(q) 41 | 42 | param_state = self.state[p] 43 | if 'mu' not in param_state: 44 | param_state['mu'] = torch.zeros_like(p) 45 | mu = param_state['mu'] 46 | mu.mul_(g['momentum']).add_(dp) 47 | p.add_(mu, alpha=-g['lr']) -------------------------------------------------------------------------------- /util/lr_decay.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # ELECTRA https://github.com/google-research/electra 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import json 13 | 14 | 15 | def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75): 16 | """ 17 | Parameter groups for layer-wise lr decay 18 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58 19 | """ 20 | param_group_names = {} 21 | param_groups = {} 22 | 23 | num_layers = len(model.blocks) + 1 24 | 25 | layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1)) 26 | 27 | for n, p in model.named_parameters(): 28 | if not p.requires_grad: 29 | continue 30 | 31 | # no decay: all 1D parameters and model specific ones 32 | if p.ndim == 1 or n in no_weight_decay_list: 33 | g_decay = "no_decay" 34 | this_decay = 0. 35 | else: 36 | g_decay = "decay" 37 | this_decay = weight_decay 38 | 39 | layer_id = get_layer_id_for_vit(n, num_layers) 40 | group_name = "layer_%d_%s" % (layer_id, g_decay) 41 | 42 | if group_name not in param_group_names: 43 | this_scale = layer_scales[layer_id] 44 | 45 | param_group_names[group_name] = { 46 | "lr_scale": this_scale, 47 | "weight_decay": this_decay, 48 | "params": [], 49 | } 50 | param_groups[group_name] = { 51 | "lr_scale": this_scale, 52 | "weight_decay": this_decay, 53 | "params": [], 54 | } 55 | 56 | param_group_names[group_name]["params"].append(n) 57 | param_groups[group_name]["params"].append(p) 58 | 59 | # print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2)) 60 | 61 | return list(param_groups.values()) 62 | 63 | 64 | def get_layer_id_for_vit(name, num_layers): 65 | """ 66 | Assign a parameter with its layer id 67 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 68 | """ 69 | if name in ['cls_token', 'pos_embed']: 70 | return 0 71 | elif name.startswith('patch_embed'): 72 | return 0 73 | elif name.startswith('blocks'): 74 | return int(name.split('.')[1]) + 1 75 | else: 76 | return num_layers -------------------------------------------------------------------------------- /util/lr_sched.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | def adjust_learning_rate(optimizer, epoch, args): 10 | """Decay the learning rate with half-cycle cosine after warmup""" 11 | if epoch < args.warmup_epochs: 12 | lr = args.lr * epoch / args.warmup_epochs +0.00055 13 | else: 14 | lr = 0.00055 + args.min_lr + (args.lr - args.min_lr) * 0.5 * \ 15 | (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs))) 16 | for param_group in optimizer.param_groups: 17 | if "lr_scale" in param_group: 18 | param_group["lr"] = lr * param_group["lr_scale"] 19 | else: 20 | param_group["lr"] = lr 21 | return lr 22 | -------------------------------------------------------------------------------- /util/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import builtins 13 | import datetime 14 | import os 15 | import time 16 | from collections import defaultdict, deque 17 | from pathlib import Path 18 | 19 | import torch 20 | import torch.distributed as dist 21 | from torch._six import inf 22 | from collections import OrderedDict 23 | 24 | 25 | class SmoothedValue(object): 26 | """Track a series of values and provide access to smoothed values over a 27 | window or the global series average. 28 | """ 29 | 30 | def __init__(self, window_size=20, fmt=None): 31 | if fmt is None: 32 | fmt = "{median:.4f} ({global_avg:.4f})" 33 | self.deque = deque(maxlen=window_size) 34 | self.total = 0.0 35 | self.count = 0 36 | self.fmt = fmt 37 | 38 | def update(self, value, n=1): 39 | self.deque.append(value) 40 | self.count += n 41 | self.total += value * n 42 | 43 | def synchronize_between_processes(self): 44 | """ 45 | Warning: does not synchronize the deque! 46 | """ 47 | if not is_dist_avail_and_initialized(): 48 | return 49 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 50 | dist.barrier() 51 | dist.all_reduce(t) 52 | t = t.tolist() 53 | self.count = int(t[0]) 54 | self.total = t[1] 55 | 56 | @property 57 | def median(self): 58 | d = torch.tensor(list(self.deque)) 59 | return d.median().item() 60 | 61 | @property 62 | def avg(self): 63 | d = torch.tensor(list(self.deque), dtype=torch.float32) 64 | return d.mean().item() 65 | 66 | @property 67 | def global_avg(self): 68 | return self.total / self.count 69 | 70 | @property 71 | def max(self): 72 | return max(self.deque) 73 | 74 | @property 75 | def value(self): 76 | return self.deque[-1] 77 | 78 | def __str__(self): 79 | return self.fmt.format( 80 | median=self.median, 81 | avg=self.avg, 82 | global_avg=self.global_avg, 83 | max=self.max, 84 | value=self.value) 85 | 86 | 87 | class MetricLogger(object): 88 | def __init__(self, delimiter="\t"): 89 | self.meters = defaultdict(SmoothedValue) 90 | self.delimiter = delimiter 91 | 92 | def update(self, **kwargs): 93 | for k, v in kwargs.items(): 94 | if v is None: 95 | continue 96 | if isinstance(v, torch.Tensor): 97 | v = v.item() 98 | assert isinstance(v, (float, int)) 99 | self.meters[k].update(v) 100 | 101 | def __getattr__(self, attr): 102 | if attr in self.meters: 103 | return self.meters[attr] 104 | if attr in self.__dict__: 105 | return self.__dict__[attr] 106 | raise AttributeError("'{}' object has no attribute '{}'".format( 107 | type(self).__name__, attr)) 108 | 109 | def __str__(self): 110 | loss_str = [] 111 | for name, meter in self.meters.items(): 112 | loss_str.append( 113 | "{}: {}".format(name, str(meter)) 114 | ) 115 | return self.delimiter.join(loss_str) 116 | 117 | def synchronize_between_processes(self): 118 | for meter in self.meters.values(): 119 | meter.synchronize_between_processes() 120 | 121 | def add_meter(self, name, meter): 122 | self.meters[name] = meter 123 | 124 | def log_every(self, iterable, print_freq, header=None): 125 | i = 0 126 | if not header: 127 | header = '' 128 | start_time = time.time() 129 | end = time.time() 130 | iter_time = SmoothedValue(fmt='{avg:.4f}') 131 | data_time = SmoothedValue(fmt='{avg:.4f}') 132 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 133 | log_msg = [ 134 | header, 135 | '[{0' + space_fmt + '}/{1}]', 136 | 'eta: {eta}', 137 | '{meters}', 138 | 'time: {time}', 139 | 'data: {data}' 140 | ] 141 | if torch.cuda.is_available(): 142 | log_msg.append('max mem: {memory:.0f}') 143 | log_msg = self.delimiter.join(log_msg) 144 | MB = 1024.0 * 1024.0 145 | for obj in iterable: 146 | data_time.update(time.time() - end) 147 | yield obj 148 | iter_time.update(time.time() - end) 149 | if i % print_freq == 0 or i == len(iterable) - 1: 150 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 151 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 152 | if torch.cuda.is_available(): 153 | print(log_msg.format( 154 | i, len(iterable), eta=eta_string, 155 | meters=str(self), 156 | time=str(iter_time), data=str(data_time), 157 | memory=torch.cuda.max_memory_allocated() / MB)) 158 | else: 159 | print(log_msg.format( 160 | i, len(iterable), eta=eta_string, 161 | meters=str(self), 162 | time=str(iter_time), data=str(data_time))) 163 | i += 1 164 | end = time.time() 165 | total_time = time.time() - start_time 166 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 167 | print('{} Total time: {} ({:.4f} s / it)'.format( 168 | header, total_time_str, total_time / len(iterable))) 169 | 170 | 171 | def setup_for_distributed(is_master): 172 | """ 173 | This function disables printing when not in master process 174 | """ 175 | builtin_print = builtins.print 176 | 177 | def print(*args, **kwargs): 178 | force = kwargs.pop('force', False) 179 | force = force or (get_world_size() > 8) 180 | if is_master or force: 181 | now = datetime.datetime.now().time() 182 | builtin_print('[{}] '.format(now), end='') # print with time stamp 183 | builtin_print(*args, **kwargs) 184 | 185 | builtins.print = print 186 | 187 | 188 | def is_dist_avail_and_initialized(): 189 | if not dist.is_available(): 190 | return False 191 | if not dist.is_initialized(): 192 | return False 193 | return True 194 | 195 | 196 | def get_world_size(): 197 | if not is_dist_avail_and_initialized(): 198 | return 1 199 | return dist.get_world_size() 200 | 201 | 202 | def get_rank(): 203 | if not is_dist_avail_and_initialized(): 204 | return 0 205 | return dist.get_rank() 206 | 207 | 208 | def is_main_process(): 209 | return get_rank() == 0 210 | 211 | 212 | def save_on_master(*args, **kwargs): 213 | if is_main_process(): 214 | torch.save(*args, **kwargs) 215 | 216 | 217 | def init_distributed_mode(args): 218 | if args.dist_on_itp: 219 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) 220 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) 221 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) 222 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) 223 | os.environ['LOCAL_RANK'] = str(args.gpu) 224 | os.environ['RANK'] = str(args.rank) 225 | os.environ['WORLD_SIZE'] = str(args.world_size) 226 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] 227 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 228 | args.rank = int(os.environ["RANK"]) 229 | args.world_size = int(os.environ['WORLD_SIZE']) 230 | args.gpu = int(os.environ['LOCAL_RANK']) 231 | elif 'SLURM_PROCID' in os.environ: 232 | args.rank = int(os.environ['SLURM_PROCID']) 233 | args.gpu = args.rank % torch.cuda.device_count() 234 | else: 235 | print('Not using distributed mode') 236 | setup_for_distributed(is_master=True) # hack 237 | args.distributed = False 238 | return 239 | 240 | args.distributed = True 241 | 242 | torch.cuda.set_device(args.gpu) 243 | args.dist_backend = 'nccl' 244 | print('| distributed init (rank {}): {}, gpu {}'.format( 245 | args.rank, args.dist_url, args.gpu), flush=True) 246 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 247 | world_size=args.world_size, rank=args.rank) 248 | torch.distributed.barrier() 249 | setup_for_distributed(args.rank == 0) 250 | 251 | 252 | class NativeScalerWithGradNormCount: 253 | state_dict_key = "amp_scaler" 254 | 255 | def __init__(self): 256 | self._scaler = torch.cuda.amp.GradScaler() 257 | 258 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): 259 | self._scaler.scale(loss).backward(create_graph=create_graph) 260 | if update_grad: 261 | if clip_grad is not None: 262 | assert parameters is not None 263 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 264 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 265 | else: 266 | self._scaler.unscale_(optimizer) 267 | norm = get_grad_norm_(parameters) 268 | self._scaler.step(optimizer) 269 | self._scaler.update() 270 | else: 271 | norm = None 272 | return norm 273 | 274 | def state_dict(self): 275 | return self._scaler.state_dict() 276 | 277 | def load_state_dict(self, state_dict): 278 | self._scaler.load_state_dict(state_dict) 279 | 280 | 281 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: 282 | if isinstance(parameters, torch.Tensor): 283 | parameters = [parameters] 284 | parameters = [p for p in parameters if p.grad is not None] 285 | norm_type = float(norm_type) 286 | if len(parameters) == 0: 287 | return torch.tensor(0.) 288 | device = parameters[0].grad.device 289 | if norm_type == inf: 290 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) 291 | else: 292 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) 293 | return total_norm 294 | 295 | 296 | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler): 297 | try: 298 | os.mkdir(str(args.save_dir) + "/" + "ckpt") 299 | # os.mkdir("/home/manxin/codes/FourCastNet/NVLabs-FourCastNet/FourCastNet-master/mae/mae-main/output_dir" + "/" + str(iii)) 300 | except: 301 | pass 302 | ckpt_dir = Path(str(args.save_dir) + "/" + "ckpt") 303 | # output_dir = Path(args.output_dir) 304 | epoch_name = str(epoch) 305 | if loss_scaler is not None: 306 | checkpoint_paths = [ckpt_dir / ('checkpoint-%s.pth' % epoch_name)] 307 | for checkpoint_path in checkpoint_paths: 308 | to_save = { 309 | 'model': model_without_ddp.state_dict(), 310 | 'optimizer': optimizer.state_dict(), 311 | 'epoch': epoch, 312 | 'scaler': loss_scaler.state_dict(), 313 | 'args': args, 314 | } 315 | 316 | save_on_master(to_save, checkpoint_path) 317 | else: 318 | client_state = {'epoch': epoch} 319 | model.save_checkpoint(save_dir=ckpt_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state) 320 | 321 | def load_model_v1(args, model_without_ddp, optimizer, loss_scaler, ckpt_path): 322 | if os.path.isfile(ckpt_path): 323 | if ckpt_path.startswith('https'): 324 | checkpoint = torch.hub.load_state_dict_from_url( 325 | ckpt_path, map_location='cpu', check_hash=True) 326 | else: 327 | try: 328 | checkpoint = torch.load(ckpt_path, map_location='cpu') 329 | except: 330 | pass 331 | # new_state_dict = OrderedDict() 332 | # for key, val in checkpoint['model'].items(): 333 | # # name = key[7:] 334 | # name = key 335 | # a = name.split('_', 1) 336 | # if a[0] == 'mask' or a[0] == 'decoder': 337 | # continue 338 | # else: 339 | # new_state_dict[name] = val 340 | 341 | # model_without_ddp.load_state_dict(new_state_dict, strict=False) 342 | model_without_ddp.load_state_dict(checkpoint['model'], strict = False) 343 | print("Resume checkpoint %s" % ckpt_path) 344 | if args.resuming: 345 | if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval): 346 | optimizer.load_state_dict(checkpoint['optimizer']) 347 | args.start_epoch = checkpoint['epoch'] + 1 348 | if 'scaler' in checkpoint: 349 | loss_scaler.load_state_dict(checkpoint['scaler']) 350 | print("With optim & sched!") 351 | 352 | def load_model_v2_keep_decoder(args, model_without_ddp, optimizer, loss_scaler, ckpt_path): 353 | if os.path.isfile(ckpt_path): 354 | if ckpt_path.startswith('https'): 355 | checkpoint = torch.hub.load_state_dict_from_url( 356 | ckpt_path, map_location='cpu', check_hash=True) 357 | else: 358 | checkpoint = torch.load(ckpt_path, map_location='cpu') 359 | # new_state_dict = OrderedDict() 360 | # for key, val in checkpoint['model'].items(): 361 | # # name = key[7:] 362 | # name = key 363 | # a = name.split('_', 1) 364 | # if a[0] == 'mask' or a[0] == 'decoder': 365 | # continue 366 | # else: 367 | # new_state_dict[name] = val 368 | 369 | model_without_ddp.load_state_dict(checkpoint['model'], strict=False) 370 | # model_without_ddp.load_state_dict(checkpoint['model'], strict = False) 371 | print("Resume checkpoint %s" % ckpt_path) 372 | if args.resuming: 373 | if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval): 374 | optimizer.load_state_dict(checkpoint['optimizer']) 375 | args.start_epoch = checkpoint['epoch'] + 1 376 | if 'scaler' in checkpoint: 377 | loss_scaler.load_state_dict(checkpoint['scaler']) 378 | print("With optim & sched!") 379 | 380 | 381 | def load_model(args, model_without_ddp, optimizer, loss_scaler): 382 | if args.resuming: 383 | # todo manxin changed 384 | if args.resuming.startswith('https'): 385 | checkpoint = torch.hub.load_state_dict_from_url( 386 | args.resuming, map_location='cpu', check_hash=True) 387 | else: 388 | checkpoint = torch.load(args.resuming, map_location='cpu') 389 | model_without_ddp.load_state_dict(checkpoint['model']) 390 | print("Resume checkpoint %s" % args.resuming) 391 | if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval): 392 | optimizer.load_state_dict(checkpoint['optimizer']) 393 | args.start_epoch = checkpoint['epoch'] + 1 394 | if 'scaler' in checkpoint: 395 | loss_scaler.load_state_dict(checkpoint['scaler']) 396 | print("With optim & sched!") 397 | 398 | 399 | def all_reduce_mean(x): 400 | world_size = get_world_size() 401 | if world_size > 1: 402 | x_reduce = torch.tensor(x).cuda() 403 | dist.all_reduce(x_reduce) 404 | x_reduce /= world_size 405 | return x_reduce.item() 406 | else: 407 | return x -------------------------------------------------------------------------------- /util/misc_finetune.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import builtins 13 | import datetime 14 | import os 15 | import time 16 | from collections import defaultdict, deque 17 | from pathlib import Path 18 | 19 | import torch 20 | import torch.distributed as dist 21 | from torch._six import inf 22 | from collections import OrderedDict 23 | 24 | 25 | class SmoothedValue(object): 26 | """Track a series of values and provide access to smoothed values over a 27 | window or the global series average. 28 | """ 29 | 30 | def __init__(self, window_size=20, fmt=None): 31 | if fmt is None: 32 | fmt = "{median:.4f} ({global_avg:.4f})" 33 | self.deque = deque(maxlen=window_size) 34 | self.total = 0.0 35 | self.count = 0 36 | self.fmt = fmt 37 | 38 | def update(self, value, n=1): 39 | self.deque.append(value) 40 | self.count += n 41 | self.total += value * n 42 | 43 | def synchronize_between_processes(self): 44 | """ 45 | Warning: does not synchronize the deque! 46 | """ 47 | if not is_dist_avail_and_initialized(): 48 | return 49 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 50 | dist.barrier() 51 | dist.all_reduce(t) 52 | t = t.tolist() 53 | self.count = int(t[0]) 54 | self.total = t[1] 55 | 56 | @property 57 | def median(self): 58 | d = torch.tensor(list(self.deque)) 59 | return d.median().item() 60 | 61 | @property 62 | def avg(self): 63 | d = torch.tensor(list(self.deque), dtype=torch.float32) 64 | return d.mean().item() 65 | 66 | @property 67 | def global_avg(self): 68 | return self.total / self.count 69 | 70 | @property 71 | def max(self): 72 | return max(self.deque) 73 | 74 | @property 75 | def value(self): 76 | return self.deque[-1] 77 | 78 | def __str__(self): 79 | return self.fmt.format( 80 | median=self.median, 81 | avg=self.avg, 82 | global_avg=self.global_avg, 83 | max=self.max, 84 | value=self.value) 85 | 86 | 87 | class MetricLogger(object): 88 | def __init__(self, delimiter="\t"): 89 | self.meters = defaultdict(SmoothedValue) 90 | self.delimiter = delimiter 91 | 92 | def update(self, **kwargs): 93 | for k, v in kwargs.items(): 94 | if v is None: 95 | continue 96 | if isinstance(v, torch.Tensor): 97 | v = v.item() 98 | assert isinstance(v, (float, int)) 99 | self.meters[k].update(v) 100 | 101 | def __getattr__(self, attr): 102 | if attr in self.meters: 103 | return self.meters[attr] 104 | if attr in self.__dict__: 105 | return self.__dict__[attr] 106 | raise AttributeError("'{}' object has no attribute '{}'".format( 107 | type(self).__name__, attr)) 108 | 109 | def __str__(self): 110 | loss_str = [] 111 | for name, meter in self.meters.items(): 112 | loss_str.append( 113 | "{}: {}".format(name, str(meter)) 114 | ) 115 | return self.delimiter.join(loss_str) 116 | 117 | def synchronize_between_processes(self): 118 | for meter in self.meters.values(): 119 | meter.synchronize_between_processes() 120 | 121 | def add_meter(self, name, meter): 122 | self.meters[name] = meter 123 | 124 | def log_every(self, iterable, print_freq, header=None): 125 | i = 0 126 | if not header: 127 | header = '' 128 | start_time = time.time() 129 | end = time.time() 130 | iter_time = SmoothedValue(fmt='{avg:.4f}') 131 | data_time = SmoothedValue(fmt='{avg:.4f}') 132 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 133 | log_msg = [ 134 | header, 135 | '[{0' + space_fmt + '}/{1}]', 136 | 'eta: {eta}', 137 | '{meters}', 138 | 'time: {time}', 139 | 'data: {data}' 140 | ] 141 | if torch.cuda.is_available(): 142 | log_msg.append('max mem: {memory:.0f}') 143 | log_msg = self.delimiter.join(log_msg) 144 | MB = 1024.0 * 1024.0 145 | for obj in iterable: 146 | data_time.update(time.time() - end) 147 | yield obj 148 | iter_time.update(time.time() - end) 149 | if i % print_freq == 0 or i == len(iterable) - 1: 150 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 151 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 152 | if torch.cuda.is_available(): 153 | print(log_msg.format( 154 | i, len(iterable), eta=eta_string, 155 | meters=str(self), 156 | time=str(iter_time), data=str(data_time), 157 | memory=torch.cuda.max_memory_allocated() / MB)) 158 | else: 159 | print(log_msg.format( 160 | i, len(iterable), eta=eta_string, 161 | meters=str(self), 162 | time=str(iter_time), data=str(data_time))) 163 | i += 1 164 | end = time.time() 165 | total_time = time.time() - start_time 166 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 167 | print('{} Total time: {} ({:.4f} s / it)'.format( 168 | header, total_time_str, total_time / len(iterable))) 169 | 170 | 171 | def setup_for_distributed(is_master): 172 | """ 173 | This function disables printing when not in master process 174 | """ 175 | builtin_print = builtins.print 176 | 177 | def print(*args, **kwargs): 178 | force = kwargs.pop('force', False) 179 | force = force or (get_world_size() > 8) 180 | if is_master or force: 181 | now = datetime.datetime.now().time() 182 | builtin_print('[{}] '.format(now), end='') # print with time stamp 183 | builtin_print(*args, **kwargs) 184 | 185 | builtins.print = print 186 | 187 | 188 | def is_dist_avail_and_initialized(): 189 | if not dist.is_available(): 190 | return False 191 | if not dist.is_initialized(): 192 | return False 193 | return True 194 | 195 | 196 | def get_world_size(): 197 | if not is_dist_avail_and_initialized(): 198 | return 1 199 | return dist.get_world_size() 200 | 201 | 202 | def get_rank(): 203 | if not is_dist_avail_and_initialized(): 204 | return 0 205 | return dist.get_rank() 206 | 207 | 208 | def is_main_process(): 209 | return get_rank() == 0 210 | 211 | 212 | def save_on_master(*args, **kwargs): 213 | if is_main_process(): 214 | torch.save(*args, **kwargs) 215 | 216 | 217 | def init_distributed_mode(args): 218 | if args.dist_on_itp: 219 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) 220 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) 221 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) 222 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) 223 | os.environ['LOCAL_RANK'] = str(args.gpu) 224 | os.environ['RANK'] = str(args.rank) 225 | os.environ['WORLD_SIZE'] = str(args.world_size) 226 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] 227 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 228 | args.rank = int(os.environ["RANK"]) 229 | args.world_size = int(os.environ['WORLD_SIZE']) 230 | args.gpu = int(os.environ['LOCAL_RANK']) 231 | elif 'SLURM_PROCID' in os.environ: 232 | args.rank = int(os.environ['SLURM_PROCID']) 233 | args.gpu = args.rank % torch.cuda.device_count() 234 | else: 235 | print('Not using distributed mode') 236 | setup_for_distributed(is_master=True) # hack 237 | args.distributed = False 238 | return 239 | 240 | args.distributed = True 241 | 242 | torch.cuda.set_device(args.gpu) 243 | args.dist_backend = 'nccl' 244 | print('| distributed init (rank {}): {}, gpu {}'.format( 245 | args.rank, args.dist_url, args.gpu), flush=True) 246 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 247 | world_size=args.world_size, rank=args.rank) 248 | torch.distributed.barrier() 249 | setup_for_distributed(args.rank == 0) 250 | 251 | 252 | class NativeScalerWithGradNormCount: 253 | state_dict_key = "amp_scaler" 254 | 255 | def __init__(self): 256 | self._scaler = torch.cuda.amp.GradScaler() 257 | 258 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): 259 | self._scaler.scale(loss).backward(create_graph=create_graph) 260 | if update_grad: 261 | if clip_grad is not None: 262 | assert parameters is not None 263 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 264 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 265 | else: 266 | self._scaler.unscale_(optimizer) 267 | norm = get_grad_norm_(parameters) 268 | self._scaler.step(optimizer) 269 | self._scaler.update() 270 | else: 271 | norm = None 272 | return norm 273 | 274 | def state_dict(self): 275 | return self._scaler.state_dict() 276 | 277 | def load_state_dict(self, state_dict): 278 | self._scaler.load_state_dict(state_dict) 279 | 280 | 281 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: 282 | if isinstance(parameters, torch.Tensor): 283 | parameters = [parameters] 284 | parameters = [p for p in parameters if p.grad is not None] 285 | norm_type = float(norm_type) 286 | if len(parameters) == 0: 287 | return torch.tensor(0.) 288 | device = parameters[0].grad.device 289 | if norm_type == inf: 290 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) 291 | else: 292 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) 293 | return total_norm 294 | 295 | 296 | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, best_loss_yn=False): 297 | try: 298 | os.mkdir(str(args.save_dir) + "/" + "ckpt") 299 | # os.mkdir("/home/manxin/codes/FourCastNet/NVLabs-FourCastNet/FourCastNet-master/mae/mae-main/output_dir" + "/" + str(iii)) 300 | except: 301 | pass 302 | ckpt_dir = Path(str(args.save_dir) + "/" + "ckpt") 303 | # output_dir = Path(args.output_dir) 304 | 305 | if best_loss_yn == True: 306 | epoch_name = str('best') 307 | elif epoch % 1 == 0: 308 | epoch_name = str(epoch) 309 | else: 310 | epoch_name = str('cur') 311 | if loss_scaler is not None: 312 | checkpoint_paths = [ckpt_dir / ('checkpoint-%s.pth' % epoch_name)] 313 | for checkpoint_path in checkpoint_paths: 314 | to_save = { 315 | 'model': model_without_ddp.state_dict(), 316 | 'optimizer': optimizer.state_dict(), 317 | 'epoch': epoch, 318 | 'scaler': loss_scaler.state_dict(), 319 | 'args': args, 320 | } 321 | 322 | save_on_master(to_save, checkpoint_path) 323 | else: 324 | client_state = {'epoch': epoch} 325 | model.save_checkpoint(save_dir=ckpt_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state) 326 | 327 | def load_model_v1(args, model_without_ddp, optimizer, loss_scaler, ckpt_path): 328 | if os.path.isfile(ckpt_path): 329 | if ckpt_path.startswith('https'): 330 | checkpoint = torch.hub.load_state_dict_from_url( 331 | ckpt_path, map_location='cpu', check_hash=True) 332 | else: 333 | try: 334 | checkpoint = torch.load(ckpt_path, map_location='cpu') 335 | except: 336 | pass 337 | model_without_ddp.load_state_dict(checkpoint['model'], strict = False) 338 | print("Resume checkpoint %s" % ckpt_path) 339 | if args.resuming: 340 | if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval): 341 | optimizer.load_state_dict(checkpoint['optimizer']) 342 | args.start_epoch = checkpoint['epoch'] + 1 343 | if 'scaler' in checkpoint: 344 | loss_scaler.load_state_dict(checkpoint['scaler']) 345 | print("With optim & sched!") 346 | 347 | def load_model_v2_keep_decoder(args, model_without_ddp, optimizer, loss_scaler, ckpt_path): 348 | if os.path.isfile(ckpt_path): 349 | if ckpt_path.startswith('https'): 350 | checkpoint = torch.hub.load_state_dict_from_url( 351 | ckpt_path, map_location='cpu', check_hash=True) 352 | else: 353 | checkpoint = torch.load(ckpt_path, map_location='cpu') 354 | # new_state_dict = OrderedDict() 355 | # for key, val in checkpoint['model'].items(): 356 | # # name = key[7:] 357 | # name = key 358 | # a = name.split('_', 1) 359 | # if a[0] == 'mask' or a[0] == 'decoder': 360 | # continue 361 | # else: 362 | # new_state_dict[name] = val 363 | 364 | model_without_ddp.load_state_dict(checkpoint['model'], strict=False) 365 | # model_without_ddp.load_state_dict(checkpoint['model'], strict = False) 366 | print("Resume checkpoint %s" % ckpt_path) 367 | if args.resuming: 368 | if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval): 369 | optimizer.load_state_dict(checkpoint['optimizer']) 370 | args.start_epoch = checkpoint['epoch'] + 1 371 | if 'scaler' in checkpoint: 372 | loss_scaler.load_state_dict(checkpoint['scaler']) 373 | print("With optim & sched!") 374 | 375 | 376 | def load_model(args, model_without_ddp, optimizer, loss_scaler): 377 | if args.resuming: 378 | # todo manxin changed 379 | if args.resuming.startswith('https'): 380 | checkpoint = torch.hub.load_state_dict_from_url( 381 | args.resuming, map_location='cpu', check_hash=True) 382 | else: 383 | checkpoint = torch.load(args.resuming, map_location='cpu') 384 | model_without_ddp.load_state_dict(checkpoint['model']) 385 | print("Resume checkpoint %s" % args.resuming) 386 | if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval): 387 | optimizer.load_state_dict(checkpoint['optimizer']) 388 | args.start_epoch = checkpoint['epoch'] + 1 389 | if 'scaler' in checkpoint: 390 | loss_scaler.load_state_dict(checkpoint['scaler']) 391 | print("With optim & sched!") 392 | 393 | 394 | def all_reduce_mean(x): 395 | world_size = get_world_size() 396 | if world_size > 1: 397 | x_reduce = torch.tensor(x).cuda() 398 | dist.all_reduce(x_reduce) 399 | x_reduce /= world_size 400 | return x_reduce.item() 401 | else: 402 | return x -------------------------------------------------------------------------------- /util/misc_pre.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import builtins 13 | import datetime 14 | import os 15 | import time 16 | from collections import defaultdict, deque 17 | from pathlib import Path 18 | 19 | import torch 20 | import torch.distributed as dist 21 | from torch._six import inf 22 | from collections import OrderedDict 23 | 24 | 25 | class SmoothedValue(object): 26 | """Track a series of values and provide access to smoothed values over a 27 | window or the global series average. 28 | """ 29 | 30 | def __init__(self, window_size=20, fmt=None): 31 | if fmt is None: 32 | fmt = "{median:.4f} ({global_avg:.4f})" 33 | self.deque = deque(maxlen=window_size) 34 | self.total = 0.0 35 | self.count = 0 36 | self.fmt = fmt 37 | 38 | def update(self, value, n=1): 39 | self.deque.append(value) 40 | self.count += n 41 | self.total += value * n 42 | 43 | def synchronize_between_processes(self): 44 | """ 45 | Warning: does not synchronize the deque! 46 | """ 47 | if not is_dist_avail_and_initialized(): 48 | return 49 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 50 | dist.barrier() 51 | dist.all_reduce(t) 52 | t = t.tolist() 53 | self.count = int(t[0]) 54 | self.total = t[1] 55 | 56 | @property 57 | def median(self): 58 | d = torch.tensor(list(self.deque)) 59 | return d.median().item() 60 | 61 | @property 62 | def avg(self): 63 | d = torch.tensor(list(self.deque), dtype=torch.float32) 64 | return d.mean().item() 65 | 66 | @property 67 | def global_avg(self): 68 | return self.total / self.count 69 | 70 | @property 71 | def max(self): 72 | return max(self.deque) 73 | 74 | @property 75 | def value(self): 76 | return self.deque[-1] 77 | 78 | def __str__(self): 79 | return self.fmt.format( 80 | median=self.median, 81 | avg=self.avg, 82 | global_avg=self.global_avg, 83 | max=self.max, 84 | value=self.value) 85 | 86 | 87 | class MetricLogger(object): 88 | def __init__(self, delimiter="\t"): 89 | self.meters = defaultdict(SmoothedValue) 90 | self.delimiter = delimiter 91 | 92 | def update(self, **kwargs): 93 | for k, v in kwargs.items(): 94 | if v is None: 95 | continue 96 | if isinstance(v, torch.Tensor): 97 | v = v.item() 98 | assert isinstance(v, (float, int)) 99 | self.meters[k].update(v) 100 | 101 | def __getattr__(self, attr): 102 | if attr in self.meters: 103 | return self.meters[attr] 104 | if attr in self.__dict__: 105 | return self.__dict__[attr] 106 | raise AttributeError("'{}' object has no attribute '{}'".format( 107 | type(self).__name__, attr)) 108 | 109 | def __str__(self): 110 | loss_str = [] 111 | for name, meter in self.meters.items(): 112 | loss_str.append( 113 | "{}: {}".format(name, str(meter)) 114 | ) 115 | return self.delimiter.join(loss_str) 116 | 117 | def synchronize_between_processes(self): 118 | for meter in self.meters.values(): 119 | meter.synchronize_between_processes() 120 | 121 | def add_meter(self, name, meter): 122 | self.meters[name] = meter 123 | 124 | def log_every(self, iterable, print_freq, header=None): 125 | i = 0 126 | if not header: 127 | header = '' 128 | start_time = time.time() 129 | end = time.time() 130 | iter_time = SmoothedValue(fmt='{avg:.4f}') 131 | data_time = SmoothedValue(fmt='{avg:.4f}') 132 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 133 | log_msg = [ 134 | header, 135 | '[{0' + space_fmt + '}/{1}]', 136 | 'eta: {eta}', 137 | '{meters}', 138 | 'time: {time}', 139 | 'data: {data}' 140 | ] 141 | if torch.cuda.is_available(): 142 | log_msg.append('max mem: {memory:.0f}') 143 | log_msg = self.delimiter.join(log_msg) 144 | MB = 1024.0 * 1024.0 145 | for obj in iterable: 146 | data_time.update(time.time() - end) 147 | yield obj 148 | iter_time.update(time.time() - end) 149 | if i % print_freq == 0 or i == len(iterable) - 1: 150 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 151 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 152 | if torch.cuda.is_available(): 153 | print(log_msg.format( 154 | i, len(iterable), eta=eta_string, 155 | meters=str(self), 156 | time=str(iter_time), data=str(data_time), 157 | memory=torch.cuda.max_memory_allocated() / MB)) 158 | else: 159 | print(log_msg.format( 160 | i, len(iterable), eta=eta_string, 161 | meters=str(self), 162 | time=str(iter_time), data=str(data_time))) 163 | i += 1 164 | end = time.time() 165 | total_time = time.time() - start_time 166 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 167 | print('{} Total time: {} ({:.4f} s / it)'.format( 168 | header, total_time_str, total_time / len(iterable))) 169 | 170 | 171 | def setup_for_distributed(is_master): 172 | """ 173 | This function disables printing when not in master process 174 | """ 175 | builtin_print = builtins.print 176 | 177 | def print(*args, **kwargs): 178 | force = kwargs.pop('force', False) 179 | force = force or (get_world_size() > 8) 180 | if is_master or force: 181 | now = datetime.datetime.now().time() 182 | builtin_print('[{}] '.format(now), end='') # print with time stamp 183 | builtin_print(*args, **kwargs) 184 | 185 | builtins.print = print 186 | 187 | 188 | def is_dist_avail_and_initialized(): 189 | if not dist.is_available(): 190 | return False 191 | if not dist.is_initialized(): 192 | return False 193 | return True 194 | 195 | 196 | def get_world_size(): 197 | if not is_dist_avail_and_initialized(): 198 | return 1 199 | return dist.get_world_size() 200 | 201 | 202 | def get_rank(): 203 | if not is_dist_avail_and_initialized(): 204 | return 0 205 | return dist.get_rank() 206 | 207 | 208 | def is_main_process(): 209 | return get_rank() == 0 210 | 211 | 212 | def save_on_master(*args, **kwargs): 213 | if is_main_process(): 214 | torch.save(*args, **kwargs) 215 | 216 | 217 | def init_distributed_mode(args): 218 | if args.dist_on_itp: 219 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) 220 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) 221 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) 222 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) 223 | os.environ['LOCAL_RANK'] = str(args.gpu) 224 | os.environ['RANK'] = str(args.rank) 225 | os.environ['WORLD_SIZE'] = str(args.world_size) 226 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] 227 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 228 | args.rank = int(os.environ["RANK"]) 229 | args.world_size = int(os.environ['WORLD_SIZE']) 230 | args.gpu = int(os.environ['LOCAL_RANK']) 231 | elif 'SLURM_PROCID' in os.environ: 232 | args.rank = int(os.environ['SLURM_PROCID']) 233 | args.gpu = args.rank % torch.cuda.device_count() 234 | else: 235 | print('Not using distributed mode') 236 | setup_for_distributed(is_master=True) # hack 237 | args.distributed = False 238 | return 239 | args.distributed = True 240 | 241 | torch.cuda.set_device(args.gpu) 242 | args.dist_backend = 'nccl' 243 | print('| distributed init (rank {}): {}, gpu {}'.format( 244 | args.rank, args.dist_url, args.gpu), flush=True) 245 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 246 | world_size=args.world_size, rank=args.rank) 247 | torch.distributed.barrier() 248 | setup_for_distributed(args.rank == 0) 249 | 250 | 251 | class NativeScalerWithGradNormCount: 252 | state_dict_key = "amp_scaler" 253 | 254 | def __init__(self): 255 | self._scaler = torch.cuda.amp.GradScaler() 256 | 257 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): 258 | self._scaler.scale(loss).backward(create_graph=create_graph) 259 | if update_grad: 260 | if clip_grad is not None: 261 | assert parameters is not None 262 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 263 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 264 | else: 265 | self._scaler.unscale_(optimizer) 266 | norm = get_grad_norm_(parameters) 267 | self._scaler.step(optimizer) 268 | self._scaler.update() 269 | else: 270 | norm = None 271 | return norm 272 | 273 | def state_dict(self): 274 | return self._scaler.state_dict() 275 | 276 | def load_state_dict(self, state_dict): 277 | self._scaler.load_state_dict(state_dict) 278 | 279 | class NativeScalerWithGradNormCount_precip: 280 | state_dict_key = "amp_scaler" 281 | 282 | def __init__(self): 283 | self._scaler = torch.cuda.amp.GradScaler() 284 | 285 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): 286 | self._scaler.scale(loss).backward(create_graph=create_graph) 287 | if update_grad: 288 | if clip_grad is not None: 289 | assert parameters is not None 290 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 291 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 292 | else: 293 | self._scaler.unscale_(optimizer) 294 | norm = get_grad_norm_(parameters) 295 | self._scaler.step(optimizer) 296 | self._scaler.update() 297 | else: 298 | norm = None 299 | return norm 300 | 301 | def state_dict(self): 302 | return self._scaler.state_dict() 303 | 304 | def load_state_dict(self, state_dict): 305 | self._scaler.load_state_dict(state_dict) 306 | 307 | 308 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: 309 | if isinstance(parameters, torch.Tensor): 310 | parameters = [parameters] 311 | parameters = [p for p in parameters if p.grad is not None] 312 | norm_type = float(norm_type) 313 | if len(parameters) == 0: 314 | return torch.tensor(0.) 315 | device = parameters[0].grad.device 316 | if norm_type == inf: 317 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) 318 | else: 319 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) 320 | return total_norm 321 | 322 | 323 | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, best_loss_yn=False): 324 | try: 325 | os.mkdir(str(args.save_dir) + "/" + "ckpt") 326 | # os.mkdir("/home/manxin/codes/FourCastNet/NVLabs-FourCastNet/FourCastNet-master/mae/mae-main/output_dir" + "/" + str(iii)) 327 | except: 328 | pass 329 | ckpt_dir = Path(str(args.save_dir) + "/" + "ckpt") 330 | # output_dir = Path(args.output_dir) 331 | if epoch % 1 == 0: 332 | epoch_name = str(epoch) 333 | else: 334 | epoch_name = str('cur') 335 | if loss_scaler is not None: 336 | checkpoint_paths = [ckpt_dir / ('checkpoint-%s.pth' % epoch_name)] 337 | for checkpoint_path in checkpoint_paths: 338 | to_save = { 339 | 'model': model_without_ddp.state_dict(), 340 | 'optimizer': optimizer.state_dict(), 341 | 'epoch': epoch, 342 | 'scaler': loss_scaler.state_dict(), 343 | 'args': args, 344 | } 345 | 346 | save_on_master(to_save, checkpoint_path) 347 | else: 348 | client_state = {'epoch': epoch} 349 | model.save_checkpoint(save_dir=ckpt_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state) 350 | 351 | def load_model_v1(args, model_without_ddp, optimizer, loss_scaler, ckpt_path): 352 | if os.path.isfile(ckpt_path): 353 | if ckpt_path.startswith('https'): 354 | checkpoint = torch.hub.load_state_dict_from_url( 355 | ckpt_path, map_location='cpu', check_hash=True) 356 | else: 357 | try: 358 | checkpoint = torch.load(ckpt_path, map_location='cpu') 359 | except: 360 | pass 361 | # new_state_dict = OrderedDict() 362 | # for key, val in checkpoint['model'].items(): 363 | # # name = key[7:] 364 | # name = key 365 | # a = name.split('_', 1) 366 | # if a[0] == 'mask' or a[0] == 'decoder': 367 | # continue 368 | # else: 369 | # new_state_dict[name] = val 370 | 371 | # model_without_ddp.load_state_dict(new_state_dict, strict=False) 372 | model_without_ddp.load_state_dict(checkpoint['model'], strict = False) 373 | print("Resume checkpoint %s" % ckpt_path) 374 | if args.resuming: 375 | if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval): 376 | optimizer.load_state_dict(checkpoint['optimizer']) 377 | args.start_epoch = checkpoint['epoch'] + 1 378 | if 'scaler' in checkpoint: 379 | loss_scaler.load_state_dict(checkpoint['scaler']) 380 | print("With optim & sched!") 381 | 382 | def load_model_v2_keep_decoder(args, model_without_ddp, optimizer, loss_scaler, ckpt_path): 383 | if os.path.isfile(ckpt_path): 384 | if ckpt_path.startswith('https'): 385 | checkpoint = torch.hub.load_state_dict_from_url( 386 | ckpt_path, map_location='cpu', check_hash=True) 387 | else: 388 | checkpoint = torch.load(ckpt_path, map_location='cpu') 389 | # new_state_dict = OrderedDict() 390 | # for key, val in checkpoint['model'].items(): 391 | # # name = key[7:] 392 | # name = key 393 | # a = name.split('_', 1) 394 | # if a[0] == 'mask' or a[0] == 'decoder': 395 | # continue 396 | # else: 397 | # new_state_dict[name] = val 398 | 399 | model_without_ddp.load_state_dict(checkpoint['model'], strict=False) 400 | # model_without_ddp.load_state_dict(checkpoint['model'], strict = False) 401 | print("Resume checkpoint %s" % ckpt_path) 402 | if args.resuming: 403 | if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval): 404 | optimizer.load_state_dict(checkpoint['optimizer']) 405 | args.start_epoch = checkpoint['epoch'] + 1 406 | if 'scaler' in checkpoint: 407 | loss_scaler.load_state_dict(checkpoint['scaler']) 408 | print("With optim & sched!") 409 | 410 | 411 | def load_model(args, model_without_ddp, optimizer, loss_scaler): 412 | if args.resuming: 413 | # todo manxin changed 414 | if args.resuming.startswith('https'): 415 | checkpoint = torch.hub.load_state_dict_from_url( 416 | args.resuming, map_location='cpu', check_hash=True) 417 | else: 418 | checkpoint = torch.load(args.resuming, map_location='cpu') 419 | model_without_ddp.load_state_dict(checkpoint['model']) 420 | print("Resume checkpoint %s" % args.resuming) 421 | if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval): 422 | optimizer.load_state_dict(checkpoint['optimizer']) 423 | args.start_epoch = checkpoint['epoch'] + 1 424 | if 'scaler' in checkpoint: 425 | loss_scaler.load_state_dict(checkpoint['scaler']) 426 | print("With optim & sched!") 427 | 428 | 429 | def all_reduce_mean(x): 430 | world_size = get_world_size() 431 | if world_size > 1: 432 | x_reduce = torch.tensor(x).cuda() 433 | dist.all_reduce(x_reduce) 434 | x_reduce /= world_size 435 | return x_reduce.item() 436 | else: 437 | return x -------------------------------------------------------------------------------- /util/pos_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # Position embedding utils 8 | # -------------------------------------------------------- 9 | 10 | import numpy as np 11 | 12 | import torch 13 | 14 | # -------------------------------------------------------- 15 | # 2D sine-cosine position embedding 16 | # References: 17 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 18 | # MoCo v3: https://github.com/facebookresearch/moco-v3 19 | # -------------------------------------------------------- 20 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 21 | """ 22 | grid_size: int of the grid height and width 23 | return: 24 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 25 | """ 26 | grid_h = np.arange(grid_size, dtype=np.float32) 27 | grid_w = np.arange(grid_size, dtype=np.float32) 28 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 29 | grid = np.stack(grid, axis=0) 30 | 31 | grid = grid.reshape([2, 1, grid_size, grid_size]) 32 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 33 | if cls_token: 34 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 35 | return pos_embed 36 | 37 | def get_2d_sincos_pos_embed_v1(embed_dim, list, cls_token=False): 38 | """ 39 | grid_size: int of the grid height and width 40 | return: 41 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 42 | """ 43 | grid_h = np.arange(list[0][0]//list[1][0], dtype=np.float32) 44 | grid_w = np.arange(list[0][1]//list[1][1], dtype=np.float32) 45 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 46 | grid = np.stack(grid, axis=0) 47 | 48 | grid = grid.reshape([2, 1, int(list[0][0]//list[1][0]), int(list[0][1]//list[1][1])]) 49 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 50 | if cls_token: 51 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 52 | return pos_embed 53 | 54 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 55 | assert embed_dim % 2 == 0 56 | 57 | # use half of dimensions to encode grid_h 58 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 59 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 60 | 61 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 62 | return emb 63 | 64 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 65 | """ 66 | embed_dim: output dimension for each position 67 | pos: a list of positions to be encoded: size (M,) 68 | out: (M, D) 69 | """ 70 | assert embed_dim % 2 == 0 71 | omega = np.arange(embed_dim // 2, dtype=np.float) 72 | omega /= embed_dim / 2. 73 | omega = 1. / 10000**omega # (D/2,) 74 | 75 | pos = pos.reshape(-1) # (M,) 76 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 77 | 78 | emb_sin = np.sin(out) # (M, D/2) 79 | emb_cos = np.cos(out) # (M, D/2) 80 | 81 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 82 | return emb 83 | 84 | 85 | # -------------------------------------------------------- 86 | # Interpolate position embeddings for high-resolution 87 | # References: 88 | # DeiT: https://github.com/facebookresearch/deit 89 | # -------------------------------------------------------- 90 | def interpolate_pos_embed(model, checkpoint_model): 91 | if 'pos_embed' in checkpoint_model: 92 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 93 | embedding_size = pos_embed_checkpoint.shape[-1] 94 | num_patches = model.patch_embed.num_patches 95 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 96 | # height (== width) for the checkpoint position embedding 97 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 98 | # height (== width) for the new position embedding 99 | new_size = int(num_patches ** 0.5) 100 | # class_token and dist_token are kept unchanged 101 | if orig_size != new_size: 102 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 103 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 104 | # only the position tokens are interpolated 105 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 106 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 107 | pos_tokens = torch.nn.functional.interpolate( 108 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 109 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 110 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 111 | checkpoint_model['pos_embed'] = new_pos_embed 112 | -------------------------------------------------------------------------------- /utils/YParams.py: -------------------------------------------------------------------------------- 1 | from ruamel.yaml import YAML 2 | import logging 3 | 4 | class YParams(): 5 | """ Yaml file parser """ 6 | def __init__(self, yaml_filename, config_name, print_params=False): 7 | self._yaml_filename = yaml_filename 8 | self._config_name = config_name 9 | self.params = {} 10 | 11 | if print_params: 12 | print("------------------ Configuration ------------------") 13 | 14 | with open(yaml_filename) as _file: 15 | 16 | for key, val in YAML().load(_file)[config_name].items(): 17 | if print_params: print(key, val) 18 | if val =='None': val = None 19 | 20 | self.params[key] = val 21 | self.__setattr__(key, val) 22 | 23 | if print_params: 24 | print("---------------------------------------------------") 25 | 26 | def __getitem__(self, key): 27 | return self.params[key] 28 | 29 | def __setitem__(self, key, val): 30 | self.params[key] = val 31 | self.__setattr__(key, val) 32 | 33 | def __contains__(self, key): 34 | return (key in self.params) 35 | 36 | def update_params(self, config): 37 | for key, val in config.items(): 38 | self.params[key] = val 39 | self.__setattr__(key, val) 40 | 41 | def log(self): 42 | logging.info("------------------ Configuration ------------------") 43 | logging.info("Configuration file: "+str(self._yaml_filename)) 44 | logging.info("Configuration name: "+str(self._config_name)) 45 | for key, val in self.params.items(): 46 | logging.info(str(key) + ' ' + str(val)) 47 | logging.info("---------------------------------------------------") 48 | -------------------------------------------------------------------------------- /utils/darcy_loss.py: -------------------------------------------------------------------------------- 1 | #MIT License 2 | # 3 | #Copyright (c) 2020 Zongyi Li 4 | # 5 | #Permission is hereby granted, free of charge, to any person obtaining a copy 6 | #of this software and associated documentation files (the "Software"), to deal 7 | #in the Software without restriction, including without limitation the rights 8 | #to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | #copies of the Software, and to permit persons to whom the Software is 10 | #furnished to do so, subject to the following conditions: 11 | # 12 | #The above copyright notice and this permission notice shall be included in all 13 | #copies or substantial portions of the Software. 14 | # 15 | #THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | #IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | #FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | #AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | #LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | #OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | #SOFTWARE. 22 | 23 | import torch 24 | import numpy as np 25 | import scipy.io 26 | import h5py 27 | import torch.nn as nn 28 | 29 | 30 | ################################################# 31 | # 32 | # Utilities 33 | # 34 | ################################################# 35 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 36 | 37 | # reading data 38 | class MatReader(object): 39 | def __init__(self, file_path, to_torch=True, to_cuda=False, to_float=True): 40 | super(MatReader, self).__init__() 41 | 42 | self.to_torch = to_torch 43 | self.to_cuda = to_cuda 44 | self.to_float = to_float 45 | 46 | self.file_path = file_path 47 | 48 | self.data = None 49 | self.old_mat = None 50 | self._load_file() 51 | 52 | def _load_file(self): 53 | try: 54 | self.data = scipy.io.loadmat(self.file_path) 55 | self.old_mat = True 56 | except: 57 | self.data = h5py.File(self.file_path) 58 | self.old_mat = False 59 | 60 | def load_file(self, file_path): 61 | self.file_path = file_path 62 | self._load_file() 63 | 64 | def read_field(self, field): 65 | x = self.data[field] 66 | 67 | if not self.old_mat: 68 | x = x[()] 69 | x = np.transpose(x, axes=range(len(x.shape) - 1, -1, -1)) 70 | 71 | if self.to_float: 72 | x = x.astype(np.float32) 73 | 74 | if self.to_torch: 75 | x = torch.from_numpy(x) 76 | 77 | if self.to_cuda: 78 | x = x.cuda() 79 | 80 | return x 81 | 82 | def set_cuda(self, to_cuda): 83 | self.to_cuda = to_cuda 84 | 85 | def set_torch(self, to_torch): 86 | self.to_torch = to_torch 87 | 88 | def set_float(self, to_float): 89 | self.to_float = to_float 90 | 91 | # normalization, pointwise gaussian 92 | class UnitGaussianNormalizer(object): 93 | def __init__(self, x, eps=0.00001): 94 | super(UnitGaussianNormalizer, self).__init__() 95 | 96 | # x could be in shape of ntrain*n or ntrain*T*n or ntrain*n*T 97 | self.mean = torch.mean(x, 0) 98 | self.std = torch.std(x, 0) 99 | self.eps = eps 100 | 101 | def encode(self, x): 102 | x = (x - self.mean) / (self.std + self.eps) 103 | return x.float() 104 | 105 | def decode(self, x, sample_idx=None): 106 | if sample_idx is None: 107 | std = self.std + self.eps # n 108 | mean = self.mean 109 | else: 110 | if len(self.mean.shape) == len(sample_idx[0].shape): 111 | std = self.std[sample_idx] + self.eps # batch*n 112 | mean = self.mean[sample_idx] 113 | if len(self.mean.shape) > len(sample_idx[0].shape): 114 | std = self.std[:,sample_idx]+ self.eps # T*batch*n 115 | mean = self.mean[:,sample_idx] 116 | 117 | # x is in shape of batch*n or T*batch*n 118 | x = (x * std) + mean 119 | return x.float() 120 | 121 | def cuda(self): 122 | self.mean = self.mean.cuda() 123 | self.std = self.std.cuda() 124 | 125 | def cpu(self): 126 | self.mean = self.mean.cpu() 127 | self.std = self.std.cpu() 128 | 129 | # normalization, Gaussian 130 | class GaussianNormalizer(object): 131 | def __init__(self, x, eps=0.00001): 132 | super(GaussianNormalizer, self).__init__() 133 | 134 | self.mean = torch.mean(x) 135 | self.std = torch.std(x) 136 | self.eps = eps 137 | 138 | def encode(self, x): 139 | x = (x - self.mean) / (self.std + self.eps) 140 | return x 141 | 142 | def decode(self, x, sample_idx=None): 143 | x = (x * (self.std + self.eps)) + self.mean 144 | return x 145 | 146 | def cuda(self): 147 | self.mean = self.mean.cuda() 148 | self.std = self.std.cuda() 149 | 150 | def cpu(self): 151 | self.mean = self.mean.cpu() 152 | self.std = self.std.cpu() 153 | 154 | 155 | # normalization, scaling by range 156 | class RangeNormalizer(object): 157 | def __init__(self, x, low=0.0, high=1.0): 158 | super(RangeNormalizer, self).__init__() 159 | mymin = torch.min(x, 0)[0].view(-1) 160 | mymax = torch.max(x, 0)[0].view(-1) 161 | 162 | self.a = (high - low)/(mymax - mymin) 163 | self.b = -self.a*mymax + high 164 | 165 | def encode(self, x): 166 | s = x.size() 167 | x = x.view(s[0], -1) 168 | x = self.a*x + self.b 169 | x = x.view(s) 170 | return x 171 | 172 | def decode(self, x): 173 | s = x.size() 174 | x = x.view(s[0], -1) 175 | x = (x - self.b)/self.a 176 | x = x.view(s) 177 | return x 178 | 179 | #loss function with rel/abs Lp loss 180 | class LpLoss(object): 181 | def __init__(self, d=2, p=2, size_average=True, reduction=True): 182 | super(LpLoss, self).__init__() 183 | 184 | #Dimension and Lp-norm type are postive 185 | assert d > 0 and p > 0 186 | 187 | self.d = d 188 | self.p = p 189 | self.reduction = reduction 190 | self.size_average = size_average 191 | 192 | def abs(self, x, y): 193 | num_examples = x.size()[0] 194 | 195 | #Assume uniform mesh 196 | h = 1.0 / (x.size()[1] - 1.0) 197 | 198 | all_norms = (h**(self.d/self.p))*torch.norm(x.view(num_examples,-1) - y.view(num_examples,-1), self.p, 1) 199 | 200 | if self.reduction: 201 | if self.size_average: 202 | return torch.mean(all_norms) 203 | else: 204 | return torch.sum(all_norms) 205 | 206 | return all_norms 207 | 208 | def rel(self, x, y): 209 | num_examples = x.size()[0] 210 | 211 | diff_norms = torch.norm(x.reshape(num_examples,-1) - y.reshape(num_examples,-1), self.p, 1) 212 | y_norms = torch.norm(y.reshape(num_examples,-1), self.p, 1) 213 | 214 | if self.reduction: 215 | if self.size_average: 216 | return torch.mean(diff_norms/y_norms) 217 | else: 218 | return torch.sum(diff_norms/y_norms) 219 | 220 | return diff_norms/y_norms 221 | 222 | def __call__(self, x, y): 223 | return self.rel(x, y) 224 | 225 | # Sobolev norm (HS norm) 226 | # where we also compare the numerical derivatives between the output and target 227 | class HsLoss(object): 228 | def __init__(self, d=2, p=2, k=1, a=None, group=False, size_average=True, reduction=True): 229 | super(HsLoss, self).__init__() 230 | 231 | #Dimension and Lp-norm type are postive 232 | assert d > 0 and p > 0 233 | 234 | self.d = d 235 | self.p = p 236 | self.k = k 237 | self.balanced = group 238 | self.reduction = reduction 239 | self.size_average = size_average 240 | 241 | if a == None: 242 | a = [1,] * k 243 | self.a = a 244 | 245 | def rel(self, x, y): 246 | num_examples = x.size()[0] 247 | diff_norms = torch.norm(x.reshape(num_examples,-1) - y.reshape(num_examples,-1), self.p, 1) 248 | y_norms = torch.norm(y.reshape(num_examples,-1), self.p, 1) 249 | if self.reduction: 250 | if self.size_average: 251 | return torch.mean(diff_norms/y_norms) 252 | else: 253 | return torch.sum(diff_norms/y_norms) 254 | return diff_norms/y_norms 255 | 256 | def __call__(self, x, y, a=None): 257 | nx = x.size()[1] 258 | ny = x.size()[2] 259 | k = self.k 260 | balanced = self.balanced 261 | a = self.a 262 | x = x.view(x.shape[0], nx, ny, -1) 263 | y = y.view(y.shape[0], nx, ny, -1) 264 | 265 | k_x = torch.cat((torch.arange(start=0, end=nx//2, step=1),torch.arange(start=-nx//2, end=0, step=1)), 0).reshape(nx,1).repeat(1,ny) 266 | k_y = torch.cat((torch.arange(start=0, end=ny//2, step=1),torch.arange(start=-ny//2, end=0, step=1)), 0).reshape(1,ny).repeat(nx,1) 267 | k_x = torch.abs(k_x).reshape(1,nx,ny,1).to(x.device) 268 | k_y = torch.abs(k_y).reshape(1,nx,ny,1).to(x.device) 269 | 270 | x = torch.fft.fftn(x, dim=[1, 2]) 271 | y = torch.fft.fftn(y, dim=[1, 2]) 272 | 273 | if balanced==False: 274 | weight = 1 275 | if k >= 1: 276 | weight += a[0]**2 * (k_x**2 + k_y**2) 277 | if k >= 2: 278 | weight += a[1]**2 * (k_x**4 + 2*k_x**2*k_y**2 + k_y**4) 279 | weight = torch.sqrt(weight) 280 | loss = self.rel(x*weight, y*weight) 281 | else: 282 | loss = self.rel(x, y) 283 | if k >= 1: 284 | weight = a[0] * torch.sqrt(k_x**2 + k_y**2) 285 | loss += self.rel(x*weight, y*weight) 286 | if k >= 2: 287 | weight = a[1] * torch.sqrt(k_x**4 + 2*k_x**2*k_y**2 + k_y**4) 288 | loss += self.rel(x*weight, y*weight) 289 | loss = loss / (k+1) 290 | 291 | return loss 292 | 293 | # A simple feedforward neural network 294 | class DenseNet(torch.nn.Module): 295 | def __init__(self, layers, nonlinearity, out_nonlinearity=None, normalize=False): 296 | super(DenseNet, self).__init__() 297 | 298 | self.n_layers = len(layers) - 1 299 | 300 | assert self.n_layers >= 1 301 | 302 | self.layers = nn.ModuleList() 303 | 304 | for j in range(self.n_layers): 305 | self.layers.append(nn.Linear(layers[j], layers[j+1])) 306 | 307 | if j != self.n_layers - 1: 308 | if normalize: 309 | self.layers.append(nn.BatchNorm1d(layers[j+1])) 310 | 311 | self.layers.append(nonlinearity()) 312 | 313 | if out_nonlinearity is not None: 314 | self.layers.append(out_nonlinearity()) 315 | 316 | def forward(self, x): 317 | for _, l in enumerate(self.layers): 318 | x = l(x) 319 | 320 | return x 321 | -------------------------------------------------------------------------------- /utils/data_loader_multifiles.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import glob 3 | import torch 4 | import random 5 | import numpy as np 6 | from torch.utils.data import DataLoader, Dataset 7 | from torch.utils.data.distributed import DistributedSampler 8 | from torch import Tensor 9 | import h5py 10 | import math 11 | #import cv2 12 | from utils.img_utils import reshape_fields, reshape_precip 13 | 14 | 15 | def get_data_loader(params, files_pattern, distributed, train): 16 | 17 | dataset = GetDataset(params, files_pattern, train) 18 | sampler = DistributedSampler(dataset, shuffle=train) if distributed else None 19 | 20 | 21 | dataloader = DataLoader(dataset, 22 | batch_size=int(params.batch_size), 23 | num_workers=params.num_workers, 24 | shuffle=False, #(sampler is None), 25 | sampler=sampler if train else None, 26 | drop_last=True, 27 | pin_memory=torch.cuda.is_available()) 28 | 29 | if train: 30 | return dataloader, dataset, sampler 31 | else: 32 | return dataloader, dataset 33 | 34 | class GetDataset(Dataset): 35 | def __init__(self, params, location, train): 36 | self.params = params 37 | self.location = location 38 | self.train = train 39 | self.dt = params.dt 40 | self.n_history = params.n_history 41 | self.in_channels = np.array(params.in_channels) 42 | self.out_channels = np.array(params.out_channels) 43 | self.n_in_channels = len(self.in_channels) 44 | self.n_out_channels = len(self.out_channels) 45 | self.crop_size_x = params.crop_size_x 46 | self.crop_size_y = params.crop_size_y 47 | self.roll = params.roll 48 | self._get_files_stats() 49 | self.two_step_training = params.two_step_training 50 | self.orography = params.orography 51 | self.precip = True if "precip" in params else False 52 | # self.add_noise = params.add_noise if train else False 53 | self.add_noise = False 54 | 55 | 56 | if self.precip: 57 | path = params.precip+'/train' if train else params.precip+'/test' 58 | self.precip_paths = glob.glob(path + "/*.h5") 59 | self.precip_paths.sort() 60 | 61 | try: 62 | # print("???????????????self.normalize") 63 | # print(self.normalize) 64 | self.normalize = params.normalize 65 | 66 | except: 67 | self.normalize = True #by default turn on normalization if not specified in config 68 | 69 | if self.orography: 70 | self.orography_path = params.orography_path 71 | 72 | def _get_files_stats(self): 73 | self.files_paths = glob.glob(self.location + "/*.h5") 74 | self.files_paths.sort() 75 | self.n_years = len(self.files_paths) 76 | with h5py.File(self.files_paths[0], 'r') as _f: 77 | logging.info("Getting file stats from {}".format(self.files_paths[0])) 78 | self.n_samples_per_year = _f['fields'].shape[0] 79 | #original image shape (before padding) 80 | self.img_shape_x = _f['fields'].shape[2] -1#just get rid of one of the pixels 81 | self.img_shape_y = _f['fields'].shape[3] 82 | 83 | self.n_samples_total = self.n_years * self.n_samples_per_year 84 | self.files = [None for _ in range(self.n_years)] 85 | self.precip_files = [None for _ in range(self.n_years)] 86 | logging.info("Number of samples per year: {}".format(self.n_samples_per_year)) 87 | logging.info("Found data at path {}. Number of examples: {}. Image Shape: {} x {} x {}".format(self.location, self.n_samples_total, self.img_shape_x, self.img_shape_y, self.n_in_channels)) 88 | logging.info("Delta t: {} hours".format(6*self.dt)) 89 | logging.info("Including {} hours of past history in training at a frequency of {} hours".format(6*self.dt*self.n_history, 6*self.dt)) 90 | 91 | 92 | def _open_file(self, year_idx): 93 | _file = h5py.File(self.files_paths[year_idx], 'r') 94 | self.files[year_idx] = _file['fields'] 95 | if self.orography: 96 | _orog_file = h5py.File(self.orography_path, 'r') 97 | self.orography_field = _orog_file['orog'] 98 | if self.precip: 99 | self.precip_files[year_idx] = h5py.File(self.precip_paths[year_idx], 'r')['tp'] 100 | 101 | 102 | def __len__(self): 103 | return self.n_samples_total 104 | 105 | 106 | def __getitem__(self, global_idx): 107 | year_idx = int(global_idx/self.n_samples_per_year) #which year we are on 108 | local_idx = int(global_idx%self.n_samples_per_year) #which sample in that year we are on - determines indices for centering 109 | 110 | y_roll = np.random.randint(0, 1440) if self.train else 0#roll image in y direction 111 | 112 | #open image file 113 | if self.files[year_idx] is None: 114 | self._open_file(year_idx) 115 | 116 | if not self.precip: 117 | #if we are not at least self.dt*n_history timesteps into the prediction 118 | if local_idx < self.dt*self.n_history: 119 | local_idx += self.dt*self.n_history 120 | 121 | #if we are on the last image in a year predict identity, else predict next timestep 122 | step = 0 if local_idx >= self.n_samples_per_year-self.dt else self.dt 123 | else: 124 | inp_local_idx = local_idx 125 | tar_local_idx = local_idx 126 | #if we are on the last image in a year predict identity, else predict next timestep 127 | step = 0 if tar_local_idx >= self.n_samples_per_year-self.dt else self.dt 128 | # first year has 2 missing samples in precip (they are first two time points) 129 | if year_idx == 0: 130 | lim = 1458 131 | local_idx = local_idx%lim 132 | inp_local_idx = local_idx + 2 133 | tar_local_idx = local_idx 134 | step = 0 if tar_local_idx >= lim-self.dt else self.dt 135 | 136 | #if two_step_training flag is true then ensure that local_idx is not the last or last but one sample in a year 137 | if self.two_step_training: 138 | if local_idx >= self.n_samples_per_year - 2*self.dt: 139 | #set local_idx to last possible sample in a year that allows taking two steps forward 140 | local_idx = self.n_samples_per_year - 3*self.dt 141 | 142 | if self.train and self.roll: 143 | y_roll = random.randint(0, self.img_shape_y) 144 | else: 145 | y_roll = 0 146 | 147 | if self.orography: 148 | orog = self.orography_field[0:720] 149 | else: 150 | orog = None 151 | 152 | if self.train and (self.crop_size_x or self.crop_size_y): 153 | rnd_x = random.randint(0, self.img_shape_x-self.crop_size_x) 154 | rnd_y = random.randint(0, self.img_shape_y-self.crop_size_y) 155 | else: 156 | rnd_x = 0 157 | rnd_y = 0 158 | 159 | if self.precip: 160 | return reshape_fields(self.files[year_idx][inp_local_idx, self.in_channels], 'inp', self.crop_size_x, self.crop_size_y, rnd_x, rnd_y,self.params, y_roll, self.train), \ 161 | reshape_precip(self.precip_files[year_idx][tar_local_idx+step], 'tar', self.crop_size_x, self.crop_size_y, rnd_x, rnd_y, self.params, y_roll, self.train) 162 | else: 163 | if self.two_step_training: 164 | return reshape_fields(self.files[year_idx][(local_idx-self.dt*self.n_history):(local_idx+1):self.dt, self.in_channels], 'inp', self.crop_size_x, self.crop_size_y, rnd_x, rnd_y,self.params, y_roll, self.train, self.normalize, orog, self.add_noise), \ 165 | reshape_fields(self.files[year_idx][local_idx + step:local_idx + step + 2, self.out_channels], 'tar', self.crop_size_x, self.crop_size_y, rnd_x, rnd_y, self.params, y_roll, self.train, self.normalize, orog) 166 | else: 167 | return reshape_fields(self.files[year_idx][(local_idx-self.dt*self.n_history):(local_idx+1):self.dt, self.in_channels], 'inp', self.crop_size_x, self.crop_size_y, rnd_x, rnd_y,self.params, y_roll, self.train, self.normalize, orog, self.add_noise), \ 168 | reshape_fields(self.files[year_idx][local_idx + step, self.out_channels], 'tar', self.crop_size_x, self.crop_size_y, rnd_x, rnd_y, self.params, y_roll, self.train, self.normalize, orog) 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | -------------------------------------------------------------------------------- /utils/data_loader_multifiles_precip.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import glob 3 | import torch 4 | import random 5 | import numpy as np 6 | from torch.utils.data import DataLoader, Dataset 7 | from torch.utils.data.distributed import DistributedSampler 8 | from torch import Tensor 9 | import h5py 10 | import math 11 | #import cv2 12 | from utils.img_utils import reshape_fields, reshape_precip 13 | 14 | precip = True 15 | def get_data_loader(params, files_pattern, distributed, train): 16 | 17 | dataset = GetDataset(params, files_pattern, train) 18 | sampler = DistributedSampler(dataset, shuffle=train) if distributed else None 19 | 20 | 21 | dataloader = DataLoader(dataset, 22 | batch_size=int(params.batch_size), 23 | num_workers=params.num_workers, 24 | shuffle=False, #(sampler is None), 25 | sampler=sampler if train else None, 26 | drop_last=True, 27 | pin_memory=torch.cuda.is_available()) 28 | 29 | if train: 30 | return dataloader, dataset, sampler 31 | else: 32 | return dataloader, dataset 33 | 34 | class GetDataset(Dataset): 35 | def __init__(self, params, location, train): 36 | self.params = params 37 | self.location = location 38 | self.train = train 39 | self.dt = params.dt 40 | self.n_history = params.n_history 41 | self.in_channels = np.array(params.in_channels) 42 | self.out_channels = np.array(params.out_channels) 43 | self.n_in_channels = len(self.in_channels) 44 | self.n_out_channels = len(self.out_channels) 45 | self.crop_size_x = params.crop_size_x 46 | self.crop_size_y = params.crop_size_y 47 | self.roll = params.roll 48 | self._get_files_stats() 49 | self.two_step_training = params.two_step_training 50 | self.orography = params.orography 51 | self.precip = True if "precip" in params else False 52 | # self.add_noise = params.add_noise if train else False 53 | self.add_noise = False 54 | 55 | 56 | if self.precip: 57 | path = params.precip+'/train' if train else params.precip+'/test' 58 | self.precip_paths = glob.glob(path + "/*.h5") 59 | self.precip_paths.sort() 60 | 61 | try: 62 | # print("???????????????self.normalize") 63 | # print(self.normalize) 64 | self.normalize = params.normalize 65 | 66 | except: 67 | self.normalize = True #by default turn on normalization if not specified in config 68 | 69 | if self.orography: 70 | self.orography_path = params.orography_path 71 | 72 | def _get_files_stats(self): 73 | self.files_paths = glob.glob(self.location + "/*.h5") 74 | self.files_paths.sort() 75 | self.n_years = len(self.files_paths) 76 | with h5py.File(self.files_paths[0], 'r') as _f: 77 | logging.info("Getting file stats from {}".format(self.files_paths[0])) 78 | self.n_samples_per_year = _f['fields'].shape[0] 79 | #original image shape (before padding) 80 | self.img_shape_x = _f['fields'].shape[2] -1#just get rid of one of the pixels 81 | self.img_shape_y = _f['fields'].shape[3] 82 | 83 | self.n_samples_total = self.n_years * self.n_samples_per_year 84 | self.files = [None for _ in range(self.n_years)] 85 | self.precip_files = [None for _ in range(self.n_years)] 86 | logging.info("Number of samples per year: {}".format(self.n_samples_per_year)) 87 | logging.info("Found data at path {}. Number of examples: {}. Image Shape: {} x {} x {}".format(self.location, self.n_samples_total, self.img_shape_x, self.img_shape_y, self.n_in_channels)) 88 | logging.info("Delta t: {} hours".format(6*self.dt)) 89 | logging.info("Including {} hours of past history in training at a frequency of {} hours".format(6*self.dt*self.n_history, 6*self.dt)) 90 | 91 | 92 | def _open_file(self, year_idx): 93 | _file = h5py.File(self.files_paths[year_idx], 'r') 94 | self.files[year_idx] = _file['fields'] 95 | if self.orography: 96 | _orog_file = h5py.File(self.orography_path, 'r') 97 | self.orography_field = _orog_file['orog'] 98 | if self.precip: 99 | self.precip_files[year_idx] = h5py.File(self.precip_paths[year_idx], 'r')['tp'] 100 | 101 | 102 | def __len__(self): 103 | return self.n_samples_total 104 | 105 | 106 | def __getitem__(self, global_idx): 107 | year_idx = int(global_idx/self.n_samples_per_year) #which year we are on 108 | local_idx = int(global_idx%self.n_samples_per_year) #which sample in that year we are on - determines indices for centering 109 | 110 | y_roll = np.random.randint(0, 1440) if self.train else 0#roll image in y direction 111 | 112 | #open image file 113 | if self.files[year_idx] is None: 114 | self._open_file(year_idx) 115 | 116 | if not self.precip: 117 | #if we are not at least self.dt*n_history timesteps into the prediction 118 | if local_idx < self.dt*self.n_history: 119 | local_idx += self.dt*self.n_history 120 | 121 | #if we are on the last image in a year predict identity, else predict next timestep 122 | step = 0 if local_idx >= self.n_samples_per_year-self.dt else self.dt 123 | else: 124 | inp_local_idx = local_idx 125 | tar_local_idx = local_idx 126 | #if we are on the last image in a year predict identity, else predict next timestep 127 | step = 0 if tar_local_idx >= self.n_samples_per_year-self.dt else self.dt 128 | # first year has 2 missing samples in precip (they are first two time points) 129 | if year_idx == 0: 130 | lim = 1458 131 | local_idx = local_idx%lim 132 | inp_local_idx = local_idx + 2 133 | tar_local_idx = local_idx 134 | step = 0 if tar_local_idx >= lim-self.dt else self.dt 135 | 136 | #if two_step_training flag is true then ensure that local_idx is not the last or last but one sample in a year 137 | if self.two_step_training: 138 | if local_idx >= self.n_samples_per_year - 2*self.dt: 139 | #set local_idx to last possible sample in a year that allows taking two steps forward 140 | local_idx = self.n_samples_per_year - 3*self.dt 141 | 142 | if self.train and self.roll: 143 | y_roll = random.randint(0, self.img_shape_y) 144 | else: 145 | y_roll = 0 146 | 147 | if self.orography: 148 | orog = self.orography_field[0:720] 149 | else: 150 | orog = None 151 | 152 | if self.train and (self.crop_size_x or self.crop_size_y): 153 | rnd_x = random.randint(0, self.img_shape_x-self.crop_size_x) 154 | rnd_y = random.randint(0, self.img_shape_y-self.crop_size_y) 155 | else: 156 | rnd_x = 0 157 | rnd_y = 0 158 | 159 | if self.precip: 160 | return reshape_fields(self.files[year_idx][inp_local_idx, self.in_channels], 'inp', self.crop_size_x, self.crop_size_y, rnd_x, rnd_y,self.params, y_roll, self.train), \ 161 | reshape_precip(self.precip_files[year_idx][tar_local_idx+step], 'tar', self.crop_size_x, self.crop_size_y, rnd_x, rnd_y, self.params, y_roll, self.train) 162 | else: 163 | if self.two_step_training: 164 | return reshape_fields(self.files[year_idx][(local_idx-self.dt*self.n_history):(local_idx+1):self.dt, self.in_channels], 'inp', self.crop_size_x, self.crop_size_y, rnd_x, rnd_y,self.params, y_roll, self.train, self.normalize, orog, self.add_noise), \ 165 | reshape_fields(self.files[year_idx][local_idx + step:local_idx + step + 2, self.out_channels], 'tar', self.crop_size_x, self.crop_size_y, rnd_x, rnd_y, self.params, y_roll, self.train, self.normalize, orog) 166 | else: 167 | return reshape_fields(self.files[year_idx][(local_idx-self.dt*self.n_history):(local_idx+1):self.dt, self.in_channels], 'inp', self.crop_size_x, self.crop_size_y, rnd_x, rnd_y,self.params, y_roll, self.train, self.normalize, orog, self.add_noise), \ 168 | reshape_fields(self.files[year_idx][local_idx + step, self.out_channels], 'tar', self.crop_size_x, self.crop_size_y, rnd_x, rnd_y, self.params, y_roll, self.train, self.normalize, orog) 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | -------------------------------------------------------------------------------- /utils/data_loader_multifiles_twoStep.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import glob 3 | import torch 4 | import random 5 | import numpy as np 6 | from torch.utils.data import DataLoader, Dataset 7 | from torch.utils.data.distributed import DistributedSampler 8 | from torch import Tensor 9 | import h5py 10 | import math 11 | #import cv2 12 | from utils.img_utils import reshape_fields, reshape_precip 13 | 14 | two_step_training = True 15 | def get_data_loader(params, files_pattern, distributed, train): 16 | 17 | dataset = GetDataset(params, files_pattern, train) 18 | sampler = DistributedSampler(dataset, shuffle=train) if distributed else None 19 | 20 | 21 | dataloader = DataLoader(dataset, 22 | batch_size=int(params.batch_size), 23 | num_workers=params.num_workers, 24 | shuffle=False, #(sampler is None), 25 | sampler=sampler if train else None, 26 | drop_last=True, 27 | pin_memory=torch.cuda.is_available()) 28 | 29 | if train: 30 | return dataloader, dataset, sampler 31 | else: 32 | return dataloader, dataset 33 | 34 | class GetDataset(Dataset): 35 | def __init__(self, params, location, train): 36 | self.params = params 37 | self.location = location 38 | self.train = train 39 | self.dt = params.dt 40 | self.n_history = params.n_history 41 | self.in_channels = np.array(params.in_channels) 42 | self.out_channels = np.array(params.out_channels) 43 | self.n_in_channels = len(self.in_channels) 44 | self.n_out_channels = len(self.out_channels) 45 | self.crop_size_x = params.crop_size_x 46 | self.crop_size_y = params.crop_size_y 47 | self.roll = params.roll 48 | self._get_files_stats() 49 | # self.two_step_training = params.two_step_training 50 | self.two_step_training = two_step_training 51 | 52 | self.orography = params.orography 53 | self.precip = True if "precip" in params else False 54 | # self.add_noise = params.add_noise if train else False 55 | self.add_noise = False 56 | 57 | 58 | if self.precip: 59 | path = params.precip+'/train' if train else params.precip+'/test' 60 | self.precip_paths = glob.glob(path + "/*.h5") 61 | self.precip_paths.sort() 62 | 63 | try: 64 | # print("???????????????self.normalize") 65 | # print(self.normalize) 66 | self.normalize = params.normalize 67 | 68 | except: 69 | self.normalize = True #by default turn on normalization if not specified in config 70 | 71 | if self.orography: 72 | self.orography_path = params.orography_path 73 | 74 | def _get_files_stats(self): 75 | self.files_paths = glob.glob(self.location + "/*.h5") 76 | self.files_paths.sort() 77 | self.n_years = len(self.files_paths) 78 | with h5py.File(self.files_paths[0], 'r') as _f: 79 | logging.info("Getting file stats from {}".format(self.files_paths[0])) 80 | self.n_samples_per_year = _f['fields'].shape[0] 81 | #original image shape (before padding) 82 | self.img_shape_x = _f['fields'].shape[2] -1#just get rid of one of the pixels 83 | self.img_shape_y = _f['fields'].shape[3] 84 | 85 | self.n_samples_total = self.n_years * self.n_samples_per_year 86 | self.files = [None for _ in range(self.n_years)] 87 | self.precip_files = [None for _ in range(self.n_years)] 88 | logging.info("Number of samples per year: {}".format(self.n_samples_per_year)) 89 | logging.info("Found data at path {}. Number of examples: {}. Image Shape: {} x {} x {}".format(self.location, self.n_samples_total, self.img_shape_x, self.img_shape_y, self.n_in_channels)) 90 | logging.info("Delta t: {} hours".format(6*self.dt)) 91 | logging.info("Including {} hours of past history in training at a frequency of {} hours".format(6*self.dt*self.n_history, 6*self.dt)) 92 | 93 | 94 | def _open_file(self, year_idx): 95 | _file = h5py.File(self.files_paths[year_idx], 'r') 96 | self.files[year_idx] = _file['fields'] 97 | if self.orography: 98 | _orog_file = h5py.File(self.orography_path, 'r') 99 | self.orography_field = _orog_file['orog'] 100 | if self.precip: 101 | self.precip_files[year_idx] = h5py.File(self.precip_paths[year_idx], 'r')['tp'] 102 | 103 | 104 | def __len__(self): 105 | return self.n_samples_total 106 | 107 | 108 | def __getitem__(self, global_idx): 109 | year_idx = int(global_idx/self.n_samples_per_year) #which year we are on 110 | local_idx = int(global_idx%self.n_samples_per_year) #which sample in that year we are on - determines indices for centering 111 | 112 | y_roll = np.random.randint(0, 1440) if self.train else 0#roll image in y direction 113 | 114 | #open image file 115 | if self.files[year_idx] is None: 116 | self._open_file(year_idx) 117 | 118 | if not self.precip: 119 | #if we are not at least self.dt*n_history timesteps into the prediction 120 | if local_idx < self.dt*self.n_history: 121 | local_idx += self.dt*self.n_history 122 | 123 | #if we are on the last image in a year predict identity, else predict next timestep 124 | step = 0 if local_idx >= self.n_samples_per_year-self.dt else self.dt 125 | else: 126 | inp_local_idx = local_idx 127 | tar_local_idx = local_idx 128 | #if we are on the last image in a year predict identity, else predict next timestep 129 | step = 0 if tar_local_idx >= self.n_samples_per_year-self.dt else self.dt 130 | # first year has 2 missing samples in precip (they are first two time points) 131 | if year_idx == 0: 132 | lim = 1458 133 | local_idx = local_idx%lim 134 | inp_local_idx = local_idx + 2 135 | tar_local_idx = local_idx 136 | step = 0 if tar_local_idx >= lim-self.dt else self.dt 137 | 138 | #if two_step_training flag is true then ensure that local_idx is not the last or last but one sample in a year 139 | if self.two_step_training: 140 | if local_idx >= self.n_samples_per_year - 2*self.dt: 141 | #set local_idx to last possible sample in a year that allows taking two steps forward 142 | local_idx = self.n_samples_per_year - 3*self.dt 143 | 144 | if self.train and self.roll: 145 | y_roll = random.randint(0, self.img_shape_y) 146 | else: 147 | y_roll = 0 148 | 149 | if self.orography: 150 | orog = self.orography_field[0:720] 151 | else: 152 | orog = None 153 | 154 | if self.train and (self.crop_size_x or self.crop_size_y): 155 | rnd_x = random.randint(0, self.img_shape_x-self.crop_size_x) 156 | rnd_y = random.randint(0, self.img_shape_y-self.crop_size_y) 157 | else: 158 | rnd_x = 0 159 | rnd_y = 0 160 | 161 | if self.precip: 162 | return reshape_fields(self.files[year_idx][inp_local_idx, self.in_channels], 'inp', self.crop_size_x, self.crop_size_y, rnd_x, rnd_y,self.params, y_roll, self.train), \ 163 | reshape_precip(self.precip_files[year_idx][tar_local_idx+step], 'tar', self.crop_size_x, self.crop_size_y, rnd_x, rnd_y, self.params, y_roll, self.train) 164 | else: 165 | if self.two_step_training: 166 | return reshape_fields(self.files[year_idx][(local_idx-self.dt*self.n_history):(local_idx+1):self.dt, self.in_channels], 'inp', self.crop_size_x, self.crop_size_y, rnd_x, rnd_y,self.params, y_roll, self.train, self.normalize, orog, self.add_noise), \ 167 | reshape_fields(self.files[year_idx][local_idx + step:local_idx + step + 2, self.out_channels], 'tar', self.crop_size_x, self.crop_size_y, rnd_x, rnd_y, self.params, y_roll, self.train, self.normalize, orog) 168 | else: 169 | return reshape_fields(self.files[year_idx][(local_idx-self.dt*self.n_history):(local_idx+1):self.dt, self.in_channels], 'inp', self.crop_size_x, self.crop_size_y, rnd_x, rnd_y,self.params, y_roll, self.train, self.normalize, orog, self.add_noise), \ 170 | reshape_fields(self.files[year_idx][local_idx + step, self.out_channels], 'tar', self.crop_size_x, self.crop_size_y, rnd_x, rnd_y, self.params, y_roll, self.train, self.normalize, orog) 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | -------------------------------------------------------------------------------- /utils/date_time_to_hours.py: -------------------------------------------------------------------------------- 1 | #BSD 3-Clause License 2 | # 3 | #Copyright (c) 2022, FourCastNet authors 4 | #All rights reserved. 5 | # 6 | #Redistribution and use in source and binary forms, with or without 7 | #modification, are permitted provided that the following conditions are met: 8 | # 9 | #1. Redistributions of source code must retain the above copyright notice, this 10 | # list of conditions and the following disclaimer. 11 | # 12 | #2. Redistributions in binary form must reproduce the above copyright notice, 13 | # this list of conditions and the following disclaimer in the documentation 14 | # and/or other materials provided with the distribution. 15 | # 16 | #3. Neither the name of the copyright holder nor the names of its 17 | # contributors may be used to endorse or promote products derived from 18 | # this software without specific prior written permission. 19 | # 20 | #THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | #AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | #IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | #DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | #FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | #DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | #SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | #CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | #OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | #OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | # 31 | #The code was authored by the following people: 32 | # 33 | #Jaideep Pathak - NVIDIA Corporation 34 | #Shashank Subramanian - NERSC, Lawrence Berkeley National Laboratory 35 | #Peter Harrington - NERSC, Lawrence Berkeley National Laboratory 36 | #Sanjeev Raja - NERSC, Lawrence Berkeley National Laboratory 37 | #Ashesh Chattopadhyay - Rice University 38 | #Morteza Mardani - NVIDIA Corporation 39 | #Thorsten Kurth - NVIDIA Corporation 40 | #David Hall - NVIDIA Corporation 41 | #Zongyi Li - California Institute of Technology, NVIDIA Corporation 42 | #Kamyar Azizzadenesheli - Purdue University 43 | #Pedram Hassanzadeh - Rice University 44 | #Karthik Kashinath - NVIDIA Corporation 45 | #Animashree Anandkumar - California Institute of Technology, NVIDIA Corporation 46 | 47 | import numpy as np 48 | from datetime import datetime 49 | 50 | 51 | #day_of_year = datetime.now().timetuple().tm_yday # returns 1 for January 1st 52 | #time_tuple = datetime.now().timetuple() 53 | date_strings = ["2016-01-01 00:00:00", "2016-09-13 00:00:00", "2016-09-17 00:00:00", "2016-09-21 00:00:00", "2016-09-25 00:00:00", "2016-09-29 00:00:00", "2016-10-03 00:00:00", "2016-10-07 00:00:00"] 54 | 55 | ics = [] 56 | 57 | for date_ in date_strings: 58 | date_obj = datetime.strptime(date_, '%Y-%m-%d %H:%M:%S') #datetime.fromisoformat(date_) 59 | print(date_obj.timetuple()) 60 | day_of_year = date_obj.timetuple().tm_yday - 1 61 | hour_of_day = date_obj.timetuple().tm_hour 62 | hours_since_jan_01_epoch = 24*day_of_year + hour_of_day 63 | ics.append(int(hours_since_jan_01_epoch/6)) 64 | print(day_of_year, hour_of_day) 65 | print("hours = ", hours_since_jan_01_epoch ) 66 | print("steps = ", hours_since_jan_01_epoch/6) 67 | 68 | 69 | print(ics) 70 | 71 | ics = [] 72 | for date_ in date_strings: 73 | date_obj = datetime.fromisoformat(date_) #datetime.strptime(date_, '%Y-%m-%d %H:%M:%S') #datetime.fromisoformat(date_) 74 | print(date_obj.timetuple()) 75 | day_of_year = date_obj.timetuple().tm_yday - 1 76 | hour_of_day = date_obj.timetuple().tm_hour 77 | hours_since_jan_01_epoch = 24*day_of_year + hour_of_day 78 | ics.append(int(hours_since_jan_01_epoch/6)) 79 | print(day_of_year, hour_of_day) 80 | print("hours = ", hours_since_jan_01_epoch ) 81 | print("steps = ", hours_since_jan_01_epoch/6) 82 | 83 | 84 | print(ics) 85 | 86 | -------------------------------------------------------------------------------- /utils/img_utils.py: -------------------------------------------------------------------------------- 1 | #BSD 3-Clause License 2 | # 3 | #Copyright (c) 2022, FourCastNet authors 4 | #All rights reserved. 5 | # 6 | #Redistribution and use in source and binary forms, with or without 7 | #modification, are permitted provided that the following conditions are met: 8 | # 9 | #1. Redistributions of source code must retain the above copyright notice, this 10 | # list of conditions and the following disclaimer. 11 | # 12 | #2. Redistributions in binary form must reproduce the above copyright notice, 13 | # this list of conditions and the following disclaimer in the documentation 14 | # and/or other materials provided with the distribution. 15 | # 16 | #3. Neither the name of the copyright holder nor the names of its 17 | # contributors may be used to endorse or promote products derived from 18 | # this software without specific prior written permission. 19 | # 20 | #THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | #AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | #IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | #DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | #FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | #DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | #SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | #CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | #OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | #OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | # 31 | #The code was authored by the following people: 32 | # 33 | #Jaideep Pathak - NVIDIA Corporation 34 | #Shashank Subramanian - NERSC, Lawrence Berkeley National Laboratory 35 | #Peter Harrington - NERSC, Lawrence Berkeley National Laboratory 36 | #Sanjeev Raja - NERSC, Lawrence Berkeley National Laboratory 37 | #Ashesh Chattopadhyay - Rice University 38 | #Morteza Mardani - NVIDIA Corporation 39 | #Thorsten Kurth - NVIDIA Corporation 40 | #David Hall - NVIDIA Corporation 41 | #Zongyi Li - California Institute of Technology, NVIDIA Corporation 42 | #Kamyar Azizzadenesheli - Purdue University 43 | #Pedram Hassanzadeh - Rice University 44 | #Karthik Kashinath - NVIDIA Corporation 45 | #Animashree Anandkumar - California Institute of Technology, NVIDIA Corporation 46 | 47 | import logging 48 | import glob 49 | from types import new_class 50 | import torch 51 | import torch.nn as nn 52 | import torch.nn.functional as F 53 | import random 54 | import numpy as np 55 | import torch 56 | from torch.utils.data import DataLoader, Dataset 57 | from torch.utils.data.distributed import DistributedSampler 58 | from torch import Tensor 59 | import h5py 60 | import math 61 | import torchvision.transforms.functional as TF 62 | import matplotlib 63 | import matplotlib.pyplot as plt 64 | 65 | class PeriodicPad2d(nn.Module): 66 | """ 67 | pad longitudinal (left-right) circular 68 | and pad latitude (top-bottom) with zeros 69 | """ 70 | def __init__(self, pad_width): 71 | super(PeriodicPad2d, self).__init__() 72 | self.pad_width = pad_width 73 | 74 | def forward(self, x): 75 | # pad left and right circular 76 | out = F.pad(x, (self.pad_width, self.pad_width, 0, 0), mode="circular") 77 | # pad top and bottom zeros 78 | out = F.pad(out, (0, 0, self.pad_width, self.pad_width), mode="constant", value=0) 79 | return out 80 | 81 | def reshape_fields(img, inp_or_tar, crop_size_x, crop_size_y,rnd_x, rnd_y, params, y_roll, train, normalize=True, orog=None, add_noise=False): 82 | #Takes in np array of size (n_history+1, c, h, w) and returns torch tensor of size ((n_channels*(n_history+1), crop_size_x, crop_size_y) 83 | 84 | if len(np.shape(img)) ==3: 85 | img = np.expand_dims(img, 0) 86 | 87 | 88 | img = img[:, :, 0:720] #remove last pi 89 | # xel 90 | n_history = np.shape(img)[0] - 1 91 | img_shape_x = np.shape(img)[-2] 92 | img_shape_y = np.shape(img)[-1] 93 | n_channels = np.shape(img)[1] #this will either be N_in_channels or N_out_channels 94 | channels = params.in_channels if inp_or_tar =='inp' else params.out_channels 95 | means = np.load(params.global_means_path)[:, channels] 96 | stds = np.load(params.global_stds_path)[:, channels] 97 | if crop_size_x == None: 98 | crop_size_x = img_shape_x 99 | if crop_size_y == None: 100 | crop_size_y = img_shape_y 101 | 102 | if normalize: 103 | if params.normalization == 'minmax': 104 | raise Exception("minmax not supported. Use zscore") 105 | elif params.normalization == 'zscore': 106 | img -=means 107 | img /=stds 108 | 109 | if params.add_grid: 110 | if inp_or_tar == 'inp': 111 | if params.gridtype == 'linear': 112 | assert params.N_grid_channels == 2, "N_grid_channels must be set to 2 for gridtype linear" 113 | x = np.meshgrid(np.linspace(-1, 1, img_shape_x)) 114 | y = np.meshgrid(np.linspace(-1, 1, img_shape_y)) 115 | grid_x, grid_y = np.meshgrid(y, x) 116 | grid = np.stack((grid_x, grid_y), axis = 0) 117 | elif params.gridtype == 'sinusoidal': 118 | assert params.N_grid_channels == 4, "N_grid_channels must be set to 4 for gridtype sinusoidal" 119 | x1 = np.meshgrid(np.sin(np.linspace(0, 2*np.pi, img_shape_x))) 120 | x2 = np.meshgrid(np.cos(np.linspace(0, 2*np.pi, img_shape_x))) 121 | y1 = np.meshgrid(np.sin(np.linspace(0, 2*np.pi, img_shape_y))) 122 | y2 = np.meshgrid(np.cos(np.linspace(0, 2*np.pi, img_shape_y))) 123 | grid_x1, grid_y1 = np.meshgrid(y1, x1) 124 | grid_x2, grid_y2 = np.meshgrid(y2, x2) 125 | grid = np.expand_dims(np.stack((grid_x1, grid_y1, grid_x2, grid_y2), axis = 0), axis = 0) 126 | img = np.concatenate((img, grid), axis = 1 ) 127 | 128 | if params.orography and inp_or_tar == 'inp': 129 | img = np.concatenate((img, np.expand_dims(orog, axis = (0,1) )), axis = 1) 130 | n_channels += 1 131 | 132 | if params.roll: 133 | img = np.roll(img, y_roll, axis = -1) 134 | 135 | if train and (crop_size_x or crop_size_y): 136 | img = img[:,:,rnd_x:rnd_x+crop_size_x, rnd_y:rnd_y+crop_size_y] 137 | 138 | if inp_or_tar == 'inp': 139 | img = np.reshape(img, (n_channels*(n_history+1), crop_size_x, crop_size_y)) 140 | elif inp_or_tar == 'tar': 141 | if params.two_step_training: 142 | img = np.reshape(img, (n_channels*2, crop_size_x, crop_size_y)) 143 | else: 144 | img = np.reshape(img, (n_channels, crop_size_x, crop_size_y)) 145 | 146 | if add_noise: 147 | img = img + np.random.normal(0, scale=params.noise_std, size=img.shape) 148 | 149 | return torch.as_tensor(img) 150 | 151 | def reshape_precip(img, inp_or_tar, crop_size_x, crop_size_y,rnd_x, rnd_y, params, y_roll, train, normalize=True): 152 | 153 | if len(np.shape(img)) ==2: 154 | img = np.expand_dims(img, 0) 155 | 156 | img = img[:,:720,:] 157 | img_shape_x = img.shape[-2] 158 | img_shape_y = img.shape[-1] 159 | n_channels = 1 160 | if crop_size_x == None: 161 | crop_size_x = img_shape_x 162 | if crop_size_y == None: 163 | crop_size_y = img_shape_y 164 | 165 | if normalize: 166 | eps = params.precip_eps 167 | img = np.log1p(img/eps) 168 | if params.add_grid: 169 | if inp_or_tar == 'inp': 170 | if params.gridtype == 'linear': 171 | assert params.N_grid_channels == 2, "N_grid_channels must be set to 2 for gridtype linear" 172 | x = np.meshgrid(np.linspace(-1, 1, img_shape_x)) 173 | y = np.meshgrid(np.linspace(-1, 1, img_shape_y)) 174 | grid_x, grid_y = np.meshgrid(y, x) 175 | grid = np.stack((grid_x, grid_y), axis = 0) 176 | elif params.gridtype == 'sinusoidal': 177 | assert params.N_grid_channels == 4, "N_grid_channels must be set to 4 for gridtype sinusoidal" 178 | x1 = np.meshgrid(np.sin(np.linspace(0, 2*np.pi, img_shape_x))) 179 | x2 = np.meshgrid(np.cos(np.linspace(0, 2*np.pi, img_shape_x))) 180 | y1 = np.meshgrid(np.sin(np.linspace(0, 2*np.pi, img_shape_y))) 181 | y2 = np.meshgrid(np.cos(np.linspace(0, 2*np.pi, img_shape_y))) 182 | grid_x1, grid_y1 = np.meshgrid(y1, x1) 183 | grid_x2, grid_y2 = np.meshgrid(y2, x2) 184 | grid = np.expand_dims(np.stack((grid_x1, grid_y1, grid_x2, grid_y2), axis = 0), axis = 0) 185 | img = np.concatenate((img, grid), axis = 1 ) 186 | 187 | if params.roll: 188 | img = np.roll(img, y_roll, axis = -1) 189 | 190 | if train and (crop_size_x or crop_size_y): 191 | img = img[:,rnd_x:rnd_x+crop_size_x, rnd_y:rnd_y+crop_size_y] 192 | 193 | img = np.reshape(img, (n_channels, crop_size_x, crop_size_y)) 194 | return torch.as_tensor(img) 195 | 196 | 197 | def vis_precip(fields): 198 | pred, tar = fields 199 | fig, ax = plt.subplots(1, 2, figsize=(24,12)) 200 | ax[0].imshow(pred, cmap="coolwarm") 201 | ax[0].set_title("tp pred") 202 | ax[1].imshow(tar, cmap="coolwarm") 203 | ax[1].set_title("tp tar") 204 | fig.tight_layout() 205 | return fig 206 | 207 | 208 | 209 | 210 | -------------------------------------------------------------------------------- /utils/logging_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | 4 | _format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s' 5 | 6 | def config_logger(log_level=logging.INFO): 7 | logging.basicConfig(format=_format, level=log_level) 8 | 9 | def log_to_file(logger_name=None, log_level=logging.INFO, log_filename='tensorflow.log'): 10 | 11 | if not os.path.exists(os.path.dirname(log_filename)): 12 | os.makedirs(os.path.dirname(log_filename)) 13 | 14 | if logger_name is not None: 15 | log = logging.getLogger(logger_name) 16 | else: 17 | log = logging.getLogger() 18 | 19 | fh = logging.FileHandler(log_filename) 20 | fh.setLevel(log_level) 21 | fh.setFormatter(logging.Formatter(_format)) 22 | log.addHandler(fh) 23 | 24 | def log_versions(): 25 | import torch 26 | import subprocess 27 | 28 | logging.info('--------------- Versions ---------------') 29 | # logging.info('git branch: ' + str(subprocess.check_output(['git', 'branch']).strip())) 30 | # logging.info('git hash: ' + str(subprocess.check_output(['git', 'rev-parse', 'HEAD']).strip())) 31 | logging.info('Torch: ' + str(torch.__version__)) 32 | logging.info('----------------------------------------') 33 | -------------------------------------------------------------------------------- /utils/weighted_acc_rmse.py: -------------------------------------------------------------------------------- 1 | #BSD 3-Clause License 2 | # 3 | #Copyright (c) 2022, FourCastNet authors 4 | #All rights reserved. 5 | # 6 | #Redistribution and use in source and binary forms, with or without 7 | #modification, are permitted provided that the following conditions are met: 8 | # 9 | #1. Redistributions of source code must retain the above copyright notice, this 10 | # list of conditions and the following disclaimer. 11 | # 12 | #2. Redistributions in binary form must reproduce the above copyright notice, 13 | # this list of conditions and the following disclaimer in the documentation 14 | # and/or other materials provided with the distribution. 15 | # 16 | #3. Neither the name of the copyright holder nor the names of its 17 | # contributors may be used to endorse or promote products derived from 18 | # this software without specific prior written permission. 19 | # 20 | #THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | #AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | #IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | #DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | #FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | #DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | #SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | #CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | #OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | #OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | # 31 | #The code was authored by the following people: 32 | # 33 | #Jaideep Pathak - NVIDIA Corporation 34 | #Shashank Subramanian - NERSC, Lawrence Berkeley National Laboratory 35 | #Peter Harrington - NERSC, Lawrence Berkeley National Laboratory 36 | #Sanjeev Raja - NERSC, Lawrence Berkeley National Laboratory 37 | #Ashesh Chattopadhyay - Rice University 38 | #Morteza Mardani - NVIDIA Corporation 39 | #Thorsten Kurth - NVIDIA Corporation 40 | #David Hall - NVIDIA Corporation 41 | #Zongyi Li - California Institute of Technology, NVIDIA Corporation 42 | #Kamyar Azizzadenesheli - Purdue University 43 | #Pedram Hassanzadeh - Rice University 44 | #Karthik Kashinath - NVIDIA Corporation 45 | #Animashree Anandkumar - California Institute of Technology, NVIDIA Corporation 46 | 47 | import os 48 | import time 49 | import numpy as np 50 | import argparse 51 | import h5py 52 | #from netCDF4 import Dataset as DS 53 | from collections import OrderedDict 54 | from utils import logging_utils 55 | logging_utils.config_logger() 56 | #from utils.YParams import YParams 57 | from utils.data_loader_multifiles import get_data_loader 58 | import wandb 59 | import torch 60 | import warnings 61 | 62 | def unlog_tp(x, eps=1E-5): 63 | # return np.exp(x + np.log(eps)) - eps 64 | return eps*(np.exp(x)-1) 65 | 66 | def unlog_tp_torch(x, eps=1E-5): 67 | # return torch.exp(x + torch.log(eps)) - eps 68 | return eps*(torch.exp(x)-1) 69 | 70 | def mean(x, axis = None): 71 | #spatial mean 72 | y = np.sum(x, axis) / np.size(x, axis) 73 | return y 74 | 75 | def lat_np(j, num_lat): 76 | return 90 - j * 180/(num_lat-1) 77 | 78 | def weighted_acc(pred,target, weighted = True): 79 | #takes in shape [1, num_lat, num_long] 80 | if len(pred.shape) ==2: 81 | pred = np.expand_dims(pred, 0) 82 | if len(target.shape) ==2: 83 | target = np.expand_dims(target, 0) 84 | 85 | num_lat = np.shape(pred)[1] 86 | num_long = np.shape(target)[2] 87 | # pred -= mean(pred) 88 | # target -= mean(target) 89 | s = np.sum(np.cos(np.pi/180* lat_np(np.arange(0, num_lat), num_lat))) 90 | weight = np.expand_dims(latitude_weighting_factor(np.arange(0, num_lat), num_lat, s), -1) if weighted else 1 91 | r = (weight*pred*target).sum() /np.sqrt((weight*pred*pred).sum() * (weight*target*target).sum()) 92 | return r 93 | 94 | def weighted_acc_masked(pred,target, weighted = True, maskarray=1): 95 | #takes in shape [1, num_lat, num_long] 96 | if len(pred.shape) ==2: 97 | pred = np.expand_dims(pred, 0) 98 | if len(target.shape) ==2: 99 | target = np.expand_dims(target, 0) 100 | 101 | num_lat = np.shape(pred)[1] 102 | num_long = np.shape(target)[2] 103 | pred -= mean(pred) 104 | target -= mean(target) 105 | s = np.sum(np.cos(np.pi/180* lat(np.arange(0, num_lat), num_lat))) 106 | weight = np.expand_dims(latitude_weighting_factor(np.arange(0, num_lat), num_lat, s), -1) if weighted else 1 107 | r = (maskarray*weight*pred*target).sum() /np.sqrt((maskarray*weight*pred*pred).sum() * (maskarray*weight*target*target).sum()) 108 | return r 109 | 110 | def weighted_rmse(pred, target): 111 | if len(pred.shape) ==2: 112 | pred = np.expand_dims(pred, 0) 113 | if len(target.shape) ==2: 114 | target = np.expand_dims(target, 0) 115 | #takes in arrays of size [1, h, w] and returns latitude-weighted rmse 116 | num_lat = np.shape(pred)[1] 117 | num_long = np.shape(target)[2] 118 | s = np.sum(np.cos(np.pi/180* lat_np(np.arange(0, num_lat), num_lat))) 119 | weight = np.expand_dims(latitude_weighting_factor(np.arange(0, num_lat), num_lat, s), -1) 120 | return np.sqrt(1/num_lat * 1/num_long * np.sum(np.dot(weight.T, (pred[0] - target[0])**2))) 121 | 122 | def latitude_weighting_factor(j, num_lat, s): 123 | return num_lat*np.cos(np.pi/180. * lat_np(j, num_lat))/s 124 | 125 | def top_quantiles_error(pred, target): 126 | if len(pred.shape) ==2: 127 | pred = np.expand_dims(pred, 0) 128 | if len(target.shape) ==2: 129 | target = np.expand_dims(target, 0) 130 | qs = 100 131 | qlim = 5 132 | qcut = 0.1 133 | qtile = 1. - np.logspace(-qlim, -qcut, num=qs) 134 | P_tar = np.quantile(target, q=qtile, axis=(1,2)) 135 | P_pred = np.quantile(pred, q=qtile, axis=(1,2)) 136 | return np.mean(P_pred - P_tar, axis=0) 137 | 138 | 139 | # torch version for rmse comp 140 | @torch.jit.script 141 | def lat(j: torch.Tensor, num_lat: int) -> torch.Tensor: 142 | return 90. - j * 180./float(num_lat-1) 143 | 144 | @torch.jit.script 145 | def latitude_weighting_factor_torch(j: torch.Tensor, num_lat: int, s: torch.Tensor) -> torch.Tensor: 146 | return num_lat * torch.cos(3.1416/180. * lat(j, num_lat))/s 147 | 148 | @torch.jit.script 149 | def weighted_rmse_torch_channels(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 150 | #takes in arrays of size [n, c, h, w] and returns latitude-weighted rmse for each chann 151 | num_lat = pred.shape[2] 152 | #num_long = target.shape[2] 153 | lat_t = torch.arange(start=0, end=num_lat, device=pred.device) 154 | 155 | s = torch.sum(torch.cos(3.1416/180. * lat(lat_t, num_lat))) 156 | weight = torch.reshape(latitude_weighting_factor_torch(lat_t, num_lat, s), (1, 1, -1, 1)) 157 | result = torch.sqrt(torch.mean(weight * (pred - target)**2., dim=(-1,-2))) 158 | return result 159 | 160 | @torch.jit.script 161 | def weighted_rmse_torch(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 162 | result = weighted_rmse_torch_channels(pred, target) 163 | return torch.mean(result, dim=0) 164 | 165 | @torch.jit.script 166 | def weighted_acc_masked_torch_channels(pred: torch.Tensor, target: torch.Tensor, maskarray: torch.Tensor) -> torch.Tensor: 167 | #takes in arrays of size [n, c, h, w] and returns latitude-weighted acc 168 | num_lat = pred.shape[2] 169 | lat_t = torch.arange(start=0, end=num_lat, device=pred.device) 170 | s = torch.sum(torch.cos(3.1416/180. * lat(lat_t, num_lat))) 171 | weight = torch.reshape(latitude_weighting_factor_torch(lat_t, num_lat, s), (1, 1, -1, 1)) 172 | result = torch.sum(maskarray * weight * pred * target, dim=(-1,-2)) / torch.sqrt(torch.sum(maskarray * weight * pred * pred, dim=(-1,-2)) * torch.sum(maskarray * weight * target * target, dim=(-1,-2))) 173 | return result 174 | 175 | @torch.jit.script 176 | def weighted_acc_torch_channels(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 177 | #takes in arrays of size [n, c, h, w] and returns latitude-weighted acc 178 | num_lat = pred.shape[2] 179 | #num_long = target.shape[2] 180 | lat_t = torch.arange(start=0, end=num_lat, device=pred.device) 181 | s = torch.sum(torch.cos(3.1416/180. * lat(lat_t, num_lat))) 182 | weight = torch.reshape(latitude_weighting_factor_torch(lat_t, num_lat, s), (1, 1, -1, 1)) 183 | result = torch.sum(weight * pred * target, dim=(-1,-2)) / torch.sqrt(torch.sum(weight * pred * pred, dim=(-1,-2)) * torch.sum(weight * target * 184 | target, dim=(-1,-2))) 185 | return result 186 | 187 | @torch.jit.script 188 | def weighted_acc_torch(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 189 | result = weighted_acc_torch_channels(pred, target) 190 | return torch.mean(result, dim=0) 191 | 192 | @torch.jit.script 193 | def unweighted_acc_torch_channels(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 194 | result = torch.sum(pred * target, dim=(-1,-2)) / torch.sqrt(torch.sum(pred * pred, dim=(-1,-2)) * torch.sum(target * 195 | target, dim=(-1,-2))) 196 | return result 197 | 198 | @torch.jit.script 199 | def unweighted_acc_torch(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 200 | result = unweighted_acc_torch_channels(pred, target) 201 | return torch.mean(result, dim=0) 202 | 203 | @torch.jit.script 204 | def top_quantiles_error_torch(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 205 | qs = 100 206 | qlim = 3 207 | qcut = 0.1 208 | n, c, h, w = pred.size() 209 | qtile = 1. - torch.logspace(-qlim, -qcut, steps=qs, device=pred.device) 210 | P_tar = torch.quantile(target.view(n,c,h*w), q=qtile, dim=-1) 211 | P_pred = torch.quantile(pred.view(n,c,h*w), q=qtile, dim=-1) 212 | return torch.mean(P_pred - P_tar, dim=0) 213 | --------------------------------------------------------------------------------