├── .gitignore ├── Adam.py ├── CustomDataset.py ├── DATA_LOADER_DICT.pth ├── FNO4D.py ├── LICENSE ├── README.mD ├── UnitGaussianNormalizer.py ├── config_meta_data.py ├── config_py_utility.py ├── data_config ├── PERF_DICT.json ├── config_utility.py ├── file_config.sh ├── file_config_GLOBAL_dP.py ├── file_config_LGR1_SG.py ├── file_config_LGR1_dP.py ├── file_config_LGR2_SG.py ├── file_config_LGR2_dP.py ├── file_config_LGR3_SG.py ├── file_config_LGR3_dP.py ├── file_config_LGR4_SG.py └── file_config_LGR4_dP.py ├── eval_sequential_prediction_dp.ipynb ├── eval_sequential_prediction_sg.ipynb ├── finetune_FNO4D_DP_LGR.ipynb ├── hetero_logs └── events.out.tfevents.1667870912.sh02-16n09.int.68858.0 ├── lploss.py ├── meta_data_to_input_dict.py ├── normalizer ├── input_normalizer_GLOBAL_DP_val.pickle ├── input_normalizer_LGR1_DP_val.pickle ├── input_normalizer_LGR2_DP_val.pickle ├── input_normalizer_LGR3_DP_val.pickle ├── input_normalizer_LGR4_DP_val.pickle ├── output_normalizer_GLOBAL_DP_val.pickle ├── output_normalizer_LGR1_DP_val.pickle ├── output_normalizer_LGR2_DP_val.pickle ├── output_normalizer_LGR3_DP_val.pickle └── output_normalizer_LGR4_DP_val.pickle ├── predict_full_sg.py ├── save_data_loader.py ├── train_FNO4D_DP_GLOBAL.py ├── train_FNO4D_DP_LGR.py ├── train_FNO4D_SG_LGR.py ├── utility.py └── visulization_compare.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | 3 | ECLIPSE/meta_data/*.npy 4 | *.pt 5 | logs/* 6 | .ipynb_checkpoints 7 | eval_sequential_prediction_dp_custom_var.ipynb 8 | eval_sequential_prediction_sg_custom_var.ipynb 9 | vonk3d.py 10 | eval_sequential_prediction_dp_backup.ipynb 11 | data_config/*.ipynb 12 | data_config/PERF_DICT_backup.json 13 | -------------------------------------------------------------------------------- /Adam.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import Tensor 4 | from typing import List, Optional 5 | from torch.optim.optimizer import Optimizer 6 | 7 | 8 | def adam(params: List[Tensor], 9 | grads: List[Tensor], 10 | exp_avgs: List[Tensor], 11 | exp_avg_sqs: List[Tensor], 12 | max_exp_avg_sqs: List[Tensor], 13 | state_steps: List[int], 14 | *, 15 | amsgrad: bool, 16 | beta1: float, 17 | beta2: float, 18 | lr: float, 19 | weight_decay: float, 20 | eps: float): 21 | r"""Functional API that performs Adam algorithm computation. 22 | See :class:`~torch.optim.Adam` for details. 23 | """ 24 | 25 | for i, param in enumerate(params): 26 | 27 | grad = grads[i] 28 | exp_avg = exp_avgs[i] 29 | exp_avg_sq = exp_avg_sqs[i] 30 | step = state_steps[i] 31 | 32 | bias_correction1 = 1 - beta1 ** step 33 | bias_correction2 = 1 - beta2 ** step 34 | 35 | if weight_decay != 0: 36 | grad = grad.add(param, alpha=weight_decay) 37 | 38 | # Decay the first and second moment running average coefficient 39 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 40 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) 41 | if amsgrad: 42 | # Maintains the maximum of all 2nd moment running avg. till now 43 | torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i]) 44 | # Use the max. for normalizing running avg. of gradient 45 | denom = (max_exp_avg_sqs[i].sqrt() / math.sqrt(bias_correction2)).add_(eps) 46 | else: 47 | denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) 48 | 49 | step_size = lr / bias_correction1 50 | 51 | param.addcdiv_(exp_avg, denom, value=-step_size) 52 | 53 | 54 | class Adam(Optimizer): 55 | r"""Implements Adam algorithm. 56 | It has been proposed in `Adam: A Method for Stochastic Optimization`_. 57 | The implementation of the L2 penalty follows changes proposed in 58 | `Decoupled Weight Decay Regularization`_. 59 | Args: 60 | params (iterable): iterable of parameters to optimize or dicts defining 61 | parameter groups 62 | lr (float, optional): learning rate (default: 1e-3) 63 | betas (Tuple[float, float], optional): coefficients used for computing 64 | running averages of gradient and its square (default: (0.9, 0.999)) 65 | eps (float, optional): term added to the denominator to improve 66 | numerical stability (default: 1e-8) 67 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 68 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this 69 | algorithm from the paper `On the Convergence of Adam and Beyond`_ 70 | (default: False) 71 | .. _Adam\: A Method for Stochastic Optimization: 72 | https://arxiv.org/abs/1412.6980 73 | .. _Decoupled Weight Decay Regularization: 74 | https://arxiv.org/abs/1711.05101 75 | .. _On the Convergence of Adam and Beyond: 76 | https://openreview.net/forum?id=ryQu7f-RZ 77 | """ 78 | 79 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 80 | weight_decay=0, amsgrad=False): 81 | if not 0.0 <= lr: 82 | raise ValueError("Invalid learning rate: {}".format(lr)) 83 | if not 0.0 <= eps: 84 | raise ValueError("Invalid epsilon value: {}".format(eps)) 85 | if not 0.0 <= betas[0] < 1.0: 86 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 87 | if not 0.0 <= betas[1] < 1.0: 88 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 89 | if not 0.0 <= weight_decay: 90 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 91 | defaults = dict(lr=lr, betas=betas, eps=eps, 92 | weight_decay=weight_decay, amsgrad=amsgrad) 93 | super(Adam, self).__init__(params, defaults) 94 | 95 | def __setstate__(self, state): 96 | super(Adam, self).__setstate__(state) 97 | for group in self.param_groups: 98 | group.setdefault('amsgrad', False) 99 | 100 | @torch.no_grad() 101 | def step(self, closure=None): 102 | """Performs a single optimization step. 103 | Args: 104 | closure (callable, optional): A closure that reevaluates the model 105 | and returns the loss. 106 | """ 107 | loss = None 108 | if closure is not None: 109 | with torch.enable_grad(): 110 | loss = closure() 111 | 112 | for group in self.param_groups: 113 | params_with_grad = [] 114 | grads = [] 115 | exp_avgs = [] 116 | exp_avg_sqs = [] 117 | max_exp_avg_sqs = [] 118 | state_steps = [] 119 | beta1, beta2 = group['betas'] 120 | 121 | for p in group['params']: 122 | if p.grad is not None: 123 | params_with_grad.append(p) 124 | if p.grad.is_sparse: 125 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 126 | grads.append(p.grad) 127 | 128 | state = self.state[p] 129 | # Lazy state initialization 130 | if len(state) == 0: 131 | state['step'] = 0 132 | # Exponential moving average of gradient values 133 | state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) 134 | # Exponential moving average of squared gradient values 135 | state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 136 | if group['amsgrad']: 137 | # Maintains max of all exp. moving avg. of sq. grad. values 138 | state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 139 | 140 | exp_avgs.append(state['exp_avg']) 141 | exp_avg_sqs.append(state['exp_avg_sq']) 142 | 143 | if group['amsgrad']: 144 | max_exp_avg_sqs.append(state['max_exp_avg_sq']) 145 | 146 | # update the steps for each param group update 147 | state['step'] += 1 148 | # record the step after step update 149 | state_steps.append(state['step']) 150 | 151 | adam(params_with_grad, 152 | grads, 153 | exp_avgs, 154 | exp_avg_sqs, 155 | max_exp_avg_sqs, 156 | state_steps, 157 | amsgrad=group['amsgrad'], 158 | beta1=beta1, 159 | beta2=beta2, 160 | lr=group['lr'], 161 | weight_decay=group['weight_decay'], 162 | eps=group['eps']) 163 | return loss 164 | 165 | 166 | 167 | 168 | -------------------------------------------------------------------------------- /CustomDataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import random 3 | import os 4 | import torch 5 | 6 | def GLOBAL_to_LGR_path(global_lists, key, names, var): 7 | lgr_list = [] 8 | for path in global_lists: 9 | case = path.split('/')[-1] 10 | slope = case[:7] 11 | idx = case.split('_')[2] 12 | for nwell in range(1,5): 13 | if var == 'dP': 14 | string = f'{slope}_{idx}_{key}_WELL{nwell}_DP.pt' 15 | if string in names: 16 | home_path = f'/dP_{key}/' 17 | lgr_list.append(home_path + string) 18 | elif var == 'SG': 19 | string = f'{slope}_{idx}_{key}_WELL{nwell}_SG.pt' 20 | if string in names: 21 | home_path = f'/SG_{key}/' 22 | lgr_list.append(home_path + string) 23 | 24 | return lgr_list 25 | 26 | class CustomDataset(Dataset): 27 | def __init__(self, root_path, names): 28 | self.names = names 29 | self.root_path = root_path 30 | 31 | def __len__(self): 32 | return len(self.names) 33 | 34 | def __getitem__(self, idx): 35 | path = self.names[idx] 36 | data = torch.load(self.root_path+path) 37 | 38 | name = path.split('/')[-1] 39 | slope, idx, well = name[:7], name.split('_')[2], name.split('_')[-2] 40 | 41 | x = data['input'].permute(0,4,1,2,3,5)[0,...] 42 | y = data['output'].permute(0,4,1,2,3,5)[0,...,:1] 43 | 44 | D = {'x': x, 45 | 'y': y, 46 | 'path': [slope, idx, well]} 47 | return D -------------------------------------------------------------------------------- /DATA_LOADER_DICT.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gegewen/nested-fno/eba03af76fabef2e7fac104fab5171537d48ae65/DATA_LOADER_DICT.pth -------------------------------------------------------------------------------- /FNO4D.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | import operator 7 | from functools import reduce 8 | from functools import partial 9 | 10 | from timeit import default_timer 11 | 12 | torch.manual_seed(0) 13 | np.random.seed(0) 14 | 15 | class SpectralConv4d(nn.Module): 16 | def __init__(self, in_channels, out_channels, modes1, modes2, modes3, modes4): 17 | super(SpectralConv4d, self).__init__() 18 | 19 | """ 20 | 4D Fourier layer. It does FFT, linear transform, and Inverse FFT. 21 | """ 22 | 23 | self.in_channels = in_channels 24 | self.out_channels = out_channels 25 | self.modes1 = modes1 #Number of Fourier modes to multiply, at most floor(N/2) + 1 26 | self.modes2 = modes2 27 | self.modes3 = modes3 28 | self.modes4 = modes4 29 | 30 | self.scale = (1 / (in_channels * out_channels)) 31 | self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, self.modes4, dtype=torch.cfloat)) 32 | self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, self.modes4, dtype=torch.cfloat)) 33 | self.weights3 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, self.modes4, dtype=torch.cfloat)) 34 | self.weights4 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, self.modes4, dtype=torch.cfloat)) 35 | self.weights5 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, self.modes4, dtype=torch.cfloat)) 36 | self.weights6 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, self.modes4, dtype=torch.cfloat)) 37 | self.weights7 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, self.modes4, dtype=torch.cfloat)) 38 | self.weights8 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, self.modes4, dtype=torch.cfloat)) 39 | 40 | # Complex multiplication 41 | def compl_mul4d(self, input, weights): 42 | # (batch, in_channel, x,y,t ), (in_channel, out_channel, x,y,t) -> (batch, out_channel, x,y,t) 43 | return torch.einsum("bixyzt,ioxyzt->boxyzt", input, weights) 44 | 45 | def forward(self, x): 46 | batchsize = x.shape[0] 47 | #Compute Fourier coeffcients up to factor of e^(- something constant) 48 | x_ft = torch.fft.rfftn(x, dim=[-4,-3,-2,-1]) 49 | 50 | # Multiply relevant Fourier modes 51 | out_ft = torch.zeros(batchsize, self.out_channels, x.size(-4), x.size(-3), x.size(-2), x.size(-1)//2 + 1, dtype=torch.cfloat, device=x.device) 52 | 53 | out_ft[:, :, :self.modes1, :self.modes2, :self.modes3, :self.modes4] = self.compl_mul4d(x_ft[:, :, :self.modes1, :self.modes2, :self.modes3, :self.modes4], self.weights1) 54 | out_ft[:, :, -self.modes1:, :self.modes2, :self.modes3, :self.modes4] = self.compl_mul4d(x_ft[:, :, -self.modes1:, :self.modes2, :self.modes3, :self.modes4], self.weights2) 55 | out_ft[:, :, :self.modes1, -self.modes2:, :self.modes3, :self.modes4] = self.compl_mul4d(x_ft[:, :, :self.modes1, -self.modes2:, :self.modes3, :self.modes4], self.weights3) 56 | out_ft[:, :, :self.modes1, :self.modes2, -self.modes3:, :self.modes4] = self.compl_mul4d(x_ft[:, :, :self.modes1, :self.modes2, -self.modes3:, :self.modes4], self.weights4) 57 | out_ft[:, :, -self.modes1:, -self.modes2:, :self.modes3, :self.modes4] = self.compl_mul4d(x_ft[:, :, -self.modes1:, -self.modes2:, :self.modes3, :self.modes4], self.weights5) 58 | out_ft[:, :, -self.modes1:, :self.modes2, -self.modes3:, :self.modes4] = self.compl_mul4d(x_ft[:, :, -self.modes1:, :self.modes2, -self.modes3:, :self.modes4], self.weights6) 59 | out_ft[:, :, :self.modes1, -self.modes2:, -self.modes3:, :self.modes4] = self.compl_mul4d(x_ft[:, :, :self.modes1, -self.modes2:, -self.modes3:, :self.modes4], self.weights7) 60 | out_ft[:, :, -self.modes1:, -self.modes2:, -self.modes3:, :self.modes4] = self.compl_mul4d(x_ft[:, :, -self.modes1:, -self.modes2:, -self.modes3:, :self.modes4], self.weights8) 61 | 62 | #Return to physical space 63 | x = torch.fft.irfftn(out_ft, s=(x.size(-4), x.size(-3), x.size(-2), x.size(-1))) 64 | return x 65 | 66 | class Block4d(nn.Module): 67 | def __init__(self, width, width2, modes1, modes2, modes3, modes4, out_dim): 68 | super(Block4d, self).__init__() 69 | self.modes1 = modes1 70 | self.modes2 = modes2 71 | self.modes3 = modes3 72 | self.modes4 = modes4 73 | 74 | self.width = width 75 | self.width2 = width2 76 | self.out_dim = out_dim 77 | self.padding = 8 78 | 79 | # channel 80 | self.conv0 = SpectralConv4d(self.width, self.width, self.modes1, self.modes2, self.modes3, self.modes4) 81 | self.conv1 = SpectralConv4d(self.width, self.width, self.modes1, self.modes2, self.modes3, self.modes4) 82 | self.conv2 = SpectralConv4d(self.width, self.width, self.modes1, self.modes2, self.modes3, self.modes4) 83 | self.conv3 = SpectralConv4d(self.width, self.width, self.modes1, self.modes2, self.modes3, self.modes4) 84 | self.w0 = nn.Conv1d(self.width, self.width, 1) 85 | self.w1 = nn.Conv1d(self.width, self.width, 1) 86 | self.w2 = nn.Conv1d(self.width, self.width, 1) 87 | self.w3 = nn.Conv1d(self.width, self.width, 1) 88 | self.fc1 = nn.Linear(self.width, self.width2) 89 | self.fc2 = nn.Linear(self.width2, self.out_dim) 90 | 91 | def forward(self, x): 92 | batchsize = x.shape[0] 93 | size_x, size_y, size_z, size_t = x.shape[2], x.shape[3], x.shape[4], x.shape[5] 94 | # print(size_x, size_y, size_z, size_t) 95 | # channel 96 | # print(x.shape) 97 | x1 = self.conv0(x) 98 | # print(x1.shape) 99 | x2 = self.w0(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y, size_z, size_t) 100 | x = x1 + x2 101 | x = F.gelu(x) 102 | 103 | x1 = self.conv1(x) 104 | x2 = self.w1(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y, size_z, size_t) 105 | x = x1 + x2 106 | x = F.gelu(x) 107 | 108 | x1 = self.conv2(x) 109 | x2 = self.w2(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y, size_z, size_t) 110 | x = x1 + x2 111 | x = F.gelu(x) 112 | 113 | x1 = self.conv3(x) 114 | x2 = self.w3(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y, size_z, size_t) 115 | x = x1 + x2 116 | 117 | x = x[:, :, self.padding:-self.padding, self.padding*2:-self.padding*2, 118 | self.padding*2:-self.padding*2, self.padding:-self.padding] 119 | 120 | x = x.permute(0, 2, 3, 4, 5, 1) # pad the domain if input is non-periodic 121 | x1 = self.fc1(x) 122 | x = F.gelu(x1) 123 | x = self.fc2(x) 124 | 125 | return x 126 | 127 | class FNO4d(nn.Module): 128 | def __init__(self, modes1, modes2, modes3, modes4, width, in_dim): 129 | super(FNO4d, self).__init__() 130 | 131 | self.modes1 = modes1 132 | self.modes2 = modes2 133 | self.modes3 = modes3 134 | self.modes4 = modes4 135 | self.width = width 136 | self.width2 = width*4 137 | self.in_dim = in_dim 138 | self.out_dim = 1 139 | self.padding = 8 # pad the domain if input is non-periodic 140 | 141 | self.fc0 = nn.Linear(self.in_dim, self.width) 142 | self.conv = Block4d(self.width, self.width2, 143 | self.modes1, self.modes2, self.modes3, self.modes4, self.out_dim) 144 | 145 | def forward(self, x, gradient=False): 146 | x = self.fc0(x) 147 | x = x.permute(0, 5, 1, 2, 3, 4) 148 | x = F.pad(x, [self.padding, self.padding, self.padding*2, self.padding*2, self.padding*2, 149 | self.padding*2, self.padding, self.padding]) 150 | 151 | x = self.conv(x) 152 | 153 | return x 154 | 155 | 156 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial-NoDerivatives 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-NonCommercial-NoDerivatives 4.0 58 | International Public License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-NonCommercial-NoDerivatives 4.0 International Public 63 | License ("Public License"). To the extent this Public License may be 64 | interpreted as a contract, You are granted the Licensed Rights in 65 | consideration of Your acceptance of these terms and conditions, and the 66 | Licensor grants You such rights in consideration of benefits the 67 | Licensor receives from making the Licensed Material available under 68 | these terms and conditions. 69 | 70 | 71 | Section 1 -- Definitions. 72 | 73 | a. Adapted Material means material subject to Copyright and Similar 74 | Rights that is derived from or based upon the Licensed Material 75 | and in which the Licensed Material is translated, altered, 76 | arranged, transformed, or otherwise modified in a manner requiring 77 | permission under the Copyright and Similar Rights held by the 78 | Licensor. For purposes of this Public License, where the Licensed 79 | Material is a musical work, performance, or sound recording, 80 | Adapted Material is always produced where the Licensed Material is 81 | synched in timed relation with a moving image. 82 | 83 | b. Copyright and Similar Rights means copyright and/or similar rights 84 | closely related to copyright including, without limitation, 85 | performance, broadcast, sound recording, and Sui Generis Database 86 | Rights, without regard to how the rights are labeled or 87 | categorized. For purposes of this Public License, the rights 88 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 89 | Rights. 90 | 91 | c. Effective Technological Measures means those measures that, in the 92 | absence of proper authority, may not be circumvented under laws 93 | fulfilling obligations under Article 11 of the WIPO Copyright 94 | Treaty adopted on December 20, 1996, and/or similar international 95 | agreements. 96 | 97 | d. Exceptions and Limitations means fair use, fair dealing, and/or 98 | any other exception or limitation to Copyright and Similar Rights 99 | that applies to Your use of the Licensed Material. 100 | 101 | e. Licensed Material means the artistic or literary work, database, 102 | or other material to which the Licensor applied this Public 103 | License. 104 | 105 | f. Licensed Rights means the rights granted to You subject to the 106 | terms and conditions of this Public License, which are limited to 107 | all Copyright and Similar Rights that apply to Your use of the 108 | Licensed Material and that the Licensor has authority to license. 109 | 110 | g. Licensor means the individual(s) or entity(ies) granting rights 111 | under this Public License. 112 | 113 | h. NonCommercial means not primarily intended for or directed towards 114 | commercial advantage or monetary compensation. For purposes of 115 | this Public License, the exchange of the Licensed Material for 116 | other material subject to Copyright and Similar Rights by digital 117 | file-sharing or similar means is NonCommercial provided there is 118 | no payment of monetary compensation in connection with the 119 | exchange. 120 | 121 | i. Share means to provide material to the public by any means or 122 | process that requires permission under the Licensed Rights, such 123 | as reproduction, public display, public performance, distribution, 124 | dissemination, communication, or importation, and to make material 125 | available to the public including in ways that members of the 126 | public may access the material from a place and at a time 127 | individually chosen by them. 128 | 129 | j. Sui Generis Database Rights means rights other than copyright 130 | resulting from Directive 96/9/EC of the European Parliament and of 131 | the Council of 11 March 1996 on the legal protection of databases, 132 | as amended and/or succeeded, as well as other essentially 133 | equivalent rights anywhere in the world. 134 | 135 | k. You means the individual or entity exercising the Licensed Rights 136 | under this Public License. Your has a corresponding meaning. 137 | 138 | 139 | Section 2 -- Scope. 140 | 141 | a. License grant. 142 | 143 | 1. Subject to the terms and conditions of this Public License, 144 | the Licensor hereby grants You a worldwide, royalty-free, 145 | non-sublicensable, non-exclusive, irrevocable license to 146 | exercise the Licensed Rights in the Licensed Material to: 147 | 148 | a. reproduce and Share the Licensed Material, in whole or 149 | in part, for NonCommercial purposes only; and 150 | 151 | b. produce and reproduce, but not Share, Adapted Material 152 | for NonCommercial purposes only. 153 | 154 | 2. Exceptions and Limitations. For the avoidance of doubt, where 155 | Exceptions and Limitations apply to Your use, this Public 156 | License does not apply, and You do not need to comply with 157 | its terms and conditions. 158 | 159 | 3. Term. The term of this Public License is specified in Section 160 | 6(a). 161 | 162 | 4. Media and formats; technical modifications allowed. The 163 | Licensor authorizes You to exercise the Licensed Rights in 164 | all media and formats whether now known or hereafter created, 165 | and to make technical modifications necessary to do so. The 166 | Licensor waives and/or agrees not to assert any right or 167 | authority to forbid You from making technical modifications 168 | necessary to exercise the Licensed Rights, including 169 | technical modifications necessary to circumvent Effective 170 | Technological Measures. For purposes of this Public License, 171 | simply making modifications authorized by this Section 2(a) 172 | (4) never produces Adapted Material. 173 | 174 | 5. Downstream recipients. 175 | 176 | a. Offer from the Licensor -- Licensed Material. Every 177 | recipient of the Licensed Material automatically 178 | receives an offer from the Licensor to exercise the 179 | Licensed Rights under the terms and conditions of this 180 | Public License. 181 | 182 | b. No downstream restrictions. You may not offer or impose 183 | any additional or different terms or conditions on, or 184 | apply any Effective Technological Measures to, the 185 | Licensed Material if doing so restricts exercise of the 186 | Licensed Rights by any recipient of the Licensed 187 | Material. 188 | 189 | 6. No endorsement. Nothing in this Public License constitutes or 190 | may be construed as permission to assert or imply that You 191 | are, or that Your use of the Licensed Material is, connected 192 | with, or sponsored, endorsed, or granted official status by, 193 | the Licensor or others designated to receive attribution as 194 | provided in Section 3(a)(1)(A)(i). 195 | 196 | b. Other rights. 197 | 198 | 1. Moral rights, such as the right of integrity, are not 199 | licensed under this Public License, nor are publicity, 200 | privacy, and/or other similar personality rights; however, to 201 | the extent possible, the Licensor waives and/or agrees not to 202 | assert any such rights held by the Licensor to the limited 203 | extent necessary to allow You to exercise the Licensed 204 | Rights, but not otherwise. 205 | 206 | 2. Patent and trademark rights are not licensed under this 207 | Public License. 208 | 209 | 3. To the extent possible, the Licensor waives any right to 210 | collect royalties from You for the exercise of the Licensed 211 | Rights, whether directly or through a collecting society 212 | under any voluntary or waivable statutory or compulsory 213 | licensing scheme. In all other cases the Licensor expressly 214 | reserves any right to collect such royalties, including when 215 | the Licensed Material is used other than for NonCommercial 216 | purposes. 217 | 218 | 219 | Section 3 -- License Conditions. 220 | 221 | Your exercise of the Licensed Rights is expressly made subject to the 222 | following conditions. 223 | 224 | a. Attribution. 225 | 226 | 1. If You Share the Licensed Material, You must: 227 | 228 | a. retain the following if it is supplied by the Licensor 229 | with the Licensed Material: 230 | 231 | i. identification of the creator(s) of the Licensed 232 | Material and any others designated to receive 233 | attribution, in any reasonable manner requested by 234 | the Licensor (including by pseudonym if 235 | designated); 236 | 237 | ii. a copyright notice; 238 | 239 | iii. a notice that refers to this Public License; 240 | 241 | iv. a notice that refers to the disclaimer of 242 | warranties; 243 | 244 | v. a URI or hyperlink to the Licensed Material to the 245 | extent reasonably practicable; 246 | 247 | b. indicate if You modified the Licensed Material and 248 | retain an indication of any previous modifications; and 249 | 250 | c. indicate the Licensed Material is licensed under this 251 | Public License, and include the text of, or the URI or 252 | hyperlink to, this Public License. 253 | 254 | For the avoidance of doubt, You do not have permission under 255 | this Public License to Share Adapted Material. 256 | 257 | 2. You may satisfy the conditions in Section 3(a)(1) in any 258 | reasonable manner based on the medium, means, and context in 259 | which You Share the Licensed Material. For example, it may be 260 | reasonable to satisfy the conditions by providing a URI or 261 | hyperlink to a resource that includes the required 262 | information. 263 | 264 | 3. If requested by the Licensor, You must remove any of the 265 | information required by Section 3(a)(1)(A) to the extent 266 | reasonably practicable. 267 | 268 | 269 | Section 4 -- Sui Generis Database Rights. 270 | 271 | Where the Licensed Rights include Sui Generis Database Rights that 272 | apply to Your use of the Licensed Material: 273 | 274 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 275 | to extract, reuse, reproduce, and Share all or a substantial 276 | portion of the contents of the database for NonCommercial purposes 277 | only and provided You do not Share Adapted Material; 278 | 279 | b. if You include all or a substantial portion of the database 280 | contents in a database in which You have Sui Generis Database 281 | Rights, then the database in which You have Sui Generis Database 282 | Rights (but not its individual contents) is Adapted Material; and 283 | 284 | c. You must comply with the conditions in Section 3(a) if You Share 285 | all or a substantial portion of the contents of the database. 286 | 287 | For the avoidance of doubt, this Section 4 supplements and does not 288 | replace Your obligations under this Public License where the Licensed 289 | Rights include other Copyright and Similar Rights. 290 | 291 | 292 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 293 | 294 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 295 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 296 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 297 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 298 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 299 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 300 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 301 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 302 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 303 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 304 | 305 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 306 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 307 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 308 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 309 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 310 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 311 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 312 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 313 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 314 | 315 | c. The disclaimer of warranties and limitation of liability provided 316 | above shall be interpreted in a manner that, to the extent 317 | possible, most closely approximates an absolute disclaimer and 318 | waiver of all liability. 319 | 320 | 321 | Section 6 -- Term and Termination. 322 | 323 | a. This Public License applies for the term of the Copyright and 324 | Similar Rights licensed here. However, if You fail to comply with 325 | this Public License, then Your rights under this Public License 326 | terminate automatically. 327 | 328 | b. Where Your right to use the Licensed Material has terminated under 329 | Section 6(a), it reinstates: 330 | 331 | 1. automatically as of the date the violation is cured, provided 332 | it is cured within 30 days of Your discovery of the 333 | violation; or 334 | 335 | 2. upon express reinstatement by the Licensor. 336 | 337 | For the avoidance of doubt, this Section 6(b) does not affect any 338 | right the Licensor may have to seek remedies for Your violations 339 | of this Public License. 340 | 341 | c. For the avoidance of doubt, the Licensor may also offer the 342 | Licensed Material under separate terms or conditions or stop 343 | distributing the Licensed Material at any time; however, doing so 344 | will not terminate this Public License. 345 | 346 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 347 | License. 348 | 349 | 350 | Section 7 -- Other Terms and Conditions. 351 | 352 | a. The Licensor shall not be bound by any additional or different 353 | terms or conditions communicated by You unless expressly agreed. 354 | 355 | b. Any arrangements, understandings, or agreements regarding the 356 | Licensed Material not stated herein are separate from and 357 | independent of the terms and conditions of this Public License. 358 | 359 | 360 | Section 8 -- Interpretation. 361 | 362 | a. For the avoidance of doubt, this Public License does not, and 363 | shall not be interpreted to, reduce, limit, restrict, or impose 364 | conditions on any use of the Licensed Material that could lawfully 365 | be made without permission under this Public License. 366 | 367 | b. To the extent possible, if any provision of this Public License is 368 | deemed unenforceable, it shall be automatically reformed to the 369 | minimum extent necessary to make it enforceable. If the provision 370 | cannot be reformed, it shall be severed from this Public License 371 | without affecting the enforceability of the remaining terms and 372 | conditions. 373 | 374 | c. No term or condition of this Public License will be waived and no 375 | failure to comply consented to unless expressly agreed to by the 376 | Licensor. 377 | 378 | d. Nothing in this Public License constitutes or may be interpreted 379 | as a limitation upon, or waiver of, any privileges and immunities 380 | that apply to the Licensor or You, including from the legal 381 | processes of any jurisdiction or authority. 382 | 383 | ======================================================================= 384 | 385 | Creative Commons is not a party to its public 386 | licenses. Notwithstanding, Creative Commons may elect to apply one of 387 | its public licenses to material it publishes and in those instances 388 | will be considered the “Licensor.” The text of the Creative Commons 389 | public licenses is dedicated to the public domain under the CC0 Public 390 | Domain Dedication. Except for the limited purpose of indicating that 391 | material is shared under a Creative Commons public license or as 392 | otherwise permitted by the Creative Commons policies published at 393 | creativecommons.org/policies, Creative Commons does not authorize the 394 | use of the trademark "Creative Commons" or any other trademark or logo 395 | of Creative Commons without its prior written consent including, 396 | without limitation, in connection with any unauthorized modifications 397 | to any of its public licenses or any other arrangements, 398 | understandings, or agreements concerning use of licensed material. For 399 | the avoidance of doubt, this paragraph does not form part of the 400 | public licenses. 401 | 402 | Creative Commons may be contacted at creativecommons.org. 403 | -------------------------------------------------------------------------------- /README.mD: -------------------------------------------------------------------------------- 1 | ## Make dataloader 2 | - step 1: download meta data from [google drive](https://drive.google.com/drive/u/1/folders/1gElIBiZW6NayuEWxgDn8cv94_4e_LF4-) and put them into `Nested_FNO/ECLIPSE/meta_data` 3 | - step 2: run following code to convert `.npy` file into `.pt` files in `dataset` folder 4 | ``` 5 | cd data_config 6 | bash file_config.sh 7 | cd .. 8 | ``` 9 | - step 3: run `python3 save_data_loader.py` to create `DATA_LOADER_DICT.pth` 10 | 11 | ## Training procedure 12 | - step 1: train each models seperately using the following code. Each model requires an NVIDIA A100 GPU. 13 | ``` 14 | python3 train_FNO4D_DP_GLOBAL.py 15 | python3 train_FNO4D_DP_LGR.py LGR1 16 | python3 train_FNO4D_DP_LGR.py LGR2 17 | python3 train_FNO4D_DP_LGR.py LGR3 18 | python3 train_FNO4D_DP_LGR.py LGR4 19 | python3 train_FNO4D_SG_LGR.py LGR1 20 | python3 train_FNO4D_SG_LGR.py LGR2 21 | python3 train_FNO4D_SG_LGR.py LGR3 22 | python3 train_FNO4D_SG_LGR.py LGR4 23 | ``` 24 | - step 2: monitor training and validation loss with tensorboard 25 | ``` 26 | tensorboard --logdir=logs --port=6007 --host=xxxxxx 27 | ``` 28 | 29 | ## Finetune procedure 30 | As discussed in the paper, we finetuned `dP_LGR1`, `dP_LGR4`, `SG_LGR1`, `SG_LGR1` models with a random instance of pre-generated error. 31 | -------------------------------------------------------------------------------- /UnitGaussianNormalizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | # normalization, pointwise gaussian 4 | class UnitGaussianNormalizer(object): 5 | def __init__(self, x, eps=0.00001): 6 | super(UnitGaussianNormalizer, self).__init__() 7 | 8 | # x could be in shape of ntrain*n or ntrain*T*n or ntrain*n*T 9 | self.mean = torch.mean(x, 0) 10 | self.std = torch.std(x, 0) 11 | self.eps = eps 12 | 13 | def encode(self, x): 14 | x = (x - self.mean) / (self.std + self.eps) 15 | return x 16 | 17 | def decode(self, x, sample_idx=None): 18 | if sample_idx is None: 19 | std = self.std + self.eps # n 20 | mean = self.mean 21 | else: 22 | if len(self.mean.shape) == len(sample_idx[0].shape): 23 | std = self.std[sample_idx] + self.eps # batch*n 24 | mean = self.mean[sample_idx] 25 | if len(self.mean.shape) > len(sample_idx[0].shape): 26 | std = self.std[:,sample_idx]+ self.eps # T*batch*n 27 | mean = self.mean[:,sample_idx] 28 | 29 | # x is in shape of batch*n or T*batch*n 30 | x = (x * std) + mean 31 | return x 32 | 33 | def cuda(self): 34 | self.mean = self.mean.cuda() 35 | self.std = self.std.cuda() 36 | 37 | def cpu(self): 38 | self.mean = self.mean.cpu() 39 | self.std = self.std.cpu() 40 | 41 | # normalization, pointwise gaussian 42 | class TimeGaussianNormalizer(object): 43 | def __init__(self, x): 44 | super(TimeGaussianNormalizer, self).__init__() 45 | 46 | self.mean = torch.mean(x,(0,2,3,4,5))[None,:,None,None,None,None] 47 | self.std = torch.std(x,(0,2,3,4,5))[None,:,None,None,None,None] 48 | 49 | def encode(self, x): 50 | x = (x - self.mean)/self.std 51 | return x 52 | 53 | def decode(self, x): 54 | x = (x * self.std) + self.mean 55 | return x 56 | 57 | def cuda(self): 58 | self.mean = self.mean.cuda() 59 | self.std = self.std.cuda() 60 | 61 | def cpu(self): 62 | self.mean = self.mean.cpu() 63 | self.std = self.std.cpu() 64 | 65 | def param(self): 66 | return self.mean.reshape(24,), self.std.reshape(24,) -------------------------------------------------------------------------------- /config_meta_data.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import torch 3 | import numpy as np 4 | import json 5 | import copy 6 | from vonk3d import * 7 | 8 | def user_input_to_meta_data(ax_m, ay_m, az_m, mu, std, pressure, temp, dip, inj_loc, inj_rate, PERF_DICT): 9 | n_well = len(inj_rate) 10 | WELL_LIST = [f'WELL{i_well+1}' for i_well in range(n_well)] 11 | 12 | coord_well = [] 13 | for well in WELL_LIST: 14 | coord_well.append([int(inj_loc[well][0]//1600-2), int(inj_loc[well][1]//1600)]) 15 | 16 | kmap3d = vonk3d(rseed=0, 17 | dx=1,dy=1,dz=1, 18 | ax=ax_m/400,ay=ay_m/400,az=az_m/2, 19 | ix=400,iy=400,iz=50, 20 | pop=1,med=3,nu=1) 21 | 22 | kmap3d = kmap3d*std + mu 23 | kmap3d = np.exp(kmap3d) 24 | 25 | PERM_DICT = return_PERM_DICT(kmap3d, coord_well) 26 | GRID_IDX_DICT = grid_idx_dict(coord_well) 27 | GRID_CENTER_DICT = grid_dict(WELL_LIST, GRID_IDX_DICT) 28 | LGR_LIST = list(GRID_IDX_DICT['WELL1'].keys()) 29 | PCOLOR_GRID_DICT = pcolor_grid_dict(WELL_LIST, GRID_IDX_DICT) 30 | INJ_LOCATION_DICT = inj_location_dict(PCOLOR_GRID_DICT, WELL_LIST) 31 | INPUT_DICT = input_dict(pressure, temp, inj_rate) 32 | TOPS_DICT = return_tops_dict(pressure, dip, GRID_CENTER_DICT, WELL_LIST, LGR_LIST) 33 | 34 | meta_data = { 35 | 'GRID_IDX_DICT':GRID_IDX_DICT, 36 | 'GRID_CENTER_DICT':GRID_CENTER_DICT, 37 | 'PCOLOR_GRID_DICT':PCOLOR_GRID_DICT, 38 | 'LGR_LIST':LGR_LIST, 39 | 'WELL_LIST':WELL_LIST, 40 | 'PERM_DICT':PERM_DICT, 41 | 'INPUT_DICT':INPUT_DICT, 42 | 'INJ_LOCATION_DICT': INJ_LOCATION_DICT, 43 | 'PERF_DICT':PERF_DICT, 44 | 'TOPS_DICT':TOPS_DICT 45 | } 46 | 47 | return meta_data 48 | 49 | 50 | depth_func = lambda a : (a - 1.01325)/9.8*100 51 | pressure_func = lambda a: a/100*9.8 + 1.01325 52 | 53 | def return_tops(depth, slope, grid_x, grid_z): 54 | nz = grid_z.shape[-1] 55 | adj = np.tan(np.deg2rad(slope)) * (grid_x - 160000/2) 56 | return adj + depth + grid_z 57 | 58 | def return_tops_dict(pressure, dip, GRID_CENTER_DICT, WELL_LIST, LGR_LIST): 59 | TOPS_DICT = {} 60 | TOPS_DICT['GLOBAL'] = return_tops(depth_func(pressure), dip, 61 | GRID_CENTER_DICT['GLOBAL']['grid_x'], GRID_CENTER_DICT['GLOBAL']['grid_z'])[None,...] 62 | for well in WELL_LIST: 63 | d = {} 64 | for lgr in LGR_LIST: 65 | tops = return_tops(depth_func(pressure), dip, 66 | GRID_CENTER_DICT[well][lgr]['grid_x'], 67 | GRID_CENTER_DICT[well][lgr]['grid_z'])[None,...] 68 | d[lgr] = tops 69 | TOPS_DICT[well] = d 70 | return TOPS_DICT 71 | 72 | 73 | def grid_idx_dict(coord_well): 74 | with open('GRID_IDX_DICT.json') as f: 75 | GRID_IDX_DICT = json.load(f) 76 | 77 | n_well = len(coord_well) 78 | for i_well in range(n_well-1): 79 | GRID_IDX_DICT[f'WELL{i_well+2}'] = copy.deepcopy(GRID_IDX_DICT['WELL1']) 80 | 81 | for i_well in range(n_well): 82 | well_x_start, well_x_end = coord_well[i_well][0] - 4, coord_well[i_well][0] + 5 83 | well_y_start, well_y_end = coord_well[i_well][1] - 4, coord_well[i_well][1] + 5 84 | GRID_IDX_DICT[f'WELL{i_well+1}']['LGR1']['I1'] = well_x_start 85 | GRID_IDX_DICT[f'WELL{i_well+1}']['LGR1']['I2'] = well_x_end 86 | GRID_IDX_DICT[f'WELL{i_well+1}']['LGR1']['J1'] = well_y_start 87 | GRID_IDX_DICT[f'WELL{i_well+1}']['LGR1']['J2'] = well_y_end 88 | return GRID_IDX_DICT 89 | 90 | def torch_regrid(x, size): 91 | return F.interpolate(torch.from_numpy(x)[None, None,...], 92 | size=size, mode='trilinear', align_corners=False)[0,0,...].numpy() 93 | 94 | def return_PERM_DICT(kmap, coord_well): 95 | DICT = {} 96 | DICT['GLOBAL'] = torch_regrid(kmap, [100, 100, 5]) 97 | 98 | for i_well in range(len(coord_well)): 99 | d = {} 100 | I1, I2 = (coord_well[i_well][0] - 5)*4, (coord_well[i_well][0] + 5)*4 101 | J1, J2 = (coord_well[i_well][1] - 5)*4, (coord_well[i_well][1] + 5)*4 102 | k_LGR1 = kmap[I1:I2, J1:J2, :] 103 | d['LGR1'] = torch_regrid(k_LGR1, [40, 40, 25]) 104 | 105 | I1, I2 = 18, 38 106 | J1, J2 = 10, 30 107 | d['LGR2'] = torch_regrid(k_LGR1[I1:I2, J1:J2, :], [40, 40, 50]) 108 | 109 | I1, I2 = 10, 30 110 | J1, J2 = 10, 30 111 | d['LGR3'] = torch_regrid(d['LGR2'][I1:I2, J1:J2, :], [40, 40, 50]) 112 | 113 | I1, I2 = 16, 24 114 | J1, J2 = 16, 24 115 | d['LGR4'] = torch_regrid(d['LGR3'][I1:I2, J1:J2, :], [40, 40, 50]) 116 | 117 | d['LGR1'] = d['LGR1'][None,...] 118 | d['LGR2'] = d['LGR2'][None,...] 119 | d['LGR3'] = d['LGR3'][None,...] 120 | d['LGR4'] = d['LGR4'][None,...] 121 | DICT[f'WELL{int(i_well+1)}'] = d 122 | DICT['GLOBAL'] = DICT['GLOBAL'][None,...] 123 | return DICT 124 | 125 | def inj_location_dict(pcolor_grid_dict, WELL_LIST): 126 | INJ_LOCATION_DICT = {} 127 | for well in WELL_LIST: 128 | well_x = pcolor_grid_dict[well]['LGR4']['grid_x'][20,20,0] 129 | well_y = pcolor_grid_dict[well]['LGR4']['grid_y'][20,20,0] 130 | INJ_LOCATION_DICT[well]=well_x, well_y 131 | return INJ_LOCATION_DICT 132 | 133 | def grid_dict(wells, grid_dict): 134 | lgrs = list(grid_dict[wells[0]].keys()) 135 | parents = parents = ['GLOBAL'] + lgrs[:-1] 136 | 137 | # GLOBAL grid 138 | grid_x = np.linspace(grid_dict['GLOBAL']['DX']/2, 139 | 160000 - grid_dict['GLOBAL']['DX']/2, 140 | grid_dict['GLOBAL']['NX']) 141 | grid_y = np.linspace(grid_dict['GLOBAL']['DY']/2, 142 | 160000-grid_dict['GLOBAL']['DY']/2, 143 | grid_dict['GLOBAL']['NY']) 144 | grid_z = np.linspace(grid_dict['GLOBAL']['DZ']/2, 145 | 100-grid_dict['GLOBAL']['DZ']/2, 146 | grid_dict['GLOBAL']['NZ']) 147 | 148 | grid_x, grid_y, grid_z = np.meshgrid(grid_x, grid_y, grid_z,indexing='ij') 149 | 150 | GRID = {} 151 | GRID['GLOBAL'] = {'grid_x': grid_x, 'grid_y': grid_y, 'grid_z': grid_z } 152 | 153 | ######################### grid for LGR1 ######################### 154 | for well in wells: 155 | lgr = 'LGR1' 156 | parent = 'GLOBAL' 157 | x_start = GRID[parent]['grid_x'][grid_dict[well][lgr]['I1']-1,0,0] - grid_dict[parent]['DX']/2 158 | x_end = GRID[parent]['grid_x'][grid_dict[well][lgr]['I2']-1,0,0] + grid_dict[parent]['DX']/2 159 | 160 | grid_x = np.linspace(x_start+grid_dict[well][lgr]['DX']/2, 161 | x_end-grid_dict[well][lgr]['DX']/2, 162 | grid_dict[well][lgr]['NX']) 163 | 164 | y_start = GRID[parent]['grid_y'][0,grid_dict[well][lgr]['J1']-1,0] - grid_dict[parent]['DY']/2 165 | y_end = GRID[parent]['grid_y'][0,grid_dict[well][lgr]['J2']-1,0] + grid_dict[parent]['DY']/2 166 | 167 | grid_y = np.linspace(y_start+grid_dict[well][lgr]['DY']/2, 168 | y_end-grid_dict[well][lgr]['DY']/2, 169 | grid_dict[well][lgr]['NY']) 170 | 171 | z_start = GRID[parent]['grid_z'][0,0,grid_dict[well][lgr]['K1']-1] - grid_dict[parent]['DZ']/2 172 | z_end = GRID[parent]['grid_z'][0,0,grid_dict[well][lgr]['K2']-1] + grid_dict[parent]['DZ']/2 173 | 174 | grid_z = np.linspace(z_start+grid_dict[well][lgr]['DZ']/2, 175 | z_end-grid_dict[well][lgr]['DZ']/2, 176 | grid_dict[well][lgr]['NZ']) 177 | 178 | grid_x, grid_y, grid_z = np.meshgrid(grid_x, grid_y, grid_z,indexing='ij') 179 | GRID[well] = {lgr: {'grid_x': grid_x, 'grid_y': grid_y, 'grid_z': grid_z }} 180 | 181 | ######################### grid for LGR2 and up ######################### 182 | 183 | lgrs = lgrs[1:] 184 | parents = parents[1:] 185 | 186 | for well in wells: 187 | for i in range(len(lgrs)): 188 | lgr = lgrs[i] 189 | parent = parents[i] 190 | 191 | x_start = GRID[well][parent]['grid_x'][grid_dict[well][lgr]['I1']-1,0,0] - grid_dict[well][parent]['DX']/2 192 | x_end = GRID[well][parent]['grid_x'][grid_dict[well][lgr]['I2']-1,0,0] + grid_dict[well][parent]['DX']/2 193 | 194 | grid_x = np.linspace(x_start+grid_dict[well][lgr]['DX']/2, 195 | x_end-grid_dict[well][lgr]['DX']/2, 196 | grid_dict[well][lgr]['NX']) 197 | 198 | y_start = GRID[well][parent]['grid_y'][0,grid_dict[well][lgr]['J1']-1,0] - grid_dict[well][parent]['DY']/2 199 | y_end = GRID[well][parent]['grid_y'][0,grid_dict[well][lgr]['J2']-1,0] + grid_dict[well][parent]['DY']/2 200 | 201 | grid_y = np.linspace(y_start+grid_dict[well][lgr]['DY']/2, 202 | y_end-grid_dict[well][lgr]['DY']/2, 203 | grid_dict[well][lgr]['NY']) 204 | 205 | z_start = GRID[well][parent]['grid_z'][0,0,grid_dict[well][lgr]['K1']-1] - grid_dict[well][parent]['DZ']/2 206 | z_end = GRID[well][parent]['grid_z'][0,0,grid_dict[well][lgr]['K2']-1] + grid_dict[well][parent]['DZ']/2 207 | 208 | grid_z = np.linspace(z_start+grid_dict[well][lgr]['DZ']/2, 209 | z_end-grid_dict[well][lgr]['DZ']/2, 210 | grid_dict[well][lgr]['NZ']) 211 | grid_x, grid_y, grid_z = np.meshgrid(grid_x, grid_y, grid_z,indexing='ij') 212 | GRID[well][lgr] = {'grid_x': grid_x, 'grid_y': grid_y, 'grid_z': grid_z } 213 | 214 | return GRID 215 | 216 | def pcolor_grid_dict(wells, grid_dict): 217 | lgrs = list(grid_dict[wells[0]].keys()) 218 | parents = parents = ['GLOBAL'] + lgrs[:-1] 219 | 220 | # GLOBAL grid 221 | grid_x = np.linspace(0, 160000, grid_dict['GLOBAL']['NX']+1) 222 | grid_y = np.linspace(0, 160000, grid_dict['GLOBAL']['NY']+1) 223 | grid_z = np.linspace(0, 100, grid_dict['GLOBAL']['NZ']+1) 224 | 225 | grid_x, grid_y, grid_z = np.meshgrid(grid_x, grid_y, grid_z,indexing='ij') 226 | 227 | GRID = {} 228 | GRID['GLOBAL'] = {'grid_x': grid_x, 'grid_y': grid_y, 'grid_z': grid_z } 229 | 230 | ######################### grid for LGR1 ######################### 231 | for well in wells: 232 | lgr = 'LGR1' 233 | parent = 'GLOBAL' 234 | x_start = GRID[parent]['grid_x'][grid_dict[well][lgr]['I1']-1,0,0] 235 | x_end = GRID[parent]['grid_x'][grid_dict[well][lgr]['I2'],0,0] 236 | 237 | grid_x = np.linspace(x_start, x_end, grid_dict[well][lgr]['NX']+1) 238 | 239 | y_start = GRID[parent]['grid_y'][0,grid_dict[well][lgr]['J1']-1,0] 240 | y_end = GRID[parent]['grid_y'][0,grid_dict[well][lgr]['J2'],0] 241 | 242 | grid_y = np.linspace(y_start, y_end, grid_dict[well][lgr]['NY']+1) 243 | 244 | z_start = GRID[parent]['grid_z'][0,0,grid_dict[well][lgr]['K1']-1] 245 | z_end = GRID[parent]['grid_z'][0,0,grid_dict[well][lgr]['K2']] 246 | 247 | grid_z = np.linspace(z_start, z_end, grid_dict[well][lgr]['NZ']+1) 248 | 249 | grid_x, grid_y, grid_z = np.meshgrid(grid_x, grid_y, grid_z,indexing='ij') 250 | GRID[well] = {lgr: {'grid_x': grid_x, 'grid_y': grid_y, 'grid_z': grid_z }} 251 | 252 | ######################### grid for LGR2 and up ######################### 253 | 254 | lgrs = lgrs[1:] 255 | parents = parents[1:] 256 | 257 | for well in wells: 258 | for i in range(len(lgrs)): 259 | lgr = lgrs[i] 260 | parent = parents[i] 261 | 262 | x_start = GRID[well][parent]['grid_x'][grid_dict[well][lgr]['I1']-1,0,0] 263 | x_end = GRID[well][parent]['grid_x'][grid_dict[well][lgr]['I2'],0,0] 264 | 265 | grid_x = np.linspace(x_start, x_end, grid_dict[well][lgr]['NX']+1) 266 | 267 | y_start = GRID[well][parent]['grid_y'][0,grid_dict[well][lgr]['J1']-1,0] 268 | y_end = GRID[well][parent]['grid_y'][0,grid_dict[well][lgr]['J2'],0] 269 | 270 | grid_y = np.linspace(y_start, y_end, grid_dict[well][lgr]['NY']+1) 271 | 272 | z_start = GRID[well][parent]['grid_z'][0,0,grid_dict[well][lgr]['K1']-1] 273 | z_end = GRID[well][parent]['grid_z'][0,0,grid_dict[well][lgr]['K2']] 274 | 275 | grid_z = np.linspace(z_start, z_end, grid_dict[well][lgr]['NZ']+1) 276 | grid_x, grid_y, grid_z = np.meshgrid(grid_x, grid_y, grid_z,indexing='ij') 277 | GRID[well][lgr] = {'grid_x': grid_x, 'grid_y': grid_y, 'grid_z': grid_z } 278 | 279 | return GRID 280 | 281 | def input_dict(p, temp, inj_rate_dict): 282 | INPUT_PARAM = {'temp': temp, 283 | 'p': p, 284 | 'inj': inj_rate_dict} 285 | return INPUT_PARAM -------------------------------------------------------------------------------- /config_py_utility.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn.functional as F 3 | import torch 4 | 5 | def return_upsample_dict(OUTPUT_DICT, WELL_LIST, GRID_IDX_DICT): 6 | OUTPUT_UPSAMPLE_DICT = {} 7 | 8 | LGR_BEFORE = ['LGR3', 'LGR2', 'LGR1'] 9 | LGR_AFTER = ['LGR4', 'LGR3', 'LGR2'] 10 | 11 | for well in WELL_LIST: 12 | OUTPUT_UPSAMPLE_DICT[well] = {'LGR4': OUTPUT_DICT[well]['LGR4']} 13 | for iii in range(3): 14 | lgr_before = LGR_BEFORE[iii] 15 | lgr_after = LGR_AFTER[iii] 16 | 17 | upsampled = np.copy(OUTPUT_DICT[well][lgr_before][-1,:,:,:]) 18 | nx_new = GRID_IDX_DICT[well][lgr_after]['I2'] - GRID_IDX_DICT[well][lgr_after]['I1'] + 1 19 | ny_new = GRID_IDX_DICT[well][lgr_after]['J2'] - GRID_IDX_DICT[well][lgr_after]['J1'] + 1 20 | nz_new = GRID_IDX_DICT[well][lgr_after]['K2'] - GRID_IDX_DICT[well][lgr_after]['K1'] + 1 21 | 22 | A = F.interpolate(torch.from_numpy(OUTPUT_UPSAMPLE_DICT[well][lgr_after][-1,:,:,:])[None, None,...], 23 | size=[nx_new,ny_new,nz_new], mode='trilinear', align_corners=False)[0,0,...].numpy() 24 | 25 | upsampled[GRID_IDX_DICT[well][lgr_after]['I1']-1:GRID_IDX_DICT[well][lgr_after]['I2'], 26 | GRID_IDX_DICT[well][lgr_after]['J1']-1:GRID_IDX_DICT[well][lgr_after]['J2'],:] = A 27 | 28 | if well in OUTPUT_UPSAMPLE_DICT: 29 | OUTPUT_UPSAMPLE_DICT[well].update({lgr_before: upsampled[None,...]}) 30 | else: 31 | OUTPUT_UPSAMPLE_DICT[well]={lgr_before: upsampled[None,...]} 32 | 33 | upsampled = np.copy(OUTPUT_DICT['GLOBAL'][-1,:,:,:]) 34 | for well in WELL_LIST: 35 | nx_new = GRID_IDX_DICT[well]['LGR1']['I2'] - GRID_IDX_DICT[well]['LGR1']['I1'] + 1 36 | ny_new = GRID_IDX_DICT[well]['LGR1']['J2'] - GRID_IDX_DICT[well]['LGR1']['J1'] + 1 37 | nz_new = GRID_IDX_DICT[well]['LGR1']['K2'] - GRID_IDX_DICT[well]['LGR1']['K1'] + 1 38 | A = F.interpolate(torch.from_numpy(OUTPUT_UPSAMPLE_DICT[well]['LGR1'][-1,:,:,:])[None, None,...], 39 | size=[nx_new,ny_new,nz_new], mode='trilinear', align_corners=False)[0,0,...].numpy() 40 | upsampled[GRID_IDX_DICT[well]['LGR1']['I1']-1:GRID_IDX_DICT[well]['LGR1']['I2'], 41 | GRID_IDX_DICT[well]['LGR1']['J1']-1:GRID_IDX_DICT[well]['LGR1']['J2'],:] = A 42 | OUTPUT_UPSAMPLE_DICT['GLOBAL'] = upsampled 43 | 44 | return OUTPUT_UPSAMPLE_DICT 45 | 46 | def return_OUTPUT_DICT(meta_data, case_name): 47 | nt = list(meta_data[case_name]['data'].keys()) 48 | OUT = {} 49 | GRID_IDX_DICT = meta_data[case_name]['GRID_IDX_DICT'] 50 | WELL_LIST = meta_data[case_name]['WELL_LIST'] 51 | LGR_LIST = meta_data[case_name]['LGR_LIST'] 52 | 53 | for name in [ 'BGSAT', 'BPR']: 54 | out = {} 55 | lname = f'L{name}' 56 | for t in nt: 57 | data = meta_data[case_name]['data'][t] 58 | output_dict = {} 59 | 60 | output_dict['GLOBAL'] = data[name].reshape((-1, GRID_IDX_DICT['GLOBAL']['NX'], 61 | GRID_IDX_DICT['GLOBAL']['NY'], 62 | GRID_IDX_DICT['GLOBAL']['NZ'])) 63 | N_LIST = [0] 64 | idx = 0 65 | for well in WELL_LIST: 66 | for lgr in LGR_LIST: 67 | n_prev = N_LIST[idx] 68 | idx += 1 69 | n_cur = n_prev+GRID_IDX_DICT[well][lgr]['NX'] * GRID_IDX_DICT[well][lgr]['NY'] * GRID_IDX_DICT[well][lgr]['NZ'] 70 | N_LIST.append(n_cur) 71 | 72 | if well in output_dict: 73 | output_dict[well].update({lgr: data[lname][:,n_prev: n_cur].reshape(-1, 74 | GRID_IDX_DICT[well][lgr]['NX'], 75 | GRID_IDX_DICT[well][lgr]['NY'], 76 | GRID_IDX_DICT[well][lgr]['NZ']) }) 77 | else: 78 | output_dict[well] = {lgr: data[lname][:,n_prev: n_cur].reshape(-1, 79 | GRID_IDX_DICT[well][lgr]['NX'], 80 | GRID_IDX_DICT[well][lgr]['NY'], 81 | GRID_IDX_DICT[well][lgr]['NZ']) } 82 | out[t] = output_dict 83 | OUT[name] = out 84 | 85 | out = {} 86 | for t in nt: 87 | output_dict = {} 88 | output_dict['GLOBAL'] = OUT['BPR'][t]['GLOBAL'] - OUT['BPR'][0]['GLOBAL'] 89 | 90 | for well in WELL_LIST: 91 | for lgr in LGR_LIST: 92 | if well in output_dict: 93 | output_dict[well].update({lgr: OUT['BPR'][t][well][lgr] - OUT['BPR'][0][well][lgr]}) 94 | else: 95 | output_dict[well] = {lgr: OUT['BPR'][t][well][lgr] - OUT['BPR'][0][well][lgr]} 96 | out[t] = output_dict 97 | OUT['dP'] = out 98 | 99 | out = {} 100 | for t in nt: 101 | output_dict = {} 102 | output_dict['GLOBAL'] = OUT['dP'][t]['GLOBAL'] > 0.1 103 | 104 | for well in WELL_LIST: 105 | for lgr in LGR_LIST: 106 | if well in output_dict: 107 | output_dict[well].update({lgr: OUT['dP'][t][well][lgr] > 0.1}) 108 | else: 109 | output_dict[well] = {lgr: OUT['dP'][t][well][lgr] > 0.1 } 110 | out[t] = output_dict 111 | OUT['P_influence'] = out 112 | return OUT 113 | 114 | 115 | def return_upsample_all_time(OUTPUT_DICT, name, WELL_LIST, GRID_IDX_DICT,LGR_LIST): 116 | OUT = {} 117 | OUT['GLOBAL'] = np.zeros((1, 100, 100, 5, 24)) 118 | for well in WELL_LIST: 119 | for lgr in LGR_LIST: 120 | nx, ny, nz = GRID_IDX_DICT[well][lgr]['NX'], GRID_IDX_DICT[well][lgr]['NY'], GRID_IDX_DICT[well][lgr]['NZ'] 121 | if well in OUT: 122 | OUT[well].update({lgr: np.zeros((1,nx,ny,nz,24))}) 123 | else: 124 | OUT[well]={lgr: np.zeros((1,nx,ny,nz,24))} 125 | 126 | for t in range(1,25): 127 | up_sample_dict = return_upsample_dict(OUTPUT_DICT, t, name, WELL_LIST, GRID_IDX_DICT) 128 | OUT['GLOBAL'][0,:,:,:,t-1] = up_sample_dict['GLOBAL'] 129 | 130 | for well in WELL_LIST: 131 | for lgr in LGR_LIST: 132 | OUT[well][lgr][0,:,:,:,t-1] = up_sample_dict[well][lgr] 133 | 134 | return OUT 135 | 136 | def return_inj_map_dict(well_list,rate_dict,inj_loc_dict,center_dict, LGR_LIST): 137 | inj_norm = lambda x: (x)/(2942777.68785957) 138 | 139 | INJ_MAP_DICT = {} 140 | 141 | inj_map = np.zeros(center_dict['GLOBAL']['grid_x'].shape) 142 | for well in well_list: 143 | well_x, well_y = inj_loc_dict[well] 144 | xidx = (np.abs(center_dict['GLOBAL']['grid_x'][:,0,0] - well_x)).argmin() 145 | yidx = (np.abs(center_dict['GLOBAL']['grid_y'][0,:,0] - well_y)).argmin() 146 | inj_map[xidx, yidx, :] = inj_norm(rate_dict[well]) 147 | INJ_MAP_DICT['GLOBAL'] = inj_map 148 | 149 | for well in well_list: 150 | well_x, well_y = inj_loc_dict[well] 151 | for lgr in LGR_LIST: 152 | inj_map = np.zeros(center_dict[well][lgr]['grid_x'].shape) 153 | xidx = (np.abs(center_dict[well][lgr]['grid_x'][:,0,0] - well_x)).argmin() 154 | yidx = (np.abs(center_dict[well][lgr]['grid_y'][0,:,0] - well_y)).argmin() 155 | inj_map[xidx, yidx, :] = inj_norm(rate_dict[well]) 156 | if well in INJ_MAP_DICT: 157 | INJ_MAP_DICT[well].update({lgr: inj_map}) 158 | else: 159 | INJ_MAP_DICT[well]={lgr: inj_map} 160 | 161 | return INJ_MAP_DICT -------------------------------------------------------------------------------- /data_config/config_utility.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn.functional as F 3 | import torch 4 | 5 | def return_OUTPUT_DICT(meta_data, case_name): 6 | nt = list(meta_data[case_name]['data'].keys()) 7 | OUT = {} 8 | GRID_IDX_DICT = meta_data[case_name]['GRID_IDX_DICT'] 9 | WELL_LIST = meta_data[case_name]['WELL_LIST'] 10 | LGR_LIST = meta_data[case_name]['LGR_LIST'] 11 | 12 | for name in [ 'BGSAT', 'BPR']: 13 | out = {} 14 | lname = f'L{name}' 15 | for t in nt: 16 | data = meta_data[case_name]['data'][t] 17 | output_dict = {} 18 | 19 | output_dict['GLOBAL'] = data[name].reshape((-1, GRID_IDX_DICT['GLOBAL']['NX'], 20 | GRID_IDX_DICT['GLOBAL']['NY'], 21 | GRID_IDX_DICT['GLOBAL']['NZ'])) 22 | N_LIST = [0] 23 | idx = 0 24 | for well in WELL_LIST: 25 | for lgr in LGR_LIST: 26 | n_prev = N_LIST[idx] 27 | idx += 1 28 | n_cur = n_prev+GRID_IDX_DICT[well][lgr]['NX'] * GRID_IDX_DICT[well][lgr]['NY'] * GRID_IDX_DICT[well][lgr]['NZ'] 29 | N_LIST.append(n_cur) 30 | 31 | if well in output_dict: 32 | output_dict[well].update({lgr: data[lname][:,n_prev: n_cur].reshape(-1, 33 | GRID_IDX_DICT[well][lgr]['NX'], 34 | GRID_IDX_DICT[well][lgr]['NY'], 35 | GRID_IDX_DICT[well][lgr]['NZ']) }) 36 | else: 37 | output_dict[well] = {lgr: data[lname][:,n_prev: n_cur].reshape(-1, 38 | GRID_IDX_DICT[well][lgr]['NX'], 39 | GRID_IDX_DICT[well][lgr]['NY'], 40 | GRID_IDX_DICT[well][lgr]['NZ']) } 41 | out[t] = output_dict 42 | OUT[name] = out 43 | 44 | out = {} 45 | for t in nt: 46 | output_dict = {} 47 | output_dict['GLOBAL'] = OUT['BPR'][t]['GLOBAL'] - OUT['BPR'][0]['GLOBAL'] 48 | 49 | for well in WELL_LIST: 50 | for lgr in LGR_LIST: 51 | if well in output_dict: 52 | output_dict[well].update({lgr: OUT['BPR'][t][well][lgr] - OUT['BPR'][0][well][lgr]}) 53 | else: 54 | output_dict[well] = {lgr: OUT['BPR'][t][well][lgr] - OUT['BPR'][0][well][lgr]} 55 | out[t] = output_dict 56 | OUT['dP'] = out 57 | 58 | out = {} 59 | for t in nt: 60 | output_dict = {} 61 | output_dict['GLOBAL'] = OUT['dP'][t]['GLOBAL'] > 0.1 62 | 63 | for well in WELL_LIST: 64 | for lgr in LGR_LIST: 65 | if well in output_dict: 66 | output_dict[well].update({lgr: OUT['dP'][t][well][lgr] > 0.1}) 67 | else: 68 | output_dict[well] = {lgr: OUT['dP'][t][well][lgr] > 0.1 } 69 | out[t] = output_dict 70 | OUT['P_influence'] = out 71 | return OUT 72 | 73 | def return_upsample_dict(OUTPUT_DICT, t, name, WELL_LIST, GRID_IDX_DICT): 74 | OUTPUT_UPSAMPLE_DICT = {} 75 | 76 | LGR_BEFORE = ['LGR3', 'LGR2', 'LGR1'] 77 | LGR_AFTER = ['LGR4', 'LGR3', 'LGR2'] 78 | 79 | for well in WELL_LIST: 80 | OUTPUT_UPSAMPLE_DICT[well] = {'LGR4': OUTPUT_DICT[name][t][well]['LGR4']} 81 | for iii in range(3): 82 | lgr_before = LGR_BEFORE[iii] 83 | lgr_after = LGR_AFTER[iii] 84 | 85 | upsampled = np.copy(OUTPUT_DICT[name][t][well][lgr_before][-1,:,:,:]) 86 | nx_new = GRID_IDX_DICT[well][lgr_after]['I2'] - GRID_IDX_DICT[well][lgr_after]['I1'] + 1 87 | ny_new = GRID_IDX_DICT[well][lgr_after]['J2'] - GRID_IDX_DICT[well][lgr_after]['J1'] + 1 88 | nz_new = GRID_IDX_DICT[well][lgr_after]['K2'] - GRID_IDX_DICT[well][lgr_after]['K1'] + 1 89 | 90 | A = F.interpolate(torch.from_numpy(OUTPUT_UPSAMPLE_DICT[well][lgr_after][-1,:,:,:])[None, None,...], 91 | size=[nx_new,ny_new,nz_new], mode='trilinear', align_corners=False)[0,0,...].numpy() 92 | 93 | upsampled[GRID_IDX_DICT[well][lgr_after]['I1']-1:GRID_IDX_DICT[well][lgr_after]['I2'], 94 | GRID_IDX_DICT[well][lgr_after]['J1']-1:GRID_IDX_DICT[well][lgr_after]['J2'],:] = A 95 | 96 | if well in OUTPUT_UPSAMPLE_DICT: 97 | OUTPUT_UPSAMPLE_DICT[well].update({lgr_before: upsampled[None,...]}) 98 | else: 99 | OUTPUT_UPSAMPLE_DICT[well]={lgr_before: upsampled[None,...]} 100 | 101 | upsampled = np.copy(OUTPUT_DICT[name][t]['GLOBAL'][-1,:,:,:]) 102 | for well in WELL_LIST: 103 | nx_new = GRID_IDX_DICT[well]['LGR1']['I2'] - GRID_IDX_DICT[well]['LGR1']['I1'] + 1 104 | ny_new = GRID_IDX_DICT[well]['LGR1']['J2'] - GRID_IDX_DICT[well]['LGR1']['J1'] + 1 105 | nz_new = GRID_IDX_DICT[well]['LGR1']['K2'] - GRID_IDX_DICT[well]['LGR1']['K1'] + 1 106 | A = F.interpolate(torch.from_numpy(OUTPUT_UPSAMPLE_DICT[well]['LGR1'][-1,:,:,:])[None, None,...], 107 | size=[nx_new,ny_new,nz_new], mode='trilinear', align_corners=False)[0,0,...].numpy() 108 | upsampled[GRID_IDX_DICT[well]['LGR1']['I1']-1:GRID_IDX_DICT[well]['LGR1']['I2'], 109 | GRID_IDX_DICT[well]['LGR1']['J1']-1:GRID_IDX_DICT[well]['LGR1']['J2'],:] = A 110 | OUTPUT_UPSAMPLE_DICT['GLOBAL'] = upsampled 111 | 112 | return OUTPUT_UPSAMPLE_DICT 113 | 114 | def return_upsample_all_time(OUTPUT_DICT, name, WELL_LIST, GRID_IDX_DICT,LGR_LIST): 115 | OUT = {} 116 | OUT['GLOBAL'] = np.zeros((1, 100, 100, 5, 24)) 117 | for well in WELL_LIST: 118 | for lgr in LGR_LIST: 119 | nx, ny, nz = GRID_IDX_DICT[well][lgr]['NX'], GRID_IDX_DICT[well][lgr]['NY'], GRID_IDX_DICT[well][lgr]['NZ'] 120 | if well in OUT: 121 | OUT[well].update({lgr: np.zeros((1,nx,ny,nz,24))}) 122 | else: 123 | OUT[well]={lgr: np.zeros((1,nx,ny,nz,24))} 124 | 125 | for t in range(1,25): 126 | up_sample_dict = return_upsample_dict(OUTPUT_DICT, t, name, WELL_LIST, GRID_IDX_DICT) 127 | OUT['GLOBAL'][0,:,:,:,t-1] = up_sample_dict['GLOBAL'] 128 | 129 | for well in WELL_LIST: 130 | for lgr in LGR_LIST: 131 | OUT[well][lgr][0,:,:,:,t-1] = up_sample_dict[well][lgr] 132 | 133 | return OUT 134 | 135 | def return_inj_map_dict(well_list,rate_dict,inj_loc_dict,center_dict, LGR_LIST): 136 | inj_norm = lambda x: (x)/(2942777.68785957) 137 | 138 | INJ_MAP_DICT = {} 139 | 140 | inj_map = np.zeros(center_dict['GLOBAL']['grid_x'].shape) 141 | for well in well_list: 142 | well_x, well_y = inj_loc_dict[well] 143 | xidx = (np.abs(center_dict['GLOBAL']['grid_x'][:,0,0] - well_x)).argmin() 144 | yidx = (np.abs(center_dict['GLOBAL']['grid_y'][0,:,0] - well_y)).argmin() 145 | inj_map[xidx, yidx, :] = inj_norm(rate_dict[well]) 146 | INJ_MAP_DICT['GLOBAL'] = inj_map 147 | 148 | for well in well_list: 149 | well_x, well_y = inj_loc_dict[well] 150 | for lgr in LGR_LIST: 151 | inj_map = np.zeros(center_dict[well][lgr]['grid_x'].shape) 152 | xidx = (np.abs(center_dict[well][lgr]['grid_x'][:,0,0] - well_x)).argmin() 153 | yidx = (np.abs(center_dict[well][lgr]['grid_y'][0,:,0] - well_y)).argmin() 154 | inj_map[xidx, yidx, :] = inj_norm(rate_dict[well]) 155 | if well in INJ_MAP_DICT: 156 | INJ_MAP_DICT[well].update({lgr: inj_map}) 157 | else: 158 | INJ_MAP_DICT[well]={lgr: inj_map} 159 | 160 | return INJ_MAP_DICT -------------------------------------------------------------------------------- /data_config/file_config.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | echo file_config_GLOBAL_dP.py 4 | python3 file_config_GLOBAL_dP.py 5 | echo file_config_LGR1_dP.py 6 | python3 file_config_LGR1_dP.py 7 | echo file_config_LGR1_SG.py 8 | python3 file_config_LGR1_SG.py 9 | echo file_config_LGR2_dP.py 10 | python3 file_config_LGR2_dP.py 11 | echo file_config_LGR2_SG.py 12 | python3 file_config_LGR2_SG.py 13 | echo file_config_LGR3_dP.py 14 | python3 file_config_LGR3_dP.py 15 | echo file_config_LGR3_SG.py 16 | python3 file_config_LGR3_SG.py 17 | echo file_config_LGR4_dP.py 18 | python3 file_config_LGR4_dP.py 19 | echo file_config_LGR4_SG.py 20 | python3 file_config_LGR4_SG.py -------------------------------------------------------------------------------- /data_config/file_config_GLOBAL_dP.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import matplotlib as mpl 4 | import matplotlib.patches as patches 5 | import glob 6 | import os 7 | from config_utility import * 8 | 9 | import torch.nn.functional as F 10 | import torch 11 | 12 | xy_norm = lambda x: (x)/160000 13 | z_norm = lambda x: (x-2000)/2000 14 | p_norm = lambda x: (x)/172 15 | t_norm = lambda x: (x)/70 16 | k_norm = lambda x: (x)/100 17 | 18 | times = np.cumsum(10*np.array(np.power(1.2531,np.arange(1,25,1)), dtype=int)) 19 | times = times/ 10950 20 | 21 | META_PATH = '..' 22 | meta_files = glob.glob(f'{META_PATH}/ECLIPSE/meta_data/*.npy') 23 | print('meta data:', len(meta_files)) 24 | 25 | PT_PATH = f'../dataset/dP_GLOBAL/' 26 | if not os.path.exists(PT_PATH): 27 | os.mkdir(PT_PATH) 28 | 29 | pt_files = os.listdir(PT_PATH) 30 | print('done collected:', len(pt_files)) 31 | 32 | collect_index = [] 33 | for name in meta_files: 34 | name = name.split('/')[-1][:-4] 35 | if f'{name}_GLOBAL_DP.pt' not in pt_files: 36 | collect_index.append([name[:7], int(name.split('_')[2])]) 37 | print('to collect', len(collect_index)) 38 | 39 | NX, NY, NZ, NT = 100, 100, 5, 24 40 | 41 | for tup in collect_index: 42 | case_path, idx = tup 43 | case_name = f'case_{idx}' 44 | meta_data = np.load(f'{META_PATH}/ECLIPSE/meta_data/{case_path}_{idx}.npy', 45 | allow_pickle=True).tolist() 46 | 47 | for k, v in meta_data[case_name].items(): 48 | globals()[k]=v 49 | 50 | OUTPUT_DICT = return_OUTPUT_DICT(meta_data, case_name) 51 | 52 | p, t, rate = INPUT_DICT['p'], INPUT_DICT['temp'], INPUT_DICT['inj'] 53 | INJ_MAP_DICT = return_inj_map_dict(WELL_LIST,rate,INJ_LOCATION_DICT,GRID_CENTER_DICT,LGR_LIST) 54 | 55 | gridx = np.repeat(xy_norm(GRID_CENTER_DICT['GLOBAL']['grid_x'])[...,None,None], 24, axis=-2) 56 | gridy = np.repeat(xy_norm(GRID_CENTER_DICT['GLOBAL']['grid_y'])[...,None,None], 24, axis=-2) 57 | gridz = np.repeat(z_norm(TOPS_DICT['GLOBAL'][0,...,None,None]), 24, axis=-2) 58 | gridt = (np.ones(gridz.shape)* times[None,None,None,:,None]) 59 | 60 | inj = np.repeat(INJ_MAP_DICT['GLOBAL'][...,None,None], 24, axis=-2) 61 | pressure = np.repeat(p_norm(return_upsample_dict(OUTPUT_DICT, 0, 'BPR', 62 | WELL_LIST, GRID_IDX_DICT)['GLOBAL'][...,None,None]), 63 | 24, axis=-2) 64 | temp = t_norm(t) * np.ones(inj.shape) 65 | perm = np.repeat(k_norm(PERM_DICT['GLOBAL'])[0,...,None,None], 24, axis=-2) 66 | 67 | DICT = return_upsample_all_time(OUTPUT_DICT, 'dP', WELL_LIST, GRID_IDX_DICT, LGR_LIST) 68 | 69 | x_DP = np.concatenate([gridx, gridy, gridz, gridt, inj, pressure, temp, perm], 70 | axis=-1)[None,...] 71 | y_DP = DICT['GLOBAL'][...,None] 72 | 73 | x_DP = torch.from_numpy(x_DP.astype(np.float32)) 74 | y_DP = torch.from_numpy(y_DP.astype(np.float32)) 75 | 76 | data = {} 77 | data['input'] = x_DP 78 | data['output'] = y_DP 79 | 80 | torch.save(data, f'../dataset/dP_GLOBAL/{case_path}_{idx}_GLOBAL_DP.pt') 81 | print(f'{case_path}_{idx}_GLOBAL_DP done') -------------------------------------------------------------------------------- /data_config/file_config_LGR1_SG.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import matplotlib as mpl 4 | import matplotlib.patches as patches 5 | import glob 6 | import os 7 | from config_utility import * 8 | 9 | import torch.nn.functional as F 10 | import torch 11 | 12 | 13 | xy_norm = lambda x: (x)/160000 14 | z_norm = lambda x: (x-2000)/2000 15 | p_norm = lambda x: (x)/172 16 | t_norm = lambda x: (x)/70 17 | k_norm = lambda x: (x)/100 18 | 19 | times = np.cumsum(10*np.array(np.power(1.2531,np.arange(1,25,1)), dtype=int)) 20 | times = times/ 10950 21 | 22 | PT_GLOBAL_PATH = f'../dataset/dP_GLOBAL/' 23 | pt_files = os.listdir(PT_GLOBAL_PATH) 24 | print('done collected:', len(pt_files)) 25 | 26 | GLOBAL_names = [] 27 | for file in pt_files: 28 | l = file.split('_') 29 | GLOBAL_names.append((f'{l[0]}_{l[1]}', int(l[2]))) 30 | print(len(GLOBAL_names)) 31 | 32 | # find reservoirs that has not been collected 33 | path = f'../dataset/SG_LGR1/' 34 | if not os.path.exists(path): 35 | os.mkdir(path) 36 | files = os.listdir(path) 37 | 38 | collected_names = [] 39 | for file in files: 40 | l = file.split('_') 41 | collected_names.append((f'{l[0]}_{l[1]}', int(l[2]))) 42 | print(len(collected_names)) 43 | 44 | to_load_names = [] 45 | for elem in GLOBAL_names: 46 | if elem not in collected_names: 47 | to_load_names.append(elem) 48 | 49 | print(len(to_load_names)) 50 | 51 | NX, NY, NZ, NT = 40, 40, 25, 24 52 | ROOT_PATH = '..' 53 | 54 | for names in to_load_names: 55 | slope_name, idx = names 56 | case_name = f'case_{idx}' 57 | meta_data = np.load(f'{ROOT_PATH}/ECLIPSE/meta_data/{slope_name}_{idx}.npy', allow_pickle=True).tolist() 58 | 59 | for k, v in meta_data[case_name].items(): 60 | globals()[k]=v 61 | 62 | OUTPUT_DICT = return_OUTPUT_DICT(meta_data, case_name) 63 | 64 | p, t, rate = INPUT_DICT['p'], INPUT_DICT['temp'], INPUT_DICT['inj'] 65 | INJ_MAP_DICT = return_inj_map_dict(WELL_LIST,rate,INJ_LOCATION_DICT,GRID_CENTER_DICT, LGR_LIST) 66 | print(idx) 67 | 68 | for well in WELL_LIST: 69 | gridx = np.repeat(xy_norm(GRID_CENTER_DICT[well]['LGR1']['grid_x'])[...,None,None], 24, axis=-2) 70 | gridy = np.repeat(xy_norm(GRID_CENTER_DICT[well]['LGR1']['grid_y'])[...,None,None], 24, axis=-2) 71 | gridz = np.repeat(z_norm(TOPS_DICT[well]['LGR1'][0,...,None,None]), 24, axis=-2) 72 | gridt = (np.ones(gridz.shape)* times[None,None,None,:,None]) 73 | 74 | inj = np.repeat(INJ_MAP_DICT[well]['LGR1'][...,None,None], 24, axis=-2) 75 | pressure = np.repeat(p_norm(return_upsample_dict(OUTPUT_DICT, 0, 'BPR', 76 | WELL_LIST, GRID_IDX_DICT)[well]['LGR1'][0,...,None,None]), 24, axis=-2) 77 | temp = t_norm(t) * np.ones(inj.shape) 78 | perm = np.repeat(k_norm(PERM_DICT[well]['LGR1'])[0,...,None,None], 24, axis=-2) 79 | 80 | 81 | I1, I2 = GRID_IDX_DICT[well]['LGR1']['I1']-1-15, GRID_IDX_DICT[well]['LGR1']['I2']+15 82 | J1, J2 = GRID_IDX_DICT[well]['LGR1']['J1']-1-15, GRID_IDX_DICT[well]['LGR1']['J2']+15 83 | 84 | DICT = return_upsample_all_time(OUTPUT_DICT, 'dP', WELL_LIST, 85 | GRID_IDX_DICT, LGR_LIST) 86 | coarse = np.repeat(DICT['GLOBAL'][0,I1:I2,J1:J2,:,:,None],5,axis=2) 87 | x_DP = np.concatenate([gridx, gridy, gridz, gridt, inj, pressure, temp, perm, coarse], axis=-1)[None,...] 88 | 89 | DICT = return_upsample_all_time(OUTPUT_DICT, 'BGSAT', WELL_LIST, GRID_IDX_DICT, LGR_LIST) 90 | y_DP = DICT[well]['LGR1'][...,None] 91 | 92 | x_DP = torch.from_numpy(x_DP.astype(np.float32)) 93 | y_DP = torch.from_numpy(y_DP.astype(np.float32)) 94 | 95 | data = {} 96 | data['input'] = x_DP 97 | data['output'] = y_DP 98 | print(f'{slope_name}_{idx}_LGR1_{well}_SG.pt') 99 | 100 | torch.save(data, f'../dataset/SG_LGR1/{slope_name}_{idx}_LGR1_{well}_SG.pt') -------------------------------------------------------------------------------- /data_config/file_config_LGR1_dP.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import matplotlib as mpl 4 | import matplotlib.patches as patches 5 | import glob 6 | import os 7 | from config_utility import * 8 | 9 | import torch.nn.functional as F 10 | import torch 11 | 12 | xy_norm = lambda x: (x)/160000 13 | z_norm = lambda x: (x-2000)/2000 14 | p_norm = lambda x: (x)/172 15 | t_norm = lambda x: (x)/70 16 | k_norm = lambda x: (x)/100 17 | 18 | times = np.cumsum(10*np.array(np.power(1.2531,np.arange(1,25,1)), dtype=int)) 19 | times = times/ 10950 20 | 21 | PT_GLOBAL_PATH = f'../dataset/dP_GLOBAL/' 22 | pt_files = os.listdir(PT_GLOBAL_PATH) 23 | print('done collected:', len(pt_files)) 24 | 25 | GLOBAL_names = [] 26 | for file in pt_files: 27 | l = file.split('_') 28 | GLOBAL_names.append((f'{l[0]}_{l[1]}', int(l[2]))) 29 | print(len(GLOBAL_names)) 30 | 31 | # find reservoirs that has not been collected 32 | path = f'../dataset/dP_LGR1/' 33 | if not os.path.exists(path): 34 | os.mkdir(path) 35 | files = os.listdir(path) 36 | 37 | collected_names = [] 38 | for file in files: 39 | l = file.split('_') 40 | collected_names.append((f'{l[0]}_{l[1]}', int(l[2]))) 41 | print(len(collected_names)) 42 | 43 | to_load_names = [] 44 | for elem in GLOBAL_names: 45 | if elem not in collected_names: 46 | to_load_names.append(elem) 47 | 48 | print(len(to_load_names)) 49 | 50 | NX, NY, NZ, NT = 40, 40, 25, 24 51 | ROOT_PATH = '..' 52 | 53 | for names in to_load_names: 54 | slope_name, idx = names 55 | case_name = f'case_{idx}' 56 | meta_data = np.load(f'{ROOT_PATH}/ECLIPSE/meta_data/{slope_name}_{idx}.npy', allow_pickle=True).tolist() 57 | 58 | for k, v in meta_data[case_name].items(): 59 | globals()[k]=v 60 | 61 | OUTPUT_DICT = return_OUTPUT_DICT(meta_data, case_name) 62 | 63 | p, t, rate = INPUT_DICT['p'], INPUT_DICT['temp'], INPUT_DICT['inj'] 64 | INJ_MAP_DICT = return_inj_map_dict(WELL_LIST,rate,INJ_LOCATION_DICT,GRID_CENTER_DICT, LGR_LIST) 65 | print(idx) 66 | 67 | for well in WELL_LIST: 68 | gridx = np.repeat(xy_norm(GRID_CENTER_DICT[well]['LGR1']['grid_x'])[...,None,None], 24, axis=-2) 69 | gridy = np.repeat(xy_norm(GRID_CENTER_DICT[well]['LGR1']['grid_y'])[...,None,None], 24, axis=-2) 70 | gridz = np.repeat(z_norm(TOPS_DICT[well]['LGR1'][0,...,None,None]), 24, axis=-2) 71 | gridt = (np.ones(gridz.shape)* times[None,None,None,:,None]) 72 | 73 | inj = np.repeat(INJ_MAP_DICT[well]['LGR1'][...,None,None], 24, axis=-2) 74 | pressure = np.repeat(p_norm(return_upsample_dict(OUTPUT_DICT, 0, 'BPR', 75 | WELL_LIST, GRID_IDX_DICT)[well]['LGR1'][0,...,None,None]), 24, axis=-2) 76 | temp = t_norm(t) * np.ones(inj.shape) 77 | perm = np.repeat(k_norm(PERM_DICT[well]['LGR1'])[0,...,None,None], 24, axis=-2) 78 | 79 | 80 | I1, I2 = GRID_IDX_DICT[well]['LGR1']['I1']-1-15, GRID_IDX_DICT[well]['LGR1']['I2']+15 81 | J1, J2 = GRID_IDX_DICT[well]['LGR1']['J1']-1-15, GRID_IDX_DICT[well]['LGR1']['J2']+15 82 | 83 | DICT = return_upsample_all_time(OUTPUT_DICT, 'dP', WELL_LIST, GRID_IDX_DICT,LGR_LIST) 84 | 85 | coarse = np.repeat(DICT['GLOBAL'][0,I1:I2,J1:J2,:,:,None],5,axis=2) 86 | x_DP = np.concatenate([gridx, gridy, gridz, gridt, inj, pressure, temp, perm, coarse], axis=-1)[None,...] 87 | y_DP = DICT[well]['LGR1'][...,None] 88 | 89 | x_DP = torch.from_numpy(x_DP.astype(np.float32)) 90 | y_DP = torch.from_numpy(y_DP.astype(np.float32)) 91 | 92 | data = {} 93 | data['input'] = x_DP 94 | data['output'] = y_DP 95 | print(f'{slope_name}_{idx}_LGR1_{well}_DP.pt') 96 | torch.save(data, f'../dataset/dP_LGR1/{slope_name}_{idx}_LGR1_{well}_DP.pt') -------------------------------------------------------------------------------- /data_config/file_config_LGR2_SG.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import matplotlib as mpl 4 | import matplotlib.patches as patches 5 | import glob 6 | import os 7 | from config_utility import * 8 | 9 | import torch.nn.functional as F 10 | import torch 11 | 12 | xy_norm = lambda x: (x)/160000 13 | z_norm = lambda x: (x-2000)/2000 14 | p_norm = lambda x: (x)/172 15 | t_norm = lambda x: (x)/70 16 | k_norm = lambda x: (x)/100 17 | 18 | times = np.cumsum(10*np.array(np.power(1.2531,np.arange(1,25,1)), dtype=int)) 19 | times = times/ 10950 20 | 21 | PT_GLOBAL_PATH = f'../dataset/dP_GLOBAL/' 22 | pt_files = os.listdir(PT_GLOBAL_PATH) 23 | print('done collected:', len(pt_files)) 24 | 25 | GLOBAL_names = [] 26 | for file in pt_files: 27 | l = file.split('_') 28 | GLOBAL_names.append((f'{l[0]}_{l[1]}', int(l[2]))) 29 | print(len(GLOBAL_names)) 30 | 31 | # find reservoirs that has not been collected 32 | path = f'../dataset/SG_LGR2/' 33 | if not os.path.exists(path): 34 | os.mkdir(path) 35 | files = os.listdir(path) 36 | 37 | 38 | collected_names = [] 39 | for file in files: 40 | l = file.split('_') 41 | collected_names.append((f'{l[0]}_{l[1]}', int(l[2]))) 42 | print(len(collected_names)) 43 | 44 | to_load_names = [] 45 | for elem in GLOBAL_names: 46 | if elem not in collected_names: 47 | to_load_names.append(elem) 48 | 49 | print(len(to_load_names)) 50 | 51 | 52 | NX, NY, NZ, NT = 40, 40, 50, 24 53 | ROOT_PATH = '..' 54 | 55 | for names in to_load_names: 56 | slope_name, idx = names 57 | case_name = f'case_{idx}' 58 | meta_data = np.load(f'{ROOT_PATH}/ECLIPSE/meta_data/{slope_name}_{idx}.npy', allow_pickle=True).tolist() 59 | 60 | for k, v in meta_data[case_name].items(): 61 | globals()[k]=v 62 | 63 | OUTPUT_DICT = return_OUTPUT_DICT(meta_data, case_name) 64 | 65 | p, t, rate = INPUT_DICT['p'], INPUT_DICT['temp'], INPUT_DICT['inj'] 66 | INJ_MAP_DICT = return_inj_map_dict(WELL_LIST,rate,INJ_LOCATION_DICT,GRID_CENTER_DICT, LGR_LIST) 67 | print(idx) 68 | 69 | for well in WELL_LIST: 70 | gridx = np.repeat(xy_norm(GRID_CENTER_DICT[well]['LGR2']['grid_x'])[...,None,None], 24, axis=-2) 71 | gridy = np.repeat(xy_norm(GRID_CENTER_DICT[well]['LGR2']['grid_y'])[...,None,None], 24, axis=-2) 72 | gridz = np.repeat(z_norm(TOPS_DICT[well]['LGR2'][0,...,None,None]), 24, axis=-2) 73 | gridt = (np.ones(gridz.shape)* times[None,None,None,:,None]) 74 | 75 | inj = np.repeat(INJ_MAP_DICT[well]['LGR2'][...,None,None], 24, axis=-2) 76 | pressure = np.repeat(p_norm(return_upsample_dict(OUTPUT_DICT, 0, 'BPR', 77 | WELL_LIST, GRID_IDX_DICT)[well]['LGR2'][0,...,None,None]), 24, axis=-2) 78 | temp = t_norm(t) * np.ones(inj.shape) 79 | perm = np.repeat(k_norm(PERM_DICT[well]['LGR2'])[0,...,None,None], 24, axis=-2) 80 | 81 | 82 | DICT = return_upsample_all_time(OUTPUT_DICT, 'BGSAT', WELL_LIST, 83 | GRID_IDX_DICT, LGR_LIST) 84 | 85 | coarse = np.repeat(DICT[well]['LGR1'][0,:,:,:,:,None],2,axis=2) 86 | x_DP = np.concatenate([gridx, gridy, gridz, gridt, inj, pressure, temp, perm, coarse], axis=-1)[None,...] 87 | y_DP = DICT[well]['LGR2'][...,None] 88 | 89 | x_DP = torch.from_numpy(x_DP.astype(np.float32)) 90 | y_DP = torch.from_numpy(y_DP.astype(np.float32)) 91 | 92 | data = {} 93 | data['input'] = x_DP 94 | data['output'] = y_DP 95 | print(f'{slope_name}_{idx}_LGR2_{well}_SG.pt') 96 | 97 | torch.save(data, f'../dataset/SG_LGR2/{slope_name}_{idx}_LGR2_{well}_SG.pt') -------------------------------------------------------------------------------- /data_config/file_config_LGR2_dP.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import matplotlib as mpl 4 | import matplotlib.patches as patches 5 | import glob 6 | import os 7 | from config_utility import * 8 | 9 | import torch.nn.functional as F 10 | import torch 11 | 12 | xy_norm = lambda x: (x)/160000 13 | z_norm = lambda x: (x-2000)/2000 14 | p_norm = lambda x: (x)/172 15 | t_norm = lambda x: (x)/70 16 | k_norm = lambda x: (x)/100 17 | 18 | times = np.cumsum(10*np.array(np.power(1.2531,np.arange(1,25,1)), dtype=int)) 19 | times = times/ 10950 20 | 21 | 22 | PT_GLOBAL_PATH = f'../dataset/dP_GLOBAL/' 23 | pt_files = os.listdir(PT_GLOBAL_PATH) 24 | print('done collected:', len(pt_files)) 25 | 26 | GLOBAL_names = [] 27 | for file in pt_files: 28 | l = file.split('_') 29 | GLOBAL_names.append((f'{l[0]}_{l[1]}', int(l[2]))) 30 | print(len(GLOBAL_names)) 31 | 32 | # find reservoirs that has not been collected 33 | path = f'../dataset/dP_LGR2/' 34 | if not os.path.exists(path): 35 | os.mkdir(path) 36 | files = os.listdir(path) 37 | 38 | collected_names = [] 39 | for file in files: 40 | l = file.split('_') 41 | collected_names.append((f'{l[0]}_{l[1]}', int(l[2]))) 42 | print(len(collected_names)) 43 | 44 | to_load_names = [] 45 | for elem in GLOBAL_names: 46 | if elem not in collected_names: 47 | to_load_names.append(elem) 48 | 49 | print(len(to_load_names)) 50 | 51 | ROOT_PATH = '..' 52 | NX, NY, NZ, NT = 40, 40, 50, 24 53 | 54 | for names in to_load_names: 55 | slope_name, idx = names 56 | case_name = f'case_{idx}' 57 | meta_data = np.load(f'{ROOT_PATH}/ECLIPSE/meta_data/{slope_name}_{idx}.npy', allow_pickle=True).tolist() 58 | 59 | for k, v in meta_data[case_name].items(): 60 | globals()[k]=v 61 | 62 | OUTPUT_DICT = return_OUTPUT_DICT(meta_data, case_name) 63 | 64 | p, t, rate = INPUT_DICT['p'], INPUT_DICT['temp'], INPUT_DICT['inj'] 65 | INJ_MAP_DICT = return_inj_map_dict(WELL_LIST,rate,INJ_LOCATION_DICT,GRID_CENTER_DICT, LGR_LIST) 66 | print(idx) 67 | 68 | for well in WELL_LIST: 69 | gridx = np.repeat(xy_norm(GRID_CENTER_DICT[well]['LGR2']['grid_x'])[...,None,None], 24, axis=-2) 70 | gridy = np.repeat(xy_norm(GRID_CENTER_DICT[well]['LGR2']['grid_y'])[...,None,None], 24, axis=-2) 71 | gridz = np.repeat(z_norm(TOPS_DICT[well]['LGR2'][0,...,None,None]), 24, axis=-2) 72 | gridt = (np.ones(gridz.shape)* times[None,None,None,:,None]) 73 | 74 | inj = np.repeat(INJ_MAP_DICT[well]['LGR2'][...,None,None], 24, axis=-2) 75 | pressure = np.repeat(p_norm(return_upsample_dict(OUTPUT_DICT, 0, 'BPR', 76 | WELL_LIST, GRID_IDX_DICT)[well]['LGR2'][0,...,None,None]), 24, axis=-2) 77 | temp = t_norm(t) * np.ones(inj.shape) 78 | perm = np.repeat(k_norm(PERM_DICT[well]['LGR2'])[0,...,None,None], 24, axis=-2) 79 | 80 | 81 | DICT = return_upsample_all_time(OUTPUT_DICT, 'dP', WELL_LIST, 82 | GRID_IDX_DICT, LGR_LIST) 83 | 84 | coarse = np.repeat(DICT[well]['LGR1'][0,:,:,:,:,None],2,axis=2) 85 | x_DP = np.concatenate([gridx, gridy, gridz, gridt, inj, pressure, temp, perm, coarse], axis=-1)[None,...] 86 | y_DP = DICT[well]['LGR2'][...,None] 87 | 88 | x_DP = torch.from_numpy(x_DP.astype(np.float32)) 89 | y_DP = torch.from_numpy(y_DP.astype(np.float32)) 90 | 91 | data = {} 92 | data['input'] = x_DP 93 | data['output'] = y_DP 94 | print(f'{slope_name}_{idx}_LGR2_{well}_DP.pt') 95 | 96 | torch.save(data, f'../dataset/dP_LGR2/{slope_name}_{idx}_LGR2_{well}_DP.pt') -------------------------------------------------------------------------------- /data_config/file_config_LGR3_SG.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import matplotlib as mpl 4 | import matplotlib.patches as patches 5 | import glob 6 | import os 7 | from config_utility import * 8 | 9 | import torch.nn.functional as F 10 | import torch 11 | 12 | xy_norm = lambda x: (x)/160000 13 | z_norm = lambda x: (x-2000)/2000 14 | p_norm = lambda x: (x)/172 15 | t_norm = lambda x: (x)/70 16 | k_norm = lambda x: (x)/100 17 | 18 | times = np.cumsum(10*np.array(np.power(1.2531,np.arange(1,25,1)), dtype=int)) 19 | times = times/ 10950 20 | 21 | PT_GLOBAL_PATH = f'../dataset/dP_GLOBAL/' 22 | pt_files = os.listdir(PT_GLOBAL_PATH) 23 | print('done collected:', len(pt_files)) 24 | 25 | GLOBAL_names = [] 26 | for file in pt_files: 27 | l = file.split('_') 28 | GLOBAL_names.append((f'{l[0]}_{l[1]}', int(l[2]))) 29 | print(len(GLOBAL_names)) 30 | 31 | # find reservoirs that has not been collected 32 | path = f'../dataset/SG_LGR3/' 33 | if not os.path.exists(path): 34 | os.mkdir(path) 35 | files = os.listdir(path) 36 | 37 | 38 | collected_names = [] 39 | for file in files: 40 | l = file.split('_') 41 | collected_names.append((f'{l[0]}_{l[1]}', int(l[2]))) 42 | print(len(collected_names)) 43 | 44 | to_load_names = [] 45 | for elem in GLOBAL_names: 46 | if elem not in collected_names: 47 | to_load_names.append(elem) 48 | 49 | print(len(to_load_names)) 50 | 51 | NX, NY, NZ, NT = 40, 40, 50, 24 52 | ROOT_PATH = '..' 53 | 54 | for names in to_load_names: 55 | slope_name, idx = names 56 | case_name = f'case_{idx}' 57 | meta_data = np.load(f'{ROOT_PATH}/ECLIPSE/meta_data/{slope_name}_{idx}.npy', allow_pickle=True).tolist() 58 | 59 | for k, v in meta_data[case_name].items(): 60 | globals()[k]=v 61 | 62 | OUTPUT_DICT = return_OUTPUT_DICT(meta_data, case_name) 63 | 64 | p, t, rate = INPUT_DICT['p'], INPUT_DICT['temp'], INPUT_DICT['inj'] 65 | INJ_MAP_DICT = return_inj_map_dict(WELL_LIST,rate,INJ_LOCATION_DICT,GRID_CENTER_DICT, LGR_LIST) 66 | print(idx) 67 | 68 | for well in WELL_LIST: 69 | gridx = np.repeat(xy_norm(GRID_CENTER_DICT[well]['LGR3']['grid_x'])[...,None,None], 24, axis=-2) 70 | gridy = np.repeat(xy_norm(GRID_CENTER_DICT[well]['LGR3']['grid_y'])[...,None,None], 24, axis=-2) 71 | gridz = np.repeat(z_norm(TOPS_DICT[well]['LGR3'][0,...,None,None]), 24, axis=-2) 72 | gridt = (np.ones(gridz.shape)* times[None,None,None,:,None]) 73 | 74 | inj = np.repeat(INJ_MAP_DICT[well]['LGR3'][...,None,None], 24, axis=-2) 75 | pressure = np.repeat(p_norm(return_upsample_dict(OUTPUT_DICT, 0, 'BPR', 76 | WELL_LIST, GRID_IDX_DICT)[well]['LGR3'][0,...,None,None]), 24, axis=-2) 77 | temp = t_norm(t) * np.ones(inj.shape) 78 | perm = np.repeat(k_norm(PERM_DICT[well]['LGR3'])[0,...,None,None], 24, axis=-2) 79 | 80 | 81 | DICT = return_upsample_all_time(OUTPUT_DICT, 'BGSAT', WELL_LIST, 82 | GRID_IDX_DICT, LGR_LIST) 83 | 84 | coarse = DICT[well]['LGR2'][0,:,:,:,:,None] 85 | x_DP = np.concatenate([gridx, gridy, gridz, gridt, inj, pressure, temp, perm, coarse], axis=-1)[None,...] 86 | y_DP = DICT[well]['LGR3'][...,None] 87 | 88 | x_DP = torch.from_numpy(x_DP.astype(np.float32)) 89 | y_DP = torch.from_numpy(y_DP.astype(np.float32)) 90 | 91 | data = {} 92 | data['input'] = x_DP 93 | data['output'] = y_DP 94 | print(f'{slope_name}_{idx}_LGR3_{well}_SG.pt') 95 | 96 | torch.save(data, f'../dataset/SG_LGR3/{slope_name}_{idx}_LGR3_{well}_SG.pt') -------------------------------------------------------------------------------- /data_config/file_config_LGR3_dP.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import matplotlib as mpl 4 | import matplotlib.patches as patches 5 | import glob 6 | import os 7 | from config_utility import * 8 | 9 | import torch.nn.functional as F 10 | import torch 11 | 12 | xy_norm = lambda x: (x)/160000 13 | z_norm = lambda x: (x-2000)/2000 14 | p_norm = lambda x: (x)/172 15 | t_norm = lambda x: (x)/70 16 | k_norm = lambda x: (x)/100 17 | 18 | times = np.cumsum(10*np.array(np.power(1.2531,np.arange(1,25,1)), dtype=int)) 19 | times = times/ 10950 20 | 21 | PT_GLOBAL_PATH = f'../dataset/dP_GLOBAL/' 22 | pt_files = os.listdir(PT_GLOBAL_PATH) 23 | print('done collected:', len(pt_files)) 24 | 25 | GLOBAL_names = [] 26 | for file in pt_files: 27 | l = file.split('_') 28 | GLOBAL_names.append((f'{l[0]}_{l[1]}', int(l[2]))) 29 | print(len(GLOBAL_names)) 30 | 31 | # find reservoirs that has not been collected 32 | path = f'../dataset/dP_LGR3/' 33 | if not os.path.exists(path): 34 | os.mkdir(path) 35 | files = os.listdir(path) 36 | 37 | collected_names = [] 38 | for file in files: 39 | l = file.split('_') 40 | collected_names.append((f'{l[0]}_{l[1]}', int(l[2]))) 41 | print(len(collected_names)) 42 | 43 | to_load_names = [] 44 | for elem in GLOBAL_names: 45 | if elem not in collected_names: 46 | to_load_names.append(elem) 47 | 48 | print(len(to_load_names)) 49 | 50 | NX, NY, NZ, NT = 40, 40, 50, 24 51 | ROOT_PATH = '..' 52 | 53 | for names in to_load_names: 54 | slope_name, idx = names 55 | case_name = f'case_{idx}' 56 | meta_data = np.load(f'{ROOT_PATH}/ECLIPSE/meta_data/{slope_name}_{idx}.npy', allow_pickle=True).tolist() 57 | 58 | for k, v in meta_data[case_name].items(): 59 | globals()[k]=v 60 | 61 | OUTPUT_DICT = return_OUTPUT_DICT(meta_data, case_name) 62 | 63 | p, t, rate = INPUT_DICT['p'], INPUT_DICT['temp'], INPUT_DICT['inj'] 64 | INJ_MAP_DICT = return_inj_map_dict(WELL_LIST,rate,INJ_LOCATION_DICT,GRID_CENTER_DICT, LGR_LIST) 65 | print(idx) 66 | 67 | for well in WELL_LIST: 68 | gridx = np.repeat(xy_norm(GRID_CENTER_DICT[well]['LGR3']['grid_x'])[...,None,None], 24, axis=-2) 69 | gridy = np.repeat(xy_norm(GRID_CENTER_DICT[well]['LGR3']['grid_y'])[...,None,None], 24, axis=-2) 70 | gridz = np.repeat(z_norm(TOPS_DICT[well]['LGR3'][0,...,None,None]), 24, axis=-2) 71 | gridt = (np.ones(gridz.shape)* times[None,None,None,:,None]) 72 | 73 | inj = np.repeat(INJ_MAP_DICT[well]['LGR3'][...,None,None], 24, axis=-2) 74 | pressure = np.repeat(p_norm(return_upsample_dict(OUTPUT_DICT, 0, 'BPR', 75 | WELL_LIST, GRID_IDX_DICT)[well]['LGR3'][0,...,None,None]), 24, axis=-2) 76 | temp = t_norm(t) * np.ones(inj.shape) 77 | perm = np.repeat(k_norm(PERM_DICT[well]['LGR3'])[0,...,None,None], 24, axis=-2) 78 | 79 | 80 | DICT = return_upsample_all_time(OUTPUT_DICT, 'dP', WELL_LIST, 81 | GRID_IDX_DICT, LGR_LIST) 82 | 83 | coarse = DICT[well]['LGR2'][0,:,:,:,:,None] 84 | x_DP = np.concatenate([gridx, gridy, gridz, gridt, inj, pressure, temp, perm, coarse], axis=-1)[None,...] 85 | y_DP = DICT[well]['LGR3'][...,None] 86 | 87 | x_DP = torch.from_numpy(x_DP.astype(np.float32)) 88 | y_DP = torch.from_numpy(y_DP.astype(np.float32)) 89 | 90 | data = {} 91 | data['input'] = x_DP 92 | data['output'] = y_DP 93 | print(f'{slope_name}_{idx}_LGR3_{well}_DP.pt') 94 | 95 | torch.save(data, f'../dataset/dP_LGR3/{slope_name}_{idx}_LGR3_{well}_DP.pt') -------------------------------------------------------------------------------- /data_config/file_config_LGR4_SG.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import matplotlib as mpl 4 | import matplotlib.patches as patches 5 | import glob 6 | import os 7 | from config_utility import * 8 | 9 | import torch.nn.functional as F 10 | import torch 11 | import json 12 | 13 | xy_norm = lambda x: (x)/160000 14 | z_norm = lambda x: (x-2000)/2000 15 | p_norm = lambda x: (x)/172 16 | t_norm = lambda x: (x)/70 17 | k_norm = lambda x: (x)/100 18 | 19 | times = np.cumsum(10*np.array(np.power(1.2531,np.arange(1,25,1)), dtype=int)) 20 | times = times/ 10950 21 | 22 | PT_GLOBAL_PATH = f'../dataset/dP_GLOBAL/' 23 | pt_files = os.listdir(PT_GLOBAL_PATH) 24 | print('done collected:', len(pt_files)) 25 | 26 | GLOBAL_names = [] 27 | for file in pt_files: 28 | l = file.split('_') 29 | GLOBAL_names.append((f'{l[0]}_{l[1]}', int(l[2]))) 30 | print(len(GLOBAL_names)) 31 | 32 | # find reservoirs that has not been collected 33 | path = f'../dataset/SG_LGR4/' 34 | if not os.path.exists(path): 35 | os.mkdir(path) 36 | files = os.listdir(path) 37 | 38 | 39 | 40 | collected_names = [] 41 | for file in files: 42 | l = file.split('_') 43 | collected_names.append((f'{l[0]}_{l[1]}', int(l[2]))) 44 | print(len(collected_names)) 45 | 46 | to_load_names = [] 47 | for elem in GLOBAL_names: 48 | if elem not in collected_names: 49 | to_load_names.append(elem) 50 | 51 | print(len(to_load_names)) 52 | 53 | 54 | f = open("PERF_DICT.json","r") 55 | PERF_DICT = json.load(f) 56 | f.close() 57 | perf_names = list(PERF_DICT.keys()) 58 | 59 | 60 | NX, NY, NZ, NT = 40, 40, 50, 24 61 | ROOT_PATH = '..' 62 | 63 | for names in to_load_names: 64 | slope_name, idx = names 65 | case_name = f'case_{idx}' 66 | meta_data = np.load(f'{ROOT_PATH}/ECLIPSE/meta_data/{slope_name}_{idx}.npy', allow_pickle=True).tolist() 67 | 68 | for k, v in meta_data[case_name].items(): 69 | globals()[k]=v 70 | 71 | OUTPUT_DICT = return_OUTPUT_DICT(meta_data, case_name) 72 | 73 | p, t, rate = INPUT_DICT['p'], INPUT_DICT['temp'], INPUT_DICT['inj'] 74 | INJ_MAP_DICT = return_inj_map_dict(WELL_LIST,rate,INJ_LOCATION_DICT,GRID_CENTER_DICT, LGR_LIST) 75 | print(idx) 76 | 77 | for well in WELL_LIST: 78 | gridx = np.repeat(xy_norm(GRID_CENTER_DICT[well]['LGR4']['grid_x'])[...,None,None], 24, axis=-2) 79 | gridy = np.repeat(xy_norm(GRID_CENTER_DICT[well]['LGR4']['grid_y'])[...,None,None], 24, axis=-2) 80 | gridz = np.repeat(z_norm(TOPS_DICT[well]['LGR4'][0,...,None,None]), 24, axis=-2) 81 | gridt = (np.ones(gridz.shape)* times[None,None,None,:,None]) 82 | 83 | inj = np.repeat(INJ_MAP_DICT[well]['LGR4'][...,None,None], 24, axis=-2) 84 | pressure = np.repeat(p_norm(return_upsample_dict(OUTPUT_DICT, 0, 'BPR', 85 | WELL_LIST, GRID_IDX_DICT)[well]['LGR4'][0,...,None,None]), 24, axis=-2) 86 | temp = t_norm(t) * np.ones(inj.shape) 87 | perm = np.repeat(k_norm(PERM_DICT[well]['LGR4'])[0,...,None,None], 24, axis=-2) 88 | 89 | 90 | DICT = return_upsample_all_time(OUTPUT_DICT, 'BGSAT', WELL_LIST, GRID_IDX_DICT, LGR_LIST) 91 | 92 | coarse = DICT[well]['LGR3'][0,:,:,:,:,None] 93 | x_DP = np.concatenate([gridx, gridy, gridz, gridt, inj, pressure, temp, perm, coarse], axis=-1)[None,...] 94 | y_DP = DICT[well]['LGR4'][...,None] 95 | 96 | x_DP = torch.from_numpy(x_DP.astype(np.float32)) 97 | y_DP = torch.from_numpy(y_DP.astype(np.float32)) 98 | 99 | perf = PERF_DICT[f'{slope_name}_{idx}'][f'INJ{well[-1]}'] 100 | old_inj = torch.clone(x_DP[0,19,19,:,:,4]) 101 | new_inj = torch.zeros(old_inj.shape) 102 | new_inj[perf[0]-1:perf[1],:] = old_inj[perf[0]-1:perf[1],:] 103 | x_DP[0,19,19,:,:,4] = new_inj 104 | 105 | data = {} 106 | data['input'] = x_DP 107 | data['output'] = y_DP 108 | print(f'{slope_name}_{idx}_LGR4_{well}_SG.pt') 109 | torch.save(data, f'../dataset/SG_LGR4/{slope_name}_{idx}_LGR4_{well}_SG.pt') -------------------------------------------------------------------------------- /data_config/file_config_LGR4_dP.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import matplotlib as mpl 4 | import matplotlib.patches as patches 5 | import glob 6 | import os 7 | from config_utility import * 8 | 9 | import torch.nn.functional as F 10 | import torch 11 | import json 12 | 13 | xy_norm = lambda x: (x)/160000 14 | z_norm = lambda x: (x-2000)/2000 15 | p_norm = lambda x: (x)/172 16 | t_norm = lambda x: (x)/70 17 | k_norm = lambda x: (x)/100 18 | 19 | times = np.cumsum(10*np.array(np.power(1.2531,np.arange(1,25,1)), dtype=int)) 20 | times = times/ 10950 21 | 22 | 23 | PT_GLOBAL_PATH = f'../dataset/dP_GLOBAL/' 24 | pt_files = os.listdir(PT_GLOBAL_PATH) 25 | print('done collected:', len(pt_files)) 26 | 27 | GLOBAL_names = [] 28 | for file in pt_files: 29 | l = file.split('_') 30 | GLOBAL_names.append((f'{l[0]}_{l[1]}', int(l[2]))) 31 | print(len(GLOBAL_names)) 32 | 33 | 34 | # find reservoirs that has not been collected 35 | path = f'../dataset/dP_LGR4/' 36 | if not os.path.exists(path): 37 | os.mkdir(path) 38 | files = os.listdir(path) 39 | 40 | 41 | collected_names = [] 42 | for file in files: 43 | l = file.split('_') 44 | collected_names.append((f'{l[0]}_{l[1]}', int(l[2]))) 45 | print(len(collected_names)) 46 | 47 | to_load_names = [] 48 | for elem in GLOBAL_names: 49 | if elem not in collected_names: 50 | to_load_names.append(elem) 51 | 52 | print(len(to_load_names)) 53 | 54 | f = open("PERF_DICT.json","r") 55 | PERF_DICT = json.load(f) 56 | f.close() 57 | perf_names = list(PERF_DICT.keys()) 58 | 59 | NX, NY, NZ, NT = 40, 40, 50, 24 60 | ROOT_PATH = '..' 61 | 62 | for names in to_load_names: 63 | slope_name, idx = names 64 | case_name = f'case_{idx}' 65 | meta_data = np.load(f'{ROOT_PATH}/ECLIPSE/meta_data/{slope_name}_{idx}.npy', 66 | allow_pickle=True).tolist() 67 | 68 | for k, v in meta_data[case_name].items(): 69 | globals()[k]=v 70 | 71 | OUTPUT_DICT = return_OUTPUT_DICT(meta_data, case_name) 72 | 73 | p, t, rate = INPUT_DICT['p'], INPUT_DICT['temp'], INPUT_DICT['inj'] 74 | INJ_MAP_DICT = return_inj_map_dict(WELL_LIST,rate,INJ_LOCATION_DICT,GRID_CENTER_DICT, LGR_LIST) 75 | print(idx) 76 | 77 | for well in WELL_LIST: 78 | gridx = np.repeat(xy_norm(GRID_CENTER_DICT[well]['LGR4']['grid_x'])[...,None,None], 24, axis=-2) 79 | gridy = np.repeat(xy_norm(GRID_CENTER_DICT[well]['LGR4']['grid_y'])[...,None,None], 24, axis=-2) 80 | gridz = np.repeat(z_norm(TOPS_DICT[well]['LGR4'][0,...,None,None]), 24, axis=-2) 81 | gridt = (np.ones(gridz.shape)* times[None,None,None,:,None]) 82 | 83 | inj = np.repeat(INJ_MAP_DICT[well]['LGR4'][...,None,None], 24, axis=-2) 84 | pressure = np.repeat(p_norm(return_upsample_dict(OUTPUT_DICT, 0, 'BPR', 85 | WELL_LIST, GRID_IDX_DICT)[well]['LGR4'][0,...,None,None]), 24, axis=-2) 86 | temp = t_norm(t) * np.ones(inj.shape) 87 | perm = np.repeat(k_norm(PERM_DICT[well]['LGR4'])[0,...,None,None], 24, axis=-2) 88 | 89 | 90 | DICT = return_upsample_all_time(OUTPUT_DICT, 'dP', WELL_LIST, GRID_IDX_DICT, LGR_LIST) 91 | 92 | coarse = DICT[well]['LGR3'][0,:,:,:,:,None] 93 | x_DP = np.concatenate([gridx, gridy, gridz, gridt, inj, pressure, temp, perm, coarse], axis=-1)[None,...] 94 | y_DP = DICT[well]['LGR4'][...,None] 95 | 96 | x_DP = torch.from_numpy(x_DP.astype(np.float32)) 97 | y_DP = torch.from_numpy(y_DP.astype(np.float32)) 98 | 99 | perf = PERF_DICT[f'{slope_name}_{idx}'][f'INJ{well[-1]}'] 100 | old_inj = torch.clone(x_DP[0,19,19,:,:,4]) 101 | new_inj = torch.zeros(old_inj.shape) 102 | new_inj[perf[0]-1:perf[1],:] = old_inj[perf[0]-1:perf[1],:] 103 | x_DP[0,19,19,:,:,4] = new_inj 104 | 105 | data = {} 106 | data['input'] = x_DP 107 | data['output'] = y_DP 108 | print(f'{slope_name}_{idx}_LGR4_{well}_DP.pt') 109 | 110 | torch.save(data, f'../dataset/dP_LGR4/{slope_name}_{idx}_LGR4_{well}_DP.pt') -------------------------------------------------------------------------------- /finetune_FNO4D_DP_LGR.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "4cba1655", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "data": { 11 | "text/plain": [ 12 | "
" 13 | ] 14 | }, 15 | "metadata": {}, 16 | "output_type": "display_data" 17 | } 18 | ], 19 | "source": [ 20 | "import numpy as np\n", 21 | "import matplotlib.pyplot as plt\n", 22 | "import matplotlib as mpl\n", 23 | "import matplotlib.patches as patches\n", 24 | "import glob\n", 25 | "import os\n", 26 | "from utility import *\n", 27 | "from UnitGaussianNormalizer import *\n", 28 | "import torch.nn.functional as F\n", 29 | "import torch\n", 30 | "from CustomDataset import *\n", 31 | "import pickle\n", 32 | "from FNO4D import *\n", 33 | "import json\n", 34 | "from visulization_compare import *\n", 35 | "plt.jet()" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 2, 41 | "id": "6a13b872", 42 | "metadata": {}, 43 | "outputs": [ 44 | { 45 | "name": "stdout", 46 | "output_type": "stream", 47 | "text": [ 48 | "dict_keys(['input', 'output'])\n", 49 | "dict_keys(['GLOBAL', 'LGR1', 'LGR2', 'LGR3', 'LGR4']) dict_keys(['GLOBAL', 'LGR1', 'LGR2', 'LGR3', 'LGR4'])\n" 50 | ] 51 | } 52 | ], 53 | "source": [ 54 | "NORMALIZER_DICT = {}\n", 55 | "d_in, d_out = {}, {}\n", 56 | "for key in ['GLOBAL', 'LGR1', 'LGR2', 'LGR3', 'LGR4']:\n", 57 | " with open(f\"normalizer/input_normalizer_{key}_DP_val.pickle\", 'rb') as f:\n", 58 | " input_normalizer = pickle.load(f)\n", 59 | " with open(f\"normalizer/output_normalizer_{key}_DP_val.pickle\", 'rb') as f:\n", 60 | " output_normalizer = pickle.load(f)\n", 61 | " input_normalizer.cuda()\n", 62 | " output_normalizer.cuda()\n", 63 | " d_in[key] = input_normalizer\n", 64 | " d_out[key] = output_normalizer\n", 65 | "NORMALIZER_DICT['input'] = d_in\n", 66 | "NORMALIZER_DICT['output'] = d_out\n", 67 | "\n", 68 | "print(NORMALIZER_DICT.keys())\n", 69 | "print(NORMALIZER_DICT['input'].keys(), NORMALIZER_DICT['output'].keys())" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": 3, 75 | "id": "24678fef", 76 | "metadata": {}, 77 | "outputs": [ 78 | { 79 | "name": "stdout", 80 | "output_type": "stream", 81 | "text": [ 82 | "dict_keys(['GLOBAL', 'LGR1', 'LGR2', 'LGR3', 'LGR4'])\n" 83 | ] 84 | } 85 | ], 86 | "source": [ 87 | "# Change these models to your trained weights\n", 88 | "PATH = {}\n", 89 | "PATH['GLOBAL'] = \"pre_trained_models/FNO4D-GLOBAL-DP.pt\"\n", 90 | "PATH['LGR1'] = \"pre_trained_models/FNO4D-LGR1-DP.pt\"\n", 91 | "PATH['LGR2'] = \"pre_trained_models/FNO4D-LGR2-DP.pt\"\n", 92 | "PATH['LGR3'] = \"pre_trained_models/FNO4D-LGR3-DP.pt\"\n", 93 | "PATH['LGR4'] = \"pre_trained_models/FNO4D-LGR4-DP.pt\"\n", 94 | "device = torch.device('cuda')\n", 95 | "\n", 96 | "MODEL_DICT = {}\n", 97 | "for key in ['GLOBAL', 'LGR1', 'LGR2', 'LGR3', 'LGR4']:\n", 98 | " model = torch.load(PATH[key])\n", 99 | " model.to(device)\n", 100 | " model.eval()\n", 101 | " MODEL_DICT[key] = model\n", 102 | " \n", 103 | "print(MODEL_DICT.keys())" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": 4, 109 | "id": "4d667b14", 110 | "metadata": {}, 111 | "outputs": [ 112 | { 113 | "name": "stdout", 114 | "output_type": "stream", 115 | "text": [ 116 | "16 2 2\n" 117 | ] 118 | } 119 | ], 120 | "source": [ 121 | "DATA_LOADER_DICT = torch.load('DATA_LOADER_DICT.pth')\n", 122 | "train_loader = DATA_LOADER_DICT['GLOBAL']['train']\n", 123 | "val_loader = DATA_LOADER_DICT['GLOBAL']['val']\n", 124 | "test_loader = DATA_LOADER_DICT['GLOBAL']['test']\n", 125 | "n_train = len(train_loader)\n", 126 | "n_val = len(val_loader)\n", 127 | "n_test = len(test_loader)\n", 128 | "print(n_train, n_val, n_test)" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": 31, 134 | "id": "2334ea1a", 135 | "metadata": {}, 136 | "outputs": [], 137 | "source": [ 138 | "it = iter(train_loader)\n", 139 | "ERR_LIST = []\n", 140 | "\n", 141 | "for counter in range(len(train_loader)):\n", 142 | " with torch.no_grad():\n", 143 | " PRED, TRUE = {}, {}\n", 144 | " data = next(it)\n", 145 | " x, y, path = data['x'], data['y'], data['path']\n", 146 | " x, y = x[None,...].to(device), y[None,...]\n", 147 | " x[...,-1:] = NORMALIZER_DICT['input']['GLOBAL'].encode(x.to(device)[...,-1:])\n", 148 | " pred = NORMALIZER_DICT['output']['GLOBAL'].decode(MODEL_DICT['GLOBAL'](x)).cpu()\n", 149 | " PRED['GLOBAL'] = pred\n", 150 | " TRUE['GLOBAL'] = y\n", 151 | " slope, idx, well = path[0], path[1], path[2]\n", 152 | "\n", 153 | " meta_data = np.load(f'ECLIPSE/meta_data/{slope}_{idx}.npy', allow_pickle=True).tolist()\n", 154 | " WELL_LIST = meta_data[f'case_{idx}']['WELL_LIST']\n", 155 | " GRID_IDX_DICT = meta_data[f'case_{idx}']['GRID_IDX_DICT']\n", 156 | "\n", 157 | " for well in WELL_LIST:\n", 158 | " lgr_dict, true_dict = {}, {}\n", 159 | "\n", 160 | " data_LGR1 = torch.load(f'dataset/dP_LGR1/{slope}_{idx}_LGR1_{well}_DP.pt')\n", 161 | " I1, I2 = GRID_IDX_DICT[well]['LGR1']['I1']-1-15, GRID_IDX_DICT[well]['LGR1']['I2']+15\n", 162 | " J1, J2 = GRID_IDX_DICT[well]['LGR1']['J1']-1-15, GRID_IDX_DICT[well]['LGR1']['J2']+15\n", 163 | " coarse = np.repeat(PRED['GLOBAL'][0,...][:,I1:I2,J1:J2,:,:],5,axis=-2).permute(-1,1,2,3,0)[...,None]\n", 164 | " x_LGR1 = torch.cat((data_LGR1['input'][...,:-1],coarse),axis=-1)\n", 165 | " x_LGR1 = x_LGR1.permute(0,4,1,2,3,5).to(device)\n", 166 | " x_LGR1[...,-1:] = NORMALIZER_DICT['input']['LGR1'].encode(x_LGR1.to(device)[...,-1:])\n", 167 | " pred = NORMALIZER_DICT['output']['LGR1'].decode(MODEL_DICT['LGR1'](x_LGR1)).cpu()\n", 168 | " lgr_dict['LGR1'] = pred\n", 169 | " y = data_LGR1['output'][...,:1].permute(0,4,1,2,3,5)\n", 170 | " true_dict['LGR1'] = y\n", 171 | "\n", 172 | " data_LGR2 = torch.load(f'dataset/dP_LGR2/{slope}_{idx}_LGR2_{well}_DP.pt')\n", 173 | " coarse = np.repeat(lgr_dict['LGR1'][0,...],2,axis=-2).permute(-1,1,2,3,0)[...,None]\n", 174 | " x_LGR2 = torch.cat((data_LGR2['input'][...,:-1],coarse),axis=-1)\n", 175 | " x_LGR2 = x_LGR2.permute(0,4,1,2,3,5).to(device)\n", 176 | " x_LGR2[...,-1:] = NORMALIZER_DICT['input']['LGR2'].encode(x_LGR2.to(device)[...,-1:])\n", 177 | " pred = NORMALIZER_DICT['output']['LGR2'].decode(MODEL_DICT['LGR2'](x_LGR2)).cpu()\n", 178 | " lgr_dict['LGR2'] = pred\n", 179 | " y = data_LGR2['output'][...,:1].permute(0,4,1,2,3,5)\n", 180 | " true_dict['LGR2'] = y\n", 181 | "\n", 182 | " data_LGR3 = torch.load(f'dataset/dP_LGR3/{slope}_{idx}_LGR3_{well}_DP.pt')\n", 183 | " coarse = lgr_dict['LGR2'][0,...].permute(-1,1,2,3,0)[...,None]\n", 184 | " x_LGR3 = torch.cat((data_LGR3['input'][...,:-1],coarse),axis=-1)\n", 185 | " x_LGR3 = x_LGR3.permute(0,4,1,2,3,5).to(device)\n", 186 | " x_LGR3[...,-1:] = NORMALIZER_DICT['input']['LGR3'].encode(x_LGR3.to(device)[...,-1:])\n", 187 | " pred = NORMALIZER_DICT['output']['LGR3'].decode(MODEL_DICT['LGR3'](x_LGR3)).cpu()\n", 188 | " lgr_dict['LGR3'] = pred\n", 189 | " y = data_LGR3['output'][...,:1].permute(0,4,1,2,3,5)\n", 190 | " true_dict['LGR3'] = y\n", 191 | "\n", 192 | " data_LGR4 = torch.load(f'dataset/dP_LGR4/{slope}_{idx}_LGR4_{well}_DP.pt')\n", 193 | " coarse = lgr_dict['LGR3'][0,...].permute(-1,1,2,3,0)[...,None]\n", 194 | " x = data_LGR4['input']\n", 195 | " x_LGR4 = torch.cat((x[...,:-1],coarse),axis=-1)\n", 196 | " x_LGR4 = x_LGR4.permute(0,4,1,2,3,5).to(device)\n", 197 | " x_LGR4[...,-1:] = NORMALIZER_DICT['input']['LGR4'].encode(x_LGR4.to(device)[...,-1:])\n", 198 | " pred = NORMALIZER_DICT['output']['LGR4'].decode(MODEL_DICT['LGR4'](x_LGR4)).cpu()\n", 199 | " lgr_dict['LGR4'] = pred\n", 200 | " y = data_LGR4['output'][...,:1].permute(0,4,1,2,3,5)\n", 201 | " true_dict['LGR4'] = y\n", 202 | "\n", 203 | " err = pred - y\n", 204 | " ERR_LIST.append(err.numpy())" 205 | ] 206 | }, 207 | { 208 | "cell_type": "markdown", 209 | "id": "0c198b65", 210 | "metadata": {}, 211 | "source": [ 212 | "# Save error in matrix" 213 | ] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "execution_count": 32, 218 | "id": "27254451", 219 | "metadata": {}, 220 | "outputs": [ 221 | { 222 | "name": "stdout", 223 | "output_type": "stream", 224 | "text": [ 225 | "(43, 24, 40, 40, 50, 1)\n" 226 | ] 227 | } 228 | ], 229 | "source": [ 230 | "err = np.array(ERR_LIST).reshape(-1, 24, 40, 40, 50, 1)\n", 231 | "print(err.shape)" 232 | ] 233 | }, 234 | { 235 | "cell_type": "markdown", 236 | "id": "dda34ef2", 237 | "metadata": {}, 238 | "source": [ 239 | "# Reload from trained weights" 240 | ] 241 | }, 242 | { 243 | "cell_type": "code", 244 | "execution_count": 24, 245 | "id": "26e50dde", 246 | "metadata": {}, 247 | "outputs": [ 248 | { 249 | "data": { 250 | "text/plain": [ 251 | "FNO4d(\n", 252 | " (fc0): Linear(in_features=9, out_features=28, bias=True)\n", 253 | " (conv): Block4d(\n", 254 | " (conv0): SpectralConv4d()\n", 255 | " (conv1): SpectralConv4d()\n", 256 | " (conv2): SpectralConv4d()\n", 257 | " (conv3): SpectralConv4d()\n", 258 | " (w0): Conv1d(28, 28, kernel_size=(1,), stride=(1,))\n", 259 | " (w1): Conv1d(28, 28, kernel_size=(1,), stride=(1,))\n", 260 | " (w2): Conv1d(28, 28, kernel_size=(1,), stride=(1,))\n", 261 | " (w3): Conv1d(28, 28, kernel_size=(1,), stride=(1,))\n", 262 | " (fc1): Linear(in_features=28, out_features=112, bias=True)\n", 263 | " (fc2): Linear(in_features=112, out_features=1, bias=True)\n", 264 | " )\n", 265 | ")" 266 | ] 267 | }, 268 | "execution_count": 24, 269 | "metadata": {}, 270 | "output_type": "execute_result" 271 | } 272 | ], 273 | "source": [ 274 | "from lploss import *\n", 275 | "LPloss = LpLoss(size_average=True)\n", 276 | "\n", 277 | "from FNO4D import *\n", 278 | "device = torch.device('cuda')\n", 279 | "width = 28\n", 280 | "model = torch.load( f\"pre_trained_models/FNO4D-LGR4-DP.pt\")\n", 281 | "model.to(device)" 282 | ] 283 | }, 284 | { 285 | "cell_type": "code", 286 | "execution_count": 25, 287 | "id": "89f6eaf1", 288 | "metadata": {}, 289 | "outputs": [ 290 | { 291 | "name": "stdout", 292 | "output_type": "stream", 293 | "text": [ 294 | "FNO4D-LGR4-DP-1107-1730-finetune16\n" 295 | ] 296 | } 297 | ], 298 | "source": [ 299 | "from datetime import datetime\n", 300 | "from datetime import date\n", 301 | "\n", 302 | "now = datetime.now()\n", 303 | "today = date.today()\n", 304 | "key = 'LGR4' \n", 305 | "day = today.strftime(\"%m%d\")\n", 306 | "current_time = now.strftime(\"%H%M\")\n", 307 | "specs = f'FNO4D-{key}-DP'\n", 308 | "model_str = f'{day}-{current_time}-finetune{n_train}'\n", 309 | "print(f'{specs}-{model_str}')" 310 | ] 311 | }, 312 | { 313 | "cell_type": "code", 314 | "execution_count": 26, 315 | "id": "d70d4f15", 316 | "metadata": {}, 317 | "outputs": [], 318 | "source": [ 319 | "with open(f\"normalizer/input_normalizer_{key}_DP_val.pickle\", 'rb') as f:\n", 320 | " input_normalizer = pickle.load(f)\n", 321 | " input_normalizer.cuda()\n", 322 | " \n", 323 | "with open(f\"normalizer/output_normalizer_{key}_DP_val.pickle\", 'rb') as f:\n", 324 | " output_normalizer = pickle.load(f)\n", 325 | " output_normalizer.cuda()" 326 | ] 327 | }, 328 | { 329 | "cell_type": "code", 330 | "execution_count": 27, 331 | "id": "2ed2fc16", 332 | "metadata": {}, 333 | "outputs": [], 334 | "source": [ 335 | "from Adam import Adam\n", 336 | "scheduler_step = 2\n", 337 | "scheduler_gamma = 0.85\n", 338 | "learning_rate = 1e-3\n", 339 | "\n", 340 | "optimizer = Adam(model.parameters(), lr=learning_rate, weight_decay=0)\n", 341 | "scheduler = torch.optim.lr_scheduler.StepLR(optimizer, \n", 342 | " step_size=scheduler_step, \n", 343 | " gamma=scheduler_gamma)" 344 | ] 345 | }, 346 | { 347 | "cell_type": "code", 348 | "execution_count": 28, 349 | "id": "f6f92ae1", 350 | "metadata": {}, 351 | "outputs": [], 352 | "source": [ 353 | "from torch.utils.tensorboard import SummaryWriter\n", 354 | "writer = SummaryWriter(f'logs/')" 355 | ] 356 | }, 357 | { 358 | "cell_type": "markdown", 359 | "id": "6aabaa3d", 360 | "metadata": {}, 361 | "source": [ 362 | "# Finetune" 363 | ] 364 | }, 365 | { 366 | "cell_type": "code", 367 | "execution_count": null, 368 | "id": "5c90b8af", 369 | "metadata": {}, 370 | "outputs": [], 371 | "source": [ 372 | "for ep in range(51,60):\n", 373 | " model.train()\n", 374 | " train_lp = 0, 0\n", 375 | " c = 0\n", 376 | " \n", 377 | " for data in train_loader:\n", 378 | " x, y, path = data['x'], data['y'], data['path']\n", 379 | " rand_idx = np.random.randint(err.shape[0]) # change this to the size of your saved error\n", 380 | " x[...,-1:] += torch.from_numpy(err[rand_idx,...])\n", 381 | " x, y = x[None,...].to(device), y[None,...][...,:1].to(device)\n", 382 | " \n", 383 | " optimizer.zero_grad()\n", 384 | " x[...,-1:] = input_normalizer_GLOBAL.encode(x[...,-1:])\n", 385 | " pred = model_global(x)\n", 386 | " pred = output_normalizer_GLOBAL.decode(pred)\n", 387 | " \n", 388 | " loss = LPloss(pred.reshape(1, -1), y.reshape(1, -1))\n", 389 | " train_lp += loss.item()\n", 390 | " \n", 391 | " loss.backward()\n", 392 | " optimizer.step()\n", 393 | " c += 1\n", 394 | " \n", 395 | " if c%100 ==0:\n", 396 | " writer.add_scalars('dp LPloss', {f'{model_str}_{specs}_train': loss.item()}, ep*n_train+c)\n", 397 | " print(f'ep: {ep}, iter: {c}, train lp: {loss.item():.4f}')\n", 398 | "\n", 399 | " scheduler.step()\n", 400 | " print('----------------------------------------------------------------------')\n", 401 | "\n", 402 | " torch.save(model_global, f'saved_models/{model_str}-{specs}-ep{ep}.pt')" 403 | ] 404 | } 405 | ], 406 | "metadata": { 407 | "kernelspec": { 408 | "display_name": "Python 3", 409 | "language": "python", 410 | "name": "python3" 411 | }, 412 | "language_info": { 413 | "codemirror_mode": { 414 | "name": "ipython", 415 | "version": 3 416 | }, 417 | "file_extension": ".py", 418 | "mimetype": "text/x-python", 419 | "name": "python", 420 | "nbconvert_exporter": "python", 421 | "pygments_lexer": "ipython3", 422 | "version": "3.9.0" 423 | } 424 | }, 425 | "nbformat": 4, 426 | "nbformat_minor": 5 427 | } 428 | -------------------------------------------------------------------------------- /hetero_logs/events.out.tfevents.1667870912.sh02-16n09.int.68858.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gegewen/nested-fno/eba03af76fabef2e7fac104fab5171537d48ae65/hetero_logs/events.out.tfevents.1667870912.sh02-16n09.int.68858.0 -------------------------------------------------------------------------------- /lploss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | #loss function with rel/abs Lp loss 4 | class LpLoss(object): 5 | def __init__(self, d=2, p=2, size_average=True, reduction=True, LGR=False): 6 | super(LpLoss, self).__init__() 7 | 8 | #Dimension and Lp-norm type are postive 9 | assert d > 0 and p > 0 10 | 11 | self.d = d 12 | self.p = p 13 | self.reduction = reduction 14 | self.size_average = size_average 15 | self.LGR = LGR 16 | 17 | def abs(self, x, y): 18 | num_examples = x.size()[0] 19 | 20 | #Assume uniform mesh 21 | h = 1.0 / (x.size()[1] - 1.0) 22 | 23 | all_norms = (h**(self.d/self.p))*torch.norm(x.view(num_examples,-1) - y.view(num_examples,-1), self.p, 1) 24 | 25 | if self.reduction: 26 | if self.size_average: 27 | return torch.mean(all_norms) 28 | else: 29 | return torch.sum(all_norms) 30 | 31 | return all_norms 32 | 33 | def rel(self, x, y): 34 | num_examples = x.size()[0] 35 | 36 | diff_norms = torch.norm(x.reshape(num_examples,-1) - y.reshape(num_examples,-1), self.p, 1) 37 | y_norms = torch.norm(y.reshape(num_examples,-1), self.p, 1) 38 | 39 | if self.reduction: 40 | if self.size_average: 41 | return torch.mean(diff_norms/y_norms) 42 | else: 43 | return torch.sum(diff_norms/y_norms) 44 | 45 | return diff_norms/y_norms 46 | 47 | def global_loss(self, x, y, glob): 48 | num_examples = x.size()[0] 49 | 50 | diff_norms = torch.norm(x.reshape(num_examples,-1) - y.reshape(num_examples,-1), self.p, 1) 51 | 52 | if self.reduction: 53 | if self.size_average: 54 | return torch.mean(diff_norms/glob) 55 | else: 56 | return torch.sum(diff_norms/glob) 57 | 58 | return diff_norms/glob 59 | 60 | def __call__(self, x, y, glob=None): 61 | if self.LGR is False: 62 | return self.rel(x, y) 63 | else: 64 | return self.global_loss(x, y, glob) 65 | 66 | -------------------------------------------------------------------------------- /meta_data_to_input_dict.py: -------------------------------------------------------------------------------- 1 | from iapws import IAPWS97 2 | import torch 3 | import numpy as np 4 | # from config_pt_utility import * 5 | 6 | 7 | xy_norm = lambda x: (x)/160000 8 | z_norm = lambda x: (x-2000)/2000 9 | p_norm = lambda x: (x)/172 10 | t_norm = lambda x: (x)/70 11 | k_norm = lambda x: (x)/100 12 | z_dnorm = lambda x: x*2000+2000 13 | p_dnorm = lambda x: 172*x 14 | depth_func = lambda a : (a - 1.01325)/9.8*100 15 | 16 | times = np.cumsum(10*np.array(np.power(1.2531,np.arange(1,25,1)), dtype=int)) 17 | times = times/ 10950 18 | 19 | def calculate_p_initial(DEPTH_datum, P_datum, T_datum, tops, thickness): 20 | rho_datum = IAPWS97(T=T_datum+273.15, P=P_datum/10).Liquid.rho 21 | Z_cell_center = tops+thickness/2 22 | return P_datum + (Z_cell_center-DEPTH_datum)* 9.8*rho_datum/100000 23 | 24 | def create_P_INIT_DICT(TOPS_DICT, INPUT_DICT, WELL_LIST, LGR_LIST, GRID_IDX_DICT): 25 | DEPTH_datum = depth_func(INPUT_DICT['p']) 26 | P_datum = INPUT_DICT['p'] 27 | T_datum = INPUT_DICT['temp'] 28 | 29 | P_INIT_DICT = {} 30 | P_INIT_DICT['GLOBAL'] = calculate_p_initial(DEPTH_datum, P_datum, T_datum, 31 | tops=TOPS_DICT['GLOBAL'], 32 | thickness=GRID_IDX_DICT['GLOBAL']['DZ']) 33 | for well in WELL_LIST: 34 | d = {} 35 | for lgr in LGR_LIST: 36 | d[lgr] = calculate_p_initial(DEPTH_datum, P_datum, T_datum, 37 | tops=TOPS_DICT[well][lgr], 38 | thickness=GRID_IDX_DICT[well][lgr]['DZ']) 39 | P_INIT_DICT[well] = d 40 | return P_INIT_DICT 41 | 42 | def meta_data_to_input_dict(meta_data): 43 | for k, v in meta_data.items(): 44 | globals()[k]=v 45 | 46 | p, t, rate = INPUT_DICT['p'], INPUT_DICT['temp'], INPUT_DICT['inj'] 47 | P_INIT_DICT = create_P_INIT_DICT(TOPS_DICT, INPUT_DICT, WELL_LIST, LGR_LIST, GRID_IDX_DICT) 48 | INJ_MAP_DICT = return_inj_map_dict(WELL_LIST,rate,INJ_LOCATION_DICT,GRID_CENTER_DICT,LGR_LIST) 49 | pressure_upsampled_dict = return_upsample_dict(P_INIT_DICT, WELL_LIST, GRID_IDX_DICT) 50 | 51 | ml_input_dict = {} 52 | 53 | # GLOBAL 54 | gridx = np.repeat(xy_norm(GRID_CENTER_DICT['GLOBAL']['grid_x'])[...,None,None], 24, axis=-2) 55 | gridy = np.repeat(xy_norm(GRID_CENTER_DICT['GLOBAL']['grid_y'])[...,None,None], 24, axis=-2) 56 | gridz = np.repeat(z_norm(TOPS_DICT['GLOBAL'][0,...,None,None]), 24, axis=-2) 57 | gridt = (np.ones(gridz.shape)* times[None,None,None,:,None]) 58 | inj = np.repeat(INJ_MAP_DICT['GLOBAL'][...,None,None], 24, axis=-2) 59 | temp = t_norm(t) * np.ones(inj.shape) 60 | perm = np.repeat(k_norm(PERM_DICT['GLOBAL'])[0,...,None,None], 24, axis=-2) 61 | pressure = np.repeat(p_norm(pressure_upsampled_dict['GLOBAL'][...,None,None]), 24, axis=-2) 62 | x = np.concatenate([gridx, gridy, gridz, gridt, inj, pressure, temp, perm], axis=-1)[None,...].transpose(0,4,1,2,3,5) 63 | x = torch.from_numpy(x.astype(np.float32)) 64 | ml_input_dict['GLOBAL'] = x 65 | 66 | for well in WELL_LIST: 67 | well_input_dict = {} 68 | # LGR1 69 | gridx = np.repeat(xy_norm(GRID_CENTER_DICT[well]['LGR1']['grid_x'])[...,None,None], 24, axis=-2) 70 | gridy = np.repeat(xy_norm(GRID_CENTER_DICT[well]['LGR1']['grid_y'])[...,None,None], 24, axis=-2) 71 | gridz = np.repeat(z_norm(TOPS_DICT[well]['LGR1'][0,...,None,None]), 24, axis=-2) 72 | gridt = (np.ones(gridz.shape)* times[None,None,None,:,None]) 73 | inj = np.repeat(INJ_MAP_DICT[well]['LGR1'][...,None,None], 24, axis=-2) 74 | pressure = np.repeat(p_norm(pressure_upsampled_dict[well]['LGR1'][0,...,None,None]), 24, axis=-2) 75 | temp = t_norm(t) * np.ones(inj.shape) 76 | perm = np.repeat(k_norm(PERM_DICT[well]['LGR1'])[0,...,None,None], 24, axis=-2) 77 | I1, I2 = GRID_IDX_DICT[well]['LGR1']['I1']-1-15, GRID_IDX_DICT[well]['LGR1']['I2']+15 78 | J1, J2 = GRID_IDX_DICT[well]['LGR1']['J1']-1-15, GRID_IDX_DICT[well]['LGR1']['J2']+15 79 | coarse = np.zeros(gridx.shape) 80 | x = np.concatenate([gridx, gridy, gridz, gridt, inj, pressure, temp, perm, coarse], axis=-1)[None,...] 81 | x = torch.from_numpy(x.astype(np.float32)) 82 | well_input_dict['LGR1'] = x 83 | 84 | # LGR2-4 85 | for lgr in ['LGR2', 'LGR3', 'LGR4']: 86 | gridx = np.repeat(xy_norm(GRID_CENTER_DICT[well][lgr]['grid_x'])[...,None,None], 24, axis=-2) 87 | gridy = np.repeat(xy_norm(GRID_CENTER_DICT[well][lgr]['grid_y'])[...,None,None], 24, axis=-2) 88 | gridz = np.repeat(z_norm(TOPS_DICT[well][lgr][0,...,None,None]), 24, axis=-2) 89 | gridt = (np.ones(gridz.shape)* times[None,None,None,:,None]) 90 | inj = np.repeat(INJ_MAP_DICT[well][lgr][...,None,None], 24, axis=-2) 91 | pressure = np.repeat(p_norm(pressure_upsampled_dict[well][lgr][0,...,None,None]), 24, axis=-2) 92 | temp = t_norm(t) * np.ones(inj.shape) 93 | perm = np.repeat(k_norm(PERM_DICT[well][lgr])[0,...,None,None], 24, axis=-2) 94 | x = np.concatenate([gridx, gridy, gridz, gridt, inj, pressure, temp, perm, np.zeros(gridx.shape)], axis=-1)[None,...] 95 | x = torch.from_numpy(x.astype(np.float32)) 96 | if lgr == 'LGR4': 97 | perf = PERF_DICT[well] 98 | old_inj = torch.clone(x[0,19,19,:,:,4]) 99 | new_inj = torch.zeros(old_inj.shape) 100 | new_inj[perf[0]-1:perf[1],:] = old_inj[perf[0]-1:perf[1],:] 101 | x[0,19,19,:,:,4] = new_inj 102 | well_input_dict[lgr] = x 103 | ml_input_dict[well] = well_input_dict 104 | return ml_input_dict -------------------------------------------------------------------------------- /normalizer/input_normalizer_GLOBAL_DP_val.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gegewen/nested-fno/eba03af76fabef2e7fac104fab5171537d48ae65/normalizer/input_normalizer_GLOBAL_DP_val.pickle -------------------------------------------------------------------------------- /normalizer/input_normalizer_LGR1_DP_val.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gegewen/nested-fno/eba03af76fabef2e7fac104fab5171537d48ae65/normalizer/input_normalizer_LGR1_DP_val.pickle -------------------------------------------------------------------------------- /normalizer/input_normalizer_LGR2_DP_val.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gegewen/nested-fno/eba03af76fabef2e7fac104fab5171537d48ae65/normalizer/input_normalizer_LGR2_DP_val.pickle -------------------------------------------------------------------------------- /normalizer/input_normalizer_LGR3_DP_val.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gegewen/nested-fno/eba03af76fabef2e7fac104fab5171537d48ae65/normalizer/input_normalizer_LGR3_DP_val.pickle -------------------------------------------------------------------------------- /normalizer/input_normalizer_LGR4_DP_val.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gegewen/nested-fno/eba03af76fabef2e7fac104fab5171537d48ae65/normalizer/input_normalizer_LGR4_DP_val.pickle -------------------------------------------------------------------------------- /normalizer/output_normalizer_GLOBAL_DP_val.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gegewen/nested-fno/eba03af76fabef2e7fac104fab5171537d48ae65/normalizer/output_normalizer_GLOBAL_DP_val.pickle -------------------------------------------------------------------------------- /normalizer/output_normalizer_LGR1_DP_val.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gegewen/nested-fno/eba03af76fabef2e7fac104fab5171537d48ae65/normalizer/output_normalizer_LGR1_DP_val.pickle -------------------------------------------------------------------------------- /normalizer/output_normalizer_LGR2_DP_val.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gegewen/nested-fno/eba03af76fabef2e7fac104fab5171537d48ae65/normalizer/output_normalizer_LGR2_DP_val.pickle -------------------------------------------------------------------------------- /normalizer/output_normalizer_LGR3_DP_val.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gegewen/nested-fno/eba03af76fabef2e7fac104fab5171537d48ae65/normalizer/output_normalizer_LGR3_DP_val.pickle -------------------------------------------------------------------------------- /normalizer/output_normalizer_LGR4_DP_val.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gegewen/nested-fno/eba03af76fabef2e7fac104fab5171537d48ae65/normalizer/output_normalizer_LGR4_DP_val.pickle -------------------------------------------------------------------------------- /predict_full_sg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | def predict_full_sg(input_dict, MODEL_DICT, NORMALIZER_DICT, device): 5 | with torch.no_grad(): 6 | PRED = {} 7 | x = input_dict['GLOBAL'].to(device) 8 | x[...,-1:] = NORMALIZER_DICT['input']['GLOBAL'].encode(x.to(device)[...,-1:]) 9 | pred = NORMALIZER_DICT['output']['GLOBAL'].decode(MODEL_DICT['GLOBAL'](x)).cpu() 10 | PRED['GLOBAL'] = pred 11 | 12 | WELL_LIST = list(input_dict.keys()) 13 | WELL_LIST.remove('GLOBAL') 14 | 15 | for well in WELL_LIST: 16 | lgr_dict = {} 17 | x_lgr1 = input_dict[well]['LGR1'] 18 | a = np.abs(input_dict['GLOBAL'][0,0,:,:,0,0][:,0].numpy()-input_dict[well]['LGR1'][0,:,:,0,0,0][:,0].numpy()[0]) 19 | I1 = np.unravel_index(np.argmin(a, axis=None), a.shape)[0] - 15 20 | a = np.abs(input_dict['GLOBAL'][0,0,:,:,0,0][:,0].numpy()-input_dict[well]['LGR1'][0,:,:,0,0,0][:,0].numpy()[-1]) 21 | I2 = np.unravel_index(np.argmin(a, axis=None), a.shape)[0] + 16 22 | a = np.abs(input_dict['GLOBAL'][0,0,:,:,0,1][0,:].numpy()-input_dict[well]['LGR1'][0,:,:,0,0,1][0,:].numpy()[0]) 23 | J1 = np.unravel_index(np.argmin(a, axis=None), a.shape)[0] - 15 24 | a = np.abs(input_dict['GLOBAL'][0,0,:,:,0,1][0,:].numpy()-input_dict[well]['LGR1'][0,:,:,0,0,1][0,:].numpy()[-1]) 25 | J2 = np.unravel_index(np.argmin(a, axis=None), a.shape)[0] + 16 26 | coarse = np.repeat(PRED['GLOBAL'][0,...][:,I1:I2,J1:J2,:,:],5,axis=-2).permute(-1,1,2,3,0)[...,None] 27 | x_LGR1 = torch.cat((x_lgr1[...,:-1],coarse),axis=-1) 28 | x_LGR1 = x_LGR1.permute(0,4,1,2,3,5).to(device) 29 | x_LGR1[...,-1:] = NORMALIZER_DICT['input']['LGR1'].encode(x_LGR1.to(device)[...,-1:]) 30 | pred = MODEL_DICT['LGR1'](x_LGR1).cpu() 31 | lgr_dict['LGR1'] = pred 32 | 33 | x_lgr2 = input_dict[well]['LGR2'] 34 | coarse = np.repeat(lgr_dict['LGR1'][0,...],2,axis=-2).permute(-1,1,2,3,0)[...,None] 35 | x_LGR2 = torch.cat((x_lgr2[...,:-1],coarse),axis=-1) 36 | x_LGR2 = x_LGR2.permute(0,4,1,2,3,5).to(device) 37 | pred = MODEL_DICT['LGR2'](x_LGR2).cpu() 38 | lgr_dict['LGR2'] = pred 39 | 40 | x_lgr3 = input_dict[well]['LGR3'] 41 | coarse = lgr_dict['LGR2'][0,...].permute(-1,1,2,3,0)[...,None] 42 | x_LGR3 = torch.cat((x_lgr3[...,:-1],coarse),axis=-1) 43 | x_LGR3 = x_LGR3.permute(0,4,1,2,3,5).to(device) 44 | pred = MODEL_DICT['LGR3'](x_LGR3).cpu() 45 | lgr_dict['LGR3'] = pred 46 | 47 | x_lgr4 = input_dict[well]['LGR4'] 48 | coarse = lgr_dict['LGR3'][0,...].permute(-1,1,2,3,0)[...,None] 49 | x_LGR4 = torch.cat((x_lgr4[...,:-1],coarse),axis=-1) 50 | x_LGR4 = x_LGR4.permute(0,4,1,2,3,5).to(device) 51 | pred = MODEL_DICT['LGR4'](x_LGR4).cpu() 52 | lgr_dict['LGR4'] = pred 53 | 54 | PRED[well] = lgr_dict 55 | 56 | PRED['GLOBAL'] *= 0 57 | return PRED -------------------------------------------------------------------------------- /save_data_loader.py: -------------------------------------------------------------------------------- 1 | import json 2 | from torch.utils.data import Dataset 3 | import random 4 | import os 5 | import torch 6 | import glob 7 | from CustomDataset import * 8 | random.seed(0) 9 | 10 | files = os.listdir('ECLIPSE/meta_data/') 11 | names = [] 12 | for file in files: 13 | n = file.split('.')[0] 14 | names.append(f'{n}_GLOBAL_DP.pt') 15 | 16 | random.shuffle(names) 17 | train_names = names[:16] 18 | val_names = names[16:18] 19 | test_names = names[18:] 20 | 21 | ROOT_PATH = 'dataset/' 22 | DATA_LOADER_DICT = {} 23 | DATA_LOADER_DICT['GLOBAL'] = {'train': CustomDataset(ROOT_PATH+'dP_GLOBAL/', train_names), 24 | 'val': CustomDataset(ROOT_PATH+'dP_GLOBAL/', val_names), 25 | 'test': CustomDataset(ROOT_PATH+'dP_GLOBAL/', test_names)} 26 | 27 | for key in ['LGR1', 'LGR2', 'LGR3', 'LGR4']: 28 | LGR_ROOT_PATH = os.listdir(f'{ROOT_PATH}dP_{key}/') 29 | train_lgr_lists_dP = GLOBAL_to_LGR_path(train_names, key, LGR_ROOT_PATH, 'dP') 30 | val_lgr_lists_dP = GLOBAL_to_LGR_path(val_names, key, LGR_ROOT_PATH, 'dP') 31 | test_lgr_lists_dP = GLOBAL_to_LGR_path(test_names, key, LGR_ROOT_PATH, 'dP') 32 | LGR_ROOT_PATH = os.listdir(f'{ROOT_PATH}SG_{key}/') 33 | train_lgr_lists_SG = GLOBAL_to_LGR_path(train_names, key, LGR_ROOT_PATH, 'SG') 34 | val_lgr_lists_SG = GLOBAL_to_LGR_path(val_names, key, LGR_ROOT_PATH, 'SG') 35 | test_lgr_lists_SG = GLOBAL_to_LGR_path(test_names, key, LGR_ROOT_PATH, 'SG') 36 | 37 | DATA_LOADER_DICT[key] = {'dP': {'train': CustomDataset(ROOT_PATH, train_lgr_lists_dP), 38 | 'val': CustomDataset(ROOT_PATH, val_lgr_lists_dP), 39 | 'test': CustomDataset(ROOT_PATH, test_lgr_lists_dP)}, 40 | 'SG': {'train': CustomDataset(ROOT_PATH, train_lgr_lists_SG), 41 | 'val': CustomDataset(ROOT_PATH, val_lgr_lists_SG), 42 | 'test': CustomDataset(ROOT_PATH, test_lgr_lists_SG)}} 43 | 44 | torch.save(DATA_LOADER_DICT, 'DATA_LOADER_DICT.pth') 45 | print('data loader done') -------------------------------------------------------------------------------- /train_FNO4D_DP_GLOBAL.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import random 3 | import os 4 | import torch 5 | from CustomDataset import * 6 | 7 | 8 | DATA_LOADER_DICT = torch.load('DATA_LOADER_DICT.pth') 9 | train_loader = DATA_LOADER_DICT['GLOBAL']['train'] 10 | val_loader = DATA_LOADER_DICT['GLOBAL']['val'] 11 | n_train = len(train_loader) 12 | n_val = len(val_loader) 13 | print(n_train, n_val) 14 | 15 | 16 | 17 | from lploss import * 18 | LPloss = LpLoss(size_average=True) 19 | 20 | from FNO4D import * 21 | device = torch.device('cuda') 22 | width = 28 23 | mode1, mode2, mode3, mode4 = 4, 20, 20, 2 24 | model_global = FNO4d(mode1, mode2, mode3, mode4, width, in_dim=8) 25 | model_global.to(device) 26 | 27 | 28 | from datetime import datetime 29 | from datetime import date 30 | 31 | 32 | now = datetime.now() 33 | today = date.today() 34 | 35 | day = today.strftime("%m%d") 36 | current_time = now.strftime("%H%M") 37 | specs = f'FNO4D-GLOBAL-DP' 38 | model_str = f'{day}-{current_time}-train{n_train}' 39 | print(f'{specs}-{model_str}') 40 | 41 | 42 | from Adam import Adam 43 | scheduler_step = 5 44 | scheduler_gamma = 0.85 45 | learning_rate = 1e-3 46 | 47 | optimizer = Adam(model_global.parameters(), lr=learning_rate, weight_decay=0) 48 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 49 | step_size=scheduler_step, 50 | gamma=scheduler_gamma) 51 | 52 | 53 | from torch.utils.tensorboard import SummaryWriter 54 | writer = SummaryWriter(f'logs/') 55 | 56 | 57 | import pickle 58 | from UnitGaussianNormalizer import * 59 | 60 | with open("normalizer/input_normalizer_GLOBAL_DP_val.pickle", 'rb') as f: 61 | input_normalizer_GLOBAL = pickle.load(f) 62 | input_normalizer_GLOBAL.cuda() 63 | 64 | with open("normalizer/output_normalizer_GLOBAL_DP_val.pickle", 'rb') as f: 65 | output_normalizer_GLOBAL = pickle.load(f) 66 | output_normalizer_GLOBAL.cuda() 67 | 68 | 69 | for ep in range(51): 70 | model_global.train() 71 | train_lp = 0 72 | c = 0 73 | 74 | for data in train_loader: 75 | x, y, path = data['x'], data['y'], data['path'] 76 | slope, idx = path[0], path[1] 77 | path = f'{slope}_{idx}' 78 | 79 | x, y = x[None,...].to(device), y[None,...].to(device) 80 | optimizer.zero_grad() 81 | 82 | x[...,-1:] = input_normalizer_GLOBAL.encode(x[...,-1:]) 83 | pred = model_global(x) 84 | pred = output_normalizer_GLOBAL.decode(pred) 85 | loss = LPloss(pred.reshape(1, -1), y[...,:1].reshape(1, -1)) 86 | train_lp += loss.item() 87 | 88 | loss.backward() 89 | optimizer.step() 90 | c += 1 91 | if c%10 ==0: 92 | writer.add_scalars('dP LPloss', {f'{model_str}_{specs}_train': loss.item()}, 93 | ep*n_train+c) 94 | print(f'ep: {ep}, iter: {c}, train lp: {loss.item():.4f}') 95 | 96 | scheduler.step() 97 | 98 | model_global.eval() 99 | val_lp = 0 100 | val_mre = 0 101 | with torch.no_grad(): 102 | for data in val_loader: 103 | x, y, path = data['x'], data['y'], data['path'] 104 | slope, idx = path[0], path[1] 105 | path = f'{slope}_{idx}' 106 | 107 | x, y = x[None,...].to(device), y[None,...].to(device) 108 | x[...,-1:] = input_normalizer_GLOBAL.encode(x[...,-1:]) 109 | pred = model_global(x) 110 | pred = output_normalizer_GLOBAL.decode(pred) 111 | loss = LPloss(pred.reshape(1, -1), y[...,:1].reshape(1, -1)) 112 | val_lp += loss.item() 113 | 114 | writer.add_scalars('dP LPloss', {f'{model_str}_{specs}_train': train_lp/n_train, 115 | f'{model_str}_{specs}_val': val_lp/n_val}, ep*n_train+c) 116 | 117 | print(f'epoch: {ep} summary') 118 | print(f'train loss: {train_lp/n_train:.4f}, val loss: {val_lp/n_val:.4f}') 119 | print('----------------------------------------------------------------------') 120 | 121 | torch.save(model_global, f'saved_models/{model_str}-{specs}-ep{ep}.pt') -------------------------------------------------------------------------------- /train_FNO4D_DP_LGR.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | import sys 3 | from torch.utils.data import Dataset 4 | import random 5 | import os 6 | import torch 7 | from CustomDataset import * 8 | 9 | key = sys.argv[1] 10 | var = 'dP' 11 | 12 | DATA_LOADER_DICT = torch.load('DATA_LOADER_DICT.pth') 13 | train_loader = DATA_LOADER_DICT[key][var]['train'] 14 | val_loader = DATA_LOADER_DICT[key][var]['val'] 15 | n_train = len(train_loader) 16 | n_val = len(val_loader) 17 | print(n_train, n_val) 18 | 19 | from lploss import * 20 | LPloss = LpLoss(size_average=True) 21 | 22 | from FNO4D import * 23 | device = torch.device('cuda') 24 | width = 28 25 | mode1, mode2, mode3, mode4 = 6, 10, 10, 10 26 | model = FNO4d(mode1, mode2, mode3, mode4, width, in_dim=9) 27 | model.to(device) 28 | 29 | from datetime import datetime 30 | from datetime import date 31 | 32 | now = datetime.now() 33 | today = date.today() 34 | 35 | day = today.strftime("%m%d") 36 | current_time = now.strftime("%H%M") 37 | specs = f'FNO4D-{key}-{var}' 38 | model_str = f'{day}-{current_time}-train{n_train}' 39 | print(f'{specs}-{model_str}') 40 | 41 | 42 | from torch.utils.tensorboard import SummaryWriter 43 | writer = SummaryWriter(f'logs/') 44 | 45 | 46 | import pickle 47 | from UnitGaussianNormalizer import * 48 | 49 | with open(f"normalizer/input_normalizer_{key}_{var.upper()}_val.pickle", 'rb') as f: 50 | input_normalizer = pickle.load(f) 51 | input_normalizer.cuda() 52 | 53 | with open(f"normalizer/output_normalizer_{key}_{var.upper()}_val.pickle", 'rb') as f: 54 | output_normalizer = pickle.load(f) 55 | output_normalizer.cuda() 56 | 57 | 58 | from Adam import Adam 59 | scheduler_step = 5 60 | scheduler_gamma = 0.85 61 | learning_rate = 1e-3 62 | 63 | optimizer = Adam(model.parameters(), lr=learning_rate, weight_decay=0) 64 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 65 | step_size=scheduler_step, 66 | gamma=scheduler_gamma) 67 | 68 | for ep in range(11): 69 | model.train() 70 | train_lp = 0 71 | c = 0 72 | 73 | for data in train_loader: 74 | x, y, path = data['x'], data['y'], data['path'] 75 | slope, idx = path[0], path[1] 76 | path = f'{slope}_{idx}' 77 | 78 | x, y = x[None,...].to(device), y[None,...][...,:1].to(device) 79 | 80 | optimizer.zero_grad() 81 | x[...,-1:] = input_normalizer.encode(x[...,-1:]) 82 | pred = model(x) 83 | pred = output_normalizer.decode(pred) 84 | loss = LPloss(pred.reshape(1, -1), y.reshape(1, -1)) 85 | train_lp += loss.item() 86 | 87 | loss.backward() 88 | optimizer.step() 89 | c += 1 90 | 91 | if c%10 ==0: 92 | writer.add_scalars(f'{var} LPloss', {f'{model_str}_{specs}_train': loss.item()}, 93 | ep*n_train+c) 94 | print(f'ep: {ep}, iter: {c}, train lp: {loss.item():.4f}') 95 | 96 | scheduler.step() 97 | 98 | model.eval() 99 | val_lp = 0 100 | val_mre = 0 101 | with torch.no_grad(): 102 | for data in val_loader: 103 | x, y, path = data['x'], data['y'], data['path'] 104 | slope, idx = path[0], path[1] 105 | path = f'{slope}_{idx}' 106 | 107 | x, y = x[None,...].to(device), y[None,...][...,:1].to(device) 108 | x[...,-1:] = input_normalizer.encode(x[...,-1:]) 109 | pred = model(x) 110 | pred = output_normalizer.decode(pred) 111 | loss = LPloss(pred.reshape(1, -1), y.reshape(1, -1)) 112 | val_lp += loss.item() 113 | 114 | writer.add_scalars(f'{var} LPloss', {f'{model_str}_{specs}_train': train_lp/n_train, 115 | f'{model_str}_{specs}_val': val_lp/n_val}, ep*n_train+c) 116 | 117 | print(f'epoch: {ep} summary') 118 | print(f'train loss: {train_lp/n_train:.4f}, val loss: {val_lp/n_val:.4f}') 119 | print('----------------------------------------------------------------------') 120 | 121 | torch.save(model, f'saved_models/{model_str}-{specs}-ep{ep}.pt') -------------------------------------------------------------------------------- /train_FNO4D_SG_LGR.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | import sys 3 | from torch.utils.data import Dataset 4 | import random 5 | import os 6 | import torch 7 | from CustomDataset import * 8 | 9 | key = sys.argv[1] 10 | var = 'SG' 11 | 12 | DATA_LOADER_DICT = torch.load('DATA_LOADER_DICT.pth') 13 | train_loader = DATA_LOADER_DICT[key][var]['train'] 14 | val_loader = DATA_LOADER_DICT[key][var]['val'] 15 | n_train = len(train_loader) 16 | n_val = len(val_loader) 17 | print(n_train, n_val) 18 | 19 | from lploss import * 20 | LPloss = LpLoss(size_average=True) 21 | 22 | from FNO4D import * 23 | device = torch.device('cuda') 24 | width = 28 25 | mode1, mode2, mode3, mode4 = 6, 10, 10, 10 26 | model = FNO4d(mode1, mode2, mode3, mode4, width, in_dim=9) 27 | model.to(device) 28 | 29 | 30 | from datetime import datetime 31 | from datetime import date 32 | 33 | now = datetime.now() 34 | today = date.today() 35 | 36 | day = today.strftime("%m%d") 37 | current_time = now.strftime("%H%M") 38 | specs = f'FNO4D-{key}-SG' 39 | model_str = f'{day}-{current_time}-train{n_train}' 40 | print(f'{specs}-{model_str}') 41 | 42 | 43 | from torch.utils.tensorboard import SummaryWriter 44 | writer = SummaryWriter(f'logs/') 45 | 46 | import pickle 47 | from UnitGaussianNormalizer import * 48 | 49 | if key == 'LGR1': 50 | with open(f"normalizer/input_normalizer_{key}_DP_val.pickle", 'rb') as f: 51 | input_normalizer = pickle.load(f) 52 | input_normalizer.cuda() 53 | 54 | 55 | from Adam import Adam 56 | scheduler_step = 5 57 | scheduler_gamma = 0.85 58 | learning_rate = 1e-3 59 | 60 | optimizer = Adam(model.parameters(), lr=learning_rate, weight_decay=0) 61 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 62 | step_size=scheduler_step, 63 | gamma=scheduler_gamma) 64 | 65 | for ep in range(51): 66 | model.train() 67 | train_lp = 0 68 | c = 0 69 | 70 | for data in train_loader: 71 | x, y, path = data['x'], data['y'], data['path'] 72 | slope, idx = path[0], path[1] 73 | path = f'{slope}_{idx}' 74 | 75 | x, y = x[None,...].to(device), y[None,...][...,:1].to(device) 76 | 77 | optimizer.zero_grad() 78 | if key == 'LGR1': 79 | x[...,-1:] = input_normalizer.encode(x[...,-1:]) 80 | pred = model(x) 81 | 82 | loss = LPloss(pred.reshape(1, -1), y.reshape(1, -1)) 83 | 84 | train_lp += loss.item() 85 | 86 | loss.backward() 87 | optimizer.step() 88 | c += 1 89 | 90 | if c%10 ==0: 91 | writer.add_scalars('SG LPloss', {f'{model_str}_{specs}_train': loss.item()}, ep*n_train+c) 92 | 93 | print(f'ep: {ep}, iter: {c}, train lp: {loss.item():.4f}') 94 | 95 | scheduler.step() 96 | 97 | model.eval() 98 | val_lp = 0 99 | with torch.no_grad(): 100 | for data in val_loader: 101 | x, y, path = data['x'], data['y'], data['path'] 102 | slope, idx = path[0], path[1] 103 | path = f'{slope}_{idx}' 104 | 105 | x, y = x[None,...].to(device), y[None,...][...,:1].to(device) 106 | if key == 'LGR1': 107 | x[...,-1:] = input_normalizer.encode(x[...,-1:]) 108 | pred = model(x) 109 | 110 | loss = LPloss(pred.reshape(1, -1), y.reshape(1, -1)) 111 | 112 | val_lp += loss.item() 113 | 114 | writer.add_scalars('SG LPloss', {f'{model_str}_{specs}_train': train_lp/n_train, 115 | f'{model_str}_{specs}_val': val_lp/n_val}, ep*n_train+c) 116 | 117 | 118 | print(f'epoch: {ep} summary') 119 | print(f'train loss: {train_lp/n_train:.4f}, val loss: {val_lp/n_val:.4f}') 120 | print('----------------------------------------------------------------------') 121 | 122 | torch.save(model, f'saved_models/{model_str}-{specs}-ep{ep}.pt') -------------------------------------------------------------------------------- /utility.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn.functional as F 3 | import torch 4 | 5 | def return_OUTPUT_DICT(meta_data, case_name): 6 | nt = list(meta_data[case_name]['data'].keys()) 7 | OUT = {} 8 | GRID_IDX_DICT = meta_data[case_name]['GRID_IDX_DICT'] 9 | WELL_LIST = meta_data[case_name]['WELL_LIST'] 10 | LGR_LIST = meta_data[case_name]['LGR_LIST'] 11 | 12 | for name in [ 'BGSAT', 'BPR']: 13 | out = {} 14 | lname = f'L{name}' 15 | for t in nt: 16 | data = meta_data[case_name]['data'][t] 17 | output_dict = {} 18 | 19 | output_dict['GLOBAL'] = data[name].reshape((-1, GRID_IDX_DICT['GLOBAL']['NX'], 20 | GRID_IDX_DICT['GLOBAL']['NY'], 21 | GRID_IDX_DICT['GLOBAL']['NZ'])) 22 | N_LIST = [0] 23 | idx = 0 24 | for well in WELL_LIST: 25 | for lgr in LGR_LIST: 26 | n_prev = N_LIST[idx] 27 | idx += 1 28 | n_cur = n_prev+GRID_IDX_DICT[well][lgr]['NX'] * GRID_IDX_DICT[well][lgr]['NY'] * GRID_IDX_DICT[well][lgr]['NZ'] 29 | N_LIST.append(n_cur) 30 | 31 | if well in output_dict: 32 | output_dict[well].update({lgr: data[lname][:,n_prev: n_cur].reshape(-1, 33 | GRID_IDX_DICT[well][lgr]['NX'], 34 | GRID_IDX_DICT[well][lgr]['NY'], 35 | GRID_IDX_DICT[well][lgr]['NZ']) }) 36 | else: 37 | output_dict[well] = {lgr: data[lname][:,n_prev: n_cur].reshape(-1, 38 | GRID_IDX_DICT[well][lgr]['NX'], 39 | GRID_IDX_DICT[well][lgr]['NY'], 40 | GRID_IDX_DICT[well][lgr]['NZ']) } 41 | out[t] = output_dict 42 | OUT[name] = out 43 | 44 | out = {} 45 | for t in nt: 46 | output_dict = {} 47 | output_dict['GLOBAL'] = OUT['BPR'][t]['GLOBAL'] - OUT['BPR'][0]['GLOBAL'] 48 | 49 | for well in WELL_LIST: 50 | for lgr in LGR_LIST: 51 | if well in output_dict: 52 | output_dict[well].update({lgr: OUT['BPR'][t][well][lgr] - OUT['BPR'][0][well][lgr]}) 53 | else: 54 | output_dict[well] = {lgr: OUT['BPR'][t][well][lgr] - OUT['BPR'][0][well][lgr]} 55 | out[t] = output_dict 56 | OUT['dP'] = out 57 | 58 | out = {} 59 | for t in nt: 60 | output_dict = {} 61 | output_dict['GLOBAL'] = OUT['dP'][t]['GLOBAL'] > 0.1 62 | 63 | for well in WELL_LIST: 64 | for lgr in LGR_LIST: 65 | if well in output_dict: 66 | output_dict[well].update({lgr: OUT['dP'][t][well][lgr] > 0.1}) 67 | else: 68 | output_dict[well] = {lgr: OUT['dP'][t][well][lgr] > 0.1 } 69 | out[t] = output_dict 70 | OUT['P_influence'] = out 71 | return OUT 72 | 73 | def return_upsample_dict(OUTPUT_DICT, t, name, WELL_LIST, GRID_IDX_DICT): 74 | OUTPUT_UPSAMPLE_DICT = {} 75 | 76 | LGR_BEFORE = ['LGR3', 'LGR2', 'LGR1'] 77 | LGR_AFTER = ['LGR4', 'LGR3', 'LGR2'] 78 | 79 | for well in WELL_LIST: 80 | OUTPUT_UPSAMPLE_DICT[well] = {'LGR4': OUTPUT_DICT[name][t][well]['LGR4']} 81 | for iii in range(3): 82 | lgr_before = LGR_BEFORE[iii] 83 | lgr_after = LGR_AFTER[iii] 84 | 85 | upsampled = np.copy(OUTPUT_DICT[name][t][well][lgr_before][-1,:,:,:]) 86 | nx_new = GRID_IDX_DICT[well][lgr_after]['I2'] - GRID_IDX_DICT[well][lgr_after]['I1'] + 1 87 | ny_new = GRID_IDX_DICT[well][lgr_after]['J2'] - GRID_IDX_DICT[well][lgr_after]['J1'] + 1 88 | nz_new = GRID_IDX_DICT[well][lgr_after]['K2'] - GRID_IDX_DICT[well][lgr_after]['K1'] + 1 89 | 90 | A = F.interpolate(torch.from_numpy(OUTPUT_UPSAMPLE_DICT[well][lgr_after][-1,:,:,:])[None, None,...], 91 | size=[nx_new,ny_new,nz_new], mode='trilinear', align_corners=False)[0,0,...].numpy() 92 | 93 | upsampled[GRID_IDX_DICT[well][lgr_after]['I1']-1:GRID_IDX_DICT[well][lgr_after]['I2'], 94 | GRID_IDX_DICT[well][lgr_after]['J1']-1:GRID_IDX_DICT[well][lgr_after]['J2'],:] = A 95 | 96 | if well in OUTPUT_UPSAMPLE_DICT: 97 | OUTPUT_UPSAMPLE_DICT[well].update({lgr_before: upsampled[None,...]}) 98 | else: 99 | OUTPUT_UPSAMPLE_DICT[well]={lgr_before: upsampled[None,...]} 100 | 101 | upsampled = np.copy(OUTPUT_DICT[name][t]['GLOBAL'][-1,:,:,:]) 102 | for well in WELL_LIST: 103 | nx_new = GRID_IDX_DICT[well]['LGR1']['I2'] - GRID_IDX_DICT[well]['LGR1']['I1'] + 1 104 | ny_new = GRID_IDX_DICT[well]['LGR1']['J2'] - GRID_IDX_DICT[well]['LGR1']['J1'] + 1 105 | nz_new = GRID_IDX_DICT[well]['LGR1']['K2'] - GRID_IDX_DICT[well]['LGR1']['K1'] + 1 106 | A = F.interpolate(torch.from_numpy(OUTPUT_UPSAMPLE_DICT[well]['LGR1'][-1,:,:,:])[None, None,...], 107 | size=[nx_new,ny_new,nz_new], mode='trilinear', align_corners=False)[0,0,...].numpy() 108 | upsampled[GRID_IDX_DICT[well]['LGR1']['I1']-1:GRID_IDX_DICT[well]['LGR1']['I2'], 109 | GRID_IDX_DICT[well]['LGR1']['J1']-1:GRID_IDX_DICT[well]['LGR1']['J2'],:] = A 110 | OUTPUT_UPSAMPLE_DICT['GLOBAL'] = upsampled 111 | return OUTPUT_UPSAMPLE_DICT 112 | 113 | def load_perm(file): 114 | with open(file,'r') as f: 115 | lines = f.readlines() 116 | perm = [] 117 | for line in lines[1:]: 118 | perm.append(float(line.split('*')[-1][:-2])) 119 | return np.array(perm) 120 | 121 | def tops_dict(parent_name, folder_name, grid_idx_dict, well_list, lgr_list): 122 | TOPS_DICT = {} 123 | nx = grid_idx_dict['GLOBAL']['NX'] 124 | ny = grid_idx_dict['GLOBAL']['NY'] 125 | nz = grid_idx_dict['GLOBAL']['NZ'] 126 | TOPS_DICT['GLOBAL'] = load_perm(f'../ECLIPSE/{parent_name}/{folder_name}/TOPS.IN').reshape(1, 127 | nx, 128 | ny, 129 | nz, 130 | order='F') 131 | for well in well_list: 132 | for lgr in lgr_list: 133 | nx = grid_idx_dict[well][lgr]['NX'] 134 | ny = grid_idx_dict[well][lgr]['NY'] 135 | nz = grid_idx_dict[well][lgr]['NZ'] 136 | if well in TOPS_DICT: 137 | TOPS_DICT[well].update({lgr: load_perm(f'../ECLIPSE/{parent_name}/{folder_name}/TOPS_{well}_{lgr}.IN').reshape(1,nx, ny, nz, order='F')}) 138 | else: 139 | TOPS_DICT[well] = {lgr: load_perm(f'../ECLIPSE/{parent_name}/{folder_name}/TOPS_{well}_{lgr}.IN').reshape(1,nx, ny, nz, order='F')} 140 | return TOPS_DICT 141 | 142 | 143 | 144 | def return_inj_map_dict(well_list,rate_dict,inj_loc_dict,center_dict, LGR_LIST): 145 | inj_norm = lambda x: (x)/(2942777.68785957) 146 | 147 | INJ_MAP_DICT = {} 148 | 149 | inj_map = np.zeros(center_dict['GLOBAL']['grid_x'].shape) 150 | for well in well_list: 151 | well_x, well_y = inj_loc_dict[well] 152 | xidx = (np.abs(center_dict['GLOBAL']['grid_x'][:,0,0] - well_x)).argmin() 153 | yidx = (np.abs(center_dict['GLOBAL']['grid_y'][0,:,0] - well_y)).argmin() 154 | inj_map[xidx, yidx, :] = inj_norm(rate_dict[well]) 155 | INJ_MAP_DICT['GLOBAL'] = inj_map 156 | 157 | for well in well_list: 158 | well_x, well_y = inj_loc_dict[well] 159 | for lgr in LGR_LIST: 160 | inj_map = np.zeros(center_dict[well][lgr]['grid_x'].shape) 161 | xidx = (np.abs(center_dict[well][lgr]['grid_x'][:,0,0] - well_x)).argmin() 162 | yidx = (np.abs(center_dict[well][lgr]['grid_y'][0,:,0] - well_y)).argmin() 163 | inj_map[xidx, yidx, :] = inj_norm(rate_dict[well]) 164 | if well in INJ_MAP_DICT: 165 | INJ_MAP_DICT[well].update({lgr: inj_map}) 166 | else: 167 | INJ_MAP_DICT[well]={lgr: inj_map} 168 | 169 | return INJ_MAP_DICT 170 | 171 | 172 | def dict_convert_torch_to_numpy(torch_dict): 173 | WELL_LIST = list(torch_dict.keys()) 174 | WELL_LIST.remove('GLOBAL') 175 | 176 | numpy_dict = {} 177 | numpy_dict['GLOBAL'] = torch_dict['GLOBAL'][0,...,0].numpy() 178 | for well in WELL_LIST: 179 | d = {} 180 | for lgr in ['LGR1','LGR2','LGR3','LGR4']: 181 | d[lgr] = torch_dict[well][lgr][0,...,0].numpy() 182 | numpy_dict[well] = d 183 | return numpy_dict 184 | -------------------------------------------------------------------------------- /visulization_compare.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import matplotlib as mpl 4 | import matplotlib.patches as patches 5 | from matplotlib.patches import Rectangle 6 | 7 | def plot_x_slice(x,WELL_LIST, LGR_LIST, PCOLOR_GRID_DICT, GRID_IDX_DICT,OUTPUT_DICT,cmin,cmax, 8 | xmin=None,xmax=None,ymin=None,ymax=None, figsize=None, title=None, boundary_on=True, grid_width=0): 9 | xidx = int(x/1600) 10 | plt.pcolormesh(PCOLOR_GRID_DICT['GLOBAL']['grid_y'][xidx,:,:], 11 | PCOLOR_GRID_DICT['GLOBAL']['grid_z'][xidx,:,:], 12 | OUTPUT_DICT['GLOBAL'][-1,xidx,:,:], 13 | shading='flat', edgecolor='k',linewidth=grid_width) 14 | plt.clim([cmin, cmax]) 15 | for well in WELL_LIST: 16 | for lgr in LGR_LIST: 17 | lgr_start = np.min(PCOLOR_GRID_DICT[well][lgr]['grid_x']) 18 | lgr_end = np.max(PCOLOR_GRID_DICT[well][lgr]['grid_x']) 19 | if (x>lgr_start) and (x=lgr_start) and (x1 183 | plt.plot(GRID_DICT[well][lgr]['grid_y'][xidx,:,z][active], 184 | OUTPUT_DICT[well][lgr][-1,xidx,:,z][active], 185 | '.',label=well+lgr) 186 | plt.title(f'x = {x} m') 187 | plt.xlim([xmin, xmax]) 188 | plt.ylim([190,220]) 189 | plt.legend() 190 | plt.show() 191 | 192 | def plot_zglobal_slice(z,WELL_LIST, LGR_LIST, GRID_IDX_DICT,OUTPUT_DICT,cmin,cmax, 193 | xmin=None,xmax=None,ymin=None,ymax=None,grid_on=False): 194 | plt.figure(figsize=(10,9)) 195 | 196 | PCOLOR_GRID_DICT = pcolor_grid_dict(WELL_LIST, GRID_IDX_DICT) 197 | if grid_on is True: 198 | plt.pcolormesh(PCOLOR_GRID_DICT['GLOBAL']['grid_x'][:,:,z], 199 | PCOLOR_GRID_DICT['GLOBAL']['grid_y'][:,:,z], 200 | OUTPUT_DICT['GLOBAL'][-1,:,:,z], 201 | shading='flat', edgecolor='k') 202 | else: 203 | plt.pcolormesh(PCOLOR_GRID_DICT['GLOBAL']['grid_x'][:,:,z], 204 | PCOLOR_GRID_DICT['GLOBAL']['grid_y'][:,:,z], 205 | OUTPUT_DICT['GLOBAL'][-1,:,:,z], 206 | shading='flat') 207 | 208 | plt.clim([cmin, cmax]) 209 | if xmin is not None: 210 | plt.xlim([xmin, xmax]) 211 | plt.ylim([ymin, ymax]) 212 | 213 | plt.colorbar(fraction=0.01) 214 | 215 | def plot_y_slice(y,WELL_LIST, LGR_LIST, GRID_IDX_DICT,OUTPUT_DICT,cmin,cmax, 216 | xmin=None,xmax=None,ymin=None,ymax=None, title=None): 217 | yidx = int(y/1000) 218 | print(yidx) 219 | PCOLOR_GRID_DICT = pcolor_grid_dict(WELL_LIST, GRID_IDX_DICT) 220 | plt.pcolormesh(PCOLOR_GRID_DICT['GLOBAL']['grid_x'][:,yidx,:], 221 | PCOLOR_GRID_DICT['GLOBAL']['grid_z'][:,yidx,:], 222 | OUTPUT_DICT['GLOBAL'][-1,:,yidx,:], 223 | shading='flat') 224 | plt.clim([cmin, cmax]) 225 | 226 | for well in WELL_LIST: 227 | for lgr in LGR_LIST: 228 | lgr_start = np.min(PCOLOR_GRID_DICT[well][lgr]['grid_y']) 229 | lgr_end = np.max(PCOLOR_GRID_DICT[well][lgr]['grid_y']) 230 | if (y>lgr_start) and (y