├── mfnn ├── __init__.py ├── xydata.py ├── utils.py ├── mfnn.py └── trainer.py ├── readme.org ├── example2.py ├── example1.py └── snapshots ├── high.svg ├── data.svg ├── low.svg └── mfnn.svg /mfnn/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from .mfnn import * 4 | from .xydata import * 5 | from .trainer import * 6 | 7 | from .utils import copy_to, DEVICE 8 | -------------------------------------------------------------------------------- /mfnn/xydata.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from __future__ import annotations 4 | import typing 5 | import torch 6 | 7 | TensorFunc = typing.Callable[[torch.Tensor], torch.Tensor] 8 | 9 | __all__ = ('XYDataSet', ) 10 | 11 | 12 | class XYDataSet(torch.utils.data.Dataset): 13 | def __init__(self, x: torch.Tensor, y: torch.Tensor): 14 | if len(x) != len(y): 15 | raise ValueError('size of x and y not match') 16 | self.x = x 17 | self.y = y 18 | 19 | def __len__(self) -> int: 20 | return len(self.x) 21 | 22 | def __getitem__(self, idx) -> tuple[torch.Tensor, torch.Tensor]: 23 | return self.x[idx], self.y[idx] 24 | -------------------------------------------------------------------------------- /readme.org: -------------------------------------------------------------------------------- 1 | * Multi-fidelity Neural Network 2 | 3 | Multi-fidelity neural network (MFNN) is used for modeling physical 4 | systems by multi-fidelity data. In real applications, low-fidelity 5 | data is usually abundant but less accurate, and high-fidelity data is 6 | scarce and expensive. MFNN makes uses of both low- and high-fidelity 7 | data for modeling the physical system, which significantly improves 8 | data accuracy by a small set of high-fidelity data. 9 | 10 | Meng and Karniadakis [1] gave an approach to MFNN by using a 11 | composite neural network. However, the code is absent for their 12 | paper. Meanwhile, the use of liner layers in high-fidelity DNN (NN_H1) 13 | is redundant, as linear features will always be modeled by the 14 | nonlinear DNN (NN_H2). 15 | 16 | Thus, in this repository, a modified version of MFNN is provided, 17 | where linear DNN (NN_H1) given by paper [1] are replaced by residual 18 | connection over the nonlinear DNN (NN_H2).The code is implemented using 19 | pytorch, and examples are provided for MFNN. 20 | 21 | * Snapshots 22 | 23 | ** Data 24 | 25 | [[./snapshots/data.svg]] 26 | 27 | ** Modeling using low-fidelity data 28 | 29 | [[./snapshots/low.svg]] 30 | 31 | ** Modeling by high-fidelity data 32 | 33 | [[./snapshots/high.svg]] 34 | 35 | ** Modeling by both low- and high-fidelity data 36 | 37 | [[./snapshots/mfnn.svg]] 38 | 39 | * References 40 | [1] Meng X, Karniadakis GE. A composite neural network that learns 41 | from multi-fidelity data: Application to function approximation and 42 | inverse PDE problems. Journal of Computational Physics 43 | 2020;401:109020. https://doi.org/10.1016/j.jcp.2019.109020. 44 | 45 | -------------------------------------------------------------------------------- /mfnn/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from __future__ import annotations 4 | 5 | import gc 6 | import sys 7 | import typing 8 | from io import BytesIO 9 | import typing 10 | 11 | import torch 12 | 13 | # __all__ = ('DEVICE', 'free_memory', 'copy_to', 'Statistics', ) 14 | 15 | 16 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 17 | 18 | 19 | def free_memory(): 20 | gc.collect() 21 | torch.cuda.empty_cache() 22 | 23 | 24 | def copy_to(data, device: torch.device | int | str | None = None): 25 | """Copy data to device.""" 26 | # Avoid useless copy in gpu. 27 | # See https://discuss.pytorch.org/t/how-to-make-a-copy-of-a-gpu-model-on-the-cpu/90955/4 28 | if device is None: 29 | return data 30 | memory = BytesIO() 31 | torch.save(data, memory, pickle_protocol=-1) 32 | memory.seek(0) 33 | data = torch.load(memory, map_location=device) 34 | memory.close() 35 | return data 36 | 37 | 38 | T = typing.TypeVar('T', bound=float) 39 | 40 | 41 | class Statistics(typing.Generic[T]): 42 | """Get statistical value(sum, average, variance) of data.""" 43 | __slots__ = '_count', '_value', '_s1', '_s2' 44 | 45 | def __init__(self): 46 | self._count = 0 47 | self._value: T = 0. 48 | self._s1: float = 0. # sum of sample values 49 | self._s2: float = 0. # sum of squared sample values 50 | 51 | def update(self, value: T, n: int = 1) -> Statistics[T]: 52 | self._count += n 53 | self._value = value 54 | self._s1 += self.value * n 55 | self._s2 += value ** 2 * n 56 | return self 57 | 58 | @property 59 | def count(self) -> int: 60 | return self._count 61 | 62 | @property 63 | def value(self) -> T: 64 | return self._value 65 | 66 | @property 67 | def sum(self) -> float: 68 | return self._s1 69 | 70 | @property 71 | def average(self) -> float: 72 | return self._s1 / self._count 73 | 74 | @property 75 | def variance(self) -> float: 76 | return self._s2 / self._count - self.average ** 2 77 | 78 | @property 79 | def std(self) -> float: 80 | return self.variance ** .5 81 | 82 | 83 | class GatedStdout: 84 | def __init__(self, suppress: bool): 85 | self.suppress = suppress 86 | 87 | def write(self, s: typing.TypeVar('AnyStr', bytes, str)): 88 | if self.suppress: 89 | return 0 90 | return sys.stdout.write(s) 91 | -------------------------------------------------------------------------------- /example2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import logging 4 | import torch 5 | 6 | from mfnn import FCNN, HFNN, MFNN, XYDataSet, Trainer 7 | 8 | import matplotlib as mpl 9 | import matplotlib.pyplot as plt 10 | mpl_fontpath = mpl.get_data_path() + '/fonts/ttf/STIXGeneral.ttf' 11 | mpl_fontprop = mpl.font_manager.FontProperties(fname=mpl_fontpath) 12 | plt.rc('font', family='STIXGeneral', weight='normal', size=10) 13 | plt.rc('mathtext', fontset='stix') 14 | 15 | 16 | def func_low(x: torch.Tensor) -> torch.Tensor: 17 | return torch.sin(8 * torch.pi * x) 18 | 19 | 20 | def func_high(x: torch.Tensor) -> torch.Tensor: 21 | return (x - 2**.5) * func_low(x) ** 2 22 | 23 | 24 | def figure1(x_low, y_low, x_high, y_high, x, y): 25 | "Plot the multi-fidelity data along with true data." 26 | y_high = func_high(x_high) 27 | 28 | fig = plt.figure(figsize=(3, 2.25)) 29 | ax = fig.add_subplot(111) 30 | ax.plot(x_low, y_low, 'o', color='None', markeredgecolor='b', label='low') 31 | ax.plot(x_high, y_high, 'rx', label='high') 32 | ax.plot(x, y, 'k:', label='true') 33 | ax.legend() 34 | ax.set_xlabel('$x$') 35 | ax.set_ylabel('$y$') 36 | ax.set_xlim(x[0], x[-1]) 37 | ax.grid() 38 | fig.tight_layout(pad=0) 39 | return fig 40 | 41 | 42 | def figure2(x_low, y_low, x_pred, y_pred, x, y): 43 | "Plot the regression result by low-fidelity data." 44 | fig = plt.figure(figsize=(3, 2.25)) 45 | ax = fig.add_subplot(111) 46 | ax.plot(x_low, y_low, 'o', color='None', markeredgecolor='b', label='low') 47 | ax.plot(x_pred, y_pred, 'r', label='$y_{\mathrm{low}}$') 48 | ax.plot(x, y, 'k:', label='true') 49 | ax.legend() 50 | ax.set_xlabel('$x$') 51 | ax.set_ylabel('$y$') 52 | ax.set_xlim(x[0], x[-1]) 53 | ax.grid() 54 | fig.tight_layout(pad=0) 55 | return fig 56 | 57 | 58 | def figure3(x_low, y_low, x_high, y_high, x_pred, y_pred, x, y): 59 | "Plot the regression result by low-fidelity data." 60 | fig = plt.figure(figsize=(3, 2.25)) 61 | ax = fig.add_subplot(111) 62 | ax.plot(x_low, y_low, 'o', color='None', markeredgecolor='b', label='high') 63 | ax.plot(x_high, y_high, 'x', color='r', label='high') 64 | ax.plot(x_pred, y_pred, 'g', label=r'$y_{\mathrm{pred}}$') 65 | ax.plot(x, y, 'k:', label='true') 66 | ax.legend() 67 | ax.set_xlabel('$x$') 68 | ax.set_ylabel('$y$') 69 | ax.set_xlim(x[0], x[-1]) 70 | ax.grid() 71 | fig.tight_layout(pad=0) 72 | return fig 73 | 74 | 75 | if __name__ == '__main__': 76 | # Generate data. 77 | x = torch.linspace(0, 1, 501).reshape(-1, 1) 78 | y = func_high(x) 79 | x_low = torch.linspace(0, 1, 51).reshape(-1, 1) 80 | x_high = torch.linspace(0, 1, 14).reshape(-1, 1) 81 | y_low = func_low(x_low) 82 | y_high = func_high(x_high) 83 | loader_low = torch.utils.data.DataLoader(XYDataSet(x_low, y_low), 84 | batch_size=len(x_low)) 85 | loader_high = torch.utils.data.DataLoader(XYDataSet(x_high, y_high), 86 | batch_size=len(x_low)) 87 | figure1(x_low, y_low, x_high, y_high, x, y) 88 | 89 | model = MFNN(1, 1, [16], [16, 16], [16, 16, 16, 16], torch.nn.Tanh) 90 | model_low = model.low 91 | model_high = model.high 92 | optimizer_low = torch.optim.Adam( 93 | model_low.parameters(), lr=1e-2, weight_decay=1e-4) 94 | optimizer_high = torch.optim.Adam( 95 | model_high.parameters(), lr=1e-2, weight_decay=1e-4) 96 | scheduler_low = torch.optim.lr_scheduler.MultiStepLR( 97 | optimizer_low, milestones=[2000, 8000]) 98 | scheduler_high = torch.optim.lr_scheduler.MultiStepLR( 99 | optimizer_high, milestones=[2000, 8000]) 100 | loss = torch.nn.MSELoss() 101 | 102 | # low-fidelity data 103 | trainer1 = Trainer(model_low, optimizer_low, loss, 104 | scheduler=scheduler_low, suppress_display=True) 105 | for i in range(10000): 106 | res = trainer1.train(loader_low) 107 | if i % 100 == 0: 108 | print(f'niter: {i}', f'avg loss: {res.average:.4e}', 109 | f'loss std: {res.std:.4e}', sep=', ') 110 | y1 = model_low.eval()(x).detach() 111 | figure2(x_low, y_low, x, y1, x, y) 112 | 113 | # high-fidelity data 114 | 115 | trainer2 = Trainer(model_high, optimizer_high, 116 | loss, scheduler=scheduler_high, suppress_display=True) 117 | for i in range(10000): 118 | res = trainer2.train(loader_high) 119 | if i % 100 == 0: 120 | print(f'niter: {i}', f'avg loss: {res.average:.4e}', 121 | f'loss std: {res.std:.4e}', sep=', ') 122 | y3 = model_high.eval()(x).detach() 123 | figure3(x_low, y_low, x_high, y_high, x, y3, x, y) 124 | 125 | plt.show() 126 | -------------------------------------------------------------------------------- /example1.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import logging 4 | import torch 5 | 6 | from mfnn import FCNN, HFNN, XYDataSet, Trainer 7 | 8 | 9 | import matplotlib as mpl 10 | import matplotlib.pyplot as plt 11 | mpl_fontpath = mpl.get_data_path() + '/fonts/ttf/STIXGeneral.ttf' 12 | mpl_fontprop = mpl.font_manager.FontProperties(fname=mpl_fontpath) 13 | plt.rc('font', family='STIXGeneral', weight='normal', size=10) 14 | plt.rc('mathtext', fontset='stix') 15 | 16 | 17 | def func_low(x: torch.Tensor) -> torch.Tensor: 18 | return torch.sin(8 * torch.pi * x) 19 | 20 | 21 | def func_high(x: torch.Tensor) -> torch.Tensor: 22 | return (x - 2**.5) * func_low(x) ** 2 23 | 24 | 25 | def figure1(x_low, y_low, x_high, y_high, x, y): 26 | "Plot the multi-fidelity data along with true data." 27 | y_high = func_high(x_high) 28 | 29 | fig = plt.figure(figsize=(3, 2.25)) 30 | ax = fig.add_subplot(111) 31 | ax.plot(x_low, y_low, 'o', color='None', markeredgecolor='b', label='low') 32 | ax.plot(x_high, y_high, 'rx', label='high') 33 | ax.plot(x, y, 'k:', label='true') 34 | ax.legend() 35 | ax.set_xlabel('$x$') 36 | ax.set_ylabel('$y$') 37 | ax.set_xlim(x[0], x[-1]) 38 | ax.grid() 39 | fig.tight_layout(pad=0) 40 | return fig 41 | 42 | 43 | def figure2(x_low, y_low, x_pred, y_pred, x, y): 44 | "Plot the regression result by low-fidelity data." 45 | fig = plt.figure(figsize=(3, 2.25)) 46 | ax = fig.add_subplot(111) 47 | ax.plot(x_low, y_low, 'o', color='None', markeredgecolor='b', label='low') 48 | ax.plot(x_pred, y_pred, 'r', label='$y_{\mathrm{low}}$') 49 | ax.plot(x, y, 'k:', label='true') 50 | ax.legend() 51 | ax.set_xlabel('$x$') 52 | ax.set_ylabel('$y$') 53 | ax.set_xlim(x[0], x[-1]) 54 | ax.grid() 55 | fig.tight_layout(pad=0) 56 | return fig 57 | 58 | 59 | def figure3(x_high, y_high, x_pred, y_pred, x, y): 60 | "Plot the regression result by low-fidelity data." 61 | fig = plt.figure(figsize=(3, 2.25)) 62 | ax = fig.add_subplot(111) 63 | ax.plot(x_high, y_high, 'o', color='None', 64 | markeredgecolor='b', label='high') 65 | ax.plot(x_pred, y_pred, 'r', label='$y_{\mathrm{high}}$') 66 | ax.plot(x, y, 'k:', label='true') 67 | ax.legend() 68 | ax.set_xlabel('$x$') 69 | ax.set_ylabel('$y$') 70 | ax.set_xlim(x[0], x[-1]) 71 | ax.grid() 72 | fig.tight_layout(pad=0) 73 | return fig 74 | 75 | 76 | def figure4(x_low, y_low, x_high, y_high, x_pred, y_pred, x, y): 77 | "Plot the regression result by low-fidelity data." 78 | fig = plt.figure(figsize=(3, 2.25)) 79 | ax = fig.add_subplot(111) 80 | ax.plot(x_low, y_low, 'o', color='None', markeredgecolor='b', label='low') 81 | ax.plot(x_high, y_high, 'x', color='r', label='high') 82 | ax.plot(x_pred, y_pred, 'g', label=r'$y_{\mathrm{pred}}$') 83 | ax.plot(x, y, 'k:', label='true') 84 | ax.legend() 85 | ax.set_xlabel('$x$') 86 | ax.set_ylabel('$y$') 87 | ax.set_xlim(x[0], x[-1]) 88 | ax.grid() 89 | fig.tight_layout(pad=0) 90 | return fig 91 | 92 | 93 | if __name__ == '__main__': 94 | # Generate data. 95 | x = torch.linspace(0, 1, 501).reshape(-1, 1) 96 | y = func_high(x) 97 | x_low = torch.linspace(0, 1, 51).reshape(-1, 1) 98 | x_high = torch.linspace(0, 1, 15)[:14].reshape(-1, 1) 99 | y_low = func_low(x_low) 100 | y_high = func_high(x_high) 101 | loader_low = torch.utils.data.DataLoader(XYDataSet(x_low, y_low), 102 | batch_size=len(x_low)) 103 | loader_high = torch.utils.data.DataLoader(XYDataSet(x_high, y_high), 104 | batch_size=len(x_low)) 105 | figure1(x_low, y_low, x_high, y_high, x, y) 106 | 107 | # low-fidelity data 108 | model1 = FCNN(1, 1, [16, 16], torch.nn.Tanh) 109 | optimizer = torch.optim.Adam(model1.parameters(), lr=1e-2, weight_decay=1e-4) 110 | loss = torch.nn.MSELoss() 111 | scheduler = torch.optim.lr_scheduler.MultiStepLR( 112 | optimizer, milestones=[2000, 8000]) 113 | trainer1 = Trainer(model1, optimizer, loss, 114 | scheduler=scheduler, 115 | suppress_display=True) 116 | for i in range(10000): 117 | res = trainer1.train(loader_low) 118 | if i % 100 == 0: 119 | print(f'niter: {i}', f'avg loss: {res.average:.4e}', 120 | f'loss std: {res.std:.4e}', sep=', ') 121 | y1 = model1.eval()(x).detach() 122 | figure2(x_low, y_low, x, y1, x, y) 123 | 124 | # high-fidelity data 125 | model2 = FCNN(1, 1, [16, 16], torch.nn.Tanh) 126 | optimizer = torch.optim.Adam( 127 | model2.parameters(), lr=1e-2, weight_decay=1e-4) 128 | loss = torch.nn.MSELoss() 129 | scheduler = torch.optim.lr_scheduler.MultiStepLR( 130 | optimizer, milestones=[2000, 8000]) 131 | trainer2 = Trainer(model2, optimizer, loss, 132 | scheduler=scheduler, suppress_display=True) 133 | for i in range(10000): 134 | res = trainer2.train(loader_high) 135 | if i % 100 == 0: 136 | print(f'niter: {i}', f'avg loss: {res.average:.4e}', 137 | f'loss std: {res.std:.4e}', sep=', ') 138 | y2 = model2.eval()(x).detach() 139 | figure3(x_high, y_high, x, y2, x, y) 140 | 141 | # MFNN 142 | model3 = HFNN(model1, 1, 1, [16, 16], torch.nn.Tanh) 143 | loss = torch.nn.MSELoss() 144 | optimizer = torch.optim.Adam( 145 | model3.parameters(), lr=1e-2, weight_decay=1e-4) 146 | scheduler = torch.optim.lr_scheduler.MultiStepLR( 147 | optimizer, milestones=[2000, 8000]) 148 | trainer3 = Trainer(model3, optimizer, loss, 149 | scheduler=scheduler, suppress_display=True) 150 | for i in range(10000): 151 | res = trainer3.train(loader_high) 152 | if i % 100 == 0: 153 | print(f'niter: {i}', f'avg loss: {res.average:.4e}', 154 | f'loss std: {res.std:.4e}', sep=', ') 155 | y3 = model3.eval()(x).detach() 156 | figure4(x_low, y_low, x_high, y_high, x, y3, x, y) 157 | 158 | plt.show() 159 | -------------------------------------------------------------------------------- /mfnn/mfnn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from __future__ import annotations 4 | import typing 5 | import torch 6 | 7 | TensorFunc = typing.Callable[[torch.Tensor], torch.Tensor] 8 | 9 | __all__ = ('FCNN', 'HFNN', 'MFNN') 10 | 11 | 12 | class BasicBlock(torch.nn.Module): 13 | """Basic block for a fully-connected layer.""" 14 | 15 | def __init__(self, in_features: int, 16 | out_features: int, 17 | activation: type[torch.nn.Module]): 18 | super().__init__() 19 | self.linear_layer = torch.nn.Linear(in_features, out_features) 20 | self.activation = activation() 21 | 22 | def forward(self, x): 23 | x = self.linear_layer(x) 24 | x = self.activation(x) 25 | return x 26 | 27 | 28 | class FCNN(torch.nn.Module): 29 | def __init__(self, 30 | in_features: int, 31 | out_features: int, 32 | midlayer_features: list[int], 33 | activation: type[torch.nn.Module] = torch.nn.Tanh 34 | ): 35 | super().__init__() 36 | 37 | layer_sizes = [in_features] + midlayer_features 38 | 39 | self.layers = torch.nn.Sequential(*[ 40 | BasicBlock(layer_sizes[i], layer_sizes[i+1], activation) 41 | for i in range(len(midlayer_features)) 42 | ]) 43 | self.fc = torch.nn.Linear(layer_sizes[-1], out_features) 44 | 45 | def forward(self, x: torch.Tensor) -> torch.Tensor: 46 | x = self.layers(x) 47 | x = self.fc(x) 48 | return x 49 | 50 | 51 | class HFNN(torch.nn.Module): 52 | def __init__(self, 53 | lfnn: torch.nn.Module | TensorFunc, 54 | in_features: int, 55 | out_features: int, 56 | midlayer_features: list[int], 57 | activation: type[torch.nn.Module] | None = None): 58 | activation = torch.nn.Tanh if activation is None else activation 59 | super().__init__() 60 | # Avoid updating the trained low-fidelity model during training. 61 | object.__setattr__(self, 'lfnn', lfnn) 62 | self.lfnn: torch.nn.Module | TensorFunc 63 | 64 | self.layers = FCNN(in_features + out_features, 65 | out_features, 66 | midlayer_features, 67 | activation) 68 | self.shortcut = FCNN(in_features + out_features, 69 | out_features, 70 | [], 71 | torch.nn.Identity) 72 | 73 | # Use He Kaiming's normal initialization 74 | for m in self.modules(): 75 | for name, parameter in m.named_parameters(): 76 | if name == 'weight': 77 | torch.nn.init.kaiming_normal_(parameter) 78 | elif name == 'bias': 79 | torch.nn.init.zeros_(parameter) 80 | else: 81 | assert "impossible!" 82 | 83 | def forward(self, x: torch.Tensor) -> torch.Tensor: 84 | y_low = self.lfnn(x) 85 | x_comb = torch.concat([x, y_low], dim=1) 86 | y_layer = self.layers(x_comb) 87 | y_shortcut = self.shortcut(x_comb) 88 | y = y_layer + y_shortcut 89 | return y 90 | 91 | 92 | class MFNN(): 93 | """A flexible network for regression problem of multi-fidelity data. 94 | 95 | This class contains two networks, for fitting low-fidelity data and 96 | high fidelity data respectively. 97 | """ 98 | 99 | def __init__(self, 100 | in_features: int, 101 | out_features: int, 102 | backbone_layers: list[int], 103 | layers_low: list[int], 104 | layers_high: list[int], 105 | activation: type[torch.nn.Module] | None = None): 106 | activation = torch.nn.Tanh if activation is None else activation 107 | 108 | layer_sizes = [in_features] + backbone_layers 109 | self.backbone = torch.nn.Sequential(*[ 110 | BasicBlock(layer_sizes[i], layer_sizes[i+1], activation) 111 | for i in range(len(backbone_layers)) 112 | ]) 113 | 114 | self.fc_low = FCNN(layer_sizes[-1], 115 | out_features, 116 | layers_low, 117 | activation) 118 | self.fc_high = FCNN(out_features + layer_sizes[-1], 119 | out_features, 120 | layers_high, 121 | activation) 122 | self.shortcut = FCNN(out_features + layer_sizes[-1], 123 | out_features, 124 | [], 125 | torch.nn.Identity) 126 | 127 | # Use He Kaiming's normal initialization 128 | for m in [self.backbone, self.fc_low, self.fc_high, self.shortcut]: 129 | for name, parameter in m.named_parameters(): 130 | if name == 'weight': 131 | torch.nn.init.kaiming_normal_(parameter) 132 | elif name == 'bias': 133 | torch.nn.init.zeros_(parameter) 134 | else: 135 | assert "impossible!" 136 | 137 | self.low = MFNNLow(self.backbone, self.fc_low) 138 | self.high = MFNNHigh(self.low, self.fc_high, self.shortcut) 139 | 140 | 141 | class MFNNLow(torch.nn.Module): 142 | def __init__(self, 143 | backbone: torch.nn.Module, 144 | fc_low: torch.nn.Module): 145 | super().__init__() 146 | self.backbone = backbone 147 | self.fc_low = fc_low 148 | 149 | def forward(self, x: torch.Tensor) -> torch.Tensor: 150 | features = self.backbone(x) 151 | y = self.fc_low(features) 152 | return y 153 | 154 | 155 | class MFNNHigh(torch.nn.Module): 156 | def __init__(self, 157 | mfnnlow: torch.nn.Module, 158 | fc_high: torch.nn.Module, 159 | shortcut: torch.nn.Module): 160 | super().__init__() 161 | # Avoid updating the trained low-fidelity model during training. 162 | object.__setattr__(self, 'mfnnlow', mfnnlow) 163 | self.mfnnlow: torch.nn.Module 164 | self.fc_high = fc_high 165 | self.shortcut = shortcut 166 | 167 | def forward(self, x: torch.Tensor) -> torch.Tensor: 168 | features = self.mfnnlow.backbone(x) 169 | y_low = self.mfnnlow.fc_low(features) 170 | x_combine = torch.concat([y_low.reshape(-1, 1), features], dim=1) 171 | y_fc = self.fc_high(x_combine) 172 | y_shortcut = self.shortcut(x_combine) 173 | y = y_fc + y_shortcut 174 | return y 175 | 176 | 177 | if __name__ == '__main__': 178 | 179 | x = torch.linspace(0, 1, 101).reshape(-1, 1) 180 | 181 | lfnn = FCNN(1, 1, [16, 16, 16], activation=torch.nn.Tanh) 182 | hfnn = HFNN(lfnn, 1, 1, [32, 32]) 183 | mfnn = MFNN(1, 1, [16], [16, 16], [16], activation=torch.nn.Tanh) 184 | 185 | ylow = lfnn(x) 186 | yhigh = hfnn(x) 187 | ylow = mfnn.low(x) 188 | yhigh = mfnn.high(x) 189 | -------------------------------------------------------------------------------- /mfnn/trainer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from __future__ import annotations 4 | import sys 5 | import time 6 | import pickle 7 | import typing 8 | 9 | import tqdm 10 | import torch 11 | 12 | 13 | from .utils import DEVICE, Statistics, free_memory, copy_to, GatedStdout 14 | 15 | __all__ = ('Trainer',) 16 | 17 | 18 | _VALID_SECOND_TIME_WARNING = True 19 | 20 | 21 | class Trainer(): 22 | """Class for training a model.""" 23 | 24 | def __init__(self, 25 | model: torch.nn.Module, 26 | optimizer: torch.optim.Optimizer, 27 | critrion: torch.nn.Module, 28 | *, 29 | device: torch.device | int | str | None = None, 30 | start_epoch: int = 0, 31 | filename: str | None = None, 32 | scheduler: torch.optim.lr_scheduler.LRScheduler | None = None, 33 | forced_gc: bool = False, 34 | suppress_display: bool = False 35 | ): 36 | self.model = model 37 | self.optimizer = optimizer 38 | self.critrion = critrion # loss function 39 | # Use property setter to move model and loss function to target device 40 | self.device = DEVICE if device is None else device 41 | 42 | self.epoch = start_epoch 43 | self.filename = 'trainer.trainer' if filename is None else filename 44 | self.scheduler = scheduler 45 | self.is_forced_gc = forced_gc 46 | self.stdout = GatedStdout(suppress_display) 47 | 48 | self.history: dict[str, list[float]] = {'train_loss': [], 49 | 'validate_loss': [], 50 | } 51 | 52 | @property 53 | def device(self) -> torch.device: 54 | return self._device 55 | 56 | @device.setter 57 | def device(self, device: torch.device | int | str) -> Trainer: 58 | self._device = torch.device(device) 59 | self.model.to(self._device) 60 | self.critrion.to(self._device) 61 | return self 62 | 63 | @property 64 | def lr(self) -> list[float]: 65 | return [pg['lr'] for pg in self.optimizer.param_groups] 66 | 67 | @lr.setter 68 | def lr(self, lr: float): 69 | for param_group in self.optimizer.param_groups: 70 | param_group['lr'] = lr 71 | 72 | def save(self, device: torch.device | int | str = "cpu"): 73 | """Save trainer object. 74 | 75 | The trainer is saved to the given `device`, along with the 76 | model. The `device` specified here does not change the device 77 | type of the trainer instance, but only saves all the variables 78 | to this `device`. The default target device is "cpu". 79 | 80 | The `device` specified in this method has nothing to do with 81 | the `load` method's `device` argument. The `device` argument 82 | is introduced here to solve the problem that a cuda model can 83 | not be saved and then loaded on another computer without 84 | gpu. So it is always suggested to set the `device` argument 85 | here to "cpu" to make sure it can be loaded on any computer. 86 | """ 87 | data = copy_to(self.__dict__, torch.device(device)) 88 | with open(self.filename, 'wb') as f: 89 | f.write(pickle.dumps((data, self.device))) 90 | 91 | def save_as(self, filename: str): 92 | self.filename = filename 93 | return self.save() 94 | 95 | @staticmethod 96 | def load(filename: str, 97 | device: torch.device | int | str | None = None 98 | ) -> Trainer: 99 | """Load a trainer object. 100 | 101 | Load the trainer to given `device`. Specifying `device` 102 | argument here would also change the loaded trainer's `device` 103 | property. If device is not given, it is defaulted to the 104 | object's `device` property. 105 | """ 106 | with open(filename, 'rb') as f: 107 | data, default_device = pickle.loads(f.read()) 108 | if device is None: 109 | data = copy_to(data, default_device) 110 | else: 111 | device = torch.device(device) 112 | data = copy_to(data, device) 113 | data['_device'] = device 114 | res = object.__new__(Trainer) 115 | res.__dict__.update(data) 116 | return res 117 | 118 | def train(self, 119 | loader: torch.utils.data.DataLoader 120 | ) -> Statistics[float]: 121 | "Train the model by given dataloader." 122 | t_start = time.time() 123 | self.model.train() 124 | tq = tqdm.tqdm(loader, 125 | desc="train", 126 | ncols=None, 127 | leave=False, 128 | file=self.stdout, 129 | unit="batch") 130 | loss_meter: Statistics[float] = Statistics() 131 | # User defined preprocess 132 | additional_data = self.additional_train_preprocess(tq) 133 | for x, y in tq: 134 | current_batch_size = x.shape[0] 135 | x = x.to(self.device) 136 | y = y.to(self.device) 137 | 138 | # compute prediction error 139 | y_pred = self.model(x) 140 | loss = self.critrion(y_pred, y) 141 | # backpropagation 142 | self.optimizer.zero_grad() 143 | loss.backward() 144 | self.optimizer.step() 145 | 146 | # record results 147 | loss_meter.update(loss.item(), current_batch_size) 148 | tq.set_postfix(loss=f"{loss_meter.value:.4e}") 149 | 150 | # Do some user-defined process. 151 | self.additional_train_process(additional_data, 152 | y_pred, y, loss, tq) 153 | 154 | # Free some space before next round. 155 | del x, y, y_pred, loss 156 | if self.is_forced_gc: 157 | free_memory() 158 | 159 | # Save information for this epoch. 160 | self.epoch += 1 161 | self.history['train_loss'].append(loss_meter.average) 162 | 163 | # User_defined postprocess. 164 | self.additional_train_postprocess(additional_data) 165 | 166 | print(f'train result [{self.epoch}]: ' 167 | f'avg loss = {loss_meter.average:.4e}, ' 168 | f'wall time = {time.time()- t_start:.2f}s', 169 | file=self.stdout) 170 | if self.scheduler: 171 | self.scheduler.step() 172 | 173 | return loss_meter 174 | 175 | def validate(self, 176 | loader: torch.utils.data.DataLoader, 177 | ) -> Statistics[float]: 178 | """Validate the model.""" 179 | t_start = time.time() 180 | self.model.eval() 181 | tq = tqdm.tqdm(loader, 182 | desc="valid", 183 | ncols=None, 184 | leave=False, 185 | file=self.stdout, 186 | unit="batch") 187 | loss_meter: Statistics[float] = Statistics() 188 | # User-defined preprocess 189 | additional_data = self.additional_validate_preprocess(tq) 190 | 191 | for x, y in tq: 192 | current_batch_size = x.size(0) 193 | x = x.to(self.device) 194 | y = y.to(self.device) 195 | 196 | with torch.no_grad(): 197 | y_pred = self.model(x) 198 | loss = self.critrion(y_pred, y) 199 | loss_meter.update(loss.item(), current_batch_size) 200 | tq.set_postfix(loss=f"{loss_meter.value:.4e}") 201 | 202 | # Do some user-defined process. 203 | self.additional_validate_process(additional_data, 204 | y_pred, y, loss, tq) 205 | 206 | del x, y, y_pred, loss 207 | if self.is_forced_gc: 208 | free_memory() 209 | 210 | # Save validation results only the fisrt run. 211 | if len(self.history['validate_loss']) < self.epoch: 212 | self.history['validate_loss'].append(loss_meter.average) 213 | else: 214 | global _VALID_SECOND_TIME_WARNING 215 | if _VALID_SECOND_TIME_WARNING: 216 | sys.stderr.write("The model is validated for the " 217 | "second time in the same epoch, " 218 | "validation result will not be " 219 | "recorded. " 220 | "This warning will be " 221 | "turned off in this session.\n") 222 | _VALID_SECOND_TIME_WARNING = False 223 | 224 | # User_defined postprocess. 225 | self.additional_validate_postprocess(additional_data) 226 | 227 | print(f'valid result [{self.epoch}]: ' 228 | f'avg loss = {loss_meter.average:.4e}, ' 229 | f'wall time = {time.time()- t_start:.2f}s', 230 | file=self.stdout) 231 | return loss_meter 232 | 233 | def step(self, 234 | train_dataloader: torch.utils.data.DataLoader, 235 | valid_dataloader: torch.utils.data.DataLoader, 236 | save_trainer: bool = True, 237 | save_best_model: str | None = None, 238 | ) -> bool: 239 | """Train and validate the model, return if it is the best model. 240 | 241 | Note that if `save_best_model` is enabled, it selects the best 242 | model by its validation loss, which might not be good for 243 | classification problems. 244 | 245 | Parameters 246 | ---------- 247 | save_trainer: bool, optional 248 | whether to save the trainer after this epoch. defaults to 249 | True. 250 | save_best_model: str, optional 251 | Filename for saving the model if it is the best model 252 | indicated by the validation procedure. If not given, the 253 | model will not be saved automatically. 254 | 255 | """ 256 | print(f' ---- Epoch {self.epoch} ---- ', file=self.stdout) 257 | self.train(train_dataloader) 258 | loss = self.validate(valid_dataloader) 259 | if save_trainer: 260 | self.save() 261 | if loss == min(self.history['validate_loss']): 262 | if save_best_model: 263 | print('This model will be saved as the best model.', 264 | file=self.stdout) 265 | with open(save_best_model, 'wb') as f: 266 | torch.save(self.model, f) 267 | return True 268 | return False 269 | 270 | # Class methods can be overwritten. 271 | 272 | def additional_train_preprocess(self, tq: tqdm.std.tqdm) -> typing.Any: 273 | """Additional pre-process in each epoch. 274 | 275 | This method can be overwritten to do some additional work 276 | before iterations in each epoch (eg, prepare dict for 277 | `additional_train_process`). The return value of the method 278 | will be used for `additional_train_process` and 279 | `additional_train_postprocess` 280 | 281 | Parameters 282 | ---------- 283 | tq: tqdm object 284 | The tqdm object. Can modify display here. 285 | 286 | Return 287 | ------ 288 | additional_data: typing.Any 289 | The returned data will be used processed in each batch 290 | and the end of the epoch. 291 | """ 292 | return None 293 | 294 | def additional_train_process(self, 295 | additional_data: typing.Any, 296 | y_pred: torch.Tensor, 297 | y_true: torch.Tensor, 298 | loss: torch.Tensor, 299 | tq: tqdm.std.tqdm): 300 | """Additional process after training the model in each batch. 301 | 302 | This method can be overwritten to do some additional work 303 | after the loss is calculated in each batch (eg, calculate 304 | top-k error in classification task). This function has no 305 | return value. 306 | 307 | Parameters 308 | ---------- 309 | additional_data: typing.Any 310 | Defined in `additional_train_preprocess`. 311 | y_pred: torch.Tensor 312 | Predicted value by model. 313 | y_true: torch.Tensor 314 | True value given by train dataset. 315 | loss: torch.Tensor 316 | Loss of this batch given by critrion. 317 | tq: tqdm object 318 | Can modify display here. 319 | """ 320 | pass 321 | 322 | def additional_train_postprocess(self, 323 | additional_data: typing.Any): 324 | """Additional postprocess in each epoch. 325 | 326 | This method can be overwritten to do some additional work 327 | after the epoch is finished (eg. saving `additional_data`). 328 | 329 | Parameters 330 | ---------- 331 | additional_data: typing.Any 332 | Defined in `additional_train_preprocess`. 333 | """ 334 | pass 335 | 336 | def additional_validate_preprocess(self, tq: tqdm.std.tqdm) -> typing.Any: 337 | """Additional pre-process in each epoch. 338 | 339 | This method can be overwritten to do some additional work 340 | before iterations in each epoch (eg, prepare dict for 341 | `additional_validate_process`). The return value of the method 342 | will be used for `additional_validate_process` and 343 | `additional_validate_postprocess` 344 | 345 | Parameters 346 | ---------- 347 | tq: tqdm object 348 | The tqdm object. Can modify display here. 349 | 350 | Return 351 | ------ 352 | additional_data: typing.Any 353 | The returned data will be used processed in each batch 354 | and the end of the epoch. 355 | """ 356 | return None 357 | 358 | def additional_validate_process(self, 359 | additional_data: typing.Any, 360 | y_pred: torch.Tensor, 361 | y_true: torch.Tensor, 362 | loss: torch.Tensor, 363 | tq: tqdm.std.tqdm): 364 | """Additional process after model validation in each batch. 365 | 366 | This method can be overwritten to do some additional work 367 | after the loss is calculated in each batch (eg, calculate 368 | top-k error in classification task). This function has no 369 | return value. 370 | 371 | Parameters 372 | ---------- 373 | additional_data: typing.Any 374 | Defined in `additional_validate_preprocess`. 375 | y_pred: torch.Tensor 376 | Predicted value by model. 377 | y_true: torch.Tensor 378 | True value given by validation dataset. 379 | loss: torch.Tensor 380 | Loss of this batch given by critrion. 381 | tq: tqdm object 382 | Can modify display here. 383 | """ 384 | pass 385 | 386 | def additional_validate_postprocess(self, 387 | additional_data: typing.Any): 388 | """Additional postprocess in each epoch. 389 | 390 | This method can be overwritten to do some additional work 391 | after the epoch is finished (eg. saving `additional_data`). 392 | 393 | Parameters 394 | ---------- 395 | additional_data: typing.Any 396 | Defined in `additional_validate_preprocess`. 397 | """ 398 | pass 399 | -------------------------------------------------------------------------------- /snapshots/high.svg: -------------------------------------------------------------------------------- 1 | 2 | 4 | 5 | 6 | 7 | 8 | 9 | 2023-05-05T09:04:23.390764 10 | image/svg+xml 11 | 12 | 13 | Matplotlib v3.6.1, https://matplotlib.org/ 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 30 | 31 | 32 | 33 | 39 | 40 | 41 | 42 | 43 | 46 | 47 | 48 | 49 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 87 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 232 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | 240 | 241 | 244 | 245 | 246 | 247 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | 294 | 295 | 296 | 297 | 298 | 299 | 300 | 301 | 302 | 303 | 306 | 307 | 308 | 309 | 310 | 311 | 312 | 313 | 314 | 315 | 316 | 334 | 335 | 336 | 337 | 338 | 339 | 340 | 341 | 342 | 343 | 344 | 345 | 392 | 393 | 394 | 395 | 396 | 397 | 398 | 399 | 400 | 403 | 404 | 405 | 406 | 409 | 410 | 411 | 412 | 413 | 414 | 415 | 416 | 417 | 418 | 425 | 454 | 455 | 456 | 457 | 458 | 459 | 460 | 461 | 462 | 463 | 464 | 465 | 468 | 469 | 470 | 471 | 472 | 473 | 474 | 475 | 476 | 477 | 478 | 479 | 480 | 481 | 482 | 483 | 484 | 485 | 486 | 487 | 490 | 491 | 492 | 493 | 494 | 495 | 496 | 497 | 498 | 499 | 500 | 513 | 514 | 515 | 516 | 517 | 518 | 519 | 520 | 521 | 522 | 523 | 524 | 527 | 528 | 529 | 530 | 531 | 532 | 533 | 534 | 535 | 536 | 537 | 538 | 539 | 540 | 541 | 542 | 543 | 544 | 545 | 546 | 549 | 550 | 551 | 552 | 553 | 554 | 555 | 556 | 557 | 558 | 559 | 560 | 561 | 562 | 563 | 564 | 565 | 566 | 567 | 568 | 571 | 572 | 573 | 574 | 575 | 576 | 577 | 578 | 579 | 580 | 581 | 582 | 583 | 584 | 585 | 586 | 587 | 588 | 589 | 590 | 591 | 627 | 628 | 629 | 630 | 631 | 632 | 633 | 634 | 645 | 646 | 647 | 648 | 649 | 650 | 651 | 652 | 653 | 654 | 655 | 656 | 657 | 658 | 659 | 660 | 661 | 662 | 663 | 664 | 833 | 834 | 835 | 1060 | 1061 | 1062 | 1065 | 1066 | 1067 | 1070 | 1071 | 1072 | 1075 | 1076 | 1077 | 1080 | 1081 | 1082 | 1083 | 1094 | 1095 | 1096 | 1097 | 1098 | 1099 | 1100 | 1101 | 1102 | 1103 | 1104 | 1139 | 1167 | 1223 | 1224 | 1225 | 1226 | 1227 | 1228 | 1229 | 1230 | 1231 | 1235 | 1236 | 1237 | 1238 | 1239 | 1240 | 1241 | 1242 | 1243 | 1244 | 1245 | 1246 | 1247 | 1251 | 1252 | 1253 | 1254 | 1255 | 1256 | 1278 | 1307 | 1338 | 1362 | 1363 | 1364 | 1365 | 1366 | 1367 | 1368 | 1369 | 1370 | 1371 | 1372 | 1373 | 1374 | 1375 | 1376 | 1377 | 1378 | -------------------------------------------------------------------------------- /snapshots/data.svg: -------------------------------------------------------------------------------- 1 | 2 | 4 | 5 | 6 | 7 | 8 | 9 | 2023-05-05T09:05:02.405522 10 | image/svg+xml 11 | 12 | 13 | Matplotlib v3.6.1, https://matplotlib.org/ 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 30 | 31 | 32 | 33 | 39 | 40 | 41 | 42 | 43 | 46 | 47 | 48 | 49 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 87 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 232 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | 240 | 241 | 244 | 245 | 246 | 247 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | 294 | 295 | 296 | 297 | 298 | 299 | 300 | 301 | 302 | 303 | 306 | 307 | 308 | 309 | 310 | 311 | 312 | 313 | 314 | 315 | 316 | 334 | 335 | 336 | 337 | 338 | 339 | 340 | 341 | 342 | 343 | 344 | 345 | 392 | 393 | 394 | 395 | 396 | 397 | 398 | 399 | 400 | 403 | 404 | 405 | 406 | 409 | 410 | 411 | 412 | 413 | 414 | 415 | 416 | 417 | 418 | 425 | 426 | 427 | 428 | 429 | 430 | 431 | 432 | 433 | 434 | 435 | 438 | 439 | 440 | 441 | 442 | 443 | 444 | 445 | 446 | 447 | 448 | 477 | 478 | 479 | 480 | 481 | 482 | 483 | 484 | 485 | 486 | 487 | 490 | 491 | 492 | 493 | 494 | 495 | 496 | 497 | 498 | 499 | 500 | 501 | 502 | 503 | 504 | 505 | 506 | 507 | 510 | 511 | 512 | 513 | 514 | 515 | 516 | 517 | 518 | 519 | 520 | 521 | 522 | 523 | 524 | 525 | 526 | 527 | 530 | 531 | 532 | 533 | 534 | 535 | 536 | 537 | 538 | 539 | 540 | 541 | 542 | 543 | 544 | 545 | 546 | 547 | 548 | 549 | 585 | 586 | 587 | 588 | 589 | 590 | 591 | 592 | 603 | 604 | 605 | 606 | 607 | 608 | 609 | 610 | 611 | 612 | 613 | 614 | 615 | 616 | 617 | 618 | 619 | 620 | 621 | 622 | 623 | 624 | 625 | 626 | 627 | 628 | 629 | 630 | 631 | 632 | 633 | 634 | 635 | 636 | 637 | 638 | 639 | 640 | 641 | 642 | 643 | 644 | 645 | 646 | 647 | 648 | 649 | 650 | 651 | 652 | 653 | 654 | 655 | 656 | 657 | 658 | 659 | 660 | 665 | 666 | 667 | 668 | 669 | 670 | 671 | 672 | 673 | 674 | 675 | 676 | 677 | 678 | 679 | 680 | 681 | 682 | 683 | 684 | 903 | 904 | 905 | 908 | 909 | 910 | 913 | 914 | 915 | 918 | 919 | 920 | 923 | 924 | 925 | 926 | 937 | 938 | 939 | 940 | 941 | 942 | 943 | 944 | 945 | 946 | 947 | 965 | 986 | 1027 | 1028 | 1029 | 1030 | 1031 | 1032 | 1033 | 1034 | 1035 | 1036 | 1037 | 1038 | 1039 | 1040 | 1041 | 1042 | 1077 | 1105 | 1161 | 1162 | 1163 | 1164 | 1165 | 1166 | 1167 | 1168 | 1169 | 1173 | 1174 | 1175 | 1176 | 1177 | 1178 | 1200 | 1229 | 1260 | 1284 | 1285 | 1286 | 1287 | 1288 | 1289 | 1290 | 1291 | 1292 | 1293 | 1294 | 1295 | 1296 | 1297 | 1298 | 1299 | 1300 | -------------------------------------------------------------------------------- /snapshots/low.svg: -------------------------------------------------------------------------------- 1 | 2 | 4 | 5 | 6 | 7 | 8 | 9 | 2023-05-05T09:04:47.417465 10 | image/svg+xml 11 | 12 | 13 | Matplotlib v3.6.1, https://matplotlib.org/ 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 30 | 31 | 32 | 33 | 39 | 40 | 41 | 42 | 43 | 46 | 47 | 48 | 49 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 87 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 232 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | 240 | 241 | 244 | 245 | 246 | 247 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | 294 | 295 | 296 | 297 | 298 | 299 | 300 | 301 | 302 | 303 | 306 | 307 | 308 | 309 | 310 | 311 | 312 | 313 | 314 | 315 | 316 | 334 | 335 | 336 | 337 | 338 | 339 | 340 | 341 | 342 | 343 | 344 | 345 | 392 | 393 | 394 | 395 | 396 | 397 | 398 | 399 | 400 | 403 | 404 | 405 | 406 | 409 | 410 | 411 | 412 | 413 | 414 | 415 | 416 | 417 | 418 | 425 | 426 | 427 | 428 | 429 | 430 | 431 | 432 | 433 | 434 | 435 | 438 | 439 | 440 | 441 | 442 | 443 | 444 | 445 | 446 | 447 | 448 | 477 | 478 | 479 | 480 | 481 | 482 | 483 | 484 | 485 | 486 | 487 | 490 | 491 | 492 | 493 | 494 | 495 | 496 | 497 | 498 | 499 | 500 | 501 | 502 | 503 | 504 | 505 | 506 | 507 | 510 | 511 | 512 | 513 | 514 | 515 | 516 | 517 | 518 | 519 | 520 | 521 | 522 | 523 | 524 | 525 | 526 | 527 | 530 | 531 | 532 | 533 | 534 | 535 | 536 | 537 | 538 | 539 | 540 | 541 | 542 | 543 | 544 | 545 | 546 | 547 | 548 | 549 | 585 | 586 | 587 | 588 | 589 | 590 | 591 | 592 | 603 | 604 | 605 | 606 | 607 | 608 | 609 | 610 | 611 | 612 | 613 | 614 | 615 | 616 | 617 | 618 | 619 | 620 | 621 | 622 | 623 | 624 | 625 | 626 | 627 | 628 | 629 | 630 | 631 | 632 | 633 | 634 | 635 | 636 | 637 | 638 | 639 | 640 | 641 | 642 | 643 | 644 | 645 | 646 | 647 | 648 | 649 | 650 | 651 | 652 | 653 | 654 | 655 | 656 | 657 | 658 | 659 | 822 | 823 | 824 | 1043 | 1044 | 1045 | 1048 | 1049 | 1050 | 1053 | 1054 | 1055 | 1058 | 1059 | 1060 | 1063 | 1064 | 1065 | 1066 | 1077 | 1078 | 1079 | 1080 | 1081 | 1082 | 1083 | 1084 | 1085 | 1086 | 1087 | 1105 | 1126 | 1167 | 1168 | 1169 | 1170 | 1171 | 1172 | 1173 | 1174 | 1178 | 1179 | 1180 | 1181 | 1182 | 1183 | 1184 | 1185 | 1186 | 1187 | 1188 | 1189 | 1193 | 1194 | 1195 | 1196 | 1197 | 1198 | 1220 | 1249 | 1280 | 1304 | 1305 | 1306 | 1307 | 1308 | 1309 | 1310 | 1311 | 1312 | 1313 | 1314 | 1315 | 1316 | 1317 | 1318 | 1319 | 1320 | -------------------------------------------------------------------------------- /snapshots/mfnn.svg: -------------------------------------------------------------------------------- 1 | 2 | 4 | 5 | 6 | 7 | 8 | 9 | 2023-05-05T09:04:04.559854 10 | image/svg+xml 11 | 12 | 13 | Matplotlib v3.6.1, https://matplotlib.org/ 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 30 | 31 | 32 | 33 | 39 | 40 | 41 | 42 | 43 | 46 | 47 | 48 | 49 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 87 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 232 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | 240 | 241 | 244 | 245 | 246 | 247 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | 294 | 295 | 296 | 297 | 298 | 299 | 300 | 301 | 302 | 303 | 306 | 307 | 308 | 309 | 310 | 311 | 312 | 313 | 314 | 315 | 316 | 334 | 335 | 336 | 337 | 338 | 339 | 340 | 341 | 342 | 343 | 344 | 345 | 392 | 393 | 394 | 395 | 396 | 397 | 398 | 399 | 400 | 403 | 404 | 405 | 406 | 409 | 410 | 411 | 412 | 413 | 414 | 415 | 416 | 417 | 418 | 425 | 426 | 427 | 428 | 429 | 430 | 431 | 432 | 433 | 434 | 435 | 438 | 439 | 440 | 441 | 442 | 443 | 444 | 445 | 446 | 447 | 448 | 477 | 478 | 479 | 480 | 481 | 482 | 483 | 484 | 485 | 486 | 487 | 490 | 491 | 492 | 493 | 494 | 495 | 496 | 497 | 498 | 499 | 500 | 501 | 502 | 503 | 504 | 505 | 506 | 507 | 510 | 511 | 512 | 513 | 514 | 515 | 516 | 517 | 518 | 519 | 520 | 521 | 522 | 523 | 524 | 525 | 526 | 527 | 530 | 531 | 532 | 533 | 534 | 535 | 536 | 537 | 538 | 539 | 540 | 541 | 542 | 543 | 544 | 545 | 546 | 547 | 548 | 549 | 585 | 586 | 587 | 588 | 589 | 590 | 591 | 592 | 603 | 604 | 605 | 606 | 607 | 608 | 609 | 610 | 611 | 612 | 613 | 614 | 615 | 616 | 617 | 618 | 619 | 620 | 621 | 622 | 623 | 624 | 625 | 626 | 627 | 628 | 629 | 630 | 631 | 632 | 633 | 634 | 635 | 636 | 637 | 638 | 639 | 640 | 641 | 642 | 643 | 644 | 645 | 646 | 647 | 648 | 649 | 650 | 651 | 652 | 653 | 654 | 655 | 656 | 657 | 658 | 659 | 660 | 665 | 666 | 667 | 668 | 669 | 670 | 671 | 672 | 673 | 674 | 675 | 676 | 677 | 678 | 679 | 680 | 681 | 682 | 683 | 684 | 907 | 908 | 909 | 1128 | 1129 | 1130 | 1133 | 1134 | 1135 | 1138 | 1139 | 1140 | 1143 | 1144 | 1145 | 1148 | 1149 | 1150 | 1151 | 1162 | 1163 | 1164 | 1165 | 1166 | 1167 | 1168 | 1169 | 1170 | 1171 | 1172 | 1190 | 1211 | 1252 | 1253 | 1254 | 1255 | 1256 | 1257 | 1258 | 1259 | 1260 | 1261 | 1262 | 1263 | 1264 | 1265 | 1266 | 1267 | 1302 | 1330 | 1386 | 1387 | 1388 | 1389 | 1390 | 1391 | 1392 | 1393 | 1394 | 1398 | 1399 | 1400 | 1401 | 1402 | 1403 | 1440 | 1469 | 1493 | 1528 | 1529 | 1530 | 1531 | 1532 | 1533 | 1534 | 1535 | 1536 | 1537 | 1541 | 1542 | 1543 | 1544 | 1545 | 1546 | 1568 | 1599 | 1600 | 1601 | 1602 | 1603 | 1604 | 1605 | 1606 | 1607 | 1608 | 1609 | 1610 | 1611 | 1612 | 1613 | 1614 | 1615 | --------------------------------------------------------------------------------