├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── data └── ILINet.csv ├── environment.yml ├── environment_11.yml ├── environment_distcal.yml ├── environment_distcal_cpu.yml ├── models ├── fnpmodels.py └── utils.py ├── run.py ├── test_ili.py ├── test_regress_ili.py ├── train_ili.py └── transform_pred.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.pdf 3 | *.png 4 | *.pkl 5 | *.pth 6 | envs/* 7 | .vscode/* 8 | env/* 9 | env11/* 10 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "DistCal"] 2 | path = DistCal 3 | url = https://github.com/kage08/DistCal.git 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 AdityaLab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # When in Doubt: Neural Non-Parametric Uncertainty Quantification for Epidemic Forecasting 2 | 3 | **Paper Link:** https://arxiv.org/abs/2106.03904 4 | 5 | ## Setup 6 | 7 | First install Anaconda. The dependencies are listed in `environment.yml` file. Make sure you make changes to version of `cudatoolkit` if applicable. 8 | 9 | Then run the following commands: 10 | 11 | ```bash 12 | conda env create --prefix ./envs/epifnp --file environment.yml 13 | source activate ./envs/epifnp 14 | ``` 15 | 16 | ## Directory structure 17 | 18 | ``` 19 | -data 20 | - ILINet.csv -> wILI values for seasons 2003 to 2020 collected from flusight 21 | - model_chkp -> stores intermediate model parameters while training 22 | - models/fnpmodels.py -> implementation of EpiFNP modules 23 | - plots -> plots of predictions 24 | - saves -> saves predictions for models as pkl files 25 | - train_ili.py -> training script for EpiFNP 26 | - test_ili.py -> inference of trained model 27 | - test_regress.py -> Autoregressive inference using a trained model 28 | ``` 29 | 30 | ## Training 31 | 32 | Run: 33 | 34 | ``` 35 | python train_ili.py -y -w -a trans -n -e 36 | ``` 37 | 38 | Or run `run.py` to run all experiments. 39 | 40 | Prediction plots will be saved in `plots/Test.png` and model in `model_chkp` folder. 41 | 42 | ## Inference 43 | 44 | Run: 45 | 46 | ```bash 47 | python test_ili.py -y -w -a trans -n 48 | ``` 49 | 50 | for normal inference. 51 | 52 | Run: 53 | 54 | ```bash 55 | python test_regress_ili.py -y -w -a trans -n 56 | ``` 57 | 58 | for auto-regressive inference. Note: Train and use a 1 week ahead model for AR inference. 59 | 60 | The predictions and plots are saved in `saves` and `plots` respectively. 61 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | channels: 2 | - pytorch 3 | - defaults 4 | dependencies: 5 | - python>=3.6 6 | - cudatoolkit=10.2 7 | - pytorch=1.8 8 | - torchvision 9 | - torchaudio 10 | - numpy 11 | - scipy 12 | - scikit-learn 13 | - pandas 14 | - matplotlib 15 | - networkx 16 | - pip 17 | - tqdm 18 | # For development 19 | - ipykernel 20 | - jupyter 21 | - black 22 | - pylint 23 | - flake8 24 | -------------------------------------------------------------------------------- /environment_11.yml: -------------------------------------------------------------------------------- 1 | channels: 2 | - pytorch 3 | - defaults 4 | dependencies: 5 | - python>=3.6 6 | - cudatoolkit>=11.3 7 | - pytorch>=1.8 8 | - torchvision 9 | - torchaudio 10 | - numpy 11 | - scipy 12 | - scikit-learn 13 | - pandas 14 | - matplotlib 15 | - networkx 16 | - pip 17 | - tqdm 18 | # For development 19 | - ipykernel 20 | - jupyter 21 | - black 22 | - pylint 23 | - flake8 24 | -------------------------------------------------------------------------------- /environment_distcal.yml: -------------------------------------------------------------------------------- 1 | channels: 2 | - defaults 3 | dependencies: 4 | - python>=3.6 5 | - cudatoolkit 6 | - tensorflow-gpu>=2.0 7 | - numpy 8 | - scipy 9 | - scikit-learn 10 | - pandas 11 | - matplotlib 12 | - pip 13 | - tqdm 14 | - joblib 15 | # For development 16 | - ipykernel 17 | - jupyter 18 | - black 19 | - pylint 20 | - flake8 21 | -------------------------------------------------------------------------------- /environment_distcal_cpu.yml: -------------------------------------------------------------------------------- 1 | channels: 2 | - defaults 3 | dependencies: 4 | - python>=3.6 5 | - tensorflow>=2.0 6 | - numpy 7 | - scipy 8 | - scikit-learn 9 | - pandas 10 | - matplotlib 11 | - pip 12 | - tqdm 13 | - joblib 14 | # For development 15 | - ipykernel 16 | - jupyter 17 | - black 18 | - pylint 19 | - flake8 20 | -------------------------------------------------------------------------------- /models/fnpmodels.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import math 4 | import torch.nn.functional as F 5 | from models.utils import ( 6 | Normal, 7 | float_tensor, 8 | logitexp, 9 | sample_DAG, 10 | sample_Clique, 11 | sample_bipartite, 12 | Flatten, 13 | one_hot, 14 | ) 15 | from torch.distributions import Categorical 16 | 17 | 18 | class TransformerAttn(nn.Module): 19 | """ 20 | Module that calculates self-attention weights using transformer like attention 21 | """ 22 | 23 | def __init__(self, dim_in=40, value_dim=40, key_dim=40) -> None: 24 | """ 25 | param dim_in: Dimensionality of input sequence 26 | param value_dim: Dimension of value transform 27 | param key_dim: Dimension of key transform 28 | """ 29 | super(TransformerAttn, self).__init__() 30 | self.value_layer = nn.Linear(dim_in, value_dim) 31 | self.query_layer = nn.Linear(dim_in, value_dim) 32 | self.key_layer = nn.Linear(dim_in, key_dim) 33 | 34 | def forward(self, seq): 35 | """ 36 | param seq: Sequence in dimension [Seq len, Batch, Hidden size] 37 | """ 38 | seq_in = seq.transpose(0, 1) 39 | value = self.value_layer(seq_in) 40 | query = self.query_layer(seq_in) 41 | keys = self.key_layer(seq_in) 42 | weights = (value @ query.transpose(1, 2)) / math.sqrt(seq.shape[-1]) 43 | weights = torch.softmax(weights, -1) 44 | return (weights @ keys).transpose(1, 0) 45 | 46 | def forward_mask(self, seq, mask): 47 | """ 48 | param seq: Sequence in dimension [Seq len, Batch, Hidden size] 49 | """ 50 | seq_in = seq.transpose(0, 1) 51 | value = self.value_layer(seq_in) 52 | query = self.query_layer(seq_in) 53 | keys = self.key_layer(seq_in) 54 | weights = (value @ query.transpose(1, 2)) / math.sqrt(seq.shape[-1]) 55 | weights = torch.exp(weights) 56 | weights = (weights.transpose(1, 2) * mask.transpose(1, 0)).transpose(1, 2) 57 | weights = weights / (weights.sum(-1, keepdim=True)) 58 | return (weights @ keys).transpose(1, 0) * mask 59 | 60 | 61 | class TanhAttn(nn.Module): 62 | """ 63 | Module that calculates self-attention weights as done in Epideep 64 | """ 65 | 66 | def __init__(self, dim_in=40, value_dim=40, key_dim=40) -> None: 67 | """ 68 | param dim_in: Dimensionality of input sequence 69 | param value_dim: Dimension of value transform 70 | param key_dim: Dimension of key transform 71 | """ 72 | super(TanhAttn, self).__init__() 73 | self.value_layer = nn.Linear(dim_in, value_dim) 74 | self.query_layer = nn.Linear(dim_in, value_dim) 75 | self.key_layer = nn.Linear(dim_in, key_dim) 76 | 77 | def forward(self, seq): 78 | """ 79 | param seq: Sequence in dimension [Seq len, Batch, Hidden size] 80 | """ 81 | seq_in = seq.transpose(0, 1) 82 | value = self.value_layer(seq_in) 83 | value = torch.tanh(value) 84 | query = value 85 | keys = seq_in 86 | weights = value @ query.transpose(1, 2) 87 | weights = torch.softmax(weights, -1) 88 | return (weights @ keys).transpose(1, 0) 89 | 90 | def forward_mask(self, seq, mask): 91 | """ 92 | param seq: Sequence in dimension [Seq len, Batch, Hidden size] 93 | """ 94 | seq_in = seq.transpose(0, 1) 95 | value = self.value_layer(seq_in) 96 | value = torch.tanh(value) 97 | query = value 98 | keys = seq_in 99 | weights = value @ query.transpose(1, 2) 100 | weights = torch.exp(weights) 101 | weights = (weights.transpose(1, 2) * mask.transpose(1, 0)).transpose(1, 2) 102 | weights = weights / (weights.sum(-1, keepdim=True)) 103 | return (weights @ keys).transpose(1, 0) 104 | 105 | 106 | class LatentAtten(nn.Module): 107 | """ 108 | Attention on latent representation 109 | """ 110 | 111 | def __init__(self, h_dim, key_dim=None) -> None: 112 | super(LatentAtten, self).__init__() 113 | if key_dim is None: 114 | key_dim = h_dim 115 | self.key_dim = key_dim 116 | self.key_layer = nn.Linear(h_dim, key_dim) 117 | self.query_layer = nn.Linear(h_dim, key_dim) 118 | 119 | def forward(self, h_M, h_R): 120 | key = self.key_layer(h_M) 121 | query = self.query_layer(h_R) 122 | atten = (key @ query.transpose(0, 1)) / math.sqrt(self.key_dim) 123 | atten = torch.softmax(atten, 1) 124 | return atten 125 | 126 | 127 | class EmbedAttenSeq(nn.Module): 128 | """ 129 | Module to embed a sequence. Adds Attention module to 130 | """ 131 | 132 | def __init__( 133 | self, 134 | dim_seq_in: int = 5, 135 | dim_metadata: int = 3, 136 | rnn_out: int = 40, 137 | dim_out: int = 50, 138 | n_layers: int = 1, 139 | bidirectional: bool = False, 140 | attn=TransformerAttn, 141 | dropout=0.0, 142 | ) -> None: 143 | """ 144 | param dim_seq_in: Dimensionality of input vector (no. of age groups) 145 | param dim_out: Dimensionality of output vector 146 | param dim_metadata: Dimensions of metadata for all sequences 147 | param rnn_out: output dimension for rnn 148 | """ 149 | super(EmbedAttenSeq, self).__init__() 150 | 151 | self.dim_seq_in = dim_seq_in 152 | self.dim_metadata = dim_metadata 153 | self.rnn_out = rnn_out 154 | self.dim_out = dim_out 155 | self.bidirectional = bidirectional 156 | 157 | self.rnn = nn.GRU( 158 | input_size=self.dim_seq_in, 159 | hidden_size=self.rnn_out // 2 if self.bidirectional else self.rnn_out, 160 | bidirectional=bidirectional, 161 | num_layers=n_layers, 162 | dropout=dropout, 163 | ) 164 | self.attn_layer = attn(self.rnn_out, self.rnn_out, self.rnn_out) 165 | self.out_layer = [ 166 | nn.Linear( 167 | in_features=self.rnn_out + self.dim_metadata, out_features=self.dim_out 168 | ), 169 | nn.Tanh(), 170 | nn.Dropout(dropout), 171 | ] 172 | self.out_layer = nn.Sequential(*self.out_layer) 173 | 174 | def forward_mask(self, seqs, metadata, mask): 175 | # Take last output from GRU 176 | latent_seqs = self.rnn(seqs)[0] 177 | latent_seqs = latent_seqs 178 | latent_seqs = self.attn_layer.forward_mask(latent_seqs, mask) 179 | latent_seqs = latent_seqs.sum(0) 180 | out = self.out_layer(torch.cat([latent_seqs, metadata], dim=1)) 181 | return out 182 | 183 | def forward(self, seqs, metadata): 184 | # Take last output from GRU 185 | latent_seqs = self.rnn(seqs)[0] 186 | latent_seqs = self.attn_layer(latent_seqs).sum(0) 187 | out = self.out_layer(torch.cat([latent_seqs, metadata], dim=1)) 188 | return out 189 | 190 | 191 | class EmbedAttenSeq2(nn.Module): 192 | """ 193 | Module to embed a sequence. Adds Attention module to 194 | """ 195 | 196 | def __init__( 197 | self, 198 | dim_seq_in: int = 5, 199 | dim_label_in: int = 1, 200 | dim_metadata: int = 3, 201 | rnn_out: int = 40, 202 | dim_out: int = 50, 203 | n_layers: int = 1, 204 | bidirectional: bool = False, 205 | attn=TransformerAttn, 206 | dropout=0.0, 207 | ) -> None: 208 | """ 209 | param dim_seq_in: Dimensionality of input vector (no. of age groups) 210 | param dim_out: Dimensionality of output vector 211 | param dim_metadata: Dimensions of metadata for all sequences 212 | param rnn_out: output dimension for rnn 213 | """ 214 | super(EmbedAttenSeq2, self).__init__() 215 | 216 | self.dim_seq_in = dim_seq_in 217 | self.dim_label_in = dim_label_in 218 | self.dim_metadata = dim_metadata 219 | self.rnn_out = rnn_out 220 | self.dim_out = dim_out 221 | self.bidirectional = bidirectional 222 | 223 | self.rnn = nn.GRU( 224 | input_size=self.dim_seq_in, 225 | hidden_size=self.rnn_out // 2 if self.bidirectional else self.rnn_out, 226 | bidirectional=bidirectional, 227 | num_layers=n_layers, 228 | dropout=dropout, 229 | ) 230 | self.attn_layer = attn(self.rnn_out, self.rnn_out, self.rnn_out) 231 | self.out_layer = [ 232 | nn.Linear( 233 | in_features=self.rnn_out + self.dim_metadata + self.dim_label_in, 234 | out_features=self.dim_out, 235 | ), 236 | nn.Tanh(), 237 | nn.Dropout(dropout), 238 | ] 239 | self.out_layer = nn.Sequential(*self.out_layer) 240 | 241 | def forward_mask(self, seqs, metadata, mask, labels): 242 | # Take last output from GRU 243 | latent_seqs = self.rnn(seqs)[0] 244 | latent_seqs = latent_seqs 245 | latent_seqs = self.attn_layer.forward_mask(latent_seqs, mask) 246 | latent_seqs = latent_seqs.sum(0) 247 | out = self.out_layer(torch.cat([latent_seqs, metadata, labels], dim=1)) 248 | return out 249 | 250 | def forward(self, seqs, metadata, labels): 251 | # Take last output from GRU 252 | latent_seqs = self.rnn(seqs)[0] 253 | latent_seqs = self.attn_layer(latent_seqs).sum(0) 254 | out = self.out_layer(torch.cat([latent_seqs, metadata, labels], dim=1)) 255 | return out 256 | 257 | 258 | class EmbedSeq(nn.Module): 259 | """ 260 | Module to embed a sequence 261 | """ 262 | 263 | def __init__( 264 | self, 265 | dim_seq_in: int = 5, 266 | dim_metadata: int = 3, 267 | rnn_out: int = 40, 268 | dim_out: int = 50, 269 | n_layers: int = 1, 270 | bidirectional: bool = False, 271 | ) -> None: 272 | """ 273 | param dim_seq_in: Dimensionality of input vector (no. of age groups) 274 | param dim_out: Dimensionality of output vector 275 | param dim_metadata: Dimensions of metadata for all sequences 276 | param rnn_out: output dimension for rnn 277 | """ 278 | super(EmbedSeq, self).__init__() 279 | 280 | self.dim_seq_in = dim_seq_in 281 | self.dim_metadata = dim_metadata 282 | self.rnn_out = rnn_out 283 | self.dim_out = dim_out 284 | self.bidirectional = bidirectional 285 | 286 | self.rnn = nn.GRU( 287 | input_size=self.dim_seq_in, 288 | hidden_size=self.rnn_out // 2 if self.bidirectional else self.rnn_out, 289 | bidirectional=bidirectional, 290 | num_layers=n_layers, 291 | ) 292 | self.out_layer = [ 293 | nn.Linear( 294 | in_features=self.rnn_out + self.dim_metadata, out_features=self.dim_out 295 | ), 296 | nn.Tanh(), 297 | ] 298 | self.out_layer = nn.Sequential(*self.out_layer) 299 | 300 | def forward_mask(self, seqs, metadata, mask): 301 | # Take last output from GRU 302 | latent_seqs = self.rnn(seqs)[0] 303 | latent_seqs = latent_seqs * mask 304 | latent_seqs = latent_seqs.sum(0) 305 | 306 | out = self.out_layer(torch.cat([latent_seqs, metadata], dim=1)) 307 | return out 308 | 309 | def forward(self, seqs, metadata): 310 | # Take last output from GRU 311 | latent_seqs = self.rnn(seqs)[0][0] 312 | 313 | out = self.out_layer(torch.cat([latent_seqs, metadata], dim=1)) 314 | return out 315 | 316 | 317 | class EmbedSeq2(nn.Module): 318 | """ 319 | Module to embed a sequence 320 | """ 321 | 322 | def __init__( 323 | self, 324 | dim_seq_in: int = 5, 325 | dim_label_in: int = 1, 326 | dim_metadata: int = 3, 327 | rnn_out: int = 40, 328 | dim_out: int = 50, 329 | n_layers: int = 1, 330 | bidirectional: bool = False, 331 | ) -> None: 332 | """ 333 | param dim_seq_in: Dimensionality of input vector 334 | param dim_label_in: Dimensions of label vector (usually 1) 335 | param dim_out: Dimensionality of output vector 336 | param dim_metadata: Dimensions of metadata for all sequences 337 | param rnn_out: output dimension for rnn 338 | """ 339 | super(EmbedSeq2, self).__init__() 340 | 341 | self.dim_seq_in = dim_seq_in 342 | self.dim_label_in = dim_label_in 343 | self.dim_metadata = dim_metadata 344 | self.rnn_out = rnn_out 345 | self.dim_out = dim_out 346 | self.bidirectional = bidirectional 347 | 348 | self.rnn = nn.GRU( 349 | input_size=self.dim_seq_in, 350 | hidden_size=self.rnn_out // 2 if self.bidirectional else self.rnn_out, 351 | bidirectional=bidirectional, 352 | num_layers=n_layers, 353 | ) 354 | self.out_layer = [ 355 | nn.Linear( 356 | in_features=self.rnn_out + self.dim_metadata + self.dim_label_in, 357 | out_features=self.dim_out, 358 | ), 359 | nn.Tanh(), 360 | ] 361 | self.out_layer = nn.Sequential(*self.out_layer) 362 | 363 | def forward_mask(self, seqs, metadata, mask, labels): 364 | # Take last output from GRU 365 | latent_seqs = self.rnn(seqs)[0] 366 | latent_seqs = latent_seqs * mask 367 | latent_seqs = latent_seqs.sum(0) 368 | 369 | out = self.out_layer(torch.cat([latent_seqs, metadata, labels], dim=1)) 370 | return out 371 | 372 | def forward(self, seqs, metadata, labels): 373 | # Take last output from GRU 374 | latent_seqs = self.rnn(seqs)[0][0] 375 | 376 | out = self.out_layer(torch.cat([latent_seqs, metadata, labels], dim=1)) 377 | return out 378 | 379 | 380 | class RegressionFNP(nn.Module): 381 | """ 382 | Functional Neural Process for regression 383 | """ 384 | 385 | def __init__( 386 | self, 387 | dim_x=1, 388 | dim_y=1, 389 | dim_h=50, 390 | transf_y=None, 391 | n_layers=1, 392 | use_plus=True, 393 | num_M=100, 394 | dim_u=1, 395 | dim_z=1, 396 | fb_z=0.0, 397 | use_ref_labels=True, 398 | use_DAG=True, 399 | add_atten=False, 400 | ): 401 | """ 402 | :param dim_x: Dimensionality of the input 403 | :param dim_y: Dimensionality of the output 404 | :param dim_h: Dimensionality of the hidden layers 405 | :param transf_y: Transformation of the output (e.g. standardization) 406 | :param n_layers: How many hidden layers to use 407 | :param use_plus: Whether to use the FNP+ 408 | :param num_M: How many points exist in the training set that are not part of the reference set 409 | :param dim_u: Dimensionality of the latents in the embedding space 410 | :param dim_z: Dimensionality of the latents that summarize the parents 411 | :param fb_z: How many free bits do we allow for the latent variable z 412 | """ 413 | super(RegressionFNP, self).__init__() 414 | 415 | self.num_M = num_M 416 | self.dim_x = dim_x 417 | self.dim_y = dim_y 418 | self.dim_h = dim_h 419 | self.dim_u = dim_u 420 | self.dim_z = dim_z 421 | self.use_plus = use_plus 422 | self.fb_z = fb_z 423 | self.transf_y = transf_y 424 | self.use_ref_labels = use_ref_labels 425 | self.use_DAG = use_DAG 426 | self.add_atten = add_atten 427 | # normalizes the graph such that inner products correspond to averages of the parents 428 | self.norm_graph = lambda x: x / (torch.sum(x, 1, keepdim=True) + 1e-8) 429 | 430 | self.register_buffer("lambda_z", float_tensor(1).fill_(1e-8)) 431 | 432 | # function that assigns the edge probabilities in the graph 433 | self.pairwise_g_logscale = nn.Parameter( 434 | float_tensor(1).fill_(math.log(math.sqrt(self.dim_u))) 435 | ) 436 | self.pairwise_g = lambda x: logitexp( 437 | -0.5 438 | * torch.sum( 439 | torch.pow(x[:, self.dim_u :] - x[:, 0 : self.dim_u], 2), 1, keepdim=True 440 | ) 441 | / self.pairwise_g_logscale.exp() 442 | ).view(x.size(0), 1) 443 | # transformation of the input 444 | 445 | init = [nn.Linear(dim_x, self.dim_h), nn.ReLU()] 446 | for i in range(n_layers - 1): 447 | init += [nn.Linear(self.dim_h, self.dim_h), nn.ReLU()] 448 | self.cond_trans = nn.Sequential(*init) 449 | # p(u|x) 450 | self.p_u = nn.Linear(self.dim_h, 2 * self.dim_u) 451 | # q(z|x) 452 | self.q_z = nn.Linear(self.dim_h, 2 * self.dim_z) 453 | # for p(z|A, XR, yR) 454 | if use_ref_labels: 455 | self.trans_cond_y = nn.Linear(self.dim_y, 2 * self.dim_z) 456 | 457 | # p(y|z) or p(y|z, u) 458 | self.output = nn.Sequential( 459 | nn.Linear( 460 | self.dim_z if not self.use_plus else self.dim_z + self.dim_u, self.dim_h 461 | ), 462 | nn.ReLU(), 463 | nn.Linear(self.dim_h, 2 * dim_y), 464 | ) 465 | if self.add_atten: 466 | self.atten_layer = LatentAtten(self.dim_h) 467 | 468 | def forward(self, XR, yR, XM, yM, kl_anneal=1.0): 469 | X_all = torch.cat([XR, XM], dim=0) 470 | H_all = self.cond_trans(X_all) 471 | 472 | # get U 473 | pu_mean_all, pu_logscale_all = torch.split(self.p_u(H_all), self.dim_u, dim=1) 474 | pu = Normal(pu_mean_all, pu_logscale_all) 475 | u = pu.rsample() 476 | 477 | # get G 478 | if self.use_DAG: 479 | G = sample_DAG(u[0 : XR.size(0)], self.pairwise_g, training=self.training) 480 | else: 481 | G = sample_Clique( 482 | u[0 : XR.size(0)], self.pairwise_g, training=self.training 483 | ) 484 | 485 | # get A 486 | A = sample_bipartite( 487 | u[XR.size(0) :], u[0 : XR.size(0)], self.pairwise_g, training=self.training 488 | ) 489 | if self.add_atten: 490 | HR, HM = H_all[0 : XR.size(0)], H_all[XR.size(0) :] 491 | atten = self.atten_layer(HM, HR) 492 | A = A * atten 493 | 494 | # get Z 495 | qz_mean_all, qz_logscale_all = torch.split(self.q_z(H_all), self.dim_z, 1) 496 | qz = Normal(qz_mean_all, qz_logscale_all) 497 | z = qz.rsample() 498 | if self.use_ref_labels: 499 | cond_y_mean, cond_y_logscale = torch.split( 500 | self.trans_cond_y(yR), self.dim_z, 1 501 | ) 502 | pz_mean_all = torch.mm( 503 | self.norm_graph(torch.cat([G, A], dim=0)), 504 | cond_y_mean + qz_mean_all[0 : XR.size(0)], 505 | ) 506 | pz_logscale_all = torch.mm( 507 | self.norm_graph(torch.cat([G, A], dim=0)), 508 | cond_y_logscale + qz_logscale_all[0 : XR.size(0)], 509 | ) 510 | else: 511 | pz_mean_all = torch.mm( 512 | self.norm_graph(torch.cat([G, A], dim=0)), qz_mean_all[0 : XR.size(0)], 513 | ) 514 | pz_logscale_all = torch.mm( 515 | self.norm_graph(torch.cat([G, A], dim=0)), 516 | qz_logscale_all[0 : XR.size(0)], 517 | ) 518 | 519 | pz = Normal(pz_mean_all, pz_logscale_all) 520 | 521 | pqz_all = pz.log_prob(z) - qz.log_prob(z) 522 | 523 | # apply free bits for the latent z 524 | if self.fb_z > 0: 525 | log_qpz = -torch.sum(pqz_all) 526 | 527 | if self.training: 528 | if log_qpz.item() > self.fb_z * z.size(0) * z.size(1) * (1 + 0.05): 529 | self.lambda_z = torch.clamp( 530 | self.lambda_z * (1 + 0.1), min=1e-8, max=1.0 531 | ) 532 | elif log_qpz.item() < self.fb_z * z.size(0) * z.size(1): 533 | self.lambda_z = torch.clamp( 534 | self.lambda_z * (1 - 0.1), min=1e-8, max=1.0 535 | ) 536 | 537 | log_pqz_R = self.lambda_z * torch.sum(pqz_all[0 : XR.size(0)]) 538 | log_pqz_M = self.lambda_z * torch.sum(pqz_all[XR.size(0) :]) 539 | 540 | else: 541 | log_pqz_R = torch.sum(pqz_all[0 : XR.size(0)]) 542 | log_pqz_M = torch.sum(pqz_all[XR.size(0) :]) 543 | 544 | final_rep = z if not self.use_plus else torch.cat([z, u], dim=1) 545 | 546 | mean_y, logstd_y = torch.split(self.output(final_rep), 1, dim=1) 547 | logstd_y = torch.log(0.1 + 0.9 * F.softplus(logstd_y)) 548 | 549 | mean_yR, mean_yM = mean_y[0 : XR.size(0)], mean_y[XR.size(0) :] 550 | logstd_yR, logstd_yM = logstd_y[0 : XR.size(0)], logstd_y[XR.size(0) :] 551 | 552 | # logp(R) 553 | pyR = Normal(mean_yR, logstd_yR) 554 | log_pyR = torch.sum(pyR.log_prob(yR)) 555 | 556 | # logp(M|S) 557 | pyM = Normal(mean_yM, logstd_yM) 558 | log_pyM = torch.sum(pyM.log_prob(yM)) 559 | 560 | obj_R = (log_pyR + log_pqz_R) / float(self.num_M) 561 | obj_M = (log_pyM + log_pqz_M) / float(XM.size(0)) 562 | 563 | if self.use_ref_labels: 564 | obj = obj_R + obj_M 565 | else: 566 | obj = obj_M 567 | 568 | loss = -obj 569 | 570 | return loss, mean_y, logstd_y 571 | 572 | def predict(self, x_new, XR, yR, sample=True): 573 | 574 | H_all = self.cond_trans(torch.cat([XR, x_new], 0)) 575 | 576 | # get U 577 | pu_mean_all, pu_logscale_all = torch.split(self.p_u(H_all), self.dim_u, dim=1) 578 | pu = Normal(pu_mean_all, pu_logscale_all) 579 | u = pu.rsample() 580 | 581 | A = sample_bipartite( 582 | u[XR.size(0) :], u[0 : XR.size(0)], self.pairwise_g, training=False 583 | ) 584 | 585 | if self.add_atten: 586 | HR, HM = H_all[0 : XR.size(0)], H_all[XR.size(0) :] 587 | atten = self.atten_layer(HM, HR) 588 | A = A * atten 589 | 590 | pz_mean_all, pz_logscale_all = torch.split( 591 | self.q_z(H_all[0 : XR.size(0)]), self.dim_z, 1 592 | ) 593 | if self.use_ref_labels: 594 | cond_y_mean, cond_y_logscale = torch.split( 595 | self.trans_cond_y(yR), self.dim_z, 1 596 | ) 597 | pz_mean_all = torch.mm(self.norm_graph(A), cond_y_mean + pz_mean_all) 598 | pz_logscale_all = torch.mm( 599 | self.norm_graph(A), cond_y_logscale + pz_logscale_all 600 | ) 601 | else: 602 | pz_mean_all = torch.mm(self.norm_graph(A), pz_mean_all) 603 | pz_logscale_all = torch.mm(self.norm_graph(A), pz_logscale_all) 604 | pz = Normal(pz_mean_all, pz_logscale_all) 605 | 606 | z = pz.rsample() 607 | final_rep = z if not self.use_plus else torch.cat([z, u[XR.size(0) :]], dim=1) 608 | 609 | mean_y, logstd_y = torch.split(self.output(final_rep), 1, dim=1) 610 | logstd_y = torch.log(0.1 + 0.9 * F.softplus(logstd_y)) 611 | 612 | init_y = Normal(mean_y, logstd_y) 613 | if sample: 614 | y_new_i = init_y.sample() 615 | else: 616 | y_new_i = mean_y 617 | 618 | y_pred = y_new_i 619 | 620 | if self.transf_y is not None: 621 | if torch.cuda.is_available(): 622 | y_pred = self.transf_y.inverse_transform(y_pred.cpu().data.numpy()) 623 | else: 624 | y_pred = self.transf_y.inverse_transform(y_pred.data.numpy()) 625 | 626 | return y_pred, mean_y, logstd_y, u[XR.size(0) :], u[: XR.size(0)], init_y, A 627 | 628 | 629 | class SelfAttention(nn.Module): 630 | """ 631 | Simple attention layer 632 | """ 633 | 634 | def __init__(self, hidden_dim, n_heads=8): 635 | super(SelfAttention, self).__init__() 636 | self._W_k = nn.ModuleList( 637 | [nn.Linear(hidden_dim, hidden_dim) for _ in range(n_heads)] 638 | ) 639 | self._W_v = nn.ModuleList( 640 | [nn.Linear(hidden_dim, hidden_dim) for _ in range(n_heads)] 641 | ) 642 | self._W_q = nn.ModuleList( 643 | [nn.Linear(hidden_dim, hidden_dim) for _ in range(n_heads)] 644 | ) 645 | self._W = nn.Linear(n_heads * hidden_dim, hidden_dim) 646 | self.n_heads = n_heads 647 | 648 | def forward(self, x): 649 | outs = [] 650 | for i in range(self.n_heads): 651 | k_ = self._W_k[i](x) 652 | v_ = self._W_v[i](x) 653 | q_ = self._W_q[i](x) 654 | wts = torch.softmax(v_ @ q_.T, dim=-1) 655 | out = wts @ k_ 656 | outs.append(out) 657 | outs = torch.cat(outs, dim=-1) 658 | outs = self._W(outs) 659 | return outs 660 | 661 | 662 | class RegressionFNP2(nn.Module): 663 | """ 664 | Functional Neural Process for regression 665 | """ 666 | 667 | def __init__( 668 | self, 669 | dim_x=1, 670 | dim_y=1, 671 | dim_h=50, 672 | transf_y=None, 673 | n_layers=1, 674 | use_plus=True, 675 | num_M=100, 676 | dim_u=1, 677 | dim_z=1, 678 | fb_z=0.0, 679 | use_ref_labels=True, 680 | use_DAG=True, 681 | add_atten=False, 682 | ): 683 | """ 684 | :param dim_x: Dimensionality of the input 685 | :param dim_y: Dimensionality of the output 686 | :param dim_h: Dimensionality of the hidden layers 687 | :param transf_y: Transformation of the output (e.g. standardization) 688 | :param n_layers: How many hidden layers to use 689 | :param use_plus: Whether to use the FNP+ 690 | :param num_M: How many points exist in the training set that are not part of the reference set 691 | :param dim_u: Dimensionality of the latents in the embedding space 692 | :param dim_z: Dimensionality of the latents that summarize the parents 693 | :param fb_z: How many free bits do we allow for the latent variable z 694 | """ 695 | super(RegressionFNP2, self).__init__() 696 | 697 | self.num_M = num_M 698 | self.dim_x = dim_x 699 | self.dim_y = dim_y 700 | self.dim_h = dim_h 701 | self.dim_u = dim_u 702 | self.dim_z = dim_z 703 | self.use_plus = use_plus 704 | self.fb_z = fb_z 705 | self.transf_y = transf_y 706 | self.use_ref_labels = use_ref_labels 707 | self.use_DAG = use_DAG 708 | self.add_atten = add_atten 709 | # normalizes the graph such that inner products correspond to averages of the parents 710 | self.norm_graph = lambda x: x / (torch.sum(x, 1, keepdim=True) + 1e-8) 711 | 712 | self.register_buffer("lambda_z", float_tensor(1).fill_(1e-8)) 713 | 714 | # function that assigns the edge probabilities in the graph 715 | self.pairwise_g_logscale = nn.Parameter( 716 | float_tensor(1).fill_(math.log(math.sqrt(self.dim_u))) 717 | ) 718 | self.pairwise_g = lambda x: logitexp( 719 | -0.5 720 | * torch.sum( 721 | torch.pow(x[:, self.dim_u :] - x[:, 0 : self.dim_u], 2), 1, keepdim=True 722 | ) 723 | / self.pairwise_g_logscale.exp() 724 | ).view(x.size(0), 1) 725 | # transformation of the input 726 | 727 | init = [nn.Linear(dim_x, self.dim_h), nn.ReLU()] 728 | for i in range(n_layers - 1): 729 | init += [nn.Linear(self.dim_h, self.dim_h), nn.ReLU()] 730 | self.cond_trans = nn.Sequential(*init) 731 | # p(u|x) 732 | self.p_u = nn.Linear(self.dim_h, 2 * self.dim_u) 733 | # q(z|x) 734 | self.q_z = nn.Linear(self.dim_h, 2 * self.dim_z) 735 | # for p(z|A, XR, yR) 736 | if use_ref_labels: 737 | self.trans_cond_y = nn.Linear(self.dim_y, 2 * self.dim_z) 738 | 739 | # p(y|z) or p(y|z, u) 740 | # TODO: Add for sR input 741 | self.atten_ref = SelfAttention(self.dim_x) 742 | self.output = nn.Sequential( 743 | nn.Linear( 744 | self.dim_z + self.dim_x 745 | if not self.use_plus 746 | else self.dim_z + self.dim_u + self.dim_x, 747 | self.dim_h, 748 | ), 749 | nn.ReLU(), 750 | nn.Linear(self.dim_h, 2 * dim_y), 751 | ) 752 | if self.add_atten: 753 | self.atten_layer = LatentAtten(self.dim_h) 754 | 755 | def forward(self, XR, yR, XM, yM, kl_anneal=1.0): 756 | # sR = self.atten_ref(XR).mean(dim=0) 757 | sR = XR.mean(dim=0) 758 | X_all = torch.cat([XR, XM], dim=0) 759 | H_all = self.cond_trans(X_all) 760 | 761 | # get U 762 | pu_mean_all, pu_logscale_all = torch.split(self.p_u(H_all), self.dim_u, dim=1) 763 | pu = Normal(pu_mean_all, pu_logscale_all) 764 | u = pu.rsample() 765 | 766 | # get G 767 | if self.use_DAG: 768 | G = sample_DAG(u[0 : XR.size(0)], self.pairwise_g, training=self.training) 769 | else: 770 | G = sample_Clique( 771 | u[0 : XR.size(0)], self.pairwise_g, training=self.training 772 | ) 773 | 774 | # get A 775 | A = sample_bipartite( 776 | u[XR.size(0) :], u[0 : XR.size(0)], self.pairwise_g, training=self.training 777 | ) 778 | if self.add_atten: 779 | HR, HM = H_all[0 : XR.size(0)], H_all[XR.size(0) :] 780 | atten = self.atten_layer(HM, HR) 781 | A = A * atten 782 | 783 | # get Z 784 | qz_mean_all, qz_logscale_all = torch.split(self.q_z(H_all), self.dim_z, 1) 785 | qz = Normal(qz_mean_all, qz_logscale_all) 786 | z = qz.rsample() 787 | if self.use_ref_labels: 788 | cond_y_mean, cond_y_logscale = torch.split( 789 | self.trans_cond_y(yR), self.dim_z, 1 790 | ) 791 | pz_mean_all = torch.mm( 792 | self.norm_graph(torch.cat([G, A], dim=0)), 793 | cond_y_mean + qz_mean_all[0 : XR.size(0)], 794 | ) 795 | pz_logscale_all = torch.mm( 796 | self.norm_graph(torch.cat([G, A], dim=0)), 797 | cond_y_logscale + qz_logscale_all[0 : XR.size(0)], 798 | ) 799 | else: 800 | pz_mean_all = torch.mm( 801 | self.norm_graph(torch.cat([G, A], dim=0)), qz_mean_all[0 : XR.size(0)], 802 | ) 803 | pz_logscale_all = torch.mm( 804 | self.norm_graph(torch.cat([G, A], dim=0)), 805 | qz_logscale_all[0 : XR.size(0)], 806 | ) 807 | 808 | pz = Normal(pz_mean_all, pz_logscale_all) 809 | 810 | pqz_all = pz.log_prob(z) - qz.log_prob(z) 811 | 812 | # apply free bits for the latent z 813 | if self.fb_z > 0: 814 | log_qpz = -torch.sum(pqz_all) 815 | 816 | if self.training: 817 | if log_qpz.item() > self.fb_z * z.size(0) * z.size(1) * (1 + 0.05): 818 | self.lambda_z = torch.clamp( 819 | self.lambda_z * (1 + 0.1), min=1e-8, max=1.0 820 | ) 821 | elif log_qpz.item() < self.fb_z * z.size(0) * z.size(1): 822 | self.lambda_z = torch.clamp( 823 | self.lambda_z * (1 - 0.1), min=1e-8, max=1.0 824 | ) 825 | 826 | log_pqz_R = self.lambda_z * torch.sum(pqz_all[0 : XR.size(0)]) 827 | log_pqz_M = self.lambda_z * torch.sum(pqz_all[XR.size(0) :]) 828 | 829 | else: 830 | log_pqz_R = torch.sum(pqz_all[0 : XR.size(0)]) 831 | log_pqz_M = torch.sum(pqz_all[XR.size(0) :]) 832 | 833 | final_rep = z if not self.use_plus else torch.cat([z, u], dim=1) 834 | sR = sR.repeat(final_rep.shape[0], 1) 835 | final_rep = torch.cat([sR, final_rep], dim=-1) 836 | 837 | mean_y, logstd_y = torch.split(self.output(final_rep), 1, dim=1) 838 | logstd_y = torch.log(0.1 + 0.9 * F.softplus(logstd_y)) 839 | 840 | mean_yR, mean_yM = mean_y[0 : XR.size(0)], mean_y[XR.size(0) :] 841 | logstd_yR, logstd_yM = logstd_y[0 : XR.size(0)], logstd_y[XR.size(0) :] 842 | 843 | # logp(R) 844 | pyR = Normal(mean_yR, logstd_yR) 845 | log_pyR = torch.sum(pyR.log_prob(yR)) 846 | 847 | # logp(M|S) 848 | pyM = Normal(mean_yM, logstd_yM) 849 | log_pyM = torch.sum(pyM.log_prob(yM)) 850 | 851 | obj_R = (log_pyR + log_pqz_R) / float(self.num_M) 852 | obj_M = (log_pyM + log_pqz_M) / float(XM.size(0)) 853 | 854 | if self.use_ref_labels: 855 | obj = obj_R + obj_M 856 | else: 857 | obj = obj_M 858 | 859 | loss = -obj 860 | 861 | return loss, mean_y, logstd_y 862 | 863 | def predict(self, x_new, XR, yR, sample=True): 864 | # sR = self.atten_ref(XR).mean(dim=0) 865 | sR = XR.mean(dim=0) 866 | H_all = self.cond_trans(torch.cat([XR, x_new], 0)) 867 | 868 | # get U 869 | pu_mean_all, pu_logscale_all = torch.split(self.p_u(H_all), self.dim_u, dim=1) 870 | pu = Normal(pu_mean_all, pu_logscale_all) 871 | u = pu.rsample() 872 | 873 | A = sample_bipartite( 874 | u[XR.size(0) :], u[0 : XR.size(0)], self.pairwise_g, training=False 875 | ) 876 | 877 | if self.add_atten: 878 | HR, HM = H_all[0 : XR.size(0)], H_all[XR.size(0) :] 879 | atten = self.atten_layer(HM, HR) 880 | A = A * atten 881 | 882 | pz_mean_all, pz_logscale_all = torch.split( 883 | self.q_z(H_all[0 : XR.size(0)]), self.dim_z, 1 884 | ) 885 | if self.use_ref_labels: 886 | cond_y_mean, cond_y_logscale = torch.split( 887 | self.trans_cond_y(yR), self.dim_z, 1 888 | ) 889 | pz_mean_all = torch.mm(self.norm_graph(A), cond_y_mean + pz_mean_all) 890 | pz_logscale_all = torch.mm( 891 | self.norm_graph(A), cond_y_logscale + pz_logscale_all 892 | ) 893 | else: 894 | pz_mean_all = torch.mm(self.norm_graph(A), pz_mean_all) 895 | pz_logscale_all = torch.mm(self.norm_graph(A), pz_logscale_all) 896 | pz = Normal(pz_mean_all, pz_logscale_all) 897 | 898 | z = pz.rsample() 899 | final_rep = z if not self.use_plus else torch.cat([z, u[XR.size(0) :]], dim=1) 900 | sR = sR.repeat(final_rep.shape[0], 1) 901 | final_rep = torch.cat([sR, final_rep], dim=-1) 902 | 903 | mean_y, logstd_y = torch.split(self.output(final_rep), 1, dim=1) 904 | logstd_y = torch.log(0.1 + 0.9 * F.softplus(logstd_y)) 905 | 906 | init_y = Normal(mean_y, logstd_y) 907 | if sample: 908 | y_new_i = init_y.sample() 909 | else: 910 | y_new_i = mean_y 911 | 912 | y_pred = y_new_i 913 | 914 | if self.transf_y is not None: 915 | if torch.cuda.is_available(): 916 | y_pred = self.transf_y.inverse_transform(y_pred.cpu().data.numpy()) 917 | else: 918 | y_pred = self.transf_y.inverse_transform(y_pred.data.numpy()) 919 | 920 | return y_pred, mean_y, logstd_y, u[XR.size(0) :], u[: XR.size(0)], init_y, A 921 | 922 | 923 | class ClassificationFNP(nn.Module): 924 | """ 925 | Functional Neural Process for classification with the LeNet-5 architecture 926 | """ 927 | 928 | def __init__( 929 | self, 930 | dim_x=(1, 28, 28), 931 | dim_y=10, 932 | use_plus=True, 933 | num_M=1, 934 | dim_u=32, 935 | dim_z=64, 936 | fb_z=1.0, 937 | ): 938 | """ 939 | :param dim_x: Dimensionality of the input 940 | :param dim_y: Dimensionality of the output 941 | :param use_plus: Whether to use the FNP+ 942 | :param num_M: How many points exist in the training set that are not part of the reference set 943 | :param dim_u: Dimensionality of the latents in the embedding space 944 | :param dim_z: Dimensionality of the latents that summarize the parents 945 | :param fb_z: How many free bits do we allow for the latent variable z 946 | """ 947 | super(ClassificationFNP, self).__init__() 948 | 949 | self.num_M = num_M 950 | self.dim_x = dim_x 951 | self.dim_y = dim_y 952 | self.dim_u = dim_u 953 | self.dim_z = dim_z 954 | self.use_plus = use_plus 955 | self.fb_z = fb_z 956 | # normalizes the graph such that inner products correspond to averages of the parents 957 | self.norm_graph = lambda x: x / (torch.sum(x, 1, keepdim=True) + 1e-8) 958 | 959 | self.register_buffer("lambda_z", float_tensor(1).fill_(1e-8)) 960 | 961 | # function that assigns the edge probabilities in the graph 962 | self.pairwise_g_logscale = nn.Parameter( 963 | float_tensor(1).fill_(math.log(math.sqrt(self.dim_u))) 964 | ) 965 | self.pairwise_g = lambda x: logitexp( 966 | -0.5 967 | * torch.sum( 968 | torch.pow(x[:, self.dim_u :] - x[:, 0 : self.dim_u], 2), 1, keepdim=True 969 | ) 970 | / self.pairwise_g_logscale.exp() 971 | ).view(x.size(0), 1) 972 | 973 | # transformation of the input 974 | self.cond_trans = nn.Sequential( 975 | nn.Conv2d(self.dim_x[0], 20, 5), 976 | nn.ReLU(), 977 | nn.MaxPool2d(2), 978 | nn.Conv2d(20, 50, 5), 979 | nn.ReLU(), 980 | nn.MaxPool2d(2), 981 | Flatten(), 982 | nn.Linear(800, 500), 983 | ) 984 | 985 | # p(u|x) 986 | self.p_u = nn.Sequential(nn.ReLU(), nn.Linear(500, 2 * self.dim_u)) 987 | # q(z|x) 988 | self.q_z = nn.Sequential(nn.ReLU(), nn.Linear(500, 2 * self.dim_z)) 989 | # for p(z|A, XR, yR) 990 | self.trans_cond_y = nn.Linear(self.dim_y, 2 * self.dim_z) 991 | 992 | # p(y|z) or p(y|z, u) 993 | self.output = nn.Sequential( 994 | nn.ReLU(), 995 | nn.Linear( 996 | self.dim_z if not self.use_plus else self.dim_z + self.dim_u, dim_y 997 | ), 998 | ) 999 | 1000 | def forward(self, XM, yM, XR, yR, kl_anneal=1.0): 1001 | X_all = torch.cat([XR, XM], dim=0) 1002 | H_all = self.cond_trans(X_all) 1003 | 1004 | # get U 1005 | pu_mean_all, pu_logscale_all = torch.split(self.p_u(H_all), self.dim_u, dim=1) 1006 | pu = Normal(pu_mean_all, pu_logscale_all) 1007 | u = pu.rsample() 1008 | 1009 | # get G 1010 | G = sample_DAG(u[0 : XR.size(0)], self.pairwise_g, training=self.training) 1011 | 1012 | # get A 1013 | A = sample_bipartite( 1014 | u[XR.size(0) :], u[0 : XR.size(0)], self.pairwise_g, training=self.training 1015 | ) 1016 | 1017 | # get Z 1018 | qz_mean_all, qz_logscale_all = torch.split(self.q_z(H_all), self.dim_z, 1) 1019 | qz = Normal(qz_mean_all, qz_logscale_all) 1020 | z = qz.rsample() 1021 | 1022 | cond_y_mean, cond_y_logscale = torch.split( 1023 | self.trans_cond_y(one_hot(yR, n_classes=self.dim_y)), self.dim_z, 1 1024 | ) 1025 | pz_mean_all = torch.mm( 1026 | self.norm_graph(torch.cat([G, A], dim=0)), 1027 | cond_y_mean + qz_mean_all[0 : XR.size(0)], 1028 | ) 1029 | pz_logscale_all = torch.mm( 1030 | self.norm_graph(torch.cat([G, A], dim=0)), 1031 | cond_y_logscale + qz_logscale_all[0 : XR.size(0)], 1032 | ) 1033 | 1034 | pz = Normal(pz_mean_all, pz_logscale_all) 1035 | 1036 | pqz_all = pz.log_prob(z) - qz.log_prob(z) 1037 | 1038 | # apply free bits for the latent z 1039 | if self.fb_z > 0: 1040 | log_qpz = -torch.sum(pqz_all) 1041 | 1042 | if self.training: 1043 | if log_qpz.item() > self.fb_z * z.size(0) * z.size(1) * (1 + 0.05): 1044 | self.lambda_z = torch.clamp( 1045 | self.lambda_z * (1 + 0.1), min=1e-8, max=1.0 1046 | ) 1047 | elif log_qpz.item() < self.fb_z * z.size(0) * z.size(1): 1048 | self.lambda_z = torch.clamp( 1049 | self.lambda_z * (1 - 0.1), min=1e-8, max=1.0 1050 | ) 1051 | 1052 | log_pqz_R = self.lambda_z * torch.sum(pqz_all[0 : XR.size(0)]) 1053 | log_pqz_M = self.lambda_z * torch.sum(pqz_all[XR.size(0) :]) 1054 | 1055 | else: 1056 | log_pqz_R = torch.sum(pqz_all[0 : XR.size(0)]) 1057 | log_pqz_M = torch.sum(pqz_all[XR.size(0) :]) 1058 | 1059 | final_rep = z if not self.use_plus else torch.cat([z, u], dim=1) 1060 | 1061 | logits_all = self.output(final_rep) 1062 | 1063 | pyR = Categorical(logits=logits_all[0 : XR.size(0)]) 1064 | log_pyR = torch.sum(pyR.log_prob(yR)) 1065 | 1066 | pyM = Categorical(logits=logits_all[XR.size(0) :]) 1067 | log_pyM = torch.sum(pyM.log_prob(yM)) 1068 | 1069 | obj_R = (log_pyR + log_pqz_R) / float(self.num_M) 1070 | obj_M = (log_pyM + log_pqz_M) / float(XM.size(0)) 1071 | 1072 | obj = obj_R + obj_M 1073 | 1074 | loss = -obj 1075 | 1076 | return loss 1077 | 1078 | def get_pred_logits(self, x_new, XR, yR, n_samples=100): 1079 | H_all = self.cond_trans(torch.cat([XR, x_new], 0)) 1080 | 1081 | # get U 1082 | pu_mean_all, pu_logscale_all = torch.split(self.p_u(H_all), self.dim_u, dim=1) 1083 | pu = Normal(pu_mean_all, pu_logscale_all) 1084 | 1085 | qz_mean_R, qz_logscale_R = torch.split( 1086 | self.q_z(H_all[0 : XR.size(0)]), self.dim_z, 1 1087 | ) 1088 | 1089 | logits = float_tensor(x_new.size(0), self.dim_y, n_samples) 1090 | for i in range(n_samples): 1091 | u = pu.rsample() 1092 | 1093 | A = sample_bipartite( 1094 | u[XR.size(0) :], u[0 : XR.size(0)], self.pairwise_g, training=False 1095 | ) 1096 | 1097 | cond_y_mean, cond_y_logscale = torch.split( 1098 | self.trans_cond_y(one_hot(yR, n_classes=self.dim_y)), self.dim_z, 1 1099 | ) 1100 | pz_mean_M = torch.mm(self.norm_graph(A), cond_y_mean + qz_mean_R) 1101 | pz_logscale_M = torch.mm( 1102 | self.norm_graph(A), cond_y_logscale + qz_logscale_R 1103 | ) 1104 | pz = Normal(pz_mean_M, pz_logscale_M) 1105 | 1106 | z = pz.rsample() 1107 | 1108 | final_rep = ( 1109 | z if not self.use_plus else torch.cat([z, u[XR.size(0) :]], dim=1) 1110 | ) 1111 | 1112 | logits[:, :, i] = F.log_softmax(self.output(final_rep), 1) 1113 | 1114 | logits = torch.logsumexp(logits, 2) - math.log(n_samples) 1115 | 1116 | return logits 1117 | 1118 | def predict(self, x_new, XR, yR, n_samples=100): 1119 | logits = self.get_pred_logits(x_new, XR, yR, n_samples=n_samples) 1120 | return torch.argmax(logits, 1) 1121 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from torch.distributions import Bernoulli 6 | from itertools import product 7 | 8 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 9 | float_tensor = ( 10 | torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor 11 | ) 12 | 13 | 14 | def logitexp(logp): 15 | # https://github.com/pytorch/pytorch/issues/4007 16 | pos = torch.clamp(logp, min=-0.69314718056) 17 | neg = torch.clamp(logp, max=-0.69314718056) 18 | neg_val = neg - torch.log(1 - torch.exp(neg)) 19 | pos_val = -torch.log(torch.clamp(torch.expm1(-pos), min=1e-20)) 20 | return pos_val + neg_val 21 | 22 | 23 | def one_hot(x: torch.Tensor, n_classes=10): 24 | x_onehot = float_tensor(x.size(0), n_classes).zero_() 25 | x_onehot.scatter_(1, x[:, None], 1) 26 | 27 | return x_onehot 28 | 29 | 30 | class LogitRelaxedBernoulli(object): 31 | def __init__(self, logits, temperature=0.3, **kwargs): 32 | self.logits = logits 33 | self.temperature = temperature 34 | 35 | def rsample(self): 36 | eps = torch.clamp( 37 | torch.rand( 38 | self.logits.size(), dtype=self.logits.dtype, device=self.logits.device 39 | ), 40 | min=1e-6, 41 | max=1 - 1e-6, 42 | ) 43 | y = (self.logits + torch.log(eps) - torch.log(1.0 - eps)) / self.temperature 44 | return y 45 | 46 | def log_prob(self, value): 47 | return ( 48 | math.log(self.temperature) 49 | - self.temperature * value 50 | + self.logits 51 | - 2 * F.softplus(-self.temperature * value + self.logits) 52 | ) 53 | 54 | 55 | class Normal(object): 56 | def __init__(self, means, logscales, **kwargs): 57 | self.means = means 58 | self.logscales = logscales 59 | 60 | def log_prob(self, value): 61 | log_prob = torch.pow(value - self.means, 2) 62 | log_prob *= -(1 / (2.0 * self.logscales.mul(2.0).exp())) 63 | log_prob -= self.logscales + 0.5 * math.log(2.0 * math.pi) 64 | return log_prob 65 | 66 | def sample(self, **kwargs): 67 | eps = torch.normal( 68 | float_tensor(self.means.size()).zero_(), 69 | float_tensor(self.means.size()).fill_(1), 70 | ) 71 | return self.means + self.logscales.exp() * eps 72 | 73 | def rsample(self, **kwargs): 74 | return self.sample(**kwargs) 75 | 76 | 77 | def order_z(z): 78 | # scalar ordering function 79 | if z.size(1) == 1: 80 | return z 81 | log_cdf = torch.sum( 82 | torch.log(0.5 + 0.5 * torch.erf(z / math.sqrt(2))), dim=1, keepdim=True 83 | ) 84 | return log_cdf 85 | 86 | 87 | def sample_DAG(Z, g, training=True, temperature=0.3): 88 | # get the indices of an upper triangular adjacency matrix that represents the DAG 89 | idx_utr = np.triu_indices(Z.size(0), 1) 90 | 91 | # get the ordering 92 | ordering = order_z(Z) 93 | # sort the latents according to the ordering 94 | sort_idx = torch.sort(torch.squeeze(ordering), 0)[1] 95 | Y = Z[sort_idx, :] 96 | # form the latent pairs for the edges 97 | Z_pairs = torch.cat([Y[idx_utr[0]], Y[idx_utr[1]]], 1) 98 | # get the logits for the edges in the DAG 99 | logits = g(Z_pairs) 100 | 101 | if training: 102 | p_edges = LogitRelaxedBernoulli(logits=logits, temperature=temperature) 103 | G = torch.sigmoid(p_edges.rsample()) 104 | else: 105 | p_edges = Bernoulli(logits=logits) 106 | G = p_edges.sample() 107 | 108 | # embed the upper triangular to the adjacency matrix 109 | unsorted_G = float_tensor(Z.size(0), Z.size(0)).zero_() 110 | unsorted_G[idx_utr[0], idx_utr[1]] = G.squeeze() 111 | # unsort the dag to conform to the data order 112 | original_idx = torch.sort(sort_idx)[1] 113 | unsorted_G = unsorted_G[original_idx, :][:, original_idx] 114 | 115 | return unsorted_G 116 | 117 | 118 | def sample_Clique(Z, g, training=True, temperature=0.3): 119 | # get the indices of an upper triangular adjacency matrix that represents the DAG 120 | # idx_utr = np.triu_indices(Z.size(0), 1) 121 | idx_utr = np.triu_indices(Z.size(0), 1) 122 | idx_ltr = np.triu_indices(Z.size(0), 1) 123 | idx_ltr = idx_ltr[1], idx_ltr[0] 124 | idx_utr = ( 125 | np.concatenate([idx_utr[0], idx_ltr[0]]), 126 | np.concatenate([idx_utr[1], idx_ltr[1]]), 127 | ) 128 | 129 | # get the ordering 130 | ordering = order_z(Z) 131 | # sort the latents according to the ordering 132 | sort_idx = torch.sort(torch.squeeze(ordering), 0)[1] 133 | Y = Z[sort_idx, :] 134 | # form the latent pairs for the edges 135 | Z_pairs = torch.cat([Y[idx_utr[0]], Y[idx_utr[1]]], 1) 136 | # get the logits for the edges in the DAG 137 | logits = g(Z_pairs) 138 | 139 | if training: 140 | p_edges = LogitRelaxedBernoulli(logits=logits, temperature=temperature) 141 | G = torch.sigmoid(p_edges.rsample()) 142 | else: 143 | p_edges = Bernoulli(logits=logits) 144 | G = p_edges.sample() 145 | 146 | # embed the upper triangular to the adjacency matrix 147 | unsorted_G = float_tensor(Z.size(0), Z.size(0)).zero_() 148 | unsorted_G[idx_utr[0], idx_utr[1]] = G.squeeze() 149 | # unsort the dag to conform to the data order 150 | original_idx = torch.sort(sort_idx)[1] 151 | unsorted_G = unsorted_G[original_idx, :][:, original_idx] 152 | 153 | return unsorted_G 154 | 155 | 156 | def sample_bipartite(Z1, Z2, g, training=True, temperature=0.3): 157 | indices = [] 158 | for element in product(range(Z1.size(0)), range(Z2.size(0))): 159 | indices.append(element) 160 | indices = np.array(indices) 161 | Z_pairs = torch.cat([Z1[indices[:, 0]], Z2[indices[:, 1]]], 1) 162 | 163 | logits = g(Z_pairs) 164 | if training: 165 | p_edges = LogitRelaxedBernoulli(logits=logits, temperature=temperature) 166 | A_vals = torch.sigmoid(p_edges.rsample()) 167 | else: 168 | p_edges = Bernoulli(logits=logits) 169 | A_vals = p_edges.sample() 170 | 171 | # embed the values to the adjacency matrix 172 | A = float_tensor(Z1.size(0), Z2.size(0)).zero_() 173 | A[indices[:, 0], indices[:, 1]] = A_vals.squeeze() 174 | 175 | return A 176 | 177 | 178 | class Flatten(torch.nn.Module): 179 | def __init__(self): 180 | super(Flatten, self).__init__() 181 | 182 | def forward(self, x): 183 | assert len(x.shape) > 1 184 | 185 | return x.view(x.shape[0], -1) 186 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import numpy as np 3 | import os 4 | 5 | for test_years in [2014, 2015, 2016, 2017, 2018, 2019]: 6 | for week in [2, 3, 4]: 7 | for atten in [ 8 | "trans", 9 | ]: 10 | model_name = f"epifnp_{test_years}_{week}" 11 | print(model_name) 12 | r = subprocess.call( 13 | [ 14 | "python", 15 | "train_ili.py", 16 | "-y", 17 | str(test_years), 18 | "-w", 19 | str(week), 20 | "-a", 21 | atten, 22 | "-n", 23 | model_name, 24 | "-e", 25 | str(3000), 26 | ] 27 | ) 28 | if r != 0: 29 | raise Exception(f"{model_name} process encountered error") 30 | -------------------------------------------------------------------------------- /test_ili.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.optim as optim 4 | from models.utils import float_tensor, device 5 | import pickle 6 | from models.fnpmodels import EmbedAttenSeq, RegressionFNP, EmbedSeq, RegressionFNP2 7 | import matplotlib.pyplot as plt 8 | import pandas as pd 9 | from optparse import OptionParser 10 | 11 | np.random.seed(10) 12 | 13 | city_idx = {f"Region {i}": i for i in range(1, 11)} 14 | city_idx["X"] = 0 15 | 16 | df = pd.read_csv("./data/ILINet.csv") 17 | df = df[["REGION", "YEAR", "WEEK", "% WEIGHTED ILI"]] 18 | df = df[(df["YEAR"] >= 2004) | ((df["YEAR"] == 2003) & (df["WEEK"] >= 20))] 19 | 20 | 21 | def get_dataset(year: int, region: str, df=df): 22 | ans = df[ 23 | ((df["YEAR"] == year) & (df["WEEK"] >= 20)) 24 | | ((df["YEAR"] == year + 1) & (df["WEEK"] <= 20)) 25 | ] 26 | return ans[ans["REGION"] == region]["% WEIGHTED ILI"] 27 | 28 | 29 | parser = OptionParser() 30 | parser.add_option("-y", "--year", dest="testyear", type="int") 31 | parser.add_option("-w", "--week", dest="week_ahead", type="int") 32 | parser.add_option("-a", "--atten", dest="atten", type="string") 33 | parser.add_option("-n", "--num", dest="num", type="string") 34 | (options, args) = parser.parse_args() 35 | 36 | train_seasons = list(range(2003, options.testyear)) 37 | test_seasons = [options.testyear] 38 | # train_seasons = list(range(2003, 2019)) 39 | # test_seasons = [2019] 40 | print(train_seasons, test_seasons) 41 | 42 | # train_seasons = [2003, 2004, 2005, 2006, 2007, 2008, 2009] 43 | # test_seasons = [2010] 44 | regions = ["X"] 45 | # regions = [f"Region {i}" for i in range(1,11)] 46 | 47 | week_ahead = options.week_ahead 48 | val_frac = 5 49 | attn = options.atten 50 | model_num = options.num 51 | # model_num = 22 52 | print(week_ahead, attn) 53 | 54 | 55 | def one_hot(idx, dim=len(city_idx)): 56 | ans = np.zeros(dim, dtype="float32") 57 | ans[idx] = 1.0 58 | return ans 59 | 60 | 61 | def save_data(obj, filepath): 62 | with open(filepath, "wb") as fl: 63 | pickle.dump(obj, fl) 64 | 65 | 66 | full_x = np.array( 67 | [ 68 | np.array(get_dataset(s, r), dtype="float32")[-53:] 69 | for s in train_seasons 70 | for r in regions 71 | ] 72 | ) 73 | full_meta = np.array([one_hot(city_idx[r]) for s in train_seasons for r in regions]) 74 | full_y = full_x.argmax(-1) 75 | full_x = full_x[:, :, None] 76 | 77 | full_x_test = np.array( 78 | [ 79 | np.array(get_dataset(s, r), dtype="float32")[-53:] 80 | for s in test_seasons 81 | for r in regions 82 | ] 83 | ) 84 | full_meta_test = np.array([one_hot(city_idx[r]) for s in test_seasons for r in regions]) 85 | full_y_test = full_x_test.argmax(-1) 86 | full_x_test = full_x_test[:, :, None] 87 | 88 | 89 | def create_dataset(full_meta, full_x, week_ahead=week_ahead): 90 | metas, seqs, y = [], [], [] 91 | for meta, seq in zip(full_meta, full_x): 92 | for i in range(20, full_x.shape[1]): 93 | metas.append(meta) 94 | seqs.append(seq[: i - week_ahead + 1]) 95 | y.append(seq[i]) 96 | return np.array(metas, dtype="float32"), seqs, np.array(y, dtype="float32") 97 | 98 | 99 | train_meta, train_x, train_y = create_dataset(full_meta, full_x) 100 | test_meta, test_x, test_y = create_dataset(full_meta_test, full_x_test) 101 | 102 | 103 | def create_tensors(metas, seqs, ys): 104 | metas = float_tensor(metas) 105 | ys = float_tensor(ys) 106 | max_len = max([len(s) for s in seqs]) 107 | out_seqs = np.zeros((len(seqs), max_len, seqs[0].shape[-1]), dtype="float32") 108 | lens = np.zeros(len(seqs), dtype="int32") 109 | for i, s in enumerate(seqs): 110 | out_seqs[i, : len(s), :] = s 111 | lens[i] = len(s) 112 | out_seqs = float_tensor(out_seqs) 113 | return metas, out_seqs, ys, lens 114 | 115 | 116 | def create_mask1(lens, out_dim=1): 117 | ans = np.zeros((max(lens), len(lens), out_dim), dtype="float32") 118 | for i, j in enumerate(lens): 119 | ans[j - 1, i, :] = 1.0 120 | return float_tensor(ans) 121 | 122 | 123 | def create_mask(lens, out_dim=1): 124 | ans = np.zeros((max(lens), len(lens), out_dim), dtype="float32") 125 | for i, j in enumerate(lens): 126 | ans[:j, i, :] = 1.0 127 | return float_tensor(ans) 128 | 129 | 130 | if attn == "trans": 131 | emb_model = EmbedAttenSeq( 132 | dim_seq_in=1, 133 | dim_metadata=len(city_idx), 134 | dim_out=50, 135 | n_layers=2, 136 | bidirectional=True, 137 | ).cuda() 138 | emb_model_full = EmbedAttenSeq( 139 | dim_seq_in=1, 140 | dim_metadata=len(city_idx), 141 | dim_out=50, 142 | n_layers=2, 143 | bidirectional=True, 144 | ).cuda() 145 | else: 146 | emb_model = EmbedSeq( 147 | dim_seq_in=1, 148 | dim_metadata=len(city_idx), 149 | dim_out=50, 150 | n_layers=2, 151 | bidirectional=True, 152 | ).cuda() 153 | emb_model_full = EmbedSeq( 154 | dim_seq_in=1, 155 | dim_metadata=len(city_idx), 156 | dim_out=50, 157 | n_layers=2, 158 | bidirectional=True, 159 | ).cuda() 160 | fnp_model = RegressionFNP2( 161 | dim_x=50, 162 | dim_y=1, 163 | dim_h=100, 164 | n_layers=3, 165 | num_M=train_meta.shape[0], 166 | dim_u=50, 167 | dim_z=50, 168 | fb_z=0.0, 169 | use_ref_labels=False, 170 | use_DAG=False, 171 | add_atten=False, 172 | ).cuda() 173 | optimizer = optim.Adam( 174 | list(emb_model.parameters()) 175 | + list(fnp_model.parameters()) 176 | + list(emb_model_full.parameters()), 177 | lr=1e-3, 178 | ) 179 | 180 | # emb_model_full = emb_model 181 | 182 | train_meta_, train_x_, train_y_, train_lens_ = create_tensors( 183 | train_meta, train_x, train_y 184 | ) 185 | 186 | test_meta, test_x, test_y, test_lens = create_tensors(test_meta, test_x, test_y) 187 | 188 | full_x_chunks = np.zeros((full_x.shape[0] * 4, full_x.shape[1], full_x.shape[2])) 189 | full_meta_chunks = np.zeros((full_meta.shape[0] * 4, full_meta.shape[1])) 190 | for i, s in enumerate(full_x): 191 | full_x_chunks[i * 4, -20:] = s[:20] 192 | full_x_chunks[i * 4 + 1, -30:] = s[:30] 193 | full_x_chunks[i * 4 + 2, -40:] = s[:40] 194 | full_x_chunks[i * 4 + 3, :] = s 195 | full_meta_chunks[i * 4 : i * 4 + 4] = full_meta[i] 196 | 197 | full_x = float_tensor(full_x) 198 | full_meta = float_tensor(full_meta) 199 | full_y = float_tensor(full_y) 200 | 201 | train_mask_, test_mask = ( 202 | create_mask(train_lens_), 203 | create_mask(test_lens), 204 | ) 205 | 206 | perm = np.random.permutation(train_meta_.shape[0]) 207 | val_perm = perm[: train_meta_.shape[0] // val_frac] 208 | train_perm = perm[train_meta_.shape[0] // val_frac :] 209 | 210 | train_meta, train_x, train_y, train_lens, train_mask = ( 211 | train_meta_[train_perm], 212 | train_x_[train_perm], 213 | train_y_[train_perm], 214 | train_lens_[train_perm], 215 | train_mask_[:, train_perm, :], 216 | ) 217 | val_meta, val_x, val_y, val_lens, val_mask = ( 218 | train_meta_[val_perm], 219 | train_x_[val_perm], 220 | train_y_[val_perm], 221 | train_lens_[val_perm], 222 | train_mask_[:, val_perm, :], 223 | ) 224 | 225 | 226 | def save_model(file_prefix: str): 227 | torch.save(emb_model.state_dict(), file_prefix + "_emb_model.pth") 228 | torch.save(emb_model_full.state_dict(), file_prefix + "_emb_model_full.pth") 229 | torch.save(fnp_model.state_dict(), file_prefix + "_fnp_model.pth") 230 | 231 | 232 | def load_model(file_prefix: str): 233 | emb_model.load_state_dict(torch.load(file_prefix + "_emb_model.pth")) 234 | emb_model_full.load_state_dict(torch.load(file_prefix + "_emb_model_full.pth")) 235 | fnp_model.load_state_dict(torch.load(file_prefix + "_fnp_model.pth")) 236 | 237 | 238 | def evaluate(sample=True, dtype="test"): 239 | with torch.no_grad(): 240 | emb_model.eval() 241 | emb_model_full.eval() 242 | fnp_model.eval() 243 | full_embeds = emb_model_full(full_x.transpose(1, 0), full_meta) 244 | if dtype == "val": 245 | x_embeds = emb_model.forward_mask(val_x.transpose(1, 0), val_meta, val_mask) 246 | elif dtype == "test": 247 | x_embeds = emb_model.forward_mask( 248 | test_x.transpose(1, 0), test_meta, test_mask 249 | ) 250 | elif dtype == "train": 251 | x_embeds = emb_model.forward_mask( 252 | train_x.transpose(1, 0), train_meta, train_mask 253 | ) 254 | elif dtype == "all": 255 | x_embeds = emb_model.forward_mask( 256 | train_x_.transpose(1, 0), train_meta_, train_mask_ 257 | ) 258 | else: 259 | raise ValueError("Incorrect dtype") 260 | y_pred, _, vars, _, _, _, _ = fnp_model.predict( 261 | x_embeds, full_embeds, full_y, sample=sample 262 | ) 263 | labels_dict = {"val": val_y, "test": test_y, "train": train_y, "all": train_y_} 264 | labels = labels_dict[dtype] 265 | mse_error = torch.pow(y_pred - labels, 2).mean().sqrt().detach().cpu().numpy() 266 | return ( 267 | mse_error, 268 | y_pred.detach().cpu().numpy().ravel(), 269 | labels.detach().cpu().numpy().ravel(), 270 | vars.mean().detach().cpu().numpy().ravel(), 271 | full_embeds.detach().cpu().numpy(), 272 | x_embeds.detach().cpu().numpy(), 273 | ) 274 | 275 | 276 | load_model(f"model_chkp/model{model_num}") 277 | 278 | e, yp, yt, vars, fem, tem = evaluate(True) 279 | yp = np.array([evaluate(True)[1] for _ in range(1000)]) 280 | yp, vars = np.mean(yp, 0), np.var(yp, 0) 281 | e = np.mean((yp - yt) ** 2) 282 | dev = np.sqrt(vars) * 1.95 283 | plt.figure(4) 284 | plt.plot(yp, label="Predicted 95%", color="blue") 285 | plt.fill_between(np.arange(len(yp)), yp + dev, yp - dev, color="blue", alpha=0.2) 286 | plt.plot(yt, label="True Value", color="green") 287 | plt.legend() 288 | plt.title(f"RMSE: {e}") 289 | plt.savefig(f"plots/Test{model_num}.png") 290 | dt = { 291 | "rmse": e, 292 | "target": yt, 293 | "pred": yp, 294 | "vars": vars, 295 | "fem": fem, 296 | "tem": tem, 297 | } 298 | save_data(dt, f"./saves/{model_num}_test.pkl") 299 | 300 | e, yp, yt, vars, _, _ = evaluate(True, dtype="val") 301 | yp = np.array([evaluate(True, dtype="val")[1] for _ in range(1000)]) 302 | yp, vars = np.mean(yp, 0), np.var(yp, 0) 303 | e = np.mean((yp - yt) ** 2) 304 | dev = np.sqrt(vars) * 1.95 305 | plt.figure(5) 306 | plt.plot(yp, label="Predicted 95%", color="blue") 307 | plt.fill_between(np.arange(len(yp)), yp + dev, yp - dev, color="blue", alpha=0.2) 308 | plt.plot(yt, label="True Value", color="green") 309 | plt.legend() 310 | plt.title(f"RMSE: {e}") 311 | plt.savefig(f"plots/Val{model_num}.png") 312 | 313 | e, yp, yt, vars, fem, tem = evaluate(True, dtype="all") 314 | yp = np.array([evaluate(True, dtype="all")[1] for _ in range(40)]) 315 | yp, vars = np.mean(yp, 0), np.var(yp, 0) 316 | e = np.mean((yp - yt) ** 2) 317 | dev = np.sqrt(vars) * 1.95 318 | plt.figure(6) 319 | plt.plot(yp, label="Predicted 95%", color="blue") 320 | plt.fill_between(np.arange(len(yp)), yp + dev, yp - dev, color="blue", alpha=0.2) 321 | plt.plot(yt, label="True Value", color="green") 322 | plt.legend() 323 | plt.title(f"RMSE: {e}") 324 | plt.savefig(f"plots/Train{model_num}.png") 325 | dt = { 326 | "rmse": e, 327 | "target": yt, 328 | "pred": yp, 329 | "vars": vars, 330 | "fem": fem, 331 | "tem": tem, 332 | } 333 | save_data(dt, f"./saves/{model_num}_train.pkl") 334 | -------------------------------------------------------------------------------- /test_regress_ili.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | import numpy as np 3 | import torch 4 | import torch.optim as optim 5 | from models.utils import float_tensor, device 6 | import pickle 7 | from models.fnpmodels import EmbedAttenSeq, RegressionFNP, EmbedSeq, RegressionFNP2 8 | import matplotlib.pyplot as plt 9 | import pandas as pd 10 | from optparse import OptionParser 11 | from tqdm import tqdm 12 | 13 | np.random.seed(10) 14 | 15 | city_idx = {f"Region {i}": i for i in range(1, 11)} 16 | city_idx["X"] = 0 17 | 18 | df = pd.read_csv("./data/ILINet.csv") 19 | df = df[["REGION", "YEAR", "WEEK", "% WEIGHTED ILI"]] 20 | df = df[(df["YEAR"] >= 2004) | ((df["YEAR"] == 2003) & (df["WEEK"] >= 20))] 21 | 22 | 23 | def get_dataset(year: int, region: str, df=df): 24 | ans = df[ 25 | ((df["YEAR"] == year) & (df["WEEK"] >= 20)) 26 | | ((df["YEAR"] == year + 1) & (df["WEEK"] <= 20)) 27 | ] 28 | return ans[ans["REGION"] == region]["% WEIGHTED ILI"] 29 | 30 | 31 | parser = OptionParser() 32 | parser.add_option("-y", "--year", dest="testyear", type="int") 33 | parser.add_option("-w", "--week", dest="week_ahead", type="int") 34 | parser.add_option("-a", "--atten", dest="atten", type="string") 35 | parser.add_option("-n", "--num", dest="num", type="string") 36 | (options, args) = parser.parse_args() 37 | 38 | train_seasons = list(range(2003, options.testyear)) 39 | test_seasons = [options.testyear] 40 | # train_seasons = list(range(2003, 2019)) 41 | # test_seasons = [2019] 42 | print(train_seasons, test_seasons) 43 | 44 | # train_seasons = [2003, 2004, 2005, 2006, 2007, 2008, 2009] 45 | # test_seasons = [2010] 46 | regions = ["X"] 47 | # regions = [f"Region {i}" for i in range(1,11)] + ["X"] 48 | 49 | week_ahead = options.week_ahead 50 | val_frac = 5 51 | attn = options.atten 52 | model_num = options.num 53 | # model_num = 22 54 | print(week_ahead, attn) 55 | 56 | 57 | def one_hot(idx, dim=len(city_idx)): 58 | ans = np.zeros(dim, dtype="float32") 59 | ans[idx] = 1.0 60 | return ans 61 | 62 | 63 | def save_data(obj, filepath): 64 | with open(filepath, "wb") as fl: 65 | pickle.dump(obj, fl) 66 | 67 | 68 | full_x = np.array( 69 | [ 70 | np.array(get_dataset(s, r), dtype="float32")[-53:] 71 | for s in train_seasons 72 | for r in regions 73 | ] 74 | ) 75 | full_meta = np.array([one_hot(city_idx[r]) for s in train_seasons for r in regions]) 76 | full_y = full_x.argmax(-1) 77 | full_x = full_x[:, :, None] 78 | 79 | full_x_test = np.array( 80 | [ 81 | np.array(get_dataset(s, r), dtype="float32")[-53:] 82 | for s in test_seasons 83 | for r in regions 84 | ] 85 | ) 86 | full_meta_test = np.array([one_hot(city_idx[r]) for s in test_seasons for r in regions]) 87 | full_y_test = full_x_test.argmax(-1) 88 | full_x_test = full_x_test[:, :, None] 89 | 90 | 91 | def create_dataset(full_meta, full_x, week_ahead=week_ahead): 92 | metas, seqs, y = [], [], [] 93 | for meta, seq in zip(full_meta, full_x): 94 | for i in range(20, full_x.shape[1]): 95 | metas.append(meta) 96 | seqs.append(seq[: i - week_ahead + 1]) 97 | y.append(seq[i]) 98 | return np.array(metas, dtype="float32"), seqs, np.array(y, dtype="float32") 99 | 100 | 101 | train_meta, train_x, train_y = create_dataset(full_meta, full_x) 102 | test_meta, test_x, test_y = create_dataset(full_meta_test, full_x_test) 103 | 104 | 105 | def create_tensors(metas, seqs, ys): 106 | metas = float_tensor(metas) 107 | ys = float_tensor(ys) 108 | max_len = max([len(s) for s in seqs]) 109 | out_seqs = np.zeros((len(seqs), max_len, seqs[0].shape[-1]), dtype="float32") 110 | lens = np.zeros(len(seqs), dtype="int32") 111 | for i, s in enumerate(seqs): 112 | out_seqs[i, : len(s), :] = s 113 | lens[i] = len(s) 114 | out_seqs = float_tensor(out_seqs) 115 | return metas, out_seqs, ys, lens 116 | 117 | 118 | def create_mask1(lens, out_dim=1): 119 | ans = np.zeros((max(lens), len(lens), out_dim), dtype="float32") 120 | for i, j in enumerate(lens): 121 | ans[j - 1, i, :] = 1.0 122 | return float_tensor(ans) 123 | 124 | 125 | def create_mask(lens, out_dim=1): 126 | ans = np.zeros((max(lens), len(lens), out_dim), dtype="float32") 127 | for i, j in enumerate(lens): 128 | ans[:j, i, :] = 1.0 129 | return float_tensor(ans) 130 | 131 | 132 | if attn == "trans": 133 | emb_model = EmbedAttenSeq( 134 | dim_seq_in=1, 135 | dim_metadata=len(city_idx), 136 | dim_out=50, 137 | n_layers=2, 138 | bidirectional=True, 139 | ).cuda() 140 | emb_model_full = EmbedAttenSeq( 141 | dim_seq_in=1, 142 | dim_metadata=len(city_idx), 143 | dim_out=50, 144 | n_layers=2, 145 | bidirectional=True, 146 | ).cuda() 147 | else: 148 | emb_model = EmbedSeq( 149 | dim_seq_in=1, 150 | dim_metadata=len(city_idx), 151 | dim_out=50, 152 | n_layers=2, 153 | bidirectional=True, 154 | ).cuda() 155 | emb_model_full = EmbedSeq( 156 | dim_seq_in=1, 157 | dim_metadata=len(city_idx), 158 | dim_out=50, 159 | n_layers=2, 160 | bidirectional=True, 161 | ).cuda() 162 | fnp_model = RegressionFNP2( 163 | dim_x=50, 164 | dim_y=1, 165 | dim_h=100, 166 | n_layers=3, 167 | num_M=train_meta.shape[0], 168 | dim_u=50, 169 | dim_z=50, 170 | fb_z=0.0, 171 | use_ref_labels=False, 172 | use_DAG=False, 173 | add_atten=False, 174 | ).cuda() 175 | optimizer = optim.Adam( 176 | list(emb_model.parameters()) 177 | + list(fnp_model.parameters()) 178 | + list(emb_model_full.parameters()), 179 | lr=1e-3, 180 | ) 181 | 182 | # emb_model_full = emb_model 183 | 184 | train_meta_, train_x_, train_y_, train_lens_ = create_tensors( 185 | train_meta, train_x, train_y 186 | ) 187 | 188 | test_meta, test_x, test_y, test_lens = create_tensors(test_meta, test_x, test_y) 189 | 190 | full_x = float_tensor(full_x) 191 | full_meta = float_tensor(full_meta) 192 | full_y = float_tensor(full_y) 193 | 194 | train_mask_, test_mask = ( 195 | create_mask(train_lens_), 196 | create_mask(test_lens), 197 | ) 198 | 199 | perm = np.random.permutation(train_meta_.shape[0]) 200 | val_perm = perm[: train_meta_.shape[0] // val_frac] 201 | train_perm = perm[train_meta_.shape[0] // val_frac :] 202 | 203 | train_meta, train_x, train_y, train_lens, train_mask = ( 204 | train_meta_[train_perm], 205 | train_x_[train_perm], 206 | train_y_[train_perm], 207 | train_lens_[train_perm], 208 | train_mask_[:, train_perm, :], 209 | ) 210 | val_meta, val_x, val_y, val_lens, val_mask = ( 211 | train_meta_[val_perm], 212 | train_x_[val_perm], 213 | train_y_[val_perm], 214 | train_lens_[val_perm], 215 | train_mask_[:, val_perm, :], 216 | ) 217 | 218 | 219 | def save_model(file_prefix: str): 220 | torch.save(emb_model.state_dict(), file_prefix + "_emb_model.pth") 221 | torch.save(emb_model_full.state_dict(), file_prefix + "_emb_model_full.pth") 222 | torch.save(fnp_model.state_dict(), file_prefix + "_fnp_model.pth") 223 | 224 | 225 | def load_model(file_prefix: str): 226 | emb_model.load_state_dict(torch.load(file_prefix + "_emb_model.pth")) 227 | emb_model_full.load_state_dict(torch.load(file_prefix + "_emb_model_full.pth")) 228 | fnp_model.load_state_dict(torch.load(file_prefix + "_fnp_model.pth")) 229 | 230 | 231 | load_model(f"model_chkp/model{model_num}") 232 | 233 | 234 | def evaluate(sample=True, dtype="test", week_ahead=week_ahead): 235 | with torch.no_grad(): 236 | emb_model.eval() 237 | emb_model_full.eval() 238 | fnp_model.eval() 239 | full_embeds = emb_model_full(full_x.transpose(1, 0), full_meta) 240 | if dtype == "val": 241 | curr_x, curr_meta, curr_lens = val_x, val_meta, val_lens.copy() 242 | elif dtype == "test": 243 | curr_x, curr_meta, curr_lens = test_x, test_meta, test_lens.copy() 244 | elif dtype == "train": 245 | curr_x, curr_meta, curr_lens = train_x, train_meta, train_lens.copy() 246 | elif dtype == "all": 247 | curr_x, curr_meta, curr_lens = train_x_, train_meta_, train_lens_.copy() 248 | else: 249 | raise ValueError("Incorrect dtype") 250 | for _ in range(week_ahead): 251 | curr_mask = create_mask(curr_lens) 252 | x_embeds = emb_model.forward_mask( 253 | curr_x.transpose(1, 0), curr_meta, curr_mask 254 | ) 255 | y_pred, _, vars, _, _, _, _ = fnp_model.predict( 256 | x_embeds, full_embeds, full_y, sample=sample 257 | ) 258 | curr_z = float_tensor(np.zeros((curr_x.shape[0], 1, 1))) 259 | curr_x = torch.cat([curr_x, curr_z], 1) 260 | curr_x[np.arange(curr_x.shape[0]), curr_lens] = y_pred 261 | curr_lens += 1 262 | labels_dict = {"val": val_y, "test": test_y, "train": train_y, "all": train_y_} 263 | labels = labels_dict[dtype] 264 | mse_error = torch.pow(y_pred - labels, 2).mean().sqrt().detach().cpu().numpy() 265 | return ( 266 | mse_error, 267 | y_pred.detach().cpu().numpy().ravel(), 268 | labels.detach().cpu().numpy().ravel(), 269 | vars.mean().detach().cpu().numpy().ravel(), 270 | full_embeds.detach().cpu().numpy(), 271 | x_embeds.detach().cpu().numpy(), 272 | ) 273 | 274 | 275 | e, yp, yt, vars, fem, tem = evaluate(True) 276 | yp = np.array([evaluate(True)[1] for _ in tqdm(range(10000))]) 277 | yp, vars = np.mean(yp, 0), np.var(yp, 0) 278 | e = np.mean((yp - yt) ** 2) 279 | dev = np.sqrt(vars) * 1.95 280 | plt.figure(4) 281 | plt.plot(yp, label="Predicted 95%", color="blue") 282 | plt.fill_between(np.arange(len(yp)), yp + dev, yp - dev, color="blue", alpha=0.2) 283 | plt.plot(yt, label="True Value", color="green") 284 | plt.legend() 285 | plt.title(f"RMSE: {e}") 286 | plt.savefig(f"plots/Test_regress_{week_ahead}_{model_num}.png") 287 | dt = { 288 | "rmse": e, 289 | "target": yt, 290 | "pred": yp, 291 | "vars": vars, 292 | "fem": fem, 293 | "tem": tem, 294 | } 295 | save_data(dt, f"./saves/regress_{week_ahead}_{model_num}_test.pkl") 296 | 297 | e, yp, yt, vars, _, _ = evaluate(True, dtype="val") 298 | yp = np.array([evaluate(True, dtype="val")[1] for _ in tqdm(range(1000))]) 299 | yp, vars = np.mean(yp, 0), np.var(yp, 0) 300 | e = np.mean((yp - yt) ** 2) 301 | dev = np.sqrt(vars) * 1.95 302 | plt.figure(5) 303 | plt.plot(yp, label="Predicted 95%", color="blue") 304 | plt.fill_between(np.arange(len(yp)), yp + dev, yp - dev, color="blue", alpha=0.2) 305 | plt.plot(yt, label="True Value", color="green") 306 | plt.legend() 307 | plt.title(f"RMSE: {e}") 308 | plt.savefig(f"plots/Val_regress_{week_ahead}_{model_num}.png") 309 | 310 | e, yp, yt, vars, fem, tem = evaluate(True, dtype="all") 311 | yp = np.array([evaluate(True, dtype="all")[1] for _ in tqdm(range(40))]) 312 | yp, vars = np.mean(yp, 0), np.var(yp, 0) 313 | e = np.mean((yp - yt) ** 2) 314 | dev = np.sqrt(vars) * 1.95 315 | plt.figure(6) 316 | plt.plot(yp, label="Predicted 95%", color="blue") 317 | plt.fill_between(np.arange(len(yp)), yp + dev, yp - dev, color="blue", alpha=0.2) 318 | plt.plot(yt, label="True Value", color="green") 319 | plt.legend() 320 | plt.title(f"RMSE: {e}") 321 | plt.savefig(f"plots/Train{model_num}.png") 322 | dt = { 323 | "rmse": e, 324 | "target": yt, 325 | "pred": yp, 326 | "vars": vars, 327 | "fem": fem, 328 | "tem": tem, 329 | } 330 | save_data(dt, f"./saves/regress2_{week_ahead}_{model_num}_train.pkl") 331 | -------------------------------------------------------------------------------- /train_ili.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.optim as optim 4 | from models.utils import float_tensor, device 5 | import pickle 6 | from models.fnpmodels import EmbedAttenSeq, RegressionFNP, EmbedSeq, RegressionFNP2 7 | import matplotlib.pyplot as plt 8 | import pandas as pd 9 | from optparse import OptionParser 10 | import os 11 | 12 | 13 | for d in ["model_chkp", "plots", "saves"]: 14 | if not os.path.exists(d): 15 | os.mkdir(d) 16 | 17 | 18 | np.random.seed(10) 19 | 20 | city_idx = {f"Region {i}": i for i in range(1, 11)} 21 | city_idx["X"] = 0 22 | 23 | df = pd.read_csv("./data/ILINet.csv") 24 | df = df[["REGION", "YEAR", "WEEK", "% WEIGHTED ILI"]] 25 | df = df[(df["YEAR"] >= 2004) | ((df["YEAR"] == 2003) & (df["WEEK"] >= 20))] 26 | 27 | 28 | def get_dataset(year: int, region: str, df=df): 29 | ans = df[ 30 | ((df["YEAR"] == year) & (df["WEEK"] >= 20)) 31 | | ((df["YEAR"] == year + 1) & (df["WEEK"] <= 20)) 32 | ] 33 | return ans[ans["REGION"] == region]["% WEIGHTED ILI"] 34 | 35 | 36 | parser = OptionParser() 37 | parser.add_option("-y", "--year", dest="testyear", type="int") 38 | parser.add_option("-w", "--week", dest="week_ahead", type="int") 39 | parser.add_option("-a", "--atten", dest="atten", type="string") 40 | parser.add_option("-n", "--num", dest="num", type="string") 41 | parser.add_option("-e", "--epoch", dest="epochs", type="int") 42 | (options, args) = parser.parse_args() 43 | 44 | train_seasons = list(range(2003, options.testyear)) 45 | test_seasons = [options.testyear] 46 | # train_seasons = list(range(2003, 2019)) 47 | # test_seasons = [2019] 48 | print(train_seasons, test_seasons) 49 | 50 | # train_seasons = [2003, 2004, 2005, 2006, 2007, 2008, 2009] 51 | # test_seasons = [2010] 52 | regions = ["X"] 53 | # regions = [f"Region {i}" for i in range(1,11)] 54 | 55 | week_ahead = options.week_ahead 56 | val_frac = 5 57 | attn = options.atten 58 | model_num = options.num 59 | # model_num = 22 60 | EPOCHS = options.epochs 61 | print(week_ahead, attn, EPOCHS) 62 | 63 | 64 | def one_hot(idx, dim=len(city_idx)): 65 | ans = np.zeros(dim, dtype="float32") 66 | ans[idx] = 1.0 67 | return ans 68 | 69 | 70 | def save_data(obj, filepath): 71 | with open(filepath, "wb") as fl: 72 | pickle.dump(obj, fl) 73 | 74 | 75 | full_x = np.array( 76 | [ 77 | np.array(get_dataset(s, r), dtype="float32")[-53:] 78 | for s in train_seasons 79 | for r in regions 80 | ] 81 | ) 82 | full_meta = np.array([one_hot(city_idx[r]) for s in train_seasons for r in regions]) 83 | full_y = full_x.argmax(-1) 84 | full_x = full_x[:, :, None] 85 | 86 | full_x_test = np.array( 87 | [ 88 | np.array(get_dataset(s, r), dtype="float32")[-53:] 89 | for s in test_seasons 90 | for r in regions 91 | ] 92 | ) 93 | full_meta_test = np.array([one_hot(city_idx[r]) for s in test_seasons for r in regions]) 94 | full_y_test = full_x_test.argmax(-1) 95 | full_x_test = full_x_test[:, :, None] 96 | 97 | 98 | def create_dataset(full_meta, full_x, week_ahead=week_ahead): 99 | metas, seqs, y = [], [], [] 100 | for meta, seq in zip(full_meta, full_x): 101 | for i in range(20, full_x.shape[1]): 102 | metas.append(meta) 103 | seqs.append(seq[: i - week_ahead + 1]) 104 | y.append(seq[i]) 105 | return np.array(metas, dtype="float32"), seqs, np.array(y, dtype="float32") 106 | 107 | 108 | train_meta, train_x, train_y = create_dataset(full_meta, full_x) 109 | test_meta, test_x, test_y = create_dataset(full_meta_test, full_x_test) 110 | 111 | 112 | def create_tensors(metas, seqs, ys): 113 | metas = float_tensor(metas) 114 | ys = float_tensor(ys) 115 | max_len = max([len(s) for s in seqs]) 116 | out_seqs = np.zeros((len(seqs), max_len, seqs[0].shape[-1]), dtype="float32") 117 | lens = np.zeros(len(seqs), dtype="int32") 118 | for i, s in enumerate(seqs): 119 | out_seqs[i, : len(s), :] = s 120 | lens[i] = len(s) 121 | out_seqs = float_tensor(out_seqs) 122 | return metas, out_seqs, ys, lens 123 | 124 | 125 | def create_mask1(lens, out_dim=1): 126 | ans = np.zeros((max(lens), len(lens), out_dim), dtype="float32") 127 | for i, j in enumerate(lens): 128 | ans[j - 1, i, :] = 1.0 129 | return float_tensor(ans) 130 | 131 | 132 | def create_mask(lens, out_dim=1): 133 | ans = np.zeros((max(lens), len(lens), out_dim), dtype="float32") 134 | for i, j in enumerate(lens): 135 | ans[:j, i, :] = 1.0 136 | return float_tensor(ans) 137 | 138 | 139 | if attn == "trans": 140 | emb_model = EmbedAttenSeq( 141 | dim_seq_in=1, 142 | dim_metadata=len(city_idx), 143 | dim_out=50, 144 | n_layers=2, 145 | bidirectional=True, 146 | ).cuda() 147 | emb_model_full = EmbedAttenSeq( 148 | dim_seq_in=1, 149 | dim_metadata=len(city_idx), 150 | dim_out=50, 151 | n_layers=2, 152 | bidirectional=True, 153 | ).cuda() 154 | else: 155 | emb_model = EmbedSeq( 156 | dim_seq_in=1, 157 | dim_metadata=len(city_idx), 158 | dim_out=50, 159 | n_layers=2, 160 | bidirectional=True, 161 | ).cuda() 162 | emb_model_full = EmbedSeq( 163 | dim_seq_in=1, 164 | dim_metadata=len(city_idx), 165 | dim_out=50, 166 | n_layers=2, 167 | bidirectional=True, 168 | ).cuda() 169 | fnp_model = RegressionFNP2( 170 | dim_x=50, 171 | dim_y=1, 172 | dim_h=100, 173 | n_layers=3, 174 | num_M=train_meta.shape[0], 175 | dim_u=50, 176 | dim_z=50, 177 | fb_z=0.0, 178 | use_ref_labels=False, 179 | use_DAG=False, 180 | add_atten=False, 181 | ).cuda() 182 | optimizer = optim.Adam( 183 | list(emb_model.parameters()) 184 | + list(fnp_model.parameters()) 185 | + list(emb_model_full.parameters()), 186 | lr=1e-3, 187 | ) 188 | 189 | # emb_model_full = emb_model 190 | 191 | train_meta_, train_x_, train_y_, train_lens_ = create_tensors( 192 | train_meta, train_x, train_y 193 | ) 194 | 195 | test_meta, test_x, test_y, test_lens = create_tensors(test_meta, test_x, test_y) 196 | 197 | full_x_chunks = np.zeros((full_x.shape[0] * 4, full_x.shape[1], full_x.shape[2])) 198 | full_meta_chunks = np.zeros((full_meta.shape[0] * 4, full_meta.shape[1])) 199 | for i, s in enumerate(full_x): 200 | full_x_chunks[i * 4, -20:] = s[:20] 201 | full_x_chunks[i * 4 + 1, -30:] = s[:30] 202 | full_x_chunks[i * 4 + 2, -40:] = s[:40] 203 | full_x_chunks[i * 4 + 3, :] = s 204 | full_meta_chunks[i * 4 : i * 4 + 4] = full_meta[i] 205 | 206 | full_x = float_tensor(full_x) 207 | full_meta = float_tensor(full_meta) 208 | full_y = float_tensor(full_y) 209 | 210 | train_mask_, test_mask = ( 211 | create_mask(train_lens_), 212 | create_mask(test_lens), 213 | ) 214 | 215 | perm = np.random.permutation(train_meta_.shape[0]) 216 | val_perm = perm[: train_meta_.shape[0] // val_frac] 217 | train_perm = perm[train_meta_.shape[0] // val_frac :] 218 | 219 | train_meta, train_x, train_y, train_lens, train_mask = ( 220 | train_meta_[train_perm], 221 | train_x_[train_perm], 222 | train_y_[train_perm], 223 | train_lens_[train_perm], 224 | train_mask_[:, train_perm, :], 225 | ) 226 | val_meta, val_x, val_y, val_lens, val_mask = ( 227 | train_meta_[val_perm], 228 | train_x_[val_perm], 229 | train_y_[val_perm], 230 | train_lens_[val_perm], 231 | train_mask_[:, val_perm, :], 232 | ) 233 | 234 | 235 | def save_model(file_prefix: str): 236 | torch.save(emb_model.state_dict(), file_prefix + "_emb_model.pth") 237 | torch.save(emb_model_full.state_dict(), file_prefix + "_emb_model_full.pth") 238 | torch.save(fnp_model.state_dict(), file_prefix + "_fnp_model.pth") 239 | 240 | 241 | def load_model(file_prefix: str): 242 | emb_model.load_state_dict(torch.load(file_prefix + "_emb_model.pth")) 243 | emb_model_full.load_state_dict(torch.load(file_prefix + "_emb_model_full.pth")) 244 | fnp_model.load_state_dict(torch.load(file_prefix + "_fnp_model.pth")) 245 | 246 | 247 | def evaluate(sample=True, dtype="test"): 248 | with torch.no_grad(): 249 | emb_model.eval() 250 | emb_model_full.eval() 251 | fnp_model.eval() 252 | full_embeds = emb_model_full(full_x.transpose(1, 0), full_meta) 253 | if dtype == "val": 254 | x_embeds = emb_model.forward_mask(val_x.transpose(1, 0), val_meta, val_mask) 255 | elif dtype == "test": 256 | x_embeds = emb_model.forward_mask( 257 | test_x.transpose(1, 0), test_meta, test_mask 258 | ) 259 | elif dtype == "train": 260 | x_embeds = emb_model.forward_mask( 261 | train_x.transpose(1, 0), train_meta, train_mask 262 | ) 263 | elif dtype == "all": 264 | x_embeds = emb_model.forward_mask( 265 | train_x_.transpose(1, 0), train_meta_, train_mask_ 266 | ) 267 | else: 268 | raise ValueError("Incorrect dtype") 269 | y_pred, _, vars, _, _, _, _ = fnp_model.predict( 270 | x_embeds, full_embeds, full_y, sample=sample 271 | ) 272 | labels_dict = {"val": val_y, "test": test_y, "train": train_y, "all": train_y_} 273 | labels = labels_dict[dtype] 274 | mse_error = torch.pow(y_pred - labels, 2).mean().sqrt().detach().cpu().numpy() 275 | return ( 276 | mse_error, 277 | y_pred.detach().cpu().numpy().ravel(), 278 | labels.detach().cpu().numpy().ravel(), 279 | vars.mean().detach().cpu().numpy().ravel(), 280 | full_embeds.detach().cpu().numpy(), 281 | x_embeds.detach().cpu().numpy(), 282 | ) 283 | 284 | 285 | error = 100.0 286 | losses = [] 287 | errors = [] 288 | train_errors = [] 289 | variances = [] 290 | best_ep = 0 291 | 292 | for ep in range(EPOCHS): 293 | emb_model.train() 294 | emb_model_full.train() 295 | fnp_model.train() 296 | print(f"Epoch: {ep+1}") 297 | optimizer.zero_grad() 298 | x_embeds = emb_model.forward_mask(train_x.transpose(1, 0), train_meta, train_mask) 299 | full_embeds = emb_model_full(full_x.transpose(1, 0), full_meta) 300 | loss, yp, _ = fnp_model.forward(full_embeds, full_y, x_embeds, train_y) 301 | loss.backward() 302 | optimizer.step() 303 | losses.append(loss.detach().cpu().numpy()) 304 | train_errors.append( 305 | torch.pow(yp[full_x.shape[0] :] - train_y, 2) 306 | .mean() 307 | .sqrt() 308 | .detach() 309 | .cpu() 310 | .numpy() 311 | ) 312 | 313 | e, yp, yt, _, _, _ = evaluate(False) 314 | e = np.mean([evaluate(True, dtype="val")[0] for _ in range(40)]) 315 | vars = np.mean([evaluate(True, dtype="val")[3] for _ in range(40)]) 316 | errors.append(e) 317 | variances.append(vars) 318 | idxs = np.random.randint(yp.shape[0], size=10) 319 | print("Loss:", loss.detach().cpu().numpy()) 320 | print(f"Val RMSE: {e:.3f}, Train RMSE: {train_errors[-1]:.3f}") 321 | # print(f"MSE: {e}") 322 | if ep > 100 and min(errors[-100:]) > error + 0.1: 323 | errors = errors[: best_ep + 1] 324 | losses = losses[: best_ep + 1] 325 | print(f"Done in {ep+1} epochs") 326 | break 327 | if e < error: 328 | save_model(f"model_chkp/model{model_num}") 329 | error = e 330 | best_ep = ep + 1 331 | 332 | 333 | print(f"Val MSE error: {error}") 334 | plt.figure(1) 335 | plt.plot(losses) 336 | plt.savefig(f"plots/losses{model_num}.png") 337 | plt.figure(2) 338 | plt.plot(errors) 339 | plt.plot(train_errors) 340 | plt.savefig(f"plots/errors{model_num}.png") 341 | plt.figure(3) 342 | plt.plot(variances) 343 | plt.savefig(f"plots/vars{model_num}.png") 344 | 345 | load_model(f"model_chkp/model{model_num}") 346 | 347 | e, yp, yt, vars, fem, tem = evaluate(True) 348 | yp = np.array([evaluate(True)[1] for _ in range(1000)]) 349 | yp, vars = np.mean(yp, 0), np.var(yp, 0) 350 | e = np.mean((yp - yt) ** 2) 351 | dev = np.sqrt(vars) * 1.95 352 | plt.figure(4) 353 | plt.plot(yp, label="Predicted 95%", color="blue") 354 | plt.fill_between(np.arange(len(yp)), yp + dev, yp - dev, color="blue", alpha=0.2) 355 | plt.plot(yt, label="True Value", color="green") 356 | plt.legend() 357 | plt.title(f"RMSE: {e}") 358 | plt.savefig(f"plots/Test{model_num}.png") 359 | dt = { 360 | "rmse": e, 361 | "target": yt, 362 | "pred": yp, 363 | "vars": vars, 364 | "fem": fem, 365 | "tem": tem, 366 | } 367 | save_data(dt, f"./saves/{model_num}_test.pkl") 368 | 369 | e, yp, yt, vars, _, _ = evaluate(True, dtype="val") 370 | yp = np.array([evaluate(True, dtype="val")[1] for _ in range(1000)]) 371 | yp, vars = np.mean(yp, 0), np.var(yp, 0) 372 | e = np.mean((yp - yt) ** 2) 373 | dev = np.sqrt(vars) * 1.95 374 | plt.figure(5) 375 | plt.plot(yp, label="Predicted 95%", color="blue") 376 | plt.fill_between(np.arange(len(yp)), yp + dev, yp - dev, color="blue", alpha=0.2) 377 | plt.plot(yt, label="True Value", color="green") 378 | plt.legend() 379 | plt.title(f"RMSE: {e}") 380 | plt.savefig(f"plots/Val{model_num}.png") 381 | 382 | e, yp, yt, vars, fem, tem = evaluate(True, dtype="all") 383 | yp = np.array([evaluate(True, dtype="all")[1] for _ in range(40)]) 384 | yp, vars = np.mean(yp, 0), np.var(yp, 0) 385 | e = np.mean((yp - yt) ** 2) 386 | dev = np.sqrt(vars) * 1.95 387 | plt.figure(6) 388 | plt.plot(yp, label="Predicted 95%", color="blue") 389 | plt.fill_between(np.arange(len(yp)), yp + dev, yp - dev, color="blue", alpha=0.2) 390 | plt.plot(yt, label="True Value", color="green") 391 | plt.legend() 392 | plt.title(f"RMSE: {e}") 393 | plt.savefig(f"plots/Train{model_num}.png") 394 | dt = { 395 | "rmse": e, 396 | "target": yt, 397 | "pred": yp, 398 | "vars": vars, 399 | "fem": fem, 400 | "tem": tem, 401 | } 402 | save_data(dt, f"./saves/{model_num}_train.pkl") 403 | -------------------------------------------------------------------------------- /transform_pred.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy as np 3 | from DistCal.GP_Beta_cal import GP_Beta 4 | import DistCal.utils as utils 5 | from sklearn.isotonic import IsotonicRegression 6 | 7 | 8 | def build_GP(mu, sigma, target, n_u=8): 9 | model = GP_Beta() 10 | model.fit(target, mu, sigma, n_u) 11 | print("Done training") 12 | return model 13 | 14 | 15 | def build_iso(mu, sigma, target): 16 | iso_q, iso_q_hat = utils.get_iso_cal_table(target, mu, sigma) 17 | model = IsotonicRegression(out_of_bounds="clip") 18 | model.fit(iso_q, iso_q_hat) 19 | return model 20 | 21 | 22 | def predict_GP(model, mu, sigma, y_range): 23 | pdf, cdf = model.predict(y_range, mu, sigma) 24 | preds = utils.get_y_hat(y_range, pdf) 25 | var = utils.get_y_var(y_range, pdf) 26 | return preds, var, pdf, cdf 27 | 28 | 29 | def predict_iso(model, mu, sigma, y_range): 30 | pdf1, cdf1 = utils.get_norm_q(mu.ravel(), sigma.ravel(), y_range.ravel()) 31 | cdf = model.predict(pdf1.ravel()).reshape(pdf1.shape) 32 | pdf = np.diff(cdf, axis=1) / (y_range[0, 1:] - y_range[0, :-1]).ravel().reshape( 33 | 1, -1 34 | ).repeat(len(mu), axis=0) 35 | preds = utils.get_y_hat(y_range.ravel(), pdf) 36 | var = utils.get_y_var(y_range.ravel(), pdf) 37 | return preds, var, pdf, cdf 38 | 39 | 40 | def load_data(path): 41 | with open(path, "rb") as fl: 42 | dt = pickle.load(fl) 43 | return dt 44 | 45 | 46 | model = "ili2_2017_2" 47 | test_data = load_data(f"./saves/{model}_test.pkl") 48 | train_data = load_data(f"./saves/{model}_train.pkl") 49 | 50 | 51 | def transform_gp(test_data, train_data): 52 | 53 | gp_model = build_GP( 54 | train_data["pred"].astype(np.float64).reshape(-1, 1), 55 | np.sqrt(train_data["vars"]).astype(np.float64).reshape(-1, 1), 56 | train_data["target"].astype(np.float64).reshape(-1, 1), 57 | ) 58 | y_range = np.linspace( 59 | train_data["pred"].min() - train_data["vars"].max(), 60 | train_data["pred"].max() + train_data["vars"].max(), 61 | 100, 62 | ).reshape(1, -1) 63 | preds_gp, var_gp, pdf_gp, cdf_gp = predict_GP( 64 | gp_model, 65 | test_data["pred"].astype(np.float64).reshape(-1, 1), 66 | np.sqrt(test_data["vars"]).astype(np.float64).reshape(-1, 1), 67 | y_range, 68 | ) 69 | 70 | return preds_gp, var_gp 71 | 72 | 73 | def transform_iso(test_data, train_data): 74 | iso_model = build_iso( 75 | train_data["pred"].astype(np.float64).reshape(-1, 1), 76 | np.sqrt(train_data["vars"]).astype(np.float64).reshape(-1, 1), 77 | train_data["target"].astype(np.float64).reshape(-1, 1), 78 | ) 79 | y_range = np.linspace( 80 | train_data["pred"].min() - train_data["vars"].max(), 81 | train_data["pred"].max() + train_data["vars"].max(), 82 | 100, 83 | ).reshape(1, -1) 84 | preds_iso, var_iso, pdf_iso, cdf_iso = predict_iso( 85 | iso_model, 86 | test_data["pred"].astype(np.float64).reshape(-1, 1), 87 | np.sqrt(test_data["vars"]).astype(np.float64).reshape(-1, 1), 88 | y_range, 89 | ) 90 | 91 | return preds_iso, var_iso 92 | --------------------------------------------------------------------------------