├── .gitignore ├── example.py ├── README.md └── kneser_ney.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.pyc 3 | -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | from nltk.corpus import gutenberg 2 | from nltk.util import ngrams 3 | from kneser_ney import KneserNeyLM 4 | 5 | gut_ngrams = ( 6 | ngram for sent in gutenberg.sents() for ngram in ngrams(sent, 3, 7 | pad_left=True, pad_right=True, pad_symbol='')) 8 | lm = KneserNeyLM(3, gut_ngrams, end_pad_symbol='') 9 | print(lm.score_sent(('This', 'is', 'a', 'sample', 'sentence', '.'))) 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # kneser-ney 2 | An implementation of [Kneser-Ney](https://en.wikipedia.org/wiki/Kneser%E2%80%93Ney_smoothing) language modeling in Python3. 3 | This is not a particularly optimized implementation, but is hopefully helpful for learning and works fine for corpuses that aren't too large. 4 | 5 | # Usage 6 | 7 | The KneserNey class does language model estimation when given a sequence of ngrams. 8 | 9 | ```python 10 | class KneserNey: 11 | 12 | def __init__(self, highest_order, ngrams, start_pad_symbol='', end_pad_symbol=''): 13 | """ 14 | Constructor for KneserNeyLM. 15 | 16 | Params: 17 | highest_order [int] The order of the language model. 18 | ngrams [list->tuple->string] Ngrams of the highest_order specified. 19 | Ngrams at beginning / end of sentences should be padded. 20 | start_pad_symbol [string] The symbol used to pad the beginning of 21 | sentences. 22 | end_pad_symbol [string] The symbol used to pad the beginning of 23 | sentences. 24 | """ 25 | ``` 26 | 27 | It is easy to create a KneserNeyLM out of an NLTK corpus (see example.py). 28 | 29 | ```python 30 | from nltk.corpus import gutenberg 31 | from nltk.util import ngrams 32 | from kneser_ney import KneserNeyLM 33 | 34 | gut_ngrams = ( 35 | ngram for sent in gutenberg.sents() for ngram in ngrams(sent, 3, 36 | pad_left=True, pad_right=True, pad_symbol='')) 37 | lm = KneserNeyLM(3, gut_ngrams, end_pad_symbol='') 38 | ``` 39 | 40 | The language model can then be used to score sentences or generate sentences. 41 | 42 | ```python 43 | lm.score_sent(('This', 'is', 'a', 'sample', 'sentence', '.')) 44 | lm.generate_sentence() 45 | ``` 46 | -------------------------------------------------------------------------------- /kneser_ney.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | from collections import Counter, defaultdict 4 | 5 | class KneserNeyLM: 6 | 7 | def __init__(self, highest_order, ngrams, start_pad_symbol='', 8 | end_pad_symbol=''): 9 | """ 10 | Constructor for KneserNeyLM. 11 | 12 | Params: 13 | highest_order [int] The order of the language model. 14 | ngrams [list->tuple->string] Ngrams of the highest_order specified. 15 | Ngrams at beginning / end of sentences should be padded. 16 | start_pad_symbol [string] The symbol used to pad the beginning of 17 | sentences. 18 | end_pad_symbol [string] The symbol used to pad the beginning of 19 | sentences. 20 | """ 21 | self.highest_order = highest_order 22 | self.start_pad_symbol = start_pad_symbol 23 | self.end_pad_symbol = end_pad_symbol 24 | self.lm = self.train(ngrams) 25 | 26 | def train(self, ngrams): 27 | """ 28 | Train the language model on the given ngrams. 29 | 30 | Params: 31 | ngrams [list->tuple->string] Ngrams of the highest_order specified. 32 | """ 33 | kgram_counts = self._calc_adj_counts(Counter(ngrams)) 34 | probs = self._calc_probs(kgram_counts) 35 | return probs 36 | 37 | def highest_order_probs(self): 38 | return self.lm[0] 39 | 40 | def _calc_adj_counts(self, highest_order_counts): 41 | """ 42 | Calculates the adjusted counts for all ngrams up to the highest order. 43 | 44 | Params: 45 | highest_order_counts [dict{tuple->string, int}] Counts of the highest 46 | order ngrams. 47 | 48 | Returns: 49 | kgrams_counts [list->dict] List of dict from kgram to counts 50 | where k is in descending order from highest_order to 0. 51 | """ 52 | kgrams_counts = [highest_order_counts] 53 | for i in range(1, self.highest_order): 54 | last_order = kgrams_counts[-1] 55 | new_order = defaultdict(int) 56 | for ngram in last_order.keys(): 57 | suffix = ngram[1:] 58 | new_order[suffix] += 1 59 | kgrams_counts.append(new_order) 60 | return kgrams_counts 61 | 62 | def _calc_probs(self, orders): 63 | """ 64 | Calculates interpolated probabilities of kgrams for all orders. 65 | """ 66 | backoffs = [] 67 | for order in orders[:-1]: 68 | backoff = self._calc_order_backoff_probs(order) 69 | backoffs.append(backoff) 70 | orders[-1] = self._calc_unigram_probs(orders[-1]) 71 | backoffs.append(defaultdict(int)) 72 | self._interpolate(orders, backoffs) 73 | return orders 74 | 75 | def _calc_unigram_probs(self, unigrams): 76 | sum_vals = sum(v for v in unigrams.values()) 77 | unigrams = dict((k, math.log(v/sum_vals)) for k, v in unigrams.items()) 78 | return unigrams 79 | 80 | def _calc_order_backoff_probs(self, order): 81 | num_kgrams_with_count = Counter( 82 | value for value in order.values() if value <= 4) 83 | discounts = self._calc_discounts(num_kgrams_with_count) 84 | prefix_sums = defaultdict(int) 85 | backoffs = defaultdict(int) 86 | for key in order.keys(): 87 | prefix = key[:-1] 88 | count = order[key] 89 | prefix_sums[prefix] += count 90 | discount = self._get_discount(discounts, count) 91 | order[key] -= discount 92 | backoffs[prefix] += discount 93 | for key in order.keys(): 94 | prefix = key[:-1] 95 | order[key] = math.log(order[key]/prefix_sums[prefix]) 96 | for prefix in backoffs.keys(): 97 | backoffs[prefix] = math.log(backoffs[prefix]/prefix_sums[prefix]) 98 | return backoffs 99 | 100 | def _get_discount(self, discounts, count): 101 | if count > 3: 102 | return discounts[3] 103 | return discounts[count] 104 | 105 | def _calc_discounts(self, num_with_count): 106 | """ 107 | Calculate the optimal discount values for kgrams with counts 1, 2, & 3+. 108 | """ 109 | common = num_with_count[1]/(num_with_count[1] + 2 * num_with_count[2]) 110 | # Init discounts[0] to 0 so that discounts[i] is for counts of i 111 | discounts = [0] 112 | for i in range(1, 4): 113 | if num_with_count[i] == 0: 114 | discount = 0 115 | else: 116 | discount = (i - (i + 1) * common 117 | * num_with_count[i + 1] / num_with_count[i]) 118 | discounts.append(discount) 119 | if any(d for d in discounts[1:] if d <= 0): 120 | raise Exception( 121 | '***Warning*** Non-positive discounts detected. ' 122 | 'Your dataset is probably too small.') 123 | return discounts 124 | 125 | def _interpolate(self, orders, backoffs): 126 | """ 127 | """ 128 | for last_order, order, backoff in zip( 129 | reversed(orders), reversed(orders[:-1]), reversed(backoffs[:-1])): 130 | for kgram in order.keys(): 131 | prefix, suffix = kgram[:-1], kgram[1:] 132 | order[kgram] += last_order[suffix] + backoff[prefix] 133 | 134 | def logprob(self, ngram): 135 | for i, order in enumerate(self.lm): 136 | if ngram[i:] in order: 137 | return order[ngram[i:]] 138 | return None 139 | 140 | def score_sent(self, sent): 141 | """ 142 | Return log prob of the sentence. 143 | 144 | Params: 145 | sent [tuple->string] The words in the unpadded sentence. 146 | """ 147 | padded = ( 148 | (self.start_pad_symbol,) * (self.highest_order - 1) + sent + 149 | (self.end_pad_symbol,)) 150 | sent_logprob = 0 151 | for i in range(len(sent) - self.highest_order + 1): 152 | ngram = sent[i:i+self.highest_order] 153 | sent_logprob += self.logprob(ngram) 154 | return sent_logprob 155 | 156 | def generate_sentence(self, min_length=4): 157 | """ 158 | Generate a sentence using the probabilities in the language model. 159 | 160 | Params: 161 | min_length [int] The mimimum number of words in the sentence. 162 | """ 163 | sent = [] 164 | probs = self.highest_order_probs() 165 | while len(sent) < min_length + self.highest_order: 166 | sent = [self.start_pad_symbol] * (self.highest_order - 1) 167 | # Append first to avoid case where start & end symbal are same 168 | sent.append(self._generate_next_word(sent, probs)) 169 | while sent[-1] != self.end_pad_symbol: 170 | sent.append(self._generate_next_word(sent, probs)) 171 | sent = ' '.join(sent[(self.highest_order - 1):-1]) 172 | return sent 173 | 174 | def _get_context(self, sentence): 175 | """ 176 | Extract context to predict next word from sentence. 177 | 178 | Params: 179 | sentence [tuple->string] The words currently in sentence. 180 | """ 181 | return sentence[(len(sentence) - self.highest_order + 1):] 182 | 183 | def _generate_next_word(self, sent, probs): 184 | context = tuple(self._get_context(sent)) 185 | pos_ngrams = list( 186 | (ngram, logprob) for ngram, logprob in probs.items() 187 | if ngram[:-1] == context) 188 | # Normalize to get conditional probability. 189 | # Subtract max logprob from all logprobs to avoid underflow. 190 | _, max_logprob = max(pos_ngrams, key=lambda x: x[1]) 191 | pos_ngrams = list( 192 | (ngram, math.exp(prob - max_logprob)) for ngram, prob in pos_ngrams) 193 | total_prob = sum(prob for ngram, prob in pos_ngrams) 194 | pos_ngrams = list( 195 | (ngram, prob/total_prob) for ngram, prob in pos_ngrams) 196 | rand = random.random() 197 | for ngram, prob in pos_ngrams: 198 | rand -= prob 199 | if rand < 0: 200 | return ngram[-1] 201 | return ngram[-1] 202 | --------------------------------------------------------------------------------