├── mfnn
├── __init__.py
├── xydata.py
├── utils.py
├── mfnn.py
└── trainer.py
├── readme.org
├── example2.py
├── example1.py
└── snapshots
├── high.svg
├── data.svg
├── low.svg
└── mfnn.svg
/mfnn/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | from .mfnn import *
4 | from .xydata import *
5 | from .trainer import *
6 |
7 | from .utils import copy_to, DEVICE
8 |
--------------------------------------------------------------------------------
/mfnn/xydata.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | from __future__ import annotations
4 | import typing
5 | import torch
6 |
7 | TensorFunc = typing.Callable[[torch.Tensor], torch.Tensor]
8 |
9 | __all__ = ('XYDataSet', )
10 |
11 |
12 | class XYDataSet(torch.utils.data.Dataset):
13 | def __init__(self, x: torch.Tensor, y: torch.Tensor):
14 | if len(x) != len(y):
15 | raise ValueError('size of x and y not match')
16 | self.x = x
17 | self.y = y
18 |
19 | def __len__(self) -> int:
20 | return len(self.x)
21 |
22 | def __getitem__(self, idx) -> tuple[torch.Tensor, torch.Tensor]:
23 | return self.x[idx], self.y[idx]
24 |
--------------------------------------------------------------------------------
/readme.org:
--------------------------------------------------------------------------------
1 | * Multi-fidelity Neural Network
2 |
3 | Multi-fidelity neural network (MFNN) is used for modeling physical
4 | systems by multi-fidelity data. In real applications, low-fidelity
5 | data is usually abundant but less accurate, and high-fidelity data is
6 | scarce and expensive. MFNN makes uses of both low- and high-fidelity
7 | data for modeling the physical system, which significantly improves
8 | data accuracy by a small set of high-fidelity data.
9 |
10 | Meng and Karniadakis [1] gave an approach to MFNN by using a
11 | composite neural network. However, the code is absent for their
12 | paper. Meanwhile, the use of liner layers in high-fidelity DNN (NN_H1)
13 | is redundant, as linear features will always be modeled by the
14 | nonlinear DNN (NN_H2).
15 |
16 | Thus, in this repository, a modified version of MFNN is provided,
17 | where linear DNN (NN_H1) given by paper [1] are replaced by residual
18 | connection over the nonlinear DNN (NN_H2).The code is implemented using
19 | pytorch, and examples are provided for MFNN.
20 |
21 | * Snapshots
22 |
23 | ** Data
24 |
25 | [[./snapshots/data.svg]]
26 |
27 | ** Modeling using low-fidelity data
28 |
29 | [[./snapshots/low.svg]]
30 |
31 | ** Modeling by high-fidelity data
32 |
33 | [[./snapshots/high.svg]]
34 |
35 | ** Modeling by both low- and high-fidelity data
36 |
37 | [[./snapshots/mfnn.svg]]
38 |
39 | * References
40 | [1] Meng X, Karniadakis GE. A composite neural network that learns
41 | from multi-fidelity data: Application to function approximation and
42 | inverse PDE problems. Journal of Computational Physics
43 | 2020;401:109020. https://doi.org/10.1016/j.jcp.2019.109020.
44 |
45 |
--------------------------------------------------------------------------------
/mfnn/utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | from __future__ import annotations
4 |
5 | import gc
6 | import sys
7 | import typing
8 | from io import BytesIO
9 | import typing
10 |
11 | import torch
12 |
13 | # __all__ = ('DEVICE', 'free_memory', 'copy_to', 'Statistics', )
14 |
15 |
16 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17 |
18 |
19 | def free_memory():
20 | gc.collect()
21 | torch.cuda.empty_cache()
22 |
23 |
24 | def copy_to(data, device: torch.device | int | str | None = None):
25 | """Copy data to device."""
26 | # Avoid useless copy in gpu.
27 | # See https://discuss.pytorch.org/t/how-to-make-a-copy-of-a-gpu-model-on-the-cpu/90955/4
28 | if device is None:
29 | return data
30 | memory = BytesIO()
31 | torch.save(data, memory, pickle_protocol=-1)
32 | memory.seek(0)
33 | data = torch.load(memory, map_location=device)
34 | memory.close()
35 | return data
36 |
37 |
38 | T = typing.TypeVar('T', bound=float)
39 |
40 |
41 | class Statistics(typing.Generic[T]):
42 | """Get statistical value(sum, average, variance) of data."""
43 | __slots__ = '_count', '_value', '_s1', '_s2'
44 |
45 | def __init__(self):
46 | self._count = 0
47 | self._value: T = 0.
48 | self._s1: float = 0. # sum of sample values
49 | self._s2: float = 0. # sum of squared sample values
50 |
51 | def update(self, value: T, n: int = 1) -> Statistics[T]:
52 | self._count += n
53 | self._value = value
54 | self._s1 += self.value * n
55 | self._s2 += value ** 2 * n
56 | return self
57 |
58 | @property
59 | def count(self) -> int:
60 | return self._count
61 |
62 | @property
63 | def value(self) -> T:
64 | return self._value
65 |
66 | @property
67 | def sum(self) -> float:
68 | return self._s1
69 |
70 | @property
71 | def average(self) -> float:
72 | return self._s1 / self._count
73 |
74 | @property
75 | def variance(self) -> float:
76 | return self._s2 / self._count - self.average ** 2
77 |
78 | @property
79 | def std(self) -> float:
80 | return self.variance ** .5
81 |
82 |
83 | class GatedStdout:
84 | def __init__(self, suppress: bool):
85 | self.suppress = suppress
86 |
87 | def write(self, s: typing.TypeVar('AnyStr', bytes, str)):
88 | if self.suppress:
89 | return 0
90 | return sys.stdout.write(s)
91 |
--------------------------------------------------------------------------------
/example2.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import logging
4 | import torch
5 |
6 | from mfnn import FCNN, HFNN, MFNN, XYDataSet, Trainer
7 |
8 | import matplotlib as mpl
9 | import matplotlib.pyplot as plt
10 | mpl_fontpath = mpl.get_data_path() + '/fonts/ttf/STIXGeneral.ttf'
11 | mpl_fontprop = mpl.font_manager.FontProperties(fname=mpl_fontpath)
12 | plt.rc('font', family='STIXGeneral', weight='normal', size=10)
13 | plt.rc('mathtext', fontset='stix')
14 |
15 |
16 | def func_low(x: torch.Tensor) -> torch.Tensor:
17 | return torch.sin(8 * torch.pi * x)
18 |
19 |
20 | def func_high(x: torch.Tensor) -> torch.Tensor:
21 | return (x - 2**.5) * func_low(x) ** 2
22 |
23 |
24 | def figure1(x_low, y_low, x_high, y_high, x, y):
25 | "Plot the multi-fidelity data along with true data."
26 | y_high = func_high(x_high)
27 |
28 | fig = plt.figure(figsize=(3, 2.25))
29 | ax = fig.add_subplot(111)
30 | ax.plot(x_low, y_low, 'o', color='None', markeredgecolor='b', label='low')
31 | ax.plot(x_high, y_high, 'rx', label='high')
32 | ax.plot(x, y, 'k:', label='true')
33 | ax.legend()
34 | ax.set_xlabel('$x$')
35 | ax.set_ylabel('$y$')
36 | ax.set_xlim(x[0], x[-1])
37 | ax.grid()
38 | fig.tight_layout(pad=0)
39 | return fig
40 |
41 |
42 | def figure2(x_low, y_low, x_pred, y_pred, x, y):
43 | "Plot the regression result by low-fidelity data."
44 | fig = plt.figure(figsize=(3, 2.25))
45 | ax = fig.add_subplot(111)
46 | ax.plot(x_low, y_low, 'o', color='None', markeredgecolor='b', label='low')
47 | ax.plot(x_pred, y_pred, 'r', label='$y_{\mathrm{low}}$')
48 | ax.plot(x, y, 'k:', label='true')
49 | ax.legend()
50 | ax.set_xlabel('$x$')
51 | ax.set_ylabel('$y$')
52 | ax.set_xlim(x[0], x[-1])
53 | ax.grid()
54 | fig.tight_layout(pad=0)
55 | return fig
56 |
57 |
58 | def figure3(x_low, y_low, x_high, y_high, x_pred, y_pred, x, y):
59 | "Plot the regression result by low-fidelity data."
60 | fig = plt.figure(figsize=(3, 2.25))
61 | ax = fig.add_subplot(111)
62 | ax.plot(x_low, y_low, 'o', color='None', markeredgecolor='b', label='high')
63 | ax.plot(x_high, y_high, 'x', color='r', label='high')
64 | ax.plot(x_pred, y_pred, 'g', label=r'$y_{\mathrm{pred}}$')
65 | ax.plot(x, y, 'k:', label='true')
66 | ax.legend()
67 | ax.set_xlabel('$x$')
68 | ax.set_ylabel('$y$')
69 | ax.set_xlim(x[0], x[-1])
70 | ax.grid()
71 | fig.tight_layout(pad=0)
72 | return fig
73 |
74 |
75 | if __name__ == '__main__':
76 | # Generate data.
77 | x = torch.linspace(0, 1, 501).reshape(-1, 1)
78 | y = func_high(x)
79 | x_low = torch.linspace(0, 1, 51).reshape(-1, 1)
80 | x_high = torch.linspace(0, 1, 14).reshape(-1, 1)
81 | y_low = func_low(x_low)
82 | y_high = func_high(x_high)
83 | loader_low = torch.utils.data.DataLoader(XYDataSet(x_low, y_low),
84 | batch_size=len(x_low))
85 | loader_high = torch.utils.data.DataLoader(XYDataSet(x_high, y_high),
86 | batch_size=len(x_low))
87 | figure1(x_low, y_low, x_high, y_high, x, y)
88 |
89 | model = MFNN(1, 1, [16], [16, 16], [16, 16, 16, 16], torch.nn.Tanh)
90 | model_low = model.low
91 | model_high = model.high
92 | optimizer_low = torch.optim.Adam(
93 | model_low.parameters(), lr=1e-2, weight_decay=1e-4)
94 | optimizer_high = torch.optim.Adam(
95 | model_high.parameters(), lr=1e-2, weight_decay=1e-4)
96 | scheduler_low = torch.optim.lr_scheduler.MultiStepLR(
97 | optimizer_low, milestones=[2000, 8000])
98 | scheduler_high = torch.optim.lr_scheduler.MultiStepLR(
99 | optimizer_high, milestones=[2000, 8000])
100 | loss = torch.nn.MSELoss()
101 |
102 | # low-fidelity data
103 | trainer1 = Trainer(model_low, optimizer_low, loss,
104 | scheduler=scheduler_low, suppress_display=True)
105 | for i in range(10000):
106 | res = trainer1.train(loader_low)
107 | if i % 100 == 0:
108 | print(f'niter: {i}', f'avg loss: {res.average:.4e}',
109 | f'loss std: {res.std:.4e}', sep=', ')
110 | y1 = model_low.eval()(x).detach()
111 | figure2(x_low, y_low, x, y1, x, y)
112 |
113 | # high-fidelity data
114 |
115 | trainer2 = Trainer(model_high, optimizer_high,
116 | loss, scheduler=scheduler_high, suppress_display=True)
117 | for i in range(10000):
118 | res = trainer2.train(loader_high)
119 | if i % 100 == 0:
120 | print(f'niter: {i}', f'avg loss: {res.average:.4e}',
121 | f'loss std: {res.std:.4e}', sep=', ')
122 | y3 = model_high.eval()(x).detach()
123 | figure3(x_low, y_low, x_high, y_high, x, y3, x, y)
124 |
125 | plt.show()
126 |
--------------------------------------------------------------------------------
/example1.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import logging
4 | import torch
5 |
6 | from mfnn import FCNN, HFNN, XYDataSet, Trainer
7 |
8 |
9 | import matplotlib as mpl
10 | import matplotlib.pyplot as plt
11 | mpl_fontpath = mpl.get_data_path() + '/fonts/ttf/STIXGeneral.ttf'
12 | mpl_fontprop = mpl.font_manager.FontProperties(fname=mpl_fontpath)
13 | plt.rc('font', family='STIXGeneral', weight='normal', size=10)
14 | plt.rc('mathtext', fontset='stix')
15 |
16 |
17 | def func_low(x: torch.Tensor) -> torch.Tensor:
18 | return torch.sin(8 * torch.pi * x)
19 |
20 |
21 | def func_high(x: torch.Tensor) -> torch.Tensor:
22 | return (x - 2**.5) * func_low(x) ** 2
23 |
24 |
25 | def figure1(x_low, y_low, x_high, y_high, x, y):
26 | "Plot the multi-fidelity data along with true data."
27 | y_high = func_high(x_high)
28 |
29 | fig = plt.figure(figsize=(3, 2.25))
30 | ax = fig.add_subplot(111)
31 | ax.plot(x_low, y_low, 'o', color='None', markeredgecolor='b', label='low')
32 | ax.plot(x_high, y_high, 'rx', label='high')
33 | ax.plot(x, y, 'k:', label='true')
34 | ax.legend()
35 | ax.set_xlabel('$x$')
36 | ax.set_ylabel('$y$')
37 | ax.set_xlim(x[0], x[-1])
38 | ax.grid()
39 | fig.tight_layout(pad=0)
40 | return fig
41 |
42 |
43 | def figure2(x_low, y_low, x_pred, y_pred, x, y):
44 | "Plot the regression result by low-fidelity data."
45 | fig = plt.figure(figsize=(3, 2.25))
46 | ax = fig.add_subplot(111)
47 | ax.plot(x_low, y_low, 'o', color='None', markeredgecolor='b', label='low')
48 | ax.plot(x_pred, y_pred, 'r', label='$y_{\mathrm{low}}$')
49 | ax.plot(x, y, 'k:', label='true')
50 | ax.legend()
51 | ax.set_xlabel('$x$')
52 | ax.set_ylabel('$y$')
53 | ax.set_xlim(x[0], x[-1])
54 | ax.grid()
55 | fig.tight_layout(pad=0)
56 | return fig
57 |
58 |
59 | def figure3(x_high, y_high, x_pred, y_pred, x, y):
60 | "Plot the regression result by low-fidelity data."
61 | fig = plt.figure(figsize=(3, 2.25))
62 | ax = fig.add_subplot(111)
63 | ax.plot(x_high, y_high, 'o', color='None',
64 | markeredgecolor='b', label='high')
65 | ax.plot(x_pred, y_pred, 'r', label='$y_{\mathrm{high}}$')
66 | ax.plot(x, y, 'k:', label='true')
67 | ax.legend()
68 | ax.set_xlabel('$x$')
69 | ax.set_ylabel('$y$')
70 | ax.set_xlim(x[0], x[-1])
71 | ax.grid()
72 | fig.tight_layout(pad=0)
73 | return fig
74 |
75 |
76 | def figure4(x_low, y_low, x_high, y_high, x_pred, y_pred, x, y):
77 | "Plot the regression result by low-fidelity data."
78 | fig = plt.figure(figsize=(3, 2.25))
79 | ax = fig.add_subplot(111)
80 | ax.plot(x_low, y_low, 'o', color='None', markeredgecolor='b', label='low')
81 | ax.plot(x_high, y_high, 'x', color='r', label='high')
82 | ax.plot(x_pred, y_pred, 'g', label=r'$y_{\mathrm{pred}}$')
83 | ax.plot(x, y, 'k:', label='true')
84 | ax.legend()
85 | ax.set_xlabel('$x$')
86 | ax.set_ylabel('$y$')
87 | ax.set_xlim(x[0], x[-1])
88 | ax.grid()
89 | fig.tight_layout(pad=0)
90 | return fig
91 |
92 |
93 | if __name__ == '__main__':
94 | # Generate data.
95 | x = torch.linspace(0, 1, 501).reshape(-1, 1)
96 | y = func_high(x)
97 | x_low = torch.linspace(0, 1, 51).reshape(-1, 1)
98 | x_high = torch.linspace(0, 1, 15)[:14].reshape(-1, 1)
99 | y_low = func_low(x_low)
100 | y_high = func_high(x_high)
101 | loader_low = torch.utils.data.DataLoader(XYDataSet(x_low, y_low),
102 | batch_size=len(x_low))
103 | loader_high = torch.utils.data.DataLoader(XYDataSet(x_high, y_high),
104 | batch_size=len(x_low))
105 | figure1(x_low, y_low, x_high, y_high, x, y)
106 |
107 | # low-fidelity data
108 | model1 = FCNN(1, 1, [16, 16], torch.nn.Tanh)
109 | optimizer = torch.optim.Adam(model1.parameters(), lr=1e-2, weight_decay=1e-4)
110 | loss = torch.nn.MSELoss()
111 | scheduler = torch.optim.lr_scheduler.MultiStepLR(
112 | optimizer, milestones=[2000, 8000])
113 | trainer1 = Trainer(model1, optimizer, loss,
114 | scheduler=scheduler,
115 | suppress_display=True)
116 | for i in range(10000):
117 | res = trainer1.train(loader_low)
118 | if i % 100 == 0:
119 | print(f'niter: {i}', f'avg loss: {res.average:.4e}',
120 | f'loss std: {res.std:.4e}', sep=', ')
121 | y1 = model1.eval()(x).detach()
122 | figure2(x_low, y_low, x, y1, x, y)
123 |
124 | # high-fidelity data
125 | model2 = FCNN(1, 1, [16, 16], torch.nn.Tanh)
126 | optimizer = torch.optim.Adam(
127 | model2.parameters(), lr=1e-2, weight_decay=1e-4)
128 | loss = torch.nn.MSELoss()
129 | scheduler = torch.optim.lr_scheduler.MultiStepLR(
130 | optimizer, milestones=[2000, 8000])
131 | trainer2 = Trainer(model2, optimizer, loss,
132 | scheduler=scheduler, suppress_display=True)
133 | for i in range(10000):
134 | res = trainer2.train(loader_high)
135 | if i % 100 == 0:
136 | print(f'niter: {i}', f'avg loss: {res.average:.4e}',
137 | f'loss std: {res.std:.4e}', sep=', ')
138 | y2 = model2.eval()(x).detach()
139 | figure3(x_high, y_high, x, y2, x, y)
140 |
141 | # MFNN
142 | model3 = HFNN(model1, 1, 1, [16, 16], torch.nn.Tanh)
143 | loss = torch.nn.MSELoss()
144 | optimizer = torch.optim.Adam(
145 | model3.parameters(), lr=1e-2, weight_decay=1e-4)
146 | scheduler = torch.optim.lr_scheduler.MultiStepLR(
147 | optimizer, milestones=[2000, 8000])
148 | trainer3 = Trainer(model3, optimizer, loss,
149 | scheduler=scheduler, suppress_display=True)
150 | for i in range(10000):
151 | res = trainer3.train(loader_high)
152 | if i % 100 == 0:
153 | print(f'niter: {i}', f'avg loss: {res.average:.4e}',
154 | f'loss std: {res.std:.4e}', sep=', ')
155 | y3 = model3.eval()(x).detach()
156 | figure4(x_low, y_low, x_high, y_high, x, y3, x, y)
157 |
158 | plt.show()
159 |
--------------------------------------------------------------------------------
/mfnn/mfnn.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | from __future__ import annotations
4 | import typing
5 | import torch
6 |
7 | TensorFunc = typing.Callable[[torch.Tensor], torch.Tensor]
8 |
9 | __all__ = ('FCNN', 'HFNN', 'MFNN')
10 |
11 |
12 | class BasicBlock(torch.nn.Module):
13 | """Basic block for a fully-connected layer."""
14 |
15 | def __init__(self, in_features: int,
16 | out_features: int,
17 | activation: type[torch.nn.Module]):
18 | super().__init__()
19 | self.linear_layer = torch.nn.Linear(in_features, out_features)
20 | self.activation = activation()
21 |
22 | def forward(self, x):
23 | x = self.linear_layer(x)
24 | x = self.activation(x)
25 | return x
26 |
27 |
28 | class FCNN(torch.nn.Module):
29 | def __init__(self,
30 | in_features: int,
31 | out_features: int,
32 | midlayer_features: list[int],
33 | activation: type[torch.nn.Module] = torch.nn.Tanh
34 | ):
35 | super().__init__()
36 |
37 | layer_sizes = [in_features] + midlayer_features
38 |
39 | self.layers = torch.nn.Sequential(*[
40 | BasicBlock(layer_sizes[i], layer_sizes[i+1], activation)
41 | for i in range(len(midlayer_features))
42 | ])
43 | self.fc = torch.nn.Linear(layer_sizes[-1], out_features)
44 |
45 | def forward(self, x: torch.Tensor) -> torch.Tensor:
46 | x = self.layers(x)
47 | x = self.fc(x)
48 | return x
49 |
50 |
51 | class HFNN(torch.nn.Module):
52 | def __init__(self,
53 | lfnn: torch.nn.Module | TensorFunc,
54 | in_features: int,
55 | out_features: int,
56 | midlayer_features: list[int],
57 | activation: type[torch.nn.Module] | None = None):
58 | activation = torch.nn.Tanh if activation is None else activation
59 | super().__init__()
60 | # Avoid updating the trained low-fidelity model during training.
61 | object.__setattr__(self, 'lfnn', lfnn)
62 | self.lfnn: torch.nn.Module | TensorFunc
63 |
64 | self.layers = FCNN(in_features + out_features,
65 | out_features,
66 | midlayer_features,
67 | activation)
68 | self.shortcut = FCNN(in_features + out_features,
69 | out_features,
70 | [],
71 | torch.nn.Identity)
72 |
73 | # Use He Kaiming's normal initialization
74 | for m in self.modules():
75 | for name, parameter in m.named_parameters():
76 | if name == 'weight':
77 | torch.nn.init.kaiming_normal_(parameter)
78 | elif name == 'bias':
79 | torch.nn.init.zeros_(parameter)
80 | else:
81 | assert "impossible!"
82 |
83 | def forward(self, x: torch.Tensor) -> torch.Tensor:
84 | y_low = self.lfnn(x)
85 | x_comb = torch.concat([x, y_low], dim=1)
86 | y_layer = self.layers(x_comb)
87 | y_shortcut = self.shortcut(x_comb)
88 | y = y_layer + y_shortcut
89 | return y
90 |
91 |
92 | class MFNN():
93 | """A flexible network for regression problem of multi-fidelity data.
94 |
95 | This class contains two networks, for fitting low-fidelity data and
96 | high fidelity data respectively.
97 | """
98 |
99 | def __init__(self,
100 | in_features: int,
101 | out_features: int,
102 | backbone_layers: list[int],
103 | layers_low: list[int],
104 | layers_high: list[int],
105 | activation: type[torch.nn.Module] | None = None):
106 | activation = torch.nn.Tanh if activation is None else activation
107 |
108 | layer_sizes = [in_features] + backbone_layers
109 | self.backbone = torch.nn.Sequential(*[
110 | BasicBlock(layer_sizes[i], layer_sizes[i+1], activation)
111 | for i in range(len(backbone_layers))
112 | ])
113 |
114 | self.fc_low = FCNN(layer_sizes[-1],
115 | out_features,
116 | layers_low,
117 | activation)
118 | self.fc_high = FCNN(out_features + layer_sizes[-1],
119 | out_features,
120 | layers_high,
121 | activation)
122 | self.shortcut = FCNN(out_features + layer_sizes[-1],
123 | out_features,
124 | [],
125 | torch.nn.Identity)
126 |
127 | # Use He Kaiming's normal initialization
128 | for m in [self.backbone, self.fc_low, self.fc_high, self.shortcut]:
129 | for name, parameter in m.named_parameters():
130 | if name == 'weight':
131 | torch.nn.init.kaiming_normal_(parameter)
132 | elif name == 'bias':
133 | torch.nn.init.zeros_(parameter)
134 | else:
135 | assert "impossible!"
136 |
137 | self.low = MFNNLow(self.backbone, self.fc_low)
138 | self.high = MFNNHigh(self.low, self.fc_high, self.shortcut)
139 |
140 |
141 | class MFNNLow(torch.nn.Module):
142 | def __init__(self,
143 | backbone: torch.nn.Module,
144 | fc_low: torch.nn.Module):
145 | super().__init__()
146 | self.backbone = backbone
147 | self.fc_low = fc_low
148 |
149 | def forward(self, x: torch.Tensor) -> torch.Tensor:
150 | features = self.backbone(x)
151 | y = self.fc_low(features)
152 | return y
153 |
154 |
155 | class MFNNHigh(torch.nn.Module):
156 | def __init__(self,
157 | mfnnlow: torch.nn.Module,
158 | fc_high: torch.nn.Module,
159 | shortcut: torch.nn.Module):
160 | super().__init__()
161 | # Avoid updating the trained low-fidelity model during training.
162 | object.__setattr__(self, 'mfnnlow', mfnnlow)
163 | self.mfnnlow: torch.nn.Module
164 | self.fc_high = fc_high
165 | self.shortcut = shortcut
166 |
167 | def forward(self, x: torch.Tensor) -> torch.Tensor:
168 | features = self.mfnnlow.backbone(x)
169 | y_low = self.mfnnlow.fc_low(features)
170 | x_combine = torch.concat([y_low.reshape(-1, 1), features], dim=1)
171 | y_fc = self.fc_high(x_combine)
172 | y_shortcut = self.shortcut(x_combine)
173 | y = y_fc + y_shortcut
174 | return y
175 |
176 |
177 | if __name__ == '__main__':
178 |
179 | x = torch.linspace(0, 1, 101).reshape(-1, 1)
180 |
181 | lfnn = FCNN(1, 1, [16, 16, 16], activation=torch.nn.Tanh)
182 | hfnn = HFNN(lfnn, 1, 1, [32, 32])
183 | mfnn = MFNN(1, 1, [16], [16, 16], [16], activation=torch.nn.Tanh)
184 |
185 | ylow = lfnn(x)
186 | yhigh = hfnn(x)
187 | ylow = mfnn.low(x)
188 | yhigh = mfnn.high(x)
189 |
--------------------------------------------------------------------------------
/mfnn/trainer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | from __future__ import annotations
4 | import sys
5 | import time
6 | import pickle
7 | import typing
8 |
9 | import tqdm
10 | import torch
11 |
12 |
13 | from .utils import DEVICE, Statistics, free_memory, copy_to, GatedStdout
14 |
15 | __all__ = ('Trainer',)
16 |
17 |
18 | _VALID_SECOND_TIME_WARNING = True
19 |
20 |
21 | class Trainer():
22 | """Class for training a model."""
23 |
24 | def __init__(self,
25 | model: torch.nn.Module,
26 | optimizer: torch.optim.Optimizer,
27 | critrion: torch.nn.Module,
28 | *,
29 | device: torch.device | int | str | None = None,
30 | start_epoch: int = 0,
31 | filename: str | None = None,
32 | scheduler: torch.optim.lr_scheduler.LRScheduler | None = None,
33 | forced_gc: bool = False,
34 | suppress_display: bool = False
35 | ):
36 | self.model = model
37 | self.optimizer = optimizer
38 | self.critrion = critrion # loss function
39 | # Use property setter to move model and loss function to target device
40 | self.device = DEVICE if device is None else device
41 |
42 | self.epoch = start_epoch
43 | self.filename = 'trainer.trainer' if filename is None else filename
44 | self.scheduler = scheduler
45 | self.is_forced_gc = forced_gc
46 | self.stdout = GatedStdout(suppress_display)
47 |
48 | self.history: dict[str, list[float]] = {'train_loss': [],
49 | 'validate_loss': [],
50 | }
51 |
52 | @property
53 | def device(self) -> torch.device:
54 | return self._device
55 |
56 | @device.setter
57 | def device(self, device: torch.device | int | str) -> Trainer:
58 | self._device = torch.device(device)
59 | self.model.to(self._device)
60 | self.critrion.to(self._device)
61 | return self
62 |
63 | @property
64 | def lr(self) -> list[float]:
65 | return [pg['lr'] for pg in self.optimizer.param_groups]
66 |
67 | @lr.setter
68 | def lr(self, lr: float):
69 | for param_group in self.optimizer.param_groups:
70 | param_group['lr'] = lr
71 |
72 | def save(self, device: torch.device | int | str = "cpu"):
73 | """Save trainer object.
74 |
75 | The trainer is saved to the given `device`, along with the
76 | model. The `device` specified here does not change the device
77 | type of the trainer instance, but only saves all the variables
78 | to this `device`. The default target device is "cpu".
79 |
80 | The `device` specified in this method has nothing to do with
81 | the `load` method's `device` argument. The `device` argument
82 | is introduced here to solve the problem that a cuda model can
83 | not be saved and then loaded on another computer without
84 | gpu. So it is always suggested to set the `device` argument
85 | here to "cpu" to make sure it can be loaded on any computer.
86 | """
87 | data = copy_to(self.__dict__, torch.device(device))
88 | with open(self.filename, 'wb') as f:
89 | f.write(pickle.dumps((data, self.device)))
90 |
91 | def save_as(self, filename: str):
92 | self.filename = filename
93 | return self.save()
94 |
95 | @staticmethod
96 | def load(filename: str,
97 | device: torch.device | int | str | None = None
98 | ) -> Trainer:
99 | """Load a trainer object.
100 |
101 | Load the trainer to given `device`. Specifying `device`
102 | argument here would also change the loaded trainer's `device`
103 | property. If device is not given, it is defaulted to the
104 | object's `device` property.
105 | """
106 | with open(filename, 'rb') as f:
107 | data, default_device = pickle.loads(f.read())
108 | if device is None:
109 | data = copy_to(data, default_device)
110 | else:
111 | device = torch.device(device)
112 | data = copy_to(data, device)
113 | data['_device'] = device
114 | res = object.__new__(Trainer)
115 | res.__dict__.update(data)
116 | return res
117 |
118 | def train(self,
119 | loader: torch.utils.data.DataLoader
120 | ) -> Statistics[float]:
121 | "Train the model by given dataloader."
122 | t_start = time.time()
123 | self.model.train()
124 | tq = tqdm.tqdm(loader,
125 | desc="train",
126 | ncols=None,
127 | leave=False,
128 | file=self.stdout,
129 | unit="batch")
130 | loss_meter: Statistics[float] = Statistics()
131 | # User defined preprocess
132 | additional_data = self.additional_train_preprocess(tq)
133 | for x, y in tq:
134 | current_batch_size = x.shape[0]
135 | x = x.to(self.device)
136 | y = y.to(self.device)
137 |
138 | # compute prediction error
139 | y_pred = self.model(x)
140 | loss = self.critrion(y_pred, y)
141 | # backpropagation
142 | self.optimizer.zero_grad()
143 | loss.backward()
144 | self.optimizer.step()
145 |
146 | # record results
147 | loss_meter.update(loss.item(), current_batch_size)
148 | tq.set_postfix(loss=f"{loss_meter.value:.4e}")
149 |
150 | # Do some user-defined process.
151 | self.additional_train_process(additional_data,
152 | y_pred, y, loss, tq)
153 |
154 | # Free some space before next round.
155 | del x, y, y_pred, loss
156 | if self.is_forced_gc:
157 | free_memory()
158 |
159 | # Save information for this epoch.
160 | self.epoch += 1
161 | self.history['train_loss'].append(loss_meter.average)
162 |
163 | # User_defined postprocess.
164 | self.additional_train_postprocess(additional_data)
165 |
166 | print(f'train result [{self.epoch}]: '
167 | f'avg loss = {loss_meter.average:.4e}, '
168 | f'wall time = {time.time()- t_start:.2f}s',
169 | file=self.stdout)
170 | if self.scheduler:
171 | self.scheduler.step()
172 |
173 | return loss_meter
174 |
175 | def validate(self,
176 | loader: torch.utils.data.DataLoader,
177 | ) -> Statistics[float]:
178 | """Validate the model."""
179 | t_start = time.time()
180 | self.model.eval()
181 | tq = tqdm.tqdm(loader,
182 | desc="valid",
183 | ncols=None,
184 | leave=False,
185 | file=self.stdout,
186 | unit="batch")
187 | loss_meter: Statistics[float] = Statistics()
188 | # User-defined preprocess
189 | additional_data = self.additional_validate_preprocess(tq)
190 |
191 | for x, y in tq:
192 | current_batch_size = x.size(0)
193 | x = x.to(self.device)
194 | y = y.to(self.device)
195 |
196 | with torch.no_grad():
197 | y_pred = self.model(x)
198 | loss = self.critrion(y_pred, y)
199 | loss_meter.update(loss.item(), current_batch_size)
200 | tq.set_postfix(loss=f"{loss_meter.value:.4e}")
201 |
202 | # Do some user-defined process.
203 | self.additional_validate_process(additional_data,
204 | y_pred, y, loss, tq)
205 |
206 | del x, y, y_pred, loss
207 | if self.is_forced_gc:
208 | free_memory()
209 |
210 | # Save validation results only the fisrt run.
211 | if len(self.history['validate_loss']) < self.epoch:
212 | self.history['validate_loss'].append(loss_meter.average)
213 | else:
214 | global _VALID_SECOND_TIME_WARNING
215 | if _VALID_SECOND_TIME_WARNING:
216 | sys.stderr.write("The model is validated for the "
217 | "second time in the same epoch, "
218 | "validation result will not be "
219 | "recorded. "
220 | "This warning will be "
221 | "turned off in this session.\n")
222 | _VALID_SECOND_TIME_WARNING = False
223 |
224 | # User_defined postprocess.
225 | self.additional_validate_postprocess(additional_data)
226 |
227 | print(f'valid result [{self.epoch}]: '
228 | f'avg loss = {loss_meter.average:.4e}, '
229 | f'wall time = {time.time()- t_start:.2f}s',
230 | file=self.stdout)
231 | return loss_meter
232 |
233 | def step(self,
234 | train_dataloader: torch.utils.data.DataLoader,
235 | valid_dataloader: torch.utils.data.DataLoader,
236 | save_trainer: bool = True,
237 | save_best_model: str | None = None,
238 | ) -> bool:
239 | """Train and validate the model, return if it is the best model.
240 |
241 | Note that if `save_best_model` is enabled, it selects the best
242 | model by its validation loss, which might not be good for
243 | classification problems.
244 |
245 | Parameters
246 | ----------
247 | save_trainer: bool, optional
248 | whether to save the trainer after this epoch. defaults to
249 | True.
250 | save_best_model: str, optional
251 | Filename for saving the model if it is the best model
252 | indicated by the validation procedure. If not given, the
253 | model will not be saved automatically.
254 |
255 | """
256 | print(f' ---- Epoch {self.epoch} ---- ', file=self.stdout)
257 | self.train(train_dataloader)
258 | loss = self.validate(valid_dataloader)
259 | if save_trainer:
260 | self.save()
261 | if loss == min(self.history['validate_loss']):
262 | if save_best_model:
263 | print('This model will be saved as the best model.',
264 | file=self.stdout)
265 | with open(save_best_model, 'wb') as f:
266 | torch.save(self.model, f)
267 | return True
268 | return False
269 |
270 | # Class methods can be overwritten.
271 |
272 | def additional_train_preprocess(self, tq: tqdm.std.tqdm) -> typing.Any:
273 | """Additional pre-process in each epoch.
274 |
275 | This method can be overwritten to do some additional work
276 | before iterations in each epoch (eg, prepare dict for
277 | `additional_train_process`). The return value of the method
278 | will be used for `additional_train_process` and
279 | `additional_train_postprocess`
280 |
281 | Parameters
282 | ----------
283 | tq: tqdm object
284 | The tqdm object. Can modify display here.
285 |
286 | Return
287 | ------
288 | additional_data: typing.Any
289 | The returned data will be used processed in each batch
290 | and the end of the epoch.
291 | """
292 | return None
293 |
294 | def additional_train_process(self,
295 | additional_data: typing.Any,
296 | y_pred: torch.Tensor,
297 | y_true: torch.Tensor,
298 | loss: torch.Tensor,
299 | tq: tqdm.std.tqdm):
300 | """Additional process after training the model in each batch.
301 |
302 | This method can be overwritten to do some additional work
303 | after the loss is calculated in each batch (eg, calculate
304 | top-k error in classification task). This function has no
305 | return value.
306 |
307 | Parameters
308 | ----------
309 | additional_data: typing.Any
310 | Defined in `additional_train_preprocess`.
311 | y_pred: torch.Tensor
312 | Predicted value by model.
313 | y_true: torch.Tensor
314 | True value given by train dataset.
315 | loss: torch.Tensor
316 | Loss of this batch given by critrion.
317 | tq: tqdm object
318 | Can modify display here.
319 | """
320 | pass
321 |
322 | def additional_train_postprocess(self,
323 | additional_data: typing.Any):
324 | """Additional postprocess in each epoch.
325 |
326 | This method can be overwritten to do some additional work
327 | after the epoch is finished (eg. saving `additional_data`).
328 |
329 | Parameters
330 | ----------
331 | additional_data: typing.Any
332 | Defined in `additional_train_preprocess`.
333 | """
334 | pass
335 |
336 | def additional_validate_preprocess(self, tq: tqdm.std.tqdm) -> typing.Any:
337 | """Additional pre-process in each epoch.
338 |
339 | This method can be overwritten to do some additional work
340 | before iterations in each epoch (eg, prepare dict for
341 | `additional_validate_process`). The return value of the method
342 | will be used for `additional_validate_process` and
343 | `additional_validate_postprocess`
344 |
345 | Parameters
346 | ----------
347 | tq: tqdm object
348 | The tqdm object. Can modify display here.
349 |
350 | Return
351 | ------
352 | additional_data: typing.Any
353 | The returned data will be used processed in each batch
354 | and the end of the epoch.
355 | """
356 | return None
357 |
358 | def additional_validate_process(self,
359 | additional_data: typing.Any,
360 | y_pred: torch.Tensor,
361 | y_true: torch.Tensor,
362 | loss: torch.Tensor,
363 | tq: tqdm.std.tqdm):
364 | """Additional process after model validation in each batch.
365 |
366 | This method can be overwritten to do some additional work
367 | after the loss is calculated in each batch (eg, calculate
368 | top-k error in classification task). This function has no
369 | return value.
370 |
371 | Parameters
372 | ----------
373 | additional_data: typing.Any
374 | Defined in `additional_validate_preprocess`.
375 | y_pred: torch.Tensor
376 | Predicted value by model.
377 | y_true: torch.Tensor
378 | True value given by validation dataset.
379 | loss: torch.Tensor
380 | Loss of this batch given by critrion.
381 | tq: tqdm object
382 | Can modify display here.
383 | """
384 | pass
385 |
386 | def additional_validate_postprocess(self,
387 | additional_data: typing.Any):
388 | """Additional postprocess in each epoch.
389 |
390 | This method can be overwritten to do some additional work
391 | after the epoch is finished (eg. saving `additional_data`).
392 |
393 | Parameters
394 | ----------
395 | additional_data: typing.Any
396 | Defined in `additional_validate_preprocess`.
397 | """
398 | pass
399 |
--------------------------------------------------------------------------------
/snapshots/high.svg:
--------------------------------------------------------------------------------
1 |
2 |
4 |
1378 |
--------------------------------------------------------------------------------
/snapshots/data.svg:
--------------------------------------------------------------------------------
1 |
2 |
4 |
1300 |
--------------------------------------------------------------------------------
/snapshots/low.svg:
--------------------------------------------------------------------------------
1 |
2 |
4 |
1320 |
--------------------------------------------------------------------------------
/snapshots/mfnn.svg:
--------------------------------------------------------------------------------
1 |
2 |
4 |
1615 |
--------------------------------------------------------------------------------