├── .idea
├── DeepBSDE-pytorch.iml
├── inspectionProfiles
│ └── Project_Default.xml
├── modules.xml
├── vcs.xml
└── workspace.xml
├── AllenCahn_training_history.csv
├── README.md
├── config.py
├── equation.py
├── solver.py
└── train.py
/.idea/DeepBSDE-pytorch.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 | true
25 | DEFINITION_ORDER
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 | Buildout
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 | 1535087265847
106 |
107 |
108 | 1535087265847
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
--------------------------------------------------------------------------------
/AllenCahn_training_history.csv:
--------------------------------------------------------------------------------
1 | step,loss_function,target_value,elapsed_time
2 | 0,5.39418e-02,3.26080e-01,0
3 | 100,3.44411e-02,2.78099e-01,13
4 | 200,2.31878e-02,2.35126e-01,26
5 | 300,1.58378e-02,1.97588e-01,39
6 | 400,9.66999e-03,1.65539e-01,52
7 | 500,6.76427e-03,1.38360e-01,65
8 | 600,6.04422e-03,1.15831e-01,78
9 | 700,4.32909e-03,9.79632e-02,92
10 | 800,4.50971e-03,8.42173e-02,105
11 | 900,3.48606e-03,7.42119e-02,118
12 | 1000,3.63141e-03,6.66798e-02,131
13 | 1100,3.00041e-03,6.09193e-02,144
14 | 1200,3.87503e-03,5.73990e-02,157
15 | 1300,2.88035e-03,5.50455e-02,170
16 | 1400,3.58197e-03,5.41830e-02,183
17 | 1500,3.43098e-03,5.44178e-02,196
18 | 1600,3.31472e-03,5.32332e-02,209
19 | 1700,3.04519e-03,5.22913e-02,223
20 | 1800,3.61448e-03,5.29079e-02,236
21 | 1900,4.06360e-03,5.27411e-02,249
22 | 2000,2.80823e-03,5.29730e-02,262
23 | 2100,3.39954e-03,5.13978e-02,275
24 | 2200,3.80409e-03,5.22298e-02,288
25 | 2300,4.08157e-03,5.20794e-02,301
26 | 2400,3.93781e-03,5.21230e-02,315
27 | 2500,3.23211e-03,5.20974e-02,328
28 | 2600,3.28405e-03,5.37664e-02,342
29 | 2700,3.89592e-03,5.27176e-02,355
30 | 2800,3.16535e-03,5.19796e-02,368
31 | 2900,3.55510e-03,5.14435e-02,382
32 | 3000,3.41030e-03,5.23272e-02,395
33 | 3100,3.02276e-03,5.29629e-02,409
34 | 3200,4.09172e-03,5.16499e-02,422
35 | 3300,3.32843e-03,5.28555e-02,435
36 | 3400,3.21832e-03,5.10399e-02,448
37 | 3500,3.13571e-03,5.17579e-02,461
38 | 3600,3.32771e-03,5.34426e-02,475
39 | 3700,3.44333e-03,5.26195e-02,488
40 | 3800,3.02391e-03,5.36960e-02,502
41 | 3900,3.37896e-03,5.22397e-02,515
42 | 4000,3.35890e-03,5.15655e-02,528
43 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | Reproduce "Solving high-dimensional partial differential equations using deep learning" by pytorch
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class Config(object):
5 | n_layer = 4
6 | batch_size = 64
7 | valid_size = 256
8 | step_boundaries = [2000, 4000]
9 | num_iterations = 6000
10 | logging_frequency = 100
11 | verbose = True
12 | y_init_range = [0, 1]
13 |
14 |
15 | class AllenCahnConfig(Config):
16 | total_time = 0.3
17 | num_time_interval = 20
18 | dim = 100
19 | lr_values = list(np.array([5e-4, 5e-4]))
20 | lr_boundaries = [2000]
21 | num_iterations = 4000
22 | num_hiddens = [dim, dim + 10, dim + 10, dim]
23 | y_init_range = [0.3, 0.6]
24 |
25 |
26 | class HJBConfig(Config):
27 | # Y_0 is about 4.5901.
28 | dim = 100
29 | total_time = 1.0
30 | num_time_interval = 20
31 | lr_boundaries = [400]
32 | num_iterations = 2000
33 | lr_values = list(np.array([1e-2, 1e-2]))
34 | num_hiddens = [dim, dim+10, dim+10, dim]
35 | y_init_range = [0, 1]
36 |
37 |
38 | class PricingOptionConfig(Config):
39 | dim = 100
40 | total_time = 0.5
41 | num_time_interval = 20
42 | lr_values = list(np.array([5e-3, 5e-3]))
43 | lr_boundaries = [2000]
44 | num_iterations = 4000
45 | num_hiddens = [dim, dim+10, dim+10, dim]
46 | y_init_range = [15, 18]
47 |
48 |
49 | class PricingDefaultRiskConfig(Config):
50 | dim = 100
51 | total_time = 1
52 | num_time_interval = 40
53 | lr_values = list(np.array([8e-3, 8e-3]))
54 | lr_boundaries = [3000]
55 | num_iterations = 6000
56 | num_hiddens = [dim, dim+10, dim+10, dim]
57 | y_init_range = [40, 50]
58 |
59 |
60 | class BurgesTypeConfig(Config):
61 | dim = 50
62 | total_time = 0.2
63 | num_time_interval = 30
64 | lr_values = list(np.array([1e-2, 1e-3, 1e-4]))
65 | lr_boundaries = [15000, 25000]
66 | num_iterations = 30000
67 | num_hiddens = [dim, dim+10, dim+10, dim]
68 | y_init_range = [2, 4]
69 |
70 |
71 | class QuadraticGradientsConfig(Config):
72 | dim = 100
73 | total_time = 1.0
74 | num_time_interval = 30
75 | lr_values = list(np.array([5e-3, 5e-3]))
76 | lr_boundaries = [2000]
77 | num_iterations = 4000
78 | num_hiddens = [dim, dim+10, dim+10, dim]
79 | y_init_range = [2, 4]
80 |
81 |
82 | class ReactionDiffusionConfig(Config):
83 | dim = 100
84 | total_time = 1.0
85 | num_time_interval = 30
86 | lr_values = list(np.array([1e-2, 1e-2, 1e-2]))
87 | lr_boundaries = [8000, 16000]
88 | num_iterations = 24000
89 | num_hiddens = [dim, dim+10, dim+10, dim]
90 |
91 |
92 | def get_config(name):
93 | try:
94 | return globals()[name+'Config']
95 | except KeyError:
96 | raise KeyError("Config for the required problem not found.")
97 |
--------------------------------------------------------------------------------
/equation.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from scipy.stats import multivariate_normal as normal
4 |
5 |
6 | class Equation(object):
7 | """Base class for defining PDE related function."""
8 |
9 | def __init__(self, dim, total_time, num_time_interval):
10 | self._dim = dim
11 | self._total_time = total_time
12 | self._num_time_interval = num_time_interval
13 | self._delta_t = (self._total_time + 0.0) / self._num_time_interval
14 | self._sqrt_delta_t = np.sqrt(self._delta_t)
15 | self._y_init = None
16 |
17 | def sample(self, num_sample):
18 | """Sample forward SDE."""
19 | raise NotImplementedError
20 |
21 | def f_th(self, t, x, y, z):
22 | """Generator function in the PDE."""
23 | raise NotImplementedError
24 |
25 | def g_th(self, t, x):
26 | """Terminal condition of the PDE."""
27 | raise NotImplementedError
28 |
29 | @property
30 | def y_init(self):
31 | return self._y_init
32 |
33 | @property
34 | def dim(self):
35 | return self._dim
36 |
37 | @property
38 | def num_time_interval(self):
39 | return self._num_time_interval
40 |
41 | @property
42 | def total_time(self):
43 | return self._total_time
44 |
45 | @property
46 | def delta_t(self):
47 | return self._delta_t
48 |
49 |
50 | def get_equation(name, dim, total_time, num_time_interval):
51 | try:
52 | return globals()[name](dim, total_time, num_time_interval)
53 | except KeyError:
54 | raise KeyError("Equation for the required problem not found.")
55 |
56 |
57 | class AllenCahn(Equation):
58 | def __init__(self, dim, total_time, num_time_interval):
59 | super(AllenCahn, self).__init__(dim, total_time, num_time_interval)
60 | self._x_init = np.zeros(self._dim)
61 | self._sigma = np.sqrt(2.0)
62 |
63 | def sample(self, num_sample):
64 | dw_sample = normal.rvs(size=[num_sample,
65 | self._dim,
66 | self._num_time_interval]) * self._sqrt_delta_t
67 | x_sample = np.zeros([num_sample, self._dim, self._num_time_interval + 1])
68 | x_sample[:, :, 0] = np.ones([num_sample, self._dim]) * self._x_init
69 | for i in range(self._num_time_interval):
70 | x_sample[:, :, i + 1] = x_sample[:, :, i] + self._sigma * dw_sample[:, :, i]
71 | return torch.FloatTensor(dw_sample), torch.FloatTensor(x_sample)
72 |
73 | def f_th(self, t, x, y, z):
74 | return y - torch.pow(y, 3)
75 |
76 | def g_th(self, t, x):
77 | return 0.5 / (1 + 0.2 * torch.sum(x**2, dim=1, keepdim=True))
78 |
79 |
80 |
--------------------------------------------------------------------------------
/solver.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torch.nn import Parameter
6 |
7 | TH_DTYPE = torch.float32
8 |
9 | MOMENTUM = 0.99
10 | EPSILON = 1e-6
11 | DELTA_CLIP = 50.0
12 |
13 | class Dense(nn.Module):
14 |
15 | def __init__(self,cin,cout, batch_norm=True, activate=True):
16 | super(Dense, self).__init__()
17 | self.cout = cout
18 | self.linear = nn.Linear(cin, cout)
19 | self.activate = activate
20 | if batch_norm:
21 | self.bn = nn.BatchNorm1d(cout,eps=EPSILON, momentum=MOMENTUM)
22 | else:
23 | self.bn = None
24 | nn.init.normal_(self.linear.weight,std=5.0/np.sqrt(cin+cout))
25 |
26 | def forward(self,x):
27 | x = self.linear(x)
28 | if self.bn is not None:
29 | x = self.bn(x)
30 | if self.activate:
31 | x = F.relu(x)
32 | return x
33 |
34 |
35 | class Subnetwork(nn.Module):
36 |
37 | def __init__(self, config):
38 | super(Subnetwork, self).__init__()
39 | self._config = config
40 | self.bn = nn.BatchNorm1d(config.dim,eps=EPSILON, momentum=MOMENTUM)
41 | self.layers = [Dense(config.num_hiddens[i-1], config.num_hiddens[i]) for i in range(1, len(config.num_hiddens)-1)]
42 | self.layers += [Dense(config.num_hiddens[-2], config.num_hiddens[-1], activate=False)]
43 | self.layers = nn.Sequential(*self.layers)
44 |
45 | def forward(self,x):
46 | x = self.bn(x)
47 | x = self.layers(x)
48 | return x
49 |
50 |
51 |
52 | class FeedForwardModel(nn.Module):
53 | """The fully connected neural network model."""
54 | def __init__(self, config, bsde):
55 | super(FeedForwardModel, self).__init__()
56 | self._config = config
57 | self._bsde = bsde
58 |
59 | # make sure consistent with FBSDE equation
60 | self._dim = bsde.dim
61 | self._num_time_interval = bsde.num_time_interval
62 | self._total_time = bsde.total_time
63 |
64 | self._y_init = Parameter(torch.Tensor([1]))
65 | self._y_init.data.uniform_(self._config.y_init_range[0], self._config.y_init_range[1])
66 | self._subnetworkList =nn.ModuleList([Subnetwork(config) for _ in range(self._num_time_interval-1)])
67 |
68 |
69 | def forward(self, x, dw):
70 |
71 | time_stamp = np.arange(0, self._bsde.num_time_interval) * self._bsde.delta_t
72 |
73 | z_init = torch.zeros([1, self._dim]).uniform_(-.1, .1).to(TH_DTYPE).cuda()
74 |
75 | all_one_vec = torch.ones((dw.shape[0], 1), dtype=TH_DTYPE).cuda()
76 | y = all_one_vec * self._y_init
77 |
78 | z = torch.matmul(all_one_vec, z_init)
79 |
80 | for t in range(0, self._num_time_interval-1):
81 | #print('y qian', y.max())
82 | y = y - self._bsde.delta_t * (
83 | self._bsde.f_th(time_stamp[t], x[:, :, t], y, z)
84 | )
85 | #print('y hou', y.max())
86 | add = torch.sum(z * dw[:, :, t], dim=1, keepdim=True)
87 | #print('add', add.max())
88 | y = y + add
89 | z = self._subnetworkList[t](x[:, :, t + 1]) / self._dim
90 | #print('z value', z.max())
91 | # terminal time
92 | y = y - self._bsde.delta_t * self._bsde.f_th(\
93 | time_stamp[-1], x[:, :, -2], y, z\
94 | ) + torch.sum(z * dw[:, :, -1], dim=1, keepdim=True)
95 |
96 | delta = y - self._bsde.g_th(self._total_time, x[:, :, -1])
97 |
98 | # use linear approximation outside the clipped range
99 | loss = torch.mean(torch.where(torch.abs(delta) < DELTA_CLIP, delta**2,
100 | 2 * DELTA_CLIP * torch.abs(delta) - DELTA_CLIP ** 2))
101 | return loss, self._y_init
102 |
103 |
104 |
105 |
106 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import time
2 | from torchImpl.solver import FeedForwardModel
3 | import logging
4 | import torch.optim as optim
5 | import numpy as np
6 | import os
7 | import torch
8 | os.environ["CUDA_VISIBLE_DEVICES"]="1"
9 | torch.backends.cudnn.benchmark=True
10 |
11 | def train(config,bsde):
12 | logging.basicConfig(level=logging.INFO,
13 | format='%(levelname)-6s %(message)s')
14 | if bsde.y_init:
15 | logging.info('Y0_true: %.4e' % bsde.y_init)
16 |
17 | # build and train
18 | net = FeedForwardModel(config,bsde)
19 | net.cuda()
20 |
21 | optimizer = optim.SGD(net.parameters(),5e-4)
22 | start_time = time.time()
23 | # to save iteration results
24 | training_history = []
25 | # for validation
26 | dw_valid, x_valid = bsde.sample(config.valid_size)
27 |
28 | # begin sgd iteration
29 | for step in range(config.num_iterations + 1):
30 | if step % config.logging_frequency == 0:
31 | net.eval()
32 | loss, init = net(x_valid.cuda(), dw_valid.cuda())
33 |
34 | elapsed_time = time.time() - start_time
35 | training_history.append([step, loss, init.item(), elapsed_time])
36 | if config.verbose:
37 | logging.info("step: %5u, loss: %.4e, Y0: %.4e, elapsed time %3u" % (
38 | step, loss, init.item(), elapsed_time))
39 |
40 | dw_train, x_train = bsde.sample(config.batch_size)
41 | optimizer.zero_grad()
42 | net.train()
43 | loss, _ = net(x_train.cuda(), dw_train.cuda())
44 | loss.backward()
45 |
46 | optimizer.step()
47 |
48 | training_history =np.array(training_history)
49 |
50 | if bsde.y_init:
51 | logging.info('relative error of Y0: %s',
52 | '{:.2%}'.format(
53 | abs(bsde.y_init - training_history[-1, 2]) / bsde.y_init))
54 |
55 |
56 | np.savetxt('{}_training_history.csv'.format(bsde.__class__.__name__),
57 | training_history,
58 | fmt=['%d', '%.5e', '%.5e', '%d'],
59 | delimiter=",",
60 | header="step,loss_function,target_value,elapsed_time",
61 | comments='')
62 |
63 | if __name__ == '__main__':
64 | from torchImpl.config import get_config
65 | from torchImpl.equation import get_equation
66 | cfg = get_config('AllenCahn')
67 | bsde = get_equation('AllenCahn', cfg.dim, cfg.total_time, cfg.num_time_interval)
68 | train(cfg,bsde)
--------------------------------------------------------------------------------