├── .vscode └── settings.json ├── README.md ├── __pycache__ ├── dataset.cpython-38.pyc ├── models.cpython-38.pyc └── utils.cpython-38.pyc ├── data └── si_quad_50.pkl ├── dataset.py ├── logs └── forward_model_MLP_evalFalse_log.pkl ├── models.py ├── train_forward_model.py ├── train_inverse_model.py ├── trained_models ├── forward_model_MLP_evalFalse.pt └── inverse_model.pt └── utils.py /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.pythonPath": "/home/hzwang/anaconda3/envs/meta-learning/bin/python" 3 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tandem neural network training for metasurface design 2 | 3 | This is a PyTorch implementation of the work: https://onlinelibrary.wiley.com/doi/abs/10.1002/adma.201905467 4 | 5 | Step 1: train the forward network 6 | Step 2: train the inverse network based on the pretrained forward network 7 | 8 | Script for training: 9 | python train_forward_model.py --seed 42 --epochs 100 --device 0 10 | python train_inverse_model.py --seed 42 --epochs 100 --device 0 -------------------------------------------------------------------------------- /__pycache__/dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hammer-wang/tandem_neural_network/6a0da20dde4cb571057c38b3b29938449e0787ed/__pycache__/dataset.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/models.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hammer-wang/tandem_neural_network/6a0da20dde4cb571057c38b3b29938449e0787ed/__pycache__/models.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hammer-wang/tandem_neural_network/6a0da20dde4cb571057c38b3b29938449e0787ed/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /data/si_quad_50.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hammer-wang/tandem_neural_network/6a0da20dde4cb571057c38b3b29938449e0787ed/data/si_quad_50.pkl -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import torch 3 | import torch.utils.data as data 4 | 5 | import numpy as np 6 | from sklearn import preprocessing 7 | 8 | import pickle as pkl 9 | from torch.utils.data import dataloader 10 | 11 | 12 | class SiliconColor(data.Dataset): 13 | 14 | def __init__(self, root, split='train'): 15 | self.root = root 16 | self.data = pkl.load(open(self.root, 'rb')) 17 | 18 | self.X = np.array([self.data['period'].to_numpy().astype('int'), 19 | self.data['height'].to_numpy().astype('int'), 20 | self.data['diameter'].to_numpy().astype('int'), 21 | self.data['gap'].to_numpy().astype('int')]).T 22 | self.Y = np.array([self.data['x'].to_numpy().astype('float'), 23 | self.data['y'].to_numpy().astype('float'), 24 | self.data['Y'].to_numpy().astype('float')]).T 25 | self.c = self.data['class'].to_numpy().astype('int') 26 | self.num_classes = len(np.unique(self.c)) 27 | 28 | tr_size = int(len(self.X) * 0.6) 29 | val_size = int(len(self.X) * 0.2) 30 | 31 | self.X_tr, self.Y_tr, self.c_tr = self.X[: 32 | tr_size], self.Y[:tr_size], self.c[:tr_size] 33 | self.X_val, self.Y_val, self.c_val = self.X[tr_size:tr_size + 34 | val_size], self.Y[tr_size:tr_size + val_size], self.c[tr_size:tr_size + val_size] 35 | self.X_te, self.Y_te, self.c_te = self.X[tr_size + 36 | val_size:], self.Y[tr_size + val_size:], self.c[tr_size + val_size:] 37 | 38 | self.scaler = preprocessing.MinMaxScaler() 39 | self.scaler.fit(self.X_tr) 40 | 41 | if split == 'train': 42 | self.X, self.Y, self.c = self.scaler.transform( 43 | self.X_tr), self.Y_tr, self.c_tr 44 | elif split == 'val': 45 | self.X, self.Y, self.c = self.scaler.transform( 46 | self.X_val), self.Y_val, self.c_val 47 | else: 48 | self.X, self.Y, self.c = self.scaler.transform( 49 | self.X_te), self.Y_te, self.c_te 50 | 51 | def __getitem__(self, index): 52 | return self.X[index], self.Y[index], self.c[index] 53 | 54 | def __len__(self): 55 | return len(self.X) 56 | 57 | 58 | def get_datasets(root): 59 | tr_dataset = SiliconColor(root, 'train') 60 | val_dataset = SiliconColor(root, 'val') 61 | te_dataset = SiliconColor(root, 'test') 62 | 63 | return tr_dataset, val_dataset, te_dataset 64 | 65 | 66 | class SiliconColorNShot: 67 | 68 | def __init__(self, root, batch_size, n_way, k_shot, k_query, device=None): 69 | self.device = device 70 | self.batch_size = batch_size 71 | self.n_way = n_way # this parameter is not in use for regression mode 72 | self.k_shot = k_shot 73 | self.k_query = k_query 74 | 75 | self.dt_tr, self.dt_val, self.dt_te = get_datasets(root) 76 | self.cls2idx_tr = self.class_to_idx(self.dt_tr) 77 | self.cls2idx_val = self.class_to_idx(self.dt_val) 78 | self.cls2idx_te = self.class_to_idx(self.dt_te) 79 | 80 | def class_to_idx(self, dt): 81 | ''' 82 | Build a hash map that maps the class to sample indices 83 | ''' 84 | cls_to_idx = {} 85 | for i in range(dt.num_classes): 86 | cls_to_idx[i] = np.where(dt.c == i)[0] 87 | 88 | return cls_to_idx 89 | 90 | def next(self, mode='train'): 91 | ''' 92 | first randomly sample tasks, then sample the data points based on the sampled classes 93 | ''' 94 | if mode == 'train': 95 | dt = self.dt_tr 96 | cls2idx = self.cls2idx_tr 97 | elif mode == 'val': 98 | dt = self.dt_val 99 | cls2idx = self.cls2idx_val 100 | else: 101 | dt = self.dt_te 102 | cls2idx = self.cls2idx_te 103 | 104 | # TODO: map the previous classes to new classes. 105 | # the classes index has to be from 0 to n_way - 1 106 | x_spts, y_spts, x_qrys, y_qrys = [], [], [], [] 107 | for b in range(self.batch_size): 108 | selected_classes = np.random.choice( 109 | np.arange(dt.num_classes), self.n_way) 110 | le = preprocessing.LabelEncoder() 111 | le.fit(selected_classes) 112 | 113 | x_spt, y_spt, x_qry, y_qry = [], [], [], [] 114 | for cls in selected_classes: 115 | if len(cls2idx[cls]) >= self.k_shot + self.k_query: 116 | sample_indices = np.random.choice( 117 | cls2idx[cls], self.k_shot + self.k_query, replace=False) 118 | else: 119 | sample_indices = np.random.choice( 120 | cls2idx[cls], self.k_shot + self.k_query, replace=True) 121 | 122 | x_spt.append(dt.X[sample_indices[:self.k_shot]]) 123 | x_qry.append(dt.X[sample_indices[self.k_shot:]]) 124 | y_spt.append(le.transform(dt.c[sample_indices[:self.k_shot]])) 125 | y_qry.append(le.transform(dt.c[sample_indices[self.k_shot:]])) 126 | 127 | perm = np.random.permutation(self.n_way * self.k_shot) 128 | x_spt = np.array(x_spt).reshape(self.n_way * self.k_shot, 4)[perm] 129 | y_spt = np.array(y_spt).reshape(self.n_way * self.k_shot)[perm] 130 | 131 | perm = np.random.permutation(self.n_way * self.k_query) 132 | x_qry = np.array(x_qry).reshape(self.n_way * self.k_query, 4)[perm] 133 | y_qry = np.array(y_qry).reshape(self.n_way * self.k_query)[perm] 134 | 135 | x_spts.append(x_spt) 136 | y_spts.append(y_spt) 137 | x_qrys.append(x_qry) 138 | y_qrys.append(y_qry) 139 | 140 | x_spts = np.array(x_spts).astype(np.float32).reshape( 141 | self.batch_size, self.k_shot * self.n_way, 4) 142 | y_spts = np.array(y_spts).astype(np.int).reshape( 143 | self.batch_size, self.k_shot * self.n_way) 144 | x_qrys = np.array(x_qrys).astype(np.float32).reshape( 145 | self.batch_size, self.k_query * self.n_way, 4) 146 | y_qrys = np.array(y_qrys).astype(np.int).reshape( 147 | self.batch_size, self.k_query * self.n_way) 148 | 149 | x_spts, y_spts, x_qrys, y_qrys = [torch.from_numpy(z).to( 150 | self.device) for z in [x_spts, y_spts, x_qrys, y_qrys]] 151 | 152 | return x_spts, y_spts, x_qrys, y_qrys 153 | 154 | 155 | class SiliconColorRegression: 156 | 157 | def __init__(self, root, batch_size, k_shot, k_query, device=None): 158 | self.device = device 159 | self.batch_size = batch_size 160 | self.k_shot = k_shot 161 | self.k_query = k_query 162 | 163 | self.dt_tr, self.dt_val, self.dt_te = get_datasets(root) 164 | 165 | def next(self, mode='train'): 166 | ''' 167 | first randomly sample tasks, then sample the data points based on the sampled classes 168 | ''' 169 | if mode == 'train': 170 | dt = self.dt_tr 171 | elif mode == 'val': 172 | dt = self.dt_val 173 | else: 174 | dt = self.dt_te 175 | 176 | x_spts, y_spts, x_qrys, y_qrys = [], [], [], [] 177 | for b in range(self.batch_size): 178 | 179 | sample_indicies = np.random.choice( 180 | len(dt.X), self.k_shot, replace=False) 181 | x_spt, y_spt = dt.X[sample_indicies], dt.Y[sample_indicies] 182 | perm = np.random.permutation(self.k_shot) 183 | x_spt = np.array(x_spt).reshape(self.k_shot, 4)[perm] 184 | y_spt = np.array(y_spt).reshape(self.k_shot, 3)[perm] 185 | 186 | sample_indicies = np.random.choice( 187 | len(dt.X), self.k_query, replace=False) 188 | x_qry, y_qry = dt.X[sample_indicies], dt.Y[sample_indicies] 189 | perm = np.random.permutation(self.k_query) 190 | x_qry = np.array(x_qry).reshape(self.k_query, 4)[perm] 191 | y_qry = np.array(y_qry).reshape(self.k_query, 3)[perm] 192 | 193 | x_spts.append(x_spt) 194 | y_spts.append(y_spt) 195 | x_qrys.append(x_qry) 196 | y_qrys.append(y_qry) 197 | 198 | x_spts = np.array(x_spts).astype(np.float32).reshape( 199 | self.batch_size, self.k_shot, 4) 200 | y_spts = np.array(y_spts).astype(np.float32).reshape( 201 | self.batch_size, self.k_shot, 3) 202 | x_qrys = np.array(x_qrys).astype(np.float32).reshape( 203 | self.batch_size, self.k_query, 4) 204 | y_qrys = np.array(y_qrys).astype(np.float32).reshape( 205 | self.batch_size, self.k_query, 3) 206 | 207 | x_spts, y_spts, x_qrys, y_qrys = [torch.from_numpy(z).to( 208 | self.device) for z in [x_spts, y_spts, x_qrys, y_qrys]] 209 | 210 | return x_spts, y_spts, x_qrys, y_qrys 211 | 212 | # Meta Learning INverse design 213 | 214 | 215 | class SiliconColorTaskLevel(data.Dataset): 216 | 217 | def __init__(self, root, split='train', center_task=None): 218 | ''' 219 | Args: 220 | root: the path of the dataset 221 | split: 'train', 'val', or 'test' 222 | center_task: if this is provided, then only return neighboring tasks based on the center task, otherwise return the entire split. 223 | ''' 224 | self.root = root 225 | self.data = pkl.load(open(self.root, 'rb')) 226 | 227 | self.X_tr, self.Y_tr, self.c_tr = self.get_feature_label( 228 | self.data['train']) 229 | self.X_val, self.Y_val, self.c_val = self.get_feature_label( 230 | self.data['val']) 231 | self.X_te, self.Y_te, self.c_te = self.get_feature_label( 232 | self.data['test']) 233 | 234 | self.scaler = preprocessing.MinMaxScaler() 235 | self.scaler.fit(self.X_tr) 236 | 237 | if split == 'train': 238 | self.X, self.Y, self.c = self.scaler.transform( 239 | self.X_tr), self.Y_tr, self.c_tr 240 | elif split == 'val': 241 | self.X, self.Y, self.c = self.scaler.transform( 242 | self.X_val), self.Y_val, self.c_val 243 | else: 244 | self.X, self.Y, self.c = self.scaler.transform( 245 | self.X_te), self.Y_te, self.c_te 246 | 247 | if center_task: 248 | c2neighbors = pkl.load( 249 | open('./data/kmeans_50_adjacency.pkl', 'rb')) 250 | neighbors = c2neighbors[center_task] 251 | subset = np.isin(self.c, neighbors) 252 | self.X, self.Y, self.c = self.X[subset], self.Y[subset], self.c[subset] 253 | print('Center task {}, neighbors {}, num_neighbors {}'.format( 254 | center_task, neighbors, len(self.X))) 255 | 256 | @staticmethod 257 | def get_feature_label(data): 258 | X = np.array([data['period'].to_numpy().astype('int'), 259 | data['height'].to_numpy().astype('int'), 260 | data['diameter'].to_numpy().astype('int'), 261 | data['gap'].to_numpy().astype('int')]).T 262 | 263 | Y = np.array([data['x'].to_numpy().astype('float'), 264 | data['y'].to_numpy().astype('float'), 265 | data['Y'].to_numpy().astype('float')]).T 266 | 267 | c = data['class'].to_numpy().astype('int') 268 | 269 | return X, Y, c 270 | 271 | def __getitem__(self, index): 272 | return self.X[index], self.Y[index], self.c[index] 273 | 274 | def __len__(self): 275 | return len(self.X) 276 | 277 | 278 | class SiliconColorRegressionTaskSplit: 279 | 280 | ''' 281 | Sample task based on the task split 282 | ''' 283 | 284 | def __init__(self, root, batch_size, k_shot, k_query, device=None, ratio=1): 285 | self.device = device 286 | self.batch_size = batch_size 287 | self.k_shot = k_shot 288 | self.k_query = k_query 289 | 290 | self.dt = {'train': SiliconColorTaskLevel(root, 'train'), 291 | 'val': SiliconColorTaskLevel(root, 'val'), 292 | 'test': SiliconColorTaskLevel(root, 'test')} 293 | 294 | self.split_class = {'train': np.unique(self.dt['train'].c), 'val': np.unique( 295 | self.dt['val'].c), 'test': np.unique(self.dt['test'].c)} 296 | 297 | self.ratio = ratio 298 | if ratio < 1: 299 | num_total = len(self.dt['train']) 300 | num_train = int(num_total * ratio) 301 | print(num_train) 302 | self.dt['train'], _ = torch.utils.data.random_split( 303 | self.dt['train'], [num_train, num_total - num_train]) 304 | 305 | print('training dataset size {}'.format(len(self.dt['train']))) 306 | 307 | def next(self, mode='train', return_task_id=False): 308 | ''' 309 | first randomly sample tasks, then sample the data points based on the sampled classes 310 | ''' 311 | dt = self.dt[mode] 312 | 313 | x_spts, y_spts, x_qrys, y_qrys = [], [], [], [] 314 | selected_classes = np.random.choice( 315 | self.split_class[mode], size=self.batch_size) 316 | 317 | for b in range(self.batch_size): 318 | 319 | if self.ratio == 1: 320 | sample_indicies = np.random.choice(np.where(dt.c == selected_classes[b])[ 321 | 0], self.k_shot + self.k_query, replace=False) 322 | dataset = dt 323 | else: 324 | if len(np.where(dt.dataset.c[dt.indices] == selected_classes[b])[0]) < self.k_shot + self.k_query: 325 | sample_indicies = np.random.choice(np.where(dt.dataset.c[dt.indices] == selected_classes[b])[ 326 | 0], self.k_shot + self.k_query, replace=True) 327 | else: 328 | sample_indicies = np.random.choice(np.where(dt.dataset.c[dt.indices] == selected_classes[b])[ 329 | 0], self.k_shot + self.k_query, replace=False) 330 | 331 | dataset = dt.dataset 332 | 333 | x_spt, y_spt = dataset.X[sample_indicies[:self.k_shot] 334 | ], dataset.Y[sample_indicies[:self.k_shot]] 335 | 336 | perm = np.random.permutation(self.k_shot) 337 | x_spt = np.array(x_spt).reshape(self.k_shot, 4)[perm] 338 | y_spt = np.array(y_spt).reshape(self.k_shot, 3)[perm] 339 | 340 | x_qry, y_qry = dataset.X[sample_indicies[self.k_shot:] 341 | ], dataset.Y[sample_indicies[self.k_shot:]] 342 | perm = np.random.permutation(self.k_query) 343 | x_qry = np.array(x_qry).reshape(self.k_query, 4)[perm] 344 | y_qry = np.array(y_qry).reshape(self.k_query, 3)[perm] 345 | 346 | x_spts.append(x_spt) 347 | y_spts.append(y_spt) 348 | x_qrys.append(x_qry) 349 | y_qrys.append(y_qry) 350 | 351 | x_spts = np.array(x_spts).astype(np.float32).reshape( 352 | self.batch_size, self.k_shot, 4) 353 | y_spts = np.array(y_spts).astype(np.float32).reshape( 354 | self.batch_size, self.k_shot, 3) 355 | x_qrys = np.array(x_qrys).astype(np.float32).reshape( 356 | self.batch_size, self.k_query, 4) 357 | y_qrys = np.array(y_qrys).astype(np.float32).reshape( 358 | self.batch_size, self.k_query, 3) 359 | 360 | x_spts, y_spts, x_qrys, y_qrys = [torch.from_numpy(z).to( 361 | self.device) for z in [x_spts, y_spts, x_qrys, y_qrys]] 362 | 363 | if not return_task_id: 364 | return x_spts, y_spts, x_qrys, y_qrys 365 | 366 | return x_spts, y_spts, x_qrys, y_qrys, selected_classes 367 | 368 | 369 | if __name__ == '__main__': 370 | 371 | dt = pkl.load(open('./data/si_quad_50.pkl', 'rb'))['test'] 372 | test_classes = np.unique(dt['class']) 373 | for c in test_classes: 374 | dataloader = data.DataLoader(SiliconColorTaskLevel( 375 | './data/si_quad_50.pkl', split='train', center_task=c), batch_size=20) 376 | for i in range(2): 377 | x, y, c = next(iter(dataloader)) 378 | print(x.size(), y.size(), c.size()) 379 | -------------------------------------------------------------------------------- /logs/forward_model_MLP_evalFalse_log.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hammer-wang/tandem_neural_network/6a0da20dde4cb571057c38b3b29938449e0787ed/logs/forward_model_MLP_evalFalse_log.pkl -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | from re import M 2 | import torch 3 | from torch import nn 4 | import numpy as np 5 | 6 | 7 | class ForwardNet(nn.Module): 8 | 9 | def __init__(self): 10 | super(ForwardNet, self).__init__() 11 | self.linear1 = nn.Linear(4, 128) 12 | self.relu1 = nn.ReLU() 13 | self.linear2 = nn.Linear(128, 128) 14 | self.relu2 = nn.ReLU() 15 | self.linear3 = nn.Linear(128, 128) 16 | self.relu3 = nn.ReLU() 17 | self.out = nn.Linear(128, 3) 18 | 19 | def forward(self, x): 20 | h = self.relu1(self.linear1(x)) 21 | h = self.relu2(self.linear2(h)) 22 | h = self.relu3(self.linear2(h)) 23 | o = self.out(h) 24 | return o 25 | 26 | 27 | class InverseNet(nn.Module): 28 | 29 | def __init__(self, out_transform=nn.Sigmoid()): 30 | super(InverseNet, self).__init__() 31 | self.linear1 = nn.Linear(3, 128) 32 | self.relu1 = nn.ReLU() 33 | self.linear2 = nn.Linear(128, 128) 34 | self.relu2 = nn.ReLU() 35 | self.linear3 = nn.Linear(128, 128) 36 | self.relu3 = nn.ReLU() 37 | self.out = nn.Linear(128, 4) 38 | self.out_transform = out_transform 39 | 40 | def forward(self, x): 41 | h = self.relu1(self.linear1(x)) 42 | h = self.relu2(self.linear2(h)) 43 | h = self.relu3(self.linear2(h)) 44 | o = self.out(h) 45 | if self.out_transform: 46 | o = self.out_transform(o) 47 | return o 48 | 49 | 50 | class TandemNet(nn.Module): 51 | 52 | def __init__(self, forward_model, inverse_model): 53 | super(TandemNet, self).__init__() 54 | self.forward_model = forward_model 55 | self.inverse_model = inverse_model 56 | 57 | def forward(self, y): 58 | ''' 59 | Args: 60 | y: true CIE coordinates 61 | 62 | Returns: 63 | x_: predicted structural parameters 64 | y_: predicted CIE coordinates for the inversely-designed structure 65 | 66 | ''' 67 | x_ = self.inverse_model(y) 68 | y_ = self.forward_model(x_) 69 | return x_, y_ 70 | 71 | 72 | if __name__ == '__main__': 73 | 74 | forward_model = ForwardNet() 75 | inverse_model = InverseNet() 76 | tandem_net = TandemNet(forward_model, inverse_model) 77 | 78 | x = torch.rand(128, 3) 79 | print(forward_model(inverse_model(x))) 80 | 81 | print(tandem_net(x)) 82 | -------------------------------------------------------------------------------- /train_forward_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import DataLoader 4 | from torch import nn, optim 5 | import pickle as pkl 6 | 7 | import torch 8 | from torch import nn 9 | from torch.utils.data import DataLoader 10 | 11 | from models import ForwardNet 12 | from dataset import SiliconColorRegressionTaskSplit 13 | from torch.optim.lr_scheduler import StepLR 14 | import tqdm 15 | 16 | def train_forward_model(forward_net, args, evaluate=False): 17 | device = args.device 18 | datasets = SiliconColorRegressionTaskSplit( 19 | './data/si_quad_50.pkl', 4, 10, 10, device) 20 | tr_dataset, val_dataset, te_dataset = datasets.dt['train'], datasets.dt['val'], datasets.dt['test'] 21 | tr_loader = DataLoader(tr_dataset, batch_size=256, num_workers=2) 22 | val_loader = DataLoader(val_dataset, batch_size=128, num_workers=2) 23 | test_loader = DataLoader(te_dataset, batch_size=128, num_workers=2) 24 | 25 | optimizer = optim.Adam(forward_net.parameters(), 26 | lr=1e-3, weight_decay=1e-4) 27 | scheduler = StepLR(optimizer, step_size=args.epochs // 5, gamma=0.2) 28 | criterion = nn.MSELoss() 29 | best_val_loss = float('inf') 30 | 31 | val_losses = [] 32 | train_losses = [] 33 | val_loss = val_forward(forward_net, criterion, val_loader, args) 34 | val_losses.append(val_loss) 35 | 36 | model_path = './trained_models/forward_model_{}_eval{}.pt'.format( 37 | 'MLP', args.eval) 38 | log_path = './logs/forward_model_{}_eval{}_log.pkl'.format( 39 | 'MLP', args.eval) 40 | 41 | for e in tqdm.tqdm(range(args.epochs)): 42 | tr_loss = train_forward_epoch( 43 | forward_net, optimizer, criterion, tr_loader, args) 44 | val_loss = val_forward(forward_net, criterion, val_loader, args) 45 | print('Epoch {}, Train loss {:.4f}, Val MSE {:.4f}'.format( 46 | e, tr_loss, val_loss)) 47 | 48 | val_losses.append(val_loss) 49 | train_losses.append(tr_loss) 50 | if val_loss < best_val_loss: 51 | torch.save(forward_net.state_dict(), model_path) 52 | best_val_loss = val_loss 53 | print('Serializing model...') 54 | 55 | scheduler.step() 56 | 57 | state_dict = torch.load(model_path) 58 | forward_net.load_state_dict(state_dict) 59 | test_loss = val_forward(forward_net, criterion, test_loader, args) 60 | print('test loss {}'.format(test_loss)) 61 | 62 | forward_training_log = {'train_loss': train_losses, 63 | 'val_mse': val_losses, 'test_mse': test_loss} 64 | 65 | pkl.dump(forward_training_log, open(log_path, 'wb')) 66 | 67 | 68 | def train_forward_epoch(net, optimizer, criterion, dataloader, args): 69 | net.train() 70 | device = args.device 71 | 72 | loss_tr = [] 73 | optimizer.zero_grad() 74 | 75 | for x, y, _ in dataloader: 76 | x, y = x.float().to(device), y.float().to(device) 77 | y_pred = net(x) 78 | loss = criterion(y, y_pred) 79 | 80 | loss.backward() 81 | optimizer.step() 82 | loss_tr.append(loss.detach().cpu().item()) 83 | 84 | return np.mean(loss_tr) 85 | 86 | 87 | def val_forward(net, criterion, dataloader, args): 88 | net.eval() 89 | device = args.device 90 | 91 | loss_total = [] 92 | for x, y, _ in dataloader: 93 | x, y = x.float().to(device), y.float().to(device) 94 | bs = len(x) 95 | 96 | y_pred = net(x) 97 | loss = criterion(y, y_pred) 98 | loss_total.append(loss.cpu().detach().item() * bs) 99 | 100 | return np.sum(loss_total) / len(dataloader.dataset.X) 101 | 102 | 103 | if __name__ == '__main__': 104 | 105 | ''' 106 | Evaluate model: 107 | python train_forward_model.py --eval --seed 42 108 | python train_forward_model.py --seed 10 109 | ''' 110 | import argparse 111 | import numpy as np 112 | 113 | parser = argparse.ArgumentParser() 114 | parser.add_argument('--eval', action='store_true', 115 | help='train a separate model for evaluation') 116 | parser.add_argument('--seed', type=int) 117 | parser.add_argument('--device', type=int) 118 | parser.add_argument('--epochs', type=int, default=1000) 119 | 120 | args = parser.parse_args() 121 | train_eval = args.eval 122 | seed = args.seed 123 | 124 | torch.manual_seed(seed) 125 | np.random.seed(seed) 126 | 127 | forward_net = ForwardNet().to(args.device) 128 | train_forward_model(forward_net, args, args.eval) 129 | -------------------------------------------------------------------------------- /train_inverse_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import DataLoader 4 | from torch import nn, optim 5 | import pickle as pkl 6 | 7 | import torch 8 | from torch import nn 9 | from torch.utils.data import DataLoader 10 | 11 | from models import ForwardNet, InverseNet, TandemNet 12 | from dataset import SiliconColorRegressionTaskSplit 13 | from utils import load_model, serialize_model 14 | 15 | def train_inverse_model(epochs, device): 16 | 17 | print('Start training inverse model...') 18 | 19 | datasets = SiliconColorRegressionTaskSplit( 20 | './data/si_quad_50.pkl', 4, 10, 10, device) 21 | tr_dataset, val_dataset, te_dataset = datasets.dt['train'], datasets.dt['val'], datasets.dt['test'] 22 | tr_loader = DataLoader(tr_dataset, batch_size=256, num_workers=2) 23 | val_loader = DataLoader(val_dataset, batch_size=128, num_workers=2) 24 | test_loader = DataLoader(te_dataset, batch_size=128, num_workers=2) 25 | 26 | # fix the model parameter of the forward model while training the model parameters of the inverse model 27 | forward_model = ForwardNet() 28 | forward_model = load_model(forward_model, 'forward_model_MLP_evalFalse.pt').to(device) 29 | 30 | inverse_model = InverseNet().to(device) 31 | tandem_net = TandemNet(forward_model, inverse_model) 32 | optimizer = optim.Adam(inverse_model.parameters(), 33 | lr=5e-4, weight_decay=1e-4) 34 | criterion = nn.MSELoss() 35 | 36 | best_val_loss = float('inf') 37 | 38 | for e in range(epochs): 39 | tr_loss = train_inverse_epoch( 40 | tandem_net, optimizer, criterion, tr_loader, device) 41 | val_loss = val_inverse(tandem_net, criterion, val_loader, device) 42 | print('Epoch {}, Train loss {:.4f}, Val loss {:.4f}'.format( 43 | e, tr_loss, val_loss)) 44 | 45 | if val_loss < best_val_loss: 46 | serialize_model('./trained_models', tandem_net.inverse_model, 'inverse_model.pt') 47 | best_val_loss = val_loss 48 | print('Serializing model...') 49 | 50 | 51 | def train_inverse_epoch(net, optimizer, criterion, dataloader, device): 52 | net.train() 53 | loss_tr = [] 54 | optimizer.zero_grad() 55 | 56 | for x, y, _ in dataloader: 57 | x, y = x.float().to(device), y.float().to(device) 58 | # the tandem net takes y (CIE coordinates as the inputs) 59 | x_pred, y_pred = net(y) 60 | loss = criterion(y, y_pred) 61 | loss.backward() 62 | optimizer.step() 63 | loss_tr.append(loss.detach().cpu().item()) 64 | 65 | return np.mean(loss_tr) 66 | 67 | 68 | def val_inverse(net, criterion, dataloader, device): 69 | net.eval() 70 | 71 | loss_val = [] 72 | for x, y, _ in dataloader: 73 | x, y = x.float().to(device), y.float().to(device) 74 | x_pred, y_pred = net(y) 75 | loss = criterion(y, y_pred) 76 | loss_val.append(loss.cpu().detach().item()) 77 | 78 | return np.mean(loss_val) 79 | 80 | 81 | if __name__ == '__main__': 82 | 83 | ''' 84 | Evaluate model: 85 | python train_inverse_model.py --eval --seed 42 86 | python train_inverse_model.py --seed 10 87 | ''' 88 | import argparse 89 | import numpy as np 90 | 91 | parser = argparse.ArgumentParser() 92 | parser.add_argument('--seed', type=int) 93 | parser.add_argument('--device', type=int) 94 | parser.add_argument('--epochs', type=int, default=1000) 95 | 96 | args = parser.parse_args() 97 | seed = args.seed 98 | device = args.device 99 | epochs = args.epochs 100 | 101 | torch.manual_seed(seed) 102 | np.random.seed(seed) 103 | 104 | train_inverse_model(epochs, device) -------------------------------------------------------------------------------- /trained_models/forward_model_MLP_evalFalse.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hammer-wang/tandem_neural_network/6a0da20dde4cb571057c38b3b29938449e0787ed/trained_models/forward_model_MLP_evalFalse.pt -------------------------------------------------------------------------------- /trained_models/inverse_model.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hammer-wang/tandem_neural_network/6a0da20dde4cb571057c38b3b29938449e0787ed/trained_models/inverse_model.pt -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import numpy as np 4 | 5 | def set_torch_seed(seed): 6 | """ 7 | Sets the pytorch seeds for current experiment run 8 | :param seed: The seed (int) 9 | :return: A random number generator to use 10 | """ 11 | rng = np.random.RandomState(seed=seed) 12 | torch_seed = rng.randint(0, 999999) 13 | torch.manual_seed(seed=torch_seed) 14 | 15 | return rng 16 | 17 | def load_model(net, name): 18 | state_dict = torch.load(os.path.join('./trained_models', name), map_location=lambda storage, loc: storage) 19 | net.load_state_dict(state_dict) 20 | del state_dict 21 | torch.cuda.empty_cache() 22 | return net 23 | 24 | def serialize_model(log_path, net, name): 25 | print('serializing model to {}'.format(log_path)) 26 | torch.save(net.state_dict(), os.path.join(log_path, name)) 27 | --------------------------------------------------------------------------------