├── README.md ├── configs ├── Adas │ └── training_options.yaml ├── ConvCNP │ └── training_options.yaml └── FNP │ └── training_options.yaml ├── data └── .gitkeep ├── datasets ├── __init__.py ├── era5_npy_f32.py ├── mean_std.json └── mean_std_single.json ├── inference.py ├── models ├── Adas.py ├── ConvCNP.py └── FNP.py ├── modules ├── __init__.py ├── cnn.py ├── encoders.py ├── helpers.py ├── initialization.py ├── losses.py ├── mlp.py └── transformer.py ├── train.py └── utils ├── __init__.py ├── builder.py ├── logger.py ├── metrics.py └── misc.py /README.md: -------------------------------------------------------------------------------- 1 | # FNP: Fourier Neural Processes for Arbitrary-Resolution Data Assimilation 2 | 3 | This repo contains the official PyTorch codebase of FNP. Our paper is accepted by NeurIPS 2024. 4 | 5 | ## Codebase Structure 6 | 7 | - `configs` contains all the experiment configurations. 8 | - `Adas` contains the configuration to train the Adas model. 9 | - `ConvCNP` contains the configuration to train the ConvCNP model. 10 | - `FNP` contains the configuration to train the FNP model. 11 | - `data` contains the ERA5 data. 12 | - `datasets` contains the dataset and the mean and standard deviation values of ERA5 data. 13 | - `models` contains the data assimilation models and the forecast model FengWu (ONNX version). 14 | - `modules` contains the basic modules used for all the data assimilation models. 15 | - `utils` contains the files that support some basic needs. 16 | - `train.py` and `inference.py` provide training and testing pipelines. 17 | 18 | We provide the ONNX model of FengWu with 128×256 resolution for making forecasts. The ERA5 data can be downloaded from the official website of Climate Data Store. 19 | 20 | ## Setup 21 | 22 | First, download and set up the repo 23 | 24 | ``` 25 | git clone https://github.com/OpenEarthLab/FNP.git 26 | cd FNP 27 | ``` 28 | 29 | Then, download and put the ERA5 data and forecast model `FengWu.onnx` into corresponding positions according to the codebase structure. 30 | 31 | Deploy the environment given below 32 | 33 | ``` 34 | python version 3.8.18 35 | torch==1.13.1+cu117 36 | ``` 37 | 38 | ## Training 39 | 40 | We support multi-node and multi-gpu training. You can freely adjust the number of nodes and GPUs in the following commands. 41 | 42 | To train the FNP model with the default configuration of `<24h lead time background, 10% observations with 128×256 resolution>`, just run 43 | 44 | ``` 45 | torchrun --nnodes=1 --nproc_per_node=4 --node_rank=0 --master_port=29500 train.py 46 | ``` 47 | 48 | You can freely choose the experiment you want to perform by changing the command parameters. For example, if you want to train the `ConvCNP` model with the configuration of `<48h lead time background, 1% observations with 256×512 resolution>`, you can run 49 | 50 | ``` 51 | torchrun --nnodes=1 --nproc_per_node=4 --node_rank=0 --master_port=29500 train.py --lead_time=48 --ratio=0.99 --resolution=256 --rundir='./configs/ConvCNP' 52 | ``` 53 | 54 | Please make sure that the parameter `--lead_time` is an integer multiple of 6, because the forecast model has a single-step forecast interval of six hours. 55 | 56 | **The resolution and ratio of the observations used for data assimilation can be arbitrary (the original resolution of ERA5 data is 721×1440), which are not limited to the settings given in our paper.** 57 | 58 | ## Evaluation 59 | 60 | The commands for testing are the same as for training. 61 | 62 | For example, you can use 1 GPU on 1 node to evaluate the performance of `Adas` model with the configuration of `<24h lead time background, 10% observations with 721×1440 resolution>` through 63 | 64 | ``` 65 | torchrun --nnodes=1 --nproc_per_node=1 --node_rank=0 --master_port=29500 inference.py --resolution=721 --rundir='./configs/Adas' 66 | ``` 67 | 68 | The best checkpoint saved during training will be loaded to evaluate the MSE, MAE, and WRMSE metrics for all variables on the testset. 69 | 70 | ## BibTeX 71 | ```bibtex 72 | @article{chen2025fnp, 73 | title={Fnp: Fourier neural processes for arbitrary-resolution data assimilation}, 74 | author={Chen, Kun and Ye, Peng and Chen, Hao and Han, Tao and Ouyang, Wanli and Chen, Tao and BAI, LEI and others}, 75 | journal={Advances in Neural Information Processing Systems}, 76 | volume={37}, 77 | pages={137847--137872}, 78 | year={2025} 79 | } 80 | ``` 81 | -------------------------------------------------------------------------------- /configs/Adas/training_options.yaml: -------------------------------------------------------------------------------- 1 | vnames: &id001 2 | single_level_vnames: 3 | - u10 4 | - v10 5 | - t2m 6 | - msl 7 | multi_level_vnames: 8 | - z 9 | - q 10 | - u 11 | - v 12 | - t 13 | hight_level_list: 14 | - 50 15 | - 100 16 | - 150 17 | - 200 18 | - 250 19 | - 300 20 | - 400 21 | - 500 22 | - 600 23 | - 700 24 | - 850 25 | - 925 26 | - 1000 27 | dataset: 28 | train: 29 | type: era5_npy_f32 30 | data_dir: ./data 31 | train_stride: 6 32 | file_stride: 6 33 | sample_stride: 1 34 | vnames: *id001 35 | valid: 36 | type: era5_npy_f32 37 | data_dir: ./data 38 | train_stride: 6 39 | file_stride: 6 40 | sample_stride: 1 41 | vnames: *id001 42 | dataloader: 43 | num_workers: 4 44 | pin_memory: true 45 | prefetch_factor: 2 46 | persistent_workers: true 47 | model: 48 | type: Adas 49 | params: 50 | img_size: 51 | - 69 52 | - 128 53 | - 256 54 | dim: 96 55 | patch_size: 56 | - 1 57 | - 2 58 | - 2 59 | window_size: 60 | - 2 61 | - 4 62 | - 8 63 | ape: True 64 | criterion: UnifyMAE 65 | optimizer: 66 | type: AdamW 67 | params: 68 | lr: 1.0e-04 69 | betas: 70 | - 0.9 71 | - 0.9 72 | weight_decay: 0.01 73 | lr_scheduler: 74 | type: OneCycleLR 75 | params: 76 | max_lr: 1.0e-4 77 | pct_start: 0.1 78 | anneal_strategy: cos 79 | div_factor: 100 80 | final_div_factor: 1000 81 | -------------------------------------------------------------------------------- /configs/ConvCNP/training_options.yaml: -------------------------------------------------------------------------------- 1 | vnames: &id001 2 | single_level_vnames: 3 | - u10 4 | - v10 5 | - t2m 6 | - msl 7 | multi_level_vnames: 8 | - z 9 | - q 10 | - u 11 | - v 12 | - t 13 | hight_level_list: 14 | - 50 15 | - 100 16 | - 150 17 | - 200 18 | - 250 19 | - 300 20 | - 400 21 | - 500 22 | - 600 23 | - 700 24 | - 850 25 | - 925 26 | - 1000 27 | dataset: 28 | train: 29 | type: era5_npy_f32 30 | data_dir: ./data 31 | train_stride: 6 32 | file_stride: 6 33 | sample_stride: 1 34 | vnames: *id001 35 | valid: 36 | type: era5_npy_f32 37 | data_dir: ./data 38 | train_stride: 6 39 | file_stride: 6 40 | sample_stride: 1 41 | vnames: *id001 42 | dataloader: 43 | num_workers: 4 44 | pin_memory: true 45 | prefetch_factor: 2 46 | persistent_workers: true 47 | model: 48 | type: ConvCNP 49 | params: 50 | x_dim: 69 51 | y_dim: 69 52 | r_dim: 512 # 512 for 128×256 & 256×512 resolution, 128 for 721×1440 resolution 53 | criterion: CNPFLoss 54 | optimizer: 55 | type: AdamW 56 | params: 57 | lr: 1.0e-04 58 | betas: 59 | - 0.9 60 | - 0.9 61 | weight_decay: 0.01 62 | lr_scheduler: 63 | type: OneCycleLR 64 | params: 65 | max_lr: 1.0e-4 66 | pct_start: 0.1 67 | anneal_strategy: cos 68 | div_factor: 100 69 | final_div_factor: 1000 70 | -------------------------------------------------------------------------------- /configs/FNP/training_options.yaml: -------------------------------------------------------------------------------- 1 | vnames: &id001 2 | single_level_vnames: 3 | - u10 4 | - v10 5 | - t2m 6 | - msl 7 | multi_level_vnames: 8 | - z 9 | - q 10 | - u 11 | - v 12 | - t 13 | hight_level_list: 14 | - 50 15 | - 100 16 | - 150 17 | - 200 18 | - 250 19 | - 300 20 | - 400 21 | - 500 22 | - 600 23 | - 700 24 | - 850 25 | - 925 26 | - 1000 27 | dataset: 28 | train: 29 | type: era5_npy_f32 30 | data_dir: ./data 31 | train_stride: 6 32 | file_stride: 6 33 | sample_stride: 1 34 | vnames: *id001 35 | valid: 36 | type: era5_npy_f32 37 | data_dir: ./data 38 | train_stride: 6 39 | file_stride: 6 40 | sample_stride: 1 41 | vnames: *id001 42 | dataloader: 43 | num_workers: 4 44 | pin_memory: true 45 | prefetch_factor: 2 46 | persistent_workers: true 47 | model: 48 | type: FNP 49 | params: 50 | n_channels: 51 | - 4 52 | - 13 53 | - 13 54 | - 13 55 | - 13 56 | - 13 57 | r_dim: 128 # 128 for 128×256 & 256×512 resolution, 64 for 721×1440 resolution 58 | use_nfl: true 59 | use_dam: true 60 | criterion: CNPFLoss 61 | optimizer: 62 | type: AdamW 63 | params: 64 | lr: 1.0e-04 65 | betas: 66 | - 0.9 67 | - 0.9 68 | weight_decay: 0.01 69 | lr_scheduler: 70 | type: OneCycleLR 71 | params: 72 | max_lr: 1.0e-4 73 | pct_start: 0.1 74 | anneal_strategy: cos 75 | div_factor: 100 76 | final_div_factor: 1000 77 | -------------------------------------------------------------------------------- /data/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenEarthLab/FNP/624e624be481cfa6a149613bd8a08f5df318cb10/data/.gitkeep -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenEarthLab/FNP/624e624be481cfa6a149613bd8a08f5df318cb10/datasets/__init__.py -------------------------------------------------------------------------------- /datasets/era5_npy_f32.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import numpy as np 3 | import io 4 | import json 5 | import pandas as pd 6 | import os 7 | from multiprocessing import shared_memory 8 | import multiprocessing 9 | import copy 10 | import queue 11 | import torch 12 | 13 | 14 | Years = { 15 | 'train': ['1979-01-01 00:00:00', '2015-12-31 23:00:00'], 16 | 'valid': ['2016-01-01 00:00:00', '2017-12-31 23:00:00'], 17 | 'test': ['2018-01-01 00:00:00', '2018-12-31 23:00:00'], 18 | 'all': ['1979-01-01 00:00:00', '2020-12-31 23:00:00'] 19 | } 20 | 21 | multi_level_vnames = [ 22 | "z", "t", "q", "r", "u", "v", "vo", "pv", 23 | ] 24 | single_level_vnames = [ 25 | "t2m", "u10", "v10", "tcc", "tp", "tisr", 26 | ] 27 | long2shortname_dict = {"geopotential": "z", "temperature": "t", "specific_humidity": "q", "relative_humidity": "r", "u_component_of_wind": "u", "v_component_of_wind": "v", "vorticity": "vo", "potential_vorticity": "pv", \ 28 | "2m_temperature": "t2m", "10m_u_component_of_wind": "u10", "10m_v_component_of_wind": "v10", "total_cloud_cover": "tcc", "total_precipitation": "tp", "toa_incident_solar_radiation": "tisr"} 29 | 30 | height_level = [1, 2, 3, 5, 7, 10, 20, 30, 50, 70, 100, 125, 150, 175, 200, 225, 250, 300, 350, 400, 450, \ 31 | 500, 550, 600, 650, 700, 750, 775, 800, 825, 850, 875, 900, 925, 950, 975, 1000] 32 | # height_level = [50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000] 33 | 34 | 35 | def standardization(data): 36 | mu = np.mean(data) 37 | sigma = np.std(data) 38 | return (data - mu) / sigma 39 | 40 | 41 | class era5_npy_f32(Dataset): 42 | def __init__(self, data_dir='./data', split='train', **kwargs) -> None: 43 | super().__init__() 44 | 45 | self.length = kwargs.get('length', 1) 46 | self.file_stride = kwargs.get('file_stride', 6) 47 | self.sample_stride = kwargs.get('sample_stride', 1) 48 | self.output_meanstd = kwargs.get("output_meanstd", False) 49 | self.use_diff_pos = kwargs.get("use_diff_pos", False) 50 | self.rm_equator = kwargs.get("rm_equator", False) 51 | Years_dict = kwargs.get('years', Years) 52 | 53 | self.pred_length = kwargs.get("pred_length", 0) 54 | self.inference_stride = kwargs.get("inference_stride", 6) 55 | self.train_stride = kwargs.get("train_stride", 6) 56 | self.use_gt = kwargs.get("use_gt", True) 57 | self.data_save_dir = kwargs.get("data_save_dir", None) 58 | 59 | self.save_single_level_names = kwargs.get("save_single_level_names", []) 60 | self.save_multi_level_names = kwargs.get("save_multi_level_names", []) 61 | 62 | vnames_type = kwargs.get("vnames", {}) 63 | self.single_level_vnames = vnames_type.get('single_level_vnames', []) 64 | self.multi_level_vnames = vnames_type.get('multi_level_vnames', ['z','q', 'u', 'v', 't']) 65 | self.height_level_list = vnames_type.get('hight_level_list', [50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000]) 66 | self.height_level_indexes = [height_level.index(j) for j in self.height_level_list] 67 | 68 | self.select_row = [i for i in range(721)] 69 | if self.rm_equator: 70 | del self.select_row[360] 71 | self.split = split 72 | self.data_dir = data_dir 73 | years = Years_dict[split] 74 | self.init_file_list(years) 75 | 76 | self._get_meanstd() 77 | self.mean, self.std = self.get_meanstd() 78 | self.data_element_num = len(self.single_level_vnames) + len(self.multi_level_vnames) * len(self.height_level_list) 79 | dim = len(self.single_level_vnames) + len(self.multi_level_vnames) * len(self.height_level_list) 80 | 81 | self.index_dict1 = {} 82 | self.index_dict2 = {} 83 | i = 0 84 | for vname in self.single_level_vnames: 85 | self.index_dict1[(vname, 0)] = i 86 | i += 1 87 | for vname in self.multi_level_vnames: 88 | for height in self.height_level_list: 89 | self.index_dict1[(vname, height)] = i 90 | i += 1 91 | 92 | self.index_queue = multiprocessing.Queue() 93 | self.unit_data_queue = multiprocessing.Queue() 94 | 95 | self.index_queue.cancel_join_thread() 96 | self.unit_data_queue.cancel_join_thread() 97 | 98 | self.compound_data_queue = [] 99 | self.sharedmemory_list = [] 100 | self.compound_data_queue_dict = {} 101 | self.sharedmemory_dict = {} 102 | 103 | self.compound_data_queue_num = 8 104 | 105 | self.lock = multiprocessing.Lock() 106 | if self.rm_equator: 107 | self.a = np.zeros((dim, 720, 1440), dtype=np.float32) 108 | else: 109 | self.a = np.zeros((dim, 721, 1440), dtype=np.float32) 110 | 111 | for _ in range(self.compound_data_queue_num): 112 | self.compound_data_queue.append(multiprocessing.Queue()) 113 | shm = shared_memory.SharedMemory(create=True, size=self.a.nbytes) 114 | shm.unlink() 115 | self.sharedmemory_list.append(shm) 116 | 117 | self.arr = multiprocessing.Array('i', range(self.compound_data_queue_num)) 118 | 119 | self._workers = [] 120 | 121 | for _ in range(40): 122 | w = multiprocessing.Process( 123 | target=self.load_data_process) 124 | w.daemon = True 125 | # NB: Process.start() actually take some time as it needs to 126 | # start a process and pass the arguments over via a pipe. 127 | # Therefore, we only add a worker to self._workers list after 128 | # it started, so that we do not call .join() if program dies 129 | # before it starts, and __del__ tries to join but will get: 130 | # AssertionError: can only join a started process. 131 | w.start() 132 | self._workers.append(w) 133 | w = multiprocessing.Process(target=self.data_compound_process) 134 | w.daemon = True 135 | w.start() 136 | self._workers.append(w) 137 | 138 | def init_file_list(self, years): 139 | time_sequence = pd.date_range(years[0],years[1],freq=str(self.file_stride)+'H') #pd.date_range(start='2019-1-09',periods=24,freq='H') 140 | self.file_list= [os.path.join(str(time_stamp.year), str(time_stamp.to_datetime64()).split('.')[0]).replace('T', '/') 141 | for time_stamp in time_sequence] 142 | self.single_file_list= [os.path.join('single/'+str(time_stamp.year), str(time_stamp.to_datetime64()).split('.')[0]).replace('T', '/') 143 | for time_stamp in time_sequence] 144 | 145 | def _get_meanstd(self): 146 | with open('./datasets/mean_std.json',mode='r') as f: 147 | multi_level_mean_std = json.load(f) 148 | with open('./datasets/mean_std_single.json',mode='r') as f: 149 | single_level_mean_std = json.load(f) 150 | self.mean_std = {} 151 | multi_level_mean_std['mean'].update(single_level_mean_std['mean']) 152 | multi_level_mean_std['std'].update(single_level_mean_std['std']) 153 | self.mean_std['mean'] = multi_level_mean_std['mean'] 154 | self.mean_std['std'] = multi_level_mean_std['std'] 155 | for vname in self.single_level_vnames: 156 | self.mean_std['mean'][vname] = np.array(self.mean_std['mean'][vname])[::-1][:,np.newaxis,np.newaxis] 157 | self.mean_std['std'][vname] = np.array(self.mean_std['std'][vname])[::-1][:,np.newaxis,np.newaxis] 158 | for vname in self.multi_level_vnames: 159 | self.mean_std['mean'][vname] = np.array(self.mean_std['mean'][vname])[::-1][:,np.newaxis,np.newaxis] 160 | self.mean_std['std'][vname] = np.array(self.mean_std['std'][vname])[::-1][:,np.newaxis,np.newaxis] 161 | 162 | def data_compound_process(self): 163 | recorder_dict = {} 164 | while True: 165 | job_pid, idx, vname, height = self.unit_data_queue.get() 166 | if job_pid not in self.compound_data_queue_dict: 167 | try: 168 | self.lock.acquire() 169 | for i in range(self.compound_data_queue_num): 170 | if job_pid == self.arr[i]: 171 | self.compound_data_queue_dict[job_pid] = self.compound_data_queue[i] 172 | break 173 | if (i == self.compound_data_queue_num - 1) and job_pid != self.arr[i]: 174 | print("error", job_pid, self.arr) 175 | except Exception as err: 176 | raise err 177 | finally: 178 | self.lock.release() 179 | 180 | if (job_pid, idx) in recorder_dict: 181 | # recorder_dict[(job_pid, idx)][(vname, height)] = 1 182 | recorder_dict[(job_pid, idx)] += 1 183 | else: 184 | recorder_dict[(job_pid, idx)] = 1 185 | if recorder_dict[(job_pid, idx)] == self.data_element_num: 186 | del recorder_dict[(job_pid, idx)] 187 | self.compound_data_queue_dict[job_pid].put((idx)) 188 | 189 | def get_data(self, idxes): 190 | job_pid = os.getpid() 191 | if job_pid not in self.compound_data_queue_dict: 192 | try: 193 | self.lock.acquire() 194 | for i in range(self.compound_data_queue_num): 195 | if i == self.arr[i]: 196 | self.arr[i] = job_pid 197 | self.compound_data_queue_dict[job_pid] = self.compound_data_queue[i] 198 | self.sharedmemory_dict[job_pid] = self.sharedmemory_list[i] 199 | break 200 | if (i == self.compound_data_queue_num - 1) and job_pid != self.arr[i]: 201 | print("error", job_pid, self.arr) 202 | 203 | except Exception as err: 204 | raise err 205 | finally: 206 | self.lock.release() 207 | 208 | try: 209 | idx = self.compound_data_queue_dict[job_pid].get(False) 210 | raise ValueError 211 | except queue.Empty: 212 | pass 213 | except Exception as err: 214 | raise err 215 | 216 | b = np.ndarray(self.a.shape, dtype=self.a.dtype, buffer=self.sharedmemory_dict[job_pid].buf) 217 | return_data = [] 218 | for idx in idxes: 219 | for vname in self.single_level_vnames: 220 | self.index_queue.put((job_pid, idx, vname, 0)) 221 | for vname in self.multi_level_vnames: 222 | for height in self.height_level_list: 223 | self.index_queue.put((job_pid, idx, vname, height)) 224 | idx = self.compound_data_queue_dict[job_pid].get() 225 | b -= self.mean.numpy()[:, np.newaxis, np.newaxis] 226 | b /= self.std.numpy()[:, np.newaxis, np.newaxis] 227 | return_data.append(copy.deepcopy(b)) 228 | 229 | return return_data 230 | 231 | def load_data_process(self): 232 | while True: 233 | job_pid, idx, vname, height = self.index_queue.get() 234 | if job_pid not in self.compound_data_queue_dict: 235 | try: 236 | self.lock.acquire() 237 | for i in range(self.compound_data_queue_num): 238 | if job_pid == self.arr[i]: 239 | self.compound_data_queue_dict[job_pid] = self.compound_data_queue[i] 240 | self.sharedmemory_dict[job_pid] = self.sharedmemory_list[i] 241 | break 242 | if (i == self.compound_data_queue_num - 1) and job_pid != self.arr[i]: 243 | print("error", job_pid, self.arr) 244 | except Exception as err: 245 | raise err 246 | finally: 247 | self.lock.release() 248 | 249 | if vname in self.single_level_vnames: 250 | file = self.single_file_list[idx] 251 | url = f"{self.data_dir}/{file}-{vname}.npy" 252 | elif vname in self.multi_level_vnames: 253 | file = self.file_list[idx] 254 | url = f"{self.data_dir}/{file}-{vname}-{height}.0.npy" 255 | b = np.ndarray(self.a.shape, dtype=self.a.dtype, buffer=self.sharedmemory_dict[job_pid].buf) 256 | unit_data = np.load(url) 257 | # unit_data = unit_data[np.newaxis, :, :] 258 | if self.rm_equator: 259 | b[self.index_dict1[(vname, height)], :360] = unit_data[:360] 260 | b[self.index_dict1[(vname, height)], 360:] = unit_data[361:] 261 | else: 262 | b[self.index_dict1[(vname, height)], :] = unit_data[:] 263 | del unit_data 264 | self.unit_data_queue.put((job_pid, idx, vname, height)) 265 | 266 | def __len__(self): 267 | 268 | if self.split != "test": 269 | data_len = (len(self.file_list) - (self.length - 1) * self.sample_stride) // (self.train_stride // self.sample_stride // self.file_stride) 270 | elif self.use_gt: 271 | data_len = len(self.file_list) - (self.length - 1) * self.sample_stride 272 | data_len -= self.pred_length * self.sample_stride + 1 273 | data_len = (data_len + max(self.inference_stride // self.sample_stride // self.file_stride, 1) - 1) // max(self.inference_stride // self.sample_stride // self.file_stride, 1) 274 | else: 275 | data_len = len(self.file_list) - (self.length - 1) * self.sample_stride 276 | data_len = (data_len + max(self.inference_stride // self.sample_stride // self.file_stride, 1) - 1) // max(self.inference_stride // self.sample_stride // self.file_stride, 1) 277 | 278 | return data_len 279 | 280 | def get_meanstd(self): 281 | return_data_mean = [] 282 | return_data_std = [] 283 | 284 | for vname in self.single_level_vnames: 285 | return_data_mean.append(self.mean_std['mean'][vname]) 286 | return_data_std.append(self.mean_std['std'][vname]) 287 | for vname in self.multi_level_vnames: 288 | return_data_mean.append(self.mean_std['mean'][vname][self.height_level_indexes]) 289 | return_data_std.append(self.mean_std['std'][vname][self.height_level_indexes]) 290 | 291 | return torch.from_numpy(np.concatenate(return_data_mean, axis=0)[:, 0, 0]), torch.from_numpy(np.concatenate(return_data_std, axis=0)[:, 0, 0]) 292 | 293 | def __getitem__(self, index): 294 | index = min(index, len(self.file_list) - (self.length-1) * self.sample_stride - 1) 295 | if self.split == "test": 296 | index = index * max(self.inference_stride // self.sample_stride // self.file_stride, 1) 297 | else: 298 | index = index * (self.train_stride // self.sample_stride // self.file_stride) 299 | array_seq = self.get_data([index, index + self.sample_stride, index + (self.length-1) * self.sample_stride]) 300 | tar_idx = np.array([index + self.sample_stride * (self.length - 1)]) 301 | return array_seq, tar_idx 302 | -------------------------------------------------------------------------------- /datasets/mean_std.json: -------------------------------------------------------------------------------- 1 | {"mean": {"z_overall": 125173.77185546875, "z": [777.5631800842285, 2806.6911419677767, 4880.769941406251, 7002.870792236329, 9175.857954101568, 11402.84267333983, 13687.02337158203, 16031.525583496094, 18439.37236816406, 20913.51230468748, 23457.279204101567, 28769.254521484374, 34409.91857421876, 40425.96180664062, 46874.99372070313, 53826.54542968748, 61368.859130859375, 69620.0044531249, 78745.945078125, 88999.83613281258, 100822.13164062506, 107551.17681640621, 115011.55044921875, 123395.13201171876, 132973.8087890624, 144168.34273437515, 157706.1917968749, 179254.20781250007, 199832.31609374992, 231648.1046875, 257369.633828125, 302391.6971875003, 326148.29640625, 349188.804375, 385898.7467187499, 416444.811875, 470007.90359375], "q_overall": 0.0016609070883714593, "q": [0.006659231842495503, 0.006508124030660836, 0.006180181269301101, 0.005698622210184111, 0.005229148901998995, 0.004779181429184972, 0.004328316930914292, 0.003895292008528486, 0.003504167633363977, 0.003151517739752307, 0.002823496932396666, 0.002245682779466732, 0.0017783852928550915, 0.0014082587775192225, 0.0010869937593815861, 0.0007819174298492726, 0.0005401131762482684, 0.0003586592462670523, 0.00022093375020631356, 0.00012106754767955863, 5.37612270545651e-05, 3.1880958731562706e-05, 1.7365863168379306e-05, 8.887264439181295e-06, 4.689598504228342e-06, 3.032497714912096e-06, 2.557213611567022e-06, 2.7101390543293752e-06, 2.8248029025235157e-06, 2.8785681178078434e-06, 2.918656103929608e-06, 3.0804213855617494e-06, 3.2166743511652386e-06, 3.3617744981029316e-06, 3.545218272620331e-06, 3.651859217939091e-06, 3.8337135470101216e-06], "u_overall": 5.067544518113135, "u": [-0.18604825597634778, 0.014575419906759663, 0.1831986801231688, 0.3961004569224316, 0.6612499056267552, 0.9589599340630222, 1.2827875773236155, 1.6256907974928618, 1.9794725690782076, 2.3418910357356073, 2.7121086287498475, 3.4704639464616776, 4.26914108872414, 5.110536641478544, 6.038327068090436, 7.07798705458641, 8.268079032897964, 9.628437678813944, 11.183755590915675, 12.895844810009004, 14.48938421010971, 15.056942412853239, 15.298378415107727, 15.109850177764892, 14.321160042285918, 12.678406665325161, 10.046632840633391, 6.356182322502137, 4.44909584343433, 0.7572945548904316, -2.5366748879943004, -4.663122029609662, -3.6527538601704896, -1.2890704965137414, 1.792431366236417, 3.4357510804757476, 5.936705157756806], "v_overall": 0.028403457459244247, "v": [-0.15797925526494516, -0.152254779834766, -0.11334734256146482, -0.062323193841984835, -0.02435781713542383, -0.0020319510312947365, 0.014846984493779027, 0.026666699802153734, 0.03010514150775634, 0.028390460408554644, 0.024946943822433242, 0.0075308467486320295, -0.011855637595126613, -0.022026948712546065, -0.017326252191560337, 0.0007164070030103152, -0.008457765014013602, -0.04722135595180817, -0.059056078250941904, 0.03872977272505523, 0.21001753183547414, 0.26851208267849863, 0.2956721917196408, 0.29721465504262573, 0.2678451650420902, 0.1749421462348983, 0.1010729405652091, 0.035401583576604036, 0.012106836824341376, 0.020043842778977715, 0.00465771341478103, -0.03084933991518482, -0.06255946054356175, -0.09668692563471275, -0.07074970662535636, -0.008126834754220965, 0.13870985066001587], "t_overall": 247.24390762329102, "t": [278.5929747772214, 277.2997010040281, 276.21999847412104, 275.3001181793211, 274.4359161376951, 273.6013989257812, 272.77368919372566, 271.89837799072285, 270.9205000305179, 269.8229228210445, 268.63149719238277, 265.99485839843743, 263.0967428588868, 259.84156120300344, 256.09259799957306, 251.74072200775146, 246.77826755523682, 241.16466262817383, 234.9638207244873, 228.6614455413818, 223.53410289764412, 221.58114109039306, 219.73181056976318, 217.6849419784546, 215.23375904083258, 212.61522850036621, 210.3573041915893, 211.44547527313233, 214.66564151763913, 219.08330230712897, 222.90260936737062, 229.70379344940184, 234.8047568511963, 242.76109550476073, 257.88357688903875, 266.1968101501466, 270.00755813598647], "r_overall": 40.394069585800146, "r": [80.17475992202763, 81.89339225769044, 82.05619462966922, 80.0935426521301, 77.31068056106565, 74.06146982192983, 70.40090473175046, 66.74169216156014, 63.65694343566894, 61.238844032287595, 59.20706973075867, 55.94804849624633, 53.78221703529358, 52.604153957366925, 51.98551220893857, 51.816392965316794, 52.520214595794705, 54.109504508972165, 55.61435362815857, 53.635069522857655, 42.4717398929596, 35.659886183738735, 29.037648782730102, 23.85700461387634, 21.362764282226564, 21.92401524066924, 23.441656923294072, 11.165672419071193, 3.7584114503860486, 1.6763820819556718, 0.9495417647063722, 0.2976933006197218, 0.10302423514425761, 0.02450558412820101, 0.0026089991186745458, 0.0007830015913350504, 0.0002408742079387594], "w_overall": 0.004383492262859363, "w": [0.022094985805451873, 0.0216572742210701, 0.019905026827473193, 0.017870978915598242, 0.01588542067212984, 0.013868597075343127, 0.011905601065373044, 0.010102826043148527, 0.008454037891642651, 0.006905023376311874, 0.005377083962594045, 0.002986208360084676, 0.0012120617383675368, 0.0008260514670288897, 0.000692131803097027, 0.000629533462382596, 0.0005991482839004908, 0.0005541121261796886, 0.0004358308141377164, 0.0002701366509023725, 0.00013300039120110795, 0.00010626948278826373, 9.627545935728681e-05, 8.954200330800081e-05, 6.651518481056939e-05, 2.8853250523752642e-05, -1.511444107139435e-05, -5.620325117561725e-05, -7.897719135826269e-05, -9.763219824087352e-05, -9.472476952721467e-05, -7.022762941403469e-05, -5.8008067545154444e-05, -4.455210098275765e-05, -2.651977333063105e-05, -1.5596725875788307e-05, -5.785629261385584e-06]}, "std": {"z_overall": 129782.78276610909, "z": [1098.9952409939283, 1112.6027805778692, 1142.3473460665198, 1187.5419349409494, 1246.4561469025161, 1317.5100178215614, 1399.3410970050247, 1490.8825928657998, 1590.9548416558155, 1698.5379482618182, 1812.9974354157323, 2060.9872289877453, 2334.319060376128, 2629.7201995715513, 2949.675806600318, 3299.702929930327, 3683.8414755168196, 4104.233456672573, 4557.44791664416, 5020.194961603476, 5405.73040397736, 5516.599935297351, 5540.73074484052, 5460.536820725648, 5253.301115477269, 4886.876297159263, 4357.588191568988, 3756.1336388685854, 3755.2810557402927, 4396.22650565561, 5144.265045354613, 6650.516303503411, 7495.09074134178, 8295.69333798942, 9517.614470978018, 10519.593543281839, 12104.888878805084], "q_overall": 0.0034514906703841854, "q": [0.00611814321149996, 0.005991895878203082, 0.005660814318820206, 0.005201345241925773, 0.004817240293151429, 0.004489990425884411, 0.004184742037434761, 0.0038904524158168458, 0.0036034287117598773, 0.0033171467057458016, 0.0030338747907811383, 0.0024928432426471183, 0.0020354637234126743, 0.0016946778969914426, 0.0013932648476959186, 0.001023028433607086, 0.0007179488022534093, 0.00048331132466880735, 0.0002993454787338255, 0.00016131853114827985, 7.02963683704546e-05, 4.0895619708336846e-05, 2.093742795871515e-05, 9.154817617587747e-06, 3.1627283344500357e-06, 7.369263717657513e-07, 4.2315237954921815e-07, 2.105491580661734e-07, 1.1555282996146702e-07, 1.6058042115095855e-07, 2.133802768772255e-07, 2.9865161188496304e-07, 3.153152311774654e-07, 3.0755590168313815e-07, 2.862195110630825e-07, 2.547970205742875e-07, 1.620701293822839e-07], "u_overall": 16.97319982237474, "u": [6.126056325796786, 7.016566034258269, 7.643058800982493, 7.93264239491015, 8.043800524363755, 8.105094821089457, 8.179197305830433, 8.287887909395936, 8.433526267376907, 8.60957088282445, 8.810868039599674, 9.286354460633103, 9.850636072885855, 10.503871726997783, 11.285900449345242, 12.215305549952125, 13.290793204085057, 14.509924979003983, 15.840437203647705, 17.078782531259524, 17.720698660431694, 17.667616741159385, 17.286773058038722, 16.542030679458666, 15.407016747306344, 13.807092388343909, 11.884088705628045, 10.334116775694541, 11.557361639969054, 14.832938252653848, 19.184593851862232, 25.245251895546826, 26.307161489246376, 26.068781364502108, 27.582285843981804, 29.507017272843985, 32.21467925596289], "v_overall": 10.225347045502723, "v": [5.23175612906023, 5.924618269023251, 6.285772466913804, 6.345356147017278, 6.279593292801223, 6.209277530023286, 6.186618416862026, 6.213000051694466, 6.279984926208441, 6.380234677983426, 6.510405300881028, 6.858187372121642, 7.312336974076069, 7.835396470389893, 8.498212257679501, 9.321096224035244, 10.312266990636417, 11.44656476066317, 12.61520885057924, 13.474533447403218, 13.360381609448558, 12.771109679285894, 11.896325029659364, 10.936078921724906, 9.998695230009567, 9.05105707564941, 8.178248048405905, 7.7733598444254, 8.417864770061094, 10.031378889054897, 11.362479417901184, 13.548648819446889, 14.409961339641928, 14.843368077149982, 15.355139749619095, 15.671183984838587, 15.620566981759925], "t_overall": 27.05388027783469, "t": [18.59201174892538, 17.962722435516273, 17.354568742491193, 16.88523210573196, 16.50532064048834, 16.214283284136293, 16.00275340237577, 15.829774257559896, 15.664667994261846, 15.503739090214387, 15.35464052689256, 15.021590351519892, 14.443361586551287, 13.858620163486986, 13.612509319936459, 13.459313552230206, 13.222745493030287, 12.773099916244078, 11.921086612374525, 10.212310312072752, 7.389004707914384, 6.217593563425525, 5.933385737657316, 6.957734476103884, 9.090666595626503, 11.748988502455319, 13.738672642636256, 12.222179433839193, 9.495652698988557, 8.418363124927735, 8.71613081797916, 9.979142288488994, 10.868110600253445, 11.453708697281687, 12.03482655512657, 12.18683909538171, 10.27148161649973], "r_overall": 38.27834275264974, "r": [18.065236846880957, 18.884182946667348, 19.9694447391712, 21.201659553304538, 22.682946621307828, 24.32165800612575, 26.05935412034893, 27.633995796855135, 28.837421755290475, 29.762150422980042, 30.498343554542576, 31.567549251190137, 32.358281367806725, 32.994946188446285, 33.54817198803172, 33.83420782522547, 34.1020571373513, 34.15576061007992, 33.97040860449722, 33.27175094990293, 33.824094547943524, 33.94965762113503, 32.2278297456812, 30.217155336090716, 29.454141244285715, 31.103837052430183, 33.23741657963707, 13.81056340487671, 5.50629572264854, 5.140879239893525, 3.7486899733037617, 1.3911355900931046, 0.45882004791501424, 0.08512110306079632, 0.004894727016083146, 0.0012718724357374124, 0.00047595219931748637], "w_overall": 0.15754408345291773, "w": [0.13979567674915916, 0.15479160882205267, 0.17909532539790496, 0.2025255148068513, 0.21983137560176763, 0.23023882117421315, 0.23583433097628145, 0.23861934853804015, 0.23947601319694647, 0.2388890682549996, 0.23720531494087146, 0.2326563138856663, 0.22800938309402644, 0.22319334450015696, 0.21895367863698434, 0.21374423887468189, 0.20502274496551173, 0.1897976164287798, 0.16794512621224, 0.14051455042850863, 0.11142703377201577, 0.09759972206579746, 0.0839118257220142, 0.0699250713532834, 0.05505537455154692, 0.0399342745096499, 0.0266884524850588, 0.016159587930447316, 0.012212181977916497, 0.008866289960038784, 0.007027148644031218, 0.004681425194685117, 0.0036326590952213142, 0.00276010917035495, 0.002059059572163848, 0.0016134326105307103, 0.0008927991973868609]}, "count": 400} -------------------------------------------------------------------------------- /datasets/mean_std_single.json: -------------------------------------------------------------------------------- 1 | {"mean": {"v10": [0.22575792335029873], "u10": [-0.14186215714480854], "t2m": [278.7854495405721], "sp": [96672.96548579562], "msl": [100980.83590625007],"tp6h": [0.5938965440938924]}, "std": {"v10": [4.798220612223473], "u10": [5.610453475051704], "t2m": [21.32010786700973], "sp": [9580.331095986492], "msl": [1336.2115992274876], "tp6h": [1.5731802126254537]}, "count": 2750} -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | from utils.builder import ConfigBuilder 5 | import utils.misc as utils 6 | import yaml 7 | from utils.logger import get_logger 8 | import copy 9 | 10 | 11 | 12 | def subprocess_fn(args): 13 | utils.setup_seed(args.seed * args.world_size + args.rank) 14 | 15 | logger = get_logger("test", args.rundir, utils.get_rank(), filename='infer.log') 16 | args.cfg_params["logger"] = logger 17 | 18 | # build config 19 | logger.info('Building config ...') 20 | builder = ConfigBuilder(**args.cfg_params) 21 | 22 | # build model 23 | logger.info('Building models ...') 24 | model = builder.get_model() 25 | checkpoint_dict = torch.load(os.path.join(args.rundir, 'best_model.pth'), map_location=torch.device('cpu')) 26 | model.kernel.load_state_dict(checkpoint_dict) 27 | model.kernel = utils.DistributedParallel_Model(model.kernel, args.local_rank) 28 | 29 | # build forecast model 30 | logger.info('Building forecast models ...') 31 | args.forecast_model = builder.get_forecast(args.local_rank) 32 | 33 | # build dataset 34 | logger.info('Building dataloaders ...') 35 | dataset_params = args.cfg_params['dataset'] 36 | test_dataloader = builder.get_dataloader(dataset_params=dataset_params, split='test', batch_size=args.batch_size) 37 | 38 | # inference 39 | logger.info('begin testing ...') 40 | model.test(test_dataloader, logger, args) 41 | logger.info('testing end ...') 42 | 43 | 44 | def main(args): 45 | if args.world_size > 1: 46 | utils.init_distributed_mode(args) 47 | else: 48 | args.rank = 0 49 | args.local_rank = 0 50 | args.distributed = False 51 | args.gpu = 0 52 | torch.cuda.set_device(args.gpu) 53 | 54 | args.rundir = os.path.join(args.rundir, f'mask{args.ratio}_lead{args.lead_time}h_res{args.resolution}') 55 | args.cfg = os.path.join(args.rundir, 'train.yaml') 56 | with open(args.cfg, 'r') as cfg_file: 57 | cfg_params = yaml.load(cfg_file, Loader = yaml.FullLoader)['cfg_params'] 58 | 59 | cfg_params['dataloader']['num_workers'] = args.per_cpus 60 | cfg_params['dataset']['test'] = copy.deepcopy(cfg_params['dataset']['train']) 61 | args.cfg_params = cfg_params 62 | 63 | subprocess_fn(args) 64 | 65 | 66 | if __name__ == "__main__": 67 | 68 | parser = argparse.ArgumentParser() 69 | 70 | parser.add_argument('--seed', type = int, default = 0, help = 'seed') 71 | parser.add_argument('--cuda', type = int, default = 0, help = 'cuda id') 72 | parser.add_argument('--world_size', type = int, default = 1, help = 'number of progress') 73 | parser.add_argument('--per_cpus', type = int, default = 4, help = 'number of perCPUs to use') 74 | parser.add_argument('--batch_size', type = int, default = 1, help = "batch size") 75 | parser.add_argument('--lead_time', type = int, default = 24, help = "lead time (h) for background") 76 | parser.add_argument('--ratio', type = float, default = 0.9, help = "mask ratio") 77 | parser.add_argument('--resolution', type = int, default = 128, help = "observation resolution") 78 | parser.add_argument('--init_method', type = str, default = 'tcp://127.0.0.1:19111', help = 'multi process init method') 79 | parser.add_argument('--rundir', type = str, default = './configs/FNP', help = 'where to save the results') 80 | 81 | args = parser.parse_args() 82 | 83 | main(args) 84 | 85 | -------------------------------------------------------------------------------- /models/Adas.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import time 5 | import numpy as np 6 | from timm.models.layers import trunc_normal_ 7 | import utils.misc as utils 8 | from utils.metrics import WRMSE 9 | from functools import partial 10 | from modules import AllPatchEmbed, PatchRecover, BasicLayer, SwinTransformerLayer 11 | from utils.builder import get_optimizer, get_lr_scheduler 12 | 13 | 14 | 15 | class Adas_model(nn.Module): 16 | def __init__(self, img_size=(69,128,256), dim=96, patch_size=(1,2,2), window_size=(2,4,8), depth=8, num_heads=4, 17 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, ape=True, use_checkpoint=False): 18 | super().__init__() 19 | 20 | self.patchembed = AllPatchEmbed(img_size=img_size, embed_dim=dim, patch_size=patch_size, norm_layer=nn.LayerNorm) # b,c,14,180,360 21 | self.patchunembed = PatchRecover(img_size=img_size, embed_dim=dim, patch_size=patch_size) 22 | self.patch_resolution = self.patchembed.patch_resolution 23 | 24 | self.layer1 = BasicLayer(dim, kernel=(3,5,7), padding=(1,2,3), num_heads=num_heads, window_size=window_size, use_checkpoint=use_checkpoint) # s1 25 | self.layer2 = BasicLayer(dim*2, kernel=(3,3,5), padding=(1,1,2), num_heads=num_heads, window_size=window_size, sample='down', use_checkpoint=use_checkpoint) # s2 26 | self.layer3 = BasicLayer(dim*4, kernel=3, padding=1, num_heads=num_heads, window_size=window_size, sample='down', use_checkpoint=use_checkpoint) # s3 27 | self.layer4 = BasicLayer(dim*2, kernel=(3,3,5), padding=(1,1,2), num_heads=num_heads, window_size=window_size, sample='up', use_checkpoint=use_checkpoint) # s2 28 | self.layer5 = BasicLayer(dim, kernel=(3,5,7), padding=(1,2,3), num_heads=num_heads, window_size=window_size, sample='up', use_checkpoint=use_checkpoint) # s1 29 | 30 | self.fusion = nn.Conv3d(dim*3, dim, kernel_size=(3,5,7), stride=1, padding=(1,2,3)) 31 | 32 | # absolute position embedding 33 | self.ape = ape 34 | if self.ape: 35 | self.absolute_pos_embed = nn.Parameter(torch.zeros(1, dim, self.patch_resolution[0], self.patch_resolution[1], self.patch_resolution[2])) 36 | trunc_normal_(self.absolute_pos_embed, std=.02) 37 | 38 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 39 | self.decoder = SwinTransformerLayer(dim=dim, depth=depth, num_heads=num_heads, window_size=window_size, qkv_bias=True, 40 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, use_checkpoint=use_checkpoint) 41 | 42 | # initial weights 43 | self.apply(self._init_weights) 44 | 45 | def _init_weights(self, m): 46 | 47 | if isinstance(m, nn.Linear): 48 | trunc_normal_(m.weight, std=.02) 49 | if isinstance(m, nn.Linear) and m.bias is not None: 50 | nn.init.constant_(m.bias, 0) 51 | elif isinstance(m, nn.LayerNorm): 52 | nn.init.constant_(m.bias, 0) 53 | nn.init.constant_(m.weight, 1.0) 54 | 55 | def encoder_forward(self, x): 56 | 57 | x1 = self.layer1(x) 58 | x2 = self.layer2(x1) 59 | x = self.layer3(x2) 60 | x = self.layer4(x, x2) 61 | x = self.layer5(x, x1) 62 | 63 | return x 64 | 65 | def forward(self, background, observation, mask): 66 | 67 | x = self.patchembed(background, observation, mask) 68 | if self.ape: 69 | x = [ x[i] + self.absolute_pos_embed for i in range(3) ] 70 | 71 | x = self.encoder_forward(x) 72 | x = self.fusion(torch.cat(x, dim=1)) 73 | x = self.decoder(x) 74 | 75 | x = self.patchunembed(x) 76 | return x 77 | 78 | 79 | class Adas(object): 80 | 81 | def __init__(self, **model_params) -> None: 82 | super().__init__() 83 | 84 | params = model_params.get('params', {}) 85 | criterion = model_params.get('criterion', 'CNPFLoss') 86 | self.optimizer_params = model_params.get('optimizer', {}) 87 | self.scheduler_params = model_params.get('lr_scheduler', {}) 88 | 89 | self.kernel = Adas_model(**params) 90 | self.best_loss = 9999 91 | self.criterion = self.get_criterion(criterion) 92 | self.criterion_mae = nn.L1Loss() 93 | self.criterion_mse = nn.MSELoss() 94 | 95 | if utils.is_dist_avail_and_initialized(): 96 | self.device = torch.device('cuda' if torch.cuda.is_available() else "cpu") 97 | if self.device == torch.device('cpu'): 98 | raise EnvironmentError('No GPUs, cannot initialize multigpu training.') 99 | else: 100 | self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 101 | 102 | def get_criterion(self, loss_type): 103 | if loss_type == 'UnifyMAE': 104 | return partial(self.unify_losses, criterion=nn.L1Loss()) 105 | elif loss_type == 'UnifyMSE': 106 | return partial(self.unify_losses, criterion=nn.MSELoss()) 107 | else: 108 | raise NotImplementedError('Invalid loss type.') 109 | 110 | def unify_losses(self, pred, target, criterion): 111 | loss_sum = 0 112 | unify_loss = criterion(pred[:,0,:,:], target[:,0,:,:]) 113 | for i in range(1, len(pred[0])): 114 | loss = criterion(pred[:,i,:,:], target[:,i,:,:]) 115 | loss_sum += loss / (loss/unify_loss).detach() 116 | return (loss_sum + unify_loss) / len(pred[0]) 117 | 118 | def process_data(self, batch_data, args): 119 | 120 | inp_data = torch.cat([batch_data[0], batch_data[1]], dim=1) 121 | inp_data = F.interpolate(inp_data, size=(128,256), mode='bilinear').numpy() 122 | truth = batch_data[-1].to(self.device, non_blocking=True) # 69 123 | truth = F.interpolate(truth, size=(args.resolution,args.resolution//2*4), mode='bilinear') 124 | truth_down = F.interpolate(truth, size=(128,256), mode='bilinear') 125 | 126 | for _ in range(args.lead_time // 6): 127 | predict_data = args.forecast_model.run(None, {'input':inp_data})[0][:,:truth.shape[1]] 128 | inp_data = np.concatenate([inp_data[:,-truth.shape[1]:], predict_data], axis=1) 129 | 130 | background = torch.from_numpy(predict_data).to(self.device, non_blocking=True) 131 | mask = (torch.rand(truth.shape, device=self.device) >= args.ratio).float() 132 | observation = truth * mask 133 | mask = F.interpolate(mask, size=(128,256), mode='bilinear') 134 | observation = F.interpolate(observation, size=(128,256), mode='bilinear') 135 | observation = torch.where(mask==0, 0., observation/mask).to(self.device, non_blocking=True) 136 | mask = torch.where(mask==0, 0., 1.).to(self.device, non_blocking=True) 137 | 138 | return [background, observation, mask], truth_down 139 | 140 | def train(self, train_data_loader, valid_data_loader, logger, args): 141 | 142 | train_step = len(train_data_loader) 143 | valid_step = len(valid_data_loader) 144 | self.optimizer = get_optimizer(self.kernel, self.optimizer_params) 145 | self.scheduler = get_lr_scheduler(self.optimizer, self.scheduler_params, total_steps=train_step*args.max_epoch) 146 | 147 | for epoch in range(args.max_epoch): 148 | begin_time = time.time() 149 | self.kernel.train() 150 | 151 | for step, batch_data in enumerate(train_data_loader): 152 | 153 | input_list, y_target = self.process_data(batch_data[0], args) 154 | self.optimizer.zero_grad() 155 | y_pred = self.kernel(input_list[0], input_list[1], input_list[2]) 156 | loss = self.criterion(y_pred, y_target) 157 | loss.backward() 158 | self.optimizer.step() 159 | self.scheduler.step() 160 | 161 | if ((step + 1) % 100 == 0) | (step+1 == train_step): 162 | logger.info(f'Train epoch:[{epoch+1}/{args.max_epoch}], step:[{step+1}/{train_step}], lr:[{self.scheduler.get_last_lr()[0]}], loss:[{loss.item()}]') 163 | 164 | self.kernel.eval() 165 | with torch.no_grad(): 166 | total_loss = 0 167 | 168 | for step, batch_data in enumerate(valid_data_loader): 169 | input_list, y_target = self.process_data(batch_data[0], args) 170 | y_pred = self.kernel(input_list[0], input_list[1], input_list[2]) 171 | loss = self.criterion(y_pred, y_target).item() 172 | total_loss += loss 173 | 174 | if ((step + 1) % 100 == 0) | (step+1 == valid_step): 175 | logger.info(f'Valid epoch:[{epoch+1}/{args.max_epoch}], step:[{step+1}/{valid_step}], loss:[{loss}]') 176 | 177 | if (total_loss/valid_step) < self.best_loss: 178 | if utils.get_world_size() > 1 and utils.get_rank() == 0: 179 | torch.save(self.kernel.module.state_dict(), f'{args.rundir}/best_model.pth') 180 | elif utils.get_world_size() == 1: 181 | torch.save(self.kernel.state_dict(), f'{args.rundir}/best_model.pth') 182 | logger.info(f'New best model appears in epoch {epoch+1}.') 183 | self.best_loss = total_loss/valid_step 184 | logger.info(f'Epoch {epoch+1} average loss:[{total_loss/valid_step}], time:[{time.time()-begin_time}]') 185 | 186 | def test(self, test_data_loader, logger, args): 187 | 188 | test_step = len(test_data_loader) 189 | data_mean, data_std = test_data_loader.dataset.get_meanstd() 190 | self.data_std = data_std.to(self.device) 191 | 192 | self.kernel.eval() 193 | with torch.no_grad(): 194 | total_loss = 0 195 | total_mae = 0 196 | total_mse = 0 197 | total_rmse = 0 198 | 199 | for step, batch_data in enumerate(test_data_loader): 200 | 201 | input_list, y_target = self.process_data(batch_data[0], args) 202 | y_pred = self.kernel(input_list[0], input_list[1], input_list[2]) 203 | loss = self.criterion(y_pred, y_target).item() 204 | mae = self.criterion_mae(y_pred, y_target).item() 205 | mse = self.criterion_mse(y_pred, y_target).item() 206 | rmse = WRMSE(y_pred, y_target, self.data_std) 207 | 208 | total_loss += loss 209 | total_mae += mae 210 | total_mse += mse 211 | total_rmse += rmse 212 | if ((step + 1) % 100 == 0) | (step+1 == test_step): 213 | logger.info(f'Valid step:[{step+1}/{test_step}], loss:[{loss}], MAE:[{mae}], MSE:[{mse}]') 214 | 215 | logger.info(f'Average loss:[{total_loss/test_step}], MAE:[{total_mae/test_step}], MSE:[{total_mse/test_step}]') 216 | logger.info(f'Average RMSE:[{total_rmse/test_step}]') 217 | -------------------------------------------------------------------------------- /models/ConvCNP.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.nn.utils import clip_grad_norm_ 7 | from functools import partial 8 | from einops import rearrange 9 | import utils.misc as utils 10 | from utils.metrics import WRMSE 11 | from utils.builder import get_optimizer, get_lr_scheduler 12 | import modules 13 | 14 | 15 | class ConvCNP_model(nn.Module): 16 | 17 | def __init__( 18 | self, 19 | x_dim=69, 20 | y_dim=69, 21 | r_dim=512, 22 | XEncoder=nn.Identity, 23 | Conv=lambda y_dim: modules.make_abs_conv(nn.Conv2d)( 24 | y_dim, 25 | y_dim, 26 | groups=y_dim, 27 | kernel_size=11, 28 | padding=11 // 2, 29 | bias=False, 30 | ), 31 | CNN=partial( 32 | modules.CNN, 33 | ConvBlock=modules.ResConvBlock, 34 | Conv=nn.Conv2d, 35 | n_blocks=12, 36 | Normalization=nn.BatchNorm2d, 37 | is_chan_last=True, 38 | kernel_size=9, 39 | n_conv_layers=2, 40 | ), 41 | PredictiveDistribution=modules.MultivariateNormalDiag, 42 | p_y_loc_transformer=nn.Identity(), 43 | p_y_scale_transformer=lambda y_scale: 0.01 + 0.99 * F.softplus(y_scale), 44 | ): 45 | super().__init__() 46 | 47 | self.x_dim = x_dim 48 | self.y_dim = y_dim 49 | self.r_dim = r_dim 50 | 51 | self.conv = nn.ModuleList([Conv(y_dim), Conv(y_dim)]) 52 | self.resizer = nn.ModuleList([nn.Linear(self.y_dim * 2, self.r_dim), nn.Linear(self.y_dim * 2, self.r_dim)]) 53 | self.induced_to_induced = nn.ModuleList([CNN(self.r_dim), CNN(self.r_dim)]) 54 | self.fusion = nn.Linear(self.r_dim * 2, self.r_dim) 55 | self.x_encoder = XEncoder() 56 | 57 | Decoder=modules.discard_ith_arg(partial(modules.MLP, n_hidden_layers=4, hidden_size=self.r_dim), i=0) 58 | # times 2 out because loc and scale (mean and var for gaussian) 59 | self.decoder = Decoder(self.x_dim, self.r_dim, self.y_dim * 2) 60 | 61 | self.PredictiveDistribution = PredictiveDistribution 62 | self.p_y_loc_transformer = p_y_loc_transformer 63 | self.p_y_scale_transformer = p_y_scale_transformer 64 | 65 | self.reset_parameters() 66 | 67 | def reset_parameters(self): 68 | modules.weights_init(self) 69 | 70 | def forward(self, input_list): 71 | 72 | X_cntxt, Y_cntxt, Xb_cntxt, Yb_cntxt, X_trgt = input_list[0], input_list[1], input_list[2], input_list[3], input_list[4] 73 | X_cntxt = self.x_encoder(X_cntxt) # b,h,w,c 74 | X_trgt = self.x_encoder(X_trgt) 75 | Xb_cntxt = self.x_encoder(Xb_cntxt) # b,h,w,c 76 | 77 | # {R^u}_u 78 | # size = [batch_size, *n_rep, r_dim] for n_channels list 79 | R = self.encode_globally(X_cntxt, Y_cntxt, Xb_cntxt, Yb_cntxt) 80 | 81 | z_samples, q_zCc, q_zCct = None, None, None 82 | 83 | # size = [n_z_samples, batch_size, *n_trgt, r_dim] 84 | R_trgt = self.trgt_dependent_representation(X_cntxt, z_samples, R, X_trgt) 85 | 86 | # p(y|cntxt,trgt) 87 | # batch shape=[n_z_samples, batch_size, *n_trgt] ; event shape=[y_dim] 88 | p_yCc = self.decode(X_trgt, R_trgt) 89 | 90 | return p_yCc, z_samples, q_zCc, q_zCct 91 | 92 | def cntxt_to_induced(self, mask_cntxt, X, index): 93 | """Infer the missing values and compute a density channel.""" 94 | 95 | # channels have to be in second dimension for convolution 96 | # size = [batch_size, y_dim, *grid_shape] 97 | X = modules.channels_to_2nd_dim(X) 98 | # size = [batch_size, x_dim, *grid_shape] 99 | mask_cntxt = modules.channels_to_2nd_dim(mask_cntxt).float() 100 | 101 | # size = [batch_size, y_dim, *grid_shape] 102 | X_cntxt = X * mask_cntxt 103 | signal = self.conv[index](X_cntxt) 104 | density = self.conv[index](mask_cntxt.expand_as(X)) 105 | 106 | # normalize 107 | out = signal / torch.clamp(density, min=1e-5) 108 | 109 | # size = [batch_size, y_dim * 2, *grid_shape] 110 | out = torch.cat([out, density], dim=1) 111 | 112 | # size = [batch_size, *grid_shape, y_dim * 2] 113 | out = modules.channels_to_last_dim(out) 114 | 115 | # size = [batch_size, *grid_shape, r_dim] 116 | out = self.resizer[index](out) 117 | 118 | return out 119 | 120 | def encode_globally(self, mask_cntxt, X, mask_cntxtb, Xb): 121 | 122 | # size = [batch_size, *grid_shape, r_dim] for each single channel 123 | R_induced = self.cntxt_to_induced(mask_cntxt, X, index=0) 124 | R_induced = self.induced_to_induced[0](R_induced) 125 | 126 | Rb_induced = self.cntxt_to_induced(mask_cntxtb, Xb, index=1) 127 | Rb_induced = self.induced_to_induced[1](Rb_induced) 128 | 129 | R_induced = rearrange(R_induced, 'b h w c -> b c h w') 130 | R_induced = F.interpolate(R_induced, size=Rb_induced.shape[1:3], mode='bilinear') 131 | R_induced = rearrange(R_induced, 'b c h w -> b h w c') 132 | R_fusion = self.fusion(torch.cat([R_induced, Rb_induced], dim=-1)) 133 | 134 | return R_fusion 135 | 136 | def trgt_dependent_representation(self, _, __, R_induced, ___): 137 | 138 | # n_z_samples=1. size = [1, batch_size, n_trgt, r_dim] 139 | return R_induced.unsqueeze(0) 140 | 141 | def decode(self, X_trgt, R_trgt): 142 | 143 | # size = [n_z_samples, batch_size, *n_trgt, y_dim*2] 144 | p_y_suffstat = self.decoder(X_trgt, R_trgt) 145 | 146 | # size = [n_z_samples, batch_size, *n_trgt, y_dim] 147 | p_y_loc, p_y_scale = p_y_suffstat.split(self.y_dim, dim=-1) 148 | 149 | p_y_loc = self.p_y_loc_transformer(p_y_loc) 150 | p_y_scale = self.p_y_scale_transformer(p_y_scale) 151 | 152 | # batch shape=[n_z_samples, batch_size, *n_trgt] ; event shape=[y_dim] 153 | p_yCc = self.PredictiveDistribution(p_y_loc, p_y_scale) 154 | 155 | return p_yCc 156 | 157 | 158 | class ConvCNP(object): 159 | 160 | def __init__(self, **model_params) -> None: 161 | super().__init__() 162 | 163 | params = model_params.get('params', {}) 164 | criterion = model_params.get('criterion', 'CNPFLoss') 165 | self.optimizer_params = model_params.get('optimizer', {}) 166 | self.scheduler_params = model_params.get('lr_scheduler', {}) 167 | 168 | self.kernel = ConvCNP_model(**params) 169 | self.best_loss = 9999 170 | self.criterion = self.get_criterion(criterion) 171 | self.criterion_mae = nn.L1Loss() 172 | self.criterion_mse = nn.MSELoss() 173 | 174 | if utils.is_dist_avail_and_initialized(): 175 | self.device = torch.device('cuda' if torch.cuda.is_available() else "cpu") 176 | if self.device == torch.device('cpu'): 177 | raise EnvironmentError('No GPUs, cannot initialize multigpu training.') 178 | else: 179 | self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 180 | 181 | def get_criterion(self, loss_type): 182 | if loss_type == 'CNPFLoss': 183 | return modules.CNPFLoss() 184 | elif loss_type == 'NLLLossLNPF': 185 | return modules.NLLLossLNPF() 186 | elif loss_type == 'ELBOLossLNPF': 187 | return modules.ELBOLossLNPF() 188 | elif loss_type == 'SUMOLossLNPF': 189 | return modules.SUMOLossLNPF() 190 | else: 191 | raise NotImplementedError('Invalid loss type.') 192 | 193 | def process_data(self, batch_data, args): 194 | 195 | inp_data = torch.cat([batch_data[0], batch_data[1]], dim=1) 196 | inp_data = F.interpolate(inp_data, size=(128,256), mode='bilinear').numpy() 197 | truth = batch_data[-1].to(self.device, non_blocking=True) # 69 198 | truth = F.interpolate(truth, size=(args.resolution,args.resolution//2*4), mode='bilinear') 199 | truth_down = F.interpolate(truth, size=(128,256), mode='bilinear') 200 | 201 | for _ in range(args.lead_time // 6): 202 | predict_data = args.forecast_model.run(None, {'input':inp_data})[0][:,:truth.shape[1]] 203 | inp_data = np.concatenate([inp_data[:,-truth.shape[1]:], predict_data], axis=1) 204 | 205 | xb_context = rearrange(torch.rand(predict_data.shape, device=self.device) >= 0, 'b c h w -> b h w c') 206 | x_context = rearrange(torch.rand(truth.shape, device=self.device) >= args.ratio, 'b c h w -> b h w c') 207 | x_target = rearrange(torch.rand(truth_down.shape, device=self.device) >= 0, 'b c h w -> b h w c') 208 | yb_context = rearrange(torch.from_numpy(predict_data).to(self.device, non_blocking=True), 'b c h w -> b h w c') 209 | y_context = rearrange(truth, 'b c h w -> b h w c') 210 | y_target = rearrange(truth_down, 'b c h w -> b h w c') 211 | 212 | return [x_context, y_context, xb_context, yb_context, x_target], y_target 213 | 214 | def train(self, train_data_loader, valid_data_loader, logger, args): 215 | 216 | train_step = len(train_data_loader) 217 | valid_step = len(valid_data_loader) 218 | self.optimizer = get_optimizer(self.kernel, self.optimizer_params) 219 | self.scheduler = get_lr_scheduler(self.optimizer, self.scheduler_params, total_steps=train_step*args.max_epoch) 220 | 221 | for epoch in range(args.max_epoch): 222 | begin_time = time.time() 223 | self.kernel.train() 224 | 225 | for step, batch_data in enumerate(train_data_loader): 226 | 227 | input_list, y_target = self.process_data(batch_data[0], args) 228 | self.optimizer.zero_grad() 229 | y_pred = self.kernel(input_list) 230 | if isinstance(self.criterion, torch.nn.Module): 231 | self.criterion.train() 232 | loss = self.criterion(y_pred, y_target) 233 | loss.backward() 234 | clip_grad_norm_(self.kernel.parameters(), max_norm=1) 235 | self.optimizer.step() 236 | self.scheduler.step() 237 | 238 | if ((step + 1) % 100 == 0) | (step+1 == train_step): 239 | logger.info(f'Train epoch:[{epoch+1}/{args.max_epoch}], step:[{step+1}/{train_step}], lr:[{self.scheduler.get_last_lr()[0]}], loss:[{loss.item()}]') 240 | 241 | self.kernel.eval() 242 | with torch.no_grad(): 243 | total_loss = 0 244 | 245 | for step, batch_data in enumerate(valid_data_loader): 246 | input_list, y_target = self.process_data(batch_data[0], args) 247 | y_pred = self.kernel(input_list) 248 | if isinstance(self.criterion, torch.nn.Module): 249 | self.criterion.eval() 250 | loss = self.criterion(y_pred, y_target).item() 251 | total_loss += loss 252 | 253 | if ((step + 1) % 100 == 0) | (step+1 == valid_step): 254 | logger.info(f'Valid epoch:[{epoch+1}/{args.max_epoch}], step:[{step+1}/{valid_step}], loss:[{loss}]') 255 | 256 | if (total_loss/valid_step) < self.best_loss: 257 | if utils.get_world_size() > 1 and utils.get_rank() == 0: 258 | torch.save(self.kernel.module.state_dict(), f'{args.rundir}/best_model.pth') 259 | elif utils.get_world_size() == 1: 260 | torch.save(self.kernel.state_dict(), f'{args.rundir}/best_model.pth') 261 | logger.info(f'New best model appears in epoch {epoch+1}.') 262 | self.best_loss = total_loss/valid_step 263 | logger.info(f'Epoch {epoch+1} average loss:[{total_loss/valid_step}], time:[{time.time()-begin_time}]') 264 | 265 | def test(self, test_data_loader, logger, args): 266 | 267 | test_step = len(test_data_loader) 268 | data_mean, data_std = test_data_loader.dataset.get_meanstd() 269 | self.data_std = data_std.to(self.device) 270 | 271 | self.kernel.eval() 272 | with torch.no_grad(): 273 | total_loss = 0 274 | total_mae = 0 275 | total_mse = 0 276 | total_rmse = 0 277 | 278 | for step, batch_data in enumerate(test_data_loader): 279 | 280 | input_list, y_target = self.process_data(batch_data[0], args) 281 | y_pred = self.kernel(input_list) 282 | if isinstance(self.criterion, torch.nn.Module): 283 | self.criterion.eval() 284 | loss = self.criterion(y_pred, y_target).item() 285 | 286 | y_pred = rearrange(y_pred[0].mean[0], 'b h w c -> b c h w') 287 | y_target = rearrange(y_target, 'b h w c -> b c h w') 288 | mae = self.criterion_mae(y_pred, y_target).item() 289 | mse = self.criterion_mse(y_pred, y_target).item() 290 | rmse = WRMSE(y_pred, y_target, self.data_std) 291 | 292 | total_loss += loss 293 | total_mae += mae 294 | total_mse += mse 295 | total_rmse += rmse 296 | if ((step + 1) % 100 == 0) | (step+1 == test_step): 297 | logger.info(f'Valid step:[{step+1}/{test_step}], loss:[{loss}], MAE:[{mae}], MSE:[{mse}]') 298 | 299 | logger.info(f'Average loss:[{total_loss/test_step}], MAE:[{total_mae/test_step}], MSE:[{total_mse/test_step}]') 300 | logger.info(f'Average RMSE:[{total_rmse/test_step}]') 301 | -------------------------------------------------------------------------------- /models/FNP.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.nn.utils import clip_grad_norm_ 7 | from functools import partial 8 | from einops import rearrange 9 | import utils.misc as utils 10 | from utils.metrics import WRMSE 11 | from utils.builder import get_optimizer, get_lr_scheduler 12 | import modules 13 | 14 | 15 | class Encoder(nn.Module): 16 | 17 | def __init__( 18 | self, 19 | n_channels=[4,13,13,13,13,13], 20 | r_dim=64, 21 | XEncoder=nn.Identity, 22 | Conv=lambda y_dim: modules.make_abs_conv(nn.Conv2d)( 23 | y_dim, 24 | y_dim, 25 | groups=y_dim, 26 | kernel_size=11, 27 | padding=11 // 2, 28 | bias=False, 29 | ), 30 | CNN=partial( 31 | modules.CNN, 32 | ConvBlock=modules.ResConvBlock, 33 | Conv=nn.Conv2d, 34 | n_blocks=12, 35 | Normalization=nn.BatchNorm2d, 36 | activation=nn.SiLU(), 37 | is_chan_last=True, 38 | kernel_size=9, 39 | n_conv_layers=2, 40 | )): 41 | super().__init__() 42 | 43 | self.r_dim = r_dim 44 | self.n_channels = n_channels 45 | self.x_encoder = XEncoder() 46 | 47 | # components for encode_globally 48 | self.conv = [Conv(y_dim) for y_dim in n_channels] # for each single channel 49 | self.conv.append(modules.make_abs_conv(nn.Conv2d)( 50 | in_channels=sum(n_channels), 51 | out_channels=sum(n_channels), 52 | groups=1, 53 | kernel_size=11, 54 | padding=11 // 2, 55 | bias=False, 56 | )) # for all channels 57 | self.conv = nn.ModuleList(self.conv) 58 | 59 | self.resizer = [nn.Linear(y_dim * 2, self.r_dim) for y_dim in n_channels] # 2 because also confidence channels 60 | self.resizer.append(nn.Linear(sum(n_channels) * 2, self.r_dim)) 61 | self.resizer = nn.ModuleList(self.resizer) 62 | 63 | self.induced_to_induced = nn.ModuleList([CNN(self.r_dim) for _ in range(len(n_channels)+1)]) 64 | 65 | def forward(self, X_cntxt, Y_cntxt, X_trgt): 66 | 67 | X_cntxt = self.x_encoder(X_cntxt) # b,h,w,c 68 | X_trgt = self.x_encoder(X_trgt) 69 | 70 | # {R^u}_u 71 | # size = [batch_size, *n_rep, r_dim] for n_channels list 72 | R_trgt = self.encode_globally(X_cntxt, Y_cntxt) 73 | 74 | return R_trgt 75 | 76 | def cntxt_to_induced(self, mask_cntxt, X, index): 77 | """Infer the missing values and compute a density channel.""" 78 | 79 | # channels have to be in second dimension for convolution 80 | # size = [batch_size, y_dim, *grid_shape] 81 | X = modules.channels_to_2nd_dim(X) 82 | # size = [batch_size, x_dim, *grid_shape] 83 | mask_cntxt = modules.channels_to_2nd_dim(mask_cntxt).float() 84 | 85 | # size = [batch_size, y_dim, *grid_shape] 86 | X_cntxt = X * mask_cntxt 87 | signal = self.conv[index](X_cntxt) 88 | density = self.conv[index](mask_cntxt.expand_as(X)) 89 | 90 | # normalize 91 | out = signal / torch.clamp(density, min=1e-5) 92 | 93 | # size = [batch_size, y_dim * 2, *grid_shape] 94 | out = torch.cat([out, density], dim=1) 95 | 96 | # size = [batch_size, *grid_shape, y_dim * 2] 97 | out = modules.channels_to_last_dim(out) 98 | 99 | # size = [batch_size, *grid_shape, r_dim] 100 | out = self.resizer[index](out) 101 | 102 | return out 103 | 104 | def encode_globally(self, mask_cntxt, X): 105 | 106 | # size = [batch_size, *grid_shape, r_dim] for each single channel 107 | R_induced_all = [] 108 | for i in range(len(self.n_channels)): 109 | R_induced = self.cntxt_to_induced(mask_cntxt[...,sum(self.n_channels[:i]):sum(self.n_channels[:i+1])], 110 | X[...,sum(self.n_channels[:i]):sum(self.n_channels[:i+1])], i) 111 | R_induced = self.induced_to_induced[i](R_induced) 112 | R_induced_all.append(R_induced) 113 | # the last for all channels 114 | R_induced = self.cntxt_to_induced(mask_cntxt, X, len(self.n_channels)) 115 | R_induced = self.induced_to_induced[len(self.n_channels)](R_induced) 116 | R_induced_all.append(R_induced) 117 | 118 | return R_induced_all 119 | 120 | 121 | class FNP_model(nn.Module): 122 | 123 | def __init__( 124 | self, 125 | n_channels=[4,13,13,13,13,13], 126 | r_dim=128, 127 | use_nfl=True, 128 | use_dam=True, 129 | PredictiveDistribution=modules.MultivariateNormalDiag, 130 | p_y_loc_transformer=nn.Identity(), 131 | p_y_scale_transformer=lambda y_scale: 0.01 + 0.99 * F.softplus(y_scale), 132 | ): 133 | super().__init__() 134 | 135 | self.r_dim = r_dim 136 | self.y_dim = sum(n_channels) 137 | self.n_channels = n_channels 138 | self.use_dam = use_dam 139 | 140 | if use_nfl: 141 | EnCNN = partial( 142 | modules.FCNN, 143 | ConvBlock=modules.ResConvBlock, 144 | Conv=nn.Conv2d, 145 | n_blocks=4, 146 | Normalization=nn.BatchNorm2d, 147 | activation=nn.SiLU(), 148 | is_chan_last=True, 149 | kernel_size=9, 150 | n_conv_layers=2) 151 | else: 152 | EnCNN = partial( 153 | modules.CNN, 154 | ConvBlock=modules.ResConvBlock, 155 | Conv=nn.Conv2d, 156 | n_blocks=12, 157 | Normalization=nn.BatchNorm2d, 158 | activation=nn.SiLU(), 159 | is_chan_last=True, 160 | kernel_size=9, 161 | n_conv_layers=2) 162 | 163 | Decoder=modules.discard_ith_arg(partial(modules.MLP, n_hidden_layers=4, hidden_size=self.r_dim), i=0) 164 | self.obs_encoder = Encoder(n_channels=self.n_channels, r_dim=self.r_dim, CNN=EnCNN) 165 | self.back_encoder = Encoder(n_channels=self.n_channels, r_dim=self.r_dim, CNN=EnCNN) 166 | self.fusion = nn.ModuleList([nn.Linear(self.r_dim * 2, self.r_dim) for _ in range(len(n_channels)+1)]) 167 | 168 | # times 2 out because loc and scale (mean and var for gaussian) 169 | self.decoder = nn.ModuleList([Decoder(y_dim, self.r_dim * 2, y_dim * 2) for y_dim in n_channels]) 170 | if self.use_dam: 171 | self.smooth = nn.ModuleList([nn.Conv2d(self.r_dim * 2, self.r_dim, 9, padding=4) for _ in range(len(n_channels)+1)]) 172 | 173 | self.PredictiveDistribution = PredictiveDistribution 174 | self.p_y_loc_transformer = p_y_loc_transformer 175 | self.p_y_scale_transformer = p_y_scale_transformer 176 | 177 | self.reset_parameters() 178 | 179 | def reset_parameters(self): 180 | modules.weights_init(self) 181 | 182 | def forward(self, input_list): 183 | 184 | Xo_cntxt, Yo_cntxt, Xb_cntxt, Yb_cntxt, X_trgt = input_list[0], input_list[1], input_list[2], input_list[3], input_list[4] 185 | 186 | # {R^u}_u 187 | # size = [batch_size, *n_rep, r_dim] for n_channels list 188 | Ro_trgt = self.obs_encoder(Xo_cntxt, Yo_cntxt, X_trgt) 189 | Rb_trgt = self.back_encoder(Xb_cntxt, Yb_cntxt, X_trgt) 190 | 191 | z_samples, q_zCc, q_zCct = None, None, None 192 | 193 | # interpolate 194 | Ro_trgt = [rearrange(Ro_trgt[i], 'b h w c -> b c h w') for i in range(len(self.n_channels)+1)] 195 | Rb_trgt = [rearrange(Rb_trgt[i], 'b h w c -> b c h w') for i in range(len(self.n_channels)+1)] 196 | Ro_trgt = [F.interpolate(Ro_trgt[i], size=Rb_trgt[i].shape[2:], mode='bilinear') for i in range(len(self.n_channels)+1)] 197 | Ro_trgt = [rearrange(Ro_trgt[i], 'b c h w -> b h w c') for i in range(len(self.n_channels)+1)] 198 | Rb_trgt = [rearrange(Rb_trgt[i], 'b c h w -> b h w c') for i in range(len(self.n_channels)+1)] 199 | 200 | # representation fusion 201 | R_fusion = [self.fusion[i](torch.cat([Ro_trgt[i], Rb_trgt[i]], dim=-1)) for i in range(len(self.n_channels)+1)] 202 | if self.use_dam: 203 | R_similar = [rearrange(self.similarity(R_fusion[i], Rb_trgt[i], Ro_trgt[i]), 'b h w c -> b c h w') for i in range(len(self.n_channels)+1)] 204 | R_fusion = [rearrange(self.smooth[i](R_similar[i]), 'b c h w -> b h w c') for i in range(len(self.n_channels)+1)] 205 | 206 | # size = [n_z_samples, batch_size, *n_trgt, r_dim] 207 | R_trgt = [self.trgt_dependent_representation(Xo_cntxt, Xb_cntxt, z_samples, R_fusion[i], X_trgt) for i in range(len(self.n_channels)+1)] 208 | 209 | # p(y|cntxt,trgt) 210 | # batch shape=[n_z_samples, batch_size, *n_trgt] ; event shape=[y_dim] 211 | p_yCc = self.decode(X_trgt, R_trgt, Yb_cntxt) 212 | 213 | return p_yCc, z_samples, q_zCc, q_zCct 214 | 215 | def similarity(self, R, Rb, Ro): 216 | 217 | distb = torch.sqrt(torch.sum((R-Rb)**2, dim=-1, keepdim=True)) 218 | disto = torch.sqrt(torch.sum((R-Ro)**2, dim=-1, keepdim=True)) 219 | mask = (disto > distb).float() 220 | R = torch.cat([Ro * mask + Rb * (1-mask), R], dim=-1) 221 | 222 | return R 223 | 224 | def trgt_dependent_representation(self, _, __, ___, R_induced, ____): 225 | 226 | # n_z_samples=1. size = [1, batch_size, n_trgt, r_dim] 227 | return R_induced.unsqueeze(0) 228 | 229 | def decode(self, X_trgt, R_trgt, Yb_cntxt): 230 | 231 | locs = [] 232 | scales = [] 233 | 234 | for i in range(len(self.n_channels)): 235 | R_trgt_single = torch.cat([R_trgt[i], R_trgt[-1]], dim=-1) 236 | 237 | # size = [n_z_samples, batch_size, *n_trgt, y_dim*2] 238 | p_y_suffstat = self.decoder[i](X_trgt, R_trgt_single) 239 | 240 | # size = [n_z_samples, batch_size, *n_trgt, y_dim] 241 | p_y_loc, p_y_scale = p_y_suffstat.split(self.n_channels[i], dim=-1) 242 | 243 | p_y_loc = self.p_y_loc_transformer(p_y_loc) 244 | p_y_scale = self.p_y_scale_transformer(p_y_scale) 245 | 246 | locs.append(p_y_loc) 247 | scales.append(p_y_scale) 248 | 249 | locs = torch.cat(locs, dim=-1) + Yb_cntxt 250 | scales = torch.cat(scales, dim=-1) 251 | # batch shape=[n_z_samples, batch_size, *n_trgt] ; event shape=[y_dim] 252 | p_yCc = self.PredictiveDistribution(locs, scales) 253 | 254 | return p_yCc 255 | 256 | 257 | class FNP(object): 258 | 259 | def __init__(self, **model_params) -> None: 260 | super().__init__() 261 | 262 | params = model_params.get('params', {}) 263 | criterion = model_params.get('criterion', 'CNPFLoss') 264 | self.optimizer_params = model_params.get('optimizer', {}) 265 | self.scheduler_params = model_params.get('lr_scheduler', {}) 266 | 267 | self.kernel = FNP_model(**params) 268 | self.best_loss = 9999 269 | self.criterion = self.get_criterion(criterion) 270 | self.criterion_mae = nn.L1Loss() 271 | self.criterion_mse = nn.MSELoss() 272 | 273 | if utils.is_dist_avail_and_initialized(): 274 | self.device = torch.device('cuda' if torch.cuda.is_available() else "cpu") 275 | if self.device == torch.device('cpu'): 276 | raise EnvironmentError('No GPUs, cannot initialize multigpu training.') 277 | else: 278 | self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 279 | 280 | def get_criterion(self, loss_type): 281 | if loss_type == 'CNPFLoss': 282 | return modules.CNPFLoss() 283 | elif loss_type == 'NLLLossLNPF': 284 | return modules.NLLLossLNPF() 285 | elif loss_type == 'ELBOLossLNPF': 286 | return modules.ELBOLossLNPF() 287 | elif loss_type == 'SUMOLossLNPF': 288 | return modules.SUMOLossLNPF() 289 | else: 290 | raise NotImplementedError('Invalid loss type.') 291 | 292 | def process_data(self, batch_data, args): 293 | 294 | inp_data = torch.cat([batch_data[0], batch_data[1]], dim=1) 295 | inp_data = F.interpolate(inp_data, size=(128,256), mode='bilinear').numpy() 296 | truth = batch_data[-1].to(self.device, non_blocking=True) # 69 297 | truth = F.interpolate(truth, size=(args.resolution,args.resolution//2*4), mode='bilinear') 298 | truth_down = F.interpolate(truth, size=(128,256), mode='bilinear') 299 | 300 | for _ in range(args.lead_time // 6): 301 | predict_data = args.forecast_model.run(None, {'input':inp_data})[0][:,:truth.shape[1]] 302 | inp_data = np.concatenate([inp_data[:,-truth.shape[1]:], predict_data], axis=1) 303 | 304 | xb_context = rearrange(torch.rand(predict_data.shape, device=self.device) >= 0, 'b c h w -> b h w c') 305 | x_context = rearrange(torch.rand(truth.shape, device=self.device) >= args.ratio, 'b c h w -> b h w c') 306 | x_target = rearrange(torch.rand(truth_down.shape, device=self.device) >= 0, 'b c h w -> b h w c') 307 | yb_context = rearrange(torch.from_numpy(predict_data).to(self.device, non_blocking=True), 'b c h w -> b h w c') 308 | y_context = rearrange(truth, 'b c h w -> b h w c') 309 | y_target = rearrange(truth_down, 'b c h w -> b h w c') 310 | 311 | return [x_context, y_context, xb_context, yb_context, x_target], y_target 312 | 313 | def train(self, train_data_loader, valid_data_loader, logger, args): 314 | 315 | train_step = len(train_data_loader) 316 | valid_step = len(valid_data_loader) 317 | self.optimizer = get_optimizer(self.kernel, self.optimizer_params) 318 | self.scheduler = get_lr_scheduler(self.optimizer, self.scheduler_params, total_steps=train_step*args.max_epoch) 319 | 320 | for epoch in range(args.max_epoch): 321 | begin_time = time.time() 322 | self.kernel.train() 323 | 324 | for step, batch_data in enumerate(train_data_loader): 325 | 326 | input_list, y_target = self.process_data(batch_data[0], args) 327 | self.optimizer.zero_grad() 328 | y_pred = self.kernel(input_list) 329 | if isinstance(self.criterion, torch.nn.Module): 330 | self.criterion.train() 331 | loss = self.criterion(y_pred, y_target) 332 | loss.backward() 333 | clip_grad_norm_(self.kernel.parameters(), max_norm=1) 334 | self.optimizer.step() 335 | self.scheduler.step() 336 | 337 | if ((step + 1) % 100 == 0) | (step+1 == train_step): 338 | logger.info(f'Train epoch:[{epoch+1}/{args.max_epoch}], step:[{step+1}/{train_step}], lr:[{self.scheduler.get_last_lr()[0]}], loss:[{loss.item()}]') 339 | 340 | self.kernel.eval() 341 | with torch.no_grad(): 342 | total_loss = 0 343 | 344 | for step, batch_data in enumerate(valid_data_loader): 345 | input_list, y_target = self.process_data(batch_data[0], args) 346 | y_pred = self.kernel(input_list) 347 | if isinstance(self.criterion, torch.nn.Module): 348 | self.criterion.eval() 349 | loss = self.criterion(y_pred, y_target).item() 350 | total_loss += loss 351 | 352 | if ((step + 1) % 100 == 0) | (step+1 == valid_step): 353 | logger.info(f'Valid epoch:[{epoch+1}/{args.max_epoch}], step:[{step+1}/{valid_step}], loss:[{loss}]') 354 | 355 | if (total_loss/valid_step) < self.best_loss: 356 | if utils.get_world_size() > 1 and utils.get_rank() == 0: 357 | torch.save(self.kernel.module.state_dict(), f'{args.rundir}/best_model.pth') 358 | elif utils.get_world_size() == 1: 359 | torch.save(self.kernel.state_dict(), f'{args.rundir}/best_model.pth') 360 | logger.info(f'New best model appears in epoch {epoch+1}.') 361 | self.best_loss = total_loss/valid_step 362 | logger.info(f'Epoch {epoch+1} average loss:[{total_loss/valid_step}], time:[{time.time()-begin_time}]') 363 | 364 | def test(self, test_data_loader, logger, args): 365 | 366 | test_step = len(test_data_loader) 367 | data_mean, data_std = test_data_loader.dataset.get_meanstd() 368 | self.data_std = data_std.to(self.device) 369 | 370 | self.kernel.eval() 371 | with torch.no_grad(): 372 | total_loss = 0 373 | total_mae = 0 374 | total_mse = 0 375 | total_rmse = 0 376 | 377 | for step, batch_data in enumerate(test_data_loader): 378 | 379 | input_list, y_target = self.process_data(batch_data[0], args) 380 | y_pred = self.kernel(input_list) 381 | if isinstance(self.criterion, torch.nn.Module): 382 | self.criterion.eval() 383 | loss = self.criterion(y_pred, y_target).item() 384 | 385 | y_pred = rearrange(y_pred[0].mean[0], 'b h w c -> b c h w') 386 | y_target = rearrange(y_target, 'b h w c -> b c h w') 387 | mae = self.criterion_mae(y_pred, y_target).item() 388 | mse = self.criterion_mse(y_pred, y_target).item() 389 | rmse = WRMSE(y_pred, y_target, self.data_std) 390 | 391 | total_loss += loss 392 | total_mae += mae 393 | total_mse += mse 394 | total_rmse += rmse 395 | if ((step + 1) % 100 == 0) | (step+1 == test_step): 396 | logger.info(f'Valid step:[{step+1}/{test_step}], loss:[{loss}], MAE:[{mae}], MSE:[{mse}]') 397 | 398 | logger.info(f'Average loss:[{total_loss/test_step}], MAE:[{total_mae/test_step}], MSE:[{total_mse/test_step}]') 399 | logger.info(f'Average RMSE:[{total_rmse/test_step}]') 400 | -------------------------------------------------------------------------------- /modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .losses import * 2 | from .cnn import * 3 | from .encoders import * 4 | from .helpers import * 5 | from .initialization import * 6 | from .mlp import * 7 | from .transformer import * -------------------------------------------------------------------------------- /modules/cnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | from .initialization import init_param_, weights_init 5 | from .helpers import ( 6 | channels_to_2nd_dim, 7 | channels_to_last_dim, 8 | make_depth_sep_conv, 9 | ) 10 | 11 | 12 | __all__ = [ 13 | "GaussianConv2d", 14 | "ConvBlock", 15 | "ResNormalizedConvBlock", 16 | "ResConvBlock", 17 | "CNN", 18 | "UnetCNN", 19 | "FCNN", 20 | ] 21 | 22 | 23 | class GaussianConv2d(nn.Module): 24 | def __init__(self, kernel_size=5, **kwargs): 25 | super().__init__() 26 | self.kwargs = kwargs 27 | assert kernel_size % 2 == 1 28 | self.kernel_sizes = (kernel_size, kernel_size) 29 | self.exponent = -( 30 | (torch.arange(0, kernel_size).view(-1, 1).float() - kernel_size // 2) ** 2 31 | ) 32 | 33 | self.reset_parameters() 34 | 35 | def reset_parameters(self): 36 | self.weights_x = nn.Parameter(torch.tensor([1.0])) 37 | self.weights_y = nn.Parameter(torch.tensor([1.0])) 38 | 39 | def forward(self, X): 40 | # only switch first time to device 41 | self.exponent = self.exponent.to(X.device) 42 | 43 | marginal_x = torch.softmax(self.exponent * self.weights_x, dim=0) 44 | marginal_y = torch.softmax(self.exponent * self.weights_y, dim=0).T 45 | 46 | in_chan = X.size(1) 47 | filters = marginal_x @ marginal_y 48 | filters = filters.view(1, 1, *self.kernel_sizes).expand( 49 | in_chan, 1, *self.kernel_sizes 50 | ) 51 | 52 | return F.conv2d(X, filters, groups=in_chan, **self.kwargs) 53 | 54 | 55 | class ConvBlock(nn.Module): 56 | """Simple convolutional block with a single layer. 57 | 58 | Parameters 59 | ---------- 60 | in_chan : int 61 | Number of input channels. 62 | 63 | out_chan : int 64 | Number of output channels. 65 | 66 | Conv : nn.Module 67 | Convolutional layer (unitialized). E.g. `nn.Conv1d`. 68 | 69 | kernel_size : int or tuple, optional 70 | Size of the convolving kernel. 71 | 72 | dilation : int or tuple, optional 73 | Spacing between kernel elements. 74 | 75 | activation: callable, optional 76 | Activation object. E.g. `nn.ReLU`. 77 | 78 | Normalization : nn.Module, optional 79 | Normalization layer (unitialized). E.g. `nn.BatchNorm1d`. 80 | 81 | kwargs : 82 | Additional arguments to `Conv`. 83 | 84 | References 85 | ---------- 86 | [1] He, K., Zhang, X., Ren, S., & Sun, J. (2016, October). Identity mappings 87 | in deep residual networks. In European conference on computer vision 88 | (pp. 630-645). Springer, Cham. 89 | 90 | [2] Chollet, F. (2017). Xception: Deep learning with depthwise separable 91 | convolutions. In Proceedings of the IEEE conference on computer vision 92 | and pattern recognition (pp. 1251-1258). 93 | """ 94 | 95 | def __init__( 96 | self, 97 | in_chan, 98 | out_chan, 99 | Conv, 100 | kernel_size=5, 101 | dilation=1, 102 | activation=nn.ReLU(), 103 | Normalization=nn.Identity, 104 | **kwargs 105 | ): 106 | super().__init__() 107 | self.activation = activation 108 | 109 | padding = kernel_size // 2 110 | 111 | Conv = make_depth_sep_conv(Conv) 112 | 113 | self.conv = Conv(in_chan, out_chan, kernel_size, padding=padding, **kwargs) 114 | self.norm = Normalization(in_chan) 115 | 116 | self.reset_parameters() 117 | 118 | def reset_parameters(self): 119 | weights_init(self) 120 | 121 | def forward(self, X): 122 | return self.conv(self.activation(self.norm(X))) 123 | 124 | 125 | class ResConvBlock(nn.Module): 126 | """Convolutional block inspired by the pre-activation Resnet [1] 127 | and depthwise separable convolutions [2]. 128 | 129 | Parameters 130 | ---------- 131 | in_chan : int 132 | Number of input channels. 133 | 134 | out_chan : int 135 | Number of output channels. 136 | 137 | Conv : nn.Module 138 | Convolutional layer (unitialized). E.g. `nn.Conv1d`. 139 | 140 | kernel_size : int or tuple, optional 141 | Size of the convolving kernel. Should be odd to keep the same size. 142 | 143 | activation: callable, optional 144 | Activation object. E.g. `nn.RelU()`. 145 | 146 | Normalization : nn.Module, optional 147 | Normalization layer (unitialized). E.g. `nn.BatchNorm1d`. 148 | 149 | n_conv_layers : int, optional 150 | Number of convolutional layers, can be 1 or 2. 151 | 152 | is_bias : bool, optional 153 | Whether to use a bias. 154 | 155 | References 156 | ---------- 157 | [1] He, K., Zhang, X., Ren, S., & Sun, J. (2016, October). Identity mappings 158 | in deep residual networks. In European conference on computer vision 159 | (pp. 630-645). Springer, Cham. 160 | 161 | [2] Chollet, F. (2017). Xception: Deep learning with depthwise separable 162 | convolutions. In Proceedings of the IEEE conference on computer vision 163 | and pattern recognition (pp. 1251-1258). 164 | """ 165 | 166 | def __init__( 167 | self, 168 | in_chan, 169 | out_chan, 170 | Conv, 171 | kernel_size=5, 172 | activation=nn.ReLU(), 173 | Normalization=nn.Identity, 174 | is_bias=True, 175 | n_conv_layers=1, 176 | ): 177 | super().__init__() 178 | self.activation = activation 179 | self.n_conv_layers = n_conv_layers 180 | assert self.n_conv_layers in [1, 2] 181 | 182 | if kernel_size % 2 == 0: 183 | raise ValueError("`kernel_size={}`, but should be odd.".format(kernel_size)) 184 | 185 | padding = kernel_size // 2 186 | 187 | if self.n_conv_layers == 2: 188 | self.norm1 = Normalization(in_chan) 189 | self.conv1 = make_depth_sep_conv(Conv)( 190 | in_chan, in_chan, kernel_size, padding=padding, bias=is_bias 191 | ) 192 | self.norm2 = Normalization(in_chan) 193 | self.conv2_depthwise = Conv( 194 | in_chan, in_chan, kernel_size, padding=padding, groups=in_chan, bias=is_bias 195 | ) 196 | self.conv2_pointwise = Conv(in_chan, out_chan, 1, bias=is_bias) 197 | 198 | self.reset_parameters() 199 | 200 | def reset_parameters(self): 201 | weights_init(self) 202 | 203 | def forward(self, X): 204 | 205 | if self.n_conv_layers == 2: 206 | out = self.conv1(self.activation(self.norm1(X))) 207 | else: 208 | out = X 209 | 210 | out = self.conv2_depthwise(self.activation(self.norm2(out))) 211 | # adds residual before point wise => output can change number of channels 212 | out = out + X 213 | out = self.conv2_pointwise(out.contiguous()) # for some reason need contiguous 214 | return out 215 | 216 | 217 | class ResNormalizedConvBlock(ResConvBlock): 218 | """Modification of `ResConvBlock` to use normalized convolutions [1]. 219 | 220 | Parameters 221 | ---------- 222 | in_chan : int 223 | Number of input channels. 224 | 225 | out_chan : int 226 | Number of output channels. 227 | 228 | Conv : nn.Module 229 | Convolutional layer (unitialized). E.g. `nn.Conv1d`. 230 | 231 | kernel_size : int or tuple, optional 232 | Size of the convolving kernel. Should be odd to keep the same size. 233 | 234 | activation: nn.Module, optional 235 | Activation object. E.g. `nn.RelU()`. 236 | 237 | is_bias : bool, optional 238 | Whether to use a bias. 239 | 240 | References 241 | ---------- 242 | [1] Knutsson, H., & Westin, C. F. (1993, June). Normalized and differential 243 | convolution. In Proceedings of IEEE Conference on Computer Vision and 244 | Pattern Recognition (pp. 515-523). IEEE. 245 | """ 246 | 247 | def __init__( 248 | self, 249 | in_chan, 250 | out_chan, 251 | Conv, 252 | kernel_size=5, 253 | activation=nn.ReLU(), 254 | is_bias=True, 255 | **kwargs 256 | ): 257 | super().__init__( 258 | in_chan, 259 | out_chan, 260 | Conv, 261 | kernel_size=kernel_size, 262 | activation=activation, 263 | is_bias=is_bias, 264 | Normalization=nn.Identity, 265 | **kwargs 266 | ) # make sure no normalization 267 | 268 | def reset_parameters(self): 269 | weights_init(self) 270 | self.bias = nn.Parameter(torch.tensor([0.0])) 271 | 272 | self.temperature = nn.Parameter(torch.tensor([0.0])) 273 | init_param_(self.temperature) 274 | 275 | def forward(self, X): 276 | """ 277 | Apply a normalized convolution. X should contain 2*in_chan channels. 278 | First halves for signal, last halve for corresponding confidence channels. 279 | """ 280 | 281 | signal, conf_1 = X.chunk(2, dim=1) 282 | # make sure confidence is in 0 1 (might not be due to the pointwise trsnf) 283 | conf_1 = conf_1.clamp(min=0, max=1) 284 | X = signal * conf_1 285 | 286 | numerator = self.conv1(self.activation(X)) 287 | numerator = self.conv2_depthwise(self.activation(numerator)) 288 | density = self.conv2_depthwise(self.conv1(conf_1)) 289 | out = numerator / torch.clamp(density, min=1e-5) 290 | 291 | # adds residual before point wise => output can change number of channels 292 | 293 | # make sure that confidence cannot decrease and cannot be greater than 1 294 | conf_2 = conf_1 + torch.sigmoid( 295 | density * F.softplus(self.temperature) + self.bias 296 | ) 297 | conf_2 = conf_2.clamp(max=1) 298 | out = out + X 299 | 300 | out = self.conv2_pointwise(out) 301 | conf_2 = self.conv2_pointwise(conf_2) 302 | 303 | return torch.cat([out, conf_2], dim=1) 304 | 305 | 306 | 307 | class SpectralConv2d_fast(nn.Module): 308 | def __init__(self, in_channels, out_channels, modes): 309 | super().__init__() 310 | 311 | """ 312 | 2D Fourier layer. It does FFT, linear transform, and Inverse FFT. 313 | """ 314 | self.in_channels = in_channels 315 | self.out_channels = out_channels 316 | self.modes = modes #Number of Fourier modes to multiply, at most floor(N/2) + 1 317 | 318 | self.scale = (1 / (in_channels * out_channels)) 319 | self.weights1 = nn.Parameter(self.scale * torch.rand(2, in_channels, out_channels, self.modes, self.modes)) 320 | self.weights2 = nn.Parameter(self.scale * torch.rand(2, in_channels, out_channels, self.modes, self.modes)) 321 | 322 | def compl_mul2d(self, input, weights): 323 | # (batch, in_channel, x,y ), (in_channel, out_channel, x,y) -> (batch, out_channel, x,y) 324 | return torch.einsum("bixy,ioxy->boxy", input, weights) 325 | 326 | def forward(self, x): 327 | batchsize, dtype = x.shape[0], x.dtype 328 | #Compute Fourier coeffcients up to factor of e^(- something constant) 329 | x_ft = torch.fft.rfft2(x) 330 | 331 | # Multiply relevant Fourier modes 332 | out_ft_real = torch.zeros(batchsize, self.out_channels, x.size(-2), x.size(-1)//2 + 1, device=x.device) 333 | out_ft_imag = torch.zeros(batchsize, self.out_channels, x.size(-2), x.size(-1)//2 + 1, device=x.device) 334 | 335 | out_ft_real[:, :, :self.modes, :self.modes] = \ 336 | self.compl_mul2d(x_ft[:, :, :self.modes, :self.modes].real, self.weights1[0]) - self.compl_mul2d(x_ft[:, :, :self.modes, :self.modes].imag, self.weights1[1]) 337 | out_ft_real[:, :, -self.modes:, :self.modes] = \ 338 | self.compl_mul2d(x_ft[:, :, -self.modes:, :self.modes].real, self.weights2[0]) - self.compl_mul2d(x_ft[:, :, -self.modes:, :self.modes].imag, self.weights2[1]) 339 | 340 | out_ft_imag[:, :, :self.modes, :self.modes] = \ 341 | self.compl_mul2d(x_ft[:, :, :self.modes, :self.modes].imag, self.weights1[0]) + self.compl_mul2d(x_ft[:, :, :self.modes, :self.modes].real, self.weights1[1]) 342 | out_ft_imag[:, :, -self.modes:, :self.modes] = \ 343 | self.compl_mul2d(x_ft[:, :, -self.modes:, :self.modes].imag, self.weights2[0]) + self.compl_mul2d(x_ft[:, :, -self.modes:, :self.modes].real, self.weights2[1]) 344 | 345 | # Return to physical space 346 | out_ft = torch.stack([out_ft_real, out_ft_imag], dim=-1) 347 | out_ft = torch.view_as_complex(out_ft) 348 | x = torch.fft.irfft2(out_ft, s=(x.size(-2), x.size(-1))) 349 | x = x.type(dtype) 350 | 351 | return x 352 | 353 | 354 | 355 | class CNN(nn.Module): 356 | """Simple multilayer CNN. 357 | 358 | Parameters 359 | ---------- 360 | n_channels : int or list 361 | Number of channels, same for input and output. If list then needs to be 362 | of size `n_blocks - 1`, e.g. [16, 32, 64] means that you will have a 363 | `[ConvBlock(16,32), ConvBlock(32, 64)]`. 364 | 365 | ConvBlock : nn.Module 366 | Convolutional block (unitialized). Needs to take as input `Should be 367 | initialized with `ConvBlock(in_chan, out_chan)`. 368 | 369 | n_blocks : int, optional 370 | Number of convolutional blocks. 371 | 372 | is_chan_last : bool, optional 373 | Whether the channels are on the last dimension of the input. 374 | 375 | kwargs : 376 | Additional arguments to `ConvBlock`. 377 | """ 378 | 379 | def __init__(self, n_channels, ConvBlock, n_blocks=3, is_chan_last=False, **kwargs): 380 | 381 | super().__init__() 382 | self.n_blocks = n_blocks 383 | self.is_chan_last = is_chan_last 384 | self.in_out_channels = self._get_in_out_channels(n_channels, n_blocks) 385 | self.conv_blocks = nn.ModuleList( 386 | [ 387 | ConvBlock(in_chan, out_chan, **kwargs) 388 | for in_chan, out_chan in self.in_out_channels 389 | ] 390 | ) 391 | self.is_return_rep = False # never return representation for vanilla conv 392 | 393 | self.reset_parameters() 394 | 395 | def reset_parameters(self): 396 | weights_init(self) 397 | 398 | def _get_in_out_channels(self, n_channels, n_blocks): 399 | """Return a list of tuple of input and output channels.""" 400 | if isinstance(n_channels, int): 401 | channel_list = [n_channels] * (n_blocks + 1) 402 | else: 403 | channel_list = list(n_channels) 404 | 405 | assert len(channel_list) == (n_blocks + 1), "{} != {}".format( 406 | len(channel_list), n_blocks + 1 407 | ) 408 | 409 | return list(zip(channel_list, channel_list[1:])) 410 | 411 | def forward(self, X): 412 | if self.is_chan_last: 413 | X = channels_to_2nd_dim(X) 414 | 415 | X, representation = self.apply_convs(X) 416 | 417 | if self.is_chan_last: 418 | X = channels_to_last_dim(X) 419 | 420 | if self.is_return_rep: 421 | return X, representation 422 | 423 | return X 424 | 425 | def apply_convs(self, X): 426 | for conv_block in self.conv_blocks: 427 | X = conv_block(X) 428 | return X, None 429 | 430 | 431 | class FCNN(nn.Module): 432 | """Simple multilayer CNN. 433 | 434 | Parameters 435 | ---------- 436 | n_channels : int or list 437 | Number of channels, same for input and output. If list then needs to be 438 | of size `n_blocks - 1`, e.g. [16, 32, 64] means that you will have a 439 | `[ConvBlock(16,32), ConvBlock(32, 64)]`. 440 | 441 | ConvBlock : nn.Module 442 | Convolutional block (unitialized). Needs to take as input `Should be 443 | initialized with `ConvBlock(in_chan, out_chan)`. 444 | 445 | n_blocks : int, optional 446 | Number of convolutional blocks. 447 | 448 | is_chan_last : bool, optional 449 | Whether the channels are on the last dimension of the input. 450 | 451 | kwargs : 452 | Additional arguments to `ConvBlock`. 453 | """ 454 | 455 | def __init__(self, n_channels, ConvBlock, n_blocks=3, is_chan_last=False, **kwargs): 456 | 457 | super().__init__() 458 | self.n_blocks = n_blocks 459 | self.is_chan_last = is_chan_last 460 | self.in_out_channels = self._get_in_out_channels(n_channels, n_blocks) 461 | self.conv_blocks = nn.ModuleList( 462 | [ 463 | ConvBlock(in_chan, out_chan, **kwargs) 464 | for in_chan, out_chan in self.in_out_channels 465 | ] 466 | ) 467 | self.fno_blocks = nn.ModuleList( 468 | [ 469 | SpectralConv2d_fast(in_chan, out_chan, modes=12) 470 | for in_chan, out_chan in self.in_out_channels 471 | ] 472 | ) 473 | self.is_return_rep = False # never return representation for vanilla conv 474 | 475 | self.reset_parameters() 476 | 477 | def reset_parameters(self): 478 | weights_init(self) 479 | 480 | def _get_in_out_channels(self, n_channels, n_blocks): 481 | """Return a list of tuple of input and output channels.""" 482 | if isinstance(n_channels, int): 483 | channel_list = [n_channels] * (n_blocks + 1) 484 | else: 485 | channel_list = list(n_channels) 486 | 487 | assert len(channel_list) == (n_blocks + 1), "{} != {}".format( 488 | len(channel_list), n_blocks + 1 489 | ) 490 | 491 | return list(zip(channel_list, channel_list[1:])) 492 | 493 | def forward(self, X): 494 | if self.is_chan_last: 495 | X = channels_to_2nd_dim(X) 496 | 497 | X, representation = self.apply_convs(X) 498 | 499 | if self.is_chan_last: 500 | X = channels_to_last_dim(X) 501 | 502 | if self.is_return_rep: 503 | return X, representation 504 | 505 | return X 506 | 507 | def apply_convs(self, X): 508 | for i in range(self.n_blocks): 509 | # for conv_block in self.conv_blocks: 510 | X = self.conv_blocks[i](X) + self.fno_blocks[i](X) + X 511 | return X, None 512 | 513 | 514 | class UnetCNN(CNN): 515 | """Unet [1]. 516 | 517 | Parameters 518 | ---------- 519 | n_channels : int or list 520 | Number of channels, same for input and output. If list then needs to be 521 | of size `n_blocks - 1`, e.g. [16, 32, 64] means that you will have a 522 | `[ConvBlock(16,32), ConvBlock(32, 64)]`. 523 | 524 | ConvBlock : nn.Module 525 | Convolutional block (unitialized). Needs to take as input `Should be 526 | initialized with `ConvBlock(in_chan, out_chan)`. 527 | 528 | Pool : nn.Module 529 | Pooling layer (unitialized). E.g. torch.nn.MaxPool1d. 530 | 531 | upsample_mode : {'nearest', 'linear', bilinear', 'bicubic', 'trilinear'} 532 | The upsampling algorithm: nearest, linear (1D-only), bilinear, bicubic 533 | (2D-only), trilinear (3D-only). 534 | 535 | max_nchannels : int, optional 536 | Bounds the maximum number of channels instead of always doubling them at 537 | downsampling block. 538 | 539 | pooling_size : int or tuple, optional 540 | Size of the pooling filter. 541 | 542 | is_force_same_bottleneck : bool, optional 543 | Whether to use the average bottleneck for the same functions sampled at 544 | different context and target. If `True` the first and second halves 545 | of a batch should contain different samples of the same functions (in order). 546 | 547 | is_return_rep : bool, optional 548 | Whether to return a summary representation, that corresponds to the 549 | bottleneck + global mean pooling. 550 | 551 | kwargs : 552 | Additional arguments to `CNN` and `ConvBlock`. 553 | 554 | References 555 | ---------- 556 | [1] Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. "U-net: Convolutional 557 | networks for biomedical image segmentation." International Conference on 558 | Medical image computing and computer-assisted intervention. Springer, Cham, 2015. 559 | """ 560 | 561 | def __init__( 562 | self, 563 | n_channels, 564 | ConvBlock, 565 | Pool, 566 | upsample_mode, 567 | max_nchannels=256, 568 | pooling_size=2, 569 | is_force_same_bottleneck=False, 570 | is_return_rep=False, 571 | **kwargs 572 | ): 573 | 574 | self.max_nchannels = max_nchannels 575 | super().__init__(n_channels, ConvBlock, **kwargs) 576 | self.pooling_size = pooling_size 577 | self.pooling = Pool(self.pooling_size) 578 | self.upsample_mode = upsample_mode 579 | self.is_force_same_bottleneck = is_force_same_bottleneck 580 | self.is_return_rep = is_return_rep 581 | 582 | def apply_convs(self, X): 583 | n_down_blocks = self.n_blocks // 2 584 | residuals = [None] * n_down_blocks 585 | 586 | # Down 587 | for i in range(n_down_blocks): 588 | X = self.conv_blocks[i](X) 589 | residuals[i] = X 590 | X = self.pooling(X) 591 | 592 | # Bottleneck 593 | X = self.conv_blocks[n_down_blocks](X) 594 | # Representation before forcing same bottleneck 595 | representation = X.view(*X.shape[:2], -1).mean(-1) 596 | 597 | if self.is_force_same_bottleneck and self.training: 598 | # forces the u-net to use the bottleneck by giving additional information 599 | # there. I.e. taking average between bottleenck of different samples 600 | # of the same functions. Because bottleneck should be a global representation 601 | # => should not depend on the sample you chose 602 | batch_size = X.size(0) 603 | batch_1 = X[: batch_size // 2, ...] 604 | batch_2 = X[batch_size // 2 :, ...] 605 | X_mean = (batch_1 + batch_2) / 2 606 | X = torch.cat([X_mean, X_mean], dim=0) 607 | 608 | # Up 609 | for i in range(n_down_blocks + 1, self.n_blocks): 610 | X = F.interpolate( 611 | X, 612 | mode=self.upsample_mode, 613 | scale_factor=self.pooling_size, 614 | align_corners=True, 615 | ) 616 | X = torch.cat( 617 | (X, residuals[n_down_blocks - i]), dim=1 618 | ) # concat on channels 619 | X = self.conv_blocks[i](X) 620 | 621 | return X, representation 622 | 623 | def _get_in_out_channels(self, n_channels, n_blocks): 624 | """Return a list of tuple of input and output channels for a Unet.""" 625 | # doubles at every down layer, as in vanilla U-net 626 | factor_chan = 2 627 | 628 | assert n_blocks % 2 == 1, "n_blocks={} not odd".format(n_blocks) 629 | # e.g. if n_channels=16, n_blocks=5: [16, 32, 64] 630 | channel_list = [factor_chan ** i * n_channels for i in range(n_blocks // 2 + 1)] 631 | # e.g.: [16, 32, 64, 64, 32, 16] 632 | channel_list = channel_list + channel_list[::-1] 633 | # bound max number of channels by self.max_nchannels (besides first and 634 | # last dim as this is input / output cand sohould not be changed) 635 | channel_list = ( 636 | channel_list[:1] 637 | + [min(c, self.max_nchannels) for c in channel_list[1:-1]] 638 | + channel_list[-1:] 639 | ) 640 | # e.g.: [(16, 32), (32,64), (64, 64), (64, 32), (32, 16)] 641 | in_out_channels = super()._get_in_out_channels(channel_list, n_blocks) 642 | # e.g.: [(16, 32), (32,64), (64, 64), (128, 32), (64, 16)] due to concat 643 | idcs = slice(len(in_out_channels) // 2 + 1, len(in_out_channels)) 644 | in_out_channels[idcs] = [ 645 | (in_chan * 2, out_chan) for in_chan, out_chan in in_out_channels[idcs] 646 | ] 647 | return in_out_channels 648 | 649 | 650 | 651 | -------------------------------------------------------------------------------- /modules/encoders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .initialization import weights_init 4 | from .mlp import MLP 5 | 6 | __all__ = [ 7 | "RelativeSinusoidalEncodings", 8 | "SinusoidalEncodings", 9 | "merge_flat_input", 10 | "discard_ith_arg", 11 | ] 12 | 13 | 14 | class SinusoidalEncodings(nn.Module): 15 | """ 16 | Converts a batch of N-dimensional spatial input X with values between `[-1,1]` 17 | to a batch of flat in vector splitted in N subvectors that encode the position via 18 | sinusoidal encodings. 19 | 20 | Parameters 21 | ---------- 22 | x_dim : int 23 | Number of spatial inputs. 24 | 25 | out_dim : int 26 | size of output encoding. Each x_dim will have an encoding of size 27 | `out_dim//x_dim`. 28 | """ 29 | 30 | def __init__(self, x_dim, out_dim): 31 | super().__init__() 32 | self.x_dim = x_dim 33 | # dimension of encoding for eacg x dimension 34 | self.sub_dim = out_dim // self.x_dim 35 | # in "attention is all you need" used 10000 but 512 dim, try to keep the 36 | # same ratio regardless of dim 37 | self._C = 10000 * (self.sub_dim / 512) ** 2 38 | 39 | if out_dim % x_dim != 0: 40 | raise ValueError( 41 | "out_dim={} has to be dividable by x_dim={}.".format(out_dim, x_dim) 42 | ) 43 | if self.sub_dim % 2 != 0: 44 | raise ValueError( 45 | "sum_dim=out_dim/x_dim={} has to be dividable by 2.".format( 46 | self.sub_dim 47 | ) 48 | ) 49 | 50 | self._precompute_denom() 51 | 52 | def _precompute_denom(self): 53 | two_i_d = torch.arange(0, self.sub_dim, 2, dtype=torch.float) / self.sub_dim 54 | denom = torch.pow(self._C, two_i_d) 55 | denom = torch.repeat_interleave(denom, 2).unsqueeze(0) 56 | self.denom = denom.expand(1, self.x_dim, self.sub_dim) 57 | 58 | def forward(self, x): 59 | shape = x.shape 60 | # flatten besides last dim 61 | x = x.view(-1, shape[-1]) 62 | # will only be passed once to GPU because precomputed 63 | self.denom = self.denom.to(x.device) 64 | # put x in a range which is similar to positions in NLP [1,51] 65 | x = (x.unsqueeze(-1) + 1) * 25 + 1 66 | out = x / self.denom 67 | out[:, :, 0::2] = torch.sin(out[:, :, 0::2]) 68 | out[:, :, 1::2] = torch.cos(out[:, :, 1::2]) 69 | # concatenate all different sinusoidal encodings for each x_dim 70 | # and unflatten 71 | out = out.view(*shape[:-1], self.sub_dim * self.x_dim) 72 | return out 73 | 74 | 75 | class RelativeSinusoidalEncodings(nn.Module): 76 | """Return relative positions of inputs between [-1,1].""" 77 | 78 | def __init__(self, x_dim, out_dim, window_size=2): 79 | super().__init__() 80 | self.pos_encoder = SinusoidalEncodings(x_dim, out_dim) 81 | self.weight = nn.Linear(out_dim, out_dim, bias=False) 82 | self.window_size = window_size 83 | self.out_dim = out_dim 84 | 85 | def forward(self, keys_pos, queries_pos): 86 | # size=[batch_size, n_queries, n_keys, x_dim] 87 | diff = (keys_pos.unsqueeze(1) - queries_pos.unsqueeze(2)).abs() 88 | 89 | # the abs differences will be between between 0, self.window_size 90 | # we multipl by 2/self.window_size then remove 1 to be [-1,1] which is 91 | # the range for `SinusoidalEncodings` 92 | scaled_diff = diff * 2 / self.window_size - 1 93 | out = self.weight(self.pos_encoder(scaled_diff)) 94 | 95 | # set to 0 points that are further than window for extap 96 | out = out * (diff < self.window_size).float() 97 | 98 | return out 99 | 100 | 101 | # META ENCODERS 102 | class DiscardIthArg(nn.Module): 103 | """ 104 | Helper module which discard the i^th argument of the constructor and forward, 105 | before being given to `To`. 106 | """ 107 | 108 | def __init__(self, *args, i=0, To=nn.Identity, **kwargs): 109 | super().__init__() 110 | self.i = i 111 | self.destination = To(*self.filter_args(*args), **kwargs) 112 | 113 | def filter_args(self, *args): 114 | return [arg for i, arg in enumerate(args) if i != self.i] 115 | 116 | def forward(self, *args, **kwargs): 117 | return self.destination(*self.filter_args(*args), **kwargs) 118 | 119 | 120 | def discard_ith_arg(module, i, **kwargs): 121 | def discarded_arg(*args, **kwargs2): 122 | return DiscardIthArg(*args, i=i, To=module, **kwargs, **kwargs2) 123 | 124 | return discarded_arg 125 | 126 | 127 | class MergeFlatInputs(nn.Module): 128 | """ 129 | Extend a module to take 2 flat inputs. It simply returns 130 | the concatenated flat inputs to the module `module({x1; x2})`. 131 | 132 | Parameters 133 | ---------- 134 | FlatModule: nn.Module 135 | Module which takes a non flat inputs. 136 | 137 | x1_dim: int 138 | Dimensionality of the first flat inputs. 139 | 140 | x2_dim: int 141 | Dimensionality of the second flat inputs. 142 | 143 | n_out: int 144 | Size of ouput. 145 | 146 | is_sum_merge : bool, optional 147 | Whether to transform `flat_input` by an MLP first (if need to resize), 148 | then sum to `X` (instead of concatenating): useful if the difference in 149 | dimension between both inputs is very large => don't want one layer to 150 | depend only on a few dimension of a large input. 151 | 152 | kwargs: 153 | Additional arguments to FlatModule. 154 | """ 155 | 156 | def __init__(self, FlatModule, x1_dim, x2_dim, n_out, is_sum_merge=False, **kwargs): 157 | super().__init__() 158 | self.is_sum_merge = is_sum_merge 159 | 160 | if self.is_sum_merge: 161 | dim = x1_dim 162 | self.resizer = MLP(x2_dim, dim) # transform to be the correct size 163 | else: 164 | dim = x1_dim + x2_dim 165 | 166 | self.flat_module = FlatModule(dim, n_out, **kwargs) 167 | self.reset_parameters() 168 | 169 | def reset_parameters(self): 170 | weights_init(self) 171 | 172 | def forward(self, x1, x2): 173 | if self.is_sum_merge: 174 | x2 = self.resizer(x2) 175 | # use activation because if not 2 linear layers in a row => useless computation 176 | out = torch.relu(x1 + x2) 177 | else: 178 | out = torch.cat((x1, x2), dim=-1) 179 | 180 | return self.flat_module(out) 181 | 182 | 183 | def merge_flat_input(module, is_sum_merge=False, **kwargs): 184 | """ 185 | Extend a module to accept an additional flat input. I.e. the output should 186 | be called by `merge_flat_input(module)(x_shape, flat_dim, n_out, **kwargs)`. 187 | 188 | Notes 189 | ----- 190 | - if x_shape is an integer (currently only available option), it simply returns 191 | the concatenated flat inputs to the module `module({x; flat_input})`. 192 | - if `is_sum_merge` then transform `flat_input` by an MLP first, then sum 193 | to `X` (instead of concatenating): useful if the difference in dimension 194 | between both inputs is very large => don't want one layer to depend only on 195 | a few dimension of a large input. 196 | """ 197 | 198 | def merged_flat_input(x_shape, flat_dim, n_out, **kwargs2): 199 | assert isinstance(x_shape, int) 200 | return MergeFlatInputs( 201 | module, 202 | x_shape, 203 | flat_dim, 204 | n_out, 205 | is_sum_merge=is_sum_merge, 206 | **kwargs2, 207 | **kwargs 208 | ) 209 | 210 | return merged_flat_input 211 | -------------------------------------------------------------------------------- /modules/helpers.py: -------------------------------------------------------------------------------- 1 | import operator 2 | from functools import reduce 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from scipy.stats import rv_discrete 8 | from torch.distributions import Normal 9 | from torch.distributions.independent import Independent 10 | from .initialization import weights_init 11 | 12 | 13 | def sum_from_nth_dim(t, dim): 14 | """Sum all dims from `dim`. E.g. sum_after_nth_dim(torch.rand(2,3,4,5), 2).shape = [2,3]""" 15 | return t.view(*t.shape[:dim], -1).sum(-1) 16 | 17 | 18 | def logcumsumexp(x, dim): 19 | """Numerically stable log cumsum exp. SLow workaround waiting for https://github.com/pytorch/pytorch/pull/36308""" 20 | 21 | if (dim != -1) or (dim != x.ndimension() - 1): 22 | x = x.transpose(dim, -1) 23 | 24 | out = [] 25 | for i in range(1, x.size(-1) + 1): 26 | out.append(torch.logsumexp(x[..., :i], dim=-1, keepdim=True)) 27 | out = torch.cat(out, dim=-1) 28 | 29 | if (dim != -1) or (dim != x.ndimension() - 1): 30 | out = out.transpose(-1, dim) 31 | return out 32 | 33 | 34 | class LightTailPareto(rv_discrete): 35 | def _cdf(self, k, alpha): 36 | # alpha is factor like in SUMO paper 37 | # m is minimum number of samples 38 | m = self.a # lower bound of support 39 | 40 | # in the paper they us P(K >= k) but cdf is P(K <= k) = 1 - P(K > k) = 1 - P(K >= k + 1) 41 | k = k + 1 42 | 43 | # make sure has at least m samples 44 | k = np.clip(k - m, a_min=1, a_max=None) # makes sure no division by 0 45 | alpha = alpha - m 46 | 47 | # sample using pmf 1/k but with finite expectation 48 | cdf = 1 - np.where(k < alpha, 1 / k, (1 / alpha) * (0.9) ** (k - alpha)) 49 | 50 | return cdf 51 | 52 | 53 | def isin_range(x, valid_range): 54 | """Check if array / tensor is in a given range elementwise.""" 55 | return ((x >= valid_range[0]) & (x <= valid_range[1])).all() 56 | 57 | 58 | def channels_to_2nd_dim(X): 59 | """ 60 | Takes a signal with channels on the last dimension (for most operations) and 61 | returns it with channels on the second dimension (for convolutions). 62 | """ 63 | return X.permute(*([0, X.dim() - 1] + list(range(1, X.dim() - 1)))).contiguous() 64 | 65 | 66 | def channels_to_last_dim(X): 67 | """ 68 | Takes a signal with channels on the second dimension (for convolutions) and 69 | returns it with channels on the last dimension (for most operations). 70 | """ 71 | return X.permute(*([0] + list(range(2, X.dim())) + [1])).contiguous() 72 | 73 | 74 | def mask_and_apply(x, mask, f): 75 | """Applies a callable on a masked version of a input.""" 76 | tranformed_selected = f(x.masked_select(mask)) 77 | return x.masked_scatter(mask, tranformed_selected) 78 | 79 | 80 | def indep_shuffle_(a, axis=-1): 81 | """ 82 | Shuffle `a` in-place along the given axis. 83 | 84 | Apply `numpy.random.shuffle` to the given axis of `a`. 85 | Each one-dimensional slice is shuffled independently. 86 | 87 | Credits : https://github.com/numpy/numpy/issues/5173 88 | """ 89 | b = a.swapaxes(axis, -1) 90 | # Shuffle `b` in-place along the last axis. `b` is a view of `a`, 91 | # so `a` is shuffled in place, too. 92 | shp = b.shape[:-1] 93 | for ndx in np.ndindex(shp): 94 | np.random.shuffle(b[ndx]) 95 | 96 | 97 | def ratio_to_int(percentage, max_val): 98 | """Converts a ratio to an integer if it is smaller than 1.""" 99 | if 1 <= percentage <= max_val: 100 | out = percentage 101 | elif 0 <= percentage < 1: 102 | out = percentage * max_val 103 | else: 104 | raise ValueError("percentage={} outside of [0,{}].".format(percentage, max_val)) 105 | 106 | return int(out) 107 | 108 | 109 | def prod(iterable): 110 | """Compute the product of all elements in an iterable.""" 111 | return reduce(operator.mul, iterable, 1) 112 | 113 | 114 | def rescale_range(X, old_range, new_range): 115 | """Rescale X linearly to be in `new_range` rather than `old_range`.""" 116 | old_min = old_range[0] 117 | new_min = new_range[0] 118 | old_delta = old_range[1] - old_min 119 | new_delta = new_range[1] - new_min 120 | return (((X - old_min) * new_delta) / old_delta) + new_min 121 | 122 | 123 | def MultivariateNormalDiag(loc, scale_diag): 124 | """Multi variate Gaussian with a diagonal covariance function (on the last dimension).""" 125 | if loc.dim() < 1: 126 | raise ValueError("loc must be at least one-dimensional.") 127 | return Independent(Normal(loc, scale_diag), 1) 128 | 129 | 130 | def clamp( 131 | x, 132 | minimum=-float("Inf"), 133 | maximum=float("Inf"), 134 | is_leaky=False, 135 | negative_slope=0.01, 136 | hard_min=None, 137 | hard_max=None, 138 | ): 139 | """ 140 | Clamps a tensor to the given [minimum, maximum] (leaky) bound, with 141 | an optional hard clamping. 142 | """ 143 | lower_bound = ( 144 | (minimum + negative_slope * (x - minimum)) 145 | if is_leaky 146 | else torch.zeros_like(x) + minimum 147 | ) 148 | upper_bound = ( 149 | (maximum + negative_slope * (x - maximum)) 150 | if is_leaky 151 | else torch.zeros_like(x) + maximum 152 | ) 153 | clamped = torch.max(lower_bound, torch.min(x, upper_bound)) 154 | 155 | if hard_min is not None or hard_max is not None: 156 | if hard_min is None: 157 | hard_min = -float("Inf") 158 | elif hard_max is None: 159 | hard_max = float("Inf") 160 | clamped = clamp(x, minimum=hard_min, maximum=hard_max, is_leaky=False) 161 | 162 | return clamped 163 | 164 | 165 | class ProbabilityConverter(nn.Module): 166 | """Maps floats to probabilites (between 0 and 1), element-wise. 167 | 168 | Parameters 169 | ---------- 170 | min_p : float, optional 171 | Minimum probability, can be useful to set greater than 0 in order to keep 172 | gradient flowing if the probability is used for convex combinations of 173 | different parts of the model. Note that maximum probability is `1-min_p`. 174 | 175 | activation : {"sigmoid", "hard-sigmoid", "leaky-hard-sigmoid"}, optional 176 | name of the activation to use to generate the probabilities. `sigmoid` 177 | has the advantage of being smooth and never exactly 0 or 1, which helps 178 | gradient flows. `hard-sigmoid` has the advantage of making all values 179 | between min_p and max_p equiprobable. 180 | 181 | is_train_temperature : bool, optional 182 | Whether to train the paremeter controling the steapness of the activation. 183 | This is useful when x is used for multiple tasks, and you don't want to 184 | constraint its magnitude. 185 | 186 | is_train_bias : bool, optional 187 | Whether to train the bias to shift the activation. This is useful when x is 188 | used for multiple tasks, and you don't want to constraint it's scale. 189 | 190 | trainable_dim : int, optional 191 | Size of the trainable bias and termperature. If `1` uses the same vale 192 | across all dimension, if not should be equal to the number of input 193 | dimensions to different trainable aprameters for each dimension. Note 194 | that the iitial value will still be the same for all dimensions. 195 | 196 | initial_temperature : int, optional 197 | Initial temperature, a higher temperature makes the activation steaper. 198 | 199 | initial_probability : float, optional 200 | Initial probability you want to start with. 201 | 202 | initial_x : float, optional 203 | First value that will be given to the function, important to make 204 | `initial_probability` work correctly. 205 | 206 | bias_transformer : callable, optional 207 | Transformer function of the bias. This function should only take care of 208 | the boundaries (e.g. leaky relu or relu). 209 | 210 | temperature_transformer : callable, optional 211 | Transformer function of the temperature. This function should only take 212 | care of the boundaries (e.g. leaky relu or relu). 213 | """ 214 | 215 | def __init__( 216 | self, 217 | min_p=0.0, 218 | activation="sigmoid", 219 | is_train_temperature=False, 220 | is_train_bias=False, 221 | trainable_dim=1, 222 | initial_temperature=1.0, 223 | initial_probability=0.5, 224 | initial_x=0, 225 | bias_transformer=nn.Identity(), 226 | temperature_transformer=nn.Identity(), 227 | ): 228 | 229 | super().__init__() 230 | self.min_p = min_p 231 | self.activation = activation 232 | self.is_train_temperature = is_train_temperature 233 | self.is_train_bias = is_train_bias 234 | self.trainable_dim = trainable_dim 235 | self.initial_temperature = initial_temperature 236 | self.initial_probability = initial_probability 237 | self.initial_x = initial_x 238 | self.bias_transformer = bias_transformer 239 | self.temperature_transformer = temperature_transformer 240 | 241 | self.reset_parameters() 242 | 243 | def reset_parameters(self): 244 | self.temperature = torch.tensor([self.initial_temperature] * self.trainable_dim) 245 | if self.is_train_temperature: 246 | self.temperature = nn.Parameter(self.temperature) 247 | 248 | initial_bias = self._probability_to_bias( 249 | self.initial_probability, initial_x=self.initial_x 250 | ) 251 | 252 | self.bias = torch.tensor([initial_bias] * self.trainable_dim) 253 | if self.is_train_bias: 254 | self.bias = nn.Parameter(self.bias) 255 | 256 | def forward(self, x): 257 | self.temperature.to(x.device) 258 | self.bias.to(x.device) 259 | 260 | temperature = self.temperature_transformer(self.temperature) 261 | bias = self.bias_transformer(self.bias) 262 | 263 | if self.activation == "sigmoid": 264 | full_p = torch.sigmoid((x + bias) * temperature) 265 | 266 | elif self.activation in ["hard-sigmoid", "leaky-hard-sigmoid"]: 267 | # uses 0.2 and 0.5 to be similar to sigmoid 268 | y = 0.2 * ((x + bias) * temperature) + 0.5 269 | 270 | if self.activation == "leaky-hard-sigmoid": 271 | full_p = clamp( 272 | y, 273 | minimum=0.1, 274 | maximum=0.9, 275 | is_leaky=True, 276 | negative_slope=0.01, 277 | hard_min=0, 278 | hard_max=0, 279 | ) 280 | elif self.activation == "hard-sigmoid": 281 | full_p = clamp(y, minimum=0.0, maximum=1.0, is_leaky=False) 282 | 283 | else: 284 | raise ValueError("Unkown activation : {}".format(self.activation)) 285 | 286 | p = rescale_range(full_p, (0, 1), (self.min_p, 1 - self.min_p)) 287 | 288 | return p 289 | 290 | def _probability_to_bias(self, p, initial_x=0): 291 | """Compute the bias to use to satisfy the constraints.""" 292 | assert p > self.min_p and p < 1 - self.min_p 293 | range_p = 1 - self.min_p * 2 294 | p = (p - self.min_p) / range_p 295 | p = torch.tensor(p, dtype=torch.float) 296 | 297 | if self.activation == "sigmoid": 298 | bias = -(torch.log((1 - p) / p) / self.initial_temperature + initial_x) 299 | 300 | elif self.activation in ["hard-sigmoid", "leaky-hard-sigmoid"]: 301 | bias = ((p - 0.5) / 0.2) / self.initial_temperature - initial_x 302 | 303 | return bias 304 | 305 | 306 | def dist_to_device(dist, device): 307 | """Set a distirbution to a given device.""" 308 | if dist is None: 309 | return 310 | dist.base_dist.loc = dist.base_dist.loc.to(device) 311 | dist.base_dist.scale = dist.base_dist.loc.to(device) 312 | 313 | 314 | def make_abs_conv(Conv): 315 | """Make a convolution have only positive parameters.""" 316 | 317 | class AbsConv(Conv): 318 | def forward(self, input): 319 | return F.conv2d( 320 | input, 321 | self.weight.abs(), 322 | self.bias, 323 | self.stride, 324 | self.padding, 325 | self.dilation, 326 | self.groups, 327 | ) 328 | 329 | return AbsConv 330 | 331 | 332 | def make_padded_conv(Conv, Padder): 333 | """Make a convolution have any possible padding.""" 334 | 335 | class PaddedConv(Conv): 336 | def __init__(self, *args, Padder=Padder, padding=0, **kwargs): 337 | old_padding = 0 338 | if Padder is None: 339 | Padder = nn.Identity 340 | old_padding = padding 341 | 342 | super().__init__(*args, padding=old_padding, **kwargs) 343 | self.padder = Padder(padding) 344 | 345 | def forward(self, X): 346 | X = self.padder(X) 347 | return super().forward(X) 348 | 349 | return PaddedConv 350 | 351 | 352 | def make_depth_sep_conv(Conv): 353 | """Make a convolution module depth separable.""" 354 | 355 | class DepthSepConv(nn.Module): 356 | """Make a convolution depth separable. 357 | 358 | Parameters 359 | ---------- 360 | in_channels : int 361 | Number of input channels. 362 | 363 | out_channels : int 364 | Number of output channels. 365 | 366 | kernel_size : int 367 | 368 | **kwargs : 369 | Additional arguments to `Conv` 370 | """ 371 | 372 | def __init__( 373 | self, 374 | in_channels, 375 | out_channels, 376 | kernel_size, 377 | confidence=False, 378 | bias=True, 379 | **kwargs 380 | ): 381 | super().__init__() 382 | self.depthwise = Conv( 383 | in_channels, 384 | in_channels, 385 | kernel_size, 386 | groups=in_channels, 387 | bias=bias, 388 | **kwargs 389 | ) 390 | self.pointwise = Conv(in_channels, out_channels, 1, bias=bias) 391 | self.reset_parameters() 392 | 393 | def forward(self, x): 394 | out = self.depthwise(x) 395 | out = self.pointwise(out) 396 | return out 397 | 398 | def reset_parameters(self): 399 | weights_init(self) 400 | 401 | return DepthSepConv 402 | 403 | 404 | class CircularPad2d(nn.Module): 405 | """Implements a 2d circular padding.""" 406 | 407 | def __init__(self, padding): 408 | super().__init__() 409 | self.padding = padding 410 | 411 | def forward(self, x): 412 | return F.pad(x, (self.padding,) * 4, mode="circular") 413 | 414 | 415 | class BackwardPDB(torch.autograd.Function): 416 | """Run PDB in the backward pass.""" 417 | 418 | @staticmethod 419 | def forward(ctx, input, name="debugger"): 420 | ctx.name = name 421 | ctx.save_for_backward(input) 422 | return input 423 | 424 | @staticmethod 425 | def backward(ctx, grad_output): 426 | (input,) = ctx.saved_tensors 427 | if not torch.isfinite(grad_output).all() or not torch.isfinite(input).all(): 428 | import pdb 429 | 430 | pdb.set_trace() 431 | return grad_output, None # 2 args so return None for `name` 432 | 433 | 434 | backward_pdb = BackwardPDB.apply 435 | -------------------------------------------------------------------------------- /modules/initialization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | __all__ = ["weights_init"] 5 | 6 | 7 | def weights_init(module, **kwargs): 8 | """Initialize a module and all its descendents. 9 | 10 | Parameters 11 | ---------- 12 | module : nn.Module 13 | module to initialize. 14 | """ 15 | module.is_resetted = True 16 | for m in module.modules(): 17 | try: 18 | if hasattr(module, "reset_parameters") and module.is_resetted: 19 | # don't reset if resetted already (might want special) 20 | continue 21 | except AttributeError: 22 | pass 23 | 24 | if isinstance(m, torch.nn.modules.conv._ConvNd): 25 | # used in https://github.com/brain-research/realistic-ssl-evaluation/ 26 | nn.init.kaiming_normal_(m.weight, mode="fan_out", **kwargs) 27 | elif isinstance(m, nn.Linear): 28 | linear_init(m, **kwargs) 29 | elif isinstance(m, nn.BatchNorm2d): 30 | m.weight.data.fill_(1) 31 | m.bias.data.zero_() 32 | 33 | 34 | def get_activation_name(activation): 35 | """Given a string or a `torch.nn.modules.activation` return the name of the activation.""" 36 | if isinstance(activation, str): 37 | return activation 38 | 39 | mapper = { 40 | nn.LeakyReLU: "leaky_relu", 41 | nn.ReLU: "relu", 42 | nn.Tanh: "tanh", 43 | nn.Sigmoid: "sigmoid", 44 | nn.Softmax: "sigmoid", 45 | } 46 | for k, v in mapper.items(): 47 | if isinstance(activation, k): 48 | return k 49 | 50 | raise ValueError("Unkown given activation type : {}".format(activation)) 51 | 52 | 53 | def get_gain(activation): 54 | """Given an object of `torch.nn.modules.activation` or an activation name 55 | return the correct gain.""" 56 | if activation is None: 57 | return 1 58 | 59 | activation_name = get_activation_name(activation) 60 | 61 | param = None if activation_name != "leaky_relu" else activation.negative_slope 62 | gain = nn.init.calculate_gain(activation_name, param) 63 | 64 | return gain 65 | 66 | 67 | def linear_init(module, activation="relu"): 68 | """Initialize a linear layer. 69 | 70 | Parameters 71 | ---------- 72 | module : nn.Module 73 | module to initialize. 74 | 75 | activation : `torch.nn.modules.activation` or str, optional 76 | Activation that will be used on the `module`. 77 | """ 78 | x = module.weight 79 | 80 | if module.bias is not None: 81 | module.bias.data.zero_() 82 | 83 | if activation is None: 84 | return nn.init.xavier_uniform_(x) 85 | 86 | activation_name = get_activation_name(activation) 87 | 88 | if activation_name == "leaky_relu": 89 | a = 0 if isinstance(activation, str) else activation.negative_slope 90 | return nn.init.kaiming_uniform_(x, a=a, nonlinearity="leaky_relu") 91 | elif activation_name == "relu": 92 | return nn.init.kaiming_uniform_(x, nonlinearity="relu") 93 | elif activation_name in ["sigmoid", "tanh"]: 94 | return nn.init.xavier_uniform_(x, gain=get_gain(activation)) 95 | 96 | 97 | def init_param_(param, activation=None, is_positive=False, bound=0.05, shift=0): 98 | """Initialize inplace some parameters of the model that are not part of a 99 | children module. 100 | 101 | Parameters 102 | ---------- 103 | param : nn.Parameter: 104 | Parameters to initialize. 105 | 106 | activation : torch.nn.modules.activation or str, optional) 107 | Activation that will be used on the `param`. 108 | 109 | is_positive : bool, optional 110 | Whether to initilize only with positive values. 111 | 112 | bound : float, optional 113 | Maximum absolute value of the initealized values. By default `0.05` which 114 | is keras default uniform bound. 115 | 116 | shift : int, optional 117 | Shift the initialisation by a certain value (same as adding a value after init). 118 | """ 119 | gain = get_gain(activation) 120 | if is_positive: 121 | nn.init.uniform_(param, 1e-5 + shift, bound * gain + shift) 122 | return 123 | 124 | nn.init.uniform_(param, -bound * gain + shift, bound * gain + shift) 125 | -------------------------------------------------------------------------------- /modules/losses.py: -------------------------------------------------------------------------------- 1 | """Module for all the loss of Neural Process Family.""" 2 | import abc 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | from .helpers import ( 7 | LightTailPareto, 8 | dist_to_device, 9 | logcumsumexp, 10 | sum_from_nth_dim, 11 | ) 12 | from torch.distributions.kl import kl_divergence 13 | 14 | 15 | __all__ = ["CNPFLoss", "ELBOLossLNPF", "SUMOLossLNPF", "NLLLossLNPF"] 16 | 17 | 18 | def sum_log_prob(prob, sample): 19 | """Compute log probability then sum all but the z_samples and batch.""" 20 | # size = [n_z_samples, batch_size, *] 21 | log_p = prob.log_prob(sample) 22 | # size = [n_z_samples, batch_size] 23 | sum_log_p = sum_from_nth_dim(log_p, 2) 24 | return sum_log_p 25 | 26 | 27 | class BaseLossNPF(nn.Module, abc.ABC): 28 | """ 29 | Compute the negative log likelihood loss for members of the conditional neural process (sub-)family. 30 | 31 | Parameters 32 | ---------- 33 | reduction : {None,"mean","sum"}, optional 34 | Batch wise reduction. 35 | 36 | is_force_mle_eval : bool, optional 37 | Whether to force mac likelihood eval even if has access to q_zCct 38 | """ 39 | 40 | def __init__(self, reduction="mean", is_force_mle_eval=True): 41 | super().__init__() 42 | self.reduction = reduction 43 | self.is_force_mle_eval = is_force_mle_eval 44 | 45 | def forward(self, pred_outputs, Y_trgt): 46 | """Compute the Neural Process Loss. 47 | 48 | Parameters 49 | ---------- 50 | pred_outputs : tuple 51 | Output of `NeuralProcessFamily`. 52 | 53 | Y_trgt : torch.Tensor, size=[batch_size, *n_trgt, y_dim] 54 | Set of all target values {y_t}. 55 | 56 | Return 57 | ------ 58 | loss : torch.Tensor 59 | size=[batch_size] if `reduction=None` else [1]. 60 | """ 61 | p_yCc, z_samples, q_zCc, q_zCct = pred_outputs 62 | 63 | if self.training: 64 | loss = self.get_loss(p_yCc, z_samples, q_zCc, q_zCct, Y_trgt) 65 | else: 66 | # always uses NPML for evaluation 67 | if self.is_force_mle_eval: 68 | q_zCct = None 69 | loss = NLLLossLNPF.get_loss(self, p_yCc, z_samples, q_zCc, q_zCct, Y_trgt) 70 | 71 | if self.reduction is None: 72 | # size = [batch_size] 73 | return loss 74 | elif self.reduction == "mean": 75 | # size = [1] 76 | return loss.mean(0) 77 | elif self.reduction == "sum": 78 | # size = [1] 79 | return loss.sum(0) 80 | else: 81 | raise ValueError(f"Unknown {self.reduction}") 82 | 83 | @abc.abstractmethod 84 | def get_loss(self, p_yCc, z_samples, q_zCc, q_zCct, Y_trgt): 85 | """Compute the Neural Process Loss 86 | 87 | Parameters 88 | ------ 89 | p_yCc: torch.distributions.Distribution, batch shape=[n_z_samples, batch_size, *n_trgt] ; event shape=[y_dim] 90 | Posterior distribution for target values {p(Y^t|y_c; x_c, x_t)}_t 91 | 92 | z_samples: torch.Tensor, size=[n_z_samples, batch_size, *n_lat, z_dim] 93 | Sampled latents. `None` if `encoded_path==deterministic`. 94 | 95 | q_zCc: torch.distributions.Distribution, batch shape=[batch_size, *n_lat] ; event shape=[z_dim] 96 | Latent distribution for the context points. `None` if `encoded_path==deterministic`. 97 | 98 | q_zCct: torch.distributions.Distribution, batch shape=[batch_size, *n_lat] ; event shape=[z_dim] 99 | Latent distribution for the targets. `None` if `encoded_path==deterministic` 100 | or not training or not `is_q_zCct`. 101 | 102 | Y_trgt: torch.Tensor, size=[batch_size, *n_trgt, y_dim] 103 | Set of all target values {y_t}. 104 | 105 | Return 106 | ------ 107 | loss : torch.Tensor, size=[1]. 108 | """ 109 | pass 110 | 111 | 112 | class CNPFLoss(BaseLossNPF): 113 | """Losss for conditional neural process (suf-)family [1].""" 114 | 115 | def get_loss(self, p_yCc, _, q_zCc, ___, Y_trgt): 116 | assert q_zCc is None 117 | # \sum_t log p(y^t|z) 118 | # \sum_t log p(y^t|z). size = [z_samples, batch_size] 119 | sum_log_p_yCz = sum_log_prob(p_yCc, Y_trgt) 120 | 121 | # size = [batch_size] 122 | nll = -sum_log_p_yCz.squeeze(0) 123 | return nll 124 | 125 | 126 | class ELBOLossLNPF(BaseLossNPF): 127 | """Approximate conditional ELBO [1]. 128 | 129 | References 130 | ---------- 131 | [1] Garnelo, Marta, et al. "Neural processes." arXiv preprint 132 | arXiv:1807.01622 (2018). 133 | """ 134 | 135 | def get_loss(self, p_yCc, _, q_zCc, q_zCct, Y_trgt): 136 | 137 | # first term in loss is E_{q(z|y_cntxt,y_trgt)}[\sum_t log p(y^t|z)] 138 | # \sum_t log p(y^t|z). size = [z_samples, batch_size] 139 | sum_log_p_yCz = sum_log_prob(p_yCc, Y_trgt) 140 | 141 | # E_{q(z|y_cntxt,y_trgt)}[...] . size = [batch_size] 142 | E_z_sum_log_p_yCz = sum_log_p_yCz.mean(0) 143 | 144 | # second term in loss is \sum_l KL[q(z^l|y_cntxt,y_trgt)||q(z^l|y_cntxt)] 145 | # KL[q(z^l|y_cntxt,y_trgt)||q(z^l|y_cntxt)]. size = [batch_size, *n_lat] 146 | kl_z = kl_divergence(q_zCct, q_zCc) 147 | # \sum_l ... . size = [batch_size] 148 | E_z_kl = sum_from_nth_dim(kl_z, 1) 149 | 150 | return -(E_z_sum_log_p_yCz - E_z_kl) 151 | 152 | 153 | class NLLLossLNPF(BaseLossNPF): 154 | """ 155 | Compute the approximate negative log likelihood for Neural Process family[?]. 156 | 157 | Notes 158 | ----- 159 | - might be high variance 160 | - biased 161 | - approximate because expectation over q(z|cntxt) instead of p(z|cntxt) 162 | - if q_zCct is not None then uses importance sampling (i.e. assumes that sampled from it). 163 | 164 | References 165 | ---------- 166 | [?] 167 | """ 168 | 169 | def get_loss(self, p_yCc, z_samples, q_zCc, q_zCct, Y_trgt): 170 | 171 | n_z_samples, batch_size, *n_trgt = p_yCc.batch_shape 172 | 173 | # computes approximate LL in a numerically stable way 174 | # LL = E_{q(z|y_cntxt)}[ \prod_t p(y^t|z)] 175 | # LL MC = log ( mean_z ( \prod_t p(y^t|z)) ) 176 | # = log [ sum_z ( \prod_t p(y^t|z)) ] - log(n_z_samples) 177 | # = log [ sum_z ( exp \sum_t log p(y^t|z)) ] - log(n_z_samples) 178 | # = log_sum_exp_z ( \sum_t log p(y^t|z)) - log(n_z_samples) 179 | 180 | # \sum_t log p(y^t|z). size = [n_z_samples, batch_size] 181 | sum_log_p_yCz = sum_log_prob(p_yCc, Y_trgt) 182 | 183 | # uses importance sampling weights if necessary 184 | if q_zCct is not None: 185 | 186 | # All latents are treated as independent. size = [n_z_samples, batch_size] 187 | sum_log_q_zCc = sum_log_prob(q_zCc, z_samples) 188 | sum_log_q_zCct = sum_log_prob(q_zCct, z_samples) 189 | 190 | # importance sampling : multiply \prod_t p(y^t|z)) by q(z|y_cntxt) / q(z|y_cntxt, y_trgt) 191 | # i.e. add log q(z|y_cntxt) - log q(z|y_cntxt, y_trgt) 192 | sum_log_w_k = sum_log_p_yCz + sum_log_q_zCc - sum_log_q_zCct 193 | else: 194 | sum_log_w_k = sum_log_p_yCz 195 | 196 | # log_sum_exp_z ... . size = [batch_size] 197 | log_S_z_sum_p_yCz = torch.logsumexp(sum_log_w_k, 0) 198 | 199 | # - log(n_z_samples) 200 | log_E_z_sum_p_yCz = log_S_z_sum_p_yCz - math.log(n_z_samples) 201 | 202 | # NEGATIVE log likelihood 203 | return -log_E_z_sum_p_yCz 204 | 205 | 206 | #! might need gradient clipping as in their paper 207 | class SUMOLossLNPF(BaseLossNPF): 208 | """ 209 | Estimate negative log likelihood for Neural Process family using SUMO [1]. 210 | 211 | Notes 212 | ----- 213 | - approximate because expectation over q(z|cntxt) instead of p(z|cntxt) 214 | - if q_zCct is not None then uses importance sampling (i.e. assumes that sampled from it). 215 | 216 | Parameters 217 | ---------- 218 | p_n_z_samples : scipy.stats.rv_frozen, optional 219 | Distribution for the number of of z_samples to take. 220 | 221 | References 222 | ---------- 223 | [1] Luo, Yucen, et al. "SUMO: Unbiased Estimation of Log Marginal Probability for Latent 224 | Variable Models." arXiv preprint arXiv:2004.00353 (2020) 225 | """ 226 | 227 | def __init__( 228 | self, 229 | p_n_z_samples=LightTailPareto(a=5).freeze(85), 230 | **kwargs, 231 | ): 232 | super().__init__() 233 | self.p_n_z_samples = p_n_z_samples 234 | 235 | def get_loss(self, p_yCc, z_samples, q_zCc, q_zCct, Y_trgt): 236 | 237 | n_z_samples, batch_size, *n_trgt = p_yCc.batch_shape 238 | 239 | # \sum_t log p(y^t|z). size = [n_z_samples, batch_size] 240 | sum_log_p_yCz = sum_log_prob(p_yCc, Y_trgt) 241 | 242 | # uses importance sampling weights if necessary 243 | if q_zCct is not None: 244 | # All latents are treated as independent. size = [n_z_samples, batch_size] 245 | sum_log_q_zCc = sum_log_prob(q_zCc, z_samples) 246 | sum_log_q_zCct = sum_log_prob(q_zCct, z_samples) 247 | 248 | #! It should be p(y^t,z|cntxt) but we are using q(z|cntxt) instead of p(z|cntxt) 249 | # \sum_t log (q(y^t,z|cntxt) / q(z|cntxt,trgt)) . size = [n_z_samples, batch_size] 250 | sum_log_w_k = sum_log_p_yCz + sum_log_q_zCc - sum_log_q_zCct 251 | else: 252 | sum_log_w_k = sum_log_p_yCz 253 | 254 | # size = [n_z_samples, 1] 255 | ks = (torch.arange(n_z_samples) + 1).unsqueeze(-1) 256 | #! slow to always put on GPU 257 | log_ks = ks.float().log().to(sum_log_w_k.device) 258 | 259 | #! the algorithm in the paper is not correct on ks[:k+1] and forgot inv_weights[m:] 260 | # size = [n_z_samples, batch_size] 261 | cum_iwae = logcumsumexp(sum_log_w_k, 0) - log_ks 262 | 263 | #! slow to always put on GPU 264 | # you want reverse_cdf which is P(K >= k ) = 1 - P(K < k) = 1 - P(K <= k-1) = 1 - CDF(k-1) 265 | inv_weights = torch.from_numpy(1 - self.p_n_z_samples.cdf(ks - 1)).to( 266 | sum_log_w_k.device 267 | ) 268 | 269 | m = self.p_n_z_samples.support()[0] 270 | # size = [batch_size] 271 | sumo = cum_iwae[m - 1] + ( 272 | inv_weights[m:] * (cum_iwae[m:] - cum_iwae[m - 1 : -1]) 273 | ).sum(0) 274 | 275 | nll = -sumo 276 | return nll 277 | -------------------------------------------------------------------------------- /modules/mlp.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import torch.nn as nn 3 | from .initialization import linear_init 4 | 5 | 6 | __all__ = ["MLP"] 7 | 8 | 9 | class MLP(nn.Module): 10 | """General MLP class. 11 | 12 | Parameters 13 | ---------- 14 | input_size: int 15 | 16 | output_size: int 17 | 18 | hidden_size: int, optional 19 | Number of hidden neurones. 20 | 21 | n_hidden_layers: int, optional 22 | Number of hidden layers. 23 | 24 | activation: callable, optional 25 | Activation function. E.g. `nn.RelU()`. 26 | 27 | is_bias: bool, optional 28 | Whether to use biaises in the hidden layers. 29 | 30 | dropout: float, optional 31 | Dropout rate. 32 | 33 | is_force_hid_smaller : bool, optional 34 | Whether to force the hidden dimensions to be smaller or equal than in and out. 35 | If not, it forces the hidden dimension to be larger or equal than in or out. 36 | 37 | is_res : bool, optional 38 | Whether to use residual connections. 39 | """ 40 | 41 | def __init__( 42 | self, 43 | input_size, 44 | output_size, 45 | hidden_size=32, 46 | n_hidden_layers=1, 47 | activation=nn.ReLU(), 48 | is_bias=True, 49 | dropout=0, 50 | is_force_hid_smaller=False, 51 | is_res=False, 52 | ): 53 | super().__init__() 54 | 55 | self.input_size = input_size 56 | self.output_size = output_size 57 | self.hidden_size = hidden_size 58 | self.n_hidden_layers = n_hidden_layers 59 | self.is_res = is_res 60 | 61 | if is_force_hid_smaller and self.hidden_size > max( 62 | self.output_size, self.input_size 63 | ): 64 | self.hidden_size = max(self.output_size, self.input_size) 65 | txt = "hidden_size={} larger than output={} and input={}. Setting it to {}." 66 | warnings.warn( 67 | txt.format(hidden_size, output_size, input_size, self.hidden_size) 68 | ) 69 | elif self.hidden_size < min(self.output_size, self.input_size): 70 | self.hidden_size = min(self.output_size, self.input_size) 71 | txt = ( 72 | "hidden_size={} smaller than output={} and input={}. Setting it to {}." 73 | ) 74 | warnings.warn( 75 | txt.format(hidden_size, output_size, input_size, self.hidden_size) 76 | ) 77 | 78 | self.dropout = nn.Dropout(p=dropout) if dropout > 0 else nn.Identity() 79 | self.activation = activation 80 | 81 | self.to_hidden = nn.Linear(self.input_size, self.hidden_size, bias=is_bias) 82 | self.linears = nn.ModuleList( 83 | [ 84 | nn.Linear(self.hidden_size, self.hidden_size, bias=is_bias) 85 | for _ in range(self.n_hidden_layers - 1) 86 | ] 87 | ) 88 | self.out = nn.Linear(self.hidden_size, self.output_size, bias=is_bias) 89 | 90 | self.reset_parameters() 91 | 92 | def forward(self, x): 93 | out = self.to_hidden(x) 94 | out = self.activation(out) 95 | x = self.dropout(out) 96 | 97 | for linear in self.linears: 98 | out = linear(x) 99 | out = self.activation(out) 100 | if self.is_res: 101 | out = out + x 102 | out = self.dropout(out) 103 | x = out 104 | 105 | out = self.out(x) 106 | return out 107 | 108 | def reset_parameters(self): 109 | linear_init(self.to_hidden, activation=self.activation) 110 | for lin in self.linears: 111 | linear_init(lin, activation=self.activation) 112 | linear_init(self.out) 113 | -------------------------------------------------------------------------------- /modules/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from timm.models.layers import DropPath, trunc_normal_ 6 | import torch.utils.checkpoint as checkpoint 7 | from functools import reduce, lru_cache 8 | from operator import mul 9 | from einops import rearrange 10 | 11 | 12 | __all__ = ["AllPatchEmbed", "PatchRecover", "BasicLayer", "SwinTransformerLayer"] 13 | 14 | 15 | class Mlp(nn.Module): 16 | """ Multilayer perceptron.""" 17 | 18 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 19 | super().__init__() 20 | out_features = out_features or in_features 21 | hidden_features = hidden_features or in_features 22 | self.fc1 = nn.Linear(in_features, hidden_features) 23 | self.act = act_layer() 24 | self.fc2 = nn.Linear(hidden_features, out_features) 25 | self.drop = nn.Dropout(drop) 26 | 27 | def forward(self, x): 28 | x = self.fc1(x) 29 | x = self.act(x) 30 | x = self.drop(x) 31 | x = self.fc2(x) 32 | x = self.drop(x) 33 | return x 34 | 35 | 36 | def swin_window_partition(x, window_size): 37 | """ 38 | Args: 39 | x: (B, D, H, W, C) 40 | window_size (tuple[int]): window size 41 | 42 | Returns: 43 | windows: (B*num_windows, window_size*window_size, C) 44 | """ 45 | B, D, H, W, C = x.shape 46 | x = x.view(B, D // window_size[0], window_size[0], H // window_size[1], window_size[1], W // window_size[2], window_size[2], C) 47 | windows = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, reduce(mul, window_size), C) 48 | return windows 49 | 50 | 51 | def swin_window_reverse(windows, window_size, B, D, H, W): 52 | """ 53 | Args: 54 | windows: (B*num_windows, window_size, window_size, C) 55 | window_size (tuple[int]): Window size 56 | H (int): Height of image 57 | W (int): Width of image 58 | 59 | Returns: 60 | x: (B, D, H, W, C) 61 | """ 62 | x = windows.view(B, D // window_size[0], H // window_size[1], W // window_size[2], window_size[0], window_size[1], window_size[2], -1) 63 | x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(B, D, H, W, -1) 64 | return x 65 | 66 | 67 | def get_window_size(x_size, window_size, shift_size=None): 68 | use_window_size = list(window_size) 69 | if shift_size is not None: 70 | use_shift_size = list(shift_size) 71 | for i in range(len(x_size)): 72 | if x_size[i] <= window_size[i]: 73 | use_window_size[i] = x_size[i] 74 | if shift_size is not None: 75 | use_shift_size[i] = 0 76 | 77 | if shift_size is None: 78 | return tuple(use_window_size) 79 | else: 80 | return tuple(use_window_size), tuple(use_shift_size) 81 | 82 | 83 | class WindowAttention3D(nn.Module): 84 | """ Window based multi-head self attention (W-MSA) module with relative position bias. 85 | It supports both of shifted and non-shifted window. 86 | Args: 87 | dim (int): Number of input channels. 88 | window_size (tuple[int]): The temporal length, height and width of the window. 89 | num_heads (int): Number of attention heads. 90 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 91 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set 92 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 93 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0 94 | """ 95 | 96 | def __init__(self, dim, window_size, num_heads, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., cross=False): 97 | 98 | super().__init__() 99 | self.dim = dim 100 | self.window_size = window_size # Wd, Wh, Ww 101 | self.num_heads = num_heads 102 | head_dim = dim // num_heads 103 | self.scale = qk_scale or head_dim ** -0.5 104 | self.cross = cross 105 | 106 | # define a parameter table of relative position bias 107 | self.relative_position_bias_table = nn.Parameter( 108 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1) * (2 * window_size[2] - 1), num_heads)) # 2*Wd-1 * 2*Wh-1 * 2*Ww-1, nH 109 | 110 | # get pair-wise relative position index for each token inside the window 111 | coords_d = torch.arange(self.window_size[0]) 112 | coords_h = torch.arange(self.window_size[1]) 113 | coords_w = torch.arange(self.window_size[2]) 114 | coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w)) # 3, Wd, Wh, Ww 115 | coords_flatten = torch.flatten(coords, 1) # 3, Wd*Wh*Ww 116 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 3, Wd*Wh*Ww, Wd*Wh*Ww 117 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wd*Wh*Ww, Wd*Wh*Ww, 3 118 | relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 119 | relative_coords[:, :, 1] += self.window_size[1] - 1 120 | relative_coords[:, :, 2] += self.window_size[2] - 1 121 | 122 | relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1) 123 | relative_coords[:, :, 1] *= (2 * self.window_size[2] - 1) 124 | relative_position_index = relative_coords.sum(-1) # Wd*Wh*Ww, Wd*Wh*Ww 125 | self.register_buffer("relative_position_index", relative_position_index) 126 | 127 | if self.cross: 128 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 129 | self.k = nn.Linear(dim, dim, bias=qkv_bias) 130 | self.v = nn.Linear(dim, dim, bias=qkv_bias) 131 | else: 132 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 133 | 134 | self.attn_drop = nn.Dropout(attn_drop) 135 | self.proj = nn.Linear(dim, dim) 136 | self.proj_drop = nn.Dropout(proj_drop) 137 | 138 | trunc_normal_(self.relative_position_bias_table, std=.02) 139 | self.softmax = nn.Softmax(dim=-1) 140 | 141 | def forward(self, x, mask=None, condition=None): 142 | """ Forward function. 143 | Args: 144 | x: input features with shape of (num_windows*B, N, C) 145 | mask: (0/-inf) mask with shape of (num_windows, N, N) or None 146 | """ 147 | B_, N, C = x.shape 148 | if self.cross: 149 | q = self.q(x).reshape(B_, N, self.num_heads, C // self.num_heads).transpose(1,2) 150 | k = self.k(condition).reshape(B_, N, self.num_heads, C // self.num_heads).transpose(1,2) 151 | v = self.v(condition).reshape(B_, N, self.num_heads, C // self.num_heads).transpose(1,2) 152 | else: 153 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 154 | q, k, v = qkv[0], qkv[1], qkv[2] # B_, nH, N, C 155 | 156 | q = q * self.scale 157 | attn = q @ k.transpose(-2, -1) 158 | 159 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index[:N, :N].reshape(-1)].reshape( 160 | N, N, -1) # Wd*Wh*Ww,Wd*Wh*Ww,nH 161 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wd*Wh*Ww, Wd*Wh*Ww 162 | attn = attn + relative_position_bias.unsqueeze(0) # B_, nH, N, N 163 | 164 | if mask is not None: 165 | nW = mask.shape[0] 166 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) 167 | attn = attn.view(-1, self.num_heads, N, N) 168 | attn = self.softmax(attn) 169 | else: 170 | attn = self.softmax(attn) 171 | 172 | attn = self.attn_drop(attn) 173 | 174 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 175 | x = self.proj(x) 176 | x = self.proj_drop(x) 177 | return x 178 | 179 | 180 | class SwinTransformerBlock3D(nn.Module): 181 | """ Swin Transformer Block. 182 | 183 | Args: 184 | dim (int): Number of input channels. 185 | num_heads (int): Number of attention heads. 186 | window_size (tuple[int]): Window size. 187 | shift_size (tuple[int]): Shift size for SW-MSA. 188 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 189 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 190 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 191 | drop (float, optional): Dropout rate. Default: 0.0 192 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 193 | drop_path (float, optional): Stochastic depth rate. Default: 0.0 194 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU 195 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 196 | """ 197 | 198 | def __init__(self, dim, num_heads, window_size=(2,7,7), shift_size=(0,0,0), 199 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., 200 | act_layer=nn.GELU, norm_layer=nn.LayerNorm, cross=False, use_checkpoint=False): 201 | super().__init__() 202 | self.dim = dim 203 | self.num_heads = num_heads 204 | self.window_size = window_size 205 | self.shift_size = shift_size 206 | self.mlp_ratio = mlp_ratio 207 | self.use_checkpoint=use_checkpoint 208 | self.cross = cross 209 | 210 | assert 0 <= self.shift_size[0] < self.window_size[0], "shift_size must in 0-window_size" 211 | assert 0 <= self.shift_size[1] < self.window_size[1], "shift_size must in 0-window_size" 212 | assert 0 <= self.shift_size[2] < self.window_size[2], "shift_size must in 0-window_size" 213 | 214 | self.norm1 = norm_layer(dim) 215 | self.attn = WindowAttention3D( 216 | dim, window_size=self.window_size, num_heads=num_heads, 217 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, cross=cross) 218 | 219 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 220 | self.norm2 = norm_layer(dim) 221 | mlp_hidden_dim = int(dim * mlp_ratio) 222 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 223 | 224 | if self.cross: 225 | self.norm_condition = norm_layer(dim) 226 | 227 | def forward_part1(self, x, mask_matrix, condition): 228 | B, D, H, W, C = x.shape 229 | window_size, shift_size = get_window_size((D, H, W), self.window_size, self.shift_size) 230 | 231 | x = self.norm1(x) 232 | # pad feature maps to multiples of window size 233 | pad_l = pad_t = pad_d0 = 0 234 | pad_d1 = (window_size[0] - D % window_size[0]) % window_size[0] 235 | pad_b = (window_size[1] - H % window_size[1]) % window_size[1] 236 | pad_r = (window_size[2] - W % window_size[2]) % window_size[2] 237 | x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_d0, pad_d1)) 238 | _, Dp, Hp, Wp, _ = x.shape 239 | # cyclic shift 240 | if any(i > 0 for i in shift_size): 241 | shifted_x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), dims=(1, 2, 3)) 242 | attn_mask = mask_matrix 243 | else: 244 | shifted_x = x 245 | attn_mask = None 246 | # partition windows 247 | x_windows = swin_window_partition(shifted_x, window_size) # B*nW, Wd*Wh*Ww, C 248 | condition_windows = condition 249 | 250 | if self.cross: 251 | condition = self.norm_condition(condition) 252 | condition = F.pad(condition, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_d0, pad_d1)) 253 | if any(i > 0 for i in shift_size): 254 | shifted_condition = torch.roll(condition, shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), dims=(1, 2, 3)) 255 | else: 256 | shifted_condition = condition 257 | condition_windows = swin_window_partition(shifted_condition, window_size) 258 | 259 | # W-MSA/SW-MSA 260 | attn_windows = self.attn(x_windows, mask=attn_mask, condition=condition_windows) # B*nW, Wd*Wh*Ww, C 261 | # merge windows 262 | attn_windows = attn_windows.view(-1, *(window_size+(C,))) 263 | shifted_x = swin_window_reverse(attn_windows, window_size, B, Dp, Hp, Wp) # B D' H' W' C 264 | # reverse cyclic shift 265 | if any(i > 0 for i in shift_size): 266 | x = torch.roll(shifted_x, shifts=(shift_size[0], shift_size[1], shift_size[2]), dims=(1, 2, 3)) 267 | else: 268 | x = shifted_x 269 | 270 | if pad_d1 >0 or pad_r > 0 or pad_b > 0: 271 | x = x[:, :D, :H, :W, :].contiguous() 272 | return x 273 | 274 | def forward_part2(self, x): 275 | return self.drop_path(self.mlp(self.norm2(x))) 276 | 277 | def forward(self, x, mask_matrix, condition=None): 278 | """ Forward function. 279 | 280 | Args: 281 | x: Input feature, tensor size (B, D, H, W, C). 282 | mask_matrix: Attention mask for cyclic shift. 283 | """ 284 | 285 | shortcut = x 286 | if self.use_checkpoint: 287 | x = checkpoint.checkpoint(self.forward_part1, x, mask_matrix, condition) 288 | else: 289 | x = self.forward_part1(x, mask_matrix, condition) 290 | x = shortcut + self.drop_path(x) 291 | 292 | if self.use_checkpoint: 293 | x = x + checkpoint.checkpoint(self.forward_part2, x) 294 | else: 295 | x = x + self.forward_part2(x) 296 | 297 | return x 298 | 299 | 300 | # cache each stage results 301 | @lru_cache() 302 | def compute_mask(D, H, W, window_size, shift_size, device): 303 | img_mask = torch.zeros((1, D, H, W, 1), device=device) # 1 Dp Hp Wp 1 304 | cnt = 0 305 | for d in slice(-window_size[0]), slice(-window_size[0], -shift_size[0]), slice(-shift_size[0],None): 306 | for h in slice(-window_size[1]), slice(-window_size[1], -shift_size[1]), slice(-shift_size[1],None): 307 | for w in slice(-window_size[2]), slice(-window_size[2], 0), slice(0,None): 308 | img_mask[:, d, h, w, :] = cnt 309 | cnt += 1 310 | mask_windows = swin_window_partition(img_mask, window_size) # nW, ws[0]*ws[1]*ws[2], 1 311 | mask_windows = mask_windows.squeeze(-1) # nW, ws[0]*ws[1]*ws[2] 312 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 313 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 314 | return attn_mask 315 | 316 | 317 | class GatedCrossAttention(nn.Module): 318 | """ A basic Swin Transformer layer for one stage. 319 | 320 | Args: 321 | dim (int): Number of feature channels 322 | depth (int): Depths of this stage. 323 | num_heads (int): Number of attention head. 324 | window_size (tuple[int]): Local window size. Default: (1,7,7). 325 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. 326 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 327 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 328 | drop (float, optional): Dropout rate. Default: 0.0 329 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 330 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 331 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 332 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 333 | """ 334 | 335 | def __init__(self, dim, num_heads, window_size=(1,7,7), mlp_ratio=4., qkv_bias=True, qk_scale=None, 336 | drop=0., attn_drop=0., drop_path=0., norm_layer=nn.LayerNorm, use_checkpoint=False): 337 | super().__init__() 338 | self.window_size = window_size 339 | self.shift_size = tuple(i // 2 for i in window_size) 340 | self.use_checkpoint = use_checkpoint 341 | 342 | # build blocks 343 | self.back1 = SwinTransformerBlock3D(dim=dim, num_heads=num_heads, window_size=window_size, shift_size=(0,0,0), 344 | mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop, 345 | drop_path=drop_path, norm_layer=norm_layer, cross=True, use_checkpoint=use_checkpoint) 346 | self.back2 = SwinTransformerBlock3D(dim=dim, num_heads=num_heads, window_size=window_size, shift_size=self.shift_size, 347 | mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop, 348 | drop_path=drop_path, norm_layer=norm_layer, cross=True, use_checkpoint=use_checkpoint) 349 | self.obser1 = SwinTransformerBlock3D(dim=dim, num_heads=num_heads, window_size=window_size, shift_size=(0,0,0), 350 | mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop, 351 | drop_path=drop_path, norm_layer=norm_layer, cross=True, use_checkpoint=use_checkpoint) 352 | self.obser2 = SwinTransformerBlock3D(dim=dim, num_heads=num_heads, window_size=window_size, shift_size=self.shift_size, 353 | mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop, 354 | drop_path=drop_path, norm_layer=norm_layer, cross=True, use_checkpoint=use_checkpoint) 355 | self.gate = nn.Conv3d(dim, dim, 1) 356 | 357 | def forward(self, x1, x2, gate): 358 | """ Forward function. 359 | 360 | Args: 361 | x: Input feature, tensor size (B, C, D, H, W). 362 | """ 363 | # calculate attention mask for SW-MSA 364 | B, C, D, H, W = x1.shape 365 | window_size, shift_size = get_window_size((D,H,W), self.window_size, self.shift_size) 366 | x1 = rearrange(x1, 'b c d h w -> b d h w c') 367 | x2 = rearrange(x2, 'b c d h w -> b d h w c') 368 | gate = rearrange(gate, 'b c d h w -> b d h w c') 369 | Dp = int(np.ceil(D / window_size[0])) * window_size[0] 370 | Hp = int(np.ceil(H / window_size[1])) * window_size[1] 371 | Wp = int(np.ceil(W / window_size[2])) * window_size[2] 372 | attn_mask = compute_mask(Dp, Hp, Wp, window_size, shift_size, x1.device) 373 | 374 | x1_mid = x1 * (1 - gate) + self.back1(x1 * gate, attn_mask, x2) 375 | x2_mid = self.obser1(x2, attn_mask, x1) 376 | x1 = x1_mid * (1 - gate) + self.back2(x1_mid * gate, attn_mask, x2_mid) 377 | x2 = self.obser2(x2_mid, attn_mask, x1_mid) 378 | 379 | x1 = rearrange(x1, 'b d h w c -> b c d h w') 380 | x2 = rearrange(x2, 'b d h w c -> b c d h w') 381 | gate = rearrange(gate, 'b d h w c -> b c d h w') 382 | gate = torch.sigmoid(self.gate(gate)) 383 | return [x1, x2, gate] 384 | 385 | 386 | class SwinTransformerLayer(nn.Module): 387 | """ A basic Swin Transformer layer for one stage. 388 | 389 | Args: 390 | dim (int): Number of feature channels 391 | depth (int): Depths of this stage. 392 | num_heads (int): Number of attention head. 393 | window_size (tuple[int]): Local window size. Default: (1,7,7). 394 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. 395 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 396 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 397 | drop (float, optional): Dropout rate. Default: 0.0 398 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 399 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 400 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 401 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 402 | """ 403 | 404 | def __init__(self, 405 | dim, 406 | depth, 407 | num_heads, 408 | window_size=(1,7,7), 409 | mlp_ratio=4., 410 | qkv_bias=True, 411 | qk_scale=None, 412 | drop=0., 413 | attn_drop=0., 414 | drop_path=0., 415 | norm_layer=nn.LayerNorm, 416 | use_checkpoint=False): 417 | super().__init__() 418 | self.window_size = window_size 419 | self.shift_size = tuple(i // 2 for i in window_size) 420 | self.depth = depth 421 | self.use_checkpoint = use_checkpoint 422 | 423 | # build blocks 424 | self.blocks = nn.ModuleList([ 425 | SwinTransformerBlock3D( 426 | dim=dim, 427 | num_heads=num_heads, 428 | window_size=window_size, 429 | shift_size=(0,0,0) if (i % 2 == 0) else self.shift_size, 430 | mlp_ratio=mlp_ratio, 431 | qkv_bias=qkv_bias, 432 | qk_scale=qk_scale, 433 | drop=drop, 434 | attn_drop=attn_drop, 435 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 436 | norm_layer=norm_layer, 437 | use_checkpoint=use_checkpoint, 438 | ) 439 | for i in range(depth)]) 440 | 441 | def forward(self, x): 442 | """ Forward function. 443 | 444 | Args: 445 | x: Input feature, tensor size (B, C, D, H, W). 446 | """ 447 | # calculate attention mask for SW-MSA 448 | B, C, D, H, W = x.shape 449 | window_size, shift_size = get_window_size((D,H,W), self.window_size, self.shift_size) 450 | x = rearrange(x, 'b c d h w -> b d h w c') 451 | Dp = int(np.ceil(D / window_size[0])) * window_size[0] 452 | Hp = int(np.ceil(H / window_size[1])) * window_size[1] 453 | Wp = int(np.ceil(W / window_size[2])) * window_size[2] 454 | attn_mask = compute_mask(Dp, Hp, Wp, window_size, shift_size, x.device) 455 | for blk in self.blocks: 456 | x = blk(x, attn_mask) 457 | x = x.view(B, D, H, W, -1) 458 | x = rearrange(x, 'b d h w c -> b c d h w') 459 | return x 460 | 461 | 462 | class PatchMerging(nn.Module): 463 | """ Patch Merging Layer 464 | 465 | Args: 466 | dim (int): Number of input channels. 467 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 468 | """ 469 | def __init__(self, dim, norm_layer=nn.LayerNorm): 470 | super().__init__() 471 | self.dim = dim 472 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) 473 | self.norm = norm_layer(4 * dim) 474 | 475 | def forward(self, x): 476 | """ Forward function. 477 | 478 | Args: 479 | x: Input feature, tensor size (B, D, H, W, C). 480 | """ 481 | B, D, H, W, C = x.shape 482 | 483 | # padding 484 | pad_input = (H % 2 == 1) or (W % 2 == 1) 485 | if pad_input: 486 | x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) 487 | 488 | x0 = x[:, :, 0::2, 0::2, :] # B D H/2 W/2 C 489 | x1 = x[:, :, 0::2, 1::2, :] # B D H/2 W/2 C 490 | x2 = x[:, :, 1::2, 0::2, :] # B D H/2 W/2 C 491 | x3 = x[:, :, 1::2, 1::2, :] # B D H/2 W/2 C 492 | x = torch.cat([x0, x1, x2, x3], -1) # B D H/2 W/2 4*C 493 | 494 | x = self.norm(x) 495 | x = self.reduction(x) 496 | 497 | return x 498 | 499 | 500 | class AllPatchMerging(nn.Module): 501 | """ Patch Merging Layer 502 | 503 | Args: 504 | dim (int): Number of input channels. 505 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 506 | """ 507 | def __init__(self, dim, norm_layer=nn.LayerNorm): 508 | super().__init__() 509 | self.dim = dim 510 | self.merge1 = PatchMerging(dim, norm_layer) 511 | self.merge2 = PatchMerging(dim, norm_layer) 512 | self.merge3 = PatchMerging(dim, norm_layer) 513 | 514 | def forward(self, x): 515 | """ Forward function. 516 | 517 | Args: 518 | x: Input feature, tensor size (B, C, D, H, W). 519 | """ 520 | x1 = rearrange(x[0], 'b c d h w -> b d h w c') 521 | x2 = rearrange(x[1], 'b c d h w -> b d h w c') 522 | gate = rearrange(x[2], 'b c d h w -> b d h w c') 523 | x1 = self.merge1(x1) 524 | x2 = self.merge2(x2) 525 | gate = torch.sigmoid(self.merge3(gate)) 526 | x1 = rearrange(x1, 'b d h w c -> b c d h w') 527 | x2 = rearrange(x2, 'b d h w c -> b c d h w') 528 | gate = rearrange(gate, 'b d h w c -> b c d h w') 529 | return [x1, x2, gate] 530 | 531 | 532 | class PatchExpand(nn.Module): 533 | """ Patch Merging Layer 534 | 535 | Args: 536 | dim (int): Number of input channels. 537 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 538 | """ 539 | def __init__(self, dim, norm_layer=nn.LayerNorm): 540 | super().__init__() 541 | self.dim = dim 542 | self.expansion = nn.Linear(dim, 2 * dim, bias=False) 543 | self.norm = norm_layer(dim) 544 | 545 | def forward(self, x): 546 | """ Forward function. 547 | 548 | Args: 549 | x: Input feature, tensor size (B, D, H, W, C). 550 | """ 551 | B, D, H, W, C = x.shape 552 | 553 | # padding 554 | pad_input = (H % 2 == 1) or (W % 2 == 1) 555 | if pad_input: 556 | x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) 557 | 558 | x = self.norm(x) 559 | x = self.expansion(x) 560 | 561 | x = x.reshape(B, D, H, W, 2, 2, C//2) 562 | x = rearrange(x, 'b d h w h1 w1 c -> b d h h1 w w1 c') 563 | x = x.reshape(B, D, 2*H, 2*W, C//2) 564 | 565 | return x 566 | 567 | 568 | class AllPatchExpand(nn.Module): 569 | """ Patch Merging Layer 570 | 571 | Args: 572 | dim (int): Number of input channels. 573 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 574 | """ 575 | def __init__(self, dim, norm_layer=nn.LayerNorm): 576 | super().__init__() 577 | self.dim = dim 578 | self.expand1 = PatchExpand(dim, norm_layer) 579 | self.expand2 = PatchExpand(dim, norm_layer) 580 | self.expand3 = PatchExpand(dim, norm_layer) 581 | 582 | def forward(self, x): 583 | """ Forward function. 584 | 585 | Args: 586 | x: Input feature, tensor size (B, C, D, H, W). 587 | """ 588 | x1 = rearrange(x[0], 'b c d h w -> b d h w c') 589 | x2 = rearrange(x[1], 'b c d h w -> b d h w c') 590 | gate = rearrange(x[2], 'b c d h w -> b d h w c') 591 | x1 = self.expand1(x1) 592 | x2 = self.expand2(x2) 593 | gate = torch.sigmoid(self.expand3(gate)) 594 | x1 = rearrange(x1, 'b d h w c -> b c d h w') 595 | x2 = rearrange(x2, 'b d h w c -> b c d h w') 596 | gate = rearrange(gate, 'b d h w c -> b c d h w') 597 | return [x1, x2, gate] 598 | 599 | 600 | class PatchEmbed(nn.Module): 601 | 602 | def __init__(self, img_size=(69,721,1440), embed_dim=96, patch_size=(1,4,4), norm_layer=None): 603 | super().__init__() 604 | 605 | if img_size[1] % 2: 606 | self.proj3d = nn.Conv3d(5, embed_dim, kernel_size=(patch_size[0], patch_size[1]+1, patch_size[2]), stride=patch_size) 607 | self.proj2d = nn.Conv2d(4, embed_dim, kernel_size=(patch_size[1]+1, patch_size[2]), stride=(patch_size[1], patch_size[2])) 608 | else: 609 | self.proj3d = nn.Conv3d(5, embed_dim, kernel_size=patch_size, stride=patch_size) 610 | self.proj2d = nn.Conv2d(4, embed_dim, kernel_size=patch_size[1:], stride=patch_size[1:]) 611 | 612 | self.embed_dim = embed_dim 613 | if norm_layer is not None: 614 | self.norm = norm_layer(embed_dim) 615 | else: 616 | self.norm = None 617 | 618 | def forward(self, x): 619 | """Forward function.""" 620 | 621 | B, C, H, W = x.shape 622 | x2d = x[:,:4,:,:] # b,4,721,1440 623 | x3d = x[:,4:,:,:].reshape(B, 5, C//5, H, W) # b,5,13,721,1440 624 | x2d = self.proj2d(x2d).unsqueeze(2) # b,c,1,180,360 625 | x3d = self.proj3d(x3d) # b,c,13,180,360 626 | x = torch.cat([x3d, x2d], dim=2) # b,c,14,180,360 627 | 628 | if self.norm is not None: 629 | D, H, W = x.size(2), x.size(3), x.size(4) 630 | x = x.flatten(2).transpose(1, 2) 631 | x = self.norm(x) 632 | x = x.transpose(1, 2).view(B, self.embed_dim, D, H, W) 633 | 634 | return x 635 | 636 | 637 | class AllPatchEmbed(nn.Module): 638 | 639 | def __init__(self, img_size=(69,721,1440), embed_dim=96, patch_size=(1,4,4), norm_layer=None): 640 | super().__init__() 641 | 642 | self.patch1 = PatchEmbed(img_size=img_size, embed_dim=embed_dim, patch_size=patch_size, norm_layer=norm_layer) 643 | self.patch2 = PatchEmbed(img_size=img_size, embed_dim=embed_dim, patch_size=patch_size, norm_layer=norm_layer) 644 | self.patch3 = PatchEmbed(img_size=img_size, embed_dim=embed_dim, patch_size=patch_size, norm_layer=norm_layer) 645 | 646 | self.patch_resolution = (img_size[0]//5+1, img_size[1]//patch_size[1], img_size[2]//patch_size[2]) 647 | 648 | def forward(self, x1, x2, mask): 649 | """Forward function.""" 650 | 651 | x1 = self.patch1(x1) 652 | x2 = self.patch2(x2) 653 | mask = torch.sigmoid(self.patch3(mask)) 654 | 655 | return [x1, x2, mask] 656 | 657 | 658 | class PatchRecover(nn.Module): 659 | 660 | def __init__(self, img_size=(69,721,1440), embed_dim=96, patch_size=(1,4,4)): 661 | super().__init__() 662 | 663 | if img_size[1] % 2: 664 | self.proj3d = nn.ConvTranspose3d(embed_dim, 5, kernel_size=(patch_size[0], patch_size[1]+1, patch_size[2]), stride=patch_size) 665 | self.proj2d = nn.ConvTranspose2d(embed_dim, 4, kernel_size=(patch_size[1]+1, patch_size[2]), stride=(patch_size[1], patch_size[2])) 666 | else: 667 | self.proj3d = nn.ConvTranspose3d(embed_dim, 5, kernel_size=patch_size, stride=patch_size) 668 | self.proj2d = nn.ConvTranspose2d(embed_dim, 4, kernel_size=patch_size[1:], stride=patch_size[1:]) 669 | 670 | def forward(self, x): 671 | """Forward function.""" 672 | 673 | x2d = x[:,:,-1:,:,:].squeeze(2) # b,c,180,360 674 | x3d = x[:,:,:-1,:,:] # b,c,13,180,360 675 | x2d = self.proj2d(x2d) # b,4,721,1440 676 | x3d = self.proj3d(x3d).flatten(1,2) # b,65,721,1440 677 | x = torch.cat([x2d, x3d], dim=1) 678 | 679 | return x 680 | 681 | 682 | class BasicLayer(nn.Module): 683 | 684 | def __init__(self, dim, kernel, padding, num_heads, window_size, sample=None, use_checkpoint=False): 685 | super().__init__() 686 | self.sample = sample 687 | inchans = dim * 2 if self.sample == 'up' else dim 688 | self.conv = nn.Conv3d(inchans, dim, kernel_size=kernel, padding=padding) 689 | self.gateconv = nn.Conv3d(inchans, dim, kernel_size=kernel, padding=padding) 690 | self.gate = nn.Conv3d(inchans, dim, kernel_size=kernel, padding=padding) 691 | self.crosslayer = GatedCrossAttention(dim=dim, num_heads=num_heads, window_size=window_size, qkv_bias=True, use_checkpoint=use_checkpoint) 692 | if self.sample == 'down': 693 | self.samplelayer = AllPatchMerging(dim//2) 694 | elif self.sample == 'up': 695 | self.samplelayer = AllPatchExpand(dim*2) 696 | 697 | def concate(self, x, prev): 698 | 699 | x1 = torch.cat([x[0], prev[0]], dim=1) 700 | x2 = torch.cat([x[1], prev[1]], dim=1) 701 | gate = torch.cat([x[2], prev[2]], dim=1) 702 | 703 | return x1, x2, gate 704 | 705 | def forward(self, x, prev=None): 706 | if self.sample is not None: 707 | x = self.samplelayer(x) 708 | x1, x2, gate = x[0], x[1], x[2] 709 | if prev is not None: 710 | x1, x2, gate = self.concate(x, prev) 711 | 712 | x1 = F.silu(self.conv(x1)) 713 | gate = torch.sigmoid(self.gate(gate)) 714 | x2 = F.silu(self.gateconv(x2)) 715 | x2 = x2 * gate 716 | out = self.crosslayer(x1, x2, gate) 717 | out = [out[0]+x[0], out[1]+x[1], out[2]+x[2]] 718 | return out 719 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | from utils.builder import ConfigBuilder 5 | import utils.misc as utils 6 | import yaml 7 | from utils.logger import get_logger 8 | 9 | 10 | 11 | def subprocess_fn(args): 12 | utils.setup_seed(args.seed * args.world_size + args.rank) 13 | 14 | logger = get_logger("train", args.rundir, utils.get_rank(), filename='iter.log') 15 | args.cfg_params["logger"] = logger 16 | 17 | # build config 18 | logger.info('Building config ...') 19 | builder = ConfigBuilder(**args.cfg_params) 20 | 21 | # build model 22 | logger.info('Building models ...') 23 | model = builder.get_model() 24 | model.kernel = utils.DistributedParallel_Model(model.kernel, args.local_rank) 25 | 26 | # build forecast model 27 | logger.info('Building forecast models ...') 28 | args.forecast_model = builder.get_forecast(args.local_rank) 29 | 30 | # build dataset 31 | logger.info('Building dataloaders ...') 32 | dataset_params = args.cfg_params['dataset'] 33 | train_dataloader = builder.get_dataloader(dataset_params=dataset_params, split='train', batch_size=args.batch_size) 34 | valid_dataloader = builder.get_dataloader(dataset_params=dataset_params, split='valid', batch_size=args.batch_size) 35 | # logger.info(f'dataloader length {len(train_dataloader), len(valid_dataloader)}') 36 | 37 | # train 38 | logger.info('begin training ...') 39 | model.train(train_dataloader, valid_dataloader, logger, args) 40 | logger.info('training end ...') 41 | 42 | 43 | def main(args): 44 | if args.world_size > 1: 45 | utils.init_distributed_mode(args) 46 | else: 47 | args.rank = 0 48 | args.local_rank = 0 49 | args.distributed = False 50 | args.gpu = 0 51 | torch.cuda.set_device(args.gpu) 52 | 53 | args.cfg = os.path.join(args.rundir, 'training_options.yaml') 54 | with open(args.cfg, 'r') as cfg_file: 55 | cfg_params = yaml.load(cfg_file, Loader = yaml.FullLoader) 56 | 57 | cfg_params['dataloader']['num_workers'] = args.per_cpus 58 | cfg_params['dataset']['train']['length'] = args.lead_time // 6 + 2 59 | cfg_params['dataset']['valid']['length'] = args.lead_time // 6 + 2 60 | args.cfg_params = cfg_params 61 | 62 | args.rundir = os.path.join(args.rundir, f'mask{args.ratio}_lead{args.lead_time}h_res{args.resolution}') 63 | os.makedirs(args.rundir, exist_ok=True) 64 | 65 | if args.rank == 0: 66 | with open(os.path.join(args.rundir, 'train.yaml'), 'wt') as f: 67 | yaml.dump(vars(args), f, indent=2, sort_keys=False) 68 | # yaml.dump(cfg_params, f, indent=2, sort_keys=False) 69 | 70 | subprocess_fn(args) 71 | 72 | 73 | if __name__ == "__main__": 74 | 75 | parser = argparse.ArgumentParser() 76 | 77 | parser.add_argument('--seed', type = int, default = 0, help = 'seed') 78 | parser.add_argument('--cuda', type = int, default = 0, help = 'cuda id') 79 | parser.add_argument('--world_size', type = int, default = 4, help = 'number of progress') 80 | parser.add_argument('--per_cpus', type = int, default = 4, help = 'number of perCPUs to use') 81 | parser.add_argument('--max_epoch', type = int, default = 20, help = "maximum training epochs") 82 | parser.add_argument('--batch_size', type = int, default = 1, help = "batch size") 83 | parser.add_argument('--lead_time', type = int, default = 24, help = "lead time (h) for background") 84 | parser.add_argument('--ratio', type = float, default = 0.9, help = "mask ratio") 85 | parser.add_argument('--resolution', type = int, default = 128, help = "observation resolution") 86 | parser.add_argument('--init_method', type = str, default = 'tcp://127.0.0.1:19111', help = 'multi process init method') 87 | parser.add_argument('--rundir', type = str, default = './configs/FNP', help = 'where to save the results') 88 | 89 | args = parser.parse_args() 90 | 91 | main(args) 92 | 93 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenEarthLab/FNP/624e624be481cfa6a149613bd8a08f5df318cb10/utils/__init__.py -------------------------------------------------------------------------------- /utils/builder.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data.distributed import DistributedSampler 2 | from utils.misc import get_rank, get_world_size, is_dist_avail_and_initialized 3 | import onnxruntime as ort 4 | 5 | 6 | class ConfigBuilder(object): 7 | """ 8 | Configuration Builder. 9 | 10 | """ 11 | def __init__(self, **params): 12 | """ 13 | Set the default configuration for the configuration builder. 14 | 15 | Parameters 16 | ---------- 17 | 18 | params: the configuration parameters. 19 | """ 20 | super(ConfigBuilder, self).__init__() 21 | self.model_params = params.get('model', {}) 22 | self.dataset_params = params.get('dataset', {'data_dir': 'data'}) 23 | self.dataloader_params = params.get('dataloader', {}) 24 | 25 | self.logger = params.get('logger', None) 26 | 27 | def get_model(self, model_params = None): 28 | """ 29 | Get the model from configuration. 30 | 31 | Parameters 32 | ---------- 33 | 34 | model_params: dict, optional, default: None. If model_params is provided, then use the parameters specified in the model_params to build the model. Otherwise, the model parameters in the self.params will be used to build the model. 35 | """ 36 | from models.FNP import FNP 37 | from models.ConvCNP import ConvCNP 38 | from models.Adas import Adas 39 | 40 | if model_params is None: 41 | model_params = self.model_params 42 | type = model_params.get('type', 'FNP') 43 | 44 | if type == 'FNP': 45 | model = FNP(**model_params) 46 | elif type == 'ConvCNP': 47 | model = ConvCNP(**model_params) 48 | elif type == 'Adas': 49 | model = Adas(**model_params) 50 | else: 51 | raise NotImplementedError('Invalid model type.') 52 | 53 | return model 54 | 55 | def get_forecast(self, local_rank): 56 | 57 | # Set the behavier of onnxruntime 58 | options = ort.SessionOptions() 59 | options.enable_cpu_mem_arena=False 60 | options.enable_mem_pattern = False 61 | options.enable_mem_reuse = False 62 | # Increase the number for faster inference and more memory consumption 63 | options.intra_op_num_threads = 1 64 | 65 | # Set the behavier of cuda provider 66 | cuda_provider_options = {'device_id': local_rank, 'arena_extend_strategy':'kSameAsRequested',} 67 | 68 | # Initialize onnxruntime session for Pangu-Weather Models 69 | ort_session = ort.InferenceSession('./models/FengWu.onnx', sess_options=options, providers=[('CUDAExecutionProvider', cuda_provider_options)]) 70 | 71 | return ort_session 72 | 73 | def get_dataset(self, dataset_params = None, split = 'train'): 74 | """ 75 | Get the dataset from configuration. 76 | 77 | Parameters 78 | ---------- 79 | 80 | dataset_params: dict, optional, default: None. If dataset_params is provided, then use the parameters specified in the dataset_params to build the dataset. Otherwise, the dataset parameters in the self.params will be used to build the dataset; 81 | 82 | split: str in ['train', 'test'], optional, default: 'train', the splitted dataset. 83 | 84 | Returns 85 | ------- 86 | 87 | A torch.utils.data.Dataset item. 88 | """ 89 | from datasets.era5_npy_f32 import era5_npy_f32 90 | if dataset_params is None: 91 | dataset_params = self.dataset_params 92 | dataset_params = dataset_params.get(split, None) 93 | if dataset_params is None: 94 | return None 95 | dataset = era5_npy_f32(split = split, **dataset_params) 96 | return dataset 97 | 98 | def get_sampler(self, dataset, split = 'train', drop_last=False): 99 | if split == 'train': 100 | shuffle = True 101 | else: 102 | shuffle = False 103 | 104 | if is_dist_avail_and_initialized(): 105 | rank = get_rank() 106 | num_gpus = get_world_size() 107 | else: 108 | rank = 0 109 | num_gpus = 1 110 | sampler = DistributedSampler(dataset, rank=rank, shuffle=shuffle, num_replicas=num_gpus, drop_last=drop_last) 111 | 112 | return sampler 113 | 114 | 115 | def get_dataloader(self, dataset_params = None, split = 'train', batch_size = 1, dataloader_params = None, drop_last = True): 116 | """ 117 | Get the dataloader from configuration. 118 | 119 | Parameters 120 | ---------- 121 | 122 | dataset_params: dict, optional, default: None. If dataset_params is provided, then use the parameters specified in the dataset_params to build the dataset. Otherwise, the dataset parameters in the self.params will be used to build the dataset; 123 | 124 | split: str in ['train', 'test'], optional, default: 'train', the splitted dataset; 125 | 126 | batch_size: int, optional, default: None. If batch_size is None, then the batch size parameter in the self.params will be used to represent the batch size (If still not specified, default: 4); 127 | 128 | dataloader_params: dict, optional, default: None. If dataloader_params is provided, then use the parameters specified in the dataloader_params to get the dataloader. Otherwise, the dataloader parameters in the self.params will be used to get the dataloader. 129 | 130 | Returns 131 | ------- 132 | 133 | A torch.utils.data.DataLoader item. 134 | """ 135 | from torch.utils.data import DataLoader 136 | 137 | # if split != "train": 138 | # drop_last = True 139 | if dataloader_params is None: 140 | dataloader_params = self.dataloader_params 141 | dataset = self.get_dataset(dataset_params, split) 142 | if dataset is None: 143 | return None 144 | sampler = self.get_sampler(dataset, split, drop_last=drop_last) 145 | 146 | return DataLoader( 147 | dataset, 148 | batch_size = batch_size, 149 | sampler=sampler, 150 | drop_last=drop_last, 151 | **dataloader_params 152 | ) 153 | 154 | 155 | def get_optimizer(model, optimizer_params = None, resume = False, resume_lr = None): 156 | """ 157 | Get the optimizer from configuration. 158 | 159 | Parameters 160 | ---------- 161 | 162 | model: a torch.nn.Module object, the model. 163 | 164 | optimizer_params: dict, optional, default: None. If optimizer_params is provided, then use the parameters specified in the optimizer_params to build the optimizer. Otherwise, the optimizer parameters in the self.params will be used to build the optimizer; 165 | 166 | resume: bool, optional, default: False, whether to resume training from an existing checkpoint; 167 | 168 | resume_lr: float, optional, default: None, the resume learning rate. 169 | 170 | Returns 171 | ------- 172 | 173 | An optimizer for the given model. 174 | """ 175 | from torch.optim import SGD, Adam, AdamW 176 | type = optimizer_params.get('type', 'AdamW') 177 | params = optimizer_params.get('params', {}) 178 | 179 | if resume: 180 | network_params = [{'params': model.parameters(), 'initial_lr': resume_lr}] 181 | params.update(lr = resume_lr) 182 | else: 183 | network_params = model.parameters() 184 | if type == 'SGD': 185 | optimizer = SGD(network_params, **params) 186 | elif type == 'Adam': 187 | optimizer = Adam(network_params, **params) 188 | elif type == 'AdamW': 189 | optimizer = AdamW(network_params, **params) 190 | else: 191 | raise NotImplementedError('Invalid optimizer type.') 192 | return optimizer 193 | 194 | def get_lr_scheduler(optimizer, lr_scheduler_params = None, resume = False, resume_epoch = None, total_steps = None): 195 | """ 196 | Get the learning rate scheduler from configuration. 197 | 198 | Parameters 199 | ---------- 200 | 201 | optimizer: an optimizer; 202 | 203 | lr_scheduler_params: dict, optional, default: None. If lr_scheduler_params is provided, then use the parameters specified in the lr_scheduler_params to build the learning rate scheduler. Otherwise, the learning rate scheduler parameters in the self.params will be used to build the learning rate scheduler; 204 | 205 | resume: bool, optional, default: False, whether to resume training from an existing checkpoint; 206 | 207 | resume_epoch: int, optional, default: None, the epoch of the checkpoint. 208 | 209 | Returns 210 | ------- 211 | 212 | A learning rate scheduler for the given optimizer. 213 | """ 214 | from torch.optim.lr_scheduler import MultiStepLR, ExponentialLR, CyclicLR, CosineAnnealingLR, StepLR, OneCycleLR 215 | type = lr_scheduler_params.get('type', '') 216 | params = lr_scheduler_params.get('params', {}) 217 | if resume: 218 | params.update(last_epoch = resume_epoch) 219 | if type == 'MultiStepLR': 220 | scheduler = MultiStepLR(optimizer, **params) 221 | elif type == 'ExponentialLR': 222 | scheduler = ExponentialLR(optimizer, **params) 223 | elif type == 'CyclicLR': 224 | scheduler = CyclicLR(optimizer, **params) 225 | elif type == 'CosineAnnealingLR': 226 | scheduler = CosineAnnealingLR(optimizer, **params) 227 | elif type == 'StepLR': 228 | scheduler = StepLR(optimizer, **params) 229 | elif type == 'OneCycleLR': 230 | scheduler = OneCycleLR(optimizer, total_steps=total_steps, **params) 231 | elif type == '': 232 | scheduler = None 233 | else: 234 | raise NotImplementedError('Invalid learning rate scheduler type.') 235 | return scheduler 236 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import logging 3 | import os 4 | 5 | logger_initialized = {} 6 | 7 | def get_logger(name, save_dir, distributed_rank, filename="log.log", resume=False): 8 | logger = logging.getLogger(name) 9 | if name in logger_initialized: 10 | return logger 11 | 12 | logger.propagate = False 13 | # don't log results for the non-master process 14 | if distributed_rank > 0: 15 | logger.setLevel(logging.ERROR) 16 | logger.setLevel(logging.WARNING) 17 | return logger 18 | formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s") 19 | 20 | ch = logging.StreamHandler() 21 | ch.setLevel(logging.INFO) 22 | ch.setFormatter(formatter) 23 | logger.addHandler(ch) 24 | 25 | if save_dir: 26 | if resume: 27 | fh = logging.FileHandler(os.path.join(save_dir, filename), mode='a') 28 | else: 29 | fh = logging.FileHandler(os.path.join(save_dir, filename), mode='w') 30 | fh.setLevel(logging.INFO) 31 | fh.setFormatter(formatter) 32 | logger.addHandler(fh) 33 | 34 | logger.setLevel(logging.INFO) 35 | 36 | logger_initialized[name] = True 37 | 38 | return logger 39 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | @torch.jit.script 5 | def lat(j: torch.Tensor, num_lat: int) -> torch.Tensor: 6 | return 90. - j * 180./float(num_lat-1) 7 | 8 | @torch.jit.script 9 | def latitude_weighting_factor_torch(j: torch.Tensor, num_lat: int, s: torch.Tensor) -> torch.Tensor: 10 | return num_lat * torch.cos(torch.pi/180. * lat(j, num_lat)) / s 11 | 12 | @torch.jit.script 13 | def weighted_rmse_torch_channels(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 14 | #takes in arrays of size [n, c, h, w] and returns latitude-weighted rmse for each chann 15 | num_lat = pred.shape[2] 16 | #num_long = target.shape[2] 17 | lat_t = torch.arange(start=0, end=num_lat, device=pred.device) 18 | 19 | s = torch.sum(torch.cos(torch.pi/180. * lat(lat_t, num_lat))) 20 | weight = torch.reshape(latitude_weighting_factor_torch(lat_t, num_lat, s), (1, 1, -1, 1)) 21 | result = torch.sqrt(torch.mean(weight * (pred - target)**2., dim=(-1,-2))) 22 | return result 23 | 24 | @torch.jit.script 25 | def weighted_rmse_torch(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 26 | result = weighted_rmse_torch_channels(pred, target) 27 | return torch.mean(result, dim=0) 28 | 29 | def WRMSE(pred, gt, data_std): 30 | return weighted_rmse_torch(pred, gt) * data_std 31 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | import torch.distributed as dist 2 | import torch 3 | import os 4 | import numpy as np 5 | import random 6 | from typing import Any 7 | import re 8 | 9 | 10 | def reduce_dict(input_dict, average=True): 11 | """ 12 | Args: 13 | input_dict (dict): all the values will be reduced 14 | average (bool): whether to do average or sum 15 | Reduce the values in the dictionary from all processes so that all processes 16 | have the averaged results. Returns a dict with the same fields as 17 | input_dict, after reduction. 18 | """ 19 | world_size = get_world_size() 20 | if world_size < 2: 21 | return input_dict 22 | with torch.no_grad(): 23 | names = [] 24 | values = [] 25 | # sort the keys so that they are consistent across processes 26 | for k in sorted(input_dict.keys()): 27 | names.append(k) 28 | values.append(input_dict[k]) 29 | values = torch.stack(values, dim=0) 30 | dist.all_reduce(values) 31 | if average: 32 | values /= world_size 33 | reduced_dict = {k: v for k, v in zip(names, values)} 34 | return reduced_dict 35 | 36 | 37 | def setup_for_distributed(is_master): 38 | """ 39 | This function disables printing when not in master process 40 | """ 41 | import builtins as __builtin__ 42 | builtin_print = __builtin__.print 43 | 44 | def print(*args, **kwargs): 45 | force = kwargs.pop('force', False) 46 | if is_master or force: 47 | builtin_print(*args, **kwargs) 48 | 49 | __builtin__.print = print 50 | 51 | 52 | def is_dist_avail_and_initialized(): 53 | if not dist.is_available(): 54 | return False 55 | if not dist.is_initialized(): 56 | return False 57 | return True 58 | 59 | 60 | def get_world_size(): 61 | if not is_dist_avail_and_initialized(): 62 | return 1 63 | return dist.get_world_size() 64 | 65 | 66 | def get_rank(): 67 | if not is_dist_avail_and_initialized(): 68 | return 0 69 | return dist.get_rank() 70 | 71 | 72 | def is_main_process(): 73 | return get_rank() == 0 74 | 75 | 76 | def save_on_master(*args, **kwargs): 77 | if is_main_process(): 78 | torch.save(*args, **kwargs) 79 | 80 | 81 | def get_ip(ip_list): 82 | if "," in ip_list: 83 | ip_list = ip_list.split(',')[0] 84 | if "[" in ip_list: 85 | ipbefore_4, ip4 = ip_list.split('[') 86 | ip4 = re.findall(r"\d+", ip4)[0] 87 | ip1, ip2, ip3 = ipbefore_4.split('-')[-4:-1] 88 | else: 89 | ip1,ip2,ip3,ip4 = ip_list.split('-')[-4:] 90 | ip_addr = "tcp://" + ".".join([ip1, ip2, ip3, ip4]) + ":" 91 | return ip_addr 92 | 93 | 94 | 95 | def init_distributed_mode(args): 96 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 97 | args.rank = int(os.environ["RANK"]) 98 | args.world_size = int(os.environ['WORLD_SIZE']) 99 | args.local_rank = int(os.environ['LOCAL_RANK']) 100 | elif 'SLURM_PROCID' in os.environ: 101 | args.rank = int(os.environ['SLURM_PROCID']) 102 | args.local_rank = int(os.environ['SLURM_LOCALID']) 103 | args.world_size = int(os.environ['SLURM_NTASKS']) 104 | ip_addr = get_ip(os.environ['SLURM_STEP_NODELIST']) 105 | port = int(os.environ['SLURM_SRUN_COMM_PORT']) 106 | # args.init_method = ip_addr + str(port) 107 | args.init_method = ip_addr + args.init_method.split(":")[-1] 108 | else: 109 | print('Not using distributed mode') 110 | args.distributed = False 111 | return 112 | 113 | args.distributed = True 114 | 115 | torch.cuda.set_device(args.local_rank) 116 | args.dist_backend = 'nccl' 117 | print('| distributed init (rank {}, local_rank {}): {}'.format( 118 | args.rank, args.local_rank, args.init_method), flush=True) 119 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.init_method, 120 | world_size=args.world_size, rank=args.rank) 121 | torch.distributed.barrier() 122 | setup_for_distributed(args.rank == 0) 123 | 124 | 125 | def DistributedParallel_Model(model, gpu_num, find_unused_parameters=False): 126 | if is_dist_avail_and_initialized(): 127 | device = torch.device('cuda' if torch.cuda.is_available() else "cpu") 128 | if device == torch.device('cpu'): 129 | raise EnvironmentError('No GPUs, cannot initialize multigpu training.') 130 | model.to(device) 131 | # optimizer = torch.optim.Adam(model.parameters(), lr=lr) 132 | # model, optimizer = amp.initialize(model, optimizer, opt_level="O0") 133 | ddp_sub_model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[gpu_num], find_unused_parameters=find_unused_parameters) 134 | model = ddp_sub_model 135 | 136 | # model.to(device) 137 | # model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[gpu_num]) 138 | # model_without_ddp = model.module 139 | else: 140 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 141 | model.to(device) 142 | # optimizer = torch.optim.Adam(model.parameters(), lr=lr) 143 | # model, optimizer = amp.initialize(model, optimizer, opt_level="O0") 144 | # for key in model.model: 145 | # model.model[key].to(device) 146 | 147 | return model 148 | 149 | 150 | class Dict(dict): 151 | def __getattr__(self, name: str) -> Any: 152 | try: 153 | return self[name] 154 | except KeyError: 155 | raise AttributeError(name) 156 | 157 | def __setattr__(self, name: str, value: Any) -> None: 158 | self[name] = value 159 | 160 | def __delattr__(self, name: str) -> None: 161 | del self[name] 162 | # __setattr__ = dict.__setitem__ 163 | # __getattr__ = dict.__getitem__ 164 | 165 | def dictToObj(dictObj): 166 | if not isinstance(dictObj, dict): 167 | return dictObj 168 | d = Dict() 169 | for k, v in dictObj.items(): 170 | d[k] = dictToObj(v) 171 | return d 172 | 173 | 174 | def setup_seed(seed): 175 | torch.manual_seed(seed) 176 | torch.cuda.manual_seed_all(seed) 177 | np.random.seed(seed) 178 | random.seed(seed) 179 | # torch.backends.cudnn.deterministic = True 180 | 181 | 182 | def named_params_and_buffers(module): 183 | assert isinstance(module, torch.nn.Module) 184 | return list(module.named_parameters()) + list(module.named_buffers()) 185 | 186 | def check_ddp_consistency(module, ignore_regex=None): 187 | assert isinstance(module, torch.nn.Module) 188 | for name, tensor in named_params_and_buffers(module): 189 | fullname = type(module).__name__ + '.' + name 190 | if ignore_regex is not None and re.fullmatch(ignore_regex, fullname): 191 | continue 192 | tensor = tensor.detach() 193 | # if tensor.is_floating_point(): 194 | # tensor = nan_to_num(tensor) 195 | other = tensor.clone() 196 | torch.distributed.broadcast(tensor=other, src=0) 197 | # print(fullname, tensor.sum(), other.sum()) 198 | assert (tensor == other).all(), fullname 199 | --------------------------------------------------------------------------------