├── redd.png
├── uk-dale.png
├── config.py
├── data
├── uk_dale
│ ├── house_3
│ │ └── labels.dat
│ ├── house_4
│ │ └── labels.dat
│ ├── house_2
│ │ └── labels.dat
│ ├── house_5
│ │ └── labels.dat
│ └── house_1
│ │ └── labels.dat
└── redd_lf
│ ├── house_2
│ └── labels.dat
│ ├── house_6
│ └── labels.dat
│ ├── house_1
│ └── labels.dat
│ ├── house_4
│ └── labels.dat
│ ├── house_3
│ └── labels.dat
│ └── house_5
│ └── labels.dat
├── README.md
├── dataloader.py
├── train.py
├── utils.py
├── model.py
├── dataset.py
└── trainer.py
/redd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Yueeeeeeee/BERT4NILM/HEAD/redd.png
--------------------------------------------------------------------------------
/uk-dale.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Yueeeeeeee/BERT4NILM/HEAD/uk-dale.png
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | RAW_DATASET_ROOT_FOLDER = 'data'
2 | EXPERIMENT_ROOT_FOLDER = 'experiments'
3 |
--------------------------------------------------------------------------------
/data/uk_dale/house_3/labels.dat:
--------------------------------------------------------------------------------
1 | 1 aggregate
2 | 2 kettle
3 | 3 electric_heater
4 | 4 laptop
5 | 5 projector
6 |
--------------------------------------------------------------------------------
/data/uk_dale/house_4/labels.dat:
--------------------------------------------------------------------------------
1 | 1 aggregate
2 | 2 tv_dvd_digibox_lamp
3 | 3 kettle
4 | 4 gas_boiler
5 | 5 freezer
6 | 6 washing_machine
7 |
--------------------------------------------------------------------------------
/data/redd_lf/house_2/labels.dat:
--------------------------------------------------------------------------------
1 | 1 mains
2 | 2 mains
3 | 3 kitchen_outlets
4 | 4 lighting
5 | 5 stove
6 | 6 microwave
7 | 7 washer_dryer
8 | 8 kitchen_outlets
9 | 9 refrigerator
10 | 10 dishwasher
11 | 11 disposal
12 |
--------------------------------------------------------------------------------
/data/uk_dale/house_2/labels.dat:
--------------------------------------------------------------------------------
1 | 1 aggregate
2 | 2 laptop
3 | 3 monitor
4 | 4 speakers
5 | 5 server
6 | 6 router
7 | 7 server_hdd
8 | 8 kettle
9 | 9 rice_cooker
10 | 10 running_machine
11 | 11 laptop2
12 | 12 washing_machine
13 | 13 dishwasher
14 | 14 fridge
15 | 15 microwave
16 | 16 toaster
17 | 17 playstation
18 | 18 modem
19 | 19 cooker
20 |
--------------------------------------------------------------------------------
/data/redd_lf/house_6/labels.dat:
--------------------------------------------------------------------------------
1 | 1 mains
2 | 2 mains
3 | 3 kitchen_outlets
4 | 4 washer_dryer
5 | 5 stove
6 | 6 electronics
7 | 7 bathroom_gfi
8 | 8 refrigerator
9 | 9 dishwasher
10 | 10 outlets_unknown
11 | 11 outlets_unknown
12 | 12 electric_heat
13 | 13 kitchen_outlets
14 | 14 lighting
15 | 15 air_conditioning
16 | 16 air_conditioning
17 | 17 air_conditioning
18 |
--------------------------------------------------------------------------------
/data/redd_lf/house_1/labels.dat:
--------------------------------------------------------------------------------
1 | 1 mains
2 | 2 mains
3 | 3 oven
4 | 4 oven
5 | 5 refrigerator
6 | 6 dishwasher
7 | 7 kitchen_outlets
8 | 8 kitchen_outlets
9 | 9 lighting
10 | 10 washer_dryer
11 | 11 microwave
12 | 12 bathroom_gfi
13 | 13 electric_heat
14 | 14 stove
15 | 15 kitchen_outlets
16 | 16 kitchen_outlets
17 | 17 lighting
18 | 18 lighting
19 | 19 washer_dryer
20 | 20 washer_dryer
21 |
--------------------------------------------------------------------------------
/data/redd_lf/house_4/labels.dat:
--------------------------------------------------------------------------------
1 | 1 mains
2 | 2 mains
3 | 3 lighting
4 | 4 furance
5 | 5 kitchen_outlets
6 | 6 outlets_unknown
7 | 7 washer_dryer
8 | 8 stove
9 | 9 air_conditioning
10 | 10 air_conditioning
11 | 11 miscellaeneous
12 | 12 smoke_alarms
13 | 13 lighting
14 | 14 kitchen_outlets
15 | 15 dishwasher
16 | 16 bathroom_gfi
17 | 17 bathroom_gfi
18 | 18 lighting
19 | 19 lighting
20 | 20 air_conditioning
21 |
--------------------------------------------------------------------------------
/data/redd_lf/house_3/labels.dat:
--------------------------------------------------------------------------------
1 | 1 mains
2 | 2 mains
3 | 3 outlets_unknown
4 | 4 outlets_unknown
5 | 5 lighting
6 | 6 electronics
7 | 7 refrigerator
8 | 8 disposal
9 | 9 dishwasher
10 | 10 furance
11 | 11 lighting
12 | 12 outlets_unknown
13 | 13 washer_dryer
14 | 14 washer_dryer
15 | 15 lighting
16 | 16 microwave
17 | 17 lighting
18 | 18 smoke_alarms
19 | 19 lighting
20 | 20 bathroom_gfi
21 | 21 kitchen_outlets
22 | 22 kitchen_outlets
23 |
--------------------------------------------------------------------------------
/data/uk_dale/house_5/labels.dat:
--------------------------------------------------------------------------------
1 | 1 aggregate
2 | 2 stereo_speakers_bedroom
3 | 3 i7_desktop
4 | 4 hairdryer
5 | 5 primary_tv
6 | 6 24_inch_lcd_bedroom
7 | 7 treadmill
8 | 8 network_attached_storage
9 | 9 core2_server
10 | 10 24_inch_lcd
11 | 11 PS4
12 | 12 steam_iron
13 | 13 nespresso_pixie
14 | 14 atom_pc
15 | 15 toaster
16 | 16 home_theatre_amp
17 | 17 sky_hd_box
18 | 18 kettle
19 | 19 fridge
20 | 20 oven
21 | 21 electric_hob
22 | 22 dishwasher
23 | 23 microwave
24 | 24 washing_machine
25 | 25 vacuum_cleaner
26 |
--------------------------------------------------------------------------------
/data/redd_lf/house_5/labels.dat:
--------------------------------------------------------------------------------
1 | 1 mains
2 | 2 mains
3 | 3 microwave
4 | 4 lighting
5 | 5 outlets_unknown
6 | 6 furance
7 | 7 outlets_unknown
8 | 8 washer_dryer
9 | 9 washer_dryer
10 | 10 subpanel
11 | 11 subpanel
12 | 12 electric_heat
13 | 13 electric_heat
14 | 14 lighting
15 | 15 outlets_unknown
16 | 16 bathroom_gfi
17 | 17 lighting
18 | 18 refrigerator
19 | 19 lighting
20 | 20 dishwasher
21 | 21 disposal
22 | 22 electronics
23 | 23 lighting
24 | 24 kitchen_outlets
25 | 25 kitchen_outlets
26 | 26 outdoor_outlets
27 |
--------------------------------------------------------------------------------
/data/uk_dale/house_1/labels.dat:
--------------------------------------------------------------------------------
1 | 1 aggregate
2 | 2 boiler
3 | 3 solar_thermal_pump
4 | 4 laptop
5 | 5 washing_machine
6 | 6 dishwasher
7 | 7 tv
8 | 8 kitchen_lights
9 | 9 htpc
10 | 10 kettle
11 | 11 toaster
12 | 12 fridge
13 | 13 microwave
14 | 14 lcd_office
15 | 15 hifi_office
16 | 16 breadmaker
17 | 17 amp_livingroom
18 | 18 adsl_router
19 | 19 livingroom_s_lamp
20 | 20 soldering_iron
21 | 21 gigE_&_USBhub
22 | 22 hoover
23 | 23 kitchen_dt_lamp
24 | 24 bedroom_ds_lamp
25 | 25 lighting_circuit
26 | 26 livingroom_s_lamp2
27 | 27 iPad_charger
28 | 28 subwoofer_livingroom
29 | 29 livingroom_lamp_tv
30 | 30 DAB_radio_livingroom
31 | 31 kitchen_lamp2
32 | 32 kitchen_phone&stereo
33 | 33 utilityrm_lamp
34 | 34 samsung_charger
35 | 35 bedroom_d_lamp
36 | 36 coffee_machine
37 | 37 kitchen_radio
38 | 38 bedroom_chargers
39 | 39 hair_dryer
40 | 40 straighteners
41 | 41 iron
42 | 42 gas_oven
43 | 43 data_logger_pc
44 | 44 childs_table_lamp
45 | 45 childs_ds_lamp
46 | 46 baby_monitor_tx
47 | 47 battery_charger
48 | 48 office_lamp1
49 | 49 office_lamp2
50 | 50 office_lamp3
51 | 51 office_pc
52 | 52 office_fan
53 | 53 LED_printer
54 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # BERT4NILM
2 |
3 | PyTorch Implementation of BERT4NILM: A Bidirectional Transformer Model for Non-Intrusive Load Monitoring
4 |
5 |
6 | ## Data
7 |
8 | The csv datasets could be downloaded here: [REDD](http://redd.csail.mit.edu/) and [UK-DALE](https://jack-kelly.com/data/)
9 |
10 | We took the liberty of modifying certain appliance names to 'dishwasher', 'fridge', 'microwave', 'washing_machine' and 'kettle' in the 'labels.dat' file, see data folder
11 |
12 |
13 | ## Training
14 |
15 | This is the PyTorch implementation of BERT4NILM, a bidirectional encoder representations from rransformers for energy disaggregation, in this repository we provide the BERT4NILM model as well as data functions for low frequency REDD dataset / UK Dale dataset, run following command to train an initial model, hyper-parameters (as well as appliances) could be tuned in utils.py, test will run after training ends:
16 |
17 | ```bash
18 | python train.py
19 | ```
20 |
21 | The trained model state dict will be saved under 'experiments/dataset-name/best_acc_model.pth'
22 |
23 |
24 | ## Performance
25 |
26 | Our models are trained 100 / 20 epochs repspectively for appliances from REDD and UK-DALE dataset, all other parameters could be found in 'train.py' and 'utils.py'
27 |
28 | ### REDD
29 |
30 |
31 |
32 | ### UK-DALE
33 |
34 |
35 |
36 |
37 | ## Citing
38 | Please cite the following paper if you use our methods in your research:
39 | ```
40 | @inproceedings{yue2020bert4nilm,
41 | title={BERT4NILM: A Bidirectional Transformer Model for Non-Intrusive Load Monitoring},
42 | author={Yue, Zhenrui and Witzig, Camilo Requena and Jorde, Daniel and Jacobsen, Hans-Arno},
43 | booktitle={Proceedings of the 5th International Workshop on Non-Intrusive Load Monitoring},
44 | pages={89--93},
45 | year={2020}
46 | }
47 | ```
48 |
49 |
50 | ## Acknowledgement
51 |
52 | During the implementation we base our code mostly on the [BERT-pytorch](https://github.com/codertimo/BERT-pytorch) by Junseong Kim, we are also inspired by the [BERT4Rec](https://github.com/jaywonchung/BERT4Rec-VAE-Pytorch) implementation by Jaewon Chung and [Transformers](https://github.com/huggingface/transformers) from Hugging Face. Many thanks to these authors for their great work!
53 |
--------------------------------------------------------------------------------
/dataloader.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | import torch
4 | import torch.utils.data as data_utils
5 |
6 |
7 | torch.set_default_tensor_type(torch.DoubleTensor)
8 |
9 |
10 | class NILMDataloader():
11 | def __init__(self, args, dataset, bert=False):
12 | self.args = args
13 | self.mask_prob = args.mask_prob
14 | self.batch_size = args.batch_size
15 |
16 | if bert:
17 | self.train_dataset, self.val_dataset = dataset.get_bert_datasets(mask_prob=self.mask_prob)
18 | else:
19 | self.train_dataset, self.val_dataset = dataset.get_datasets()
20 |
21 | @classmethod
22 | def code(cls):
23 | return 'dataloader'
24 |
25 | def get_dataloaders(self):
26 | train_loader = self._get_loader(self.train_dataset)
27 | val_loader = self._get_loader(self.val_dataset)
28 | return train_loader, val_loader
29 |
30 | def _get_loader(self, dataset):
31 | dataloader = data_utils.DataLoader(
32 | dataset, batch_size=self.batch_size, shuffle=False, pin_memory=True)
33 | return dataloader
34 |
35 |
36 | class NILMDataset(data_utils.Dataset):
37 | def __init__(self, x, y, status, window_size=480, stride=30):
38 | self.x = x
39 | self.y = y
40 | self.status = status
41 | self.window_size = window_size
42 | self.stride = stride
43 |
44 | def __len__(self):
45 | return int(np.ceil((len(self.x) - self.window_size) / self.stride) + 1)
46 |
47 | def __getitem__(self, index):
48 | start_index = index * self.stride
49 | end_index = np.min(
50 | (len(self.x), index * self.stride + self.window_size))
51 | x = self.padding_seqs(self.x[start_index: end_index])
52 | y = self.padding_seqs(self.y[start_index: end_index])
53 | status = self.padding_seqs(self.status[start_index: end_index])
54 | return torch.tensor(x), torch.tensor(y), torch.tensor(status)
55 |
56 | def padding_seqs(self, in_array):
57 | if len(in_array) == self.window_size:
58 | return in_array
59 | try:
60 | out_array = np.zeros((self.window_size, in_array.shape[1]))
61 | except:
62 | out_array = np.zeros(self.window_size)
63 |
64 | length = len(in_array)
65 | out_array[:length] = in_array
66 | return out_array
67 |
68 |
69 | class BERTDataset(data_utils.Dataset):
70 | def __init__(self, x, y, status, window_size=480, stride=30, mask_prob=0.2):
71 | self.x = x
72 | self.y = y
73 | self.status = status
74 | self.window_size = window_size
75 | self.stride = stride
76 | self.mask_prob = mask_prob
77 | self.columns = y.shape[1]
78 |
79 | def __len__(self):
80 | return int(np.ceil((len(self.x) - self.window_size) / self.stride) + 1)
81 |
82 | def __getitem__(self, index):
83 | start_index = index * self.stride
84 | end_index = np.min(
85 | (len(self.x), index * self.stride + self.window_size))
86 | x = self.padding_seqs(self.x[start_index: end_index])
87 | y = self.padding_seqs(self.y[start_index: end_index])
88 | status = self.padding_seqs(self.status[start_index: end_index])
89 |
90 | tokens = []
91 | labels = []
92 | on_offs = []
93 | for i in range(len(x)):
94 | prob = random.random()
95 | if prob < self.mask_prob:
96 | prob = random.random()
97 | if prob < 0.8:
98 | tokens.append(-1)
99 | elif prob < 0.9:
100 | tokens.append(np.random.normal())
101 | else:
102 | tokens.append(x[i])
103 |
104 | labels.append(y[i])
105 | on_offs.append(status[i])
106 | else:
107 | tokens.append(x[i])
108 | temp = np.array([-1] * self.columns)
109 | labels.append(temp)
110 | on_offs.append(temp)
111 |
112 | return torch.tensor(tokens), torch.tensor(labels), torch.tensor(on_offs)
113 |
114 | def padding_seqs(self, in_array):
115 | if len(in_array) == self.window_size:
116 | return in_array
117 | try:
118 | out_array = np.zeros((self.window_size, in_array.shape[1]))
119 | except:
120 | out_array = np.zeros(self.window_size)
121 |
122 | length = len(in_array)
123 | out_array[:length] = in_array
124 | return out_array
125 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from dataset import *
2 | from dataloader import *
3 | from trainer import *
4 | from config import *
5 | from utils import *
6 | from model import BERT4NILM
7 |
8 | import argparse
9 | import torch
10 |
11 |
12 | def train(args, export_root=None, resume=True):
13 | args.validation_size = 0.1
14 | if args.dataset_code == 'redd_lf':
15 | args.house_indicies = [2, 3, 4, 5, 6]
16 | dataset = REDD_LF_Dataset(args)
17 | elif args.dataset_code == 'uk_dale':
18 | args.house_indicies = [1, 3, 4, 5]
19 | dataset = UK_DALE_Dataset(args)
20 |
21 | x_mean, x_std = dataset.get_mean_std()
22 | stats = (x_mean, x_std)
23 |
24 | model = BERT4NILM(args)
25 |
26 | if export_root == None:
27 | folder_name = '-'.join(args.appliance_names)
28 | export_root = 'experiments/' + args.dataset_code + '/' + folder_name
29 |
30 | dataloader = NILMDataloader(args, dataset, bert=True)
31 | train_loader, val_loader = dataloader.get_dataloaders()
32 |
33 | trainer = Trainer(args, model, train_loader,
34 | val_loader, stats, export_root)
35 | if args.num_epochs > 0:
36 | if resume:
37 | try:
38 | model.load_state_dict(torch.load(os.path.join(
39 | export_root, 'best_acc_model.pth'), map_location='cpu'))
40 | print('Successfully loaded previous model, continue training...')
41 | except FileNotFoundError:
42 | print('Failed to load old model, continue training new model...')
43 | trainer.train()
44 |
45 | args.validation_size = 1.
46 | if args.dataset_code == 'redd_lf':
47 | args.house_indicies = [1]
48 | dataset = REDD_LF_Dataset(args, stats)
49 | elif args.dataset_code == 'uk_dale':
50 | args.house_indicies = [2]
51 | dataset = UK_DALE_Dataset(args, stats)
52 |
53 | dataloader = NILMDataloader(args, dataset)
54 | _, test_loader = dataloader.get_dataloaders()
55 | rel_err, abs_err, acc, prec, recall, f1 = trainer.test(test_loader)
56 | print('Mean Accuracy:', acc)
57 | print('Mean F1-Score:', f1)
58 | print('Mean Relative Error:', rel_err)
59 | print('Mean Absolute Error:', abs_err)
60 |
61 |
62 | def fix_random_seed_as(random_seed):
63 | random.seed(random_seed)
64 | torch.manual_seed(random_seed)
65 | torch.cuda.manual_seed_all(random_seed)
66 | np.random.seed(random_seed)
67 |
68 |
69 | torch.set_default_tensor_type(torch.DoubleTensor)
70 | parser = argparse.ArgumentParser()
71 | parser.add_argument('--seed', type=int, default=12345)
72 | parser.add_argument('--dataset_code', type=str,
73 | default='redd_lf', choices=['redd_lf', 'uk_dale'])
74 | parser.add_argument('--validation_size', type=float, default=0.2)
75 | parser.add_argument('--batch_size', type=int, default=128)
76 | parser.add_argument('--house_indicies', type=list, default=[1, 2, 3, 4, 5])
77 | parser.add_argument('--appliance_names', type=list,
78 | default=['microwave', 'dishwasher'])
79 | parser.add_argument('--sampling', type=str, default='6s')
80 | parser.add_argument('--cutoff', type=dict, default=None)
81 | parser.add_argument('--threshold', type=dict, default=None)
82 | parser.add_argument('--min_on', type=dict, default=None)
83 | parser.add_argument('--min_off', type=dict, default=None)
84 | parser.add_argument('--window_size', type=int, default=480)
85 | parser.add_argument('--window_stride', type=int, default=120)
86 | parser.add_argument('--normalize', type=str, default='mean',
87 | choices=['mean', 'minmax'])
88 | parser.add_argument('--denom', type=int, default=2000)
89 | parser.add_argument('--model_size', type=str, default='gru',
90 | choices=['gru', 'lstm', 'dae'])
91 | parser.add_argument('--output_size', type=int, default=1)
92 | parser.add_argument('--drop_out', type=float, default=0.1)
93 | parser.add_argument('--mask_prob', type=float, default=0.25)
94 | parser.add_argument('--device', type=str, default='cpu',
95 | choices=['cpu', 'cuda'])
96 | parser.add_argument('--optimizer', type=str,
97 | default='adam', choices=['sgd', 'adam', 'adamw'])
98 | parser.add_argument('--lr', type=float, default=1e-4)
99 | parser.add_argument('--weight_decay', type=float, default=0.)
100 | parser.add_argument('--momentum', type=float, default=None)
101 | parser.add_argument('--decay_step', type=int, default=100)
102 | parser.add_argument('--gamma', type=float, default=0.1)
103 | parser.add_argument('--num_epochs', type=int, default=100)
104 | parser.add_argument('--c0', type=dict, default=None)
105 |
106 | args = parser.parse_args()
107 |
108 |
109 | if __name__ == "__main__":
110 | fix_random_seed_as(args.seed)
111 | get_user_input(args)
112 | set_template(args)
113 | train(args)
114 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from sklearn.metrics import confusion_matrix
4 |
5 |
6 | def get_user_input(args):
7 | if torch.cuda.is_available():
8 | args.device = 'cuda:' + input('Input GPU ID: ')
9 | else:
10 | args.device = 'cpu'
11 |
12 | dataset_code = {'r': 'redd_lf', 'u': 'uk_dale'}
13 | args.dataset_code = dataset_code[input(
14 | 'Input r for REDD, u for UK_DALE: ')]
15 |
16 | if args.dataset_code == 'redd_lf':
17 | app_dict = {
18 | 'r': ['refrigerator'],
19 | 'w': ['washer_dryer'],
20 | 'm': ['microwave'],
21 | 'd': ['dishwasher'],
22 | }
23 | args.appliance_names = app_dict[input(
24 | 'Input r, w, m or d for target appliance: ')]
25 |
26 | elif args.dataset_code == 'uk_dale':
27 | app_dict = {
28 | 'k': ['kettle'],
29 | 'f': ['fridge'],
30 | 'w': ['washing_machine'],
31 | 'm': ['microwave'],
32 | 'd': ['dishwasher'],
33 | }
34 | args.appliance_names = app_dict[input(
35 | 'Input k, f, w, m or d for target appliance: ')]
36 |
37 | args.num_epochs = int(input('Input training epochs: '))
38 |
39 |
40 | def set_template(args):
41 | args.output_size = len(args.appliance_names)
42 | if args.dataset_code == 'redd_lf':
43 | args.window_stride = 120
44 | args.house_indicies = [1, 2, 3, 4, 5, 6]
45 |
46 | args.cutoff = {
47 | 'aggregate': 6000,
48 | 'refrigerator': 400,
49 | 'washer_dryer': 3500,
50 | 'microwave': 1800,
51 | 'dishwasher': 1200
52 | }
53 |
54 | args.threshold = {
55 | 'refrigerator': 50,
56 | 'washer_dryer': 20,
57 | 'microwave': 200,
58 | 'dishwasher': 10
59 | }
60 |
61 | args.min_on = {
62 | 'refrigerator': 10,
63 | 'washer_dryer': 300,
64 | 'microwave': 2,
65 | 'dishwasher': 300
66 | }
67 |
68 | args.min_off = {
69 | 'refrigerator': 2,
70 | 'washer_dryer': 26,
71 | 'microwave': 5,
72 | 'dishwasher': 300
73 | }
74 |
75 | args.c0 = {
76 | 'refrigerator': 1e-6,
77 | 'washer_dryer': 0.001,
78 | 'microwave': 1.,
79 | 'dishwasher': 1.
80 | }
81 |
82 | elif args.dataset_code == 'uk_dale':
83 | args.window_stride = 240
84 | args.house_indicies = [1, 2, 3, 4, 5]
85 |
86 | args.cutoff = {
87 | 'aggregate': 6000,
88 | 'kettle': 3100,
89 | 'fridge': 300,
90 | 'washing_machine': 2500,
91 | 'microwave': 3000,
92 | 'dishwasher': 2500
93 | }
94 |
95 | args.threshold = {
96 | 'kettle': 2000,
97 | 'fridge': 50,
98 | 'washing_machine': 20,
99 | 'microwave': 200,
100 | 'dishwasher': 10
101 | }
102 |
103 | args.min_on = {
104 | 'kettle': 2,
105 | 'fridge': 10,
106 | 'washing_machine': 300,
107 | 'microwave': 2,
108 | 'dishwasher': 300
109 | }
110 |
111 | args.min_off = {
112 | 'kettle': 0,
113 | 'fridge': 2,
114 | 'washing_machine': 26,
115 | 'microwave': 5,
116 | 'dishwasher': 300
117 | }
118 |
119 | args.c0 = {
120 | 'kettle': 1.,
121 | 'fridge': 1e-6,
122 | 'washing_machine': 0.01,
123 | 'microwave': 1.,
124 | 'dishwasher': 1.
125 | }
126 |
127 | args.optimizer = 'adam'
128 | args.lr = 1e-4
129 | args.enable_lr_schedule = False
130 | args.batch_size = 128
131 |
132 |
133 | def acc_precision_recall_f1_score(pred, status):
134 | assert pred.shape == status.shape
135 |
136 | pred = pred.reshape(-1, pred.shape[-1])
137 | status = status.reshape(-1, status.shape[-1])
138 | accs, precisions, recalls, f1_scores = [], [], [], []
139 |
140 | for i in range(status.shape[-1]):
141 | tn, fp, fn, tp = confusion_matrix(status[:, i], pred[:, i], labels=[
142 | 0, 1]).ravel()
143 | acc = (tn + tp) / (tn + fp + fn + tp)
144 | precision = tp / np.max((tp + fp, 1e-9))
145 | recall = tp / np.max((tp + fn, 1e-9))
146 | f1_score = 2 * (precision * recall) / \
147 | np.max((precision + recall, 1e-9))
148 |
149 | accs.append(acc)
150 | precisions.append(precision)
151 | recalls.append(recall)
152 | f1_scores.append(f1_score)
153 |
154 | return np.array(accs), np.array(precisions), np.array(recalls), np.array(f1_scores)
155 |
156 |
157 | def relative_absolute_error(pred, label):
158 | assert pred.shape == label.shape
159 |
160 | pred = pred.reshape(-1, pred.shape[-1])
161 | label = label.reshape(-1, label.shape[-1])
162 | temp = np.full(label.shape, 1e-9)
163 | relative, absolute, sum_err = [], [], []
164 |
165 | for i in range(label.shape[-1]):
166 | relative_error = np.mean(np.nan_to_num(np.abs(label[:, i] - pred[:, i]) / np.max(
167 | (label[:, i], pred[:, i], temp[:, i]), axis=0)))
168 | absolute_error = np.mean(np.abs(label[:, i] - pred[:, i]))
169 |
170 | relative.append(relative_error)
171 | absolute.append(absolute_error)
172 |
173 | return np.array(relative), np.array(absolute)
174 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch import nn
4 | import torch.nn.functional as F
5 |
6 |
7 | class GELU(nn.Module):
8 | def forward(self, x):
9 | return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
10 |
11 |
12 | class PositionalEmbedding(nn.Module):
13 | def __init__(self, max_len, d_model):
14 | super().__init__()
15 | self.pe = nn.Embedding(max_len, d_model)
16 |
17 | def forward(self, x):
18 | batch_size = x.size(0)
19 | return self.pe.weight.unsqueeze(0).repeat(batch_size, 1, 1)
20 |
21 |
22 | class LayerNorm(nn.Module):
23 | def __init__(self, features, eps=1e-6):
24 | super(LayerNorm, self).__init__()
25 | self.weight = nn.Parameter(torch.ones(features))
26 | self.bias = nn.Parameter(torch.zeros(features))
27 | self.eps = eps
28 |
29 | def forward(self, x):
30 | mean = x.mean(-1, keepdim=True)
31 | std = x.std(-1, keepdim=True)
32 | return self.weight * (x - mean) / (std + self.eps) + self.bias
33 |
34 |
35 | class Attention(nn.Module):
36 | def forward(self, query, key, value, mask=None, dropout=None):
37 | scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(query.size(-1))
38 | if mask is not None:
39 | scores = scores.masked_fill(mask == 0, -1e9)
40 |
41 | p_attn = F.softmax(scores, dim=-1)
42 | if dropout is not None:
43 | p_attn = dropout(p_attn)
44 |
45 | return torch.matmul(p_attn, value), p_attn
46 |
47 |
48 | class MultiHeadedAttention(nn.Module):
49 | def __init__(self, h, d_model, dropout=0.1):
50 | super().__init__()
51 | assert d_model % h == 0
52 |
53 | self.d_k = d_model // h
54 | self.h = h
55 |
56 | self.linear_layers = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(3)])
57 | self.output_linear = nn.Linear(d_model, d_model)
58 | self.attention = Attention()
59 |
60 | self.dropout = nn.Dropout(p=dropout)
61 |
62 | def forward(self, query, key, value, mask=None):
63 | batch_size = query.size(0)
64 |
65 | query, key, value = [l(x).view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
66 | for l, x in zip(self.linear_layers, (query, key, value))]
67 |
68 | x, attn = self.attention(
69 | query, key, value, mask=mask, dropout=self.dropout)
70 |
71 | x = x.transpose(1, 2).contiguous().view(
72 | batch_size, -1, self.h * self.d_k)
73 |
74 | return self.output_linear(x)
75 |
76 |
77 | class PositionwiseFeedForward(nn.Module):
78 | def __init__(self, d_model, d_ff):
79 | super(PositionwiseFeedForward, self).__init__()
80 | self.w_1 = nn.Linear(d_model, d_ff)
81 | self.w_2 = nn.Linear(d_ff, d_model)
82 | self.activation = GELU()
83 |
84 | def forward(self, x):
85 | return self.w_2(self.activation(self.w_1(x)))
86 |
87 |
88 | class SublayerConnection(nn.Module):
89 | def __init__(self, size, dropout):
90 | super(SublayerConnection, self).__init__()
91 | self.layer_norm = LayerNorm(size)
92 | self.dropout = nn.Dropout(dropout)
93 |
94 | def forward(self, x, sublayer):
95 | return self.layer_norm(x + self.dropout(sublayer(x)))
96 |
97 |
98 | class TransformerBlock(nn.Module):
99 | def __init__(self, hidden, attn_heads, feed_forward_hidden, dropout):
100 | super().__init__()
101 | self.attention = MultiHeadedAttention(
102 | h=attn_heads, d_model=hidden, dropout=dropout)
103 | self.feed_forward = PositionwiseFeedForward(
104 | d_model=hidden, d_ff=feed_forward_hidden)
105 | self.input_sublayer = SublayerConnection(size=hidden, dropout=dropout)
106 | self.output_sublayer = SublayerConnection(size=hidden, dropout=dropout)
107 | self.dropout = nn.Dropout(p=dropout)
108 |
109 | def forward(self, x, mask):
110 | x = self.input_sublayer(
111 | x, lambda _x: self.attention.forward(_x, _x, _x, mask=mask))
112 | x = self.output_sublayer(x, self.feed_forward)
113 | return self.dropout(x)
114 |
115 |
116 | class BERT4NILM(nn.Module):
117 | def __init__(self, args):
118 | super().__init__()
119 | self.args = args
120 |
121 | self.original_len = args.window_size
122 | self.latent_len = int(self.original_len / 2)
123 | self.dropout_rate = args.drop_out
124 |
125 | self.hidden = 256
126 | self.heads = 2
127 | self.n_layers = 2
128 | self.output_size = args.output_size
129 |
130 | self.conv = nn.Conv1d(in_channels=1, out_channels=self.hidden,
131 | kernel_size=5, stride=1, padding=2, padding_mode='replicate')
132 | self.pool = nn.LPPool1d(norm_type=2, kernel_size=2, stride=2)
133 |
134 | self.position = PositionalEmbedding(
135 | max_len=self.latent_len, d_model=self.hidden)
136 | self.layer_norm = LayerNorm(self.hidden)
137 | self.dropout = nn.Dropout(p=self.dropout_rate)
138 |
139 | self.transformer_blocks = nn.ModuleList([TransformerBlock(
140 | self.hidden, self.heads, self.hidden * 4, self.dropout_rate) for _ in range(self.n_layers)])
141 |
142 | self.deconv = nn.ConvTranspose1d(
143 | in_channels=self.hidden, out_channels=self.hidden, kernel_size=4, stride=2, padding=1)
144 | self.linear1 = nn.Linear(self.hidden, 128)
145 | self.linear2 = nn.Linear(128, self.output_size)
146 |
147 | self.truncated_normal_init()
148 |
149 | def truncated_normal_init(self, mean=0, std=0.02, lower=-0.04, upper=0.04):
150 | params = list(self.named_parameters())
151 | for n, p in params:
152 | if 'layer_norm' in n:
153 | continue
154 | else:
155 | with torch.no_grad():
156 | l = (1. + math.erf(((lower - mean) / std) / math.sqrt(2.))) / 2.
157 | u = (1. + math.erf(((upper - mean) / std) / math.sqrt(2.))) / 2.
158 | p.uniform_(2 * l - 1, 2 * u - 1)
159 | p.erfinv_()
160 | p.mul_(std * math.sqrt(2.))
161 | p.add_(mean)
162 |
163 | def forward(self, sequence):
164 | x_token = self.pool(self.conv(sequence.unsqueeze(1))).permute(0, 2, 1)
165 | embedding = x_token + self.position(sequence)
166 | x = self.dropout(self.layer_norm(embedding))
167 |
168 | mask = None
169 | for transformer in self.transformer_blocks:
170 | x = transformer.forward(x, mask)
171 |
172 | x = self.deconv(x.permute(0, 2, 1)).permute(0, 2, 1)
173 | x = torch.tanh(self.linear1(x))
174 | x = self.linear2(x)
175 | return x
176 |
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | from abc import *
2 | from config import *
3 | from dataloader import *
4 |
5 | import numpy as np
6 | import pandas as pd
7 | from pathlib import Path
8 | from collections import defaultdict
9 | import torch.utils.data as data_utils
10 |
11 |
12 | class AbstractDataset(metaclass=ABCMeta):
13 | def __init__(self, args, stats=None):
14 | self.house_indicies = args.house_indicies
15 | self.appliance_names = args.appliance_names
16 | self.normalize = args.normalize
17 | self.sampling = args.sampling
18 | self.cutoff = [args.cutoff[i]
19 | for i in ['aggregate'] + self.appliance_names]
20 |
21 | self.threshold = [args.threshold[i] for i in self.appliance_names]
22 | self.min_on = [args.min_on[i] for i in self.appliance_names]
23 | self.min_off = [args.min_off[i] for i in self.appliance_names]
24 |
25 | self.val_size = args.validation_size
26 | self.window_size = args.window_size
27 | self.window_stride = args.window_stride
28 |
29 | self.x, self.y = self.load_data()
30 | self.status = self.compute_status(self.y)
31 | print('Appliance:', self.appliance_names)
32 | print('Sum of ons:', np.sum(self.status, axis=0))
33 | print('Total length:', self.status.shape[0])
34 |
35 | if stats is None:
36 | self.x_mean = np.mean(self.x, axis=0)
37 | self.x_std = np.std(self.x, axis=0)
38 | else:
39 | self.x_mean, self.x_std = stats
40 |
41 | self.x = (self.x - self.x_mean) / self.x_std
42 |
43 | @classmethod
44 | @abstractmethod
45 | def code(cls):
46 | pass
47 |
48 | @classmethod
49 | def raw_code(cls):
50 | return cls.code()
51 |
52 | @abstractmethod
53 | def load_data(self):
54 | pass
55 |
56 | def get_data(self):
57 | return self.x, self.y, self.status
58 |
59 | def get_original_data(self):
60 | x_org = self.x * self.x_std + self.x_mean
61 | return x_org, self.y, self.status
62 |
63 | def get_mean_std(self):
64 | return self.x_mean, self.x_std
65 |
66 | def compute_status(self, data):
67 | status = np.zeros(data.shape)
68 | if len(data.squeeze().shape) == 1:
69 | columns = 1
70 | else:
71 | columns = data.squeeze().shape[-1]
72 |
73 | if not self.threshold:
74 | self.threshold = [10 for i in range(columns)]
75 | if not self.min_on:
76 | self.min_on = [1 for i in range(columns)]
77 | if not self.min_off:
78 | self.min_off = [1 for i in range(columns)]
79 |
80 | for i in range(columns):
81 | initial_status = data[:, i] >= self.threshold[i]
82 | status_diff = np.diff(initial_status)
83 | events_idx = status_diff.nonzero()
84 |
85 | events_idx = np.array(events_idx).squeeze()
86 | events_idx += 1
87 |
88 | if initial_status[0]:
89 | events_idx = np.insert(events_idx, 0, 0)
90 |
91 | if initial_status[-1]:
92 | events_idx = np.insert(
93 | events_idx, events_idx.size, initial_status.size)
94 |
95 | events_idx = events_idx.reshape((-1, 2))
96 | on_events = events_idx[:, 0].copy()
97 | off_events = events_idx[:, 1].copy()
98 | assert len(on_events) == len(off_events)
99 |
100 | if len(on_events) > 0:
101 | off_duration = on_events[1:] - off_events[:-1]
102 | off_duration = np.insert(off_duration, 0, 1000)
103 | on_events = on_events[off_duration > self.min_off[i]]
104 | off_events = off_events[np.roll(
105 | off_duration, -1) > self.min_off[i]]
106 |
107 | on_duration = off_events - on_events
108 | on_events = on_events[on_duration >= self.min_on[i]]
109 | off_events = off_events[on_duration >= self.min_on[i]]
110 | assert len(on_events) == len(off_events)
111 |
112 | temp_status = data[:, i].copy()
113 | temp_status[:] = 0
114 | for on, off in zip(on_events, off_events):
115 | temp_status[on: off] = 1
116 | status[:, i] = temp_status
117 |
118 | return status
119 |
120 | def get_status(self):
121 | return self.status
122 |
123 | def get_datasets(self):
124 | val_end = int(self.val_size * len(self.x))
125 | val = NILMDataset(self.x[:val_end], self.y[:val_end], self.status[:val_end],
126 | self.window_size, self.window_size)
127 | train = NILMDataset(self.x[val_end:], self.y[val_end:], self.status[val_end:],
128 | self.window_size, self.window_stride)
129 | return train, val
130 |
131 | def get_bert_datasets(self, mask_prob=0.25):
132 | val_end = int(self.val_size * len(self.x))
133 | val = NILMDataset(self.x[:val_end], self.y[:val_end], self.status[:val_end],
134 | self.window_size, self.window_size)
135 | train = BERTDataset(self.x[val_end:], self.y[val_end:], self.status[val_end:],
136 | self.window_size, self.window_stride, mask_prob=mask_prob)
137 | return train, val
138 |
139 | def _get_rawdata_root_path(self):
140 | return Path(RAW_DATASET_ROOT_FOLDER)
141 |
142 | def _get_folder_path(self):
143 | root = self._get_rawdata_root_path()
144 | return root.joinpath(self.raw_code())
145 |
146 |
147 | class REDD_LF_Dataset(AbstractDataset):
148 | @classmethod
149 | def code(cls):
150 | return 'redd_lf'
151 |
152 | @classmethod
153 | def _if_data_exists(self):
154 | folder = Path(RAW_DATASET_ROOT_FOLDER).joinpath(self.code())
155 | first_file = folder.joinpath('house_1', 'channel_1.dat')
156 | if first_file.is_file():
157 | return True
158 | return False
159 |
160 | def load_data(self):
161 | for appliance in self.appliance_names:
162 | assert appliance in ['dishwasher',
163 | 'refrigerator', 'microwave', 'washer_dryer']
164 |
165 | for house_id in self.house_indicies:
166 | assert house_id in [1, 2, 3, 4, 5, 6]
167 |
168 | if not self.cutoff:
169 | self.cutoff = [6000] * (len(self.appliance_names) + 1)
170 |
171 | if not self._if_data_exists():
172 | print('Please download, unzip and move data into',
173 | self._get_folder_path())
174 | raise FileNotFoundError
175 |
176 | else:
177 | directory = self._get_folder_path()
178 |
179 | for house_id in self.house_indicies:
180 | house_folder = directory.joinpath('house_' + str(house_id))
181 | house_label = pd.read_csv(house_folder.joinpath(
182 | 'labels.dat'), sep=' ', header=None)
183 |
184 | main_1 = pd.read_csv(house_folder.joinpath(
185 | 'channel_1.dat'), sep=' ', header=None)
186 | main_2 = pd.read_csv(house_folder.joinpath(
187 | 'channel_2.dat'), sep=' ', header=None)
188 | house_data = pd.merge(main_1, main_2, how='inner', on=0)
189 | house_data.iloc[:, 1] = house_data.iloc[:,
190 | 1] + house_data.iloc[:, 2]
191 | house_data = house_data.iloc[:, 0: 2]
192 |
193 | appliance_list = house_label.iloc[:, 1].values
194 | app_index_dict = defaultdict(list)
195 |
196 | for appliance in self.appliance_names:
197 | data_found = False
198 | for i in range(len(appliance_list)):
199 | if appliance_list[i] == appliance:
200 | app_index_dict[appliance].append(i + 1)
201 | data_found = True
202 |
203 | if not data_found:
204 | app_index_dict[appliance].append(-1)
205 |
206 | if np.sum(list(app_index_dict.values())) == -len(self.appliance_names):
207 | self.house_indicies.remove(house_id)
208 | continue
209 |
210 | for appliance in self.appliance_names:
211 | if app_index_dict[appliance][0] == -1:
212 | temp_values = house_data.copy().iloc[:, 1]
213 | temp_values[:] = 0
214 | temp_data = house_data.copy().iloc[:, :2]
215 | temp_data.iloc[:, 1] = temp_values
216 | else:
217 | temp_data = pd.read_csv(house_folder.joinpath(
218 | 'channel_' + str(app_index_dict[appliance][0]) + '.dat'), sep=' ', header=None)
219 |
220 | if len(app_index_dict[appliance]) > 1:
221 | for idx in app_index_dict[appliance][1:]:
222 | temp_data_ = pd.read_csv(house_folder.joinpath(
223 | 'channel_' + str(idx) + '.dat'), sep=' ', header=None)
224 | temp_data = pd.merge(
225 | temp_data, temp_data_, how='inner', on=0)
226 | temp_data.iloc[:, 1] = temp_data.iloc[:,
227 | 1] + temp_data.iloc[:, 2]
228 | temp_data = temp_data.iloc[:, 0: 2]
229 |
230 | house_data = pd.merge(
231 | house_data, temp_data, how='inner', on=0)
232 |
233 | house_data.iloc[:, 0] = pd.to_datetime(
234 | house_data.iloc[:, 0], unit='s')
235 | house_data.columns = ['time', 'aggregate'] + \
236 | [i for i in self.appliance_names]
237 | house_data = house_data.set_index('time')
238 | house_data = house_data.resample(self.sampling).mean().fillna(
239 | method='ffill', limit=30)
240 |
241 | if house_id == self.house_indicies[0]:
242 | entire_data = house_data
243 | else:
244 | entire_data = entire_data.append(
245 | house_data, ignore_index=True)
246 |
247 | entire_data = entire_data.dropna().copy()
248 | entire_data = entire_data[entire_data['aggregate'] > 0]
249 | entire_data[entire_data < 5] = 0
250 | entire_data = entire_data.clip(
251 | [0] * len(entire_data.columns), self.cutoff, axis=1)
252 |
253 | return entire_data.values[:, 0], entire_data.values[:, 1:]
254 |
255 |
256 | class UK_DALE_Dataset(AbstractDataset):
257 | @classmethod
258 | def code(cls):
259 | return 'uk_dale'
260 |
261 | @classmethod
262 | def _if_data_exists(self):
263 | folder = Path(RAW_DATASET_ROOT_FOLDER).joinpath(self.code())
264 | first_file = folder.joinpath('house_1', 'channel_1.dat')
265 | if first_file.is_file():
266 | return True
267 | return False
268 |
269 | def load_data(self):
270 | for appliance in self.appliance_names:
271 | assert appliance in ['dishwasher', 'fridge',
272 | 'microwave', 'washing_machine', 'kettle']
273 |
274 | for house_id in self.house_indicies:
275 | assert house_id in [1, 2, 3, 4, 5]
276 |
277 | if not self.cutoff:
278 | self.cutoff = [6000] * (len(self.appliance_names) + 1)
279 |
280 | if not self._if_data_exists():
281 | print('Please download, unzip and move data into',
282 | self._get_folder_path())
283 | raise FileNotFoundError
284 |
285 | else:
286 | directory = self._get_folder_path()
287 |
288 | for house_id in self.house_indicies:
289 | house_folder = directory.joinpath('house_' + str(house_id))
290 | house_label = pd.read_csv(house_folder.joinpath(
291 | 'labels.dat'), sep=' ', header=None)
292 |
293 | house_data = pd.read_csv(house_folder.joinpath(
294 | 'channel_1.dat'), sep=' ', header=None)
295 | house_data.iloc[:, 0] = pd.to_datetime(
296 | house_data.iloc[:, 0], unit='s')
297 | house_data.columns = ['time', 'aggregate']
298 | house_data = house_data.set_index('time')
299 | house_data = house_data.resample(self.sampling).mean().fillna(
300 | method='ffill', limit=30)
301 |
302 | appliance_list = house_label.iloc[:, 1].values
303 | app_index_dict = defaultdict(list)
304 |
305 | for appliance in self.appliance_names:
306 | data_found = False
307 | for i in range(len(appliance_list)):
308 | if appliance_list[i] == appliance:
309 | app_index_dict[appliance].append(i + 1)
310 | data_found = True
311 |
312 | if not data_found:
313 | app_index_dict[appliance].append(-1)
314 |
315 | if np.sum(list(app_index_dict.values())) == -len(self.appliance_names):
316 | self.house_indicies.remove(house_id)
317 | continue
318 |
319 | for appliance in self.appliance_names:
320 | if app_index_dict[appliance][0] == -1:
321 | house_data.insert(len(house_data.columns), appliance, np.zeros(len(house_data)))
322 | else:
323 | temp_data = pd.read_csv(house_folder.joinpath(
324 | 'channel_' + str(app_index_dict[appliance][0]) + '.dat'), sep=' ', header=None)
325 | temp_data.iloc[:, 0] = pd.to_datetime(
326 | temp_data.iloc[:, 0], unit='s')
327 | temp_data.columns = ['time', appliance]
328 | temp_data = temp_data.set_index('time')
329 | temp_data = temp_data.resample(self.sampling).mean().fillna(
330 | method='ffill', limit=30)
331 | house_data = pd.merge(
332 | house_data, temp_data, how='inner', on='time')
333 |
334 | if house_id == self.house_indicies[0]:
335 | entire_data = house_data
336 | if len(self.house_indicies) == 1:
337 | entire_data = entire_data.reset_index(drop=True)
338 | else:
339 | entire_data = entire_data.append(
340 | house_data, ignore_index=True)
341 |
342 | entire_data = entire_data.dropna().copy()
343 | entire_data = entire_data[entire_data['aggregate'] > 0]
344 | entire_data[entire_data < 5] = 0
345 | entire_data = entire_data.clip(
346 | [0] * len(entire_data.columns), self.cutoff, axis=1)
347 |
348 | return entire_data.values[:, 0], entire_data.values[:, 1:]
349 |
--------------------------------------------------------------------------------
/trainer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.optim as optim
4 | import torch.nn.functional as F
5 | from torch.autograd.gradcheck import zero_gradients
6 | from tqdm import tqdm
7 |
8 | import os
9 | import json
10 | import random
11 | import numpy as np
12 | from abc import *
13 | from pathlib import Path
14 |
15 | from utils import *
16 | import matplotlib.pyplot as plt
17 |
18 |
19 | torch.set_default_tensor_type(torch.DoubleTensor)
20 |
21 |
22 | class Trainer(metaclass=ABCMeta):
23 | def __init__(self, args, model, train_loader, val_loader, stats, export_root):
24 | self.args = args
25 | self.device = args.device
26 | self.num_epochs = args.num_epochs
27 | self.model = model.to(self.device)
28 | self.export_root = Path(export_root)
29 |
30 | self.cutoff = torch.tensor([args.cutoff[i]
31 | for i in args.appliance_names]).to(self.device)
32 | self.threshold = torch.tensor(
33 | [args.threshold[i] for i in args.appliance_names]).to(self.device)
34 | self.min_on = torch.tensor([args.min_on[i]
35 | for i in args.appliance_names]).to(self.device)
36 | self.min_off = torch.tensor(
37 | [args.min_off[i] for i in args.appliance_names]).to(self.device)
38 |
39 | self.normalize = args.normalize
40 | self.denom = args.denom
41 | if self.normalize == 'mean':
42 | self.x_mean, self.x_std = stats
43 | self.x_mean = torch.tensor(self.x_mean).to(self.device)
44 | self.x_std = torch.tensor(self.x_std).to(self.device)
45 |
46 | self.train_loader = train_loader
47 | self.val_loader = val_loader
48 |
49 | self.optimizer = self._create_optimizer()
50 | if args.enable_lr_schedule:
51 | self.lr_scheduler = optim.lr_scheduler.StepLR(
52 | self.optimizer, step_size=args.decay_step, gamma=args.gamma)
53 |
54 | self.C0 = torch.tensor(args.c0[args.appliance_names[0]]).to(self.device)
55 | print('C0: {}'.format(self.C0))
56 | self.kl = nn.KLDivLoss(reduction='batchmean')
57 | self.mse = nn.MSELoss()
58 | self.margin = nn.SoftMarginLoss()
59 | self.l1_on = nn.L1Loss(reduction='sum')
60 |
61 | def train(self):
62 | val_rel_err, val_abs_err = [], []
63 | val_acc, val_precision, val_recall, val_f1 = [], [], [], []
64 |
65 | best_rel_err, _, best_acc, _, _, best_f1 = self.validate()
66 | self._save_state_dict()
67 |
68 | for epoch in range(self.num_epochs):
69 | self.train_bert_one_epoch(epoch + 1)
70 |
71 | rel_err, abs_err, acc, precision, recall, f1 = self.validate()
72 | val_rel_err.append(rel_err.tolist())
73 | val_abs_err.append(abs_err.tolist())
74 | val_acc.append(acc.tolist())
75 | val_precision.append(precision.tolist())
76 | val_recall.append(recall.tolist())
77 | val_f1.append(f1.tolist())
78 |
79 | if f1.mean() + acc.mean() - rel_err.mean() > best_f1.mean() + best_acc.mean() - best_rel_err.mean():
80 | best_f1 = f1
81 | best_acc = acc
82 | best_rel_err = rel_err
83 | self._save_state_dict()
84 |
85 | def train_one_epoch(self, epoch):
86 | loss_values = []
87 | self.model.train()
88 | tqdm_dataloader = tqdm(self.train_loader)
89 | for batch_idx, batch in enumerate(tqdm_dataloader):
90 | seqs, labels_energy, status = batch
91 | seqs, labels_energy, status = seqs.to(self.device), labels_energy.to(self.device), status.to(self.device)
92 | self.optimizer.zero_grad()
93 | logits = self.model(seqs)
94 | labels = labels_energy / self.cutoff
95 | logits_energy = self.cutoff_energy(logits * self.cutoff)
96 | logits_status = self.compute_status(logits_energy)
97 |
98 | kl_loss = self.kl(torch.log(F.softmax(logits.squeeze() / 0.1, dim=-1) + 1e-9), F.softmax(labels.squeeze() / 0.1, dim=-1))
99 | mse_loss = self.mse(logits.contiguous().view(-1).double(),
100 | labels.contiguous().view(-1).double())
101 | margin_loss = self.margin((logits_status * 2 - 1).contiguous().view(-1).double(),
102 | (status * 2 - 1).contiguous().view(-1).double())
103 | total_loss = kl_loss + mse_loss + margin_loss
104 |
105 | on_mask = ((status == 1) + (status != logits_status.reshape(status.shape))) >= 1
106 | if on_mask.sum() > 0:
107 | total_size = torch.tensor(on_mask.shape).prod()
108 | logits_on = torch.masked_select(logits.reshape(on_mask.shape), on_mask)
109 | labels_on = torch.masked_select(labels.reshape(on_mask.shape), on_mask)
110 | loss_l1_on = self.l1_on(logits_on.contiguous().view(-1),
111 | labels_on.contiguous().view(-1))
112 | total_loss += self.C0 * loss_l1_on / total_size
113 |
114 | total_loss.backward()
115 | self.optimizer.step()
116 | loss_values.append(total_loss.item())
117 |
118 | average_loss = np.mean(np.array(loss_values))
119 | tqdm_dataloader.set_description('Epoch {}, loss {:.2f}'.format(epoch, average_loss))
120 |
121 | if self.args.enable_lr_schedule:
122 | self.lr_scheduler.step()
123 |
124 | def train_bert_one_epoch(self, epoch):
125 | loss_values = []
126 | self.model.train()
127 | tqdm_dataloader = tqdm(self.train_loader)
128 | for batch_idx, batch in enumerate(tqdm_dataloader):
129 | seqs, labels_energy, status = batch
130 | seqs, labels_energy, status = seqs.to(self.device), labels_energy.to(self.device), status.to(self.device)
131 | batch_shape = status.shape
132 | self.optimizer.zero_grad()
133 | logits = self.model(seqs)
134 | labels = labels_energy / self.cutoff
135 | logits_energy = self.cutoff_energy(logits * self.cutoff)
136 | logits_status = self.compute_status(logits_energy)
137 |
138 | mask = (status >= 0)
139 | labels_masked = torch.masked_select(labels, mask).view((-1, batch_shape[-1]))
140 | logits_masked = torch.masked_select(logits, mask).view((-1, batch_shape[-1]))
141 | status_masked = torch.masked_select(status, mask).view((-1, batch_shape[-1]))
142 | logits_status_masked = torch.masked_select(logits_status, mask).view((-1, batch_shape[-1]))
143 |
144 | kl_loss = self.kl(torch.log(F.softmax(logits_masked.squeeze() / 0.1, dim=-1) + 1e-9), F.softmax(labels_masked.squeeze() / 0.1, dim=-1))
145 | mse_loss = self.mse(logits_masked.contiguous().view(-1).double(),
146 | labels_masked.contiguous().view(-1).double())
147 | margin_loss = self.margin((logits_status_masked * 2 - 1).contiguous().view(-1).double(),
148 | (status_masked * 2 - 1).contiguous().view(-1).double())
149 | total_loss = kl_loss + mse_loss + margin_loss
150 |
151 | on_mask = (status >= 0) * (((status == 1) + (status != logits_status.reshape(status.shape))) >= 1)
152 | if on_mask.sum() > 0:
153 | total_size = torch.tensor(on_mask.shape).prod()
154 | logits_on = torch.masked_select(logits.reshape(on_mask.shape), on_mask)
155 | labels_on = torch.masked_select(labels.reshape(on_mask.shape), on_mask)
156 | loss_l1_on = self.l1_on(logits_on.contiguous().view(-1),
157 | labels_on.contiguous().view(-1))
158 | total_loss += self.C0 * loss_l1_on / total_size
159 |
160 | total_loss.backward()
161 | self.optimizer.step()
162 | loss_values.append(total_loss.item())
163 |
164 | average_loss = np.mean(np.array(loss_values))
165 | tqdm_dataloader.set_description('Epoch {}, loss {:.2f}'.format(epoch, average_loss))
166 |
167 | if self.args.enable_lr_schedule:
168 | self.lr_scheduler.step()
169 |
170 | def validate(self):
171 | self.model.eval()
172 | loss_values, relative_errors, absolute_errors = [], [], []
173 | acc_values, precision_values, recall_values, f1_values, = [], [], [], []
174 |
175 | with torch.no_grad():
176 | tqdm_dataloader = tqdm(self.val_loader)
177 | for batch_idx, batch in enumerate(tqdm_dataloader):
178 | seqs, labels_energy, status = batch
179 | seqs, labels_energy, status = seqs.to(self.device), labels_energy.to(self.device), status.to(self.device)
180 | logits = self.model(seqs)
181 | labels = labels_energy / self.cutoff
182 | logits_energy = self.cutoff_energy(logits * self.cutoff)
183 | logits_status = self.compute_status(logits_energy)
184 | logits_energy = logits_energy * logits_status
185 |
186 | rel_err, abs_err = relative_absolute_error(logits_energy.detach(
187 | ).cpu().numpy().squeeze(), labels_energy.detach().cpu().numpy().squeeze())
188 | relative_errors.append(rel_err.tolist())
189 | absolute_errors.append(abs_err.tolist())
190 |
191 | acc, precision, recall, f1 = acc_precision_recall_f1_score(logits_status.detach(
192 | ).cpu().numpy().squeeze(), status.detach().cpu().numpy().squeeze())
193 | acc_values.append(acc.tolist())
194 | precision_values.append(precision.tolist())
195 | recall_values.append(recall.tolist())
196 | f1_values.append(f1.tolist())
197 |
198 | average_acc = np.mean(np.array(acc_values).reshape(-1))
199 | average_f1 = np.mean(np.array(f1_values).reshape(-1))
200 | average_rel_err = np.mean(np.array(relative_errors).reshape(-1))
201 |
202 | tqdm_dataloader.set_description('Validation, rel_err {:.2f}, acc {:.2f}, f1 {:.2f}'.format(
203 | average_rel_err, average_acc, average_f1))
204 |
205 | return_rel_err = np.array(relative_errors).mean(axis=0)
206 | return_abs_err = np.array(absolute_errors).mean(axis=0)
207 | return_acc = np.array(acc_values).mean(axis=0)
208 | return_precision = np.array(precision_values).mean(axis=0)
209 | return_recall = np.array(recall_values).mean(axis=0)
210 | return_f1 = np.array(f1_values).mean(axis=0)
211 | return return_rel_err, return_abs_err, return_acc, return_precision, return_recall, return_f1
212 |
213 | def test(self, test_loader):
214 | self._load_best_model()
215 | self.model.eval()
216 | loss_values, relative_errors, absolute_errors = [], [], []
217 | acc_values, precision_values, recall_values, f1_values, = [], [], [], []
218 |
219 | label_curve = []
220 | e_pred_curve = []
221 | status_curve = []
222 | s_pred_curve = []
223 | with torch.no_grad():
224 | tqdm_dataloader = tqdm(test_loader)
225 | for batch_idx, batch in enumerate(tqdm_dataloader):
226 | seqs, labels_energy, status = batch
227 | seqs, labels_energy, status = seqs.to(self.device), labels_energy.to(self.device), status.to(self.device)
228 | logits = self.model(seqs)
229 | labels = labels_energy / self.cutoff
230 | logits_energy = self.cutoff_energy(logits * self.cutoff)
231 | logits_status = self.compute_status(logits_energy)
232 | logits_energy = logits_energy * logits_status
233 |
234 | acc, precision, recall, f1 = acc_precision_recall_f1_score(logits_status.detach(
235 | ).cpu().numpy().squeeze(), status.detach().cpu().numpy().squeeze())
236 | acc_values.append(acc.tolist())
237 | precision_values.append(precision.tolist())
238 | recall_values.append(recall.tolist())
239 | f1_values.append(f1.tolist())
240 |
241 | rel_err, abs_err = relative_absolute_error(logits_energy.detach(
242 | ).cpu().numpy().squeeze(), labels_energy.detach().cpu().numpy().squeeze())
243 | relative_errors.append(rel_err.tolist())
244 | absolute_errors.append(abs_err.tolist())
245 |
246 | average_acc = np.mean(np.array(acc_values).reshape(-1))
247 | average_f1 = np.mean(np.array(f1_values).reshape(-1))
248 | average_rel_err = np.mean(np.array(relative_errors).reshape(-1))
249 |
250 | tqdm_dataloader.set_description('Test, rel_err {:.2f}, acc {:.2f}, f1 {:.2f}'.format(
251 | average_rel_err, average_acc, average_f1))
252 |
253 | label_curve.append(labels_energy.detach().cpu().numpy().tolist())
254 | e_pred_curve.append(logits_energy.detach().cpu().numpy().tolist())
255 | status_curve.append(status.detach().cpu().numpy().tolist())
256 | s_pred_curve.append(logits_status.detach().cpu().numpy().tolist())
257 |
258 | label_curve = np.concatenate(label_curve).reshape(-1, self.args.output_size)
259 | e_pred_curve = np.concatenate(e_pred_curve).reshape(-1, self.args.output_size)
260 | status_curve = np.concatenate(status_curve).reshape(-1, self.args.output_size)
261 | s_pred_curve = np.concatenate(s_pred_curve).reshape(-1, self.args.output_size)
262 |
263 | self._save_result({'gt': label_curve.tolist(),
264 | 'pred': e_pred_curve.tolist()}, 'test_result.json')
265 |
266 | if self.args.output_size > 1:
267 | return_rel_err = np.array(relative_errors).mean(axis=0)
268 | else:
269 | return_rel_err = np.array(relative_errors).mean()
270 | return_rel_err, return_abs_err = relative_absolute_error(e_pred_curve, label_curve)
271 | return_acc, return_precision, return_recall, return_f1 = acc_precision_recall_f1_score(s_pred_curve, status_curve)
272 |
273 | return return_rel_err, return_abs_err, return_acc, return_precision, return_recall, return_f1
274 |
275 | def cutoff_energy(self, data):
276 | columns = data.squeeze().shape[-1]
277 |
278 | if self.cutoff.size(0) == 0:
279 | self.cutoff = torch.tensor(
280 | [3100 for i in range(columns)]).to(self.device)
281 |
282 | data[data < 5] = 0
283 | data = torch.min(data, self.cutoff.double())
284 | return data
285 |
286 | def compute_status(self, data):
287 | data_shape = data.shape
288 | columns = data.squeeze().shape[-1]
289 |
290 | if self.threshold.size(0) == 0:
291 | self.threshold = torch.tensor(
292 | [10 for i in range(columns)]).to(self.device)
293 |
294 | status = (data >= self.threshold) * 1
295 | return status
296 |
297 | def _create_optimizer(self):
298 | args = self.args
299 | param_optimizer = list(self.model.named_parameters())
300 | no_decay = ['bias', 'layer_norm']
301 | optimizer_grouped_parameters = [
302 | {
303 | 'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
304 | 'weight_decay': args.weight_decay,
305 | },
306 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
307 | ]
308 | if args.optimizer.lower() == 'adamw':
309 | return optim.AdamW(optimizer_grouped_parameters, lr=args.lr)
310 | elif args.optimizer.lower() == 'adam':
311 | return optim.Adam(optimizer_grouped_parameters, lr=args.lr)
312 | elif args.optimizer.lower() == 'sgd':
313 | return optim.SGD(optimizer_grouped_parameters, lr=args.lr, momentum=args.momentum)
314 | else:
315 | raise ValueError
316 |
317 | def _load_best_model(self):
318 | try:
319 | self.model.load_state_dict(torch.load(
320 | self.export_root.joinpath('best_acc_model.pth')))
321 | self.model.to(self.device)
322 | except:
323 | print('Failed to load best model, continue testing with current model...')
324 |
325 | def _save_state_dict(self):
326 | if not os.path.exists(self.export_root):
327 | os.makedirs(self.export_root)
328 | print('Saving best model...')
329 | torch.save(self.model.state_dict(),
330 | self.export_root.joinpath('best_acc_model.pth'))
331 |
332 | def _save_values(self, filename):
333 | if not os.path.exists(self.export_root):
334 | os.makedirs(self.export_root)
335 | torch.save(self.model.state_dict(),
336 | self.export_root.joinpath('best_acc_model.pth'))
337 |
338 | def _save_result(self, data, filename):
339 | if not os.path.exists(self.export_root):
340 | os.makedirs(self.export_root)
341 | filepath = Path(self.export_root).joinpath(filename)
342 | with filepath.open('w') as f:
343 | json.dump(data, f, indent=2)
344 |
--------------------------------------------------------------------------------