├── HiddenMarkovModel.ipynb ├── HiddenMarkovModel.py ├── README.md └── images ├── BW.png ├── BW1.png ├── BW2.png ├── BW3.png ├── BW4.png ├── BW5.png ├── BW6.png ├── EM1.png ├── EM2.png ├── EM3.jpg ├── FB1.png ├── FB2.png ├── FB3.png ├── FB4.png ├── Screen Shot 2016-05-08 at 8.27.19 PM.png ├── Viterbi.gif ├── eq1.png ├── eq2.png ├── eq3.png ├── eq4.png ├── eq5.png ├── eq6.png ├── eq7.png ├── eq8.png ├── eq9.png ├── graph1.gif ├── graph2.gif ├── graph3.gif ├── hmm.png ├── trans.png ├── trans2.png ├── viterbi.png ├── viterbi2.png ├── viterbi3.png └── viterbi4.png /HiddenMarkovModel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import pdb 4 | 5 | class HiddenMarkovModel(object): 6 | """ 7 | Hidden Markov self Class 8 | 9 | Parameters: 10 | ----------- 11 | 12 | - S: Number of states. 13 | - T: numpy.array Transition matrix of size S by S 14 | stores probability from state i to state j. 15 | - E: numpy.array Emission matrix of size S by N (number of observations) 16 | stores the probability of observing O_j from state S_i. 17 | - T0: numpy.array Initial state probabilities of size S. 18 | """ 19 | 20 | def __init__(self, T, E, T0, epsilon = 0.001, maxStep = 10): 21 | # Max number of iteration 22 | self.maxStep = maxStep 23 | # convergence criteria 24 | self.epsilon = epsilon 25 | # Number of possible states 26 | self.S = T.shape[0] 27 | # Number of possible observations 28 | self.O = E.shape[0] 29 | self.prob_state_1 = [] 30 | # Emission probability 31 | self.E = torch.tensor(E) 32 | # Transition matrix 33 | self.T = torch.tensor(T) 34 | # Initial state vector 35 | self.T0 = torch.tensor(T0) 36 | 37 | def initialize_viterbi_variables(self, shape): 38 | pathStates = torch.zeros(shape, dtype=torch.float64) 39 | pathScores = torch.zeros_like(pathStates) 40 | states_seq = torch.zeros([shape[0]], dtype=torch.int64) 41 | return pathStates, pathScores, states_seq 42 | 43 | def belief_propagation(self, scores): 44 | return scores.view(-1,1) + torch.log(self.T) 45 | 46 | def viterbi_inference(self, x): # x: observing sequence 47 | self.N = len(x) 48 | shape = [self.N, self.S] 49 | # Init_viterbi_variables 50 | pathStates, pathScores, states_seq = self.initialize_viterbi_variables(shape) 51 | # log probability of emission sequence 52 | obs_prob_full = torch.log(self.E[x]) 53 | # initialize with state starting log-priors 54 | pathScores[0] = torch.log(self.T0) + obs_prob_full[0] 55 | for step, obs_prob in enumerate(obs_prob_full[1:]): 56 | # propagate state belief 57 | belief = self.belief_propagation(pathScores[step, :]) 58 | # the inferred state by maximizing global function 59 | pathStates[step + 1] = torch.argmax(belief, 0) 60 | # and update state and score matrices 61 | pathScores[step + 1] = torch.max(belief, 0)[0] + obs_prob 62 | # infer most likely last state 63 | states_seq[self.N - 1] = torch.argmax(pathScores[self.N-1, :], 0) 64 | for step in range(self.N - 1, 0, -1): 65 | # for every timestep retrieve inferred state 66 | state = states_seq[step] 67 | state_prob = pathStates[step][state] 68 | states_seq[step -1] = state_prob 69 | return states_seq, torch.exp(pathScores) # turn scores back to probabilities 70 | 71 | def initialize_forw_back_variables(self, shape): 72 | self.forward = torch.zeros(shape, dtype=torch.float64) 73 | self.backward = torch.zeros_like(self.forward) 74 | self.posterior = torch.zeros_like(self.forward) 75 | 76 | def _forward(model, obs_prob_seq): 77 | model.scale = torch.zeros([model.N], dtype=torch.float64) #scale factors 78 | # initialize with state starting priors 79 | init_prob = model.T0 * obs_prob_seq[0] 80 | # scaling factor at t=0 81 | model.scale[0] = 1.0 / init_prob.sum() 82 | # scaled belief at t=0 83 | model.forward[0] = model.scale[0] * init_prob 84 | # propagate belief 85 | for step, obs_prob in enumerate(obs_prob_seq[1:]): 86 | # previous state probability 87 | prev_prob = model.forward[step].unsqueeze(0) 88 | # transition prior 89 | prior_prob = torch.matmul(prev_prob, model.T) 90 | # forward belief propagation 91 | forward_score = prior_prob * obs_prob 92 | forward_prob = torch.squeeze(forward_score) 93 | # scaling factor 94 | model.scale[step + 1] = 1 / forward_prob.sum() 95 | # Update forward matrix 96 | model.forward[step + 1] = model.scale[step + 1] * forward_prob 97 | 98 | def _backward(self, obs_prob_seq_rev): 99 | # initialize with state ending priors 100 | self.backward[0] = self.scale[self.N - 1] * torch.ones([self.S], dtype=torch.float64) 101 | # propagate belief 102 | for step, obs_prob in enumerate(obs_prob_seq_rev[:-1]): 103 | # next state probability 104 | next_prob = self.backward[step, :].unsqueeze(1) 105 | # observation emission probabilities 106 | obs_prob_d = torch.diag(obs_prob) 107 | # transition prior 108 | prior_prob = torch.matmul(self.T, obs_prob_d) 109 | # backward belief propagation 110 | backward_prob = torch.matmul(prior_prob, next_prob).squeeze() 111 | # Update backward matrix 112 | self.backward[step + 1] = self.scale[self.N - 2 - step] * backward_prob 113 | self.backward = torch.flip(self.backward, [0, 1]) 114 | 115 | def forward_backward(self, obs_prob_seq): 116 | """ 117 | runs forward backward algorithm on observation sequence 118 | 119 | Arguments 120 | --------- 121 | - obs_prob_seq : matrix of size N by S, where N is number of timesteps and 122 | S is the number of states 123 | 124 | Returns 125 | ------- 126 | - forward : matrix of size N by S representing 127 | the forward probability of each state at each time step 128 | - backward : matrix of size N by S representing 129 | the backward probability of each state at each time step 130 | - posterior : matrix of size N by S representing 131 | the posterior probability of each state at each time step 132 | """ 133 | self._forward(obs_prob_seq) 134 | obs_prob_seq_rev = torch.flip(obs_prob_seq, [0, 1]) 135 | self._backward(obs_prob_seq_rev) 136 | 137 | def re_estimate_transition(self, x): 138 | self.M = torch.zeros([self.N - 1, self.S, self.S], dtype = torch.float64) 139 | 140 | for t in range(self.N - 1): 141 | tmp_0 = torch.matmul(self.forward[t].unsqueeze(0), self.T) 142 | tmp_1 = tmp_0 * self.E[x[t + 1]].unsqueeze(0) 143 | denom = torch.matmul(tmp_1, self.backward[t + 1].unsqueeze(1)).squeeze() 144 | 145 | trans_re_estimate = torch.zeros([self.S, self.S], dtype = torch.float64) 146 | 147 | for i in range(self.S): 148 | numer = self.forward[t, i] * self.T[i, :] * self.E[x[t+1]] * self.backward[t+1] 149 | trans_re_estimate[i] = numer / denom 150 | 151 | self.M[t] = trans_re_estimate 152 | 153 | self.gamma = self.M.sum(2).squeeze() 154 | T_new = self.M.sum(0) / self.gamma.sum(0).unsqueeze(1) 155 | 156 | T0_new = self.gamma[0,:] 157 | 158 | prod = (self.forward[self.N-1] * self.backward[self.N-1]).unsqueeze(0) 159 | s = prod / prod.sum() 160 | self.gamma = torch.cat([self.gamma, s], 0) 161 | self.prob_state_1.append(self.gamma[:, 0]) 162 | return T0_new, T_new 163 | 164 | def re_estimate_emission(self, x): 165 | states_marginal = self.gamma.sum(0) 166 | # One hot encoding buffer that you create out of the loop and just keep reusing 167 | seq_one_hot = torch.zeros([len(x), self.O], dtype=torch.float64) 168 | seq_one_hot.scatter_(1, torch.tensor(x).unsqueeze(1), 1) 169 | emission_score = torch.matmul(seq_one_hot.transpose_(1, 0), self.gamma) 170 | return emission_score / states_marginal 171 | 172 | def check_convergence(self, new_T0, new_transition, new_emission): 173 | 174 | delta_T0 = torch.max(torch.abs(self.T0 - new_T0)).item() < self.epsilon 175 | delta_T = torch.max(torch.abs(self.T - new_transition)).item() < self.epsilon 176 | delta_E = torch.max(torch.abs(self.E - new_emission)).item() < self.epsilon 177 | 178 | return delta_T0 and delta_T and delta_E 179 | 180 | def expectation_maximization_step(self, obs_seq): 181 | 182 | # probability of emission sequence 183 | obs_prob_seq = self.E[obs_seq] 184 | 185 | self.forward_backward(obs_prob_seq) 186 | 187 | new_T0, new_transition = self.re_estimate_transition(obs_seq) 188 | 189 | new_emission = self.re_estimate_emission(obs_seq) 190 | 191 | converged = self.check_convergence(new_T0, new_transition, new_emission) 192 | 193 | self.T0 = new_T0 194 | self.E = new_emission 195 | self.T = new_transition 196 | 197 | return converged 198 | 199 | def Baum_Welch_EM(self, obs_seq): 200 | # length of observed sequence 201 | self.N = len(obs_seq) 202 | 203 | # shape of Variables 204 | shape = [self.N, self.S] 205 | 206 | # initialize variables 207 | self.initialize_forw_back_variables(shape) 208 | 209 | converged = False 210 | 211 | for i in range(self.maxStep): 212 | converged = self.expectation_maximization_step(obs_seq) 213 | if converged: 214 | print('converged at step {}'.format(i)) 215 | break 216 | return self.T0, self.T, self.E, converged -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Hidden Markov Model in Pytorch 2 | 3 | This is a practical project for the learning of PGM, which is based on the [tensorflow implementation](https://github.com/MarvinBertin/HiddenMarkovModel_TensorFlow) from [Marvin Bertin](https://github.com/MarvinBertin). 4 | 5 | Viterbi, Forward-Backward and Baum Welch are implemented. 6 | 7 | For details, please check the [Jupyter Notebook Guide](https://github.com/TreB1eN/HiddenMarkovModel_Pytorch/blob/master/HiddenMarkovModel.ipynb) -------------------------------------------------------------------------------- /images/BW.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TreB1eN/HiddenMarkovModel_Pytorch/50aed18b9fe3b1397b1c486a5d6ef3f14910d168/images/BW.png -------------------------------------------------------------------------------- /images/BW1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TreB1eN/HiddenMarkovModel_Pytorch/50aed18b9fe3b1397b1c486a5d6ef3f14910d168/images/BW1.png -------------------------------------------------------------------------------- /images/BW2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TreB1eN/HiddenMarkovModel_Pytorch/50aed18b9fe3b1397b1c486a5d6ef3f14910d168/images/BW2.png -------------------------------------------------------------------------------- /images/BW3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TreB1eN/HiddenMarkovModel_Pytorch/50aed18b9fe3b1397b1c486a5d6ef3f14910d168/images/BW3.png -------------------------------------------------------------------------------- /images/BW4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TreB1eN/HiddenMarkovModel_Pytorch/50aed18b9fe3b1397b1c486a5d6ef3f14910d168/images/BW4.png -------------------------------------------------------------------------------- /images/BW5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TreB1eN/HiddenMarkovModel_Pytorch/50aed18b9fe3b1397b1c486a5d6ef3f14910d168/images/BW5.png -------------------------------------------------------------------------------- /images/BW6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TreB1eN/HiddenMarkovModel_Pytorch/50aed18b9fe3b1397b1c486a5d6ef3f14910d168/images/BW6.png -------------------------------------------------------------------------------- /images/EM1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TreB1eN/HiddenMarkovModel_Pytorch/50aed18b9fe3b1397b1c486a5d6ef3f14910d168/images/EM1.png -------------------------------------------------------------------------------- /images/EM2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TreB1eN/HiddenMarkovModel_Pytorch/50aed18b9fe3b1397b1c486a5d6ef3f14910d168/images/EM2.png -------------------------------------------------------------------------------- /images/EM3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TreB1eN/HiddenMarkovModel_Pytorch/50aed18b9fe3b1397b1c486a5d6ef3f14910d168/images/EM3.jpg -------------------------------------------------------------------------------- /images/FB1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TreB1eN/HiddenMarkovModel_Pytorch/50aed18b9fe3b1397b1c486a5d6ef3f14910d168/images/FB1.png -------------------------------------------------------------------------------- /images/FB2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TreB1eN/HiddenMarkovModel_Pytorch/50aed18b9fe3b1397b1c486a5d6ef3f14910d168/images/FB2.png -------------------------------------------------------------------------------- /images/FB3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TreB1eN/HiddenMarkovModel_Pytorch/50aed18b9fe3b1397b1c486a5d6ef3f14910d168/images/FB3.png -------------------------------------------------------------------------------- /images/FB4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TreB1eN/HiddenMarkovModel_Pytorch/50aed18b9fe3b1397b1c486a5d6ef3f14910d168/images/FB4.png -------------------------------------------------------------------------------- /images/Screen Shot 2016-05-08 at 8.27.19 PM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TreB1eN/HiddenMarkovModel_Pytorch/50aed18b9fe3b1397b1c486a5d6ef3f14910d168/images/Screen Shot 2016-05-08 at 8.27.19 PM.png -------------------------------------------------------------------------------- /images/Viterbi.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TreB1eN/HiddenMarkovModel_Pytorch/50aed18b9fe3b1397b1c486a5d6ef3f14910d168/images/Viterbi.gif -------------------------------------------------------------------------------- /images/eq1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TreB1eN/HiddenMarkovModel_Pytorch/50aed18b9fe3b1397b1c486a5d6ef3f14910d168/images/eq1.png -------------------------------------------------------------------------------- /images/eq2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TreB1eN/HiddenMarkovModel_Pytorch/50aed18b9fe3b1397b1c486a5d6ef3f14910d168/images/eq2.png -------------------------------------------------------------------------------- /images/eq3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TreB1eN/HiddenMarkovModel_Pytorch/50aed18b9fe3b1397b1c486a5d6ef3f14910d168/images/eq3.png -------------------------------------------------------------------------------- /images/eq4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TreB1eN/HiddenMarkovModel_Pytorch/50aed18b9fe3b1397b1c486a5d6ef3f14910d168/images/eq4.png -------------------------------------------------------------------------------- /images/eq5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TreB1eN/HiddenMarkovModel_Pytorch/50aed18b9fe3b1397b1c486a5d6ef3f14910d168/images/eq5.png -------------------------------------------------------------------------------- /images/eq6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TreB1eN/HiddenMarkovModel_Pytorch/50aed18b9fe3b1397b1c486a5d6ef3f14910d168/images/eq6.png -------------------------------------------------------------------------------- /images/eq7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TreB1eN/HiddenMarkovModel_Pytorch/50aed18b9fe3b1397b1c486a5d6ef3f14910d168/images/eq7.png -------------------------------------------------------------------------------- /images/eq8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TreB1eN/HiddenMarkovModel_Pytorch/50aed18b9fe3b1397b1c486a5d6ef3f14910d168/images/eq8.png -------------------------------------------------------------------------------- /images/eq9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TreB1eN/HiddenMarkovModel_Pytorch/50aed18b9fe3b1397b1c486a5d6ef3f14910d168/images/eq9.png -------------------------------------------------------------------------------- /images/graph1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TreB1eN/HiddenMarkovModel_Pytorch/50aed18b9fe3b1397b1c486a5d6ef3f14910d168/images/graph1.gif -------------------------------------------------------------------------------- /images/graph2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TreB1eN/HiddenMarkovModel_Pytorch/50aed18b9fe3b1397b1c486a5d6ef3f14910d168/images/graph2.gif -------------------------------------------------------------------------------- /images/graph3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TreB1eN/HiddenMarkovModel_Pytorch/50aed18b9fe3b1397b1c486a5d6ef3f14910d168/images/graph3.gif -------------------------------------------------------------------------------- /images/hmm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TreB1eN/HiddenMarkovModel_Pytorch/50aed18b9fe3b1397b1c486a5d6ef3f14910d168/images/hmm.png -------------------------------------------------------------------------------- /images/trans.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TreB1eN/HiddenMarkovModel_Pytorch/50aed18b9fe3b1397b1c486a5d6ef3f14910d168/images/trans.png -------------------------------------------------------------------------------- /images/trans2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TreB1eN/HiddenMarkovModel_Pytorch/50aed18b9fe3b1397b1c486a5d6ef3f14910d168/images/trans2.png -------------------------------------------------------------------------------- /images/viterbi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TreB1eN/HiddenMarkovModel_Pytorch/50aed18b9fe3b1397b1c486a5d6ef3f14910d168/images/viterbi.png -------------------------------------------------------------------------------- /images/viterbi2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TreB1eN/HiddenMarkovModel_Pytorch/50aed18b9fe3b1397b1c486a5d6ef3f14910d168/images/viterbi2.png -------------------------------------------------------------------------------- /images/viterbi3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TreB1eN/HiddenMarkovModel_Pytorch/50aed18b9fe3b1397b1c486a5d6ef3f14910d168/images/viterbi3.png -------------------------------------------------------------------------------- /images/viterbi4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TreB1eN/HiddenMarkovModel_Pytorch/50aed18b9fe3b1397b1c486a5d6ef3f14910d168/images/viterbi4.png --------------------------------------------------------------------------------