├── FNN.py ├── data ├── T_test1.txt ├── T_train1.txt ├── u_test1.txt ├── u_train1.txt └── x_grid.txt ├── deeponet.py ├── train.py └── utils.py /FNN.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from torch import nn 6 | 7 | 8 | class FNN(nn.Module): 9 | """Fully-connected neural network.""" 10 | 11 | def __init__(self, layer_sizes): 12 | super(FNN, self).__init__() 13 | 14 | self.denses = [] 15 | for i in range(1, len(layer_sizes) - 1): 16 | self.denses.append( 17 | nn.Linear(in_features=layer_sizes[i - 1], out_features=layer_sizes[i]) 18 | ) 19 | self.denses.append(nn.ReLU()) 20 | self.denses.append( 21 | nn.Linear(in_features=layer_sizes[-2], out_features=layer_sizes[-1]) 22 | ) 23 | 24 | def forward(self, inputs): 25 | y = inputs 26 | for f in self.denses: 27 | y = f(y) 28 | return y 29 | -------------------------------------------------------------------------------- /data/x_grid.txt: -------------------------------------------------------------------------------- 1 | 0.000000000000000000e+00 2 | 9.000000000000001055e-03 3 | 1.800000000000000211e-02 4 | 2.700000000000000316e-02 5 | 3.600000000000000422e-02 6 | 4.500000000000000527e-02 7 | 5.400000000000000633e-02 8 | 6.300000000000000044e-02 9 | 7.200000000000000844e-02 10 | 8.100000000000001643e-02 11 | 9.000000000000001055e-02 12 | 9.900000000000000466e-02 13 | 1.080000000000000127e-01 14 | 1.170000000000000207e-01 15 | 1.260000000000000009e-01 16 | 1.350000000000000089e-01 17 | 1.440000000000000169e-01 18 | 1.530000000000000249e-01 19 | 1.620000000000000329e-01 20 | 1.710000000000000131e-01 21 | 1.800000000000000211e-01 22 | 1.890000000000000291e-01 23 | 1.980000000000000093e-01 24 | 2.070000000000000173e-01 25 | 2.160000000000000253e-01 26 | 2.250000000000000333e-01 27 | 2.340000000000000413e-01 28 | 2.430000000000000215e-01 29 | 2.520000000000000018e-01 30 | 2.610000000000000098e-01 31 | 2.700000000000000178e-01 32 | 2.790000000000000258e-01 33 | 2.880000000000000338e-01 34 | 2.970000000000000417e-01 35 | 3.060000000000000497e-01 36 | 3.150000000000000577e-01 37 | 3.240000000000000657e-01 38 | 3.330000000000000182e-01 39 | 3.420000000000000262e-01 40 | 3.510000000000000342e-01 41 | 3.600000000000000422e-01 42 | 3.690000000000000502e-01 43 | 3.780000000000000582e-01 44 | 3.870000000000000662e-01 45 | 3.960000000000000187e-01 46 | 4.050000000000000266e-01 47 | 4.140000000000000346e-01 48 | 4.230000000000000426e-01 49 | 4.320000000000000506e-01 50 | 4.410000000000000586e-01 51 | 4.500000000000000666e-01 52 | 4.590000000000000746e-01 53 | 4.680000000000000826e-01 54 | 4.770000000000000351e-01 55 | 4.860000000000000431e-01 56 | 4.950000000000000511e-01 57 | 5.040000000000000036e-01 58 | 5.130000000000000115e-01 59 | 5.220000000000000195e-01 60 | 5.310000000000000275e-01 61 | 5.400000000000000355e-01 62 | 5.490000000000000435e-01 63 | 5.580000000000000515e-01 64 | 5.670000000000000595e-01 65 | 5.760000000000000675e-01 66 | 5.850000000000000755e-01 67 | 5.940000000000000835e-01 68 | 6.030000000000000915e-01 69 | 6.120000000000000995e-01 70 | 6.210000000000001075e-01 71 | 6.300000000000001155e-01 72 | 6.390000000000001235e-01 73 | 6.480000000000001315e-01 74 | 6.570000000000000284e-01 75 | 6.660000000000000364e-01 76 | 6.750000000000000444e-01 77 | 6.840000000000000524e-01 78 | 6.930000000000000604e-01 79 | 7.020000000000000684e-01 80 | 7.110000000000000764e-01 81 | 7.200000000000000844e-01 82 | 7.290000000000000924e-01 83 | 7.380000000000001004e-01 84 | 7.470000000000001084e-01 85 | 7.560000000000001164e-01 86 | 7.650000000000001243e-01 87 | 7.740000000000001323e-01 88 | 7.830000000000001403e-01 89 | 7.920000000000000373e-01 90 | 8.010000000000000453e-01 91 | 8.100000000000000533e-01 92 | 8.190000000000000613e-01 93 | 8.280000000000000693e-01 94 | 8.370000000000000773e-01 95 | 8.460000000000000853e-01 96 | 8.550000000000000933e-01 97 | 8.640000000000001013e-01 98 | 8.730000000000001092e-01 99 | 8.820000000000001172e-01 100 | 8.910000000000001252e-01 101 | 9.000000000000000222e-01 102 | 9.010000000000000231e-01 103 | 9.020000000000000240e-01 104 | 9.030000000000000249e-01 105 | 9.040000000000000258e-01 106 | 9.050000000000000266e-01 107 | 9.060000000000000275e-01 108 | 9.070000000000000284e-01 109 | 9.080000000000000293e-01 110 | 9.090000000000000302e-01 111 | 9.100000000000000311e-01 112 | 9.110000000000000320e-01 113 | 9.120000000000000329e-01 114 | 9.130000000000000338e-01 115 | 9.140000000000000346e-01 116 | 9.150000000000000355e-01 117 | 9.160000000000000364e-01 118 | 9.170000000000000373e-01 119 | 9.180000000000000382e-01 120 | 9.190000000000000391e-01 121 | 9.200000000000000400e-01 122 | 9.210000000000000409e-01 123 | 9.220000000000000417e-01 124 | 9.230000000000000426e-01 125 | 9.240000000000000435e-01 126 | 9.250000000000000444e-01 127 | 9.260000000000000453e-01 128 | 9.270000000000000462e-01 129 | 9.280000000000000471e-01 130 | 9.290000000000000480e-01 131 | 9.300000000000000488e-01 132 | 9.310000000000000497e-01 133 | 9.320000000000000506e-01 134 | 9.330000000000000515e-01 135 | 9.340000000000000524e-01 136 | 9.350000000000000533e-01 137 | 9.360000000000000542e-01 138 | 9.370000000000000551e-01 139 | 9.380000000000000560e-01 140 | 9.390000000000000568e-01 141 | 9.400000000000000577e-01 142 | 9.410000000000000586e-01 143 | 9.420000000000000595e-01 144 | 9.430000000000000604e-01 145 | 9.440000000000000613e-01 146 | 9.450000000000000622e-01 147 | 9.460000000000000631e-01 148 | 9.470000000000000639e-01 149 | 9.479999999999999538e-01 150 | 9.490000000000000657e-01 151 | 9.499999999999999556e-01 152 | 9.510000000000000675e-01 153 | 9.519999999999999574e-01 154 | 9.530000000000000693e-01 155 | 9.539999999999999591e-01 156 | 9.549999999999999600e-01 157 | 9.559999999999999609e-01 158 | 9.569999999999999618e-01 159 | 9.579999999999999627e-01 160 | 9.589999999999999636e-01 161 | 9.599999999999999645e-01 162 | 9.609999999999999654e-01 163 | 9.619999999999999662e-01 164 | 9.629999999999999671e-01 165 | 9.639999999999999680e-01 166 | 9.649999999999999689e-01 167 | 9.659999999999999698e-01 168 | 9.669999999999999707e-01 169 | 9.679999999999999716e-01 170 | 9.689999999999999725e-01 171 | 9.699999999999999734e-01 172 | 9.709999999999999742e-01 173 | 9.719999999999999751e-01 174 | 9.729999999999999760e-01 175 | 9.739999999999999769e-01 176 | 9.749999999999999778e-01 177 | 9.759999999999999787e-01 178 | 9.769999999999999796e-01 179 | 9.779999999999999805e-01 180 | 9.789999999999999813e-01 181 | 9.799999999999999822e-01 182 | 9.809999999999999831e-01 183 | 9.819999999999999840e-01 184 | 9.829999999999999849e-01 185 | 9.839999999999999858e-01 186 | 9.849999999999999867e-01 187 | 9.859999999999999876e-01 188 | 9.869999999999999885e-01 189 | 9.879999999999999893e-01 190 | 9.889999999999999902e-01 191 | 9.899999999999999911e-01 192 | 9.909999999999999920e-01 193 | 9.919999999999999929e-01 194 | 9.929999999999999938e-01 195 | 9.939999999999999947e-01 196 | 9.949999999999999956e-01 197 | 9.959999999999999964e-01 198 | 9.969999999999999973e-01 199 | 9.979999999999999982e-01 200 | 9.989999999999999991e-01 201 | 1.000000000000000000e+00 202 | -------------------------------------------------------------------------------- /deeponet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import torch 6 | from torch import nn 7 | from FNN import FNN 8 | 9 | 10 | class DeepONet(nn.Module): 11 | 12 | # Initialize the class 13 | def __init__(self, layer_size_branch, layer_size_trunk): 14 | super(DeepONet, self).__init__() 15 | 16 | # initialize parameters and configuration 17 | self.layer_size_branch = layer_size_branch 18 | self.layer_size_trunk = layer_size_trunk 19 | 20 | self.loss_fun = self.MSE 21 | 22 | # initialize layers 23 | self.branch_net = FNN(self.layer_size_branch) 24 | if callable(self.layer_size_trunk[1]): 25 | # User-defined trunk net 26 | self.trunk_net = self.layer_size_trunk[1] 27 | else: 28 | self.trunk_net = FNN(self.layer_size_trunk) 29 | self.bias_last = torch.tebsor(torch.zeros(1), requires_grad=True) 30 | 31 | 32 | def forward(self, x_branch, x_trunk): 33 | # Branch net to encode the input function 34 | y_branch = self.branch_net(x_branch) 35 | # Trunk net to encode the domain of the output function 36 | y_trunk = self.trunk_net(x_trunk) 37 | # Dot product 38 | if y_branch.shape[-1] != y_trunk.shape[-1]: 39 | raise AssertionError( 40 | "Output sizes of branch net and trunk net do not match." 41 | ) 42 | Y = torch.einsum("bi,ni->bn", y_branch, y_trunk) 43 | # Add bias 44 | Y += self.bias_last 45 | return Y 46 | 47 | # mean square error 48 | def MSE(self, y_true, y_pred): 49 | return torch.mean(torch.square(y_true - y_pred)) 50 | 51 | # max L^infinity error of the test data set 52 | def Max_Linfty_Error(self, y_true, y_pred): 53 | return torch.max(torch.abs(y_true - y_pred)) 54 | 55 | # mean L^infinity error of the test data set 56 | def Mean_Linfty_Error(self, y_true, y_pred): 57 | return torch.mean( 58 | torch.max(torch.abs(y_true - y_pred), dim=1) 59 | ) 60 | 61 | # relative L2 error 62 | def relative_L2_Error(self, y_true, y_pred): 63 | return torch.mean( 64 | torch.norm(y_true - y_pred, dim=1) / torch.norm(y_true, dim=1) 65 | ) 66 | 67 | def get_loss(self, identifier): 68 | loss_identifier = { 69 | "mean squared error": self.MSE, 70 | "MSE": self.MSE, 71 | "mse": self.MSE, 72 | } 73 | if isinstance(identifier, str): 74 | return loss_identifier[identifier] 75 | elif callable(identifier): 76 | return identifier 77 | else: 78 | raise ValueError( 79 | "Could not interpret loss function identifier:", identifier 80 | ) 81 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | from utils import * 5 | 6 | 7 | class Triple(object): 8 | def __init__( 9 | self, U_train, X_train, S_train, U_val, X_val, S_val, 10 | ): 11 | self.check_shape(U_train, X_train, S_train) 12 | self.check_shape(U_val, X_val, S_val) 13 | assert U_train.shape[1] == U_val.shape[1] 14 | assert X_train.shape[1] == X_val.shape[1] 15 | 16 | self.U_train = torch.from_numpy(U_train.astype(np.float32)) 17 | self.X_train = torch.from_numpy(X_train.astype(np.float32)) 18 | self.S_train = torch.from_numpy(S_train.astype(np.float32)) 19 | self.U_val = torch.from_numpy(U_val.astype(np.float32)) 20 | self.X_val = torch.from_numpy(X_val.astype(np.float32)) 21 | self.S_val = torch.from_numpy(S_val.astype(np.float32)) 22 | 23 | def print_shape(self): 24 | print("\n========================= Data =========================") 25 | print("Train: Branch input shape (#u_train, m): " + str(data.U_train.shape)) 26 | print(" Trunk input shape (s, dim_x): " + str(self.X_train.shape)) 27 | print(" Output shape (#u_train, s): " + str(self.S_train.shape)) 28 | print("Test: Branch input shape (#u_test, m): " + str(self.S_val.shape)) 29 | print(" Trunk input shape: (s, dim_x): " + str(self.X_val.shape)) 30 | print(" Output shape: (#u_test, s): " + str(self.S_val.shape)) 31 | print("========================================================\n") 32 | 33 | def check_shape(self, U, X, S): 34 | assert U.shape[0] == S.shape[0] 35 | assert X.shape[0] == S.shape[1] 36 | 37 | @property 38 | def num_funcs_train(self): 39 | return self.U_train.shape[0] 40 | 41 | @property 42 | def num_funcs_val(self): 43 | return self.U_val.shape[0] 44 | 45 | @property 46 | def num_sensors(self): 47 | return self.U_train.shape[1] 48 | 49 | 50 | class Train(object): 51 | def __init__(self, model_path=None, device="cpu"): 52 | self.model_path = model_path 53 | self.device = device 54 | self.train_log = [] 55 | self.trainloss_best = {"epoch": 0, "loss": 1e5} 56 | self.valloss_best = {"epoch": 0, "loss": 1e5} 57 | 58 | def visualize_loss(self, save=False): 59 | epoch = np.array([d["epoch"] for d in self.train_log if "epoch" in d]) 60 | train_loss_log = np.array( 61 | [d["train_loss"] for d in self.train_log if "train_loss" in d] 62 | ) 63 | val_loss_log = np.array( 64 | [d["val_loss"] for d in self.train_log if "val_loss" in d] 65 | ) 66 | fig, axes = plt.subplots(1, 1, figsize=(6, 4)) 67 | axes.plot(epoch, train_loss_log, label="train loss") 68 | axes.plot(epoch, val_loss_log, label="test loss") 69 | axes.legend() 70 | axes.set_xlabel("epochs") 71 | axes.set_yscale("log") 72 | axes.tick_params(labelsize=8) 73 | 74 | 75 | class Train_Adam(Train): 76 | def __init__( 77 | self, batch_size, learning_rate=1e-3, model_path=None, device="cpu", 78 | ): 79 | super(Train_Adam, self).__init__(model_path, device) 80 | self.batch_size = batch_size 81 | self.optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) 82 | 83 | def train_one_batch_adam(self, model, U_train_batch, X_train, S_train_batch): 84 | self.optimizer.zero_grad() 85 | U_train_batch, X_train, S_train_batch = ( 86 | U_train_batch.to(self.device), 87 | X_train.to(self.device), 88 | S_train_batch.to(self.device), 89 | ) 90 | preds = model(U_train_batch, X_train) 91 | loss = model.loss_fun(S_train_batch, preds) 92 | loss.backward() 93 | self.optimizer.step() 94 | 95 | 96 | def train_adam(self, model, data, num_epochs): 97 | # Early stopping on validation MSE 98 | patience = 20 99 | wait = 0 100 | best = 1e5 101 | exit_flag = False 102 | for epoch in range(1, num_epochs + 1): 103 | model.train() 104 | perm = np.random.permutation(data.num_funcs_train) 105 | for it in range(0, data.num_funcs_train, self.batch_size): 106 | if it + self.batch_size < data.num_funcs_train: 107 | idx = perm[np.arange(it, it + self.batch_size)] 108 | else: 109 | idx = perm[np.arange(it, data.num_funcs_train)] 110 | self.train_one_batch_adam( 111 | model, data.U_train[idx, :], data.X_train, data.S_train[idx, :] 112 | ) 113 | if epoch == 1 or epoch % 1000 == 0: 114 | print("--------------------------------------------------------------") 115 | print("Epoch %d:" % (epoch)) 116 | preds_train = model(data.U_train[idx, :], data.X_train) 117 | train_loss = model.loss_fun(data.S_train[idx, :], preds_train).detach().numpy() 118 | preds_val = model(data.U_val, data.X_val) 119 | val_loss = model.loss_fun(data.S_val, preds_val).detach().numpy() 120 | 121 | print("train loss: %.3e, test MSE: %.3e" % (train_loss, val_loss)) 122 | self.train_log.append( 123 | {"epoch": epoch, "train_loss": train_loss, "val_loss": val_loss} 124 | ) 125 | if train_loss < self.trainloss_best["loss"]: 126 | self.trainloss_best["epoch"] = epoch 127 | self.trainloss_best["loss"] = train_loss 128 | if val_loss < self.valloss_best["loss"]: 129 | self.valloss_best["epoch"] = epoch 130 | self.valloss_best["loss"] = val_loss 131 | if self.model_path is not None: 132 | pass 133 | wait += 1 134 | if val_loss < best: 135 | best = val_loss 136 | wait = 0 137 | if wait >= patience: 138 | print("Epoch %d: ... Early Stopping ..." % (epoch)) 139 | exit_flag = True 140 | break 141 | if exit_flag: 142 | break -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | class Normalizer: 5 | def __init__(self, data, eps=1e-8): 6 | self.mean = np.mean(data, axis=0) 7 | self.std = np.std(data, axis=0) 8 | self.eps = eps 9 | 10 | def encode(self, data): 11 | data = (data - self.mean) / (self.std + self.eps) 12 | return data 13 | 14 | def decode(self, data): 15 | data = data * (self.std + self.eps) + self.mean 16 | return data 17 | 18 | 19 | def get_errors(model, y_true, y_pred): 20 | return ( 21 | model.MSE(y_true, y_pred), 22 | model.relative_L2_Error(y_true, y_pred), 23 | model.Mean_Linfty_Error(y_true, y_pred), 24 | model.Max_Linfty_Error(y_true, y_pred), 25 | ) 26 | 27 | def make_triple(ufname, xfname, Sfname): 28 | branch_data = np.loadtxt(ufname) 29 | x_grid = np.loadtxt(xfname) 30 | trunk_data = x_grid.reshape(-1, 1) 31 | output_data = np.loadtxt(Sfname) 32 | return branch_data, trunk_data, output_data 33 | 34 | def test(model, data): 35 | model.eval() 36 | preds = model(data.U_val, data.X_val) 37 | # errors of function values 38 | test_mse, test_L2Error, test_mean_LinfError, test_max_LinfError = get_errors( 39 | model, data.S_val, preds 40 | ) 41 | return test_mse, test_L2Error, test_mean_LinfError, test_max_LinfError 42 | 43 | 44 | def Plot(model, x_branch, x_trunk, output, fprefix): 45 | preds = model.forward(x_branch, x_trunk) 46 | preds = preds.detach().numpy() 47 | output = output.detach().numpy() 48 | fig1, axes1 = plt.subplots( 49 | len(x_branch), 2, squeeze=False, figsize=(12, 4 * len(x_branch)) 50 | ) 51 | for i in range(len(x_branch)): 52 | axes1[i][0].plot(x_trunk, preds[i], label="DeepONet") 53 | axes1[i][0].plot(x_trunk, output[i], label="reference") 54 | axes1[i][0].legend() 55 | axes1[i][0].set_xlabel("$x$") 56 | axes1[i][0].set_ylabel("$u(x,T=1)$") 57 | axes1[i][0].tick_params(labelsize=7) 58 | axes1[i][1].plot( 59 | x_trunk, 60 | preds[i] - output[i], 61 | label="error", 62 | ) 63 | axes1[i][1].legend() 64 | axes1[i][1].set_xlabel("$x$") 65 | axes1[i][1].set_ylabel("error") 66 | axes1[i][1].tick_params(labelsize=7) 67 | fig1.savefig(fprefix + "Plots.png",bbox_inches="tight") 68 | # plt.show() 69 | 70 | --------------------------------------------------------------------------------