├── README.md
├── configs
├── Adas
│ └── training_options.yaml
├── ConvCNP
│ └── training_options.yaml
└── FNP
│ └── training_options.yaml
├── data
└── .gitkeep
├── datasets
├── __init__.py
├── era5_npy_f32.py
├── mean_std.json
└── mean_std_single.json
├── inference.py
├── models
├── Adas.py
├── ConvCNP.py
└── FNP.py
├── modules
├── __init__.py
├── cnn.py
├── encoders.py
├── helpers.py
├── initialization.py
├── losses.py
├── mlp.py
└── transformer.py
├── train.py
└── utils
├── __init__.py
├── builder.py
├── logger.py
├── metrics.py
└── misc.py
/README.md:
--------------------------------------------------------------------------------
1 | # FNP: Fourier Neural Processes for Arbitrary-Resolution Data Assimilation
2 |
3 | This repo contains the official PyTorch codebase of FNP. Our paper is accepted by NeurIPS 2024.
4 |
5 | ## Codebase Structure
6 |
7 | - `configs` contains all the experiment configurations.
8 | - `Adas` contains the configuration to train the Adas model.
9 | - `ConvCNP` contains the configuration to train the ConvCNP model.
10 | - `FNP` contains the configuration to train the FNP model.
11 | - `data` contains the ERA5 data.
12 | - `datasets` contains the dataset and the mean and standard deviation values of ERA5 data.
13 | - `models` contains the data assimilation models and the forecast model FengWu (ONNX version).
14 | - `modules` contains the basic modules used for all the data assimilation models.
15 | - `utils` contains the files that support some basic needs.
16 | - `train.py` and `inference.py` provide training and testing pipelines.
17 |
18 | We provide the ONNX model of FengWu with 128×256 resolution for making forecasts. The ERA5 data can be downloaded from the official website of Climate Data Store.
19 |
20 | ## Setup
21 |
22 | First, download and set up the repo
23 |
24 | ```
25 | git clone https://github.com/OpenEarthLab/FNP.git
26 | cd FNP
27 | ```
28 |
29 | Then, download and put the ERA5 data and forecast model `FengWu.onnx` into corresponding positions according to the codebase structure.
30 |
31 | Deploy the environment given below
32 |
33 | ```
34 | python version 3.8.18
35 | torch==1.13.1+cu117
36 | ```
37 |
38 | ## Training
39 |
40 | We support multi-node and multi-gpu training. You can freely adjust the number of nodes and GPUs in the following commands.
41 |
42 | To train the FNP model with the default configuration of `<24h lead time background, 10% observations with 128×256 resolution>`, just run
43 |
44 | ```
45 | torchrun --nnodes=1 --nproc_per_node=4 --node_rank=0 --master_port=29500 train.py
46 | ```
47 |
48 | You can freely choose the experiment you want to perform by changing the command parameters. For example, if you want to train the `ConvCNP` model with the configuration of `<48h lead time background, 1% observations with 256×512 resolution>`, you can run
49 |
50 | ```
51 | torchrun --nnodes=1 --nproc_per_node=4 --node_rank=0 --master_port=29500 train.py --lead_time=48 --ratio=0.99 --resolution=256 --rundir='./configs/ConvCNP'
52 | ```
53 |
54 | Please make sure that the parameter `--lead_time` is an integer multiple of 6, because the forecast model has a single-step forecast interval of six hours.
55 |
56 | **The resolution and ratio of the observations used for data assimilation can be arbitrary (the original resolution of ERA5 data is 721×1440), which are not limited to the settings given in our paper.**
57 |
58 | ## Evaluation
59 |
60 | The commands for testing are the same as for training.
61 |
62 | For example, you can use 1 GPU on 1 node to evaluate the performance of `Adas` model with the configuration of `<24h lead time background, 10% observations with 721×1440 resolution>` through
63 |
64 | ```
65 | torchrun --nnodes=1 --nproc_per_node=1 --node_rank=0 --master_port=29500 inference.py --resolution=721 --rundir='./configs/Adas'
66 | ```
67 |
68 | The best checkpoint saved during training will be loaded to evaluate the MSE, MAE, and WRMSE metrics for all variables on the testset.
69 |
70 | ## BibTeX
71 | ```bibtex
72 | @article{chen2025fnp,
73 | title={Fnp: Fourier neural processes for arbitrary-resolution data assimilation},
74 | author={Chen, Kun and Ye, Peng and Chen, Hao and Han, Tao and Ouyang, Wanli and Chen, Tao and BAI, LEI and others},
75 | journal={Advances in Neural Information Processing Systems},
76 | volume={37},
77 | pages={137847--137872},
78 | year={2025}
79 | }
80 | ```
81 |
--------------------------------------------------------------------------------
/configs/Adas/training_options.yaml:
--------------------------------------------------------------------------------
1 | vnames: &id001
2 | single_level_vnames:
3 | - u10
4 | - v10
5 | - t2m
6 | - msl
7 | multi_level_vnames:
8 | - z
9 | - q
10 | - u
11 | - v
12 | - t
13 | hight_level_list:
14 | - 50
15 | - 100
16 | - 150
17 | - 200
18 | - 250
19 | - 300
20 | - 400
21 | - 500
22 | - 600
23 | - 700
24 | - 850
25 | - 925
26 | - 1000
27 | dataset:
28 | train:
29 | type: era5_npy_f32
30 | data_dir: ./data
31 | train_stride: 6
32 | file_stride: 6
33 | sample_stride: 1
34 | vnames: *id001
35 | valid:
36 | type: era5_npy_f32
37 | data_dir: ./data
38 | train_stride: 6
39 | file_stride: 6
40 | sample_stride: 1
41 | vnames: *id001
42 | dataloader:
43 | num_workers: 4
44 | pin_memory: true
45 | prefetch_factor: 2
46 | persistent_workers: true
47 | model:
48 | type: Adas
49 | params:
50 | img_size:
51 | - 69
52 | - 128
53 | - 256
54 | dim: 96
55 | patch_size:
56 | - 1
57 | - 2
58 | - 2
59 | window_size:
60 | - 2
61 | - 4
62 | - 8
63 | ape: True
64 | criterion: UnifyMAE
65 | optimizer:
66 | type: AdamW
67 | params:
68 | lr: 1.0e-04
69 | betas:
70 | - 0.9
71 | - 0.9
72 | weight_decay: 0.01
73 | lr_scheduler:
74 | type: OneCycleLR
75 | params:
76 | max_lr: 1.0e-4
77 | pct_start: 0.1
78 | anneal_strategy: cos
79 | div_factor: 100
80 | final_div_factor: 1000
81 |
--------------------------------------------------------------------------------
/configs/ConvCNP/training_options.yaml:
--------------------------------------------------------------------------------
1 | vnames: &id001
2 | single_level_vnames:
3 | - u10
4 | - v10
5 | - t2m
6 | - msl
7 | multi_level_vnames:
8 | - z
9 | - q
10 | - u
11 | - v
12 | - t
13 | hight_level_list:
14 | - 50
15 | - 100
16 | - 150
17 | - 200
18 | - 250
19 | - 300
20 | - 400
21 | - 500
22 | - 600
23 | - 700
24 | - 850
25 | - 925
26 | - 1000
27 | dataset:
28 | train:
29 | type: era5_npy_f32
30 | data_dir: ./data
31 | train_stride: 6
32 | file_stride: 6
33 | sample_stride: 1
34 | vnames: *id001
35 | valid:
36 | type: era5_npy_f32
37 | data_dir: ./data
38 | train_stride: 6
39 | file_stride: 6
40 | sample_stride: 1
41 | vnames: *id001
42 | dataloader:
43 | num_workers: 4
44 | pin_memory: true
45 | prefetch_factor: 2
46 | persistent_workers: true
47 | model:
48 | type: ConvCNP
49 | params:
50 | x_dim: 69
51 | y_dim: 69
52 | r_dim: 512 # 512 for 128×256 & 256×512 resolution, 128 for 721×1440 resolution
53 | criterion: CNPFLoss
54 | optimizer:
55 | type: AdamW
56 | params:
57 | lr: 1.0e-04
58 | betas:
59 | - 0.9
60 | - 0.9
61 | weight_decay: 0.01
62 | lr_scheduler:
63 | type: OneCycleLR
64 | params:
65 | max_lr: 1.0e-4
66 | pct_start: 0.1
67 | anneal_strategy: cos
68 | div_factor: 100
69 | final_div_factor: 1000
70 |
--------------------------------------------------------------------------------
/configs/FNP/training_options.yaml:
--------------------------------------------------------------------------------
1 | vnames: &id001
2 | single_level_vnames:
3 | - u10
4 | - v10
5 | - t2m
6 | - msl
7 | multi_level_vnames:
8 | - z
9 | - q
10 | - u
11 | - v
12 | - t
13 | hight_level_list:
14 | - 50
15 | - 100
16 | - 150
17 | - 200
18 | - 250
19 | - 300
20 | - 400
21 | - 500
22 | - 600
23 | - 700
24 | - 850
25 | - 925
26 | - 1000
27 | dataset:
28 | train:
29 | type: era5_npy_f32
30 | data_dir: ./data
31 | train_stride: 6
32 | file_stride: 6
33 | sample_stride: 1
34 | vnames: *id001
35 | valid:
36 | type: era5_npy_f32
37 | data_dir: ./data
38 | train_stride: 6
39 | file_stride: 6
40 | sample_stride: 1
41 | vnames: *id001
42 | dataloader:
43 | num_workers: 4
44 | pin_memory: true
45 | prefetch_factor: 2
46 | persistent_workers: true
47 | model:
48 | type: FNP
49 | params:
50 | n_channels:
51 | - 4
52 | - 13
53 | - 13
54 | - 13
55 | - 13
56 | - 13
57 | r_dim: 128 # 128 for 128×256 & 256×512 resolution, 64 for 721×1440 resolution
58 | use_nfl: true
59 | use_dam: true
60 | criterion: CNPFLoss
61 | optimizer:
62 | type: AdamW
63 | params:
64 | lr: 1.0e-04
65 | betas:
66 | - 0.9
67 | - 0.9
68 | weight_decay: 0.01
69 | lr_scheduler:
70 | type: OneCycleLR
71 | params:
72 | max_lr: 1.0e-4
73 | pct_start: 0.1
74 | anneal_strategy: cos
75 | div_factor: 100
76 | final_div_factor: 1000
77 |
--------------------------------------------------------------------------------
/data/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenEarthLab/FNP/624e624be481cfa6a149613bd8a08f5df318cb10/data/.gitkeep
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenEarthLab/FNP/624e624be481cfa6a149613bd8a08f5df318cb10/datasets/__init__.py
--------------------------------------------------------------------------------
/datasets/era5_npy_f32.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset
2 | import numpy as np
3 | import io
4 | import json
5 | import pandas as pd
6 | import os
7 | from multiprocessing import shared_memory
8 | import multiprocessing
9 | import copy
10 | import queue
11 | import torch
12 |
13 |
14 | Years = {
15 | 'train': ['1979-01-01 00:00:00', '2015-12-31 23:00:00'],
16 | 'valid': ['2016-01-01 00:00:00', '2017-12-31 23:00:00'],
17 | 'test': ['2018-01-01 00:00:00', '2018-12-31 23:00:00'],
18 | 'all': ['1979-01-01 00:00:00', '2020-12-31 23:00:00']
19 | }
20 |
21 | multi_level_vnames = [
22 | "z", "t", "q", "r", "u", "v", "vo", "pv",
23 | ]
24 | single_level_vnames = [
25 | "t2m", "u10", "v10", "tcc", "tp", "tisr",
26 | ]
27 | long2shortname_dict = {"geopotential": "z", "temperature": "t", "specific_humidity": "q", "relative_humidity": "r", "u_component_of_wind": "u", "v_component_of_wind": "v", "vorticity": "vo", "potential_vorticity": "pv", \
28 | "2m_temperature": "t2m", "10m_u_component_of_wind": "u10", "10m_v_component_of_wind": "v10", "total_cloud_cover": "tcc", "total_precipitation": "tp", "toa_incident_solar_radiation": "tisr"}
29 |
30 | height_level = [1, 2, 3, 5, 7, 10, 20, 30, 50, 70, 100, 125, 150, 175, 200, 225, 250, 300, 350, 400, 450, \
31 | 500, 550, 600, 650, 700, 750, 775, 800, 825, 850, 875, 900, 925, 950, 975, 1000]
32 | # height_level = [50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000]
33 |
34 |
35 | def standardization(data):
36 | mu = np.mean(data)
37 | sigma = np.std(data)
38 | return (data - mu) / sigma
39 |
40 |
41 | class era5_npy_f32(Dataset):
42 | def __init__(self, data_dir='./data', split='train', **kwargs) -> None:
43 | super().__init__()
44 |
45 | self.length = kwargs.get('length', 1)
46 | self.file_stride = kwargs.get('file_stride', 6)
47 | self.sample_stride = kwargs.get('sample_stride', 1)
48 | self.output_meanstd = kwargs.get("output_meanstd", False)
49 | self.use_diff_pos = kwargs.get("use_diff_pos", False)
50 | self.rm_equator = kwargs.get("rm_equator", False)
51 | Years_dict = kwargs.get('years', Years)
52 |
53 | self.pred_length = kwargs.get("pred_length", 0)
54 | self.inference_stride = kwargs.get("inference_stride", 6)
55 | self.train_stride = kwargs.get("train_stride", 6)
56 | self.use_gt = kwargs.get("use_gt", True)
57 | self.data_save_dir = kwargs.get("data_save_dir", None)
58 |
59 | self.save_single_level_names = kwargs.get("save_single_level_names", [])
60 | self.save_multi_level_names = kwargs.get("save_multi_level_names", [])
61 |
62 | vnames_type = kwargs.get("vnames", {})
63 | self.single_level_vnames = vnames_type.get('single_level_vnames', [])
64 | self.multi_level_vnames = vnames_type.get('multi_level_vnames', ['z','q', 'u', 'v', 't'])
65 | self.height_level_list = vnames_type.get('hight_level_list', [50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000])
66 | self.height_level_indexes = [height_level.index(j) for j in self.height_level_list]
67 |
68 | self.select_row = [i for i in range(721)]
69 | if self.rm_equator:
70 | del self.select_row[360]
71 | self.split = split
72 | self.data_dir = data_dir
73 | years = Years_dict[split]
74 | self.init_file_list(years)
75 |
76 | self._get_meanstd()
77 | self.mean, self.std = self.get_meanstd()
78 | self.data_element_num = len(self.single_level_vnames) + len(self.multi_level_vnames) * len(self.height_level_list)
79 | dim = len(self.single_level_vnames) + len(self.multi_level_vnames) * len(self.height_level_list)
80 |
81 | self.index_dict1 = {}
82 | self.index_dict2 = {}
83 | i = 0
84 | for vname in self.single_level_vnames:
85 | self.index_dict1[(vname, 0)] = i
86 | i += 1
87 | for vname in self.multi_level_vnames:
88 | for height in self.height_level_list:
89 | self.index_dict1[(vname, height)] = i
90 | i += 1
91 |
92 | self.index_queue = multiprocessing.Queue()
93 | self.unit_data_queue = multiprocessing.Queue()
94 |
95 | self.index_queue.cancel_join_thread()
96 | self.unit_data_queue.cancel_join_thread()
97 |
98 | self.compound_data_queue = []
99 | self.sharedmemory_list = []
100 | self.compound_data_queue_dict = {}
101 | self.sharedmemory_dict = {}
102 |
103 | self.compound_data_queue_num = 8
104 |
105 | self.lock = multiprocessing.Lock()
106 | if self.rm_equator:
107 | self.a = np.zeros((dim, 720, 1440), dtype=np.float32)
108 | else:
109 | self.a = np.zeros((dim, 721, 1440), dtype=np.float32)
110 |
111 | for _ in range(self.compound_data_queue_num):
112 | self.compound_data_queue.append(multiprocessing.Queue())
113 | shm = shared_memory.SharedMemory(create=True, size=self.a.nbytes)
114 | shm.unlink()
115 | self.sharedmemory_list.append(shm)
116 |
117 | self.arr = multiprocessing.Array('i', range(self.compound_data_queue_num))
118 |
119 | self._workers = []
120 |
121 | for _ in range(40):
122 | w = multiprocessing.Process(
123 | target=self.load_data_process)
124 | w.daemon = True
125 | # NB: Process.start() actually take some time as it needs to
126 | # start a process and pass the arguments over via a pipe.
127 | # Therefore, we only add a worker to self._workers list after
128 | # it started, so that we do not call .join() if program dies
129 | # before it starts, and __del__ tries to join but will get:
130 | # AssertionError: can only join a started process.
131 | w.start()
132 | self._workers.append(w)
133 | w = multiprocessing.Process(target=self.data_compound_process)
134 | w.daemon = True
135 | w.start()
136 | self._workers.append(w)
137 |
138 | def init_file_list(self, years):
139 | time_sequence = pd.date_range(years[0],years[1],freq=str(self.file_stride)+'H') #pd.date_range(start='2019-1-09',periods=24,freq='H')
140 | self.file_list= [os.path.join(str(time_stamp.year), str(time_stamp.to_datetime64()).split('.')[0]).replace('T', '/')
141 | for time_stamp in time_sequence]
142 | self.single_file_list= [os.path.join('single/'+str(time_stamp.year), str(time_stamp.to_datetime64()).split('.')[0]).replace('T', '/')
143 | for time_stamp in time_sequence]
144 |
145 | def _get_meanstd(self):
146 | with open('./datasets/mean_std.json',mode='r') as f:
147 | multi_level_mean_std = json.load(f)
148 | with open('./datasets/mean_std_single.json',mode='r') as f:
149 | single_level_mean_std = json.load(f)
150 | self.mean_std = {}
151 | multi_level_mean_std['mean'].update(single_level_mean_std['mean'])
152 | multi_level_mean_std['std'].update(single_level_mean_std['std'])
153 | self.mean_std['mean'] = multi_level_mean_std['mean']
154 | self.mean_std['std'] = multi_level_mean_std['std']
155 | for vname in self.single_level_vnames:
156 | self.mean_std['mean'][vname] = np.array(self.mean_std['mean'][vname])[::-1][:,np.newaxis,np.newaxis]
157 | self.mean_std['std'][vname] = np.array(self.mean_std['std'][vname])[::-1][:,np.newaxis,np.newaxis]
158 | for vname in self.multi_level_vnames:
159 | self.mean_std['mean'][vname] = np.array(self.mean_std['mean'][vname])[::-1][:,np.newaxis,np.newaxis]
160 | self.mean_std['std'][vname] = np.array(self.mean_std['std'][vname])[::-1][:,np.newaxis,np.newaxis]
161 |
162 | def data_compound_process(self):
163 | recorder_dict = {}
164 | while True:
165 | job_pid, idx, vname, height = self.unit_data_queue.get()
166 | if job_pid not in self.compound_data_queue_dict:
167 | try:
168 | self.lock.acquire()
169 | for i in range(self.compound_data_queue_num):
170 | if job_pid == self.arr[i]:
171 | self.compound_data_queue_dict[job_pid] = self.compound_data_queue[i]
172 | break
173 | if (i == self.compound_data_queue_num - 1) and job_pid != self.arr[i]:
174 | print("error", job_pid, self.arr)
175 | except Exception as err:
176 | raise err
177 | finally:
178 | self.lock.release()
179 |
180 | if (job_pid, idx) in recorder_dict:
181 | # recorder_dict[(job_pid, idx)][(vname, height)] = 1
182 | recorder_dict[(job_pid, idx)] += 1
183 | else:
184 | recorder_dict[(job_pid, idx)] = 1
185 | if recorder_dict[(job_pid, idx)] == self.data_element_num:
186 | del recorder_dict[(job_pid, idx)]
187 | self.compound_data_queue_dict[job_pid].put((idx))
188 |
189 | def get_data(self, idxes):
190 | job_pid = os.getpid()
191 | if job_pid not in self.compound_data_queue_dict:
192 | try:
193 | self.lock.acquire()
194 | for i in range(self.compound_data_queue_num):
195 | if i == self.arr[i]:
196 | self.arr[i] = job_pid
197 | self.compound_data_queue_dict[job_pid] = self.compound_data_queue[i]
198 | self.sharedmemory_dict[job_pid] = self.sharedmemory_list[i]
199 | break
200 | if (i == self.compound_data_queue_num - 1) and job_pid != self.arr[i]:
201 | print("error", job_pid, self.arr)
202 |
203 | except Exception as err:
204 | raise err
205 | finally:
206 | self.lock.release()
207 |
208 | try:
209 | idx = self.compound_data_queue_dict[job_pid].get(False)
210 | raise ValueError
211 | except queue.Empty:
212 | pass
213 | except Exception as err:
214 | raise err
215 |
216 | b = np.ndarray(self.a.shape, dtype=self.a.dtype, buffer=self.sharedmemory_dict[job_pid].buf)
217 | return_data = []
218 | for idx in idxes:
219 | for vname in self.single_level_vnames:
220 | self.index_queue.put((job_pid, idx, vname, 0))
221 | for vname in self.multi_level_vnames:
222 | for height in self.height_level_list:
223 | self.index_queue.put((job_pid, idx, vname, height))
224 | idx = self.compound_data_queue_dict[job_pid].get()
225 | b -= self.mean.numpy()[:, np.newaxis, np.newaxis]
226 | b /= self.std.numpy()[:, np.newaxis, np.newaxis]
227 | return_data.append(copy.deepcopy(b))
228 |
229 | return return_data
230 |
231 | def load_data_process(self):
232 | while True:
233 | job_pid, idx, vname, height = self.index_queue.get()
234 | if job_pid not in self.compound_data_queue_dict:
235 | try:
236 | self.lock.acquire()
237 | for i in range(self.compound_data_queue_num):
238 | if job_pid == self.arr[i]:
239 | self.compound_data_queue_dict[job_pid] = self.compound_data_queue[i]
240 | self.sharedmemory_dict[job_pid] = self.sharedmemory_list[i]
241 | break
242 | if (i == self.compound_data_queue_num - 1) and job_pid != self.arr[i]:
243 | print("error", job_pid, self.arr)
244 | except Exception as err:
245 | raise err
246 | finally:
247 | self.lock.release()
248 |
249 | if vname in self.single_level_vnames:
250 | file = self.single_file_list[idx]
251 | url = f"{self.data_dir}/{file}-{vname}.npy"
252 | elif vname in self.multi_level_vnames:
253 | file = self.file_list[idx]
254 | url = f"{self.data_dir}/{file}-{vname}-{height}.0.npy"
255 | b = np.ndarray(self.a.shape, dtype=self.a.dtype, buffer=self.sharedmemory_dict[job_pid].buf)
256 | unit_data = np.load(url)
257 | # unit_data = unit_data[np.newaxis, :, :]
258 | if self.rm_equator:
259 | b[self.index_dict1[(vname, height)], :360] = unit_data[:360]
260 | b[self.index_dict1[(vname, height)], 360:] = unit_data[361:]
261 | else:
262 | b[self.index_dict1[(vname, height)], :] = unit_data[:]
263 | del unit_data
264 | self.unit_data_queue.put((job_pid, idx, vname, height))
265 |
266 | def __len__(self):
267 |
268 | if self.split != "test":
269 | data_len = (len(self.file_list) - (self.length - 1) * self.sample_stride) // (self.train_stride // self.sample_stride // self.file_stride)
270 | elif self.use_gt:
271 | data_len = len(self.file_list) - (self.length - 1) * self.sample_stride
272 | data_len -= self.pred_length * self.sample_stride + 1
273 | data_len = (data_len + max(self.inference_stride // self.sample_stride // self.file_stride, 1) - 1) // max(self.inference_stride // self.sample_stride // self.file_stride, 1)
274 | else:
275 | data_len = len(self.file_list) - (self.length - 1) * self.sample_stride
276 | data_len = (data_len + max(self.inference_stride // self.sample_stride // self.file_stride, 1) - 1) // max(self.inference_stride // self.sample_stride // self.file_stride, 1)
277 |
278 | return data_len
279 |
280 | def get_meanstd(self):
281 | return_data_mean = []
282 | return_data_std = []
283 |
284 | for vname in self.single_level_vnames:
285 | return_data_mean.append(self.mean_std['mean'][vname])
286 | return_data_std.append(self.mean_std['std'][vname])
287 | for vname in self.multi_level_vnames:
288 | return_data_mean.append(self.mean_std['mean'][vname][self.height_level_indexes])
289 | return_data_std.append(self.mean_std['std'][vname][self.height_level_indexes])
290 |
291 | return torch.from_numpy(np.concatenate(return_data_mean, axis=0)[:, 0, 0]), torch.from_numpy(np.concatenate(return_data_std, axis=0)[:, 0, 0])
292 |
293 | def __getitem__(self, index):
294 | index = min(index, len(self.file_list) - (self.length-1) * self.sample_stride - 1)
295 | if self.split == "test":
296 | index = index * max(self.inference_stride // self.sample_stride // self.file_stride, 1)
297 | else:
298 | index = index * (self.train_stride // self.sample_stride // self.file_stride)
299 | array_seq = self.get_data([index, index + self.sample_stride, index + (self.length-1) * self.sample_stride])
300 | tar_idx = np.array([index + self.sample_stride * (self.length - 1)])
301 | return array_seq, tar_idx
302 |
--------------------------------------------------------------------------------
/datasets/mean_std.json:
--------------------------------------------------------------------------------
1 | {"mean": {"z_overall": 125173.77185546875, "z": [777.5631800842285, 2806.6911419677767, 4880.769941406251, 7002.870792236329, 9175.857954101568, 11402.84267333983, 13687.02337158203, 16031.525583496094, 18439.37236816406, 20913.51230468748, 23457.279204101567, 28769.254521484374, 34409.91857421876, 40425.96180664062, 46874.99372070313, 53826.54542968748, 61368.859130859375, 69620.0044531249, 78745.945078125, 88999.83613281258, 100822.13164062506, 107551.17681640621, 115011.55044921875, 123395.13201171876, 132973.8087890624, 144168.34273437515, 157706.1917968749, 179254.20781250007, 199832.31609374992, 231648.1046875, 257369.633828125, 302391.6971875003, 326148.29640625, 349188.804375, 385898.7467187499, 416444.811875, 470007.90359375], "q_overall": 0.0016609070883714593, "q": [0.006659231842495503, 0.006508124030660836, 0.006180181269301101, 0.005698622210184111, 0.005229148901998995, 0.004779181429184972, 0.004328316930914292, 0.003895292008528486, 0.003504167633363977, 0.003151517739752307, 0.002823496932396666, 0.002245682779466732, 0.0017783852928550915, 0.0014082587775192225, 0.0010869937593815861, 0.0007819174298492726, 0.0005401131762482684, 0.0003586592462670523, 0.00022093375020631356, 0.00012106754767955863, 5.37612270545651e-05, 3.1880958731562706e-05, 1.7365863168379306e-05, 8.887264439181295e-06, 4.689598504228342e-06, 3.032497714912096e-06, 2.557213611567022e-06, 2.7101390543293752e-06, 2.8248029025235157e-06, 2.8785681178078434e-06, 2.918656103929608e-06, 3.0804213855617494e-06, 3.2166743511652386e-06, 3.3617744981029316e-06, 3.545218272620331e-06, 3.651859217939091e-06, 3.8337135470101216e-06], "u_overall": 5.067544518113135, "u": [-0.18604825597634778, 0.014575419906759663, 0.1831986801231688, 0.3961004569224316, 0.6612499056267552, 0.9589599340630222, 1.2827875773236155, 1.6256907974928618, 1.9794725690782076, 2.3418910357356073, 2.7121086287498475, 3.4704639464616776, 4.26914108872414, 5.110536641478544, 6.038327068090436, 7.07798705458641, 8.268079032897964, 9.628437678813944, 11.183755590915675, 12.895844810009004, 14.48938421010971, 15.056942412853239, 15.298378415107727, 15.109850177764892, 14.321160042285918, 12.678406665325161, 10.046632840633391, 6.356182322502137, 4.44909584343433, 0.7572945548904316, -2.5366748879943004, -4.663122029609662, -3.6527538601704896, -1.2890704965137414, 1.792431366236417, 3.4357510804757476, 5.936705157756806], "v_overall": 0.028403457459244247, "v": [-0.15797925526494516, -0.152254779834766, -0.11334734256146482, -0.062323193841984835, -0.02435781713542383, -0.0020319510312947365, 0.014846984493779027, 0.026666699802153734, 0.03010514150775634, 0.028390460408554644, 0.024946943822433242, 0.0075308467486320295, -0.011855637595126613, -0.022026948712546065, -0.017326252191560337, 0.0007164070030103152, -0.008457765014013602, -0.04722135595180817, -0.059056078250941904, 0.03872977272505523, 0.21001753183547414, 0.26851208267849863, 0.2956721917196408, 0.29721465504262573, 0.2678451650420902, 0.1749421462348983, 0.1010729405652091, 0.035401583576604036, 0.012106836824341376, 0.020043842778977715, 0.00465771341478103, -0.03084933991518482, -0.06255946054356175, -0.09668692563471275, -0.07074970662535636, -0.008126834754220965, 0.13870985066001587], "t_overall": 247.24390762329102, "t": [278.5929747772214, 277.2997010040281, 276.21999847412104, 275.3001181793211, 274.4359161376951, 273.6013989257812, 272.77368919372566, 271.89837799072285, 270.9205000305179, 269.8229228210445, 268.63149719238277, 265.99485839843743, 263.0967428588868, 259.84156120300344, 256.09259799957306, 251.74072200775146, 246.77826755523682, 241.16466262817383, 234.9638207244873, 228.6614455413818, 223.53410289764412, 221.58114109039306, 219.73181056976318, 217.6849419784546, 215.23375904083258, 212.61522850036621, 210.3573041915893, 211.44547527313233, 214.66564151763913, 219.08330230712897, 222.90260936737062, 229.70379344940184, 234.8047568511963, 242.76109550476073, 257.88357688903875, 266.1968101501466, 270.00755813598647], "r_overall": 40.394069585800146, "r": [80.17475992202763, 81.89339225769044, 82.05619462966922, 80.0935426521301, 77.31068056106565, 74.06146982192983, 70.40090473175046, 66.74169216156014, 63.65694343566894, 61.238844032287595, 59.20706973075867, 55.94804849624633, 53.78221703529358, 52.604153957366925, 51.98551220893857, 51.816392965316794, 52.520214595794705, 54.109504508972165, 55.61435362815857, 53.635069522857655, 42.4717398929596, 35.659886183738735, 29.037648782730102, 23.85700461387634, 21.362764282226564, 21.92401524066924, 23.441656923294072, 11.165672419071193, 3.7584114503860486, 1.6763820819556718, 0.9495417647063722, 0.2976933006197218, 0.10302423514425761, 0.02450558412820101, 0.0026089991186745458, 0.0007830015913350504, 0.0002408742079387594], "w_overall": 0.004383492262859363, "w": [0.022094985805451873, 0.0216572742210701, 0.019905026827473193, 0.017870978915598242, 0.01588542067212984, 0.013868597075343127, 0.011905601065373044, 0.010102826043148527, 0.008454037891642651, 0.006905023376311874, 0.005377083962594045, 0.002986208360084676, 0.0012120617383675368, 0.0008260514670288897, 0.000692131803097027, 0.000629533462382596, 0.0005991482839004908, 0.0005541121261796886, 0.0004358308141377164, 0.0002701366509023725, 0.00013300039120110795, 0.00010626948278826373, 9.627545935728681e-05, 8.954200330800081e-05, 6.651518481056939e-05, 2.8853250523752642e-05, -1.511444107139435e-05, -5.620325117561725e-05, -7.897719135826269e-05, -9.763219824087352e-05, -9.472476952721467e-05, -7.022762941403469e-05, -5.8008067545154444e-05, -4.455210098275765e-05, -2.651977333063105e-05, -1.5596725875788307e-05, -5.785629261385584e-06]}, "std": {"z_overall": 129782.78276610909, "z": [1098.9952409939283, 1112.6027805778692, 1142.3473460665198, 1187.5419349409494, 1246.4561469025161, 1317.5100178215614, 1399.3410970050247, 1490.8825928657998, 1590.9548416558155, 1698.5379482618182, 1812.9974354157323, 2060.9872289877453, 2334.319060376128, 2629.7201995715513, 2949.675806600318, 3299.702929930327, 3683.8414755168196, 4104.233456672573, 4557.44791664416, 5020.194961603476, 5405.73040397736, 5516.599935297351, 5540.73074484052, 5460.536820725648, 5253.301115477269, 4886.876297159263, 4357.588191568988, 3756.1336388685854, 3755.2810557402927, 4396.22650565561, 5144.265045354613, 6650.516303503411, 7495.09074134178, 8295.69333798942, 9517.614470978018, 10519.593543281839, 12104.888878805084], "q_overall": 0.0034514906703841854, "q": [0.00611814321149996, 0.005991895878203082, 0.005660814318820206, 0.005201345241925773, 0.004817240293151429, 0.004489990425884411, 0.004184742037434761, 0.0038904524158168458, 0.0036034287117598773, 0.0033171467057458016, 0.0030338747907811383, 0.0024928432426471183, 0.0020354637234126743, 0.0016946778969914426, 0.0013932648476959186, 0.001023028433607086, 0.0007179488022534093, 0.00048331132466880735, 0.0002993454787338255, 0.00016131853114827985, 7.02963683704546e-05, 4.0895619708336846e-05, 2.093742795871515e-05, 9.154817617587747e-06, 3.1627283344500357e-06, 7.369263717657513e-07, 4.2315237954921815e-07, 2.105491580661734e-07, 1.1555282996146702e-07, 1.6058042115095855e-07, 2.133802768772255e-07, 2.9865161188496304e-07, 3.153152311774654e-07, 3.0755590168313815e-07, 2.862195110630825e-07, 2.547970205742875e-07, 1.620701293822839e-07], "u_overall": 16.97319982237474, "u": [6.126056325796786, 7.016566034258269, 7.643058800982493, 7.93264239491015, 8.043800524363755, 8.105094821089457, 8.179197305830433, 8.287887909395936, 8.433526267376907, 8.60957088282445, 8.810868039599674, 9.286354460633103, 9.850636072885855, 10.503871726997783, 11.285900449345242, 12.215305549952125, 13.290793204085057, 14.509924979003983, 15.840437203647705, 17.078782531259524, 17.720698660431694, 17.667616741159385, 17.286773058038722, 16.542030679458666, 15.407016747306344, 13.807092388343909, 11.884088705628045, 10.334116775694541, 11.557361639969054, 14.832938252653848, 19.184593851862232, 25.245251895546826, 26.307161489246376, 26.068781364502108, 27.582285843981804, 29.507017272843985, 32.21467925596289], "v_overall": 10.225347045502723, "v": [5.23175612906023, 5.924618269023251, 6.285772466913804, 6.345356147017278, 6.279593292801223, 6.209277530023286, 6.186618416862026, 6.213000051694466, 6.279984926208441, 6.380234677983426, 6.510405300881028, 6.858187372121642, 7.312336974076069, 7.835396470389893, 8.498212257679501, 9.321096224035244, 10.312266990636417, 11.44656476066317, 12.61520885057924, 13.474533447403218, 13.360381609448558, 12.771109679285894, 11.896325029659364, 10.936078921724906, 9.998695230009567, 9.05105707564941, 8.178248048405905, 7.7733598444254, 8.417864770061094, 10.031378889054897, 11.362479417901184, 13.548648819446889, 14.409961339641928, 14.843368077149982, 15.355139749619095, 15.671183984838587, 15.620566981759925], "t_overall": 27.05388027783469, "t": [18.59201174892538, 17.962722435516273, 17.354568742491193, 16.88523210573196, 16.50532064048834, 16.214283284136293, 16.00275340237577, 15.829774257559896, 15.664667994261846, 15.503739090214387, 15.35464052689256, 15.021590351519892, 14.443361586551287, 13.858620163486986, 13.612509319936459, 13.459313552230206, 13.222745493030287, 12.773099916244078, 11.921086612374525, 10.212310312072752, 7.389004707914384, 6.217593563425525, 5.933385737657316, 6.957734476103884, 9.090666595626503, 11.748988502455319, 13.738672642636256, 12.222179433839193, 9.495652698988557, 8.418363124927735, 8.71613081797916, 9.979142288488994, 10.868110600253445, 11.453708697281687, 12.03482655512657, 12.18683909538171, 10.27148161649973], "r_overall": 38.27834275264974, "r": [18.065236846880957, 18.884182946667348, 19.9694447391712, 21.201659553304538, 22.682946621307828, 24.32165800612575, 26.05935412034893, 27.633995796855135, 28.837421755290475, 29.762150422980042, 30.498343554542576, 31.567549251190137, 32.358281367806725, 32.994946188446285, 33.54817198803172, 33.83420782522547, 34.1020571373513, 34.15576061007992, 33.97040860449722, 33.27175094990293, 33.824094547943524, 33.94965762113503, 32.2278297456812, 30.217155336090716, 29.454141244285715, 31.103837052430183, 33.23741657963707, 13.81056340487671, 5.50629572264854, 5.140879239893525, 3.7486899733037617, 1.3911355900931046, 0.45882004791501424, 0.08512110306079632, 0.004894727016083146, 0.0012718724357374124, 0.00047595219931748637], "w_overall": 0.15754408345291773, "w": [0.13979567674915916, 0.15479160882205267, 0.17909532539790496, 0.2025255148068513, 0.21983137560176763, 0.23023882117421315, 0.23583433097628145, 0.23861934853804015, 0.23947601319694647, 0.2388890682549996, 0.23720531494087146, 0.2326563138856663, 0.22800938309402644, 0.22319334450015696, 0.21895367863698434, 0.21374423887468189, 0.20502274496551173, 0.1897976164287798, 0.16794512621224, 0.14051455042850863, 0.11142703377201577, 0.09759972206579746, 0.0839118257220142, 0.0699250713532834, 0.05505537455154692, 0.0399342745096499, 0.0266884524850588, 0.016159587930447316, 0.012212181977916497, 0.008866289960038784, 0.007027148644031218, 0.004681425194685117, 0.0036326590952213142, 0.00276010917035495, 0.002059059572163848, 0.0016134326105307103, 0.0008927991973868609]}, "count": 400}
--------------------------------------------------------------------------------
/datasets/mean_std_single.json:
--------------------------------------------------------------------------------
1 | {"mean": {"v10": [0.22575792335029873], "u10": [-0.14186215714480854], "t2m": [278.7854495405721], "sp": [96672.96548579562], "msl": [100980.83590625007],"tp6h": [0.5938965440938924]}, "std": {"v10": [4.798220612223473], "u10": [5.610453475051704], "t2m": [21.32010786700973], "sp": [9580.331095986492], "msl": [1336.2115992274876], "tp6h": [1.5731802126254537]}, "count": 2750}
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import torch
4 | from utils.builder import ConfigBuilder
5 | import utils.misc as utils
6 | import yaml
7 | from utils.logger import get_logger
8 | import copy
9 |
10 |
11 |
12 | def subprocess_fn(args):
13 | utils.setup_seed(args.seed * args.world_size + args.rank)
14 |
15 | logger = get_logger("test", args.rundir, utils.get_rank(), filename='infer.log')
16 | args.cfg_params["logger"] = logger
17 |
18 | # build config
19 | logger.info('Building config ...')
20 | builder = ConfigBuilder(**args.cfg_params)
21 |
22 | # build model
23 | logger.info('Building models ...')
24 | model = builder.get_model()
25 | checkpoint_dict = torch.load(os.path.join(args.rundir, 'best_model.pth'), map_location=torch.device('cpu'))
26 | model.kernel.load_state_dict(checkpoint_dict)
27 | model.kernel = utils.DistributedParallel_Model(model.kernel, args.local_rank)
28 |
29 | # build forecast model
30 | logger.info('Building forecast models ...')
31 | args.forecast_model = builder.get_forecast(args.local_rank)
32 |
33 | # build dataset
34 | logger.info('Building dataloaders ...')
35 | dataset_params = args.cfg_params['dataset']
36 | test_dataloader = builder.get_dataloader(dataset_params=dataset_params, split='test', batch_size=args.batch_size)
37 |
38 | # inference
39 | logger.info('begin testing ...')
40 | model.test(test_dataloader, logger, args)
41 | logger.info('testing end ...')
42 |
43 |
44 | def main(args):
45 | if args.world_size > 1:
46 | utils.init_distributed_mode(args)
47 | else:
48 | args.rank = 0
49 | args.local_rank = 0
50 | args.distributed = False
51 | args.gpu = 0
52 | torch.cuda.set_device(args.gpu)
53 |
54 | args.rundir = os.path.join(args.rundir, f'mask{args.ratio}_lead{args.lead_time}h_res{args.resolution}')
55 | args.cfg = os.path.join(args.rundir, 'train.yaml')
56 | with open(args.cfg, 'r') as cfg_file:
57 | cfg_params = yaml.load(cfg_file, Loader = yaml.FullLoader)['cfg_params']
58 |
59 | cfg_params['dataloader']['num_workers'] = args.per_cpus
60 | cfg_params['dataset']['test'] = copy.deepcopy(cfg_params['dataset']['train'])
61 | args.cfg_params = cfg_params
62 |
63 | subprocess_fn(args)
64 |
65 |
66 | if __name__ == "__main__":
67 |
68 | parser = argparse.ArgumentParser()
69 |
70 | parser.add_argument('--seed', type = int, default = 0, help = 'seed')
71 | parser.add_argument('--cuda', type = int, default = 0, help = 'cuda id')
72 | parser.add_argument('--world_size', type = int, default = 1, help = 'number of progress')
73 | parser.add_argument('--per_cpus', type = int, default = 4, help = 'number of perCPUs to use')
74 | parser.add_argument('--batch_size', type = int, default = 1, help = "batch size")
75 | parser.add_argument('--lead_time', type = int, default = 24, help = "lead time (h) for background")
76 | parser.add_argument('--ratio', type = float, default = 0.9, help = "mask ratio")
77 | parser.add_argument('--resolution', type = int, default = 128, help = "observation resolution")
78 | parser.add_argument('--init_method', type = str, default = 'tcp://127.0.0.1:19111', help = 'multi process init method')
79 | parser.add_argument('--rundir', type = str, default = './configs/FNP', help = 'where to save the results')
80 |
81 | args = parser.parse_args()
82 |
83 | main(args)
84 |
85 |
--------------------------------------------------------------------------------
/models/Adas.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import time
5 | import numpy as np
6 | from timm.models.layers import trunc_normal_
7 | import utils.misc as utils
8 | from utils.metrics import WRMSE
9 | from functools import partial
10 | from modules import AllPatchEmbed, PatchRecover, BasicLayer, SwinTransformerLayer
11 | from utils.builder import get_optimizer, get_lr_scheduler
12 |
13 |
14 |
15 | class Adas_model(nn.Module):
16 | def __init__(self, img_size=(69,128,256), dim=96, patch_size=(1,2,2), window_size=(2,4,8), depth=8, num_heads=4,
17 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, ape=True, use_checkpoint=False):
18 | super().__init__()
19 |
20 | self.patchembed = AllPatchEmbed(img_size=img_size, embed_dim=dim, patch_size=patch_size, norm_layer=nn.LayerNorm) # b,c,14,180,360
21 | self.patchunembed = PatchRecover(img_size=img_size, embed_dim=dim, patch_size=patch_size)
22 | self.patch_resolution = self.patchembed.patch_resolution
23 |
24 | self.layer1 = BasicLayer(dim, kernel=(3,5,7), padding=(1,2,3), num_heads=num_heads, window_size=window_size, use_checkpoint=use_checkpoint) # s1
25 | self.layer2 = BasicLayer(dim*2, kernel=(3,3,5), padding=(1,1,2), num_heads=num_heads, window_size=window_size, sample='down', use_checkpoint=use_checkpoint) # s2
26 | self.layer3 = BasicLayer(dim*4, kernel=3, padding=1, num_heads=num_heads, window_size=window_size, sample='down', use_checkpoint=use_checkpoint) # s3
27 | self.layer4 = BasicLayer(dim*2, kernel=(3,3,5), padding=(1,1,2), num_heads=num_heads, window_size=window_size, sample='up', use_checkpoint=use_checkpoint) # s2
28 | self.layer5 = BasicLayer(dim, kernel=(3,5,7), padding=(1,2,3), num_heads=num_heads, window_size=window_size, sample='up', use_checkpoint=use_checkpoint) # s1
29 |
30 | self.fusion = nn.Conv3d(dim*3, dim, kernel_size=(3,5,7), stride=1, padding=(1,2,3))
31 |
32 | # absolute position embedding
33 | self.ape = ape
34 | if self.ape:
35 | self.absolute_pos_embed = nn.Parameter(torch.zeros(1, dim, self.patch_resolution[0], self.patch_resolution[1], self.patch_resolution[2]))
36 | trunc_normal_(self.absolute_pos_embed, std=.02)
37 |
38 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
39 | self.decoder = SwinTransformerLayer(dim=dim, depth=depth, num_heads=num_heads, window_size=window_size, qkv_bias=True,
40 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, use_checkpoint=use_checkpoint)
41 |
42 | # initial weights
43 | self.apply(self._init_weights)
44 |
45 | def _init_weights(self, m):
46 |
47 | if isinstance(m, nn.Linear):
48 | trunc_normal_(m.weight, std=.02)
49 | if isinstance(m, nn.Linear) and m.bias is not None:
50 | nn.init.constant_(m.bias, 0)
51 | elif isinstance(m, nn.LayerNorm):
52 | nn.init.constant_(m.bias, 0)
53 | nn.init.constant_(m.weight, 1.0)
54 |
55 | def encoder_forward(self, x):
56 |
57 | x1 = self.layer1(x)
58 | x2 = self.layer2(x1)
59 | x = self.layer3(x2)
60 | x = self.layer4(x, x2)
61 | x = self.layer5(x, x1)
62 |
63 | return x
64 |
65 | def forward(self, background, observation, mask):
66 |
67 | x = self.patchembed(background, observation, mask)
68 | if self.ape:
69 | x = [ x[i] + self.absolute_pos_embed for i in range(3) ]
70 |
71 | x = self.encoder_forward(x)
72 | x = self.fusion(torch.cat(x, dim=1))
73 | x = self.decoder(x)
74 |
75 | x = self.patchunembed(x)
76 | return x
77 |
78 |
79 | class Adas(object):
80 |
81 | def __init__(self, **model_params) -> None:
82 | super().__init__()
83 |
84 | params = model_params.get('params', {})
85 | criterion = model_params.get('criterion', 'CNPFLoss')
86 | self.optimizer_params = model_params.get('optimizer', {})
87 | self.scheduler_params = model_params.get('lr_scheduler', {})
88 |
89 | self.kernel = Adas_model(**params)
90 | self.best_loss = 9999
91 | self.criterion = self.get_criterion(criterion)
92 | self.criterion_mae = nn.L1Loss()
93 | self.criterion_mse = nn.MSELoss()
94 |
95 | if utils.is_dist_avail_and_initialized():
96 | self.device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
97 | if self.device == torch.device('cpu'):
98 | raise EnvironmentError('No GPUs, cannot initialize multigpu training.')
99 | else:
100 | self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
101 |
102 | def get_criterion(self, loss_type):
103 | if loss_type == 'UnifyMAE':
104 | return partial(self.unify_losses, criterion=nn.L1Loss())
105 | elif loss_type == 'UnifyMSE':
106 | return partial(self.unify_losses, criterion=nn.MSELoss())
107 | else:
108 | raise NotImplementedError('Invalid loss type.')
109 |
110 | def unify_losses(self, pred, target, criterion):
111 | loss_sum = 0
112 | unify_loss = criterion(pred[:,0,:,:], target[:,0,:,:])
113 | for i in range(1, len(pred[0])):
114 | loss = criterion(pred[:,i,:,:], target[:,i,:,:])
115 | loss_sum += loss / (loss/unify_loss).detach()
116 | return (loss_sum + unify_loss) / len(pred[0])
117 |
118 | def process_data(self, batch_data, args):
119 |
120 | inp_data = torch.cat([batch_data[0], batch_data[1]], dim=1)
121 | inp_data = F.interpolate(inp_data, size=(128,256), mode='bilinear').numpy()
122 | truth = batch_data[-1].to(self.device, non_blocking=True) # 69
123 | truth = F.interpolate(truth, size=(args.resolution,args.resolution//2*4), mode='bilinear')
124 | truth_down = F.interpolate(truth, size=(128,256), mode='bilinear')
125 |
126 | for _ in range(args.lead_time // 6):
127 | predict_data = args.forecast_model.run(None, {'input':inp_data})[0][:,:truth.shape[1]]
128 | inp_data = np.concatenate([inp_data[:,-truth.shape[1]:], predict_data], axis=1)
129 |
130 | background = torch.from_numpy(predict_data).to(self.device, non_blocking=True)
131 | mask = (torch.rand(truth.shape, device=self.device) >= args.ratio).float()
132 | observation = truth * mask
133 | mask = F.interpolate(mask, size=(128,256), mode='bilinear')
134 | observation = F.interpolate(observation, size=(128,256), mode='bilinear')
135 | observation = torch.where(mask==0, 0., observation/mask).to(self.device, non_blocking=True)
136 | mask = torch.where(mask==0, 0., 1.).to(self.device, non_blocking=True)
137 |
138 | return [background, observation, mask], truth_down
139 |
140 | def train(self, train_data_loader, valid_data_loader, logger, args):
141 |
142 | train_step = len(train_data_loader)
143 | valid_step = len(valid_data_loader)
144 | self.optimizer = get_optimizer(self.kernel, self.optimizer_params)
145 | self.scheduler = get_lr_scheduler(self.optimizer, self.scheduler_params, total_steps=train_step*args.max_epoch)
146 |
147 | for epoch in range(args.max_epoch):
148 | begin_time = time.time()
149 | self.kernel.train()
150 |
151 | for step, batch_data in enumerate(train_data_loader):
152 |
153 | input_list, y_target = self.process_data(batch_data[0], args)
154 | self.optimizer.zero_grad()
155 | y_pred = self.kernel(input_list[0], input_list[1], input_list[2])
156 | loss = self.criterion(y_pred, y_target)
157 | loss.backward()
158 | self.optimizer.step()
159 | self.scheduler.step()
160 |
161 | if ((step + 1) % 100 == 0) | (step+1 == train_step):
162 | logger.info(f'Train epoch:[{epoch+1}/{args.max_epoch}], step:[{step+1}/{train_step}], lr:[{self.scheduler.get_last_lr()[0]}], loss:[{loss.item()}]')
163 |
164 | self.kernel.eval()
165 | with torch.no_grad():
166 | total_loss = 0
167 |
168 | for step, batch_data in enumerate(valid_data_loader):
169 | input_list, y_target = self.process_data(batch_data[0], args)
170 | y_pred = self.kernel(input_list[0], input_list[1], input_list[2])
171 | loss = self.criterion(y_pred, y_target).item()
172 | total_loss += loss
173 |
174 | if ((step + 1) % 100 == 0) | (step+1 == valid_step):
175 | logger.info(f'Valid epoch:[{epoch+1}/{args.max_epoch}], step:[{step+1}/{valid_step}], loss:[{loss}]')
176 |
177 | if (total_loss/valid_step) < self.best_loss:
178 | if utils.get_world_size() > 1 and utils.get_rank() == 0:
179 | torch.save(self.kernel.module.state_dict(), f'{args.rundir}/best_model.pth')
180 | elif utils.get_world_size() == 1:
181 | torch.save(self.kernel.state_dict(), f'{args.rundir}/best_model.pth')
182 | logger.info(f'New best model appears in epoch {epoch+1}.')
183 | self.best_loss = total_loss/valid_step
184 | logger.info(f'Epoch {epoch+1} average loss:[{total_loss/valid_step}], time:[{time.time()-begin_time}]')
185 |
186 | def test(self, test_data_loader, logger, args):
187 |
188 | test_step = len(test_data_loader)
189 | data_mean, data_std = test_data_loader.dataset.get_meanstd()
190 | self.data_std = data_std.to(self.device)
191 |
192 | self.kernel.eval()
193 | with torch.no_grad():
194 | total_loss = 0
195 | total_mae = 0
196 | total_mse = 0
197 | total_rmse = 0
198 |
199 | for step, batch_data in enumerate(test_data_loader):
200 |
201 | input_list, y_target = self.process_data(batch_data[0], args)
202 | y_pred = self.kernel(input_list[0], input_list[1], input_list[2])
203 | loss = self.criterion(y_pred, y_target).item()
204 | mae = self.criterion_mae(y_pred, y_target).item()
205 | mse = self.criterion_mse(y_pred, y_target).item()
206 | rmse = WRMSE(y_pred, y_target, self.data_std)
207 |
208 | total_loss += loss
209 | total_mae += mae
210 | total_mse += mse
211 | total_rmse += rmse
212 | if ((step + 1) % 100 == 0) | (step+1 == test_step):
213 | logger.info(f'Valid step:[{step+1}/{test_step}], loss:[{loss}], MAE:[{mae}], MSE:[{mse}]')
214 |
215 | logger.info(f'Average loss:[{total_loss/test_step}], MAE:[{total_mae/test_step}], MSE:[{total_mse/test_step}]')
216 | logger.info(f'Average RMSE:[{total_rmse/test_step}]')
217 |
--------------------------------------------------------------------------------
/models/ConvCNP.py:
--------------------------------------------------------------------------------
1 | import time
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch.nn.utils import clip_grad_norm_
7 | from functools import partial
8 | from einops import rearrange
9 | import utils.misc as utils
10 | from utils.metrics import WRMSE
11 | from utils.builder import get_optimizer, get_lr_scheduler
12 | import modules
13 |
14 |
15 | class ConvCNP_model(nn.Module):
16 |
17 | def __init__(
18 | self,
19 | x_dim=69,
20 | y_dim=69,
21 | r_dim=512,
22 | XEncoder=nn.Identity,
23 | Conv=lambda y_dim: modules.make_abs_conv(nn.Conv2d)(
24 | y_dim,
25 | y_dim,
26 | groups=y_dim,
27 | kernel_size=11,
28 | padding=11 // 2,
29 | bias=False,
30 | ),
31 | CNN=partial(
32 | modules.CNN,
33 | ConvBlock=modules.ResConvBlock,
34 | Conv=nn.Conv2d,
35 | n_blocks=12,
36 | Normalization=nn.BatchNorm2d,
37 | is_chan_last=True,
38 | kernel_size=9,
39 | n_conv_layers=2,
40 | ),
41 | PredictiveDistribution=modules.MultivariateNormalDiag,
42 | p_y_loc_transformer=nn.Identity(),
43 | p_y_scale_transformer=lambda y_scale: 0.01 + 0.99 * F.softplus(y_scale),
44 | ):
45 | super().__init__()
46 |
47 | self.x_dim = x_dim
48 | self.y_dim = y_dim
49 | self.r_dim = r_dim
50 |
51 | self.conv = nn.ModuleList([Conv(y_dim), Conv(y_dim)])
52 | self.resizer = nn.ModuleList([nn.Linear(self.y_dim * 2, self.r_dim), nn.Linear(self.y_dim * 2, self.r_dim)])
53 | self.induced_to_induced = nn.ModuleList([CNN(self.r_dim), CNN(self.r_dim)])
54 | self.fusion = nn.Linear(self.r_dim * 2, self.r_dim)
55 | self.x_encoder = XEncoder()
56 |
57 | Decoder=modules.discard_ith_arg(partial(modules.MLP, n_hidden_layers=4, hidden_size=self.r_dim), i=0)
58 | # times 2 out because loc and scale (mean and var for gaussian)
59 | self.decoder = Decoder(self.x_dim, self.r_dim, self.y_dim * 2)
60 |
61 | self.PredictiveDistribution = PredictiveDistribution
62 | self.p_y_loc_transformer = p_y_loc_transformer
63 | self.p_y_scale_transformer = p_y_scale_transformer
64 |
65 | self.reset_parameters()
66 |
67 | def reset_parameters(self):
68 | modules.weights_init(self)
69 |
70 | def forward(self, input_list):
71 |
72 | X_cntxt, Y_cntxt, Xb_cntxt, Yb_cntxt, X_trgt = input_list[0], input_list[1], input_list[2], input_list[3], input_list[4]
73 | X_cntxt = self.x_encoder(X_cntxt) # b,h,w,c
74 | X_trgt = self.x_encoder(X_trgt)
75 | Xb_cntxt = self.x_encoder(Xb_cntxt) # b,h,w,c
76 |
77 | # {R^u}_u
78 | # size = [batch_size, *n_rep, r_dim] for n_channels list
79 | R = self.encode_globally(X_cntxt, Y_cntxt, Xb_cntxt, Yb_cntxt)
80 |
81 | z_samples, q_zCc, q_zCct = None, None, None
82 |
83 | # size = [n_z_samples, batch_size, *n_trgt, r_dim]
84 | R_trgt = self.trgt_dependent_representation(X_cntxt, z_samples, R, X_trgt)
85 |
86 | # p(y|cntxt,trgt)
87 | # batch shape=[n_z_samples, batch_size, *n_trgt] ; event shape=[y_dim]
88 | p_yCc = self.decode(X_trgt, R_trgt)
89 |
90 | return p_yCc, z_samples, q_zCc, q_zCct
91 |
92 | def cntxt_to_induced(self, mask_cntxt, X, index):
93 | """Infer the missing values and compute a density channel."""
94 |
95 | # channels have to be in second dimension for convolution
96 | # size = [batch_size, y_dim, *grid_shape]
97 | X = modules.channels_to_2nd_dim(X)
98 | # size = [batch_size, x_dim, *grid_shape]
99 | mask_cntxt = modules.channels_to_2nd_dim(mask_cntxt).float()
100 |
101 | # size = [batch_size, y_dim, *grid_shape]
102 | X_cntxt = X * mask_cntxt
103 | signal = self.conv[index](X_cntxt)
104 | density = self.conv[index](mask_cntxt.expand_as(X))
105 |
106 | # normalize
107 | out = signal / torch.clamp(density, min=1e-5)
108 |
109 | # size = [batch_size, y_dim * 2, *grid_shape]
110 | out = torch.cat([out, density], dim=1)
111 |
112 | # size = [batch_size, *grid_shape, y_dim * 2]
113 | out = modules.channels_to_last_dim(out)
114 |
115 | # size = [batch_size, *grid_shape, r_dim]
116 | out = self.resizer[index](out)
117 |
118 | return out
119 |
120 | def encode_globally(self, mask_cntxt, X, mask_cntxtb, Xb):
121 |
122 | # size = [batch_size, *grid_shape, r_dim] for each single channel
123 | R_induced = self.cntxt_to_induced(mask_cntxt, X, index=0)
124 | R_induced = self.induced_to_induced[0](R_induced)
125 |
126 | Rb_induced = self.cntxt_to_induced(mask_cntxtb, Xb, index=1)
127 | Rb_induced = self.induced_to_induced[1](Rb_induced)
128 |
129 | R_induced = rearrange(R_induced, 'b h w c -> b c h w')
130 | R_induced = F.interpolate(R_induced, size=Rb_induced.shape[1:3], mode='bilinear')
131 | R_induced = rearrange(R_induced, 'b c h w -> b h w c')
132 | R_fusion = self.fusion(torch.cat([R_induced, Rb_induced], dim=-1))
133 |
134 | return R_fusion
135 |
136 | def trgt_dependent_representation(self, _, __, R_induced, ___):
137 |
138 | # n_z_samples=1. size = [1, batch_size, n_trgt, r_dim]
139 | return R_induced.unsqueeze(0)
140 |
141 | def decode(self, X_trgt, R_trgt):
142 |
143 | # size = [n_z_samples, batch_size, *n_trgt, y_dim*2]
144 | p_y_suffstat = self.decoder(X_trgt, R_trgt)
145 |
146 | # size = [n_z_samples, batch_size, *n_trgt, y_dim]
147 | p_y_loc, p_y_scale = p_y_suffstat.split(self.y_dim, dim=-1)
148 |
149 | p_y_loc = self.p_y_loc_transformer(p_y_loc)
150 | p_y_scale = self.p_y_scale_transformer(p_y_scale)
151 |
152 | # batch shape=[n_z_samples, batch_size, *n_trgt] ; event shape=[y_dim]
153 | p_yCc = self.PredictiveDistribution(p_y_loc, p_y_scale)
154 |
155 | return p_yCc
156 |
157 |
158 | class ConvCNP(object):
159 |
160 | def __init__(self, **model_params) -> None:
161 | super().__init__()
162 |
163 | params = model_params.get('params', {})
164 | criterion = model_params.get('criterion', 'CNPFLoss')
165 | self.optimizer_params = model_params.get('optimizer', {})
166 | self.scheduler_params = model_params.get('lr_scheduler', {})
167 |
168 | self.kernel = ConvCNP_model(**params)
169 | self.best_loss = 9999
170 | self.criterion = self.get_criterion(criterion)
171 | self.criterion_mae = nn.L1Loss()
172 | self.criterion_mse = nn.MSELoss()
173 |
174 | if utils.is_dist_avail_and_initialized():
175 | self.device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
176 | if self.device == torch.device('cpu'):
177 | raise EnvironmentError('No GPUs, cannot initialize multigpu training.')
178 | else:
179 | self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
180 |
181 | def get_criterion(self, loss_type):
182 | if loss_type == 'CNPFLoss':
183 | return modules.CNPFLoss()
184 | elif loss_type == 'NLLLossLNPF':
185 | return modules.NLLLossLNPF()
186 | elif loss_type == 'ELBOLossLNPF':
187 | return modules.ELBOLossLNPF()
188 | elif loss_type == 'SUMOLossLNPF':
189 | return modules.SUMOLossLNPF()
190 | else:
191 | raise NotImplementedError('Invalid loss type.')
192 |
193 | def process_data(self, batch_data, args):
194 |
195 | inp_data = torch.cat([batch_data[0], batch_data[1]], dim=1)
196 | inp_data = F.interpolate(inp_data, size=(128,256), mode='bilinear').numpy()
197 | truth = batch_data[-1].to(self.device, non_blocking=True) # 69
198 | truth = F.interpolate(truth, size=(args.resolution,args.resolution//2*4), mode='bilinear')
199 | truth_down = F.interpolate(truth, size=(128,256), mode='bilinear')
200 |
201 | for _ in range(args.lead_time // 6):
202 | predict_data = args.forecast_model.run(None, {'input':inp_data})[0][:,:truth.shape[1]]
203 | inp_data = np.concatenate([inp_data[:,-truth.shape[1]:], predict_data], axis=1)
204 |
205 | xb_context = rearrange(torch.rand(predict_data.shape, device=self.device) >= 0, 'b c h w -> b h w c')
206 | x_context = rearrange(torch.rand(truth.shape, device=self.device) >= args.ratio, 'b c h w -> b h w c')
207 | x_target = rearrange(torch.rand(truth_down.shape, device=self.device) >= 0, 'b c h w -> b h w c')
208 | yb_context = rearrange(torch.from_numpy(predict_data).to(self.device, non_blocking=True), 'b c h w -> b h w c')
209 | y_context = rearrange(truth, 'b c h w -> b h w c')
210 | y_target = rearrange(truth_down, 'b c h w -> b h w c')
211 |
212 | return [x_context, y_context, xb_context, yb_context, x_target], y_target
213 |
214 | def train(self, train_data_loader, valid_data_loader, logger, args):
215 |
216 | train_step = len(train_data_loader)
217 | valid_step = len(valid_data_loader)
218 | self.optimizer = get_optimizer(self.kernel, self.optimizer_params)
219 | self.scheduler = get_lr_scheduler(self.optimizer, self.scheduler_params, total_steps=train_step*args.max_epoch)
220 |
221 | for epoch in range(args.max_epoch):
222 | begin_time = time.time()
223 | self.kernel.train()
224 |
225 | for step, batch_data in enumerate(train_data_loader):
226 |
227 | input_list, y_target = self.process_data(batch_data[0], args)
228 | self.optimizer.zero_grad()
229 | y_pred = self.kernel(input_list)
230 | if isinstance(self.criterion, torch.nn.Module):
231 | self.criterion.train()
232 | loss = self.criterion(y_pred, y_target)
233 | loss.backward()
234 | clip_grad_norm_(self.kernel.parameters(), max_norm=1)
235 | self.optimizer.step()
236 | self.scheduler.step()
237 |
238 | if ((step + 1) % 100 == 0) | (step+1 == train_step):
239 | logger.info(f'Train epoch:[{epoch+1}/{args.max_epoch}], step:[{step+1}/{train_step}], lr:[{self.scheduler.get_last_lr()[0]}], loss:[{loss.item()}]')
240 |
241 | self.kernel.eval()
242 | with torch.no_grad():
243 | total_loss = 0
244 |
245 | for step, batch_data in enumerate(valid_data_loader):
246 | input_list, y_target = self.process_data(batch_data[0], args)
247 | y_pred = self.kernel(input_list)
248 | if isinstance(self.criterion, torch.nn.Module):
249 | self.criterion.eval()
250 | loss = self.criterion(y_pred, y_target).item()
251 | total_loss += loss
252 |
253 | if ((step + 1) % 100 == 0) | (step+1 == valid_step):
254 | logger.info(f'Valid epoch:[{epoch+1}/{args.max_epoch}], step:[{step+1}/{valid_step}], loss:[{loss}]')
255 |
256 | if (total_loss/valid_step) < self.best_loss:
257 | if utils.get_world_size() > 1 and utils.get_rank() == 0:
258 | torch.save(self.kernel.module.state_dict(), f'{args.rundir}/best_model.pth')
259 | elif utils.get_world_size() == 1:
260 | torch.save(self.kernel.state_dict(), f'{args.rundir}/best_model.pth')
261 | logger.info(f'New best model appears in epoch {epoch+1}.')
262 | self.best_loss = total_loss/valid_step
263 | logger.info(f'Epoch {epoch+1} average loss:[{total_loss/valid_step}], time:[{time.time()-begin_time}]')
264 |
265 | def test(self, test_data_loader, logger, args):
266 |
267 | test_step = len(test_data_loader)
268 | data_mean, data_std = test_data_loader.dataset.get_meanstd()
269 | self.data_std = data_std.to(self.device)
270 |
271 | self.kernel.eval()
272 | with torch.no_grad():
273 | total_loss = 0
274 | total_mae = 0
275 | total_mse = 0
276 | total_rmse = 0
277 |
278 | for step, batch_data in enumerate(test_data_loader):
279 |
280 | input_list, y_target = self.process_data(batch_data[0], args)
281 | y_pred = self.kernel(input_list)
282 | if isinstance(self.criterion, torch.nn.Module):
283 | self.criterion.eval()
284 | loss = self.criterion(y_pred, y_target).item()
285 |
286 | y_pred = rearrange(y_pred[0].mean[0], 'b h w c -> b c h w')
287 | y_target = rearrange(y_target, 'b h w c -> b c h w')
288 | mae = self.criterion_mae(y_pred, y_target).item()
289 | mse = self.criterion_mse(y_pred, y_target).item()
290 | rmse = WRMSE(y_pred, y_target, self.data_std)
291 |
292 | total_loss += loss
293 | total_mae += mae
294 | total_mse += mse
295 | total_rmse += rmse
296 | if ((step + 1) % 100 == 0) | (step+1 == test_step):
297 | logger.info(f'Valid step:[{step+1}/{test_step}], loss:[{loss}], MAE:[{mae}], MSE:[{mse}]')
298 |
299 | logger.info(f'Average loss:[{total_loss/test_step}], MAE:[{total_mae/test_step}], MSE:[{total_mse/test_step}]')
300 | logger.info(f'Average RMSE:[{total_rmse/test_step}]')
301 |
--------------------------------------------------------------------------------
/models/FNP.py:
--------------------------------------------------------------------------------
1 | import time
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch.nn.utils import clip_grad_norm_
7 | from functools import partial
8 | from einops import rearrange
9 | import utils.misc as utils
10 | from utils.metrics import WRMSE
11 | from utils.builder import get_optimizer, get_lr_scheduler
12 | import modules
13 |
14 |
15 | class Encoder(nn.Module):
16 |
17 | def __init__(
18 | self,
19 | n_channels=[4,13,13,13,13,13],
20 | r_dim=64,
21 | XEncoder=nn.Identity,
22 | Conv=lambda y_dim: modules.make_abs_conv(nn.Conv2d)(
23 | y_dim,
24 | y_dim,
25 | groups=y_dim,
26 | kernel_size=11,
27 | padding=11 // 2,
28 | bias=False,
29 | ),
30 | CNN=partial(
31 | modules.CNN,
32 | ConvBlock=modules.ResConvBlock,
33 | Conv=nn.Conv2d,
34 | n_blocks=12,
35 | Normalization=nn.BatchNorm2d,
36 | activation=nn.SiLU(),
37 | is_chan_last=True,
38 | kernel_size=9,
39 | n_conv_layers=2,
40 | )):
41 | super().__init__()
42 |
43 | self.r_dim = r_dim
44 | self.n_channels = n_channels
45 | self.x_encoder = XEncoder()
46 |
47 | # components for encode_globally
48 | self.conv = [Conv(y_dim) for y_dim in n_channels] # for each single channel
49 | self.conv.append(modules.make_abs_conv(nn.Conv2d)(
50 | in_channels=sum(n_channels),
51 | out_channels=sum(n_channels),
52 | groups=1,
53 | kernel_size=11,
54 | padding=11 // 2,
55 | bias=False,
56 | )) # for all channels
57 | self.conv = nn.ModuleList(self.conv)
58 |
59 | self.resizer = [nn.Linear(y_dim * 2, self.r_dim) for y_dim in n_channels] # 2 because also confidence channels
60 | self.resizer.append(nn.Linear(sum(n_channels) * 2, self.r_dim))
61 | self.resizer = nn.ModuleList(self.resizer)
62 |
63 | self.induced_to_induced = nn.ModuleList([CNN(self.r_dim) for _ in range(len(n_channels)+1)])
64 |
65 | def forward(self, X_cntxt, Y_cntxt, X_trgt):
66 |
67 | X_cntxt = self.x_encoder(X_cntxt) # b,h,w,c
68 | X_trgt = self.x_encoder(X_trgt)
69 |
70 | # {R^u}_u
71 | # size = [batch_size, *n_rep, r_dim] for n_channels list
72 | R_trgt = self.encode_globally(X_cntxt, Y_cntxt)
73 |
74 | return R_trgt
75 |
76 | def cntxt_to_induced(self, mask_cntxt, X, index):
77 | """Infer the missing values and compute a density channel."""
78 |
79 | # channels have to be in second dimension for convolution
80 | # size = [batch_size, y_dim, *grid_shape]
81 | X = modules.channels_to_2nd_dim(X)
82 | # size = [batch_size, x_dim, *grid_shape]
83 | mask_cntxt = modules.channels_to_2nd_dim(mask_cntxt).float()
84 |
85 | # size = [batch_size, y_dim, *grid_shape]
86 | X_cntxt = X * mask_cntxt
87 | signal = self.conv[index](X_cntxt)
88 | density = self.conv[index](mask_cntxt.expand_as(X))
89 |
90 | # normalize
91 | out = signal / torch.clamp(density, min=1e-5)
92 |
93 | # size = [batch_size, y_dim * 2, *grid_shape]
94 | out = torch.cat([out, density], dim=1)
95 |
96 | # size = [batch_size, *grid_shape, y_dim * 2]
97 | out = modules.channels_to_last_dim(out)
98 |
99 | # size = [batch_size, *grid_shape, r_dim]
100 | out = self.resizer[index](out)
101 |
102 | return out
103 |
104 | def encode_globally(self, mask_cntxt, X):
105 |
106 | # size = [batch_size, *grid_shape, r_dim] for each single channel
107 | R_induced_all = []
108 | for i in range(len(self.n_channels)):
109 | R_induced = self.cntxt_to_induced(mask_cntxt[...,sum(self.n_channels[:i]):sum(self.n_channels[:i+1])],
110 | X[...,sum(self.n_channels[:i]):sum(self.n_channels[:i+1])], i)
111 | R_induced = self.induced_to_induced[i](R_induced)
112 | R_induced_all.append(R_induced)
113 | # the last for all channels
114 | R_induced = self.cntxt_to_induced(mask_cntxt, X, len(self.n_channels))
115 | R_induced = self.induced_to_induced[len(self.n_channels)](R_induced)
116 | R_induced_all.append(R_induced)
117 |
118 | return R_induced_all
119 |
120 |
121 | class FNP_model(nn.Module):
122 |
123 | def __init__(
124 | self,
125 | n_channels=[4,13,13,13,13,13],
126 | r_dim=128,
127 | use_nfl=True,
128 | use_dam=True,
129 | PredictiveDistribution=modules.MultivariateNormalDiag,
130 | p_y_loc_transformer=nn.Identity(),
131 | p_y_scale_transformer=lambda y_scale: 0.01 + 0.99 * F.softplus(y_scale),
132 | ):
133 | super().__init__()
134 |
135 | self.r_dim = r_dim
136 | self.y_dim = sum(n_channels)
137 | self.n_channels = n_channels
138 | self.use_dam = use_dam
139 |
140 | if use_nfl:
141 | EnCNN = partial(
142 | modules.FCNN,
143 | ConvBlock=modules.ResConvBlock,
144 | Conv=nn.Conv2d,
145 | n_blocks=4,
146 | Normalization=nn.BatchNorm2d,
147 | activation=nn.SiLU(),
148 | is_chan_last=True,
149 | kernel_size=9,
150 | n_conv_layers=2)
151 | else:
152 | EnCNN = partial(
153 | modules.CNN,
154 | ConvBlock=modules.ResConvBlock,
155 | Conv=nn.Conv2d,
156 | n_blocks=12,
157 | Normalization=nn.BatchNorm2d,
158 | activation=nn.SiLU(),
159 | is_chan_last=True,
160 | kernel_size=9,
161 | n_conv_layers=2)
162 |
163 | Decoder=modules.discard_ith_arg(partial(modules.MLP, n_hidden_layers=4, hidden_size=self.r_dim), i=0)
164 | self.obs_encoder = Encoder(n_channels=self.n_channels, r_dim=self.r_dim, CNN=EnCNN)
165 | self.back_encoder = Encoder(n_channels=self.n_channels, r_dim=self.r_dim, CNN=EnCNN)
166 | self.fusion = nn.ModuleList([nn.Linear(self.r_dim * 2, self.r_dim) for _ in range(len(n_channels)+1)])
167 |
168 | # times 2 out because loc and scale (mean and var for gaussian)
169 | self.decoder = nn.ModuleList([Decoder(y_dim, self.r_dim * 2, y_dim * 2) for y_dim in n_channels])
170 | if self.use_dam:
171 | self.smooth = nn.ModuleList([nn.Conv2d(self.r_dim * 2, self.r_dim, 9, padding=4) for _ in range(len(n_channels)+1)])
172 |
173 | self.PredictiveDistribution = PredictiveDistribution
174 | self.p_y_loc_transformer = p_y_loc_transformer
175 | self.p_y_scale_transformer = p_y_scale_transformer
176 |
177 | self.reset_parameters()
178 |
179 | def reset_parameters(self):
180 | modules.weights_init(self)
181 |
182 | def forward(self, input_list):
183 |
184 | Xo_cntxt, Yo_cntxt, Xb_cntxt, Yb_cntxt, X_trgt = input_list[0], input_list[1], input_list[2], input_list[3], input_list[4]
185 |
186 | # {R^u}_u
187 | # size = [batch_size, *n_rep, r_dim] for n_channels list
188 | Ro_trgt = self.obs_encoder(Xo_cntxt, Yo_cntxt, X_trgt)
189 | Rb_trgt = self.back_encoder(Xb_cntxt, Yb_cntxt, X_trgt)
190 |
191 | z_samples, q_zCc, q_zCct = None, None, None
192 |
193 | # interpolate
194 | Ro_trgt = [rearrange(Ro_trgt[i], 'b h w c -> b c h w') for i in range(len(self.n_channels)+1)]
195 | Rb_trgt = [rearrange(Rb_trgt[i], 'b h w c -> b c h w') for i in range(len(self.n_channels)+1)]
196 | Ro_trgt = [F.interpolate(Ro_trgt[i], size=Rb_trgt[i].shape[2:], mode='bilinear') for i in range(len(self.n_channels)+1)]
197 | Ro_trgt = [rearrange(Ro_trgt[i], 'b c h w -> b h w c') for i in range(len(self.n_channels)+1)]
198 | Rb_trgt = [rearrange(Rb_trgt[i], 'b c h w -> b h w c') for i in range(len(self.n_channels)+1)]
199 |
200 | # representation fusion
201 | R_fusion = [self.fusion[i](torch.cat([Ro_trgt[i], Rb_trgt[i]], dim=-1)) for i in range(len(self.n_channels)+1)]
202 | if self.use_dam:
203 | R_similar = [rearrange(self.similarity(R_fusion[i], Rb_trgt[i], Ro_trgt[i]), 'b h w c -> b c h w') for i in range(len(self.n_channels)+1)]
204 | R_fusion = [rearrange(self.smooth[i](R_similar[i]), 'b c h w -> b h w c') for i in range(len(self.n_channels)+1)]
205 |
206 | # size = [n_z_samples, batch_size, *n_trgt, r_dim]
207 | R_trgt = [self.trgt_dependent_representation(Xo_cntxt, Xb_cntxt, z_samples, R_fusion[i], X_trgt) for i in range(len(self.n_channels)+1)]
208 |
209 | # p(y|cntxt,trgt)
210 | # batch shape=[n_z_samples, batch_size, *n_trgt] ; event shape=[y_dim]
211 | p_yCc = self.decode(X_trgt, R_trgt, Yb_cntxt)
212 |
213 | return p_yCc, z_samples, q_zCc, q_zCct
214 |
215 | def similarity(self, R, Rb, Ro):
216 |
217 | distb = torch.sqrt(torch.sum((R-Rb)**2, dim=-1, keepdim=True))
218 | disto = torch.sqrt(torch.sum((R-Ro)**2, dim=-1, keepdim=True))
219 | mask = (disto > distb).float()
220 | R = torch.cat([Ro * mask + Rb * (1-mask), R], dim=-1)
221 |
222 | return R
223 |
224 | def trgt_dependent_representation(self, _, __, ___, R_induced, ____):
225 |
226 | # n_z_samples=1. size = [1, batch_size, n_trgt, r_dim]
227 | return R_induced.unsqueeze(0)
228 |
229 | def decode(self, X_trgt, R_trgt, Yb_cntxt):
230 |
231 | locs = []
232 | scales = []
233 |
234 | for i in range(len(self.n_channels)):
235 | R_trgt_single = torch.cat([R_trgt[i], R_trgt[-1]], dim=-1)
236 |
237 | # size = [n_z_samples, batch_size, *n_trgt, y_dim*2]
238 | p_y_suffstat = self.decoder[i](X_trgt, R_trgt_single)
239 |
240 | # size = [n_z_samples, batch_size, *n_trgt, y_dim]
241 | p_y_loc, p_y_scale = p_y_suffstat.split(self.n_channels[i], dim=-1)
242 |
243 | p_y_loc = self.p_y_loc_transformer(p_y_loc)
244 | p_y_scale = self.p_y_scale_transformer(p_y_scale)
245 |
246 | locs.append(p_y_loc)
247 | scales.append(p_y_scale)
248 |
249 | locs = torch.cat(locs, dim=-1) + Yb_cntxt
250 | scales = torch.cat(scales, dim=-1)
251 | # batch shape=[n_z_samples, batch_size, *n_trgt] ; event shape=[y_dim]
252 | p_yCc = self.PredictiveDistribution(locs, scales)
253 |
254 | return p_yCc
255 |
256 |
257 | class FNP(object):
258 |
259 | def __init__(self, **model_params) -> None:
260 | super().__init__()
261 |
262 | params = model_params.get('params', {})
263 | criterion = model_params.get('criterion', 'CNPFLoss')
264 | self.optimizer_params = model_params.get('optimizer', {})
265 | self.scheduler_params = model_params.get('lr_scheduler', {})
266 |
267 | self.kernel = FNP_model(**params)
268 | self.best_loss = 9999
269 | self.criterion = self.get_criterion(criterion)
270 | self.criterion_mae = nn.L1Loss()
271 | self.criterion_mse = nn.MSELoss()
272 |
273 | if utils.is_dist_avail_and_initialized():
274 | self.device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
275 | if self.device == torch.device('cpu'):
276 | raise EnvironmentError('No GPUs, cannot initialize multigpu training.')
277 | else:
278 | self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
279 |
280 | def get_criterion(self, loss_type):
281 | if loss_type == 'CNPFLoss':
282 | return modules.CNPFLoss()
283 | elif loss_type == 'NLLLossLNPF':
284 | return modules.NLLLossLNPF()
285 | elif loss_type == 'ELBOLossLNPF':
286 | return modules.ELBOLossLNPF()
287 | elif loss_type == 'SUMOLossLNPF':
288 | return modules.SUMOLossLNPF()
289 | else:
290 | raise NotImplementedError('Invalid loss type.')
291 |
292 | def process_data(self, batch_data, args):
293 |
294 | inp_data = torch.cat([batch_data[0], batch_data[1]], dim=1)
295 | inp_data = F.interpolate(inp_data, size=(128,256), mode='bilinear').numpy()
296 | truth = batch_data[-1].to(self.device, non_blocking=True) # 69
297 | truth = F.interpolate(truth, size=(args.resolution,args.resolution//2*4), mode='bilinear')
298 | truth_down = F.interpolate(truth, size=(128,256), mode='bilinear')
299 |
300 | for _ in range(args.lead_time // 6):
301 | predict_data = args.forecast_model.run(None, {'input':inp_data})[0][:,:truth.shape[1]]
302 | inp_data = np.concatenate([inp_data[:,-truth.shape[1]:], predict_data], axis=1)
303 |
304 | xb_context = rearrange(torch.rand(predict_data.shape, device=self.device) >= 0, 'b c h w -> b h w c')
305 | x_context = rearrange(torch.rand(truth.shape, device=self.device) >= args.ratio, 'b c h w -> b h w c')
306 | x_target = rearrange(torch.rand(truth_down.shape, device=self.device) >= 0, 'b c h w -> b h w c')
307 | yb_context = rearrange(torch.from_numpy(predict_data).to(self.device, non_blocking=True), 'b c h w -> b h w c')
308 | y_context = rearrange(truth, 'b c h w -> b h w c')
309 | y_target = rearrange(truth_down, 'b c h w -> b h w c')
310 |
311 | return [x_context, y_context, xb_context, yb_context, x_target], y_target
312 |
313 | def train(self, train_data_loader, valid_data_loader, logger, args):
314 |
315 | train_step = len(train_data_loader)
316 | valid_step = len(valid_data_loader)
317 | self.optimizer = get_optimizer(self.kernel, self.optimizer_params)
318 | self.scheduler = get_lr_scheduler(self.optimizer, self.scheduler_params, total_steps=train_step*args.max_epoch)
319 |
320 | for epoch in range(args.max_epoch):
321 | begin_time = time.time()
322 | self.kernel.train()
323 |
324 | for step, batch_data in enumerate(train_data_loader):
325 |
326 | input_list, y_target = self.process_data(batch_data[0], args)
327 | self.optimizer.zero_grad()
328 | y_pred = self.kernel(input_list)
329 | if isinstance(self.criterion, torch.nn.Module):
330 | self.criterion.train()
331 | loss = self.criterion(y_pred, y_target)
332 | loss.backward()
333 | clip_grad_norm_(self.kernel.parameters(), max_norm=1)
334 | self.optimizer.step()
335 | self.scheduler.step()
336 |
337 | if ((step + 1) % 100 == 0) | (step+1 == train_step):
338 | logger.info(f'Train epoch:[{epoch+1}/{args.max_epoch}], step:[{step+1}/{train_step}], lr:[{self.scheduler.get_last_lr()[0]}], loss:[{loss.item()}]')
339 |
340 | self.kernel.eval()
341 | with torch.no_grad():
342 | total_loss = 0
343 |
344 | for step, batch_data in enumerate(valid_data_loader):
345 | input_list, y_target = self.process_data(batch_data[0], args)
346 | y_pred = self.kernel(input_list)
347 | if isinstance(self.criterion, torch.nn.Module):
348 | self.criterion.eval()
349 | loss = self.criterion(y_pred, y_target).item()
350 | total_loss += loss
351 |
352 | if ((step + 1) % 100 == 0) | (step+1 == valid_step):
353 | logger.info(f'Valid epoch:[{epoch+1}/{args.max_epoch}], step:[{step+1}/{valid_step}], loss:[{loss}]')
354 |
355 | if (total_loss/valid_step) < self.best_loss:
356 | if utils.get_world_size() > 1 and utils.get_rank() == 0:
357 | torch.save(self.kernel.module.state_dict(), f'{args.rundir}/best_model.pth')
358 | elif utils.get_world_size() == 1:
359 | torch.save(self.kernel.state_dict(), f'{args.rundir}/best_model.pth')
360 | logger.info(f'New best model appears in epoch {epoch+1}.')
361 | self.best_loss = total_loss/valid_step
362 | logger.info(f'Epoch {epoch+1} average loss:[{total_loss/valid_step}], time:[{time.time()-begin_time}]')
363 |
364 | def test(self, test_data_loader, logger, args):
365 |
366 | test_step = len(test_data_loader)
367 | data_mean, data_std = test_data_loader.dataset.get_meanstd()
368 | self.data_std = data_std.to(self.device)
369 |
370 | self.kernel.eval()
371 | with torch.no_grad():
372 | total_loss = 0
373 | total_mae = 0
374 | total_mse = 0
375 | total_rmse = 0
376 |
377 | for step, batch_data in enumerate(test_data_loader):
378 |
379 | input_list, y_target = self.process_data(batch_data[0], args)
380 | y_pred = self.kernel(input_list)
381 | if isinstance(self.criterion, torch.nn.Module):
382 | self.criterion.eval()
383 | loss = self.criterion(y_pred, y_target).item()
384 |
385 | y_pred = rearrange(y_pred[0].mean[0], 'b h w c -> b c h w')
386 | y_target = rearrange(y_target, 'b h w c -> b c h w')
387 | mae = self.criterion_mae(y_pred, y_target).item()
388 | mse = self.criterion_mse(y_pred, y_target).item()
389 | rmse = WRMSE(y_pred, y_target, self.data_std)
390 |
391 | total_loss += loss
392 | total_mae += mae
393 | total_mse += mse
394 | total_rmse += rmse
395 | if ((step + 1) % 100 == 0) | (step+1 == test_step):
396 | logger.info(f'Valid step:[{step+1}/{test_step}], loss:[{loss}], MAE:[{mae}], MSE:[{mse}]')
397 |
398 | logger.info(f'Average loss:[{total_loss/test_step}], MAE:[{total_mae/test_step}], MSE:[{total_mse/test_step}]')
399 | logger.info(f'Average RMSE:[{total_rmse/test_step}]')
400 |
--------------------------------------------------------------------------------
/modules/__init__.py:
--------------------------------------------------------------------------------
1 | from .losses import *
2 | from .cnn import *
3 | from .encoders import *
4 | from .helpers import *
5 | from .initialization import *
6 | from .mlp import *
7 | from .transformer import *
--------------------------------------------------------------------------------
/modules/cnn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import functional as F
4 | from .initialization import init_param_, weights_init
5 | from .helpers import (
6 | channels_to_2nd_dim,
7 | channels_to_last_dim,
8 | make_depth_sep_conv,
9 | )
10 |
11 |
12 | __all__ = [
13 | "GaussianConv2d",
14 | "ConvBlock",
15 | "ResNormalizedConvBlock",
16 | "ResConvBlock",
17 | "CNN",
18 | "UnetCNN",
19 | "FCNN",
20 | ]
21 |
22 |
23 | class GaussianConv2d(nn.Module):
24 | def __init__(self, kernel_size=5, **kwargs):
25 | super().__init__()
26 | self.kwargs = kwargs
27 | assert kernel_size % 2 == 1
28 | self.kernel_sizes = (kernel_size, kernel_size)
29 | self.exponent = -(
30 | (torch.arange(0, kernel_size).view(-1, 1).float() - kernel_size // 2) ** 2
31 | )
32 |
33 | self.reset_parameters()
34 |
35 | def reset_parameters(self):
36 | self.weights_x = nn.Parameter(torch.tensor([1.0]))
37 | self.weights_y = nn.Parameter(torch.tensor([1.0]))
38 |
39 | def forward(self, X):
40 | # only switch first time to device
41 | self.exponent = self.exponent.to(X.device)
42 |
43 | marginal_x = torch.softmax(self.exponent * self.weights_x, dim=0)
44 | marginal_y = torch.softmax(self.exponent * self.weights_y, dim=0).T
45 |
46 | in_chan = X.size(1)
47 | filters = marginal_x @ marginal_y
48 | filters = filters.view(1, 1, *self.kernel_sizes).expand(
49 | in_chan, 1, *self.kernel_sizes
50 | )
51 |
52 | return F.conv2d(X, filters, groups=in_chan, **self.kwargs)
53 |
54 |
55 | class ConvBlock(nn.Module):
56 | """Simple convolutional block with a single layer.
57 |
58 | Parameters
59 | ----------
60 | in_chan : int
61 | Number of input channels.
62 |
63 | out_chan : int
64 | Number of output channels.
65 |
66 | Conv : nn.Module
67 | Convolutional layer (unitialized). E.g. `nn.Conv1d`.
68 |
69 | kernel_size : int or tuple, optional
70 | Size of the convolving kernel.
71 |
72 | dilation : int or tuple, optional
73 | Spacing between kernel elements.
74 |
75 | activation: callable, optional
76 | Activation object. E.g. `nn.ReLU`.
77 |
78 | Normalization : nn.Module, optional
79 | Normalization layer (unitialized). E.g. `nn.BatchNorm1d`.
80 |
81 | kwargs :
82 | Additional arguments to `Conv`.
83 |
84 | References
85 | ----------
86 | [1] He, K., Zhang, X., Ren, S., & Sun, J. (2016, October). Identity mappings
87 | in deep residual networks. In European conference on computer vision
88 | (pp. 630-645). Springer, Cham.
89 |
90 | [2] Chollet, F. (2017). Xception: Deep learning with depthwise separable
91 | convolutions. In Proceedings of the IEEE conference on computer vision
92 | and pattern recognition (pp. 1251-1258).
93 | """
94 |
95 | def __init__(
96 | self,
97 | in_chan,
98 | out_chan,
99 | Conv,
100 | kernel_size=5,
101 | dilation=1,
102 | activation=nn.ReLU(),
103 | Normalization=nn.Identity,
104 | **kwargs
105 | ):
106 | super().__init__()
107 | self.activation = activation
108 |
109 | padding = kernel_size // 2
110 |
111 | Conv = make_depth_sep_conv(Conv)
112 |
113 | self.conv = Conv(in_chan, out_chan, kernel_size, padding=padding, **kwargs)
114 | self.norm = Normalization(in_chan)
115 |
116 | self.reset_parameters()
117 |
118 | def reset_parameters(self):
119 | weights_init(self)
120 |
121 | def forward(self, X):
122 | return self.conv(self.activation(self.norm(X)))
123 |
124 |
125 | class ResConvBlock(nn.Module):
126 | """Convolutional block inspired by the pre-activation Resnet [1]
127 | and depthwise separable convolutions [2].
128 |
129 | Parameters
130 | ----------
131 | in_chan : int
132 | Number of input channels.
133 |
134 | out_chan : int
135 | Number of output channels.
136 |
137 | Conv : nn.Module
138 | Convolutional layer (unitialized). E.g. `nn.Conv1d`.
139 |
140 | kernel_size : int or tuple, optional
141 | Size of the convolving kernel. Should be odd to keep the same size.
142 |
143 | activation: callable, optional
144 | Activation object. E.g. `nn.RelU()`.
145 |
146 | Normalization : nn.Module, optional
147 | Normalization layer (unitialized). E.g. `nn.BatchNorm1d`.
148 |
149 | n_conv_layers : int, optional
150 | Number of convolutional layers, can be 1 or 2.
151 |
152 | is_bias : bool, optional
153 | Whether to use a bias.
154 |
155 | References
156 | ----------
157 | [1] He, K., Zhang, X., Ren, S., & Sun, J. (2016, October). Identity mappings
158 | in deep residual networks. In European conference on computer vision
159 | (pp. 630-645). Springer, Cham.
160 |
161 | [2] Chollet, F. (2017). Xception: Deep learning with depthwise separable
162 | convolutions. In Proceedings of the IEEE conference on computer vision
163 | and pattern recognition (pp. 1251-1258).
164 | """
165 |
166 | def __init__(
167 | self,
168 | in_chan,
169 | out_chan,
170 | Conv,
171 | kernel_size=5,
172 | activation=nn.ReLU(),
173 | Normalization=nn.Identity,
174 | is_bias=True,
175 | n_conv_layers=1,
176 | ):
177 | super().__init__()
178 | self.activation = activation
179 | self.n_conv_layers = n_conv_layers
180 | assert self.n_conv_layers in [1, 2]
181 |
182 | if kernel_size % 2 == 0:
183 | raise ValueError("`kernel_size={}`, but should be odd.".format(kernel_size))
184 |
185 | padding = kernel_size // 2
186 |
187 | if self.n_conv_layers == 2:
188 | self.norm1 = Normalization(in_chan)
189 | self.conv1 = make_depth_sep_conv(Conv)(
190 | in_chan, in_chan, kernel_size, padding=padding, bias=is_bias
191 | )
192 | self.norm2 = Normalization(in_chan)
193 | self.conv2_depthwise = Conv(
194 | in_chan, in_chan, kernel_size, padding=padding, groups=in_chan, bias=is_bias
195 | )
196 | self.conv2_pointwise = Conv(in_chan, out_chan, 1, bias=is_bias)
197 |
198 | self.reset_parameters()
199 |
200 | def reset_parameters(self):
201 | weights_init(self)
202 |
203 | def forward(self, X):
204 |
205 | if self.n_conv_layers == 2:
206 | out = self.conv1(self.activation(self.norm1(X)))
207 | else:
208 | out = X
209 |
210 | out = self.conv2_depthwise(self.activation(self.norm2(out)))
211 | # adds residual before point wise => output can change number of channels
212 | out = out + X
213 | out = self.conv2_pointwise(out.contiguous()) # for some reason need contiguous
214 | return out
215 |
216 |
217 | class ResNormalizedConvBlock(ResConvBlock):
218 | """Modification of `ResConvBlock` to use normalized convolutions [1].
219 |
220 | Parameters
221 | ----------
222 | in_chan : int
223 | Number of input channels.
224 |
225 | out_chan : int
226 | Number of output channels.
227 |
228 | Conv : nn.Module
229 | Convolutional layer (unitialized). E.g. `nn.Conv1d`.
230 |
231 | kernel_size : int or tuple, optional
232 | Size of the convolving kernel. Should be odd to keep the same size.
233 |
234 | activation: nn.Module, optional
235 | Activation object. E.g. `nn.RelU()`.
236 |
237 | is_bias : bool, optional
238 | Whether to use a bias.
239 |
240 | References
241 | ----------
242 | [1] Knutsson, H., & Westin, C. F. (1993, June). Normalized and differential
243 | convolution. In Proceedings of IEEE Conference on Computer Vision and
244 | Pattern Recognition (pp. 515-523). IEEE.
245 | """
246 |
247 | def __init__(
248 | self,
249 | in_chan,
250 | out_chan,
251 | Conv,
252 | kernel_size=5,
253 | activation=nn.ReLU(),
254 | is_bias=True,
255 | **kwargs
256 | ):
257 | super().__init__(
258 | in_chan,
259 | out_chan,
260 | Conv,
261 | kernel_size=kernel_size,
262 | activation=activation,
263 | is_bias=is_bias,
264 | Normalization=nn.Identity,
265 | **kwargs
266 | ) # make sure no normalization
267 |
268 | def reset_parameters(self):
269 | weights_init(self)
270 | self.bias = nn.Parameter(torch.tensor([0.0]))
271 |
272 | self.temperature = nn.Parameter(torch.tensor([0.0]))
273 | init_param_(self.temperature)
274 |
275 | def forward(self, X):
276 | """
277 | Apply a normalized convolution. X should contain 2*in_chan channels.
278 | First halves for signal, last halve for corresponding confidence channels.
279 | """
280 |
281 | signal, conf_1 = X.chunk(2, dim=1)
282 | # make sure confidence is in 0 1 (might not be due to the pointwise trsnf)
283 | conf_1 = conf_1.clamp(min=0, max=1)
284 | X = signal * conf_1
285 |
286 | numerator = self.conv1(self.activation(X))
287 | numerator = self.conv2_depthwise(self.activation(numerator))
288 | density = self.conv2_depthwise(self.conv1(conf_1))
289 | out = numerator / torch.clamp(density, min=1e-5)
290 |
291 | # adds residual before point wise => output can change number of channels
292 |
293 | # make sure that confidence cannot decrease and cannot be greater than 1
294 | conf_2 = conf_1 + torch.sigmoid(
295 | density * F.softplus(self.temperature) + self.bias
296 | )
297 | conf_2 = conf_2.clamp(max=1)
298 | out = out + X
299 |
300 | out = self.conv2_pointwise(out)
301 | conf_2 = self.conv2_pointwise(conf_2)
302 |
303 | return torch.cat([out, conf_2], dim=1)
304 |
305 |
306 |
307 | class SpectralConv2d_fast(nn.Module):
308 | def __init__(self, in_channels, out_channels, modes):
309 | super().__init__()
310 |
311 | """
312 | 2D Fourier layer. It does FFT, linear transform, and Inverse FFT.
313 | """
314 | self.in_channels = in_channels
315 | self.out_channels = out_channels
316 | self.modes = modes #Number of Fourier modes to multiply, at most floor(N/2) + 1
317 |
318 | self.scale = (1 / (in_channels * out_channels))
319 | self.weights1 = nn.Parameter(self.scale * torch.rand(2, in_channels, out_channels, self.modes, self.modes))
320 | self.weights2 = nn.Parameter(self.scale * torch.rand(2, in_channels, out_channels, self.modes, self.modes))
321 |
322 | def compl_mul2d(self, input, weights):
323 | # (batch, in_channel, x,y ), (in_channel, out_channel, x,y) -> (batch, out_channel, x,y)
324 | return torch.einsum("bixy,ioxy->boxy", input, weights)
325 |
326 | def forward(self, x):
327 | batchsize, dtype = x.shape[0], x.dtype
328 | #Compute Fourier coeffcients up to factor of e^(- something constant)
329 | x_ft = torch.fft.rfft2(x)
330 |
331 | # Multiply relevant Fourier modes
332 | out_ft_real = torch.zeros(batchsize, self.out_channels, x.size(-2), x.size(-1)//2 + 1, device=x.device)
333 | out_ft_imag = torch.zeros(batchsize, self.out_channels, x.size(-2), x.size(-1)//2 + 1, device=x.device)
334 |
335 | out_ft_real[:, :, :self.modes, :self.modes] = \
336 | self.compl_mul2d(x_ft[:, :, :self.modes, :self.modes].real, self.weights1[0]) - self.compl_mul2d(x_ft[:, :, :self.modes, :self.modes].imag, self.weights1[1])
337 | out_ft_real[:, :, -self.modes:, :self.modes] = \
338 | self.compl_mul2d(x_ft[:, :, -self.modes:, :self.modes].real, self.weights2[0]) - self.compl_mul2d(x_ft[:, :, -self.modes:, :self.modes].imag, self.weights2[1])
339 |
340 | out_ft_imag[:, :, :self.modes, :self.modes] = \
341 | self.compl_mul2d(x_ft[:, :, :self.modes, :self.modes].imag, self.weights1[0]) + self.compl_mul2d(x_ft[:, :, :self.modes, :self.modes].real, self.weights1[1])
342 | out_ft_imag[:, :, -self.modes:, :self.modes] = \
343 | self.compl_mul2d(x_ft[:, :, -self.modes:, :self.modes].imag, self.weights2[0]) + self.compl_mul2d(x_ft[:, :, -self.modes:, :self.modes].real, self.weights2[1])
344 |
345 | # Return to physical space
346 | out_ft = torch.stack([out_ft_real, out_ft_imag], dim=-1)
347 | out_ft = torch.view_as_complex(out_ft)
348 | x = torch.fft.irfft2(out_ft, s=(x.size(-2), x.size(-1)))
349 | x = x.type(dtype)
350 |
351 | return x
352 |
353 |
354 |
355 | class CNN(nn.Module):
356 | """Simple multilayer CNN.
357 |
358 | Parameters
359 | ----------
360 | n_channels : int or list
361 | Number of channels, same for input and output. If list then needs to be
362 | of size `n_blocks - 1`, e.g. [16, 32, 64] means that you will have a
363 | `[ConvBlock(16,32), ConvBlock(32, 64)]`.
364 |
365 | ConvBlock : nn.Module
366 | Convolutional block (unitialized). Needs to take as input `Should be
367 | initialized with `ConvBlock(in_chan, out_chan)`.
368 |
369 | n_blocks : int, optional
370 | Number of convolutional blocks.
371 |
372 | is_chan_last : bool, optional
373 | Whether the channels are on the last dimension of the input.
374 |
375 | kwargs :
376 | Additional arguments to `ConvBlock`.
377 | """
378 |
379 | def __init__(self, n_channels, ConvBlock, n_blocks=3, is_chan_last=False, **kwargs):
380 |
381 | super().__init__()
382 | self.n_blocks = n_blocks
383 | self.is_chan_last = is_chan_last
384 | self.in_out_channels = self._get_in_out_channels(n_channels, n_blocks)
385 | self.conv_blocks = nn.ModuleList(
386 | [
387 | ConvBlock(in_chan, out_chan, **kwargs)
388 | for in_chan, out_chan in self.in_out_channels
389 | ]
390 | )
391 | self.is_return_rep = False # never return representation for vanilla conv
392 |
393 | self.reset_parameters()
394 |
395 | def reset_parameters(self):
396 | weights_init(self)
397 |
398 | def _get_in_out_channels(self, n_channels, n_blocks):
399 | """Return a list of tuple of input and output channels."""
400 | if isinstance(n_channels, int):
401 | channel_list = [n_channels] * (n_blocks + 1)
402 | else:
403 | channel_list = list(n_channels)
404 |
405 | assert len(channel_list) == (n_blocks + 1), "{} != {}".format(
406 | len(channel_list), n_blocks + 1
407 | )
408 |
409 | return list(zip(channel_list, channel_list[1:]))
410 |
411 | def forward(self, X):
412 | if self.is_chan_last:
413 | X = channels_to_2nd_dim(X)
414 |
415 | X, representation = self.apply_convs(X)
416 |
417 | if self.is_chan_last:
418 | X = channels_to_last_dim(X)
419 |
420 | if self.is_return_rep:
421 | return X, representation
422 |
423 | return X
424 |
425 | def apply_convs(self, X):
426 | for conv_block in self.conv_blocks:
427 | X = conv_block(X)
428 | return X, None
429 |
430 |
431 | class FCNN(nn.Module):
432 | """Simple multilayer CNN.
433 |
434 | Parameters
435 | ----------
436 | n_channels : int or list
437 | Number of channels, same for input and output. If list then needs to be
438 | of size `n_blocks - 1`, e.g. [16, 32, 64] means that you will have a
439 | `[ConvBlock(16,32), ConvBlock(32, 64)]`.
440 |
441 | ConvBlock : nn.Module
442 | Convolutional block (unitialized). Needs to take as input `Should be
443 | initialized with `ConvBlock(in_chan, out_chan)`.
444 |
445 | n_blocks : int, optional
446 | Number of convolutional blocks.
447 |
448 | is_chan_last : bool, optional
449 | Whether the channels are on the last dimension of the input.
450 |
451 | kwargs :
452 | Additional arguments to `ConvBlock`.
453 | """
454 |
455 | def __init__(self, n_channels, ConvBlock, n_blocks=3, is_chan_last=False, **kwargs):
456 |
457 | super().__init__()
458 | self.n_blocks = n_blocks
459 | self.is_chan_last = is_chan_last
460 | self.in_out_channels = self._get_in_out_channels(n_channels, n_blocks)
461 | self.conv_blocks = nn.ModuleList(
462 | [
463 | ConvBlock(in_chan, out_chan, **kwargs)
464 | for in_chan, out_chan in self.in_out_channels
465 | ]
466 | )
467 | self.fno_blocks = nn.ModuleList(
468 | [
469 | SpectralConv2d_fast(in_chan, out_chan, modes=12)
470 | for in_chan, out_chan in self.in_out_channels
471 | ]
472 | )
473 | self.is_return_rep = False # never return representation for vanilla conv
474 |
475 | self.reset_parameters()
476 |
477 | def reset_parameters(self):
478 | weights_init(self)
479 |
480 | def _get_in_out_channels(self, n_channels, n_blocks):
481 | """Return a list of tuple of input and output channels."""
482 | if isinstance(n_channels, int):
483 | channel_list = [n_channels] * (n_blocks + 1)
484 | else:
485 | channel_list = list(n_channels)
486 |
487 | assert len(channel_list) == (n_blocks + 1), "{} != {}".format(
488 | len(channel_list), n_blocks + 1
489 | )
490 |
491 | return list(zip(channel_list, channel_list[1:]))
492 |
493 | def forward(self, X):
494 | if self.is_chan_last:
495 | X = channels_to_2nd_dim(X)
496 |
497 | X, representation = self.apply_convs(X)
498 |
499 | if self.is_chan_last:
500 | X = channels_to_last_dim(X)
501 |
502 | if self.is_return_rep:
503 | return X, representation
504 |
505 | return X
506 |
507 | def apply_convs(self, X):
508 | for i in range(self.n_blocks):
509 | # for conv_block in self.conv_blocks:
510 | X = self.conv_blocks[i](X) + self.fno_blocks[i](X) + X
511 | return X, None
512 |
513 |
514 | class UnetCNN(CNN):
515 | """Unet [1].
516 |
517 | Parameters
518 | ----------
519 | n_channels : int or list
520 | Number of channels, same for input and output. If list then needs to be
521 | of size `n_blocks - 1`, e.g. [16, 32, 64] means that you will have a
522 | `[ConvBlock(16,32), ConvBlock(32, 64)]`.
523 |
524 | ConvBlock : nn.Module
525 | Convolutional block (unitialized). Needs to take as input `Should be
526 | initialized with `ConvBlock(in_chan, out_chan)`.
527 |
528 | Pool : nn.Module
529 | Pooling layer (unitialized). E.g. torch.nn.MaxPool1d.
530 |
531 | upsample_mode : {'nearest', 'linear', bilinear', 'bicubic', 'trilinear'}
532 | The upsampling algorithm: nearest, linear (1D-only), bilinear, bicubic
533 | (2D-only), trilinear (3D-only).
534 |
535 | max_nchannels : int, optional
536 | Bounds the maximum number of channels instead of always doubling them at
537 | downsampling block.
538 |
539 | pooling_size : int or tuple, optional
540 | Size of the pooling filter.
541 |
542 | is_force_same_bottleneck : bool, optional
543 | Whether to use the average bottleneck for the same functions sampled at
544 | different context and target. If `True` the first and second halves
545 | of a batch should contain different samples of the same functions (in order).
546 |
547 | is_return_rep : bool, optional
548 | Whether to return a summary representation, that corresponds to the
549 | bottleneck + global mean pooling.
550 |
551 | kwargs :
552 | Additional arguments to `CNN` and `ConvBlock`.
553 |
554 | References
555 | ----------
556 | [1] Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. "U-net: Convolutional
557 | networks for biomedical image segmentation." International Conference on
558 | Medical image computing and computer-assisted intervention. Springer, Cham, 2015.
559 | """
560 |
561 | def __init__(
562 | self,
563 | n_channels,
564 | ConvBlock,
565 | Pool,
566 | upsample_mode,
567 | max_nchannels=256,
568 | pooling_size=2,
569 | is_force_same_bottleneck=False,
570 | is_return_rep=False,
571 | **kwargs
572 | ):
573 |
574 | self.max_nchannels = max_nchannels
575 | super().__init__(n_channels, ConvBlock, **kwargs)
576 | self.pooling_size = pooling_size
577 | self.pooling = Pool(self.pooling_size)
578 | self.upsample_mode = upsample_mode
579 | self.is_force_same_bottleneck = is_force_same_bottleneck
580 | self.is_return_rep = is_return_rep
581 |
582 | def apply_convs(self, X):
583 | n_down_blocks = self.n_blocks // 2
584 | residuals = [None] * n_down_blocks
585 |
586 | # Down
587 | for i in range(n_down_blocks):
588 | X = self.conv_blocks[i](X)
589 | residuals[i] = X
590 | X = self.pooling(X)
591 |
592 | # Bottleneck
593 | X = self.conv_blocks[n_down_blocks](X)
594 | # Representation before forcing same bottleneck
595 | representation = X.view(*X.shape[:2], -1).mean(-1)
596 |
597 | if self.is_force_same_bottleneck and self.training:
598 | # forces the u-net to use the bottleneck by giving additional information
599 | # there. I.e. taking average between bottleenck of different samples
600 | # of the same functions. Because bottleneck should be a global representation
601 | # => should not depend on the sample you chose
602 | batch_size = X.size(0)
603 | batch_1 = X[: batch_size // 2, ...]
604 | batch_2 = X[batch_size // 2 :, ...]
605 | X_mean = (batch_1 + batch_2) / 2
606 | X = torch.cat([X_mean, X_mean], dim=0)
607 |
608 | # Up
609 | for i in range(n_down_blocks + 1, self.n_blocks):
610 | X = F.interpolate(
611 | X,
612 | mode=self.upsample_mode,
613 | scale_factor=self.pooling_size,
614 | align_corners=True,
615 | )
616 | X = torch.cat(
617 | (X, residuals[n_down_blocks - i]), dim=1
618 | ) # concat on channels
619 | X = self.conv_blocks[i](X)
620 |
621 | return X, representation
622 |
623 | def _get_in_out_channels(self, n_channels, n_blocks):
624 | """Return a list of tuple of input and output channels for a Unet."""
625 | # doubles at every down layer, as in vanilla U-net
626 | factor_chan = 2
627 |
628 | assert n_blocks % 2 == 1, "n_blocks={} not odd".format(n_blocks)
629 | # e.g. if n_channels=16, n_blocks=5: [16, 32, 64]
630 | channel_list = [factor_chan ** i * n_channels for i in range(n_blocks // 2 + 1)]
631 | # e.g.: [16, 32, 64, 64, 32, 16]
632 | channel_list = channel_list + channel_list[::-1]
633 | # bound max number of channels by self.max_nchannels (besides first and
634 | # last dim as this is input / output cand sohould not be changed)
635 | channel_list = (
636 | channel_list[:1]
637 | + [min(c, self.max_nchannels) for c in channel_list[1:-1]]
638 | + channel_list[-1:]
639 | )
640 | # e.g.: [(16, 32), (32,64), (64, 64), (64, 32), (32, 16)]
641 | in_out_channels = super()._get_in_out_channels(channel_list, n_blocks)
642 | # e.g.: [(16, 32), (32,64), (64, 64), (128, 32), (64, 16)] due to concat
643 | idcs = slice(len(in_out_channels) // 2 + 1, len(in_out_channels))
644 | in_out_channels[idcs] = [
645 | (in_chan * 2, out_chan) for in_chan, out_chan in in_out_channels[idcs]
646 | ]
647 | return in_out_channels
648 |
649 |
650 |
651 |
--------------------------------------------------------------------------------
/modules/encoders.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from .initialization import weights_init
4 | from .mlp import MLP
5 |
6 | __all__ = [
7 | "RelativeSinusoidalEncodings",
8 | "SinusoidalEncodings",
9 | "merge_flat_input",
10 | "discard_ith_arg",
11 | ]
12 |
13 |
14 | class SinusoidalEncodings(nn.Module):
15 | """
16 | Converts a batch of N-dimensional spatial input X with values between `[-1,1]`
17 | to a batch of flat in vector splitted in N subvectors that encode the position via
18 | sinusoidal encodings.
19 |
20 | Parameters
21 | ----------
22 | x_dim : int
23 | Number of spatial inputs.
24 |
25 | out_dim : int
26 | size of output encoding. Each x_dim will have an encoding of size
27 | `out_dim//x_dim`.
28 | """
29 |
30 | def __init__(self, x_dim, out_dim):
31 | super().__init__()
32 | self.x_dim = x_dim
33 | # dimension of encoding for eacg x dimension
34 | self.sub_dim = out_dim // self.x_dim
35 | # in "attention is all you need" used 10000 but 512 dim, try to keep the
36 | # same ratio regardless of dim
37 | self._C = 10000 * (self.sub_dim / 512) ** 2
38 |
39 | if out_dim % x_dim != 0:
40 | raise ValueError(
41 | "out_dim={} has to be dividable by x_dim={}.".format(out_dim, x_dim)
42 | )
43 | if self.sub_dim % 2 != 0:
44 | raise ValueError(
45 | "sum_dim=out_dim/x_dim={} has to be dividable by 2.".format(
46 | self.sub_dim
47 | )
48 | )
49 |
50 | self._precompute_denom()
51 |
52 | def _precompute_denom(self):
53 | two_i_d = torch.arange(0, self.sub_dim, 2, dtype=torch.float) / self.sub_dim
54 | denom = torch.pow(self._C, two_i_d)
55 | denom = torch.repeat_interleave(denom, 2).unsqueeze(0)
56 | self.denom = denom.expand(1, self.x_dim, self.sub_dim)
57 |
58 | def forward(self, x):
59 | shape = x.shape
60 | # flatten besides last dim
61 | x = x.view(-1, shape[-1])
62 | # will only be passed once to GPU because precomputed
63 | self.denom = self.denom.to(x.device)
64 | # put x in a range which is similar to positions in NLP [1,51]
65 | x = (x.unsqueeze(-1) + 1) * 25 + 1
66 | out = x / self.denom
67 | out[:, :, 0::2] = torch.sin(out[:, :, 0::2])
68 | out[:, :, 1::2] = torch.cos(out[:, :, 1::2])
69 | # concatenate all different sinusoidal encodings for each x_dim
70 | # and unflatten
71 | out = out.view(*shape[:-1], self.sub_dim * self.x_dim)
72 | return out
73 |
74 |
75 | class RelativeSinusoidalEncodings(nn.Module):
76 | """Return relative positions of inputs between [-1,1]."""
77 |
78 | def __init__(self, x_dim, out_dim, window_size=2):
79 | super().__init__()
80 | self.pos_encoder = SinusoidalEncodings(x_dim, out_dim)
81 | self.weight = nn.Linear(out_dim, out_dim, bias=False)
82 | self.window_size = window_size
83 | self.out_dim = out_dim
84 |
85 | def forward(self, keys_pos, queries_pos):
86 | # size=[batch_size, n_queries, n_keys, x_dim]
87 | diff = (keys_pos.unsqueeze(1) - queries_pos.unsqueeze(2)).abs()
88 |
89 | # the abs differences will be between between 0, self.window_size
90 | # we multipl by 2/self.window_size then remove 1 to be [-1,1] which is
91 | # the range for `SinusoidalEncodings`
92 | scaled_diff = diff * 2 / self.window_size - 1
93 | out = self.weight(self.pos_encoder(scaled_diff))
94 |
95 | # set to 0 points that are further than window for extap
96 | out = out * (diff < self.window_size).float()
97 |
98 | return out
99 |
100 |
101 | # META ENCODERS
102 | class DiscardIthArg(nn.Module):
103 | """
104 | Helper module which discard the i^th argument of the constructor and forward,
105 | before being given to `To`.
106 | """
107 |
108 | def __init__(self, *args, i=0, To=nn.Identity, **kwargs):
109 | super().__init__()
110 | self.i = i
111 | self.destination = To(*self.filter_args(*args), **kwargs)
112 |
113 | def filter_args(self, *args):
114 | return [arg for i, arg in enumerate(args) if i != self.i]
115 |
116 | def forward(self, *args, **kwargs):
117 | return self.destination(*self.filter_args(*args), **kwargs)
118 |
119 |
120 | def discard_ith_arg(module, i, **kwargs):
121 | def discarded_arg(*args, **kwargs2):
122 | return DiscardIthArg(*args, i=i, To=module, **kwargs, **kwargs2)
123 |
124 | return discarded_arg
125 |
126 |
127 | class MergeFlatInputs(nn.Module):
128 | """
129 | Extend a module to take 2 flat inputs. It simply returns
130 | the concatenated flat inputs to the module `module({x1; x2})`.
131 |
132 | Parameters
133 | ----------
134 | FlatModule: nn.Module
135 | Module which takes a non flat inputs.
136 |
137 | x1_dim: int
138 | Dimensionality of the first flat inputs.
139 |
140 | x2_dim: int
141 | Dimensionality of the second flat inputs.
142 |
143 | n_out: int
144 | Size of ouput.
145 |
146 | is_sum_merge : bool, optional
147 | Whether to transform `flat_input` by an MLP first (if need to resize),
148 | then sum to `X` (instead of concatenating): useful if the difference in
149 | dimension between both inputs is very large => don't want one layer to
150 | depend only on a few dimension of a large input.
151 |
152 | kwargs:
153 | Additional arguments to FlatModule.
154 | """
155 |
156 | def __init__(self, FlatModule, x1_dim, x2_dim, n_out, is_sum_merge=False, **kwargs):
157 | super().__init__()
158 | self.is_sum_merge = is_sum_merge
159 |
160 | if self.is_sum_merge:
161 | dim = x1_dim
162 | self.resizer = MLP(x2_dim, dim) # transform to be the correct size
163 | else:
164 | dim = x1_dim + x2_dim
165 |
166 | self.flat_module = FlatModule(dim, n_out, **kwargs)
167 | self.reset_parameters()
168 |
169 | def reset_parameters(self):
170 | weights_init(self)
171 |
172 | def forward(self, x1, x2):
173 | if self.is_sum_merge:
174 | x2 = self.resizer(x2)
175 | # use activation because if not 2 linear layers in a row => useless computation
176 | out = torch.relu(x1 + x2)
177 | else:
178 | out = torch.cat((x1, x2), dim=-1)
179 |
180 | return self.flat_module(out)
181 |
182 |
183 | def merge_flat_input(module, is_sum_merge=False, **kwargs):
184 | """
185 | Extend a module to accept an additional flat input. I.e. the output should
186 | be called by `merge_flat_input(module)(x_shape, flat_dim, n_out, **kwargs)`.
187 |
188 | Notes
189 | -----
190 | - if x_shape is an integer (currently only available option), it simply returns
191 | the concatenated flat inputs to the module `module({x; flat_input})`.
192 | - if `is_sum_merge` then transform `flat_input` by an MLP first, then sum
193 | to `X` (instead of concatenating): useful if the difference in dimension
194 | between both inputs is very large => don't want one layer to depend only on
195 | a few dimension of a large input.
196 | """
197 |
198 | def merged_flat_input(x_shape, flat_dim, n_out, **kwargs2):
199 | assert isinstance(x_shape, int)
200 | return MergeFlatInputs(
201 | module,
202 | x_shape,
203 | flat_dim,
204 | n_out,
205 | is_sum_merge=is_sum_merge,
206 | **kwargs2,
207 | **kwargs
208 | )
209 |
210 | return merged_flat_input
211 |
--------------------------------------------------------------------------------
/modules/helpers.py:
--------------------------------------------------------------------------------
1 | import operator
2 | from functools import reduce
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from scipy.stats import rv_discrete
8 | from torch.distributions import Normal
9 | from torch.distributions.independent import Independent
10 | from .initialization import weights_init
11 |
12 |
13 | def sum_from_nth_dim(t, dim):
14 | """Sum all dims from `dim`. E.g. sum_after_nth_dim(torch.rand(2,3,4,5), 2).shape = [2,3]"""
15 | return t.view(*t.shape[:dim], -1).sum(-1)
16 |
17 |
18 | def logcumsumexp(x, dim):
19 | """Numerically stable log cumsum exp. SLow workaround waiting for https://github.com/pytorch/pytorch/pull/36308"""
20 |
21 | if (dim != -1) or (dim != x.ndimension() - 1):
22 | x = x.transpose(dim, -1)
23 |
24 | out = []
25 | for i in range(1, x.size(-1) + 1):
26 | out.append(torch.logsumexp(x[..., :i], dim=-1, keepdim=True))
27 | out = torch.cat(out, dim=-1)
28 |
29 | if (dim != -1) or (dim != x.ndimension() - 1):
30 | out = out.transpose(-1, dim)
31 | return out
32 |
33 |
34 | class LightTailPareto(rv_discrete):
35 | def _cdf(self, k, alpha):
36 | # alpha is factor like in SUMO paper
37 | # m is minimum number of samples
38 | m = self.a # lower bound of support
39 |
40 | # in the paper they us P(K >= k) but cdf is P(K <= k) = 1 - P(K > k) = 1 - P(K >= k + 1)
41 | k = k + 1
42 |
43 | # make sure has at least m samples
44 | k = np.clip(k - m, a_min=1, a_max=None) # makes sure no division by 0
45 | alpha = alpha - m
46 |
47 | # sample using pmf 1/k but with finite expectation
48 | cdf = 1 - np.where(k < alpha, 1 / k, (1 / alpha) * (0.9) ** (k - alpha))
49 |
50 | return cdf
51 |
52 |
53 | def isin_range(x, valid_range):
54 | """Check if array / tensor is in a given range elementwise."""
55 | return ((x >= valid_range[0]) & (x <= valid_range[1])).all()
56 |
57 |
58 | def channels_to_2nd_dim(X):
59 | """
60 | Takes a signal with channels on the last dimension (for most operations) and
61 | returns it with channels on the second dimension (for convolutions).
62 | """
63 | return X.permute(*([0, X.dim() - 1] + list(range(1, X.dim() - 1)))).contiguous()
64 |
65 |
66 | def channels_to_last_dim(X):
67 | """
68 | Takes a signal with channels on the second dimension (for convolutions) and
69 | returns it with channels on the last dimension (for most operations).
70 | """
71 | return X.permute(*([0] + list(range(2, X.dim())) + [1])).contiguous()
72 |
73 |
74 | def mask_and_apply(x, mask, f):
75 | """Applies a callable on a masked version of a input."""
76 | tranformed_selected = f(x.masked_select(mask))
77 | return x.masked_scatter(mask, tranformed_selected)
78 |
79 |
80 | def indep_shuffle_(a, axis=-1):
81 | """
82 | Shuffle `a` in-place along the given axis.
83 |
84 | Apply `numpy.random.shuffle` to the given axis of `a`.
85 | Each one-dimensional slice is shuffled independently.
86 |
87 | Credits : https://github.com/numpy/numpy/issues/5173
88 | """
89 | b = a.swapaxes(axis, -1)
90 | # Shuffle `b` in-place along the last axis. `b` is a view of `a`,
91 | # so `a` is shuffled in place, too.
92 | shp = b.shape[:-1]
93 | for ndx in np.ndindex(shp):
94 | np.random.shuffle(b[ndx])
95 |
96 |
97 | def ratio_to_int(percentage, max_val):
98 | """Converts a ratio to an integer if it is smaller than 1."""
99 | if 1 <= percentage <= max_val:
100 | out = percentage
101 | elif 0 <= percentage < 1:
102 | out = percentage * max_val
103 | else:
104 | raise ValueError("percentage={} outside of [0,{}].".format(percentage, max_val))
105 |
106 | return int(out)
107 |
108 |
109 | def prod(iterable):
110 | """Compute the product of all elements in an iterable."""
111 | return reduce(operator.mul, iterable, 1)
112 |
113 |
114 | def rescale_range(X, old_range, new_range):
115 | """Rescale X linearly to be in `new_range` rather than `old_range`."""
116 | old_min = old_range[0]
117 | new_min = new_range[0]
118 | old_delta = old_range[1] - old_min
119 | new_delta = new_range[1] - new_min
120 | return (((X - old_min) * new_delta) / old_delta) + new_min
121 |
122 |
123 | def MultivariateNormalDiag(loc, scale_diag):
124 | """Multi variate Gaussian with a diagonal covariance function (on the last dimension)."""
125 | if loc.dim() < 1:
126 | raise ValueError("loc must be at least one-dimensional.")
127 | return Independent(Normal(loc, scale_diag), 1)
128 |
129 |
130 | def clamp(
131 | x,
132 | minimum=-float("Inf"),
133 | maximum=float("Inf"),
134 | is_leaky=False,
135 | negative_slope=0.01,
136 | hard_min=None,
137 | hard_max=None,
138 | ):
139 | """
140 | Clamps a tensor to the given [minimum, maximum] (leaky) bound, with
141 | an optional hard clamping.
142 | """
143 | lower_bound = (
144 | (minimum + negative_slope * (x - minimum))
145 | if is_leaky
146 | else torch.zeros_like(x) + minimum
147 | )
148 | upper_bound = (
149 | (maximum + negative_slope * (x - maximum))
150 | if is_leaky
151 | else torch.zeros_like(x) + maximum
152 | )
153 | clamped = torch.max(lower_bound, torch.min(x, upper_bound))
154 |
155 | if hard_min is not None or hard_max is not None:
156 | if hard_min is None:
157 | hard_min = -float("Inf")
158 | elif hard_max is None:
159 | hard_max = float("Inf")
160 | clamped = clamp(x, minimum=hard_min, maximum=hard_max, is_leaky=False)
161 |
162 | return clamped
163 |
164 |
165 | class ProbabilityConverter(nn.Module):
166 | """Maps floats to probabilites (between 0 and 1), element-wise.
167 |
168 | Parameters
169 | ----------
170 | min_p : float, optional
171 | Minimum probability, can be useful to set greater than 0 in order to keep
172 | gradient flowing if the probability is used for convex combinations of
173 | different parts of the model. Note that maximum probability is `1-min_p`.
174 |
175 | activation : {"sigmoid", "hard-sigmoid", "leaky-hard-sigmoid"}, optional
176 | name of the activation to use to generate the probabilities. `sigmoid`
177 | has the advantage of being smooth and never exactly 0 or 1, which helps
178 | gradient flows. `hard-sigmoid` has the advantage of making all values
179 | between min_p and max_p equiprobable.
180 |
181 | is_train_temperature : bool, optional
182 | Whether to train the paremeter controling the steapness of the activation.
183 | This is useful when x is used for multiple tasks, and you don't want to
184 | constraint its magnitude.
185 |
186 | is_train_bias : bool, optional
187 | Whether to train the bias to shift the activation. This is useful when x is
188 | used for multiple tasks, and you don't want to constraint it's scale.
189 |
190 | trainable_dim : int, optional
191 | Size of the trainable bias and termperature. If `1` uses the same vale
192 | across all dimension, if not should be equal to the number of input
193 | dimensions to different trainable aprameters for each dimension. Note
194 | that the iitial value will still be the same for all dimensions.
195 |
196 | initial_temperature : int, optional
197 | Initial temperature, a higher temperature makes the activation steaper.
198 |
199 | initial_probability : float, optional
200 | Initial probability you want to start with.
201 |
202 | initial_x : float, optional
203 | First value that will be given to the function, important to make
204 | `initial_probability` work correctly.
205 |
206 | bias_transformer : callable, optional
207 | Transformer function of the bias. This function should only take care of
208 | the boundaries (e.g. leaky relu or relu).
209 |
210 | temperature_transformer : callable, optional
211 | Transformer function of the temperature. This function should only take
212 | care of the boundaries (e.g. leaky relu or relu).
213 | """
214 |
215 | def __init__(
216 | self,
217 | min_p=0.0,
218 | activation="sigmoid",
219 | is_train_temperature=False,
220 | is_train_bias=False,
221 | trainable_dim=1,
222 | initial_temperature=1.0,
223 | initial_probability=0.5,
224 | initial_x=0,
225 | bias_transformer=nn.Identity(),
226 | temperature_transformer=nn.Identity(),
227 | ):
228 |
229 | super().__init__()
230 | self.min_p = min_p
231 | self.activation = activation
232 | self.is_train_temperature = is_train_temperature
233 | self.is_train_bias = is_train_bias
234 | self.trainable_dim = trainable_dim
235 | self.initial_temperature = initial_temperature
236 | self.initial_probability = initial_probability
237 | self.initial_x = initial_x
238 | self.bias_transformer = bias_transformer
239 | self.temperature_transformer = temperature_transformer
240 |
241 | self.reset_parameters()
242 |
243 | def reset_parameters(self):
244 | self.temperature = torch.tensor([self.initial_temperature] * self.trainable_dim)
245 | if self.is_train_temperature:
246 | self.temperature = nn.Parameter(self.temperature)
247 |
248 | initial_bias = self._probability_to_bias(
249 | self.initial_probability, initial_x=self.initial_x
250 | )
251 |
252 | self.bias = torch.tensor([initial_bias] * self.trainable_dim)
253 | if self.is_train_bias:
254 | self.bias = nn.Parameter(self.bias)
255 |
256 | def forward(self, x):
257 | self.temperature.to(x.device)
258 | self.bias.to(x.device)
259 |
260 | temperature = self.temperature_transformer(self.temperature)
261 | bias = self.bias_transformer(self.bias)
262 |
263 | if self.activation == "sigmoid":
264 | full_p = torch.sigmoid((x + bias) * temperature)
265 |
266 | elif self.activation in ["hard-sigmoid", "leaky-hard-sigmoid"]:
267 | # uses 0.2 and 0.5 to be similar to sigmoid
268 | y = 0.2 * ((x + bias) * temperature) + 0.5
269 |
270 | if self.activation == "leaky-hard-sigmoid":
271 | full_p = clamp(
272 | y,
273 | minimum=0.1,
274 | maximum=0.9,
275 | is_leaky=True,
276 | negative_slope=0.01,
277 | hard_min=0,
278 | hard_max=0,
279 | )
280 | elif self.activation == "hard-sigmoid":
281 | full_p = clamp(y, minimum=0.0, maximum=1.0, is_leaky=False)
282 |
283 | else:
284 | raise ValueError("Unkown activation : {}".format(self.activation))
285 |
286 | p = rescale_range(full_p, (0, 1), (self.min_p, 1 - self.min_p))
287 |
288 | return p
289 |
290 | def _probability_to_bias(self, p, initial_x=0):
291 | """Compute the bias to use to satisfy the constraints."""
292 | assert p > self.min_p and p < 1 - self.min_p
293 | range_p = 1 - self.min_p * 2
294 | p = (p - self.min_p) / range_p
295 | p = torch.tensor(p, dtype=torch.float)
296 |
297 | if self.activation == "sigmoid":
298 | bias = -(torch.log((1 - p) / p) / self.initial_temperature + initial_x)
299 |
300 | elif self.activation in ["hard-sigmoid", "leaky-hard-sigmoid"]:
301 | bias = ((p - 0.5) / 0.2) / self.initial_temperature - initial_x
302 |
303 | return bias
304 |
305 |
306 | def dist_to_device(dist, device):
307 | """Set a distirbution to a given device."""
308 | if dist is None:
309 | return
310 | dist.base_dist.loc = dist.base_dist.loc.to(device)
311 | dist.base_dist.scale = dist.base_dist.loc.to(device)
312 |
313 |
314 | def make_abs_conv(Conv):
315 | """Make a convolution have only positive parameters."""
316 |
317 | class AbsConv(Conv):
318 | def forward(self, input):
319 | return F.conv2d(
320 | input,
321 | self.weight.abs(),
322 | self.bias,
323 | self.stride,
324 | self.padding,
325 | self.dilation,
326 | self.groups,
327 | )
328 |
329 | return AbsConv
330 |
331 |
332 | def make_padded_conv(Conv, Padder):
333 | """Make a convolution have any possible padding."""
334 |
335 | class PaddedConv(Conv):
336 | def __init__(self, *args, Padder=Padder, padding=0, **kwargs):
337 | old_padding = 0
338 | if Padder is None:
339 | Padder = nn.Identity
340 | old_padding = padding
341 |
342 | super().__init__(*args, padding=old_padding, **kwargs)
343 | self.padder = Padder(padding)
344 |
345 | def forward(self, X):
346 | X = self.padder(X)
347 | return super().forward(X)
348 |
349 | return PaddedConv
350 |
351 |
352 | def make_depth_sep_conv(Conv):
353 | """Make a convolution module depth separable."""
354 |
355 | class DepthSepConv(nn.Module):
356 | """Make a convolution depth separable.
357 |
358 | Parameters
359 | ----------
360 | in_channels : int
361 | Number of input channels.
362 |
363 | out_channels : int
364 | Number of output channels.
365 |
366 | kernel_size : int
367 |
368 | **kwargs :
369 | Additional arguments to `Conv`
370 | """
371 |
372 | def __init__(
373 | self,
374 | in_channels,
375 | out_channels,
376 | kernel_size,
377 | confidence=False,
378 | bias=True,
379 | **kwargs
380 | ):
381 | super().__init__()
382 | self.depthwise = Conv(
383 | in_channels,
384 | in_channels,
385 | kernel_size,
386 | groups=in_channels,
387 | bias=bias,
388 | **kwargs
389 | )
390 | self.pointwise = Conv(in_channels, out_channels, 1, bias=bias)
391 | self.reset_parameters()
392 |
393 | def forward(self, x):
394 | out = self.depthwise(x)
395 | out = self.pointwise(out)
396 | return out
397 |
398 | def reset_parameters(self):
399 | weights_init(self)
400 |
401 | return DepthSepConv
402 |
403 |
404 | class CircularPad2d(nn.Module):
405 | """Implements a 2d circular padding."""
406 |
407 | def __init__(self, padding):
408 | super().__init__()
409 | self.padding = padding
410 |
411 | def forward(self, x):
412 | return F.pad(x, (self.padding,) * 4, mode="circular")
413 |
414 |
415 | class BackwardPDB(torch.autograd.Function):
416 | """Run PDB in the backward pass."""
417 |
418 | @staticmethod
419 | def forward(ctx, input, name="debugger"):
420 | ctx.name = name
421 | ctx.save_for_backward(input)
422 | return input
423 |
424 | @staticmethod
425 | def backward(ctx, grad_output):
426 | (input,) = ctx.saved_tensors
427 | if not torch.isfinite(grad_output).all() or not torch.isfinite(input).all():
428 | import pdb
429 |
430 | pdb.set_trace()
431 | return grad_output, None # 2 args so return None for `name`
432 |
433 |
434 | backward_pdb = BackwardPDB.apply
435 |
--------------------------------------------------------------------------------
/modules/initialization.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | __all__ = ["weights_init"]
5 |
6 |
7 | def weights_init(module, **kwargs):
8 | """Initialize a module and all its descendents.
9 |
10 | Parameters
11 | ----------
12 | module : nn.Module
13 | module to initialize.
14 | """
15 | module.is_resetted = True
16 | for m in module.modules():
17 | try:
18 | if hasattr(module, "reset_parameters") and module.is_resetted:
19 | # don't reset if resetted already (might want special)
20 | continue
21 | except AttributeError:
22 | pass
23 |
24 | if isinstance(m, torch.nn.modules.conv._ConvNd):
25 | # used in https://github.com/brain-research/realistic-ssl-evaluation/
26 | nn.init.kaiming_normal_(m.weight, mode="fan_out", **kwargs)
27 | elif isinstance(m, nn.Linear):
28 | linear_init(m, **kwargs)
29 | elif isinstance(m, nn.BatchNorm2d):
30 | m.weight.data.fill_(1)
31 | m.bias.data.zero_()
32 |
33 |
34 | def get_activation_name(activation):
35 | """Given a string or a `torch.nn.modules.activation` return the name of the activation."""
36 | if isinstance(activation, str):
37 | return activation
38 |
39 | mapper = {
40 | nn.LeakyReLU: "leaky_relu",
41 | nn.ReLU: "relu",
42 | nn.Tanh: "tanh",
43 | nn.Sigmoid: "sigmoid",
44 | nn.Softmax: "sigmoid",
45 | }
46 | for k, v in mapper.items():
47 | if isinstance(activation, k):
48 | return k
49 |
50 | raise ValueError("Unkown given activation type : {}".format(activation))
51 |
52 |
53 | def get_gain(activation):
54 | """Given an object of `torch.nn.modules.activation` or an activation name
55 | return the correct gain."""
56 | if activation is None:
57 | return 1
58 |
59 | activation_name = get_activation_name(activation)
60 |
61 | param = None if activation_name != "leaky_relu" else activation.negative_slope
62 | gain = nn.init.calculate_gain(activation_name, param)
63 |
64 | return gain
65 |
66 |
67 | def linear_init(module, activation="relu"):
68 | """Initialize a linear layer.
69 |
70 | Parameters
71 | ----------
72 | module : nn.Module
73 | module to initialize.
74 |
75 | activation : `torch.nn.modules.activation` or str, optional
76 | Activation that will be used on the `module`.
77 | """
78 | x = module.weight
79 |
80 | if module.bias is not None:
81 | module.bias.data.zero_()
82 |
83 | if activation is None:
84 | return nn.init.xavier_uniform_(x)
85 |
86 | activation_name = get_activation_name(activation)
87 |
88 | if activation_name == "leaky_relu":
89 | a = 0 if isinstance(activation, str) else activation.negative_slope
90 | return nn.init.kaiming_uniform_(x, a=a, nonlinearity="leaky_relu")
91 | elif activation_name == "relu":
92 | return nn.init.kaiming_uniform_(x, nonlinearity="relu")
93 | elif activation_name in ["sigmoid", "tanh"]:
94 | return nn.init.xavier_uniform_(x, gain=get_gain(activation))
95 |
96 |
97 | def init_param_(param, activation=None, is_positive=False, bound=0.05, shift=0):
98 | """Initialize inplace some parameters of the model that are not part of a
99 | children module.
100 |
101 | Parameters
102 | ----------
103 | param : nn.Parameter:
104 | Parameters to initialize.
105 |
106 | activation : torch.nn.modules.activation or str, optional)
107 | Activation that will be used on the `param`.
108 |
109 | is_positive : bool, optional
110 | Whether to initilize only with positive values.
111 |
112 | bound : float, optional
113 | Maximum absolute value of the initealized values. By default `0.05` which
114 | is keras default uniform bound.
115 |
116 | shift : int, optional
117 | Shift the initialisation by a certain value (same as adding a value after init).
118 | """
119 | gain = get_gain(activation)
120 | if is_positive:
121 | nn.init.uniform_(param, 1e-5 + shift, bound * gain + shift)
122 | return
123 |
124 | nn.init.uniform_(param, -bound * gain + shift, bound * gain + shift)
125 |
--------------------------------------------------------------------------------
/modules/losses.py:
--------------------------------------------------------------------------------
1 | """Module for all the loss of Neural Process Family."""
2 | import abc
3 | import math
4 | import torch
5 | import torch.nn as nn
6 | from .helpers import (
7 | LightTailPareto,
8 | dist_to_device,
9 | logcumsumexp,
10 | sum_from_nth_dim,
11 | )
12 | from torch.distributions.kl import kl_divergence
13 |
14 |
15 | __all__ = ["CNPFLoss", "ELBOLossLNPF", "SUMOLossLNPF", "NLLLossLNPF"]
16 |
17 |
18 | def sum_log_prob(prob, sample):
19 | """Compute log probability then sum all but the z_samples and batch."""
20 | # size = [n_z_samples, batch_size, *]
21 | log_p = prob.log_prob(sample)
22 | # size = [n_z_samples, batch_size]
23 | sum_log_p = sum_from_nth_dim(log_p, 2)
24 | return sum_log_p
25 |
26 |
27 | class BaseLossNPF(nn.Module, abc.ABC):
28 | """
29 | Compute the negative log likelihood loss for members of the conditional neural process (sub-)family.
30 |
31 | Parameters
32 | ----------
33 | reduction : {None,"mean","sum"}, optional
34 | Batch wise reduction.
35 |
36 | is_force_mle_eval : bool, optional
37 | Whether to force mac likelihood eval even if has access to q_zCct
38 | """
39 |
40 | def __init__(self, reduction="mean", is_force_mle_eval=True):
41 | super().__init__()
42 | self.reduction = reduction
43 | self.is_force_mle_eval = is_force_mle_eval
44 |
45 | def forward(self, pred_outputs, Y_trgt):
46 | """Compute the Neural Process Loss.
47 |
48 | Parameters
49 | ----------
50 | pred_outputs : tuple
51 | Output of `NeuralProcessFamily`.
52 |
53 | Y_trgt : torch.Tensor, size=[batch_size, *n_trgt, y_dim]
54 | Set of all target values {y_t}.
55 |
56 | Return
57 | ------
58 | loss : torch.Tensor
59 | size=[batch_size] if `reduction=None` else [1].
60 | """
61 | p_yCc, z_samples, q_zCc, q_zCct = pred_outputs
62 |
63 | if self.training:
64 | loss = self.get_loss(p_yCc, z_samples, q_zCc, q_zCct, Y_trgt)
65 | else:
66 | # always uses NPML for evaluation
67 | if self.is_force_mle_eval:
68 | q_zCct = None
69 | loss = NLLLossLNPF.get_loss(self, p_yCc, z_samples, q_zCc, q_zCct, Y_trgt)
70 |
71 | if self.reduction is None:
72 | # size = [batch_size]
73 | return loss
74 | elif self.reduction == "mean":
75 | # size = [1]
76 | return loss.mean(0)
77 | elif self.reduction == "sum":
78 | # size = [1]
79 | return loss.sum(0)
80 | else:
81 | raise ValueError(f"Unknown {self.reduction}")
82 |
83 | @abc.abstractmethod
84 | def get_loss(self, p_yCc, z_samples, q_zCc, q_zCct, Y_trgt):
85 | """Compute the Neural Process Loss
86 |
87 | Parameters
88 | ------
89 | p_yCc: torch.distributions.Distribution, batch shape=[n_z_samples, batch_size, *n_trgt] ; event shape=[y_dim]
90 | Posterior distribution for target values {p(Y^t|y_c; x_c, x_t)}_t
91 |
92 | z_samples: torch.Tensor, size=[n_z_samples, batch_size, *n_lat, z_dim]
93 | Sampled latents. `None` if `encoded_path==deterministic`.
94 |
95 | q_zCc: torch.distributions.Distribution, batch shape=[batch_size, *n_lat] ; event shape=[z_dim]
96 | Latent distribution for the context points. `None` if `encoded_path==deterministic`.
97 |
98 | q_zCct: torch.distributions.Distribution, batch shape=[batch_size, *n_lat] ; event shape=[z_dim]
99 | Latent distribution for the targets. `None` if `encoded_path==deterministic`
100 | or not training or not `is_q_zCct`.
101 |
102 | Y_trgt: torch.Tensor, size=[batch_size, *n_trgt, y_dim]
103 | Set of all target values {y_t}.
104 |
105 | Return
106 | ------
107 | loss : torch.Tensor, size=[1].
108 | """
109 | pass
110 |
111 |
112 | class CNPFLoss(BaseLossNPF):
113 | """Losss for conditional neural process (suf-)family [1]."""
114 |
115 | def get_loss(self, p_yCc, _, q_zCc, ___, Y_trgt):
116 | assert q_zCc is None
117 | # \sum_t log p(y^t|z)
118 | # \sum_t log p(y^t|z). size = [z_samples, batch_size]
119 | sum_log_p_yCz = sum_log_prob(p_yCc, Y_trgt)
120 |
121 | # size = [batch_size]
122 | nll = -sum_log_p_yCz.squeeze(0)
123 | return nll
124 |
125 |
126 | class ELBOLossLNPF(BaseLossNPF):
127 | """Approximate conditional ELBO [1].
128 |
129 | References
130 | ----------
131 | [1] Garnelo, Marta, et al. "Neural processes." arXiv preprint
132 | arXiv:1807.01622 (2018).
133 | """
134 |
135 | def get_loss(self, p_yCc, _, q_zCc, q_zCct, Y_trgt):
136 |
137 | # first term in loss is E_{q(z|y_cntxt,y_trgt)}[\sum_t log p(y^t|z)]
138 | # \sum_t log p(y^t|z). size = [z_samples, batch_size]
139 | sum_log_p_yCz = sum_log_prob(p_yCc, Y_trgt)
140 |
141 | # E_{q(z|y_cntxt,y_trgt)}[...] . size = [batch_size]
142 | E_z_sum_log_p_yCz = sum_log_p_yCz.mean(0)
143 |
144 | # second term in loss is \sum_l KL[q(z^l|y_cntxt,y_trgt)||q(z^l|y_cntxt)]
145 | # KL[q(z^l|y_cntxt,y_trgt)||q(z^l|y_cntxt)]. size = [batch_size, *n_lat]
146 | kl_z = kl_divergence(q_zCct, q_zCc)
147 | # \sum_l ... . size = [batch_size]
148 | E_z_kl = sum_from_nth_dim(kl_z, 1)
149 |
150 | return -(E_z_sum_log_p_yCz - E_z_kl)
151 |
152 |
153 | class NLLLossLNPF(BaseLossNPF):
154 | """
155 | Compute the approximate negative log likelihood for Neural Process family[?].
156 |
157 | Notes
158 | -----
159 | - might be high variance
160 | - biased
161 | - approximate because expectation over q(z|cntxt) instead of p(z|cntxt)
162 | - if q_zCct is not None then uses importance sampling (i.e. assumes that sampled from it).
163 |
164 | References
165 | ----------
166 | [?]
167 | """
168 |
169 | def get_loss(self, p_yCc, z_samples, q_zCc, q_zCct, Y_trgt):
170 |
171 | n_z_samples, batch_size, *n_trgt = p_yCc.batch_shape
172 |
173 | # computes approximate LL in a numerically stable way
174 | # LL = E_{q(z|y_cntxt)}[ \prod_t p(y^t|z)]
175 | # LL MC = log ( mean_z ( \prod_t p(y^t|z)) )
176 | # = log [ sum_z ( \prod_t p(y^t|z)) ] - log(n_z_samples)
177 | # = log [ sum_z ( exp \sum_t log p(y^t|z)) ] - log(n_z_samples)
178 | # = log_sum_exp_z ( \sum_t log p(y^t|z)) - log(n_z_samples)
179 |
180 | # \sum_t log p(y^t|z). size = [n_z_samples, batch_size]
181 | sum_log_p_yCz = sum_log_prob(p_yCc, Y_trgt)
182 |
183 | # uses importance sampling weights if necessary
184 | if q_zCct is not None:
185 |
186 | # All latents are treated as independent. size = [n_z_samples, batch_size]
187 | sum_log_q_zCc = sum_log_prob(q_zCc, z_samples)
188 | sum_log_q_zCct = sum_log_prob(q_zCct, z_samples)
189 |
190 | # importance sampling : multiply \prod_t p(y^t|z)) by q(z|y_cntxt) / q(z|y_cntxt, y_trgt)
191 | # i.e. add log q(z|y_cntxt) - log q(z|y_cntxt, y_trgt)
192 | sum_log_w_k = sum_log_p_yCz + sum_log_q_zCc - sum_log_q_zCct
193 | else:
194 | sum_log_w_k = sum_log_p_yCz
195 |
196 | # log_sum_exp_z ... . size = [batch_size]
197 | log_S_z_sum_p_yCz = torch.logsumexp(sum_log_w_k, 0)
198 |
199 | # - log(n_z_samples)
200 | log_E_z_sum_p_yCz = log_S_z_sum_p_yCz - math.log(n_z_samples)
201 |
202 | # NEGATIVE log likelihood
203 | return -log_E_z_sum_p_yCz
204 |
205 |
206 | #! might need gradient clipping as in their paper
207 | class SUMOLossLNPF(BaseLossNPF):
208 | """
209 | Estimate negative log likelihood for Neural Process family using SUMO [1].
210 |
211 | Notes
212 | -----
213 | - approximate because expectation over q(z|cntxt) instead of p(z|cntxt)
214 | - if q_zCct is not None then uses importance sampling (i.e. assumes that sampled from it).
215 |
216 | Parameters
217 | ----------
218 | p_n_z_samples : scipy.stats.rv_frozen, optional
219 | Distribution for the number of of z_samples to take.
220 |
221 | References
222 | ----------
223 | [1] Luo, Yucen, et al. "SUMO: Unbiased Estimation of Log Marginal Probability for Latent
224 | Variable Models." arXiv preprint arXiv:2004.00353 (2020)
225 | """
226 |
227 | def __init__(
228 | self,
229 | p_n_z_samples=LightTailPareto(a=5).freeze(85),
230 | **kwargs,
231 | ):
232 | super().__init__()
233 | self.p_n_z_samples = p_n_z_samples
234 |
235 | def get_loss(self, p_yCc, z_samples, q_zCc, q_zCct, Y_trgt):
236 |
237 | n_z_samples, batch_size, *n_trgt = p_yCc.batch_shape
238 |
239 | # \sum_t log p(y^t|z). size = [n_z_samples, batch_size]
240 | sum_log_p_yCz = sum_log_prob(p_yCc, Y_trgt)
241 |
242 | # uses importance sampling weights if necessary
243 | if q_zCct is not None:
244 | # All latents are treated as independent. size = [n_z_samples, batch_size]
245 | sum_log_q_zCc = sum_log_prob(q_zCc, z_samples)
246 | sum_log_q_zCct = sum_log_prob(q_zCct, z_samples)
247 |
248 | #! It should be p(y^t,z|cntxt) but we are using q(z|cntxt) instead of p(z|cntxt)
249 | # \sum_t log (q(y^t,z|cntxt) / q(z|cntxt,trgt)) . size = [n_z_samples, batch_size]
250 | sum_log_w_k = sum_log_p_yCz + sum_log_q_zCc - sum_log_q_zCct
251 | else:
252 | sum_log_w_k = sum_log_p_yCz
253 |
254 | # size = [n_z_samples, 1]
255 | ks = (torch.arange(n_z_samples) + 1).unsqueeze(-1)
256 | #! slow to always put on GPU
257 | log_ks = ks.float().log().to(sum_log_w_k.device)
258 |
259 | #! the algorithm in the paper is not correct on ks[:k+1] and forgot inv_weights[m:]
260 | # size = [n_z_samples, batch_size]
261 | cum_iwae = logcumsumexp(sum_log_w_k, 0) - log_ks
262 |
263 | #! slow to always put on GPU
264 | # you want reverse_cdf which is P(K >= k ) = 1 - P(K < k) = 1 - P(K <= k-1) = 1 - CDF(k-1)
265 | inv_weights = torch.from_numpy(1 - self.p_n_z_samples.cdf(ks - 1)).to(
266 | sum_log_w_k.device
267 | )
268 |
269 | m = self.p_n_z_samples.support()[0]
270 | # size = [batch_size]
271 | sumo = cum_iwae[m - 1] + (
272 | inv_weights[m:] * (cum_iwae[m:] - cum_iwae[m - 1 : -1])
273 | ).sum(0)
274 |
275 | nll = -sumo
276 | return nll
277 |
--------------------------------------------------------------------------------
/modules/mlp.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | import torch.nn as nn
3 | from .initialization import linear_init
4 |
5 |
6 | __all__ = ["MLP"]
7 |
8 |
9 | class MLP(nn.Module):
10 | """General MLP class.
11 |
12 | Parameters
13 | ----------
14 | input_size: int
15 |
16 | output_size: int
17 |
18 | hidden_size: int, optional
19 | Number of hidden neurones.
20 |
21 | n_hidden_layers: int, optional
22 | Number of hidden layers.
23 |
24 | activation: callable, optional
25 | Activation function. E.g. `nn.RelU()`.
26 |
27 | is_bias: bool, optional
28 | Whether to use biaises in the hidden layers.
29 |
30 | dropout: float, optional
31 | Dropout rate.
32 |
33 | is_force_hid_smaller : bool, optional
34 | Whether to force the hidden dimensions to be smaller or equal than in and out.
35 | If not, it forces the hidden dimension to be larger or equal than in or out.
36 |
37 | is_res : bool, optional
38 | Whether to use residual connections.
39 | """
40 |
41 | def __init__(
42 | self,
43 | input_size,
44 | output_size,
45 | hidden_size=32,
46 | n_hidden_layers=1,
47 | activation=nn.ReLU(),
48 | is_bias=True,
49 | dropout=0,
50 | is_force_hid_smaller=False,
51 | is_res=False,
52 | ):
53 | super().__init__()
54 |
55 | self.input_size = input_size
56 | self.output_size = output_size
57 | self.hidden_size = hidden_size
58 | self.n_hidden_layers = n_hidden_layers
59 | self.is_res = is_res
60 |
61 | if is_force_hid_smaller and self.hidden_size > max(
62 | self.output_size, self.input_size
63 | ):
64 | self.hidden_size = max(self.output_size, self.input_size)
65 | txt = "hidden_size={} larger than output={} and input={}. Setting it to {}."
66 | warnings.warn(
67 | txt.format(hidden_size, output_size, input_size, self.hidden_size)
68 | )
69 | elif self.hidden_size < min(self.output_size, self.input_size):
70 | self.hidden_size = min(self.output_size, self.input_size)
71 | txt = (
72 | "hidden_size={} smaller than output={} and input={}. Setting it to {}."
73 | )
74 | warnings.warn(
75 | txt.format(hidden_size, output_size, input_size, self.hidden_size)
76 | )
77 |
78 | self.dropout = nn.Dropout(p=dropout) if dropout > 0 else nn.Identity()
79 | self.activation = activation
80 |
81 | self.to_hidden = nn.Linear(self.input_size, self.hidden_size, bias=is_bias)
82 | self.linears = nn.ModuleList(
83 | [
84 | nn.Linear(self.hidden_size, self.hidden_size, bias=is_bias)
85 | for _ in range(self.n_hidden_layers - 1)
86 | ]
87 | )
88 | self.out = nn.Linear(self.hidden_size, self.output_size, bias=is_bias)
89 |
90 | self.reset_parameters()
91 |
92 | def forward(self, x):
93 | out = self.to_hidden(x)
94 | out = self.activation(out)
95 | x = self.dropout(out)
96 |
97 | for linear in self.linears:
98 | out = linear(x)
99 | out = self.activation(out)
100 | if self.is_res:
101 | out = out + x
102 | out = self.dropout(out)
103 | x = out
104 |
105 | out = self.out(x)
106 | return out
107 |
108 | def reset_parameters(self):
109 | linear_init(self.to_hidden, activation=self.activation)
110 | for lin in self.linears:
111 | linear_init(lin, activation=self.activation)
112 | linear_init(self.out)
113 |
--------------------------------------------------------------------------------
/modules/transformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | from timm.models.layers import DropPath, trunc_normal_
6 | import torch.utils.checkpoint as checkpoint
7 | from functools import reduce, lru_cache
8 | from operator import mul
9 | from einops import rearrange
10 |
11 |
12 | __all__ = ["AllPatchEmbed", "PatchRecover", "BasicLayer", "SwinTransformerLayer"]
13 |
14 |
15 | class Mlp(nn.Module):
16 | """ Multilayer perceptron."""
17 |
18 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
19 | super().__init__()
20 | out_features = out_features or in_features
21 | hidden_features = hidden_features or in_features
22 | self.fc1 = nn.Linear(in_features, hidden_features)
23 | self.act = act_layer()
24 | self.fc2 = nn.Linear(hidden_features, out_features)
25 | self.drop = nn.Dropout(drop)
26 |
27 | def forward(self, x):
28 | x = self.fc1(x)
29 | x = self.act(x)
30 | x = self.drop(x)
31 | x = self.fc2(x)
32 | x = self.drop(x)
33 | return x
34 |
35 |
36 | def swin_window_partition(x, window_size):
37 | """
38 | Args:
39 | x: (B, D, H, W, C)
40 | window_size (tuple[int]): window size
41 |
42 | Returns:
43 | windows: (B*num_windows, window_size*window_size, C)
44 | """
45 | B, D, H, W, C = x.shape
46 | x = x.view(B, D // window_size[0], window_size[0], H // window_size[1], window_size[1], W // window_size[2], window_size[2], C)
47 | windows = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, reduce(mul, window_size), C)
48 | return windows
49 |
50 |
51 | def swin_window_reverse(windows, window_size, B, D, H, W):
52 | """
53 | Args:
54 | windows: (B*num_windows, window_size, window_size, C)
55 | window_size (tuple[int]): Window size
56 | H (int): Height of image
57 | W (int): Width of image
58 |
59 | Returns:
60 | x: (B, D, H, W, C)
61 | """
62 | x = windows.view(B, D // window_size[0], H // window_size[1], W // window_size[2], window_size[0], window_size[1], window_size[2], -1)
63 | x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(B, D, H, W, -1)
64 | return x
65 |
66 |
67 | def get_window_size(x_size, window_size, shift_size=None):
68 | use_window_size = list(window_size)
69 | if shift_size is not None:
70 | use_shift_size = list(shift_size)
71 | for i in range(len(x_size)):
72 | if x_size[i] <= window_size[i]:
73 | use_window_size[i] = x_size[i]
74 | if shift_size is not None:
75 | use_shift_size[i] = 0
76 |
77 | if shift_size is None:
78 | return tuple(use_window_size)
79 | else:
80 | return tuple(use_window_size), tuple(use_shift_size)
81 |
82 |
83 | class WindowAttention3D(nn.Module):
84 | """ Window based multi-head self attention (W-MSA) module with relative position bias.
85 | It supports both of shifted and non-shifted window.
86 | Args:
87 | dim (int): Number of input channels.
88 | window_size (tuple[int]): The temporal length, height and width of the window.
89 | num_heads (int): Number of attention heads.
90 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
91 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
92 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
93 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0
94 | """
95 |
96 | def __init__(self, dim, window_size, num_heads, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., cross=False):
97 |
98 | super().__init__()
99 | self.dim = dim
100 | self.window_size = window_size # Wd, Wh, Ww
101 | self.num_heads = num_heads
102 | head_dim = dim // num_heads
103 | self.scale = qk_scale or head_dim ** -0.5
104 | self.cross = cross
105 |
106 | # define a parameter table of relative position bias
107 | self.relative_position_bias_table = nn.Parameter(
108 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1) * (2 * window_size[2] - 1), num_heads)) # 2*Wd-1 * 2*Wh-1 * 2*Ww-1, nH
109 |
110 | # get pair-wise relative position index for each token inside the window
111 | coords_d = torch.arange(self.window_size[0])
112 | coords_h = torch.arange(self.window_size[1])
113 | coords_w = torch.arange(self.window_size[2])
114 | coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w)) # 3, Wd, Wh, Ww
115 | coords_flatten = torch.flatten(coords, 1) # 3, Wd*Wh*Ww
116 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 3, Wd*Wh*Ww, Wd*Wh*Ww
117 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wd*Wh*Ww, Wd*Wh*Ww, 3
118 | relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
119 | relative_coords[:, :, 1] += self.window_size[1] - 1
120 | relative_coords[:, :, 2] += self.window_size[2] - 1
121 |
122 | relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1)
123 | relative_coords[:, :, 1] *= (2 * self.window_size[2] - 1)
124 | relative_position_index = relative_coords.sum(-1) # Wd*Wh*Ww, Wd*Wh*Ww
125 | self.register_buffer("relative_position_index", relative_position_index)
126 |
127 | if self.cross:
128 | self.q = nn.Linear(dim, dim, bias=qkv_bias)
129 | self.k = nn.Linear(dim, dim, bias=qkv_bias)
130 | self.v = nn.Linear(dim, dim, bias=qkv_bias)
131 | else:
132 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
133 |
134 | self.attn_drop = nn.Dropout(attn_drop)
135 | self.proj = nn.Linear(dim, dim)
136 | self.proj_drop = nn.Dropout(proj_drop)
137 |
138 | trunc_normal_(self.relative_position_bias_table, std=.02)
139 | self.softmax = nn.Softmax(dim=-1)
140 |
141 | def forward(self, x, mask=None, condition=None):
142 | """ Forward function.
143 | Args:
144 | x: input features with shape of (num_windows*B, N, C)
145 | mask: (0/-inf) mask with shape of (num_windows, N, N) or None
146 | """
147 | B_, N, C = x.shape
148 | if self.cross:
149 | q = self.q(x).reshape(B_, N, self.num_heads, C // self.num_heads).transpose(1,2)
150 | k = self.k(condition).reshape(B_, N, self.num_heads, C // self.num_heads).transpose(1,2)
151 | v = self.v(condition).reshape(B_, N, self.num_heads, C // self.num_heads).transpose(1,2)
152 | else:
153 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
154 | q, k, v = qkv[0], qkv[1], qkv[2] # B_, nH, N, C
155 |
156 | q = q * self.scale
157 | attn = q @ k.transpose(-2, -1)
158 |
159 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index[:N, :N].reshape(-1)].reshape(
160 | N, N, -1) # Wd*Wh*Ww,Wd*Wh*Ww,nH
161 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wd*Wh*Ww, Wd*Wh*Ww
162 | attn = attn + relative_position_bias.unsqueeze(0) # B_, nH, N, N
163 |
164 | if mask is not None:
165 | nW = mask.shape[0]
166 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
167 | attn = attn.view(-1, self.num_heads, N, N)
168 | attn = self.softmax(attn)
169 | else:
170 | attn = self.softmax(attn)
171 |
172 | attn = self.attn_drop(attn)
173 |
174 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
175 | x = self.proj(x)
176 | x = self.proj_drop(x)
177 | return x
178 |
179 |
180 | class SwinTransformerBlock3D(nn.Module):
181 | """ Swin Transformer Block.
182 |
183 | Args:
184 | dim (int): Number of input channels.
185 | num_heads (int): Number of attention heads.
186 | window_size (tuple[int]): Window size.
187 | shift_size (tuple[int]): Shift size for SW-MSA.
188 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
189 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
190 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
191 | drop (float, optional): Dropout rate. Default: 0.0
192 | attn_drop (float, optional): Attention dropout rate. Default: 0.0
193 | drop_path (float, optional): Stochastic depth rate. Default: 0.0
194 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
195 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
196 | """
197 |
198 | def __init__(self, dim, num_heads, window_size=(2,7,7), shift_size=(0,0,0),
199 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
200 | act_layer=nn.GELU, norm_layer=nn.LayerNorm, cross=False, use_checkpoint=False):
201 | super().__init__()
202 | self.dim = dim
203 | self.num_heads = num_heads
204 | self.window_size = window_size
205 | self.shift_size = shift_size
206 | self.mlp_ratio = mlp_ratio
207 | self.use_checkpoint=use_checkpoint
208 | self.cross = cross
209 |
210 | assert 0 <= self.shift_size[0] < self.window_size[0], "shift_size must in 0-window_size"
211 | assert 0 <= self.shift_size[1] < self.window_size[1], "shift_size must in 0-window_size"
212 | assert 0 <= self.shift_size[2] < self.window_size[2], "shift_size must in 0-window_size"
213 |
214 | self.norm1 = norm_layer(dim)
215 | self.attn = WindowAttention3D(
216 | dim, window_size=self.window_size, num_heads=num_heads,
217 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, cross=cross)
218 |
219 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
220 | self.norm2 = norm_layer(dim)
221 | mlp_hidden_dim = int(dim * mlp_ratio)
222 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
223 |
224 | if self.cross:
225 | self.norm_condition = norm_layer(dim)
226 |
227 | def forward_part1(self, x, mask_matrix, condition):
228 | B, D, H, W, C = x.shape
229 | window_size, shift_size = get_window_size((D, H, W), self.window_size, self.shift_size)
230 |
231 | x = self.norm1(x)
232 | # pad feature maps to multiples of window size
233 | pad_l = pad_t = pad_d0 = 0
234 | pad_d1 = (window_size[0] - D % window_size[0]) % window_size[0]
235 | pad_b = (window_size[1] - H % window_size[1]) % window_size[1]
236 | pad_r = (window_size[2] - W % window_size[2]) % window_size[2]
237 | x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_d0, pad_d1))
238 | _, Dp, Hp, Wp, _ = x.shape
239 | # cyclic shift
240 | if any(i > 0 for i in shift_size):
241 | shifted_x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), dims=(1, 2, 3))
242 | attn_mask = mask_matrix
243 | else:
244 | shifted_x = x
245 | attn_mask = None
246 | # partition windows
247 | x_windows = swin_window_partition(shifted_x, window_size) # B*nW, Wd*Wh*Ww, C
248 | condition_windows = condition
249 |
250 | if self.cross:
251 | condition = self.norm_condition(condition)
252 | condition = F.pad(condition, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_d0, pad_d1))
253 | if any(i > 0 for i in shift_size):
254 | shifted_condition = torch.roll(condition, shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), dims=(1, 2, 3))
255 | else:
256 | shifted_condition = condition
257 | condition_windows = swin_window_partition(shifted_condition, window_size)
258 |
259 | # W-MSA/SW-MSA
260 | attn_windows = self.attn(x_windows, mask=attn_mask, condition=condition_windows) # B*nW, Wd*Wh*Ww, C
261 | # merge windows
262 | attn_windows = attn_windows.view(-1, *(window_size+(C,)))
263 | shifted_x = swin_window_reverse(attn_windows, window_size, B, Dp, Hp, Wp) # B D' H' W' C
264 | # reverse cyclic shift
265 | if any(i > 0 for i in shift_size):
266 | x = torch.roll(shifted_x, shifts=(shift_size[0], shift_size[1], shift_size[2]), dims=(1, 2, 3))
267 | else:
268 | x = shifted_x
269 |
270 | if pad_d1 >0 or pad_r > 0 or pad_b > 0:
271 | x = x[:, :D, :H, :W, :].contiguous()
272 | return x
273 |
274 | def forward_part2(self, x):
275 | return self.drop_path(self.mlp(self.norm2(x)))
276 |
277 | def forward(self, x, mask_matrix, condition=None):
278 | """ Forward function.
279 |
280 | Args:
281 | x: Input feature, tensor size (B, D, H, W, C).
282 | mask_matrix: Attention mask for cyclic shift.
283 | """
284 |
285 | shortcut = x
286 | if self.use_checkpoint:
287 | x = checkpoint.checkpoint(self.forward_part1, x, mask_matrix, condition)
288 | else:
289 | x = self.forward_part1(x, mask_matrix, condition)
290 | x = shortcut + self.drop_path(x)
291 |
292 | if self.use_checkpoint:
293 | x = x + checkpoint.checkpoint(self.forward_part2, x)
294 | else:
295 | x = x + self.forward_part2(x)
296 |
297 | return x
298 |
299 |
300 | # cache each stage results
301 | @lru_cache()
302 | def compute_mask(D, H, W, window_size, shift_size, device):
303 | img_mask = torch.zeros((1, D, H, W, 1), device=device) # 1 Dp Hp Wp 1
304 | cnt = 0
305 | for d in slice(-window_size[0]), slice(-window_size[0], -shift_size[0]), slice(-shift_size[0],None):
306 | for h in slice(-window_size[1]), slice(-window_size[1], -shift_size[1]), slice(-shift_size[1],None):
307 | for w in slice(-window_size[2]), slice(-window_size[2], 0), slice(0,None):
308 | img_mask[:, d, h, w, :] = cnt
309 | cnt += 1
310 | mask_windows = swin_window_partition(img_mask, window_size) # nW, ws[0]*ws[1]*ws[2], 1
311 | mask_windows = mask_windows.squeeze(-1) # nW, ws[0]*ws[1]*ws[2]
312 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
313 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
314 | return attn_mask
315 |
316 |
317 | class GatedCrossAttention(nn.Module):
318 | """ A basic Swin Transformer layer for one stage.
319 |
320 | Args:
321 | dim (int): Number of feature channels
322 | depth (int): Depths of this stage.
323 | num_heads (int): Number of attention head.
324 | window_size (tuple[int]): Local window size. Default: (1,7,7).
325 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
326 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
327 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
328 | drop (float, optional): Dropout rate. Default: 0.0
329 | attn_drop (float, optional): Attention dropout rate. Default: 0.0
330 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
331 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
332 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
333 | """
334 |
335 | def __init__(self, dim, num_heads, window_size=(1,7,7), mlp_ratio=4., qkv_bias=True, qk_scale=None,
336 | drop=0., attn_drop=0., drop_path=0., norm_layer=nn.LayerNorm, use_checkpoint=False):
337 | super().__init__()
338 | self.window_size = window_size
339 | self.shift_size = tuple(i // 2 for i in window_size)
340 | self.use_checkpoint = use_checkpoint
341 |
342 | # build blocks
343 | self.back1 = SwinTransformerBlock3D(dim=dim, num_heads=num_heads, window_size=window_size, shift_size=(0,0,0),
344 | mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop,
345 | drop_path=drop_path, norm_layer=norm_layer, cross=True, use_checkpoint=use_checkpoint)
346 | self.back2 = SwinTransformerBlock3D(dim=dim, num_heads=num_heads, window_size=window_size, shift_size=self.shift_size,
347 | mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop,
348 | drop_path=drop_path, norm_layer=norm_layer, cross=True, use_checkpoint=use_checkpoint)
349 | self.obser1 = SwinTransformerBlock3D(dim=dim, num_heads=num_heads, window_size=window_size, shift_size=(0,0,0),
350 | mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop,
351 | drop_path=drop_path, norm_layer=norm_layer, cross=True, use_checkpoint=use_checkpoint)
352 | self.obser2 = SwinTransformerBlock3D(dim=dim, num_heads=num_heads, window_size=window_size, shift_size=self.shift_size,
353 | mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop,
354 | drop_path=drop_path, norm_layer=norm_layer, cross=True, use_checkpoint=use_checkpoint)
355 | self.gate = nn.Conv3d(dim, dim, 1)
356 |
357 | def forward(self, x1, x2, gate):
358 | """ Forward function.
359 |
360 | Args:
361 | x: Input feature, tensor size (B, C, D, H, W).
362 | """
363 | # calculate attention mask for SW-MSA
364 | B, C, D, H, W = x1.shape
365 | window_size, shift_size = get_window_size((D,H,W), self.window_size, self.shift_size)
366 | x1 = rearrange(x1, 'b c d h w -> b d h w c')
367 | x2 = rearrange(x2, 'b c d h w -> b d h w c')
368 | gate = rearrange(gate, 'b c d h w -> b d h w c')
369 | Dp = int(np.ceil(D / window_size[0])) * window_size[0]
370 | Hp = int(np.ceil(H / window_size[1])) * window_size[1]
371 | Wp = int(np.ceil(W / window_size[2])) * window_size[2]
372 | attn_mask = compute_mask(Dp, Hp, Wp, window_size, shift_size, x1.device)
373 |
374 | x1_mid = x1 * (1 - gate) + self.back1(x1 * gate, attn_mask, x2)
375 | x2_mid = self.obser1(x2, attn_mask, x1)
376 | x1 = x1_mid * (1 - gate) + self.back2(x1_mid * gate, attn_mask, x2_mid)
377 | x2 = self.obser2(x2_mid, attn_mask, x1_mid)
378 |
379 | x1 = rearrange(x1, 'b d h w c -> b c d h w')
380 | x2 = rearrange(x2, 'b d h w c -> b c d h w')
381 | gate = rearrange(gate, 'b d h w c -> b c d h w')
382 | gate = torch.sigmoid(self.gate(gate))
383 | return [x1, x2, gate]
384 |
385 |
386 | class SwinTransformerLayer(nn.Module):
387 | """ A basic Swin Transformer layer for one stage.
388 |
389 | Args:
390 | dim (int): Number of feature channels
391 | depth (int): Depths of this stage.
392 | num_heads (int): Number of attention head.
393 | window_size (tuple[int]): Local window size. Default: (1,7,7).
394 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
395 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
396 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
397 | drop (float, optional): Dropout rate. Default: 0.0
398 | attn_drop (float, optional): Attention dropout rate. Default: 0.0
399 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
400 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
401 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
402 | """
403 |
404 | def __init__(self,
405 | dim,
406 | depth,
407 | num_heads,
408 | window_size=(1,7,7),
409 | mlp_ratio=4.,
410 | qkv_bias=True,
411 | qk_scale=None,
412 | drop=0.,
413 | attn_drop=0.,
414 | drop_path=0.,
415 | norm_layer=nn.LayerNorm,
416 | use_checkpoint=False):
417 | super().__init__()
418 | self.window_size = window_size
419 | self.shift_size = tuple(i // 2 for i in window_size)
420 | self.depth = depth
421 | self.use_checkpoint = use_checkpoint
422 |
423 | # build blocks
424 | self.blocks = nn.ModuleList([
425 | SwinTransformerBlock3D(
426 | dim=dim,
427 | num_heads=num_heads,
428 | window_size=window_size,
429 | shift_size=(0,0,0) if (i % 2 == 0) else self.shift_size,
430 | mlp_ratio=mlp_ratio,
431 | qkv_bias=qkv_bias,
432 | qk_scale=qk_scale,
433 | drop=drop,
434 | attn_drop=attn_drop,
435 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
436 | norm_layer=norm_layer,
437 | use_checkpoint=use_checkpoint,
438 | )
439 | for i in range(depth)])
440 |
441 | def forward(self, x):
442 | """ Forward function.
443 |
444 | Args:
445 | x: Input feature, tensor size (B, C, D, H, W).
446 | """
447 | # calculate attention mask for SW-MSA
448 | B, C, D, H, W = x.shape
449 | window_size, shift_size = get_window_size((D,H,W), self.window_size, self.shift_size)
450 | x = rearrange(x, 'b c d h w -> b d h w c')
451 | Dp = int(np.ceil(D / window_size[0])) * window_size[0]
452 | Hp = int(np.ceil(H / window_size[1])) * window_size[1]
453 | Wp = int(np.ceil(W / window_size[2])) * window_size[2]
454 | attn_mask = compute_mask(Dp, Hp, Wp, window_size, shift_size, x.device)
455 | for blk in self.blocks:
456 | x = blk(x, attn_mask)
457 | x = x.view(B, D, H, W, -1)
458 | x = rearrange(x, 'b d h w c -> b c d h w')
459 | return x
460 |
461 |
462 | class PatchMerging(nn.Module):
463 | """ Patch Merging Layer
464 |
465 | Args:
466 | dim (int): Number of input channels.
467 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
468 | """
469 | def __init__(self, dim, norm_layer=nn.LayerNorm):
470 | super().__init__()
471 | self.dim = dim
472 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
473 | self.norm = norm_layer(4 * dim)
474 |
475 | def forward(self, x):
476 | """ Forward function.
477 |
478 | Args:
479 | x: Input feature, tensor size (B, D, H, W, C).
480 | """
481 | B, D, H, W, C = x.shape
482 |
483 | # padding
484 | pad_input = (H % 2 == 1) or (W % 2 == 1)
485 | if pad_input:
486 | x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
487 |
488 | x0 = x[:, :, 0::2, 0::2, :] # B D H/2 W/2 C
489 | x1 = x[:, :, 0::2, 1::2, :] # B D H/2 W/2 C
490 | x2 = x[:, :, 1::2, 0::2, :] # B D H/2 W/2 C
491 | x3 = x[:, :, 1::2, 1::2, :] # B D H/2 W/2 C
492 | x = torch.cat([x0, x1, x2, x3], -1) # B D H/2 W/2 4*C
493 |
494 | x = self.norm(x)
495 | x = self.reduction(x)
496 |
497 | return x
498 |
499 |
500 | class AllPatchMerging(nn.Module):
501 | """ Patch Merging Layer
502 |
503 | Args:
504 | dim (int): Number of input channels.
505 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
506 | """
507 | def __init__(self, dim, norm_layer=nn.LayerNorm):
508 | super().__init__()
509 | self.dim = dim
510 | self.merge1 = PatchMerging(dim, norm_layer)
511 | self.merge2 = PatchMerging(dim, norm_layer)
512 | self.merge3 = PatchMerging(dim, norm_layer)
513 |
514 | def forward(self, x):
515 | """ Forward function.
516 |
517 | Args:
518 | x: Input feature, tensor size (B, C, D, H, W).
519 | """
520 | x1 = rearrange(x[0], 'b c d h w -> b d h w c')
521 | x2 = rearrange(x[1], 'b c d h w -> b d h w c')
522 | gate = rearrange(x[2], 'b c d h w -> b d h w c')
523 | x1 = self.merge1(x1)
524 | x2 = self.merge2(x2)
525 | gate = torch.sigmoid(self.merge3(gate))
526 | x1 = rearrange(x1, 'b d h w c -> b c d h w')
527 | x2 = rearrange(x2, 'b d h w c -> b c d h w')
528 | gate = rearrange(gate, 'b d h w c -> b c d h w')
529 | return [x1, x2, gate]
530 |
531 |
532 | class PatchExpand(nn.Module):
533 | """ Patch Merging Layer
534 |
535 | Args:
536 | dim (int): Number of input channels.
537 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
538 | """
539 | def __init__(self, dim, norm_layer=nn.LayerNorm):
540 | super().__init__()
541 | self.dim = dim
542 | self.expansion = nn.Linear(dim, 2 * dim, bias=False)
543 | self.norm = norm_layer(dim)
544 |
545 | def forward(self, x):
546 | """ Forward function.
547 |
548 | Args:
549 | x: Input feature, tensor size (B, D, H, W, C).
550 | """
551 | B, D, H, W, C = x.shape
552 |
553 | # padding
554 | pad_input = (H % 2 == 1) or (W % 2 == 1)
555 | if pad_input:
556 | x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
557 |
558 | x = self.norm(x)
559 | x = self.expansion(x)
560 |
561 | x = x.reshape(B, D, H, W, 2, 2, C//2)
562 | x = rearrange(x, 'b d h w h1 w1 c -> b d h h1 w w1 c')
563 | x = x.reshape(B, D, 2*H, 2*W, C//2)
564 |
565 | return x
566 |
567 |
568 | class AllPatchExpand(nn.Module):
569 | """ Patch Merging Layer
570 |
571 | Args:
572 | dim (int): Number of input channels.
573 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
574 | """
575 | def __init__(self, dim, norm_layer=nn.LayerNorm):
576 | super().__init__()
577 | self.dim = dim
578 | self.expand1 = PatchExpand(dim, norm_layer)
579 | self.expand2 = PatchExpand(dim, norm_layer)
580 | self.expand3 = PatchExpand(dim, norm_layer)
581 |
582 | def forward(self, x):
583 | """ Forward function.
584 |
585 | Args:
586 | x: Input feature, tensor size (B, C, D, H, W).
587 | """
588 | x1 = rearrange(x[0], 'b c d h w -> b d h w c')
589 | x2 = rearrange(x[1], 'b c d h w -> b d h w c')
590 | gate = rearrange(x[2], 'b c d h w -> b d h w c')
591 | x1 = self.expand1(x1)
592 | x2 = self.expand2(x2)
593 | gate = torch.sigmoid(self.expand3(gate))
594 | x1 = rearrange(x1, 'b d h w c -> b c d h w')
595 | x2 = rearrange(x2, 'b d h w c -> b c d h w')
596 | gate = rearrange(gate, 'b d h w c -> b c d h w')
597 | return [x1, x2, gate]
598 |
599 |
600 | class PatchEmbed(nn.Module):
601 |
602 | def __init__(self, img_size=(69,721,1440), embed_dim=96, patch_size=(1,4,4), norm_layer=None):
603 | super().__init__()
604 |
605 | if img_size[1] % 2:
606 | self.proj3d = nn.Conv3d(5, embed_dim, kernel_size=(patch_size[0], patch_size[1]+1, patch_size[2]), stride=patch_size)
607 | self.proj2d = nn.Conv2d(4, embed_dim, kernel_size=(patch_size[1]+1, patch_size[2]), stride=(patch_size[1], patch_size[2]))
608 | else:
609 | self.proj3d = nn.Conv3d(5, embed_dim, kernel_size=patch_size, stride=patch_size)
610 | self.proj2d = nn.Conv2d(4, embed_dim, kernel_size=patch_size[1:], stride=patch_size[1:])
611 |
612 | self.embed_dim = embed_dim
613 | if norm_layer is not None:
614 | self.norm = norm_layer(embed_dim)
615 | else:
616 | self.norm = None
617 |
618 | def forward(self, x):
619 | """Forward function."""
620 |
621 | B, C, H, W = x.shape
622 | x2d = x[:,:4,:,:] # b,4,721,1440
623 | x3d = x[:,4:,:,:].reshape(B, 5, C//5, H, W) # b,5,13,721,1440
624 | x2d = self.proj2d(x2d).unsqueeze(2) # b,c,1,180,360
625 | x3d = self.proj3d(x3d) # b,c,13,180,360
626 | x = torch.cat([x3d, x2d], dim=2) # b,c,14,180,360
627 |
628 | if self.norm is not None:
629 | D, H, W = x.size(2), x.size(3), x.size(4)
630 | x = x.flatten(2).transpose(1, 2)
631 | x = self.norm(x)
632 | x = x.transpose(1, 2).view(B, self.embed_dim, D, H, W)
633 |
634 | return x
635 |
636 |
637 | class AllPatchEmbed(nn.Module):
638 |
639 | def __init__(self, img_size=(69,721,1440), embed_dim=96, patch_size=(1,4,4), norm_layer=None):
640 | super().__init__()
641 |
642 | self.patch1 = PatchEmbed(img_size=img_size, embed_dim=embed_dim, patch_size=patch_size, norm_layer=norm_layer)
643 | self.patch2 = PatchEmbed(img_size=img_size, embed_dim=embed_dim, patch_size=patch_size, norm_layer=norm_layer)
644 | self.patch3 = PatchEmbed(img_size=img_size, embed_dim=embed_dim, patch_size=patch_size, norm_layer=norm_layer)
645 |
646 | self.patch_resolution = (img_size[0]//5+1, img_size[1]//patch_size[1], img_size[2]//patch_size[2])
647 |
648 | def forward(self, x1, x2, mask):
649 | """Forward function."""
650 |
651 | x1 = self.patch1(x1)
652 | x2 = self.patch2(x2)
653 | mask = torch.sigmoid(self.patch3(mask))
654 |
655 | return [x1, x2, mask]
656 |
657 |
658 | class PatchRecover(nn.Module):
659 |
660 | def __init__(self, img_size=(69,721,1440), embed_dim=96, patch_size=(1,4,4)):
661 | super().__init__()
662 |
663 | if img_size[1] % 2:
664 | self.proj3d = nn.ConvTranspose3d(embed_dim, 5, kernel_size=(patch_size[0], patch_size[1]+1, patch_size[2]), stride=patch_size)
665 | self.proj2d = nn.ConvTranspose2d(embed_dim, 4, kernel_size=(patch_size[1]+1, patch_size[2]), stride=(patch_size[1], patch_size[2]))
666 | else:
667 | self.proj3d = nn.ConvTranspose3d(embed_dim, 5, kernel_size=patch_size, stride=patch_size)
668 | self.proj2d = nn.ConvTranspose2d(embed_dim, 4, kernel_size=patch_size[1:], stride=patch_size[1:])
669 |
670 | def forward(self, x):
671 | """Forward function."""
672 |
673 | x2d = x[:,:,-1:,:,:].squeeze(2) # b,c,180,360
674 | x3d = x[:,:,:-1,:,:] # b,c,13,180,360
675 | x2d = self.proj2d(x2d) # b,4,721,1440
676 | x3d = self.proj3d(x3d).flatten(1,2) # b,65,721,1440
677 | x = torch.cat([x2d, x3d], dim=1)
678 |
679 | return x
680 |
681 |
682 | class BasicLayer(nn.Module):
683 |
684 | def __init__(self, dim, kernel, padding, num_heads, window_size, sample=None, use_checkpoint=False):
685 | super().__init__()
686 | self.sample = sample
687 | inchans = dim * 2 if self.sample == 'up' else dim
688 | self.conv = nn.Conv3d(inchans, dim, kernel_size=kernel, padding=padding)
689 | self.gateconv = nn.Conv3d(inchans, dim, kernel_size=kernel, padding=padding)
690 | self.gate = nn.Conv3d(inchans, dim, kernel_size=kernel, padding=padding)
691 | self.crosslayer = GatedCrossAttention(dim=dim, num_heads=num_heads, window_size=window_size, qkv_bias=True, use_checkpoint=use_checkpoint)
692 | if self.sample == 'down':
693 | self.samplelayer = AllPatchMerging(dim//2)
694 | elif self.sample == 'up':
695 | self.samplelayer = AllPatchExpand(dim*2)
696 |
697 | def concate(self, x, prev):
698 |
699 | x1 = torch.cat([x[0], prev[0]], dim=1)
700 | x2 = torch.cat([x[1], prev[1]], dim=1)
701 | gate = torch.cat([x[2], prev[2]], dim=1)
702 |
703 | return x1, x2, gate
704 |
705 | def forward(self, x, prev=None):
706 | if self.sample is not None:
707 | x = self.samplelayer(x)
708 | x1, x2, gate = x[0], x[1], x[2]
709 | if prev is not None:
710 | x1, x2, gate = self.concate(x, prev)
711 |
712 | x1 = F.silu(self.conv(x1))
713 | gate = torch.sigmoid(self.gate(gate))
714 | x2 = F.silu(self.gateconv(x2))
715 | x2 = x2 * gate
716 | out = self.crosslayer(x1, x2, gate)
717 | out = [out[0]+x[0], out[1]+x[1], out[2]+x[2]]
718 | return out
719 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import torch
4 | from utils.builder import ConfigBuilder
5 | import utils.misc as utils
6 | import yaml
7 | from utils.logger import get_logger
8 |
9 |
10 |
11 | def subprocess_fn(args):
12 | utils.setup_seed(args.seed * args.world_size + args.rank)
13 |
14 | logger = get_logger("train", args.rundir, utils.get_rank(), filename='iter.log')
15 | args.cfg_params["logger"] = logger
16 |
17 | # build config
18 | logger.info('Building config ...')
19 | builder = ConfigBuilder(**args.cfg_params)
20 |
21 | # build model
22 | logger.info('Building models ...')
23 | model = builder.get_model()
24 | model.kernel = utils.DistributedParallel_Model(model.kernel, args.local_rank)
25 |
26 | # build forecast model
27 | logger.info('Building forecast models ...')
28 | args.forecast_model = builder.get_forecast(args.local_rank)
29 |
30 | # build dataset
31 | logger.info('Building dataloaders ...')
32 | dataset_params = args.cfg_params['dataset']
33 | train_dataloader = builder.get_dataloader(dataset_params=dataset_params, split='train', batch_size=args.batch_size)
34 | valid_dataloader = builder.get_dataloader(dataset_params=dataset_params, split='valid', batch_size=args.batch_size)
35 | # logger.info(f'dataloader length {len(train_dataloader), len(valid_dataloader)}')
36 |
37 | # train
38 | logger.info('begin training ...')
39 | model.train(train_dataloader, valid_dataloader, logger, args)
40 | logger.info('training end ...')
41 |
42 |
43 | def main(args):
44 | if args.world_size > 1:
45 | utils.init_distributed_mode(args)
46 | else:
47 | args.rank = 0
48 | args.local_rank = 0
49 | args.distributed = False
50 | args.gpu = 0
51 | torch.cuda.set_device(args.gpu)
52 |
53 | args.cfg = os.path.join(args.rundir, 'training_options.yaml')
54 | with open(args.cfg, 'r') as cfg_file:
55 | cfg_params = yaml.load(cfg_file, Loader = yaml.FullLoader)
56 |
57 | cfg_params['dataloader']['num_workers'] = args.per_cpus
58 | cfg_params['dataset']['train']['length'] = args.lead_time // 6 + 2
59 | cfg_params['dataset']['valid']['length'] = args.lead_time // 6 + 2
60 | args.cfg_params = cfg_params
61 |
62 | args.rundir = os.path.join(args.rundir, f'mask{args.ratio}_lead{args.lead_time}h_res{args.resolution}')
63 | os.makedirs(args.rundir, exist_ok=True)
64 |
65 | if args.rank == 0:
66 | with open(os.path.join(args.rundir, 'train.yaml'), 'wt') as f:
67 | yaml.dump(vars(args), f, indent=2, sort_keys=False)
68 | # yaml.dump(cfg_params, f, indent=2, sort_keys=False)
69 |
70 | subprocess_fn(args)
71 |
72 |
73 | if __name__ == "__main__":
74 |
75 | parser = argparse.ArgumentParser()
76 |
77 | parser.add_argument('--seed', type = int, default = 0, help = 'seed')
78 | parser.add_argument('--cuda', type = int, default = 0, help = 'cuda id')
79 | parser.add_argument('--world_size', type = int, default = 4, help = 'number of progress')
80 | parser.add_argument('--per_cpus', type = int, default = 4, help = 'number of perCPUs to use')
81 | parser.add_argument('--max_epoch', type = int, default = 20, help = "maximum training epochs")
82 | parser.add_argument('--batch_size', type = int, default = 1, help = "batch size")
83 | parser.add_argument('--lead_time', type = int, default = 24, help = "lead time (h) for background")
84 | parser.add_argument('--ratio', type = float, default = 0.9, help = "mask ratio")
85 | parser.add_argument('--resolution', type = int, default = 128, help = "observation resolution")
86 | parser.add_argument('--init_method', type = str, default = 'tcp://127.0.0.1:19111', help = 'multi process init method')
87 | parser.add_argument('--rundir', type = str, default = './configs/FNP', help = 'where to save the results')
88 |
89 | args = parser.parse_args()
90 |
91 | main(args)
92 |
93 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenEarthLab/FNP/624e624be481cfa6a149613bd8a08f5df318cb10/utils/__init__.py
--------------------------------------------------------------------------------
/utils/builder.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data.distributed import DistributedSampler
2 | from utils.misc import get_rank, get_world_size, is_dist_avail_and_initialized
3 | import onnxruntime as ort
4 |
5 |
6 | class ConfigBuilder(object):
7 | """
8 | Configuration Builder.
9 |
10 | """
11 | def __init__(self, **params):
12 | """
13 | Set the default configuration for the configuration builder.
14 |
15 | Parameters
16 | ----------
17 |
18 | params: the configuration parameters.
19 | """
20 | super(ConfigBuilder, self).__init__()
21 | self.model_params = params.get('model', {})
22 | self.dataset_params = params.get('dataset', {'data_dir': 'data'})
23 | self.dataloader_params = params.get('dataloader', {})
24 |
25 | self.logger = params.get('logger', None)
26 |
27 | def get_model(self, model_params = None):
28 | """
29 | Get the model from configuration.
30 |
31 | Parameters
32 | ----------
33 |
34 | model_params: dict, optional, default: None. If model_params is provided, then use the parameters specified in the model_params to build the model. Otherwise, the model parameters in the self.params will be used to build the model.
35 | """
36 | from models.FNP import FNP
37 | from models.ConvCNP import ConvCNP
38 | from models.Adas import Adas
39 |
40 | if model_params is None:
41 | model_params = self.model_params
42 | type = model_params.get('type', 'FNP')
43 |
44 | if type == 'FNP':
45 | model = FNP(**model_params)
46 | elif type == 'ConvCNP':
47 | model = ConvCNP(**model_params)
48 | elif type == 'Adas':
49 | model = Adas(**model_params)
50 | else:
51 | raise NotImplementedError('Invalid model type.')
52 |
53 | return model
54 |
55 | def get_forecast(self, local_rank):
56 |
57 | # Set the behavier of onnxruntime
58 | options = ort.SessionOptions()
59 | options.enable_cpu_mem_arena=False
60 | options.enable_mem_pattern = False
61 | options.enable_mem_reuse = False
62 | # Increase the number for faster inference and more memory consumption
63 | options.intra_op_num_threads = 1
64 |
65 | # Set the behavier of cuda provider
66 | cuda_provider_options = {'device_id': local_rank, 'arena_extend_strategy':'kSameAsRequested',}
67 |
68 | # Initialize onnxruntime session for Pangu-Weather Models
69 | ort_session = ort.InferenceSession('./models/FengWu.onnx', sess_options=options, providers=[('CUDAExecutionProvider', cuda_provider_options)])
70 |
71 | return ort_session
72 |
73 | def get_dataset(self, dataset_params = None, split = 'train'):
74 | """
75 | Get the dataset from configuration.
76 |
77 | Parameters
78 | ----------
79 |
80 | dataset_params: dict, optional, default: None. If dataset_params is provided, then use the parameters specified in the dataset_params to build the dataset. Otherwise, the dataset parameters in the self.params will be used to build the dataset;
81 |
82 | split: str in ['train', 'test'], optional, default: 'train', the splitted dataset.
83 |
84 | Returns
85 | -------
86 |
87 | A torch.utils.data.Dataset item.
88 | """
89 | from datasets.era5_npy_f32 import era5_npy_f32
90 | if dataset_params is None:
91 | dataset_params = self.dataset_params
92 | dataset_params = dataset_params.get(split, None)
93 | if dataset_params is None:
94 | return None
95 | dataset = era5_npy_f32(split = split, **dataset_params)
96 | return dataset
97 |
98 | def get_sampler(self, dataset, split = 'train', drop_last=False):
99 | if split == 'train':
100 | shuffle = True
101 | else:
102 | shuffle = False
103 |
104 | if is_dist_avail_and_initialized():
105 | rank = get_rank()
106 | num_gpus = get_world_size()
107 | else:
108 | rank = 0
109 | num_gpus = 1
110 | sampler = DistributedSampler(dataset, rank=rank, shuffle=shuffle, num_replicas=num_gpus, drop_last=drop_last)
111 |
112 | return sampler
113 |
114 |
115 | def get_dataloader(self, dataset_params = None, split = 'train', batch_size = 1, dataloader_params = None, drop_last = True):
116 | """
117 | Get the dataloader from configuration.
118 |
119 | Parameters
120 | ----------
121 |
122 | dataset_params: dict, optional, default: None. If dataset_params is provided, then use the parameters specified in the dataset_params to build the dataset. Otherwise, the dataset parameters in the self.params will be used to build the dataset;
123 |
124 | split: str in ['train', 'test'], optional, default: 'train', the splitted dataset;
125 |
126 | batch_size: int, optional, default: None. If batch_size is None, then the batch size parameter in the self.params will be used to represent the batch size (If still not specified, default: 4);
127 |
128 | dataloader_params: dict, optional, default: None. If dataloader_params is provided, then use the parameters specified in the dataloader_params to get the dataloader. Otherwise, the dataloader parameters in the self.params will be used to get the dataloader.
129 |
130 | Returns
131 | -------
132 |
133 | A torch.utils.data.DataLoader item.
134 | """
135 | from torch.utils.data import DataLoader
136 |
137 | # if split != "train":
138 | # drop_last = True
139 | if dataloader_params is None:
140 | dataloader_params = self.dataloader_params
141 | dataset = self.get_dataset(dataset_params, split)
142 | if dataset is None:
143 | return None
144 | sampler = self.get_sampler(dataset, split, drop_last=drop_last)
145 |
146 | return DataLoader(
147 | dataset,
148 | batch_size = batch_size,
149 | sampler=sampler,
150 | drop_last=drop_last,
151 | **dataloader_params
152 | )
153 |
154 |
155 | def get_optimizer(model, optimizer_params = None, resume = False, resume_lr = None):
156 | """
157 | Get the optimizer from configuration.
158 |
159 | Parameters
160 | ----------
161 |
162 | model: a torch.nn.Module object, the model.
163 |
164 | optimizer_params: dict, optional, default: None. If optimizer_params is provided, then use the parameters specified in the optimizer_params to build the optimizer. Otherwise, the optimizer parameters in the self.params will be used to build the optimizer;
165 |
166 | resume: bool, optional, default: False, whether to resume training from an existing checkpoint;
167 |
168 | resume_lr: float, optional, default: None, the resume learning rate.
169 |
170 | Returns
171 | -------
172 |
173 | An optimizer for the given model.
174 | """
175 | from torch.optim import SGD, Adam, AdamW
176 | type = optimizer_params.get('type', 'AdamW')
177 | params = optimizer_params.get('params', {})
178 |
179 | if resume:
180 | network_params = [{'params': model.parameters(), 'initial_lr': resume_lr}]
181 | params.update(lr = resume_lr)
182 | else:
183 | network_params = model.parameters()
184 | if type == 'SGD':
185 | optimizer = SGD(network_params, **params)
186 | elif type == 'Adam':
187 | optimizer = Adam(network_params, **params)
188 | elif type == 'AdamW':
189 | optimizer = AdamW(network_params, **params)
190 | else:
191 | raise NotImplementedError('Invalid optimizer type.')
192 | return optimizer
193 |
194 | def get_lr_scheduler(optimizer, lr_scheduler_params = None, resume = False, resume_epoch = None, total_steps = None):
195 | """
196 | Get the learning rate scheduler from configuration.
197 |
198 | Parameters
199 | ----------
200 |
201 | optimizer: an optimizer;
202 |
203 | lr_scheduler_params: dict, optional, default: None. If lr_scheduler_params is provided, then use the parameters specified in the lr_scheduler_params to build the learning rate scheduler. Otherwise, the learning rate scheduler parameters in the self.params will be used to build the learning rate scheduler;
204 |
205 | resume: bool, optional, default: False, whether to resume training from an existing checkpoint;
206 |
207 | resume_epoch: int, optional, default: None, the epoch of the checkpoint.
208 |
209 | Returns
210 | -------
211 |
212 | A learning rate scheduler for the given optimizer.
213 | """
214 | from torch.optim.lr_scheduler import MultiStepLR, ExponentialLR, CyclicLR, CosineAnnealingLR, StepLR, OneCycleLR
215 | type = lr_scheduler_params.get('type', '')
216 | params = lr_scheduler_params.get('params', {})
217 | if resume:
218 | params.update(last_epoch = resume_epoch)
219 | if type == 'MultiStepLR':
220 | scheduler = MultiStepLR(optimizer, **params)
221 | elif type == 'ExponentialLR':
222 | scheduler = ExponentialLR(optimizer, **params)
223 | elif type == 'CyclicLR':
224 | scheduler = CyclicLR(optimizer, **params)
225 | elif type == 'CosineAnnealingLR':
226 | scheduler = CosineAnnealingLR(optimizer, **params)
227 | elif type == 'StepLR':
228 | scheduler = StepLR(optimizer, **params)
229 | elif type == 'OneCycleLR':
230 | scheduler = OneCycleLR(optimizer, total_steps=total_steps, **params)
231 | elif type == '':
232 | scheduler = None
233 | else:
234 | raise NotImplementedError('Invalid learning rate scheduler type.')
235 | return scheduler
236 |
--------------------------------------------------------------------------------
/utils/logger.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2 | import logging
3 | import os
4 |
5 | logger_initialized = {}
6 |
7 | def get_logger(name, save_dir, distributed_rank, filename="log.log", resume=False):
8 | logger = logging.getLogger(name)
9 | if name in logger_initialized:
10 | return logger
11 |
12 | logger.propagate = False
13 | # don't log results for the non-master process
14 | if distributed_rank > 0:
15 | logger.setLevel(logging.ERROR)
16 | logger.setLevel(logging.WARNING)
17 | return logger
18 | formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s")
19 |
20 | ch = logging.StreamHandler()
21 | ch.setLevel(logging.INFO)
22 | ch.setFormatter(formatter)
23 | logger.addHandler(ch)
24 |
25 | if save_dir:
26 | if resume:
27 | fh = logging.FileHandler(os.path.join(save_dir, filename), mode='a')
28 | else:
29 | fh = logging.FileHandler(os.path.join(save_dir, filename), mode='w')
30 | fh.setLevel(logging.INFO)
31 | fh.setFormatter(formatter)
32 | logger.addHandler(fh)
33 |
34 | logger.setLevel(logging.INFO)
35 |
36 | logger_initialized[name] = True
37 |
38 | return logger
39 |
--------------------------------------------------------------------------------
/utils/metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | @torch.jit.script
5 | def lat(j: torch.Tensor, num_lat: int) -> torch.Tensor:
6 | return 90. - j * 180./float(num_lat-1)
7 |
8 | @torch.jit.script
9 | def latitude_weighting_factor_torch(j: torch.Tensor, num_lat: int, s: torch.Tensor) -> torch.Tensor:
10 | return num_lat * torch.cos(torch.pi/180. * lat(j, num_lat)) / s
11 |
12 | @torch.jit.script
13 | def weighted_rmse_torch_channels(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
14 | #takes in arrays of size [n, c, h, w] and returns latitude-weighted rmse for each chann
15 | num_lat = pred.shape[2]
16 | #num_long = target.shape[2]
17 | lat_t = torch.arange(start=0, end=num_lat, device=pred.device)
18 |
19 | s = torch.sum(torch.cos(torch.pi/180. * lat(lat_t, num_lat)))
20 | weight = torch.reshape(latitude_weighting_factor_torch(lat_t, num_lat, s), (1, 1, -1, 1))
21 | result = torch.sqrt(torch.mean(weight * (pred - target)**2., dim=(-1,-2)))
22 | return result
23 |
24 | @torch.jit.script
25 | def weighted_rmse_torch(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
26 | result = weighted_rmse_torch_channels(pred, target)
27 | return torch.mean(result, dim=0)
28 |
29 | def WRMSE(pred, gt, data_std):
30 | return weighted_rmse_torch(pred, gt) * data_std
31 |
--------------------------------------------------------------------------------
/utils/misc.py:
--------------------------------------------------------------------------------
1 | import torch.distributed as dist
2 | import torch
3 | import os
4 | import numpy as np
5 | import random
6 | from typing import Any
7 | import re
8 |
9 |
10 | def reduce_dict(input_dict, average=True):
11 | """
12 | Args:
13 | input_dict (dict): all the values will be reduced
14 | average (bool): whether to do average or sum
15 | Reduce the values in the dictionary from all processes so that all processes
16 | have the averaged results. Returns a dict with the same fields as
17 | input_dict, after reduction.
18 | """
19 | world_size = get_world_size()
20 | if world_size < 2:
21 | return input_dict
22 | with torch.no_grad():
23 | names = []
24 | values = []
25 | # sort the keys so that they are consistent across processes
26 | for k in sorted(input_dict.keys()):
27 | names.append(k)
28 | values.append(input_dict[k])
29 | values = torch.stack(values, dim=0)
30 | dist.all_reduce(values)
31 | if average:
32 | values /= world_size
33 | reduced_dict = {k: v for k, v in zip(names, values)}
34 | return reduced_dict
35 |
36 |
37 | def setup_for_distributed(is_master):
38 | """
39 | This function disables printing when not in master process
40 | """
41 | import builtins as __builtin__
42 | builtin_print = __builtin__.print
43 |
44 | def print(*args, **kwargs):
45 | force = kwargs.pop('force', False)
46 | if is_master or force:
47 | builtin_print(*args, **kwargs)
48 |
49 | __builtin__.print = print
50 |
51 |
52 | def is_dist_avail_and_initialized():
53 | if not dist.is_available():
54 | return False
55 | if not dist.is_initialized():
56 | return False
57 | return True
58 |
59 |
60 | def get_world_size():
61 | if not is_dist_avail_and_initialized():
62 | return 1
63 | return dist.get_world_size()
64 |
65 |
66 | def get_rank():
67 | if not is_dist_avail_and_initialized():
68 | return 0
69 | return dist.get_rank()
70 |
71 |
72 | def is_main_process():
73 | return get_rank() == 0
74 |
75 |
76 | def save_on_master(*args, **kwargs):
77 | if is_main_process():
78 | torch.save(*args, **kwargs)
79 |
80 |
81 | def get_ip(ip_list):
82 | if "," in ip_list:
83 | ip_list = ip_list.split(',')[0]
84 | if "[" in ip_list:
85 | ipbefore_4, ip4 = ip_list.split('[')
86 | ip4 = re.findall(r"\d+", ip4)[0]
87 | ip1, ip2, ip3 = ipbefore_4.split('-')[-4:-1]
88 | else:
89 | ip1,ip2,ip3,ip4 = ip_list.split('-')[-4:]
90 | ip_addr = "tcp://" + ".".join([ip1, ip2, ip3, ip4]) + ":"
91 | return ip_addr
92 |
93 |
94 |
95 | def init_distributed_mode(args):
96 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
97 | args.rank = int(os.environ["RANK"])
98 | args.world_size = int(os.environ['WORLD_SIZE'])
99 | args.local_rank = int(os.environ['LOCAL_RANK'])
100 | elif 'SLURM_PROCID' in os.environ:
101 | args.rank = int(os.environ['SLURM_PROCID'])
102 | args.local_rank = int(os.environ['SLURM_LOCALID'])
103 | args.world_size = int(os.environ['SLURM_NTASKS'])
104 | ip_addr = get_ip(os.environ['SLURM_STEP_NODELIST'])
105 | port = int(os.environ['SLURM_SRUN_COMM_PORT'])
106 | # args.init_method = ip_addr + str(port)
107 | args.init_method = ip_addr + args.init_method.split(":")[-1]
108 | else:
109 | print('Not using distributed mode')
110 | args.distributed = False
111 | return
112 |
113 | args.distributed = True
114 |
115 | torch.cuda.set_device(args.local_rank)
116 | args.dist_backend = 'nccl'
117 | print('| distributed init (rank {}, local_rank {}): {}'.format(
118 | args.rank, args.local_rank, args.init_method), flush=True)
119 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.init_method,
120 | world_size=args.world_size, rank=args.rank)
121 | torch.distributed.barrier()
122 | setup_for_distributed(args.rank == 0)
123 |
124 |
125 | def DistributedParallel_Model(model, gpu_num, find_unused_parameters=False):
126 | if is_dist_avail_and_initialized():
127 | device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
128 | if device == torch.device('cpu'):
129 | raise EnvironmentError('No GPUs, cannot initialize multigpu training.')
130 | model.to(device)
131 | # optimizer = torch.optim.Adam(model.parameters(), lr=lr)
132 | # model, optimizer = amp.initialize(model, optimizer, opt_level="O0")
133 | ddp_sub_model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[gpu_num], find_unused_parameters=find_unused_parameters)
134 | model = ddp_sub_model
135 |
136 | # model.to(device)
137 | # model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[gpu_num])
138 | # model_without_ddp = model.module
139 | else:
140 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
141 | model.to(device)
142 | # optimizer = torch.optim.Adam(model.parameters(), lr=lr)
143 | # model, optimizer = amp.initialize(model, optimizer, opt_level="O0")
144 | # for key in model.model:
145 | # model.model[key].to(device)
146 |
147 | return model
148 |
149 |
150 | class Dict(dict):
151 | def __getattr__(self, name: str) -> Any:
152 | try:
153 | return self[name]
154 | except KeyError:
155 | raise AttributeError(name)
156 |
157 | def __setattr__(self, name: str, value: Any) -> None:
158 | self[name] = value
159 |
160 | def __delattr__(self, name: str) -> None:
161 | del self[name]
162 | # __setattr__ = dict.__setitem__
163 | # __getattr__ = dict.__getitem__
164 |
165 | def dictToObj(dictObj):
166 | if not isinstance(dictObj, dict):
167 | return dictObj
168 | d = Dict()
169 | for k, v in dictObj.items():
170 | d[k] = dictToObj(v)
171 | return d
172 |
173 |
174 | def setup_seed(seed):
175 | torch.manual_seed(seed)
176 | torch.cuda.manual_seed_all(seed)
177 | np.random.seed(seed)
178 | random.seed(seed)
179 | # torch.backends.cudnn.deterministic = True
180 |
181 |
182 | def named_params_and_buffers(module):
183 | assert isinstance(module, torch.nn.Module)
184 | return list(module.named_parameters()) + list(module.named_buffers())
185 |
186 | def check_ddp_consistency(module, ignore_regex=None):
187 | assert isinstance(module, torch.nn.Module)
188 | for name, tensor in named_params_and_buffers(module):
189 | fullname = type(module).__name__ + '.' + name
190 | if ignore_regex is not None and re.fullmatch(ignore_regex, fullname):
191 | continue
192 | tensor = tensor.detach()
193 | # if tensor.is_floating_point():
194 | # tensor = nan_to_num(tensor)
195 | other = tensor.clone()
196 | torch.distributed.broadcast(tensor=other, src=0)
197 | # print(fullname, tensor.sum(), other.sum())
198 | assert (tensor == other).all(), fullname
199 |
--------------------------------------------------------------------------------