├── 1-wx_inner.ipynb ├── 2-ne2wx.ipynb ├── 3-nmc2lfp.ipynb ├── README.md ├── common.py ├── data ├── ne_data │ └── test.txt ├── nmc_data │ └── test.txt └── our_data │ └── test.txt ├── fig1.png ├── model ├── ne2wx │ └── ne_pretrain_best.pt ├── nmc2lfp │ └── nmc_pretrain_best.pt └── wx_inner │ └── wx_inner_pretrain_end.pt ├── net.py ├── prepare_ne_data.py ├── prepare_nmc_data.py ├── requirements.txt └── tool.py /README.md: -------------------------------------------------------------------------------- 1 | # Health_status_prediction 2 | This is a PyTorch implementation of the paper: Real-time personalized health status prediction of lithium-ion batteries using deep transfer learning. 3 | Ye Yuan, Guijun Ma, Songpei Xu 4 | ![imamge](https://github.com/HAIRLAB/Health_status_prediction/blob/main/fig1.png) 5 | ## Environment Setup 6 | 1. **System** 7 | - OS: Ubuntu 18.04 8 | - GPU (one card): 9 | - NVIDIA GeForce RTX 3090 (24 GB) 10 | - CUDA: 11.1 11 | - Driver: 470.57.02 12 | 2. **Python version** 13 | python = 3.8.8 14 | ## Requirements 15 | This model is implemented using Python3 with dependencies specified in requirements.txt 16 | ``` 17 | # Install pytorch, see the official website for details: https://pytorch.org/ 18 | pip install torch==1.8.0+cu111 torchvision==0.9.0+cu111 torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html 19 | 20 | # Install other dependencies 21 | pip install -r requirements.txt 22 | ``` 23 | ## Data Preparation 24 | [Data Download](https://doi.org/10.17632/nsc7hnsg4s.2) 25 | Yuan, Ye; Ma, Guijun; Xu, Songpei (2022), “The Dataset for: Real-time personalized health status prediction of lithium-ion batteries using deep transfer learning ”, Mendeley Data, V2, doi: 10.17632/nsc7hnsg4s.2 26 | ## Code Introduction 27 | - [tool.py](https://github.com/HAIRLAB/Health_status_prediction/blob/main/tool.py) : Early stopping function 28 | - [common.py](https://github.com/HAIRLAB/Health_status_prediction/blob/main/common.py) : Including data preprocessing, model training and validation 29 | - [net.py](https://github.com/HAIRLAB/Health_status_prediction/blob/main/net.py) : Model structure 30 | - [prepare_ne_data.py](https://github.com/HAIRLAB/Health_status_prediction/blob/main/prepare_ne_data.py) : Data preprocessing for Task B 31 | - [prepare_nmc_data.py](https://github.com/HAIRLAB/Health_status_prediction/blob/main/prepare_nmc_data.py) : Data preprocessing for Task C 32 | - [1-wx_inner.ipynb](https://github.com/HAIRLAB/Health_status_prediction/blob/main/1-wx_inner.ipynb) : The pipeline of the Task A 33 | - [2-ne2wx.ipynb](https://github.com/HAIRLAB/Health_status_prediction/blob/main/2-ne2wx.ipynb) : The pipeline of the Task B 34 | - [3-nmc2lfp.ipynb](https://github.com/HAIRLAB/Health_status_prediction/blob/main/3-nmc2lfp.ipynb) : The pipeline of the Task C 35 | ## Contact 36 | If you have any questions, please contact songpeix@hust.edu.cn 37 | 38 | -------------------------------------------------------------------------------- /common.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import pickle 4 | import math 5 | import random 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | from scipy import interpolate 9 | from datetime import datetime 10 | import pandas as pd 11 | from tool import EarlyStopping 12 | from sklearn.metrics import roc_auc_score,mean_squared_error 13 | 14 | import torch 15 | import torch.nn.functional as F 16 | import torch.optim as optim 17 | 18 | from torch import nn 19 | from torch.autograd import Variable 20 | from torch.utils.data import DataLoader, Dataset, Sampler, TensorDataset 21 | from torch.utils.data.sampler import RandomSampler 22 | 23 | plt.rcParams['font.sans-serif'] = ['SimHei'] 24 | plt.rcParams['axes.unicode_minus'] = False 25 | 26 | import warnings 27 | warnings.filterwarnings('ignore') 28 | 29 | 30 | # save dict 31 | def save_obj(obj,name): 32 | with open(name + '.pkl','wb') as f: 33 | pickle.dump(obj,f) 34 | 35 | #load dict 36 | def load_obj(name): 37 | with open(name +'.pkl','rb') as f: 38 | return pickle.load(f) 39 | 40 | def seed_torch(seed=1029): 41 | random.seed(seed) 42 | os.environ['PYTHONHASHSEED'] = str(seed) 43 | np.random.seed(seed) 44 | torch.manual_seed(seed) 45 | torch.cuda.manual_seed(seed) 46 | torch.cuda.manual_seed_all(seed) 47 | torch.backends.cudnn.benchmark = False 48 | torch.backends.cudnn.deterministic = True 49 | 50 | def interp(v, q, num): 51 | f = interpolate.interp1d(v,q,kind='linear') 52 | v_new = np.linspace(v[0],v[-1],num) 53 | q_new = f(v_new) 54 | vq_new = np.concatenate((v_new.reshape(-1,1),q_new.reshape(-1,1)),axis=1) 55 | return q_new 56 | 57 | def get_xy(name, n_cyc, in_stride, fea_num, v_low, v_upp, q_low, q_upp, rul_factor, cap_factor): 58 | """ 59 | Args: 60 | n_cyc (int): The previous cycles number for model input 61 | in_stride (int): The interval in the previous cycles number 62 | fea_num (int): The number of interpolation 63 | v_low (float): Voltage minimum for normalization 64 | v_upp (float): Voltage maximum for normalization 65 | q_low (float): Capacity minimum for normalization 66 | q_upp (float): Capacity maximum for normalization 67 | rul_factor (float): The RUL factor for normalization 68 | cap_factor (float): The capacity factor for normalization 69 | """ 70 | A = load_obj(f'./data/our_data/{name}')[name] 71 | A_rul = A['rul'] 72 | A_dq = A['dq'] 73 | A_df = A['data'] 74 | 75 | all_idx = list(A_dq.keys())[9:] 76 | all_fea, rul_lbl, cap_lbl = [], [], [] 77 | for cyc in all_idx: 78 | tmp = A_df[cyc] 79 | tmp = tmp.loc[tmp['Status'].apply(lambda x: not 'discharge' in x)] 80 | 81 | left = (tmp['Current (mA)']<5000).argmax() + 1 82 | right = (tmp['Current (mA)']<1090).argmax() - 2 83 | 84 | tmp = tmp.iloc[left:right] 85 | 86 | tmp_v = tmp['Voltage (V)'].values 87 | tmp_q = tmp['Capacity (mAh)'].values 88 | tmp_t = tmp['Time (s)'].values 89 | v_fea = interp(tmp_t, tmp_v, fea_num) 90 | q_fea = interp(tmp_t, tmp_q, fea_num) 91 | 92 | tmp_fea = np.hstack((v_fea.reshape(-1,1), q_fea.reshape(-1,1))) 93 | 94 | all_fea.append(np.expand_dims(tmp_fea,axis=0)) 95 | rul_lbl.append(A_rul[cyc]) 96 | cap_lbl.append(A_dq[cyc]) 97 | all_fea = np.vstack(all_fea) 98 | rul_lbl = np.array(rul_lbl) 99 | cap_lbl = np.array(cap_lbl) 100 | 101 | all_fea_c = all_fea.copy() 102 | all_fea_c[:,:,0] = (all_fea_c[:,:,0]-v_low)/(v_upp-v_low) 103 | all_fea_c[:,:,1] = (all_fea_c[:,:,1]-q_low)/(q_upp-q_low) 104 | dif_fea = all_fea_c - all_fea_c[0:1,:,:] 105 | all_fea = np.concatenate((all_fea,dif_fea),axis=2) 106 | 107 | all_fea = np.lib.stride_tricks.sliding_window_view(all_fea,(n_cyc,fea_num,4)) 108 | cap_lbl = np.lib.stride_tricks.sliding_window_view(cap_lbl,(n_cyc,)) 109 | all_fea = all_fea.squeeze(axis=(1,2,)) 110 | rul_lbl = rul_lbl[n_cyc-1:] 111 | all_fea = all_fea[:,(in_stride - 1)::in_stride,:,:] 112 | cap_lbl = cap_lbl[:,(in_stride - 1)::in_stride,] 113 | 114 | all_fea_new = np.zeros(all_fea.shape) 115 | all_fea_new[:,:,:,0] = (all_fea[:,:,:,0]-v_low)/(v_upp-v_low) 116 | all_fea_new[:,:,:,1] = (all_fea[:,:,:,1]-q_low)/(q_upp-q_low) 117 | all_fea_new[:,:,:,2] = all_fea[:,:,:,2] 118 | all_fea_new[:,:,:,3] = all_fea[:,:,:,3] 119 | print(f'{name} length is {all_fea_new.shape[0]}', 120 | 'v_max:', '%.4f'%all_fea_new[:,:,:,0].max(), 121 | 'q_max:', '%.4f'%all_fea_new[:,:,:,1].max(), 122 | 'dv_max:', '%.4f'%all_fea_new[:,:,:,2].max(), 123 | 'dq_max:', '%.4f'%all_fea_new[:,:,:,3].max()) 124 | rul_lbl = rul_lbl / rul_factor 125 | cap_lbl = cap_lbl / cap_factor 126 | 127 | return all_fea_new,np.hstack((rul_lbl.reshape(-1,1),cap_lbl)) 128 | 129 | 130 | class Trainer(): 131 | 132 | def __init__(self, lr, n_epochs,device, patience, lamda, alpha, model_name): 133 | """ 134 | Args: 135 | lr (float): Learning rate 136 | n_epochs (int): The number of training epoch 137 | device: 'cuda' or 'cpu' 138 | patience (int): How long to wait after last time validation loss improved. 139 | lamda (float): The weight of RUL loss 140 | alpha (List: [float]): The weights of Capacity loss 141 | model_name (str): The model save path 142 | """ 143 | self.lr = lr 144 | self.n_epochs = n_epochs 145 | self.device = device 146 | self.patience = patience 147 | self.model_name = model_name 148 | self.lamda = lamda 149 | self.alpha = alpha 150 | 151 | def train(self, train_loader, valid_loader, model, load_model): 152 | model = model.to(self.device) 153 | device = self.device 154 | optimizer = optim.Adam(model.parameters(), lr=self.lr,) 155 | model_name = self.model_name 156 | lamda = self.lamda 157 | alpha = self.alpha 158 | 159 | loss_fn = nn.MSELoss() 160 | early_stopping = EarlyStopping(self.patience, verbose=True) 161 | loss_fn.to(self.device) 162 | 163 | # Training 164 | train_loss = [] 165 | valid_loss = [] 166 | total_loss = [] 167 | 168 | for epoch in range(self.n_epochs): 169 | model.train() 170 | y_true, y_pred = [], [] 171 | losses = [] 172 | for step, (x,y) in enumerate(train_loader): 173 | optimizer.zero_grad() 174 | 175 | x = x.to(device) 176 | y = y.to(device) 177 | y_, soh_ = model(x) 178 | 179 | loss = lamda * loss_fn(y_.squeeze(), y[:,0]) 180 | 181 | for i in range(y.shape[1] - 1): 182 | loss += loss_fn(soh_[:,i], y[:,i+1]) * alpha[i] 183 | 184 | loss.backward() 185 | optimizer.step() 186 | losses.append(loss.cpu().detach().numpy()) 187 | 188 | y_pred.append(y_) 189 | y_true.append(y[:,0]) 190 | 191 | y_true = torch.cat(y_true, axis=0) 192 | y_pred = torch.cat(y_pred, axis=0) 193 | 194 | epoch_loss = mean_squared_error(y_true.cpu().detach().numpy(), y_pred.cpu().detach().numpy()) 195 | train_loss.append(epoch_loss) 196 | 197 | losses = np.mean(losses) 198 | total_loss.append(losses) 199 | 200 | # validate 201 | model.eval() 202 | y_true, y_pred = [], [] 203 | with torch.no_grad(): 204 | for step, (x,y) in enumerate(valid_loader): 205 | x = x.to(device) 206 | y = y.to(device) 207 | y_, soh_ = model(x) 208 | 209 | y_pred.append(y_) 210 | y_true.append(y[:,0]) 211 | 212 | y_true = torch.cat(y_true, axis=0) 213 | y_pred = torch.cat(y_pred, axis=0) 214 | epoch_loss = mean_squared_error(y_true.cpu().detach().numpy(), y_pred.cpu().detach().numpy()) 215 | valid_loss.append(epoch_loss) 216 | 217 | if self.n_epochs > 100: 218 | if (epoch % 100 == 0 and epoch !=0): 219 | print('Epoch number : ', epoch) 220 | print(f'-- "train" loss {train_loss[-1]:.4}', f'-- "valid" loss {epoch_loss:.4}',f'-- "total" loss {losses:.4}') 221 | else : 222 | print(f'-- "train" loss {train_loss[-1]:.4}', f'-- "valid" loss {epoch_loss:.4}',f'-- "total" loss {losses:.4}') 223 | 224 | early_stopping(epoch_loss, model, f'{model_name}_best.pt') 225 | if early_stopping.early_stop: 226 | break 227 | 228 | if load_model: 229 | model.load_state_dict(torch.load(f'{model_name}_best.pt')) 230 | else: 231 | torch.save(model.state_dict(), f'{model_name}_end.pt') 232 | 233 | return model, train_loss, valid_loss, total_loss 234 | 235 | def test(self, test_loader, model): 236 | model = model.to(self.device) 237 | device = self.device 238 | 239 | y_true, y_pred, soh_true, soh_pred = [], [], [], [] 240 | model.eval() 241 | with torch.no_grad(): 242 | for step, (x, y) in enumerate(test_loader): 243 | x = x.to(device) 244 | y = y.to(device) 245 | y_, soh_ = model(x) 246 | 247 | y_pred.append(y_) 248 | y_true.append(y[:,0]) 249 | soh_pred.append(soh_) 250 | soh_true.append(y[:,1:]) 251 | 252 | y_true = torch.cat(y_true, axis=0) 253 | y_pred = torch.cat(y_pred, axis=0) 254 | soh_true = torch.cat(soh_true, axis=0) 255 | soh_pred = torch.cat(soh_pred, axis=0) 256 | mse_loss = mean_squared_error(y_true.cpu().detach().numpy(), y_pred.cpu().detach().numpy()) 257 | return y_true, y_pred, mse_loss, soh_true, soh_pred 258 | 259 | 260 | class FineTrainer(): 261 | 262 | def __init__(self, lr, n_epochs,device, patience, lamda, train_alpha, valid_alpha, model_name): 263 | """ 264 | Args: 265 | lr (float): Learning rate 266 | n_epochs (int): The number of training epoch 267 | device: 'cuda' or 'cpu' 268 | patience (int): How long to wait after last time validation loss improved. 269 | lamda (float): The weight of RUL loss. In fine-tuning part, set 0. 270 | train_alpha (List: [float]): The weights of Capacity loss in model training 271 | valid_alpha (List: [float]): The weights of Capacity loss in model validation 272 | model_name (str): The model save path 273 | """ 274 | self.lr = lr 275 | self.n_epochs = n_epochs 276 | self.device = device 277 | self.patience = patience 278 | self.model_name = model_name 279 | self.lamda = lamda 280 | self.train_alpha = train_alpha 281 | self.valid_alpha = valid_alpha 282 | 283 | def train(self, train_loader, valid_loader, model, load_model): 284 | model = model.to(self.device) 285 | device = self.device 286 | optimizer = optim.Adam(model.parameters(), lr=self.lr,) 287 | model_name = self.model_name 288 | lamda = self.lamda 289 | train_alpha = self.train_alpha 290 | valid_alpha = self.valid_alpha 291 | 292 | loss_fn = nn.MSELoss() 293 | early_stopping = EarlyStopping(self.patience, verbose=True) 294 | loss_fn.to(self.device) 295 | 296 | # Training 297 | train_loss = [] 298 | valid_loss = [] 299 | total_loss = [] 300 | added_loss = [] 301 | 302 | for epoch in range(self.n_epochs): 303 | model.train() 304 | y_true, y_pred = [], [] 305 | losses = [] 306 | for step, (x,y) in enumerate(train_loader): 307 | optimizer.zero_grad() 308 | 309 | x = x.to(device) 310 | y = y.to(device) 311 | y_, soh_ = model(x) 312 | soh_ = soh_.view(y_.shape[0], -1) 313 | 314 | loss = lamda * loss_fn(y_.squeeze(), y[:,0]) 315 | 316 | for i in range(y.shape[1] - 1): 317 | loss += loss_fn(soh_[:,i], y[:,i+1]) * train_alpha[i] 318 | 319 | loss.backward() 320 | optimizer.step() 321 | losses.append(loss.cpu().detach().numpy()) 322 | 323 | y_pred.append(y_) 324 | y_true.append(y[:,0]) 325 | 326 | y_true = torch.cat(y_true, axis=0) 327 | y_pred = torch.cat(y_pred, axis=0) 328 | 329 | epoch_loss = mean_squared_error(y_true.cpu().detach().numpy(), y_pred.cpu().detach().numpy()) 330 | train_loss.append(epoch_loss) 331 | 332 | losses = np.mean(losses) 333 | total_loss.append(losses) 334 | 335 | # validate 336 | model.eval() 337 | y_true, y_pred, all_true, all_pred = [], [], [], [] 338 | with torch.no_grad(): 339 | for step, (x,y) in enumerate(valid_loader): 340 | x = x.to(device) 341 | y = y.to(device) 342 | y_, soh_ = model(x) 343 | soh_ = soh_.view(y_.shape[0], -1) 344 | 345 | y_pred.append(y_) 346 | y_true.append(y[:,0]) 347 | all_true.append(y[:,1:]) 348 | all_pred.append(soh_) 349 | 350 | y_true = torch.cat(y_true, axis=0) 351 | y_pred = torch.cat(y_pred, axis=0) 352 | all_true = torch.cat(all_true, axis=0) 353 | all_pred = torch.cat(all_pred, axis=0) 354 | epoch_loss = mean_squared_error(y_true.cpu().detach().numpy(), y_pred.cpu().detach().numpy()) 355 | valid_loss.append(epoch_loss) 356 | 357 | temp = 0 358 | for i in range(all_true.shape[1]): 359 | temp += mean_squared_error(all_true[0:1,i].cpu().detach().numpy(), 360 | all_pred[0:1,i].cpu().detach().numpy()) * valid_alpha[i] 361 | added_loss.append(temp) 362 | 363 | if self.n_epochs > 10: 364 | if (epoch % 200 == 0 and epoch !=0): 365 | print('Epoch number : ', epoch) 366 | print(f'-- "train" loss {train_loss[-1]:.4}', f'-- "valid" loss {epoch_loss:.4}', 367 | f'-- "total" loss {losses:.4}',f'-- "added" loss {temp:.4}') 368 | else : 369 | print(f'-- "train" loss {train_loss[-1]:.4}', f'-- "valid" loss {epoch_loss:.4}', 370 | f'-- "total" loss {losses:.4}',f'-- "added" loss {temp:.4}') 371 | 372 | early_stopping(temp, model, f'{model_name}_fine_best.pt') 373 | if early_stopping.early_stop: 374 | break 375 | 376 | if load_model: 377 | model.load_state_dict(torch.load(f'{model_name}_fine_best.pt')) 378 | else: 379 | torch.save(model.state_dict(), f'{model_name}_fine_end.pt') 380 | 381 | return model, train_loss, valid_loss, total_loss, added_loss 382 | 383 | def test(self, test_loader, model): 384 | model = model.to(self.device) 385 | device = self.device 386 | 387 | y_true, y_pred, soh_true, soh_pred = [], [], [], [] 388 | model.eval() 389 | with torch.no_grad(): 390 | for step, (x, y) in enumerate(test_loader): 391 | x = x.to(device) 392 | y = y.to(device) 393 | y_, soh_ = model(x) 394 | soh_ = soh_.view(y_.shape[0], -1) 395 | 396 | y_pred.append(y_) 397 | y_true.append(y[:,0]) 398 | soh_pred.append(soh_) 399 | soh_true.append(y[:,1:]) 400 | 401 | y_true = torch.cat(y_true, axis=0) 402 | y_pred = torch.cat(y_pred, axis=0) 403 | soh_true = torch.cat(soh_true, axis=0) 404 | soh_pred = torch.cat(soh_pred, axis=0) 405 | mse_loss = mean_squared_error(y_true.cpu().detach().numpy(), y_pred.cpu().detach().numpy()) 406 | return y_true, y_pred, mse_loss, soh_true, soh_pred -------------------------------------------------------------------------------- /data/ne_data/test.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /data/nmc_data/test.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /data/our_data/test.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HAIRLAB/Health_status_prediction/8e9073edbc54c2620b5356e601a5d48abf70573d/fig1.png -------------------------------------------------------------------------------- /model/ne2wx/ne_pretrain_best.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HAIRLAB/Health_status_prediction/8e9073edbc54c2620b5356e601a5d48abf70573d/model/ne2wx/ne_pretrain_best.pt -------------------------------------------------------------------------------- /model/nmc2lfp/nmc_pretrain_best.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HAIRLAB/Health_status_prediction/8e9073edbc54c2620b5356e601a5d48abf70573d/model/nmc2lfp/nmc_pretrain_best.pt -------------------------------------------------------------------------------- /model/wx_inner/wx_inner_pretrain_end.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HAIRLAB/Health_status_prediction/8e9073edbc54c2620b5356e601a5d48abf70573d/model/wx_inner/wx_inner_pretrain_end.pt -------------------------------------------------------------------------------- /net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.optim as optim 4 | 5 | from torch import nn 6 | from torch.autograd import Variable 7 | from torch.utils.data import DataLoader, Dataset, Sampler, TensorDataset 8 | from torch.utils.data.sampler import RandomSampler 9 | 10 | class BidirectionalLSTM(nn.Module): 11 | def __init__(self, nIn, nHidden, nOut, dropout): 12 | super(BidirectionalLSTM, self).__init__() 13 | """ 14 | Args: 15 | nIn (int): The number of input unit 16 | nHidden (int): The number of hidden unit 17 | nOut (int): The number of output unit 18 | """ 19 | self.rnn = nn.LSTM(nIn, nHidden, bidirectional=False, batch_first=True) 20 | self.embedding = nn.Linear(nHidden, nOut) 21 | if dropout: 22 | self.dropout = nn.Dropout(p=0.5) 23 | else: 24 | self.dropout = dropout 25 | 26 | def forward(self, input): 27 | recurrent, _ = self.rnn(input) 28 | b, T, h = recurrent.size() 29 | t_rec = recurrent.contiguous().view(b * T, h) 30 | 31 | if self.dropout: 32 | t_rec = self.dropout(t_rec) 33 | output = self.embedding(t_rec) 34 | output = output.contiguous().view(b, T, -1) 35 | 36 | return output 37 | 38 | class CRNN(nn.Module): 39 | def __init__(self, ni, nc, no, nh, n_rnn=2, leakyRelu=False,sigmoid = False): 40 | """ 41 | Args: 42 | ni (int): The number of input unit 43 | nc (int): The number of original channel 44 | no (int): The number of output unit 45 | nh (int): The number of hidden unit 46 | """ 47 | super(CRNN, self).__init__() 48 | 49 | ks = [3, 3, 3, 3, 3, 3, 3] 50 | ps = [0, 0, 0, 0, 0, 0, 0] 51 | ss = [2, 2, 2, 2, 2, 2, 1] 52 | nm = [8, 16, 64, 64, 64, 64, 64] 53 | 54 | cnn = nn.Sequential() 55 | 56 | def convRelu(i, cnn, batchNormalization=False): 57 | nIn = nc if i == 0 else nm[i - 1] 58 | if i == 3: nIn = 64 59 | nOut = nm[i] 60 | cnn.add_module('conv{0}'.format(i), 61 | nn.Conv2d(nIn, nOut, (ks[i],1), (ss[i],1), (ps[i],0))) 62 | if batchNormalization: 63 | cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut)) 64 | if leakyRelu: 65 | cnn.add_module('relu{0}'.format(i), 66 | nn.LeakyReLU(0.2, inplace=True)) 67 | else: 68 | cnn.add_module('relu{0}'.format(i), nn.ReLU(True)) 69 | 70 | convRelu(0,cnn) 71 | cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d((2,1), (2,1))) 72 | convRelu(1,cnn) 73 | cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d((2,1), (2,1))) 74 | convRelu(2,cnn) 75 | cnn.add_module('pooling{0}'.format(2), 76 | nn.MaxPool2d((2, 1), (2, 1), (0, 0))) 77 | self.sigmoid = sigmoid 78 | self.cnn = cnn 79 | self.rnn = nn.Sequential( 80 | BidirectionalLSTM(64, nh, nh, False), 81 | BidirectionalLSTM(nh, nh, no, False),) 82 | self.rul = nn.Linear(10, 1) 83 | self.soh = nn.Linear(64, 1) 84 | 85 | 86 | def forward(self, input): 87 | """ 88 | Input shape: [b, c, h, w] 89 | Output shape: 90 | rul [b, 1] 91 | soh [b, 10] 92 | """ 93 | conv = self.cnn(input) 94 | b, c, h, w = conv.size() 95 | conv = conv.squeeze(2) 96 | conv = conv.permute(0, 2, 1) 97 | output = self.rnn(conv) 98 | soh = self.soh(output).squeeze() 99 | 100 | if not self.sigmoid: 101 | rul = self.rul(soh) 102 | else: 103 | rul = F.sigmoid(self.rul(soh)) 104 | 105 | return rul, soh 106 | -------------------------------------------------------------------------------- /prepare_ne_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.io as sio 3 | import h5py 4 | import matplotlib.pyplot as plt 5 | from scipy import interpolate 6 | from copy import deepcopy 7 | from scipy import stats 8 | from scipy.optimize import leastsq 9 | from scipy.stats import pearsonr 10 | import pickle 11 | 12 | def save_obj(obj,name): 13 | with open(name + '.pkl','wb') as f: 14 | pickle.dump(obj,f) 15 | 16 | def load_obj(name): 17 | with open(name +'.pkl','rb') as f: 18 | return pickle.load(f) 19 | 20 | def change(arr,t,num): 21 | x_new = np.linspace(t[0],t[-1],num) 22 | f_linear = interpolate.interp1d(t,arr) 23 | y_new = f_linear(x_new) 24 | return y_new 25 | 26 | path1 = './data/ne_data/2017-05-12_batchdata_updated_struct_errorcorrect.mat' 27 | path2 = './data/ne_data/2017-06-30_batchdata_updated_struct_errorcorrect.mat' 28 | path3 = './data/ne_data/2018-04-12_batchdata_updated_struct_errorcorrect.mat' 29 | 30 | temp1=h5py.File(path1,'r') 31 | temp2=h5py.File(path2,'r') 32 | temp3=h5py.File(path3,'r') 33 | 34 | batch1=temp1['batch'] 35 | batch2=temp2['batch'] 36 | batch3=temp3['batch'] 37 | 38 | cycle_life=dict() 39 | 40 | '''the first batch''' 41 | temp=temp1 42 | batch=batch1 43 | for bat_num in range(batch['cycles'].shape[0]): 44 | a=int(list(temp[batch['cycle_life'][bat_num,0]])[0][0]) 45 | #cl=temp[batch['cycle_life'][bat_num,0]].value 46 | cycle_life.update({'a'+str(bat_num):a}) 47 | 48 | '''the second batch''' 49 | temp=temp2 50 | batch=batch2 51 | for bat_num in range(batch['cycles'].shape[0]): 52 | a=int(list(temp[batch['cycle_life'][bat_num,0]])[0][0]) 53 | cycle_life.update({'b'+str(bat_num):a}) 54 | 55 | '''the third batch''' 56 | temp=temp3 57 | batch=batch3 58 | for bat_num in range(batch['cycles'].shape[0]): 59 | if bat_num!=23 and bat_num!=32: 60 | a=int(list(temp[batch['cycle_life'][bat_num,0]])[0][0]) 61 | cycle_life.update({'c'+str(bat_num):a}) 62 | continue 63 | cycle_life.update({'c23':2190}) 64 | cycle_life.update({'c32':2238}) 65 | 66 | # There are four cells from batch1 that carried into batch2, we'll remove the data from batch2 67 | # and put it with the correct cell from batch1 68 | cycle_life['a0']=cycle_life['a0']+cycle_life['b7']-1 69 | cycle_life['a1']=cycle_life['a1']+cycle_life['b8']-1 70 | cycle_life['a2']=cycle_life['a2']+cycle_life['b9']-1 71 | cycle_life['a3']=cycle_life['a3']+cycle_life['b15']-1 72 | cycle_life['a4']=cycle_life['a4']+cycle_life['b16']-1 73 | 74 | #remove batteries that do not reach 80% capacity 75 | del cycle_life['a8'] 76 | del cycle_life['a10'] 77 | del cycle_life['a12'] 78 | del cycle_life['a13'] 79 | del cycle_life['a22'] 80 | 81 | #remove data from a that carried into b 82 | del cycle_life['b7'] 83 | del cycle_life['b8'] 84 | del cycle_life['b9'] 85 | del cycle_life['b15'] 86 | del cycle_life['b16'] 87 | 88 | # remove noisy channels from c 89 | del cycle_life['c37'] 90 | del cycle_life['c2'] 91 | del cycle_life['c23'] 92 | del cycle_life['c32'] 93 | del cycle_life['c38'] 94 | del cycle_life['c39'] 95 | 96 | fea_num = 100 97 | n_cyc = 100 98 | in_stride = 10 99 | stride = 1 100 | 101 | v_low = 3.36 102 | v_upp = 3.60 103 | q_low = 0.61 104 | q_upp = 1.19 105 | lbl_factor = 3000 106 | aux_factor = 1190 107 | 108 | a0_4 = {} 109 | ay0_4 = {} 110 | list_a = [0,1,2,3,4] 111 | list_b = [7,8,9,15,16] 112 | for i,num in enumerate(list_a): 113 | fea_list = [] 114 | label_list = [] 115 | fea_list.append(None) 116 | label_list.append(None) 117 | b_num = list_b[i] 118 | bat_life = cycle_life['a'+str(num)] 119 | cyc_num = int(list(temp1[batch1['cycle_life'][num,0]])[0][0])-1 120 | for j in range(1,cyc_num): 121 | I = list(temp1[temp1[batch1['cycles'][num,0]]['I'][j,:][0]])[0] 122 | try: 123 | left_id = 0 124 | left = np.where(np.abs(I - 1)<0.001)[0][left_id] 125 | while list(temp[temp[batch['cycles'][bat_num,0]]['Qc'][j-1,:][0]])[0][left] < 0.4: 126 | left_id += 1 127 | left = np.where(np.abs(I - 1)<0.001)[0][left_id] 128 | right = np.where(np.abs(I - 1)<0.001)[0][-1] 129 | except:continue 130 | if right - left <=1: 131 | continue 132 | 133 | t = list(temp1[temp1[batch1['cycles'][num,0]]['t'][j,:][0]])[0][left:right] 134 | V = list(temp1[temp1[batch1['cycles'][num,0]]['V'][j,:][0]])[0][left:right] 135 | Vc = change(V,t,fea_num) 136 | Qc = list(temp1[temp1[batch1['cycles'][num,0]]['Qc'][j,:][0]])[0][left:right] 137 | Qc = change(Qc,t,fea_num) 138 | QD = list(temp1[batch1['summary'][num,0]]['QDischarge'][0,:])[j] 139 | tmp_fea = np.hstack((Vc.reshape(-1,1), Qc.reshape(-1,1))) 140 | 141 | fea_list.append(tmp_fea) 142 | label_list.append(QD) 143 | new_num = int(list(temp2[batch2['cycle_life'][b_num,0]])[0][0])-1 144 | for j in range(new_num): 145 | 146 | I = list(temp2[temp2[batch2['cycles'][b_num,0]]['I'][j,:][0]])[0] 147 | try: 148 | left_id = 0 149 | left = np.where(np.abs(I - 1)<0.001)[0][left_id] 150 | while list(temp[temp[batch['cycles'][bat_num,0]]['Qc'][j-1,:][0]])[0][left] < 0.4: 151 | left_id += 1 152 | left = np.where(np.abs(I - 1)<0.001)[0][left_id] 153 | right = np.where(np.abs(I - 1)<0.001)[0][-1] 154 | except:continue 155 | if right - left <=1: 156 | continue 157 | 158 | t = list(temp2[temp2[batch2['cycles'][b_num,0]]['t'][j,:][0]])[0][left:right] 159 | V = list(temp2[temp2[batch2['cycles'][b_num,0]]['V'][j,:][0]])[0][left:right] 160 | Vc = change(V,t,fea_num) 161 | Qc = list(temp2[temp2[batch2['cycles'][b_num,0]]['Qc'][j,:][0]])[0][left:right] 162 | Qc = change(Qc,t,fea_num) 163 | QD = list(temp2[batch2['summary'][b_num,0]]['QDischarge'][0,:])[j] 164 | tmp_fea = np.hstack((Vc.reshape(-1,1), Qc.reshape(-1,1))) 165 | 166 | fea_list.append(tmp_fea) 167 | label_list.append(QD) 168 | a0_4.update({num:fea_list}) 169 | ay0_4.update({num:label_list}) 170 | 171 | numBat1 = 0 172 | numBat2 = 0 173 | numBat3 = 0 174 | for key in cycle_life.keys(): 175 | if 'a' in key: 176 | numBat1 += 1 177 | elif 'b' in key: 178 | numBat2 += 1 179 | elif 'c' in key: 180 | numBat3 += 1 181 | numBat = numBat1 + numBat2 + numBat3 182 | 183 | # Train and Test Split 184 | # If you are interested in using the same train/test split as the paper, use the indices specified below 185 | test_ind = np.hstack((np.arange(0,(numBat1+numBat2),2),83)) 186 | train_ind = np.arange(1,(numBat1+numBat2-1),2) 187 | secondary_test_ind = np.arange(numBat-numBat3,numBat) 188 | 189 | # print(len(train_ind),len(test_ind), len(secondary_test_ind)) 190 | 191 | cycle_train = [] 192 | cycle_test = [] 193 | cycle_secondary = [] 194 | 195 | for i,key in enumerate(cycle_life.keys()): 196 | if i in train_ind:cycle_train.append(key) 197 | elif i in test_ind:cycle_test.append(key) 198 | elif i in secondary_test_ind:cycle_secondary.append(key) 199 | 200 | print(len(cycle_train), len(cycle_test), len(cycle_secondary)) 201 | 202 | def get_xy(cyc_num): 203 | fea = dict() 204 | label = dict() 205 | for i in cyc_num: 206 | key = i 207 | bat_life = cycle_life[key] 208 | fea_i = [] 209 | label_i = [] 210 | aux_lbl = [] 211 | for j in range(11,bat_life): 212 | if key[0]== 'a': 213 | temp=temp1 214 | batch=batch1 215 | elif key[0] == 'b': 216 | temp=temp2 217 | batch=batch2 218 | else: 219 | temp=temp3 220 | batch=batch3 221 | 222 | bat_num=int(key[1:]) 223 | if key[0] == 'a' and bat_num in [0,1,2,3,4]: 224 | try: 225 | tmp_fea = a0_4[bat_num][j-1] 226 | QD = ay0_4[bat_num][j-1] 227 | except:continue 228 | else: 229 | I = list(temp[temp[batch['cycles'][bat_num,0]]['I'][j-1,:][0]])[0] 230 | try: 231 | left_id = 0 232 | left = np.where(np.abs(I - 1)<0.001)[0][left_id] 233 | while list(temp[temp[batch['cycles'][bat_num,0]]['Qc'][j-1,:][0]])[0][left] < 0.4: 234 | left_id += 1 235 | left = np.where(np.abs(I - 1)<0.001)[0][left_id] 236 | right = np.where(np.abs(I - 1)<0.001)[0][-1] 237 | except:continue 238 | if right - left <=1: 239 | continue 240 | 241 | t = list(temp[temp[batch['cycles'][bat_num,0]]['t'][j-1,:][0]])[0][left:right] 242 | V = list(temp[temp[batch['cycles'][bat_num,0]]['V'][j-1,:][0]])[0][left:right] 243 | Qc = list(temp[temp[batch['cycles'][bat_num,0]]['Qc'][j-1,:][0]])[0][left:right] 244 | Vc = change(V,t,fea_num) 245 | Qc = change(Qc,t,fea_num) 246 | QD = list(temp[batch['summary'][bat_num,0]]['QDischarge'][0,:])[j-1] 247 | tmp_fea = np.hstack((Vc.reshape(-1,1), Qc.reshape(-1,1))) 248 | fea_i.append(np.expand_dims(tmp_fea,axis=0)) 249 | label_i.append(bat_life-j) 250 | aux_lbl.append(QD) 251 | 252 | all_fea = np.vstack(fea_i) 253 | all_lbl = np.array(label_i).reshape(-1,1) 254 | aux_lbl = np.array(aux_lbl) 255 | 256 | all_fea_c = all_fea.copy() 257 | all_fea_c[:,:,0] = (all_fea_c[:,:,0]-v_low)/(v_upp-v_low) 258 | all_fea_c[:,:,1] = (all_fea_c[:,:,1]-q_low)/(q_upp-q_low) 259 | dif_fea = all_fea_c - all_fea_c[0:1,:,:] 260 | all_fea = np.concatenate((all_fea,dif_fea),axis=2) 261 | 262 | all_fea = np.lib.stride_tricks.sliding_window_view(all_fea,(n_cyc,fea_num,4)) 263 | aux_lbl = np.lib.stride_tricks.sliding_window_view(aux_lbl,(n_cyc,)) 264 | all_fea = all_fea.squeeze(axis=(1,2,)) 265 | all_lbl = all_lbl[n_cyc-1:] 266 | all_fea = all_fea[::stride] 267 | all_fea = all_fea[:,::in_stride,:,:] 268 | all_lbl = all_lbl[::stride] 269 | aux_lbl = aux_lbl[::stride] 270 | aux_lbl = aux_lbl[:,::in_stride,] 271 | 272 | all_fea_new = np.zeros(all_fea.shape) 273 | all_fea_new[:,:,:,0] = (all_fea[:,:,:,0]-v_low)/(v_upp-v_low) 274 | all_fea_new[:,:,:,1] = (all_fea[:,:,:,1]-q_low)/(q_upp-q_low) 275 | all_fea_new[:,:,:,2] = all_fea[:,:,:,2] 276 | all_fea_new[:,:,:,3] = all_fea[:,:,:,3] 277 | print(f'{key} length is {all_fea_new.shape[0]}', 278 | 'v_max:', '%.4f'%all_fea_new[:,:,:,0].max(), 279 | 'q_max:', '%.4f'%all_fea_new[:,:,:,1].max(), 280 | 'dv_max:', '%.4f'%all_fea_new[:,:,:,2].max(), 281 | 'dq_max:', '%.4f'%all_fea_new[:,:,:,3].max()) 282 | all_lbl = all_lbl / lbl_factor 283 | aux_lbl = aux_lbl / aux_factor 284 | 285 | fea.update({key:all_fea_new}) 286 | label.update({key:np.hstack((all_lbl.reshape(-1,1),aux_lbl))}) 287 | return fea, label 288 | 289 | fea, label = get_xy(cycle_train) 290 | save_obj(fea,'./data/ne_data/fea_train') 291 | save_obj(label,'./data/ne_data/label_train') 292 | 293 | fea, label = get_xy(cycle_test) 294 | save_obj(fea,'./data/ne_data/fea_test') 295 | save_obj(label,'./data/ne_data/label_test') 296 | 297 | fea, label = get_xy(cycle_secondary) 298 | save_obj(fea,'./data/ne_data/fea_sec') 299 | save_obj(label,'./data/ne_data/label_sec') -------------------------------------------------------------------------------- /prepare_nmc_data.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import scipy.io as sio 4 | import matplotlib.pyplot as plt 5 | import time 6 | import os 7 | import pickle 8 | import re 9 | import xlrd 10 | import tqdm 11 | 12 | from scipy import interpolate 13 | from copy import deepcopy 14 | from scipy import stats 15 | from scipy.optimize import leastsq 16 | from scipy.stats import pearsonr 17 | 18 | from common import * 19 | 20 | import warnings 21 | warnings.filterwarnings("ignore") 22 | 23 | data_path = './data/nmc_data/original/' 24 | files = os.listdir(data_path) 25 | files = [file for file in files if 'SNL' in file] 26 | bat_prefix = list(set([i[:-15] for i in files])) 27 | 28 | bat_prefix = tqdm.tqdm(bat_prefix) 29 | for prefix in bat_prefix: 30 | cyc_v = {} 31 | cyc_rul = {} 32 | cyc_dq = {} 33 | 34 | cycle_df = prefix + '_cycle_data.csv' 35 | time_df = prefix + '_timeseries.csv' 36 | cycle_df = pd.read_csv(data_path + cycle_df) 37 | time_df = pd.read_csv(data_path + time_df) 38 | 39 | tmp = cycle_df[['Cycle_Index','Discharge_Capacity (Ah)']] 40 | init_cap = tmp['Discharge_Capacity (Ah)'].iloc[0] 41 | end_cap = init_cap * 0.8 42 | tmp = tmp[tmp['Discharge_Capacity (Ah)'] < end_cap] 43 | tmp = tmp.Cycle_Index.values 44 | for i in range(len(tmp) - 1, 0, -1): 45 | if tmp[i] - tmp[i - 1] != 1: 46 | break 47 | life_cyc = tmp[i] 48 | 49 | cyc_list = [] 50 | for i in range(len(cycle_df)): 51 | if 0 < i < len(cycle_df) - 1: 52 | if abs(cycle_df.iloc[i]['Discharge_Capacity (Ah)'] - cycle_df.iloc[i - 1]['Discharge_Capacity (Ah)']) >= 0.05: 53 | continue 54 | if abs(cycle_df.iloc[i]['Discharge_Capacity (Ah)'] - cycle_df.iloc[i + 1]['Discharge_Capacity (Ah)']) >= 0.05: 55 | continue 56 | tmp = cycle_df.iloc[i] 57 | if 1 <= tmp.Cycle_Index < life_cyc and end_cap <= tmp['Discharge_Capacity (Ah)']<= init_cap: 58 | cyc = tmp.Cycle_Index 59 | cyc_list.append(cyc) 60 | cyc_rul.update({cyc:life_cyc-cyc}) 61 | cyc_dq.update({cyc:tmp['Discharge_Capacity (Ah)']}) 62 | for cyc in cyc_list: 63 | tmp = time_df[time_df.Cycle_Index == cyc] 64 | tmp = tmp.reset_index(drop=True) 65 | tmp['Test_Time (s)'] = tmp['Test_Time (s)'] - tmp['Test_Time (s)'].iloc[0] 66 | cyc_v.update({cyc: tmp}) 67 | 68 | bats_dic = {} 69 | bats_dic.update({prefix:{'rul':cyc_rul, 70 | 'dq':cyc_dq, 71 | 'data':cyc_v}}) 72 | save_obj(bats_dic,'./data/nmc_data/'+prefix) 73 | 74 | pkl_list = os.listdir('./data/nmc_data/') 75 | pkl_list = [i for i in pkl_list if 'SNL' in i] 76 | 77 | train_name = [] 78 | for name in pkl_list: 79 | train_name.append(name[:-4]) 80 | 81 | def get_xy(name): 82 | A = load_obj(f'./data/nmc_data/{name}')[name] 83 | A_rul = A['rul'] 84 | A_dq = A['dq'] 85 | A_df = A['data'] 86 | 87 | all_idx = list(A_dq.keys())[9:] 88 | all_fea, all_lbl, aux_lbl = [], [], [] 89 | for cyc in all_idx: 90 | tmp = A_df[cyc] 91 | 92 | init_cap = tmp['Charge_Capacity (Ah)'].iloc[-1] * 0.8 93 | left = (tmp['Charge_Capacity (Ah)'] > init_cap).argmax() - 20 94 | 95 | current = tmp['Current (A)'].values 96 | for i in range(len(current)): 97 | if current[i] > 0: 98 | break 99 | i += 1 100 | pos = np.where(current < current[i])[0] 101 | for j in pos: 102 | if j > i: 103 | break 104 | right = j + 20 105 | 106 | if left >= right - 1: 107 | continue 108 | 109 | tmp = tmp.iloc[left:right] 110 | 111 | tmp_v = tmp['Voltage (V)'].values 112 | tmp_q = tmp['Charge_Capacity (Ah)'].values 113 | tmp_t = tmp['Test_Time (s)'].values 114 | v_fea = interp(tmp_t, tmp_v, fea_num) 115 | q_fea = interp(tmp_t, tmp_q, fea_num) 116 | 117 | tmp_fea = np.hstack((v_fea.reshape(-1,1), q_fea.reshape(-1,1))) 118 | 119 | all_fea.append(np.expand_dims(tmp_fea,axis=0)) 120 | all_lbl.append(A_rul[cyc]) 121 | aux_lbl.append(A_dq[cyc]) 122 | # print(len(all_fea)) 123 | all_fea = np.vstack(all_fea) 124 | all_lbl = np.array(all_lbl) 125 | aux_lbl = np.array(aux_lbl) 126 | 127 | all_fea_c = all_fea.copy() 128 | all_fea_c[:,:,0] = (all_fea_c[:,:,0]-v_low)/(v_upp-v_low) 129 | all_fea_c[:,:,1] = (all_fea_c[:,:,1]-q_low)/(q_upp-q_low) 130 | dif_fea = all_fea_c - all_fea_c[0:1,:,:] 131 | all_fea = np.concatenate((all_fea,dif_fea),axis=2) 132 | # print(all_fea.shape,all_lbl.shape) 133 | 134 | # print(all_fea.shape) 135 | all_fea = np.lib.stride_tricks.sliding_window_view(all_fea,(n_cyc,fea_num,4)) 136 | aux_lbl = np.lib.stride_tricks.sliding_window_view(aux_lbl,(n_cyc,)) 137 | # print(all_fea.shape) 138 | all_fea = all_fea.squeeze(axis=(1,2,)) 139 | # print(all_fea.shape) 140 | all_lbl = all_lbl[n_cyc-1:] 141 | all_fea = all_fea[::stride] 142 | all_fea = all_fea[:,::in_stride,:,:] 143 | all_lbl = all_lbl[::stride] 144 | aux_lbl = aux_lbl[::stride] 145 | aux_lbl = aux_lbl[:,::in_stride,] 146 | 147 | all_fea_new = np.zeros(all_fea.shape) 148 | all_fea_new[:,:,:,0] = (all_fea[:,:,:,0]-v_low)/(v_upp-v_low) 149 | all_fea_new[:,:,:,1] = (all_fea[:,:,:,1]-q_low)/(q_upp-q_low) 150 | all_fea_new[:,:,:,2] = all_fea[:,:,:,2] 151 | all_fea_new[:,:,:,3] = all_fea[:,:,:,3] 152 | print(f'{name} length is {all_fea_new.shape[0]}', 153 | 'v_max:', '%.4f'%all_fea_new[:,:,:,0].max(), 154 | 'q_max:', '%.4f'%all_fea_new[:,:,:,1].max(), 155 | 'dv_max:', '%.4f'%all_fea_new[:,:,:,2].max(), 156 | 'dq_max:', '%.4f'%all_fea_new[:,:,:,3].max()) 157 | all_lbl = all_lbl / lbl_factor 158 | aux_lbl = aux_lbl / aux_factor 159 | 160 | return all_fea_new,np.hstack((all_lbl.reshape(-1,1),aux_lbl)) 161 | 162 | n_cyc = 30 163 | in_stride = 3 164 | fea_num = 100 165 | 166 | v_low = 3 167 | v_upp = 4.3 168 | q_low = 1.2 169 | q_upp = 2.9 170 | lbl_factor = 2000 171 | aux_factor = 2.9 172 | 173 | stride = 1 174 | all_loader = dict() 175 | all_fea = [] 176 | all_lbl = [] 177 | print('----init_train----') 178 | for name in train_name: 179 | tmp_fea, tmp_lbl = get_xy(name) 180 | all_loader.update({name:{'fea':tmp_fea,'lbl':tmp_lbl}}) 181 | all_fea.append(tmp_fea) 182 | all_lbl.append(tmp_lbl) 183 | save_obj(all_loader,'./data/nmc_data/nmc_loader') -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==3.3.4 2 | numpy==1.20.1 3 | pandas==1.2.4 4 | scikit_learn==1.1.1 5 | scipy==1.6.2 6 | tqdm==4.64.0 7 | -------------------------------------------------------------------------------- /tool.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from copy import deepcopy 4 | import numpy as np 5 | import pandas as pd 6 | import torch.nn as nn 7 | import random 8 | from tqdm import tqdm 9 | from sklearn.model_selection import train_test_split 10 | from torch.utils.data import Dataset, DataLoader 11 | 12 | class EarlyStopping: 13 | """Early stops the training if validation loss doesn't improve after a given patience.""" 14 | def __init__(self, patience=7, verbose=False, delta=0): 15 | """ 16 | Args: 17 | patience (int): How long to wait after last time validation loss improved. 18 | Default: 7 19 | verbose (bool): If True, prints a message for each validation loss improvement. 20 | Default: False 21 | delta (float): Minimum change in the monitored quantity to qualify as an improvement. 22 | Default: 0 23 | """ 24 | self.patience = patience 25 | self.verbose = verbose 26 | self.counter = 0 27 | self.best_score = None 28 | self.early_stop = False 29 | self.val_loss_min = np.Inf 30 | self.delta = delta 31 | 32 | def __call__(self, val_loss, model, check_name='checkpoint.pt'): 33 | 34 | score = -val_loss 35 | 36 | if self.best_score is None: 37 | self.best_score = score 38 | self.save_checkpoint(val_loss, model, check_name) 39 | elif score < self.best_score + self.delta: 40 | self.counter += 1 41 | if self.counter >= self.patience: 42 | self.early_stop = True 43 | else: 44 | self.best_score = score 45 | self.save_checkpoint(val_loss, model, check_name) 46 | self.counter = 0 47 | 48 | def save_checkpoint(self, val_loss, model, check_name): 49 | '''Saves model when validation loss decrease.''' 50 | torch.save(model.state_dict(), check_name) 51 | self.val_loss_min = val_loss 52 | --------------------------------------------------------------------------------