├── redd.png ├── uk-dale.png ├── config.py ├── data ├── uk_dale │ ├── house_3 │ │ └── labels.dat │ ├── house_4 │ │ └── labels.dat │ ├── house_2 │ │ └── labels.dat │ ├── house_5 │ │ └── labels.dat │ └── house_1 │ │ └── labels.dat └── redd_lf │ ├── house_2 │ └── labels.dat │ ├── house_6 │ └── labels.dat │ ├── house_1 │ └── labels.dat │ ├── house_4 │ └── labels.dat │ ├── house_3 │ └── labels.dat │ └── house_5 │ └── labels.dat ├── README.md ├── dataloader.py ├── train.py ├── utils.py ├── model.py ├── dataset.py └── trainer.py /redd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yueeeeeeee/BERT4NILM/HEAD/redd.png -------------------------------------------------------------------------------- /uk-dale.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yueeeeeeee/BERT4NILM/HEAD/uk-dale.png -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | RAW_DATASET_ROOT_FOLDER = 'data' 2 | EXPERIMENT_ROOT_FOLDER = 'experiments' 3 | -------------------------------------------------------------------------------- /data/uk_dale/house_3/labels.dat: -------------------------------------------------------------------------------- 1 | 1 aggregate 2 | 2 kettle 3 | 3 electric_heater 4 | 4 laptop 5 | 5 projector 6 | -------------------------------------------------------------------------------- /data/uk_dale/house_4/labels.dat: -------------------------------------------------------------------------------- 1 | 1 aggregate 2 | 2 tv_dvd_digibox_lamp 3 | 3 kettle 4 | 4 gas_boiler 5 | 5 freezer 6 | 6 washing_machine 7 | -------------------------------------------------------------------------------- /data/redd_lf/house_2/labels.dat: -------------------------------------------------------------------------------- 1 | 1 mains 2 | 2 mains 3 | 3 kitchen_outlets 4 | 4 lighting 5 | 5 stove 6 | 6 microwave 7 | 7 washer_dryer 8 | 8 kitchen_outlets 9 | 9 refrigerator 10 | 10 dishwasher 11 | 11 disposal 12 | -------------------------------------------------------------------------------- /data/uk_dale/house_2/labels.dat: -------------------------------------------------------------------------------- 1 | 1 aggregate 2 | 2 laptop 3 | 3 monitor 4 | 4 speakers 5 | 5 server 6 | 6 router 7 | 7 server_hdd 8 | 8 kettle 9 | 9 rice_cooker 10 | 10 running_machine 11 | 11 laptop2 12 | 12 washing_machine 13 | 13 dishwasher 14 | 14 fridge 15 | 15 microwave 16 | 16 toaster 17 | 17 playstation 18 | 18 modem 19 | 19 cooker 20 | -------------------------------------------------------------------------------- /data/redd_lf/house_6/labels.dat: -------------------------------------------------------------------------------- 1 | 1 mains 2 | 2 mains 3 | 3 kitchen_outlets 4 | 4 washer_dryer 5 | 5 stove 6 | 6 electronics 7 | 7 bathroom_gfi 8 | 8 refrigerator 9 | 9 dishwasher 10 | 10 outlets_unknown 11 | 11 outlets_unknown 12 | 12 electric_heat 13 | 13 kitchen_outlets 14 | 14 lighting 15 | 15 air_conditioning 16 | 16 air_conditioning 17 | 17 air_conditioning 18 | -------------------------------------------------------------------------------- /data/redd_lf/house_1/labels.dat: -------------------------------------------------------------------------------- 1 | 1 mains 2 | 2 mains 3 | 3 oven 4 | 4 oven 5 | 5 refrigerator 6 | 6 dishwasher 7 | 7 kitchen_outlets 8 | 8 kitchen_outlets 9 | 9 lighting 10 | 10 washer_dryer 11 | 11 microwave 12 | 12 bathroom_gfi 13 | 13 electric_heat 14 | 14 stove 15 | 15 kitchen_outlets 16 | 16 kitchen_outlets 17 | 17 lighting 18 | 18 lighting 19 | 19 washer_dryer 20 | 20 washer_dryer 21 | -------------------------------------------------------------------------------- /data/redd_lf/house_4/labels.dat: -------------------------------------------------------------------------------- 1 | 1 mains 2 | 2 mains 3 | 3 lighting 4 | 4 furance 5 | 5 kitchen_outlets 6 | 6 outlets_unknown 7 | 7 washer_dryer 8 | 8 stove 9 | 9 air_conditioning 10 | 10 air_conditioning 11 | 11 miscellaeneous 12 | 12 smoke_alarms 13 | 13 lighting 14 | 14 kitchen_outlets 15 | 15 dishwasher 16 | 16 bathroom_gfi 17 | 17 bathroom_gfi 18 | 18 lighting 19 | 19 lighting 20 | 20 air_conditioning 21 | -------------------------------------------------------------------------------- /data/redd_lf/house_3/labels.dat: -------------------------------------------------------------------------------- 1 | 1 mains 2 | 2 mains 3 | 3 outlets_unknown 4 | 4 outlets_unknown 5 | 5 lighting 6 | 6 electronics 7 | 7 refrigerator 8 | 8 disposal 9 | 9 dishwasher 10 | 10 furance 11 | 11 lighting 12 | 12 outlets_unknown 13 | 13 washer_dryer 14 | 14 washer_dryer 15 | 15 lighting 16 | 16 microwave 17 | 17 lighting 18 | 18 smoke_alarms 19 | 19 lighting 20 | 20 bathroom_gfi 21 | 21 kitchen_outlets 22 | 22 kitchen_outlets 23 | -------------------------------------------------------------------------------- /data/uk_dale/house_5/labels.dat: -------------------------------------------------------------------------------- 1 | 1 aggregate 2 | 2 stereo_speakers_bedroom 3 | 3 i7_desktop 4 | 4 hairdryer 5 | 5 primary_tv 6 | 6 24_inch_lcd_bedroom 7 | 7 treadmill 8 | 8 network_attached_storage 9 | 9 core2_server 10 | 10 24_inch_lcd 11 | 11 PS4 12 | 12 steam_iron 13 | 13 nespresso_pixie 14 | 14 atom_pc 15 | 15 toaster 16 | 16 home_theatre_amp 17 | 17 sky_hd_box 18 | 18 kettle 19 | 19 fridge 20 | 20 oven 21 | 21 electric_hob 22 | 22 dishwasher 23 | 23 microwave 24 | 24 washing_machine 25 | 25 vacuum_cleaner 26 | -------------------------------------------------------------------------------- /data/redd_lf/house_5/labels.dat: -------------------------------------------------------------------------------- 1 | 1 mains 2 | 2 mains 3 | 3 microwave 4 | 4 lighting 5 | 5 outlets_unknown 6 | 6 furance 7 | 7 outlets_unknown 8 | 8 washer_dryer 9 | 9 washer_dryer 10 | 10 subpanel 11 | 11 subpanel 12 | 12 electric_heat 13 | 13 electric_heat 14 | 14 lighting 15 | 15 outlets_unknown 16 | 16 bathroom_gfi 17 | 17 lighting 18 | 18 refrigerator 19 | 19 lighting 20 | 20 dishwasher 21 | 21 disposal 22 | 22 electronics 23 | 23 lighting 24 | 24 kitchen_outlets 25 | 25 kitchen_outlets 26 | 26 outdoor_outlets 27 | -------------------------------------------------------------------------------- /data/uk_dale/house_1/labels.dat: -------------------------------------------------------------------------------- 1 | 1 aggregate 2 | 2 boiler 3 | 3 solar_thermal_pump 4 | 4 laptop 5 | 5 washing_machine 6 | 6 dishwasher 7 | 7 tv 8 | 8 kitchen_lights 9 | 9 htpc 10 | 10 kettle 11 | 11 toaster 12 | 12 fridge 13 | 13 microwave 14 | 14 lcd_office 15 | 15 hifi_office 16 | 16 breadmaker 17 | 17 amp_livingroom 18 | 18 adsl_router 19 | 19 livingroom_s_lamp 20 | 20 soldering_iron 21 | 21 gigE_&_USBhub 22 | 22 hoover 23 | 23 kitchen_dt_lamp 24 | 24 bedroom_ds_lamp 25 | 25 lighting_circuit 26 | 26 livingroom_s_lamp2 27 | 27 iPad_charger 28 | 28 subwoofer_livingroom 29 | 29 livingroom_lamp_tv 30 | 30 DAB_radio_livingroom 31 | 31 kitchen_lamp2 32 | 32 kitchen_phone&stereo 33 | 33 utilityrm_lamp 34 | 34 samsung_charger 35 | 35 bedroom_d_lamp 36 | 36 coffee_machine 37 | 37 kitchen_radio 38 | 38 bedroom_chargers 39 | 39 hair_dryer 40 | 40 straighteners 41 | 41 iron 42 | 42 gas_oven 43 | 43 data_logger_pc 44 | 44 childs_table_lamp 45 | 45 childs_ds_lamp 46 | 46 baby_monitor_tx 47 | 47 battery_charger 48 | 48 office_lamp1 49 | 49 office_lamp2 50 | 50 office_lamp3 51 | 51 office_pc 52 | 52 office_fan 53 | 53 LED_printer 54 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BERT4NILM 2 | 3 | PyTorch Implementation of BERT4NILM: A Bidirectional Transformer Model for Non-Intrusive Load Monitoring 4 | 5 | 6 | ## Data 7 | 8 | The csv datasets could be downloaded here: [REDD](http://redd.csail.mit.edu/) and [UK-DALE](https://jack-kelly.com/data/) 9 | 10 | We took the liberty of modifying certain appliance names to 'dishwasher', 'fridge', 'microwave', 'washing_machine' and 'kettle' in the 'labels.dat' file, see data folder 11 | 12 | 13 | ## Training 14 | 15 | This is the PyTorch implementation of BERT4NILM, a bidirectional encoder representations from rransformers for energy disaggregation, in this repository we provide the BERT4NILM model as well as data functions for low frequency REDD dataset / UK Dale dataset, run following command to train an initial model, hyper-parameters (as well as appliances) could be tuned in utils.py, test will run after training ends: 16 | 17 | ```bash 18 | python train.py 19 | ``` 20 | 21 | The trained model state dict will be saved under 'experiments/dataset-name/best_acc_model.pth' 22 | 23 | 24 | ## Performance 25 | 26 | Our models are trained 100 / 20 epochs repspectively for appliances from REDD and UK-DALE dataset, all other parameters could be found in 'train.py' and 'utils.py' 27 | 28 | ### REDD 29 | 30 | 31 | 32 | ### UK-DALE 33 | 34 | 35 | 36 | 37 | ## Citing 38 | Please cite the following paper if you use our methods in your research: 39 | ``` 40 | @inproceedings{yue2020bert4nilm, 41 | title={BERT4NILM: A Bidirectional Transformer Model for Non-Intrusive Load Monitoring}, 42 | author={Yue, Zhenrui and Witzig, Camilo Requena and Jorde, Daniel and Jacobsen, Hans-Arno}, 43 | booktitle={Proceedings of the 5th International Workshop on Non-Intrusive Load Monitoring}, 44 | pages={89--93}, 45 | year={2020} 46 | } 47 | ``` 48 | 49 | 50 | ## Acknowledgement 51 | 52 | During the implementation we base our code mostly on the [BERT-pytorch](https://github.com/codertimo/BERT-pytorch) by Junseong Kim, we are also inspired by the [BERT4Rec](https://github.com/jaywonchung/BERT4Rec-VAE-Pytorch) implementation by Jaewon Chung and [Transformers](https://github.com/huggingface/transformers) from Hugging Face. Many thanks to these authors for their great work! 53 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | import torch.utils.data as data_utils 5 | 6 | 7 | torch.set_default_tensor_type(torch.DoubleTensor) 8 | 9 | 10 | class NILMDataloader(): 11 | def __init__(self, args, dataset, bert=False): 12 | self.args = args 13 | self.mask_prob = args.mask_prob 14 | self.batch_size = args.batch_size 15 | 16 | if bert: 17 | self.train_dataset, self.val_dataset = dataset.get_bert_datasets(mask_prob=self.mask_prob) 18 | else: 19 | self.train_dataset, self.val_dataset = dataset.get_datasets() 20 | 21 | @classmethod 22 | def code(cls): 23 | return 'dataloader' 24 | 25 | def get_dataloaders(self): 26 | train_loader = self._get_loader(self.train_dataset) 27 | val_loader = self._get_loader(self.val_dataset) 28 | return train_loader, val_loader 29 | 30 | def _get_loader(self, dataset): 31 | dataloader = data_utils.DataLoader( 32 | dataset, batch_size=self.batch_size, shuffle=False, pin_memory=True) 33 | return dataloader 34 | 35 | 36 | class NILMDataset(data_utils.Dataset): 37 | def __init__(self, x, y, status, window_size=480, stride=30): 38 | self.x = x 39 | self.y = y 40 | self.status = status 41 | self.window_size = window_size 42 | self.stride = stride 43 | 44 | def __len__(self): 45 | return int(np.ceil((len(self.x) - self.window_size) / self.stride) + 1) 46 | 47 | def __getitem__(self, index): 48 | start_index = index * self.stride 49 | end_index = np.min( 50 | (len(self.x), index * self.stride + self.window_size)) 51 | x = self.padding_seqs(self.x[start_index: end_index]) 52 | y = self.padding_seqs(self.y[start_index: end_index]) 53 | status = self.padding_seqs(self.status[start_index: end_index]) 54 | return torch.tensor(x), torch.tensor(y), torch.tensor(status) 55 | 56 | def padding_seqs(self, in_array): 57 | if len(in_array) == self.window_size: 58 | return in_array 59 | try: 60 | out_array = np.zeros((self.window_size, in_array.shape[1])) 61 | except: 62 | out_array = np.zeros(self.window_size) 63 | 64 | length = len(in_array) 65 | out_array[:length] = in_array 66 | return out_array 67 | 68 | 69 | class BERTDataset(data_utils.Dataset): 70 | def __init__(self, x, y, status, window_size=480, stride=30, mask_prob=0.2): 71 | self.x = x 72 | self.y = y 73 | self.status = status 74 | self.window_size = window_size 75 | self.stride = stride 76 | self.mask_prob = mask_prob 77 | self.columns = y.shape[1] 78 | 79 | def __len__(self): 80 | return int(np.ceil((len(self.x) - self.window_size) / self.stride) + 1) 81 | 82 | def __getitem__(self, index): 83 | start_index = index * self.stride 84 | end_index = np.min( 85 | (len(self.x), index * self.stride + self.window_size)) 86 | x = self.padding_seqs(self.x[start_index: end_index]) 87 | y = self.padding_seqs(self.y[start_index: end_index]) 88 | status = self.padding_seqs(self.status[start_index: end_index]) 89 | 90 | tokens = [] 91 | labels = [] 92 | on_offs = [] 93 | for i in range(len(x)): 94 | prob = random.random() 95 | if prob < self.mask_prob: 96 | prob = random.random() 97 | if prob < 0.8: 98 | tokens.append(-1) 99 | elif prob < 0.9: 100 | tokens.append(np.random.normal()) 101 | else: 102 | tokens.append(x[i]) 103 | 104 | labels.append(y[i]) 105 | on_offs.append(status[i]) 106 | else: 107 | tokens.append(x[i]) 108 | temp = np.array([-1] * self.columns) 109 | labels.append(temp) 110 | on_offs.append(temp) 111 | 112 | return torch.tensor(tokens), torch.tensor(labels), torch.tensor(on_offs) 113 | 114 | def padding_seqs(self, in_array): 115 | if len(in_array) == self.window_size: 116 | return in_array 117 | try: 118 | out_array = np.zeros((self.window_size, in_array.shape[1])) 119 | except: 120 | out_array = np.zeros(self.window_size) 121 | 122 | length = len(in_array) 123 | out_array[:length] = in_array 124 | return out_array 125 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from dataset import * 2 | from dataloader import * 3 | from trainer import * 4 | from config import * 5 | from utils import * 6 | from model import BERT4NILM 7 | 8 | import argparse 9 | import torch 10 | 11 | 12 | def train(args, export_root=None, resume=True): 13 | args.validation_size = 0.1 14 | if args.dataset_code == 'redd_lf': 15 | args.house_indicies = [2, 3, 4, 5, 6] 16 | dataset = REDD_LF_Dataset(args) 17 | elif args.dataset_code == 'uk_dale': 18 | args.house_indicies = [1, 3, 4, 5] 19 | dataset = UK_DALE_Dataset(args) 20 | 21 | x_mean, x_std = dataset.get_mean_std() 22 | stats = (x_mean, x_std) 23 | 24 | model = BERT4NILM(args) 25 | 26 | if export_root == None: 27 | folder_name = '-'.join(args.appliance_names) 28 | export_root = 'experiments/' + args.dataset_code + '/' + folder_name 29 | 30 | dataloader = NILMDataloader(args, dataset, bert=True) 31 | train_loader, val_loader = dataloader.get_dataloaders() 32 | 33 | trainer = Trainer(args, model, train_loader, 34 | val_loader, stats, export_root) 35 | if args.num_epochs > 0: 36 | if resume: 37 | try: 38 | model.load_state_dict(torch.load(os.path.join( 39 | export_root, 'best_acc_model.pth'), map_location='cpu')) 40 | print('Successfully loaded previous model, continue training...') 41 | except FileNotFoundError: 42 | print('Failed to load old model, continue training new model...') 43 | trainer.train() 44 | 45 | args.validation_size = 1. 46 | if args.dataset_code == 'redd_lf': 47 | args.house_indicies = [1] 48 | dataset = REDD_LF_Dataset(args, stats) 49 | elif args.dataset_code == 'uk_dale': 50 | args.house_indicies = [2] 51 | dataset = UK_DALE_Dataset(args, stats) 52 | 53 | dataloader = NILMDataloader(args, dataset) 54 | _, test_loader = dataloader.get_dataloaders() 55 | rel_err, abs_err, acc, prec, recall, f1 = trainer.test(test_loader) 56 | print('Mean Accuracy:', acc) 57 | print('Mean F1-Score:', f1) 58 | print('Mean Relative Error:', rel_err) 59 | print('Mean Absolute Error:', abs_err) 60 | 61 | 62 | def fix_random_seed_as(random_seed): 63 | random.seed(random_seed) 64 | torch.manual_seed(random_seed) 65 | torch.cuda.manual_seed_all(random_seed) 66 | np.random.seed(random_seed) 67 | 68 | 69 | torch.set_default_tensor_type(torch.DoubleTensor) 70 | parser = argparse.ArgumentParser() 71 | parser.add_argument('--seed', type=int, default=12345) 72 | parser.add_argument('--dataset_code', type=str, 73 | default='redd_lf', choices=['redd_lf', 'uk_dale']) 74 | parser.add_argument('--validation_size', type=float, default=0.2) 75 | parser.add_argument('--batch_size', type=int, default=128) 76 | parser.add_argument('--house_indicies', type=list, default=[1, 2, 3, 4, 5]) 77 | parser.add_argument('--appliance_names', type=list, 78 | default=['microwave', 'dishwasher']) 79 | parser.add_argument('--sampling', type=str, default='6s') 80 | parser.add_argument('--cutoff', type=dict, default=None) 81 | parser.add_argument('--threshold', type=dict, default=None) 82 | parser.add_argument('--min_on', type=dict, default=None) 83 | parser.add_argument('--min_off', type=dict, default=None) 84 | parser.add_argument('--window_size', type=int, default=480) 85 | parser.add_argument('--window_stride', type=int, default=120) 86 | parser.add_argument('--normalize', type=str, default='mean', 87 | choices=['mean', 'minmax']) 88 | parser.add_argument('--denom', type=int, default=2000) 89 | parser.add_argument('--model_size', type=str, default='gru', 90 | choices=['gru', 'lstm', 'dae']) 91 | parser.add_argument('--output_size', type=int, default=1) 92 | parser.add_argument('--drop_out', type=float, default=0.1) 93 | parser.add_argument('--mask_prob', type=float, default=0.25) 94 | parser.add_argument('--device', type=str, default='cpu', 95 | choices=['cpu', 'cuda']) 96 | parser.add_argument('--optimizer', type=str, 97 | default='adam', choices=['sgd', 'adam', 'adamw']) 98 | parser.add_argument('--lr', type=float, default=1e-4) 99 | parser.add_argument('--weight_decay', type=float, default=0.) 100 | parser.add_argument('--momentum', type=float, default=None) 101 | parser.add_argument('--decay_step', type=int, default=100) 102 | parser.add_argument('--gamma', type=float, default=0.1) 103 | parser.add_argument('--num_epochs', type=int, default=100) 104 | parser.add_argument('--c0', type=dict, default=None) 105 | 106 | args = parser.parse_args() 107 | 108 | 109 | if __name__ == "__main__": 110 | fix_random_seed_as(args.seed) 111 | get_user_input(args) 112 | set_template(args) 113 | train(args) 114 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from sklearn.metrics import confusion_matrix 4 | 5 | 6 | def get_user_input(args): 7 | if torch.cuda.is_available(): 8 | args.device = 'cuda:' + input('Input GPU ID: ') 9 | else: 10 | args.device = 'cpu' 11 | 12 | dataset_code = {'r': 'redd_lf', 'u': 'uk_dale'} 13 | args.dataset_code = dataset_code[input( 14 | 'Input r for REDD, u for UK_DALE: ')] 15 | 16 | if args.dataset_code == 'redd_lf': 17 | app_dict = { 18 | 'r': ['refrigerator'], 19 | 'w': ['washer_dryer'], 20 | 'm': ['microwave'], 21 | 'd': ['dishwasher'], 22 | } 23 | args.appliance_names = app_dict[input( 24 | 'Input r, w, m or d for target appliance: ')] 25 | 26 | elif args.dataset_code == 'uk_dale': 27 | app_dict = { 28 | 'k': ['kettle'], 29 | 'f': ['fridge'], 30 | 'w': ['washing_machine'], 31 | 'm': ['microwave'], 32 | 'd': ['dishwasher'], 33 | } 34 | args.appliance_names = app_dict[input( 35 | 'Input k, f, w, m or d for target appliance: ')] 36 | 37 | args.num_epochs = int(input('Input training epochs: ')) 38 | 39 | 40 | def set_template(args): 41 | args.output_size = len(args.appliance_names) 42 | if args.dataset_code == 'redd_lf': 43 | args.window_stride = 120 44 | args.house_indicies = [1, 2, 3, 4, 5, 6] 45 | 46 | args.cutoff = { 47 | 'aggregate': 6000, 48 | 'refrigerator': 400, 49 | 'washer_dryer': 3500, 50 | 'microwave': 1800, 51 | 'dishwasher': 1200 52 | } 53 | 54 | args.threshold = { 55 | 'refrigerator': 50, 56 | 'washer_dryer': 20, 57 | 'microwave': 200, 58 | 'dishwasher': 10 59 | } 60 | 61 | args.min_on = { 62 | 'refrigerator': 10, 63 | 'washer_dryer': 300, 64 | 'microwave': 2, 65 | 'dishwasher': 300 66 | } 67 | 68 | args.min_off = { 69 | 'refrigerator': 2, 70 | 'washer_dryer': 26, 71 | 'microwave': 5, 72 | 'dishwasher': 300 73 | } 74 | 75 | args.c0 = { 76 | 'refrigerator': 1e-6, 77 | 'washer_dryer': 0.001, 78 | 'microwave': 1., 79 | 'dishwasher': 1. 80 | } 81 | 82 | elif args.dataset_code == 'uk_dale': 83 | args.window_stride = 240 84 | args.house_indicies = [1, 2, 3, 4, 5] 85 | 86 | args.cutoff = { 87 | 'aggregate': 6000, 88 | 'kettle': 3100, 89 | 'fridge': 300, 90 | 'washing_machine': 2500, 91 | 'microwave': 3000, 92 | 'dishwasher': 2500 93 | } 94 | 95 | args.threshold = { 96 | 'kettle': 2000, 97 | 'fridge': 50, 98 | 'washing_machine': 20, 99 | 'microwave': 200, 100 | 'dishwasher': 10 101 | } 102 | 103 | args.min_on = { 104 | 'kettle': 2, 105 | 'fridge': 10, 106 | 'washing_machine': 300, 107 | 'microwave': 2, 108 | 'dishwasher': 300 109 | } 110 | 111 | args.min_off = { 112 | 'kettle': 0, 113 | 'fridge': 2, 114 | 'washing_machine': 26, 115 | 'microwave': 5, 116 | 'dishwasher': 300 117 | } 118 | 119 | args.c0 = { 120 | 'kettle': 1., 121 | 'fridge': 1e-6, 122 | 'washing_machine': 0.01, 123 | 'microwave': 1., 124 | 'dishwasher': 1. 125 | } 126 | 127 | args.optimizer = 'adam' 128 | args.lr = 1e-4 129 | args.enable_lr_schedule = False 130 | args.batch_size = 128 131 | 132 | 133 | def acc_precision_recall_f1_score(pred, status): 134 | assert pred.shape == status.shape 135 | 136 | pred = pred.reshape(-1, pred.shape[-1]) 137 | status = status.reshape(-1, status.shape[-1]) 138 | accs, precisions, recalls, f1_scores = [], [], [], [] 139 | 140 | for i in range(status.shape[-1]): 141 | tn, fp, fn, tp = confusion_matrix(status[:, i], pred[:, i], labels=[ 142 | 0, 1]).ravel() 143 | acc = (tn + tp) / (tn + fp + fn + tp) 144 | precision = tp / np.max((tp + fp, 1e-9)) 145 | recall = tp / np.max((tp + fn, 1e-9)) 146 | f1_score = 2 * (precision * recall) / \ 147 | np.max((precision + recall, 1e-9)) 148 | 149 | accs.append(acc) 150 | precisions.append(precision) 151 | recalls.append(recall) 152 | f1_scores.append(f1_score) 153 | 154 | return np.array(accs), np.array(precisions), np.array(recalls), np.array(f1_scores) 155 | 156 | 157 | def relative_absolute_error(pred, label): 158 | assert pred.shape == label.shape 159 | 160 | pred = pred.reshape(-1, pred.shape[-1]) 161 | label = label.reshape(-1, label.shape[-1]) 162 | temp = np.full(label.shape, 1e-9) 163 | relative, absolute, sum_err = [], [], [] 164 | 165 | for i in range(label.shape[-1]): 166 | relative_error = np.mean(np.nan_to_num(np.abs(label[:, i] - pred[:, i]) / np.max( 167 | (label[:, i], pred[:, i], temp[:, i]), axis=0))) 168 | absolute_error = np.mean(np.abs(label[:, i] - pred[:, i])) 169 | 170 | relative.append(relative_error) 171 | absolute.append(absolute_error) 172 | 173 | return np.array(relative), np.array(absolute) 174 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class GELU(nn.Module): 8 | def forward(self, x): 9 | return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 10 | 11 | 12 | class PositionalEmbedding(nn.Module): 13 | def __init__(self, max_len, d_model): 14 | super().__init__() 15 | self.pe = nn.Embedding(max_len, d_model) 16 | 17 | def forward(self, x): 18 | batch_size = x.size(0) 19 | return self.pe.weight.unsqueeze(0).repeat(batch_size, 1, 1) 20 | 21 | 22 | class LayerNorm(nn.Module): 23 | def __init__(self, features, eps=1e-6): 24 | super(LayerNorm, self).__init__() 25 | self.weight = nn.Parameter(torch.ones(features)) 26 | self.bias = nn.Parameter(torch.zeros(features)) 27 | self.eps = eps 28 | 29 | def forward(self, x): 30 | mean = x.mean(-1, keepdim=True) 31 | std = x.std(-1, keepdim=True) 32 | return self.weight * (x - mean) / (std + self.eps) + self.bias 33 | 34 | 35 | class Attention(nn.Module): 36 | def forward(self, query, key, value, mask=None, dropout=None): 37 | scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(query.size(-1)) 38 | if mask is not None: 39 | scores = scores.masked_fill(mask == 0, -1e9) 40 | 41 | p_attn = F.softmax(scores, dim=-1) 42 | if dropout is not None: 43 | p_attn = dropout(p_attn) 44 | 45 | return torch.matmul(p_attn, value), p_attn 46 | 47 | 48 | class MultiHeadedAttention(nn.Module): 49 | def __init__(self, h, d_model, dropout=0.1): 50 | super().__init__() 51 | assert d_model % h == 0 52 | 53 | self.d_k = d_model // h 54 | self.h = h 55 | 56 | self.linear_layers = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(3)]) 57 | self.output_linear = nn.Linear(d_model, d_model) 58 | self.attention = Attention() 59 | 60 | self.dropout = nn.Dropout(p=dropout) 61 | 62 | def forward(self, query, key, value, mask=None): 63 | batch_size = query.size(0) 64 | 65 | query, key, value = [l(x).view(batch_size, -1, self.h, self.d_k).transpose(1, 2) 66 | for l, x in zip(self.linear_layers, (query, key, value))] 67 | 68 | x, attn = self.attention( 69 | query, key, value, mask=mask, dropout=self.dropout) 70 | 71 | x = x.transpose(1, 2).contiguous().view( 72 | batch_size, -1, self.h * self.d_k) 73 | 74 | return self.output_linear(x) 75 | 76 | 77 | class PositionwiseFeedForward(nn.Module): 78 | def __init__(self, d_model, d_ff): 79 | super(PositionwiseFeedForward, self).__init__() 80 | self.w_1 = nn.Linear(d_model, d_ff) 81 | self.w_2 = nn.Linear(d_ff, d_model) 82 | self.activation = GELU() 83 | 84 | def forward(self, x): 85 | return self.w_2(self.activation(self.w_1(x))) 86 | 87 | 88 | class SublayerConnection(nn.Module): 89 | def __init__(self, size, dropout): 90 | super(SublayerConnection, self).__init__() 91 | self.layer_norm = LayerNorm(size) 92 | self.dropout = nn.Dropout(dropout) 93 | 94 | def forward(self, x, sublayer): 95 | return self.layer_norm(x + self.dropout(sublayer(x))) 96 | 97 | 98 | class TransformerBlock(nn.Module): 99 | def __init__(self, hidden, attn_heads, feed_forward_hidden, dropout): 100 | super().__init__() 101 | self.attention = MultiHeadedAttention( 102 | h=attn_heads, d_model=hidden, dropout=dropout) 103 | self.feed_forward = PositionwiseFeedForward( 104 | d_model=hidden, d_ff=feed_forward_hidden) 105 | self.input_sublayer = SublayerConnection(size=hidden, dropout=dropout) 106 | self.output_sublayer = SublayerConnection(size=hidden, dropout=dropout) 107 | self.dropout = nn.Dropout(p=dropout) 108 | 109 | def forward(self, x, mask): 110 | x = self.input_sublayer( 111 | x, lambda _x: self.attention.forward(_x, _x, _x, mask=mask)) 112 | x = self.output_sublayer(x, self.feed_forward) 113 | return self.dropout(x) 114 | 115 | 116 | class BERT4NILM(nn.Module): 117 | def __init__(self, args): 118 | super().__init__() 119 | self.args = args 120 | 121 | self.original_len = args.window_size 122 | self.latent_len = int(self.original_len / 2) 123 | self.dropout_rate = args.drop_out 124 | 125 | self.hidden = 256 126 | self.heads = 2 127 | self.n_layers = 2 128 | self.output_size = args.output_size 129 | 130 | self.conv = nn.Conv1d(in_channels=1, out_channels=self.hidden, 131 | kernel_size=5, stride=1, padding=2, padding_mode='replicate') 132 | self.pool = nn.LPPool1d(norm_type=2, kernel_size=2, stride=2) 133 | 134 | self.position = PositionalEmbedding( 135 | max_len=self.latent_len, d_model=self.hidden) 136 | self.layer_norm = LayerNorm(self.hidden) 137 | self.dropout = nn.Dropout(p=self.dropout_rate) 138 | 139 | self.transformer_blocks = nn.ModuleList([TransformerBlock( 140 | self.hidden, self.heads, self.hidden * 4, self.dropout_rate) for _ in range(self.n_layers)]) 141 | 142 | self.deconv = nn.ConvTranspose1d( 143 | in_channels=self.hidden, out_channels=self.hidden, kernel_size=4, stride=2, padding=1) 144 | self.linear1 = nn.Linear(self.hidden, 128) 145 | self.linear2 = nn.Linear(128, self.output_size) 146 | 147 | self.truncated_normal_init() 148 | 149 | def truncated_normal_init(self, mean=0, std=0.02, lower=-0.04, upper=0.04): 150 | params = list(self.named_parameters()) 151 | for n, p in params: 152 | if 'layer_norm' in n: 153 | continue 154 | else: 155 | with torch.no_grad(): 156 | l = (1. + math.erf(((lower - mean) / std) / math.sqrt(2.))) / 2. 157 | u = (1. + math.erf(((upper - mean) / std) / math.sqrt(2.))) / 2. 158 | p.uniform_(2 * l - 1, 2 * u - 1) 159 | p.erfinv_() 160 | p.mul_(std * math.sqrt(2.)) 161 | p.add_(mean) 162 | 163 | def forward(self, sequence): 164 | x_token = self.pool(self.conv(sequence.unsqueeze(1))).permute(0, 2, 1) 165 | embedding = x_token + self.position(sequence) 166 | x = self.dropout(self.layer_norm(embedding)) 167 | 168 | mask = None 169 | for transformer in self.transformer_blocks: 170 | x = transformer.forward(x, mask) 171 | 172 | x = self.deconv(x.permute(0, 2, 1)).permute(0, 2, 1) 173 | x = torch.tanh(self.linear1(x)) 174 | x = self.linear2(x) 175 | return x 176 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | from abc import * 2 | from config import * 3 | from dataloader import * 4 | 5 | import numpy as np 6 | import pandas as pd 7 | from pathlib import Path 8 | from collections import defaultdict 9 | import torch.utils.data as data_utils 10 | 11 | 12 | class AbstractDataset(metaclass=ABCMeta): 13 | def __init__(self, args, stats=None): 14 | self.house_indicies = args.house_indicies 15 | self.appliance_names = args.appliance_names 16 | self.normalize = args.normalize 17 | self.sampling = args.sampling 18 | self.cutoff = [args.cutoff[i] 19 | for i in ['aggregate'] + self.appliance_names] 20 | 21 | self.threshold = [args.threshold[i] for i in self.appliance_names] 22 | self.min_on = [args.min_on[i] for i in self.appliance_names] 23 | self.min_off = [args.min_off[i] for i in self.appliance_names] 24 | 25 | self.val_size = args.validation_size 26 | self.window_size = args.window_size 27 | self.window_stride = args.window_stride 28 | 29 | self.x, self.y = self.load_data() 30 | self.status = self.compute_status(self.y) 31 | print('Appliance:', self.appliance_names) 32 | print('Sum of ons:', np.sum(self.status, axis=0)) 33 | print('Total length:', self.status.shape[0]) 34 | 35 | if stats is None: 36 | self.x_mean = np.mean(self.x, axis=0) 37 | self.x_std = np.std(self.x, axis=0) 38 | else: 39 | self.x_mean, self.x_std = stats 40 | 41 | self.x = (self.x - self.x_mean) / self.x_std 42 | 43 | @classmethod 44 | @abstractmethod 45 | def code(cls): 46 | pass 47 | 48 | @classmethod 49 | def raw_code(cls): 50 | return cls.code() 51 | 52 | @abstractmethod 53 | def load_data(self): 54 | pass 55 | 56 | def get_data(self): 57 | return self.x, self.y, self.status 58 | 59 | def get_original_data(self): 60 | x_org = self.x * self.x_std + self.x_mean 61 | return x_org, self.y, self.status 62 | 63 | def get_mean_std(self): 64 | return self.x_mean, self.x_std 65 | 66 | def compute_status(self, data): 67 | status = np.zeros(data.shape) 68 | if len(data.squeeze().shape) == 1: 69 | columns = 1 70 | else: 71 | columns = data.squeeze().shape[-1] 72 | 73 | if not self.threshold: 74 | self.threshold = [10 for i in range(columns)] 75 | if not self.min_on: 76 | self.min_on = [1 for i in range(columns)] 77 | if not self.min_off: 78 | self.min_off = [1 for i in range(columns)] 79 | 80 | for i in range(columns): 81 | initial_status = data[:, i] >= self.threshold[i] 82 | status_diff = np.diff(initial_status) 83 | events_idx = status_diff.nonzero() 84 | 85 | events_idx = np.array(events_idx).squeeze() 86 | events_idx += 1 87 | 88 | if initial_status[0]: 89 | events_idx = np.insert(events_idx, 0, 0) 90 | 91 | if initial_status[-1]: 92 | events_idx = np.insert( 93 | events_idx, events_idx.size, initial_status.size) 94 | 95 | events_idx = events_idx.reshape((-1, 2)) 96 | on_events = events_idx[:, 0].copy() 97 | off_events = events_idx[:, 1].copy() 98 | assert len(on_events) == len(off_events) 99 | 100 | if len(on_events) > 0: 101 | off_duration = on_events[1:] - off_events[:-1] 102 | off_duration = np.insert(off_duration, 0, 1000) 103 | on_events = on_events[off_duration > self.min_off[i]] 104 | off_events = off_events[np.roll( 105 | off_duration, -1) > self.min_off[i]] 106 | 107 | on_duration = off_events - on_events 108 | on_events = on_events[on_duration >= self.min_on[i]] 109 | off_events = off_events[on_duration >= self.min_on[i]] 110 | assert len(on_events) == len(off_events) 111 | 112 | temp_status = data[:, i].copy() 113 | temp_status[:] = 0 114 | for on, off in zip(on_events, off_events): 115 | temp_status[on: off] = 1 116 | status[:, i] = temp_status 117 | 118 | return status 119 | 120 | def get_status(self): 121 | return self.status 122 | 123 | def get_datasets(self): 124 | val_end = int(self.val_size * len(self.x)) 125 | val = NILMDataset(self.x[:val_end], self.y[:val_end], self.status[:val_end], 126 | self.window_size, self.window_size) 127 | train = NILMDataset(self.x[val_end:], self.y[val_end:], self.status[val_end:], 128 | self.window_size, self.window_stride) 129 | return train, val 130 | 131 | def get_bert_datasets(self, mask_prob=0.25): 132 | val_end = int(self.val_size * len(self.x)) 133 | val = NILMDataset(self.x[:val_end], self.y[:val_end], self.status[:val_end], 134 | self.window_size, self.window_size) 135 | train = BERTDataset(self.x[val_end:], self.y[val_end:], self.status[val_end:], 136 | self.window_size, self.window_stride, mask_prob=mask_prob) 137 | return train, val 138 | 139 | def _get_rawdata_root_path(self): 140 | return Path(RAW_DATASET_ROOT_FOLDER) 141 | 142 | def _get_folder_path(self): 143 | root = self._get_rawdata_root_path() 144 | return root.joinpath(self.raw_code()) 145 | 146 | 147 | class REDD_LF_Dataset(AbstractDataset): 148 | @classmethod 149 | def code(cls): 150 | return 'redd_lf' 151 | 152 | @classmethod 153 | def _if_data_exists(self): 154 | folder = Path(RAW_DATASET_ROOT_FOLDER).joinpath(self.code()) 155 | first_file = folder.joinpath('house_1', 'channel_1.dat') 156 | if first_file.is_file(): 157 | return True 158 | return False 159 | 160 | def load_data(self): 161 | for appliance in self.appliance_names: 162 | assert appliance in ['dishwasher', 163 | 'refrigerator', 'microwave', 'washer_dryer'] 164 | 165 | for house_id in self.house_indicies: 166 | assert house_id in [1, 2, 3, 4, 5, 6] 167 | 168 | if not self.cutoff: 169 | self.cutoff = [6000] * (len(self.appliance_names) + 1) 170 | 171 | if not self._if_data_exists(): 172 | print('Please download, unzip and move data into', 173 | self._get_folder_path()) 174 | raise FileNotFoundError 175 | 176 | else: 177 | directory = self._get_folder_path() 178 | 179 | for house_id in self.house_indicies: 180 | house_folder = directory.joinpath('house_' + str(house_id)) 181 | house_label = pd.read_csv(house_folder.joinpath( 182 | 'labels.dat'), sep=' ', header=None) 183 | 184 | main_1 = pd.read_csv(house_folder.joinpath( 185 | 'channel_1.dat'), sep=' ', header=None) 186 | main_2 = pd.read_csv(house_folder.joinpath( 187 | 'channel_2.dat'), sep=' ', header=None) 188 | house_data = pd.merge(main_1, main_2, how='inner', on=0) 189 | house_data.iloc[:, 1] = house_data.iloc[:, 190 | 1] + house_data.iloc[:, 2] 191 | house_data = house_data.iloc[:, 0: 2] 192 | 193 | appliance_list = house_label.iloc[:, 1].values 194 | app_index_dict = defaultdict(list) 195 | 196 | for appliance in self.appliance_names: 197 | data_found = False 198 | for i in range(len(appliance_list)): 199 | if appliance_list[i] == appliance: 200 | app_index_dict[appliance].append(i + 1) 201 | data_found = True 202 | 203 | if not data_found: 204 | app_index_dict[appliance].append(-1) 205 | 206 | if np.sum(list(app_index_dict.values())) == -len(self.appliance_names): 207 | self.house_indicies.remove(house_id) 208 | continue 209 | 210 | for appliance in self.appliance_names: 211 | if app_index_dict[appliance][0] == -1: 212 | temp_values = house_data.copy().iloc[:, 1] 213 | temp_values[:] = 0 214 | temp_data = house_data.copy().iloc[:, :2] 215 | temp_data.iloc[:, 1] = temp_values 216 | else: 217 | temp_data = pd.read_csv(house_folder.joinpath( 218 | 'channel_' + str(app_index_dict[appliance][0]) + '.dat'), sep=' ', header=None) 219 | 220 | if len(app_index_dict[appliance]) > 1: 221 | for idx in app_index_dict[appliance][1:]: 222 | temp_data_ = pd.read_csv(house_folder.joinpath( 223 | 'channel_' + str(idx) + '.dat'), sep=' ', header=None) 224 | temp_data = pd.merge( 225 | temp_data, temp_data_, how='inner', on=0) 226 | temp_data.iloc[:, 1] = temp_data.iloc[:, 227 | 1] + temp_data.iloc[:, 2] 228 | temp_data = temp_data.iloc[:, 0: 2] 229 | 230 | house_data = pd.merge( 231 | house_data, temp_data, how='inner', on=0) 232 | 233 | house_data.iloc[:, 0] = pd.to_datetime( 234 | house_data.iloc[:, 0], unit='s') 235 | house_data.columns = ['time', 'aggregate'] + \ 236 | [i for i in self.appliance_names] 237 | house_data = house_data.set_index('time') 238 | house_data = house_data.resample(self.sampling).mean().fillna( 239 | method='ffill', limit=30) 240 | 241 | if house_id == self.house_indicies[0]: 242 | entire_data = house_data 243 | else: 244 | entire_data = entire_data.append( 245 | house_data, ignore_index=True) 246 | 247 | entire_data = entire_data.dropna().copy() 248 | entire_data = entire_data[entire_data['aggregate'] > 0] 249 | entire_data[entire_data < 5] = 0 250 | entire_data = entire_data.clip( 251 | [0] * len(entire_data.columns), self.cutoff, axis=1) 252 | 253 | return entire_data.values[:, 0], entire_data.values[:, 1:] 254 | 255 | 256 | class UK_DALE_Dataset(AbstractDataset): 257 | @classmethod 258 | def code(cls): 259 | return 'uk_dale' 260 | 261 | @classmethod 262 | def _if_data_exists(self): 263 | folder = Path(RAW_DATASET_ROOT_FOLDER).joinpath(self.code()) 264 | first_file = folder.joinpath('house_1', 'channel_1.dat') 265 | if first_file.is_file(): 266 | return True 267 | return False 268 | 269 | def load_data(self): 270 | for appliance in self.appliance_names: 271 | assert appliance in ['dishwasher', 'fridge', 272 | 'microwave', 'washing_machine', 'kettle'] 273 | 274 | for house_id in self.house_indicies: 275 | assert house_id in [1, 2, 3, 4, 5] 276 | 277 | if not self.cutoff: 278 | self.cutoff = [6000] * (len(self.appliance_names) + 1) 279 | 280 | if not self._if_data_exists(): 281 | print('Please download, unzip and move data into', 282 | self._get_folder_path()) 283 | raise FileNotFoundError 284 | 285 | else: 286 | directory = self._get_folder_path() 287 | 288 | for house_id in self.house_indicies: 289 | house_folder = directory.joinpath('house_' + str(house_id)) 290 | house_label = pd.read_csv(house_folder.joinpath( 291 | 'labels.dat'), sep=' ', header=None) 292 | 293 | house_data = pd.read_csv(house_folder.joinpath( 294 | 'channel_1.dat'), sep=' ', header=None) 295 | house_data.iloc[:, 0] = pd.to_datetime( 296 | house_data.iloc[:, 0], unit='s') 297 | house_data.columns = ['time', 'aggregate'] 298 | house_data = house_data.set_index('time') 299 | house_data = house_data.resample(self.sampling).mean().fillna( 300 | method='ffill', limit=30) 301 | 302 | appliance_list = house_label.iloc[:, 1].values 303 | app_index_dict = defaultdict(list) 304 | 305 | for appliance in self.appliance_names: 306 | data_found = False 307 | for i in range(len(appliance_list)): 308 | if appliance_list[i] == appliance: 309 | app_index_dict[appliance].append(i + 1) 310 | data_found = True 311 | 312 | if not data_found: 313 | app_index_dict[appliance].append(-1) 314 | 315 | if np.sum(list(app_index_dict.values())) == -len(self.appliance_names): 316 | self.house_indicies.remove(house_id) 317 | continue 318 | 319 | for appliance in self.appliance_names: 320 | if app_index_dict[appliance][0] == -1: 321 | house_data.insert(len(house_data.columns), appliance, np.zeros(len(house_data))) 322 | else: 323 | temp_data = pd.read_csv(house_folder.joinpath( 324 | 'channel_' + str(app_index_dict[appliance][0]) + '.dat'), sep=' ', header=None) 325 | temp_data.iloc[:, 0] = pd.to_datetime( 326 | temp_data.iloc[:, 0], unit='s') 327 | temp_data.columns = ['time', appliance] 328 | temp_data = temp_data.set_index('time') 329 | temp_data = temp_data.resample(self.sampling).mean().fillna( 330 | method='ffill', limit=30) 331 | house_data = pd.merge( 332 | house_data, temp_data, how='inner', on='time') 333 | 334 | if house_id == self.house_indicies[0]: 335 | entire_data = house_data 336 | if len(self.house_indicies) == 1: 337 | entire_data = entire_data.reset_index(drop=True) 338 | else: 339 | entire_data = entire_data.append( 340 | house_data, ignore_index=True) 341 | 342 | entire_data = entire_data.dropna().copy() 343 | entire_data = entire_data[entire_data['aggregate'] > 0] 344 | entire_data[entire_data < 5] = 0 345 | entire_data = entire_data.clip( 346 | [0] * len(entire_data.columns), self.cutoff, axis=1) 347 | 348 | return entire_data.values[:, 0], entire_data.values[:, 1:] 349 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | from torch.autograd.gradcheck import zero_gradients 6 | from tqdm import tqdm 7 | 8 | import os 9 | import json 10 | import random 11 | import numpy as np 12 | from abc import * 13 | from pathlib import Path 14 | 15 | from utils import * 16 | import matplotlib.pyplot as plt 17 | 18 | 19 | torch.set_default_tensor_type(torch.DoubleTensor) 20 | 21 | 22 | class Trainer(metaclass=ABCMeta): 23 | def __init__(self, args, model, train_loader, val_loader, stats, export_root): 24 | self.args = args 25 | self.device = args.device 26 | self.num_epochs = args.num_epochs 27 | self.model = model.to(self.device) 28 | self.export_root = Path(export_root) 29 | 30 | self.cutoff = torch.tensor([args.cutoff[i] 31 | for i in args.appliance_names]).to(self.device) 32 | self.threshold = torch.tensor( 33 | [args.threshold[i] for i in args.appliance_names]).to(self.device) 34 | self.min_on = torch.tensor([args.min_on[i] 35 | for i in args.appliance_names]).to(self.device) 36 | self.min_off = torch.tensor( 37 | [args.min_off[i] for i in args.appliance_names]).to(self.device) 38 | 39 | self.normalize = args.normalize 40 | self.denom = args.denom 41 | if self.normalize == 'mean': 42 | self.x_mean, self.x_std = stats 43 | self.x_mean = torch.tensor(self.x_mean).to(self.device) 44 | self.x_std = torch.tensor(self.x_std).to(self.device) 45 | 46 | self.train_loader = train_loader 47 | self.val_loader = val_loader 48 | 49 | self.optimizer = self._create_optimizer() 50 | if args.enable_lr_schedule: 51 | self.lr_scheduler = optim.lr_scheduler.StepLR( 52 | self.optimizer, step_size=args.decay_step, gamma=args.gamma) 53 | 54 | self.C0 = torch.tensor(args.c0[args.appliance_names[0]]).to(self.device) 55 | print('C0: {}'.format(self.C0)) 56 | self.kl = nn.KLDivLoss(reduction='batchmean') 57 | self.mse = nn.MSELoss() 58 | self.margin = nn.SoftMarginLoss() 59 | self.l1_on = nn.L1Loss(reduction='sum') 60 | 61 | def train(self): 62 | val_rel_err, val_abs_err = [], [] 63 | val_acc, val_precision, val_recall, val_f1 = [], [], [], [] 64 | 65 | best_rel_err, _, best_acc, _, _, best_f1 = self.validate() 66 | self._save_state_dict() 67 | 68 | for epoch in range(self.num_epochs): 69 | self.train_bert_one_epoch(epoch + 1) 70 | 71 | rel_err, abs_err, acc, precision, recall, f1 = self.validate() 72 | val_rel_err.append(rel_err.tolist()) 73 | val_abs_err.append(abs_err.tolist()) 74 | val_acc.append(acc.tolist()) 75 | val_precision.append(precision.tolist()) 76 | val_recall.append(recall.tolist()) 77 | val_f1.append(f1.tolist()) 78 | 79 | if f1.mean() + acc.mean() - rel_err.mean() > best_f1.mean() + best_acc.mean() - best_rel_err.mean(): 80 | best_f1 = f1 81 | best_acc = acc 82 | best_rel_err = rel_err 83 | self._save_state_dict() 84 | 85 | def train_one_epoch(self, epoch): 86 | loss_values = [] 87 | self.model.train() 88 | tqdm_dataloader = tqdm(self.train_loader) 89 | for batch_idx, batch in enumerate(tqdm_dataloader): 90 | seqs, labels_energy, status = batch 91 | seqs, labels_energy, status = seqs.to(self.device), labels_energy.to(self.device), status.to(self.device) 92 | self.optimizer.zero_grad() 93 | logits = self.model(seqs) 94 | labels = labels_energy / self.cutoff 95 | logits_energy = self.cutoff_energy(logits * self.cutoff) 96 | logits_status = self.compute_status(logits_energy) 97 | 98 | kl_loss = self.kl(torch.log(F.softmax(logits.squeeze() / 0.1, dim=-1) + 1e-9), F.softmax(labels.squeeze() / 0.1, dim=-1)) 99 | mse_loss = self.mse(logits.contiguous().view(-1).double(), 100 | labels.contiguous().view(-1).double()) 101 | margin_loss = self.margin((logits_status * 2 - 1).contiguous().view(-1).double(), 102 | (status * 2 - 1).contiguous().view(-1).double()) 103 | total_loss = kl_loss + mse_loss + margin_loss 104 | 105 | on_mask = ((status == 1) + (status != logits_status.reshape(status.shape))) >= 1 106 | if on_mask.sum() > 0: 107 | total_size = torch.tensor(on_mask.shape).prod() 108 | logits_on = torch.masked_select(logits.reshape(on_mask.shape), on_mask) 109 | labels_on = torch.masked_select(labels.reshape(on_mask.shape), on_mask) 110 | loss_l1_on = self.l1_on(logits_on.contiguous().view(-1), 111 | labels_on.contiguous().view(-1)) 112 | total_loss += self.C0 * loss_l1_on / total_size 113 | 114 | total_loss.backward() 115 | self.optimizer.step() 116 | loss_values.append(total_loss.item()) 117 | 118 | average_loss = np.mean(np.array(loss_values)) 119 | tqdm_dataloader.set_description('Epoch {}, loss {:.2f}'.format(epoch, average_loss)) 120 | 121 | if self.args.enable_lr_schedule: 122 | self.lr_scheduler.step() 123 | 124 | def train_bert_one_epoch(self, epoch): 125 | loss_values = [] 126 | self.model.train() 127 | tqdm_dataloader = tqdm(self.train_loader) 128 | for batch_idx, batch in enumerate(tqdm_dataloader): 129 | seqs, labels_energy, status = batch 130 | seqs, labels_energy, status = seqs.to(self.device), labels_energy.to(self.device), status.to(self.device) 131 | batch_shape = status.shape 132 | self.optimizer.zero_grad() 133 | logits = self.model(seqs) 134 | labels = labels_energy / self.cutoff 135 | logits_energy = self.cutoff_energy(logits * self.cutoff) 136 | logits_status = self.compute_status(logits_energy) 137 | 138 | mask = (status >= 0) 139 | labels_masked = torch.masked_select(labels, mask).view((-1, batch_shape[-1])) 140 | logits_masked = torch.masked_select(logits, mask).view((-1, batch_shape[-1])) 141 | status_masked = torch.masked_select(status, mask).view((-1, batch_shape[-1])) 142 | logits_status_masked = torch.masked_select(logits_status, mask).view((-1, batch_shape[-1])) 143 | 144 | kl_loss = self.kl(torch.log(F.softmax(logits_masked.squeeze() / 0.1, dim=-1) + 1e-9), F.softmax(labels_masked.squeeze() / 0.1, dim=-1)) 145 | mse_loss = self.mse(logits_masked.contiguous().view(-1).double(), 146 | labels_masked.contiguous().view(-1).double()) 147 | margin_loss = self.margin((logits_status_masked * 2 - 1).contiguous().view(-1).double(), 148 | (status_masked * 2 - 1).contiguous().view(-1).double()) 149 | total_loss = kl_loss + mse_loss + margin_loss 150 | 151 | on_mask = (status >= 0) * (((status == 1) + (status != logits_status.reshape(status.shape))) >= 1) 152 | if on_mask.sum() > 0: 153 | total_size = torch.tensor(on_mask.shape).prod() 154 | logits_on = torch.masked_select(logits.reshape(on_mask.shape), on_mask) 155 | labels_on = torch.masked_select(labels.reshape(on_mask.shape), on_mask) 156 | loss_l1_on = self.l1_on(logits_on.contiguous().view(-1), 157 | labels_on.contiguous().view(-1)) 158 | total_loss += self.C0 * loss_l1_on / total_size 159 | 160 | total_loss.backward() 161 | self.optimizer.step() 162 | loss_values.append(total_loss.item()) 163 | 164 | average_loss = np.mean(np.array(loss_values)) 165 | tqdm_dataloader.set_description('Epoch {}, loss {:.2f}'.format(epoch, average_loss)) 166 | 167 | if self.args.enable_lr_schedule: 168 | self.lr_scheduler.step() 169 | 170 | def validate(self): 171 | self.model.eval() 172 | loss_values, relative_errors, absolute_errors = [], [], [] 173 | acc_values, precision_values, recall_values, f1_values, = [], [], [], [] 174 | 175 | with torch.no_grad(): 176 | tqdm_dataloader = tqdm(self.val_loader) 177 | for batch_idx, batch in enumerate(tqdm_dataloader): 178 | seqs, labels_energy, status = batch 179 | seqs, labels_energy, status = seqs.to(self.device), labels_energy.to(self.device), status.to(self.device) 180 | logits = self.model(seqs) 181 | labels = labels_energy / self.cutoff 182 | logits_energy = self.cutoff_energy(logits * self.cutoff) 183 | logits_status = self.compute_status(logits_energy) 184 | logits_energy = logits_energy * logits_status 185 | 186 | rel_err, abs_err = relative_absolute_error(logits_energy.detach( 187 | ).cpu().numpy().squeeze(), labels_energy.detach().cpu().numpy().squeeze()) 188 | relative_errors.append(rel_err.tolist()) 189 | absolute_errors.append(abs_err.tolist()) 190 | 191 | acc, precision, recall, f1 = acc_precision_recall_f1_score(logits_status.detach( 192 | ).cpu().numpy().squeeze(), status.detach().cpu().numpy().squeeze()) 193 | acc_values.append(acc.tolist()) 194 | precision_values.append(precision.tolist()) 195 | recall_values.append(recall.tolist()) 196 | f1_values.append(f1.tolist()) 197 | 198 | average_acc = np.mean(np.array(acc_values).reshape(-1)) 199 | average_f1 = np.mean(np.array(f1_values).reshape(-1)) 200 | average_rel_err = np.mean(np.array(relative_errors).reshape(-1)) 201 | 202 | tqdm_dataloader.set_description('Validation, rel_err {:.2f}, acc {:.2f}, f1 {:.2f}'.format( 203 | average_rel_err, average_acc, average_f1)) 204 | 205 | return_rel_err = np.array(relative_errors).mean(axis=0) 206 | return_abs_err = np.array(absolute_errors).mean(axis=0) 207 | return_acc = np.array(acc_values).mean(axis=0) 208 | return_precision = np.array(precision_values).mean(axis=0) 209 | return_recall = np.array(recall_values).mean(axis=0) 210 | return_f1 = np.array(f1_values).mean(axis=0) 211 | return return_rel_err, return_abs_err, return_acc, return_precision, return_recall, return_f1 212 | 213 | def test(self, test_loader): 214 | self._load_best_model() 215 | self.model.eval() 216 | loss_values, relative_errors, absolute_errors = [], [], [] 217 | acc_values, precision_values, recall_values, f1_values, = [], [], [], [] 218 | 219 | label_curve = [] 220 | e_pred_curve = [] 221 | status_curve = [] 222 | s_pred_curve = [] 223 | with torch.no_grad(): 224 | tqdm_dataloader = tqdm(test_loader) 225 | for batch_idx, batch in enumerate(tqdm_dataloader): 226 | seqs, labels_energy, status = batch 227 | seqs, labels_energy, status = seqs.to(self.device), labels_energy.to(self.device), status.to(self.device) 228 | logits = self.model(seqs) 229 | labels = labels_energy / self.cutoff 230 | logits_energy = self.cutoff_energy(logits * self.cutoff) 231 | logits_status = self.compute_status(logits_energy) 232 | logits_energy = logits_energy * logits_status 233 | 234 | acc, precision, recall, f1 = acc_precision_recall_f1_score(logits_status.detach( 235 | ).cpu().numpy().squeeze(), status.detach().cpu().numpy().squeeze()) 236 | acc_values.append(acc.tolist()) 237 | precision_values.append(precision.tolist()) 238 | recall_values.append(recall.tolist()) 239 | f1_values.append(f1.tolist()) 240 | 241 | rel_err, abs_err = relative_absolute_error(logits_energy.detach( 242 | ).cpu().numpy().squeeze(), labels_energy.detach().cpu().numpy().squeeze()) 243 | relative_errors.append(rel_err.tolist()) 244 | absolute_errors.append(abs_err.tolist()) 245 | 246 | average_acc = np.mean(np.array(acc_values).reshape(-1)) 247 | average_f1 = np.mean(np.array(f1_values).reshape(-1)) 248 | average_rel_err = np.mean(np.array(relative_errors).reshape(-1)) 249 | 250 | tqdm_dataloader.set_description('Test, rel_err {:.2f}, acc {:.2f}, f1 {:.2f}'.format( 251 | average_rel_err, average_acc, average_f1)) 252 | 253 | label_curve.append(labels_energy.detach().cpu().numpy().tolist()) 254 | e_pred_curve.append(logits_energy.detach().cpu().numpy().tolist()) 255 | status_curve.append(status.detach().cpu().numpy().tolist()) 256 | s_pred_curve.append(logits_status.detach().cpu().numpy().tolist()) 257 | 258 | label_curve = np.concatenate(label_curve).reshape(-1, self.args.output_size) 259 | e_pred_curve = np.concatenate(e_pred_curve).reshape(-1, self.args.output_size) 260 | status_curve = np.concatenate(status_curve).reshape(-1, self.args.output_size) 261 | s_pred_curve = np.concatenate(s_pred_curve).reshape(-1, self.args.output_size) 262 | 263 | self._save_result({'gt': label_curve.tolist(), 264 | 'pred': e_pred_curve.tolist()}, 'test_result.json') 265 | 266 | if self.args.output_size > 1: 267 | return_rel_err = np.array(relative_errors).mean(axis=0) 268 | else: 269 | return_rel_err = np.array(relative_errors).mean() 270 | return_rel_err, return_abs_err = relative_absolute_error(e_pred_curve, label_curve) 271 | return_acc, return_precision, return_recall, return_f1 = acc_precision_recall_f1_score(s_pred_curve, status_curve) 272 | 273 | return return_rel_err, return_abs_err, return_acc, return_precision, return_recall, return_f1 274 | 275 | def cutoff_energy(self, data): 276 | columns = data.squeeze().shape[-1] 277 | 278 | if self.cutoff.size(0) == 0: 279 | self.cutoff = torch.tensor( 280 | [3100 for i in range(columns)]).to(self.device) 281 | 282 | data[data < 5] = 0 283 | data = torch.min(data, self.cutoff.double()) 284 | return data 285 | 286 | def compute_status(self, data): 287 | data_shape = data.shape 288 | columns = data.squeeze().shape[-1] 289 | 290 | if self.threshold.size(0) == 0: 291 | self.threshold = torch.tensor( 292 | [10 for i in range(columns)]).to(self.device) 293 | 294 | status = (data >= self.threshold) * 1 295 | return status 296 | 297 | def _create_optimizer(self): 298 | args = self.args 299 | param_optimizer = list(self.model.named_parameters()) 300 | no_decay = ['bias', 'layer_norm'] 301 | optimizer_grouped_parameters = [ 302 | { 303 | 'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 304 | 'weight_decay': args.weight_decay, 305 | }, 306 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}, 307 | ] 308 | if args.optimizer.lower() == 'adamw': 309 | return optim.AdamW(optimizer_grouped_parameters, lr=args.lr) 310 | elif args.optimizer.lower() == 'adam': 311 | return optim.Adam(optimizer_grouped_parameters, lr=args.lr) 312 | elif args.optimizer.lower() == 'sgd': 313 | return optim.SGD(optimizer_grouped_parameters, lr=args.lr, momentum=args.momentum) 314 | else: 315 | raise ValueError 316 | 317 | def _load_best_model(self): 318 | try: 319 | self.model.load_state_dict(torch.load( 320 | self.export_root.joinpath('best_acc_model.pth'))) 321 | self.model.to(self.device) 322 | except: 323 | print('Failed to load best model, continue testing with current model...') 324 | 325 | def _save_state_dict(self): 326 | if not os.path.exists(self.export_root): 327 | os.makedirs(self.export_root) 328 | print('Saving best model...') 329 | torch.save(self.model.state_dict(), 330 | self.export_root.joinpath('best_acc_model.pth')) 331 | 332 | def _save_values(self, filename): 333 | if not os.path.exists(self.export_root): 334 | os.makedirs(self.export_root) 335 | torch.save(self.model.state_dict(), 336 | self.export_root.joinpath('best_acc_model.pth')) 337 | 338 | def _save_result(self, data, filename): 339 | if not os.path.exists(self.export_root): 340 | os.makedirs(self.export_root) 341 | filepath = Path(self.export_root).joinpath(filename) 342 | with filepath.open('w') as f: 343 | json.dump(data, f, indent=2) 344 | --------------------------------------------------------------------------------