├── .gitignore ├── README.md ├── utils.py ├── dataloader.py ├── train.py └── CTLSTM.py /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | model/ 3 | .vscode/ 4 | __pycache__/ 5 | *.txt 6 | *.sh 7 | log/ 8 | error/ 9 | .env/ 10 | .DS_Store 11 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A simple Pytorch implementation of neural Hawkes process 2 | 3 | This repository is a simple PyTorch implementation of Neural Hawkes Process from paper *[The Neural Hawkes Process: A Neurally Self-Modulating Multivariate Point Process](https://arxiv.org/abs/1612.09328)*. 4 | 5 | ## Data format 6 | The data can obtain from the author's [GitHub page](https://github.com/HMEIatJHU/neurawkes). Download the pickle files under data folder. 7 | 8 | Here is an example of the training set. 9 | ```python 10 | { 11 | 'train':[ 12 | [{'type_event': 'A', 13 | 'time_since_last_event': 1.0}, # First event 14 | {'type_event': 'B', 15 | 'time_since_last_event': 1.2}, 16 | {'type_event': 'C', 17 | 'time_since_last_event': 2.0},], # First sequence 18 | 19 | [{'type_event': 'B', 20 | 'time_since_last_event': 1.1}, 21 | {'type_event': 'A', 22 | 'time_since_last_event': 0.6}, 23 | {'type_event': 'C', 24 | 'time_since_last_event': 2.3},], # Second sequence 25 | ], 26 | 27 | 'dev': [], # Only not empty in the development dataset 28 | 'test': [], # Only not empty in the test dataset 29 | } 30 | ``` 31 | 32 | 33 | ## How to Run 34 | A quick look of the code: 35 | ```python 36 | python train.py 37 | ``` 38 | 39 | Please use Pytorch version 1.0 and later. 40 | 41 | 42 | 43 | ## What's more 44 | This repository serves as an understanding of the neural Hawkes process, so not all experiments in the paper are implemented and tested. Feel free to take it and implement your own. 45 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """Utility functions for CTLSTM model.""" 2 | 3 | import torch 4 | 5 | 6 | def generate_sim_time_seqs(time_seqs, seqs_length): 7 | """Generate a simulated time interval sequences from original time interval sequences based on uniform distribution 8 | 9 | Args: 10 | time_seqs: list of torch float tensors 11 | Results: 12 | sim_time_seqs: list of torch float tensors 13 | sim_index_seqs: list of torch long tensors 14 | """ 15 | sim_time_seqs = torch.zeros((time_seqs.size()[0], time_seqs.size()[1]-1)).float() 16 | sim_index_seqs = torch.zeros((time_seqs.size()[0], time_seqs.size()[1]-1)).long() 17 | restore_time_seqs, restore_sim_time_seqs = [], [] 18 | for idx, time_seq in enumerate(time_seqs): 19 | restore_time_seq = torch.stack([torch.sum(time_seq[0:i]) for i in range(1,seqs_length[idx]+1)]) 20 | restore_sim_time_seq, _ = torch.sort(torch.empty(seqs_length[idx]-1).uniform_(0, restore_time_seq[-1])) 21 | 22 | sim_time_seq = torch.zeros(seqs_length[idx]-1) 23 | sim_index_seq = torch.zeros(seqs_length[idx]-1).long() 24 | 25 | for idx_t, t in enumerate(restore_time_seq): 26 | indices_to_update = restore_sim_time_seq > t 27 | 28 | sim_time_seq[indices_to_update] = restore_sim_time_seq[indices_to_update] - t 29 | sim_index_seq[indices_to_update] = idx_t 30 | 31 | restore_time_seqs.append(restore_time_seq) 32 | restore_sim_time_seqs.append(restore_sim_time_seq) 33 | sim_time_seqs[idx, :seqs_length[idx]-1] = sim_time_seq 34 | sim_index_seqs[idx, :seqs_length[idx]-1] = sim_index_seq 35 | 36 | return sim_time_seqs, sim_index_seqs 37 | 38 | 39 | def pad_bos(batch_data, type_size): 40 | event_seqs, time_seqs, total_time_seqs, seqs_length = batch_data 41 | pad_event_seqs = torch.zeros((event_seqs.size()[0], event_seqs.size()[1]+1)).long() * type_size 42 | pad_time_seqs = torch.zeros((time_seqs.size()[0], event_seqs.size()[1]+1)).float() 43 | 44 | pad_event_seqs[:, 1:] = event_seqs.clone() 45 | pad_event_seqs[:, 0] = type_size 46 | pad_time_seqs[:, 1:] = time_seqs.clone() 47 | 48 | return pad_event_seqs, pad_time_seqs, total_time_seqs, seqs_length 49 | 50 | 51 | if __name__ == '__main__': 52 | a = torch.tensor([0., 1., 2., 3., 4., 5.]) 53 | b = torch.tensor([0., 2., 4., 6., 0., 0.]) 54 | 55 | sim_time_seqs, sim_index_seqs, restore_time_seqs, restore_sim_time_seqs =\ 56 | generate_sim_time_seqs(torch.stack([a,b]), torch.LongTensor([6,4])) 57 | 58 | 59 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Dataloader for neural hawkes process. 3 | 4 | Args: 5 | path: file path for the dataset 6 | batch_size: size of one batch 7 | 8 | Results: 9 | A Dataset class 10 | """ 11 | 12 | import pickle 13 | import torch 14 | from torch.utils.data import Dataset, DataLoader 15 | 16 | class CTLSTMDataset(Dataset): 17 | ''' Dataset class for neural hawkes data 18 | ''' 19 | def __init__(self, file_path): 20 | self.file_path = file_path 21 | self.event_seqs = [] 22 | self.time_seqs = [] 23 | 24 | with open(self.file_path, 'rb') as f: 25 | if 'dev' in file_path: 26 | seqs = pickle.load(f, encoding='latin1')['dev'] 27 | elif 'train' in file_path: 28 | seqs = pickle.load(f, encoding='latin1')['train'] 29 | for idx, seq in enumerate(seqs): 30 | # if idx == 1: 31 | # print(seq[0].keys()) 32 | self.event_seqs.append(torch.LongTensor([int(event['type_event']) for event in seq])) 33 | self.time_seqs.append(torch.FloatTensor([float(event['time_since_last_event']) for event in seq])) 34 | 35 | def __len__(self): 36 | return len(self.event_seqs) 37 | 38 | def __getitem__(self, index): 39 | sample = { 40 | 'event_seq': self.event_seqs[index], 41 | 'time_seq': self.time_seqs[index] 42 | } 43 | 44 | return sample 45 | 46 | # def pad_batch_fn(batch_data): 47 | # sorted_batch = sorted(batch_data, key=lambda x: len(x['event_seq']), reverse=True) 48 | 49 | # event_seqs = [seq['event_seq'] for seq in sorted_batch] 50 | # time_seqs = [seq['time_seq'] for seq in sorted_batch] 51 | # seqs_length = list(map(len, event_seqs)) 52 | 53 | # for idx, (event_seq, time_seq, seq_length) in enumerate(zip(event_seqs, time_seqs, seqs_length)): 54 | # tmp_event_seq = torch.zeros(seqs_length[0]) 55 | # tmp_event_seq[:seq_length] = torch.IntTensor(event_seq) 56 | # event_seqs[idx] = tmp_event_seq 57 | 58 | # tmp_time_seq = torch.zeros(seqs_length[0]) 59 | # tmp_time_seq[:seq_length] = torch.FloatTensor(time_seq) 60 | # time_seqs[idx] = tmp_time_seq 61 | 62 | # return event_seqs, time_seqs, seqs_length 63 | 64 | def pad_batch_fn(batch_data): 65 | sorted_batch = sorted(batch_data, key=lambda x: x['event_seq'].size(), reverse=True) 66 | event_seqs = [seq['event_seq'].long() for seq in sorted_batch] 67 | time_seqs = [seq['time_seq'].float() for seq in sorted_batch] 68 | seqs_length = torch.LongTensor(list(map(len, event_seqs))) 69 | last_time_seqs = torch.stack([torch.sum(time_seq) for time_seq in time_seqs]) 70 | 71 | event_seqs_tensor = torch.zeros(len(sorted_batch), seqs_length.max()).long() 72 | time_seqs_tensor = torch.zeros(len(sorted_batch), seqs_length.max()).float() 73 | 74 | for idx, (event_seq, time_seq, seqlen) in enumerate(zip(event_seqs, time_seqs, seqs_length)): 75 | event_seqs_tensor[idx, :seqlen] = torch.LongTensor(event_seq) 76 | time_seqs_tensor[idx, :seqlen] = torch.FloatTensor(time_seq) 77 | 78 | return event_seqs_tensor, time_seqs_tensor, last_time_seqs, seqs_length 79 | 80 | # def restore_batch(sample_batched, type_size): 81 | # event_seqs, time_seqs, seqs_length = sample_batched 82 | 83 | # event_seqs_list, time_seqs_list = [], [] 84 | # total_time_list = [] 85 | 86 | # for idx, (event_seq, time_seq, seq_length) in enumerate(zip(event_seqs, time_seqs, seqs_length)): 87 | # tmp_event_seq = torch.ones(seq_length + 1, dtype=torch.int32) * type_size 88 | # tmp_event_seq[1:] = event_seq[:seq_length] 89 | # event_seqs_list.append(tmp_event_seq) 90 | 91 | # tmp_time_seq = torch.zeros(seq_length + 1, dtype=torch.float) 92 | # tmp_time_seq[1:] = time_seq[:seq_length] 93 | # time_seqs_list.append(tmp_time_seq) 94 | 95 | # total_time_list.append(torch.sum(tmp_time_seq)) 96 | 97 | # return event_seqs_list, time_seqs_list, total_time_list 98 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Training code for neural hawkes model.""" 3 | import time 4 | import datetime 5 | import torch 6 | import torch.optim as opt 7 | from torch.utils.data import DataLoader 8 | 9 | import dataloader 10 | import CTLSTM 11 | import utils 12 | 13 | 14 | def train(settings): 15 | """Training process.""" 16 | hidden_size = settings['hidden_size'] 17 | type_size = settings['type_size'] 18 | train_path = settings['train_path'] 19 | dev_path = settings['dev_path'] 20 | batch_size = settings['batch_size'] 21 | epoch_num = settings['epoch_num'] 22 | current_date = settings['current_date'] 23 | 24 | model = CTLSTM.CTLSTM(hidden_size, type_size) 25 | optim = opt.Adam(model.parameters()) 26 | train_dataset = dataloader.CTLSTMDataset(train_path) 27 | dev_dataset = dataloader.CTLSTMDataset(dev_path) 28 | 29 | train_dataloader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=dataloader.pad_batch_fn, shuffle=True) 30 | dev_dataloader = DataLoader(dev_dataset, collate_fn=dataloader.pad_batch_fn, shuffle=True) 31 | 32 | last_dev_loss = 0.0 33 | for epoch in range(epoch_num): 34 | tic_epoch = time.time() 35 | epoch_train_loss = 0.0 36 | epoch_dev_loss = 0.0 37 | train_event_num = 0 38 | dev_event_num = 0 39 | print('Epoch.{} starts.'.format(epoch)) 40 | tic_train = time.time() 41 | for i_batch, sample_batched in enumerate(train_dataloader): 42 | tic_batch = time.time() 43 | 44 | optim.zero_grad() 45 | 46 | event_seqs, time_seqs, total_time_seqs, seqs_length = utils.pad_bos(sample_batched, model.type_size) 47 | 48 | sim_time_seqs, sim_index_seqs = utils.generate_sim_time_seqs(time_seqs, seqs_length) 49 | 50 | model.forward(event_seqs, time_seqs) 51 | likelihood = model.log_likelihood(event_seqs, sim_time_seqs, sim_index_seqs, total_time_seqs, seqs_length) 52 | batch_event_num = torch.sum(seqs_length) 53 | batch_loss = -likelihood 54 | 55 | batch_loss.backward() 56 | optim.step() 57 | 58 | toc_batch = time.time() 59 | if i_batch % 100 == 0: 60 | print('Epoch.{} Batch.{}:\nBatch Likelihood per event: {:5f} nats\nTrain Time: {:2f} s'.format(epoch, i_batch, likelihood/batch_event_num, toc_batch-tic_batch)) 61 | epoch_train_loss += batch_loss 62 | train_event_num += batch_event_num 63 | 64 | toc_train = time.time() 65 | print('---\nEpoch.{} Training set\nTrain Likelihood per event: {:5f} nats\nTrainig Time:{:2f} s'.format(epoch, -epoch_train_loss/train_event_num, toc_train-tic_train)) 66 | 67 | tic_eval = time.time() 68 | for i_batch, sample_batched in enumerate(dev_dataloader): 69 | event_seqs, time_seqs, total_time_seqs, seqs_length = utils.pad_bos(sample_batched, model.type_size) 70 | sim_time_seqs, sim_index_seqs = utils.generate_sim_time_seqs(time_seqs, seqs_length) 71 | model.forward(event_seqs, time_seqs) 72 | likelihood = model.log_likelihood(event_seqs, sim_time_seqs, sim_index_seqs, total_time_seqs,seqs_length) 73 | 74 | dev_event_num += torch.sum(seqs_length) 75 | epoch_dev_loss -= likelihood 76 | 77 | toc_eval = time.time() 78 | toc_epoch = time.time() 79 | print('Epoch.{} Devlopment set\nDev Likelihood per event: {:5f} nats\nEval Time:{:2f}s.\n'.format(epoch, -epoch_dev_loss/dev_event_num, toc_eval-tic_eval)) 80 | 81 | with open("loss_{}.txt".format(current_date), 'a') as l: 82 | l.write("Epoch {}:\n".format(epoch)) 83 | l.write("Train Event Number:\t\t{}\n".format(train_event_num)) 84 | l.write("Train Likelihood per event:\t{:.5f}\n".format(-epoch_train_loss/train_event_num)) 85 | l.write("Training time:\t\t\t{:.2f} s\n".format(toc_train-tic_train)) 86 | l.write("Dev Event Number:\t\t{}\n".format(dev_event_num)) 87 | l.write("Dev Likelihood per event:\t{:.5f}\n".format(-epoch_dev_loss/dev_event_num)) 88 | l.write("Dev evaluating time:\t\t{:.2f} s\n".format(toc_eval-tic_eval)) 89 | l.write("Epoch time:\t\t\t{:.2f} s\n".format(toc_epoch-tic_epoch)) 90 | l.write("\n") 91 | 92 | gap = epoch_dev_loss/dev_event_num - last_dev_loss 93 | if abs(gap) < 1e-4: 94 | print('Final log likelihood: {} nats'.format(-epoch_dev_loss/dev_event_num)) 95 | break 96 | 97 | last_dev_loss = epoch_dev_loss/dev_event_num 98 | 99 | return 100 | 101 | 102 | if __name__ == "__main__": 103 | settings = { 104 | 'hidden_size': 32, 105 | 'type_size': 8, 106 | 'train_path': 'data/train.pkl', 107 | 'dev_path': 'data/dev.pkl', 108 | 'batch_size': 32, 109 | 'epoch_num': 100, 110 | 'current_date': datetime.date.today() 111 | } 112 | 113 | train(settings) 114 | -------------------------------------------------------------------------------- /CTLSTM.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """A continuous time LSTM network.""" 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | import utils 8 | import dataloader 9 | from torch.utils.data import DataLoader 10 | 11 | 12 | class CTLSTM(nn.Module): 13 | """Continuous time LSTM network with decay function.""" 14 | def __init__(self, hidden_size, type_size, batch_first=True): 15 | super(CTLSTM, self).__init__() 16 | 17 | self.hidden_size = hidden_size 18 | self.type_size = type_size 19 | self.batch_first = batch_first 20 | self.num_layers = 1 21 | 22 | # Parameters 23 | # recurrent cells 24 | self.rec = nn.Linear(2*self.hidden_size, 7*self.hidden_size) 25 | # output mapping from hidden vectors to unnormalized intensity 26 | self.wa = nn.Linear(self.hidden_size, self.type_size) 27 | # embedding layer for valid events, including BOS 28 | self.emb = nn.Embedding(self.type_size+1, self.hidden_size) 29 | 30 | def init_states(self, batch_size): 31 | self.h_d = torch.zeros(batch_size, self.hidden_size, dtype=torch.float) 32 | self.c_d = torch.zeros(batch_size, self.hidden_size, dtype=torch.float) 33 | self.c_bar = torch.zeros(batch_size, self.hidden_size, dtype=torch.float) 34 | self.c = torch.zeros(batch_size, self.hidden_size, dtype=torch.float) 35 | 36 | def recurrence(self, emb_event_t, h_d_tm1, c_tm1, c_bar_tm1): 37 | feed = torch.cat((emb_event_t, h_d_tm1), dim=1) 38 | # B * 2H 39 | (gate_i, 40 | gate_f, 41 | gate_z, 42 | gate_o, 43 | gate_i_bar, 44 | gate_f_bar, 45 | gate_delta) = torch.chunk(self.rec(feed), 7, -1) 46 | 47 | gate_i = torch.sigmoid(gate_i) 48 | gate_f = torch.sigmoid(gate_f) 49 | gate_z = torch.tanh(gate_z) 50 | gate_o = torch.sigmoid(gate_o) 51 | gate_i_bar = torch.sigmoid(gate_i_bar) 52 | gate_f_bar = torch.sigmoid(gate_f_bar) 53 | gate_delta = F.softplus(gate_delta) 54 | 55 | c_t = gate_f * c_tm1 + gate_i * gate_z 56 | c_bar_t = gate_f_bar * c_bar_tm1 + gate_i_bar * gate_z 57 | 58 | return c_t, c_bar_t, gate_o, gate_delta 59 | 60 | def decay(self, c_t, c_bar_t, o_t, delta_t, duration_t): 61 | c_d_t = c_bar_t + (c_t - c_bar_t) * \ 62 | torch.exp(-delta_t * duration_t.view(-1,1)) 63 | 64 | h_d_t = o_t * torch.tanh(c_d_t) 65 | 66 | return c_d_t, h_d_t 67 | 68 | def forward(self, event_seqs, duration_seqs, batch_first = True): 69 | if batch_first: 70 | event_seqs = event_seqs.transpose(0,1) 71 | duration_seqs = duration_seqs.transpose(0,1) 72 | 73 | batch_size = event_seqs.size()[1] 74 | batch_length = event_seqs.size()[0] 75 | 76 | h_list, c_list, c_bar_list, o_list, delta_list = [], [], [], [], [] 77 | 78 | for t in range(batch_length): 79 | self.init_states(batch_size) 80 | c, self.c_bar, o_t, delta_t = self.recurrence(self.emb(event_seqs[t]), self.h_d, self.c_d, self.c_bar) 81 | self.c_d, self.h_d = self.decay(c, self.c_bar, o_t, delta_t, duration_seqs[t]) 82 | h_list.append(self.h_d) 83 | c_list.append(c) 84 | c_bar_list.append(self.c_bar) 85 | o_list.append(o_t) 86 | delta_list.append(delta_t) 87 | h_seq = torch.stack(h_list) 88 | c_seq = torch.stack(c_list) 89 | c_bar_seq = torch.stack(c_bar_list) 90 | o_seq = torch.stack(o_list) 91 | delta_seq = torch.stack(delta_list) 92 | 93 | self.output = torch.stack((h_seq, c_seq, c_bar_seq, o_seq, delta_seq)) 94 | return self.output 95 | 96 | def log_likelihood(self, event_seqs, sim_time_seqs, sim_index_seqs, total_time_seqs, seqs_length, batch_first=True): 97 | """Calculate log likelihood per sequence.""" 98 | batch_size, batch_length = event_seqs.shape 99 | h, c, c_bar, o, delta = torch.chunk(self.output, 5, 0) 100 | # L * B * H 101 | h = torch.squeeze(h, 0) 102 | c = torch.squeeze(c, 0) 103 | c_bar = torch.squeeze(c_bar, 0) 104 | o = torch.squeeze(o, 0) 105 | delta = torch.squeeze(delta, 0) 106 | 107 | # Calculate the sum of log intensities of each event in the sequence 108 | original_loglikelihood = torch.zeros(batch_size) 109 | lambda_k = F.softplus(self.wa(h)).transpose(0, 1) 110 | 111 | for idx, (event_seq, seq_len) in enumerate(zip(event_seqs, seqs_length)): 112 | original_loglikelihood[idx] = torch.sum(torch.log( 113 | lambda_k[idx, torch.arange(seq_len).long(), event_seq[1:seq_len+1]])) 114 | 115 | # Calculate simulated loss from MCMC method 116 | h_d_list = [] 117 | if batch_first: 118 | sim_time_seqs = sim_time_seqs.transpose(0,1) 119 | for idx, sim_duration in enumerate(sim_time_seqs): 120 | _, h_d_idx = self.decay(c[idx], c_bar[idx], o[idx], delta[idx], sim_duration) 121 | h_d_list.append(h_d_idx) 122 | h_d = torch.stack(h_d_list) 123 | 124 | sim_lambda_k = F.softplus(self.wa(h_d)).transpose(0,1) 125 | simulated_likelihood = torch.zeros(batch_size) 126 | for idx, (total_time, seq_len) in enumerate(zip(total_time_seqs, seqs_length)): 127 | mc_coefficient = total_time / (seq_len) 128 | simulated_likelihood[idx] = mc_coefficient * torch.sum(torch.sum(sim_lambda_k[idx, torch.arange(seq_len).long(), :])) 129 | 130 | loglikelihood = torch.sum(original_loglikelihood - simulated_likelihood) 131 | return loglikelihood --------------------------------------------------------------------------------