├── .gitignore ├── LICENSE ├── README.md ├── extract_ngrams.py ├── learnmdl.py ├── ngram.py └── paraphrase.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # IPython Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # dotenv 79 | .env 80 | 81 | # virtualenv 82 | venv/ 83 | ENV/ 84 | 85 | # Spyder project settings 86 | .spyderproject 87 | 88 | # Rope project settings 89 | .ropeproject 90 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Mark Fishel 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ngram2vec 2 | Embeddings for n-grams via sampling. 3 | 4 | ## Learning word2vec model with word2vec (gensim word2vec) 5 | Parameters for extraction are inside the learnmdl.py, they can be changed in the file. 6 | 7 | ```console 8 | $ python3 learnmdl.py preproc.data.en model.en 9 | ``` 10 | Now model.en can be loaded in and used with python (gensim word2vec model). 11 | 12 | ## Extracting only ngrams 13 | Extracting only ngrams is good, because mainly fasttext with python is very slow and C++ code is used to train. Additionally, adds modularity - extracting ngrams once and then training whatever embeddings: different embeddings (word2vec, glove, fasttext etc) or different hyperparameters... 14 | 15 | ### Example use with fasttext (C++ compiled): 16 | 17 | ```console 18 | $ python3 extract_ngrams.py data.clean.en data.ngrams.en 19 | $ ./fasttext cbow -input data.ngram.en -thread 16 -dim 300 ngram.mdl.d3.en 20 | ``` 21 | -------------------------------------------------------------------------------- /extract_ngrams.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import ngram 4 | 5 | if __name__ == "__main__": 6 | dataFile = sys.argv[1] 7 | outFile = sys.argv[2] 8 | beta = 0.125 9 | freqFilter = [20,120,90] 10 | 11 | lines = ngram.SentenceNgramSampler(dataFile, minCounts=freqFilter, ngramThresholdBeta=beta) 12 | 13 | with open(outFile, "w") as f: 14 | for i in lines: 15 | f.write(" ".join(i) + "\n") 16 | print("Job finished") 17 | -------------------------------------------------------------------------------- /learnmdl.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import re 4 | import sys 5 | import logging 6 | import os 7 | 8 | import ngram 9 | 10 | from datetime import datetime 11 | from gensim.models import Word2Vec 12 | 13 | def debug(msg): 14 | sys.stderr.write("{0}: {1}\n".format(str(datetime.now()), msg)) 15 | 16 | if __name__ == "__main__": 17 | dataFile = sys.argv[1] 18 | modelFile = sys.argv[2] 19 | 20 | #Factored Estonian data: 21 | tokFactor = 1 22 | posFactor = 2 23 | firstPosFilter = "A,S,H,V,X,D,G,U,Y" 24 | lastPosFilter = "S,H,V,X,K,Y" 25 | 26 | freqFilter = [5, 50] 27 | somePosFilter = None 28 | crazyBigMFCorpus = True 29 | beta = 0.125 30 | epochs = 10 31 | 32 | logging.basicConfig(level = logging.INFO) 33 | 34 | lines = ngram.SentenceNgramSampler(dataFile, minCounts = freqFilter, tokFactor = tokFactor, posFactor = posFactor, firstPosFilter = firstPosFilter, lastPosFilter = lastPosFilter, atLeastOnePosFilter = somePosFilter, ngramThresholdBeta = beta, crazyBigMFCorpus = crazyBigMFCorpus) 35 | 36 | if len(freqFilter) > 1: 37 | debug("Initializing") 38 | for line in lines: 39 | pass 40 | 41 | if epochs > 0: 42 | model = Word2Vec(workers=60, sg=1, hs=1, iter=10, min_count=freqFilter[0]) 43 | 44 | debug("Building vocab") 45 | model.build_vocab(lines) 46 | 47 | debug("Learning") 48 | for i in range(epochs): 49 | model.train(lines, total_examples = len(lines), epochs = 1) 50 | model.save(modelFile + ".trainable." + str(i)) 51 | model.wv.save_word2vec_format(modelFile + "." + str(i), binary = True) 52 | debug("Iteration {0} done".format(i)) 53 | 54 | os.rename(modelFile + "." + str(epochs - 1), modelFile) 55 | -------------------------------------------------------------------------------- /ngram.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # Author: Mark Fishel 4 | 5 | import re 6 | import random 7 | import pickle 8 | import math 9 | import sys 10 | 11 | import logging 12 | 13 | from collections import defaultdict, deque, Counter 14 | from operator import itemgetter 15 | 16 | logger = logging.getLogger('ngram iter') 17 | 18 | class SentenceNgramSampler: 19 | batchCount = 0 20 | epochCount = 0 21 | 22 | ngramDict = defaultdict(lambda: defaultdict(int)) 23 | 24 | length = 0 25 | 26 | currSntIdx = None 27 | storedData = [] 28 | 29 | firstIter = True 30 | 31 | def __init__(self, filename, minCounts = [5, 30, 50], ngramThresholdBeta = 0.125, 32 | firstPosFilter = None, lastPosFilter = None, atLeastOnePosFilter = None, 33 | crazyBigMFCorpus = False, tokFactor = None, posFactor = None): 34 | 35 | self.maxNgramLen = len(minCounts) 36 | self.minCounts = minCounts 37 | 38 | self.ngramThresholdBeta = ngramThresholdBeta 39 | 40 | self.firstPosFilter = self._maybeReadFilter(firstPosFilter) 41 | self.lastPosFilter = self._maybeReadFilter(lastPosFilter) 42 | self.atLeastOnePosFilter = self._maybeReadFilter(atLeastOnePosFilter) 43 | 44 | self.tokFactor = tokFactor 45 | self.posFactor = posFactor 46 | 47 | self.crazyBigMFCorpus = crazyBigMFCorpus 48 | self.filename = filename 49 | 50 | self.fileHandle = open(filename, 'r') 51 | 52 | def __len__(self): 53 | return self.length 54 | 55 | def __next__(self): 56 | result = None 57 | 58 | while not result: 59 | factoredSnt = self._getNextSentence() 60 | 61 | ngramsAndSpecs = list(self.ngrams(factoredSnt)) 62 | random.shuffle(ngramsAndSpecs) 63 | 64 | toJoin = self._getNonoverlappingNgrams(ngramsAndSpecs) 65 | 66 | result = self._applyJoinOps(factoredSnt, toJoin) 67 | 68 | return result 69 | 70 | def _maybeReadFilter(self, rawFilterSpec): 71 | if rawFilterSpec is None: 72 | return None 73 | else: 74 | return set(rawFilterSpec.split(",")) 75 | 76 | def _getFactors(self, rawToken): 77 | if self.tokFactor is None: 78 | f1 = rawToken.lower() 79 | f2 = None 80 | else: 81 | factors = rawToken.split("|") 82 | 83 | try: 84 | f1 = factors[self.tokFactor].lower() 85 | f2 = factors[self.posFactor] 86 | except IndexError: 87 | self.tokFactor = None 88 | self.posFactor = None 89 | 90 | f1 = rawToken.lower() 91 | f2 = None 92 | 93 | return (f1, f2) 94 | 95 | def _cleanSentence(self, rawSnt): 96 | result = [self._getFactors(t) for t in rawSnt.strip().split()] 97 | 98 | return [(t, p) for t, p in result if re.search(r'[a-zäöüõšž]', t)] 99 | 100 | def _tryGetNext(self): 101 | #either first iteration, or re-reading the file every time 102 | if (self.currSntIdx is None): 103 | rawSnt = next(self.fileHandle) 104 | return self._cleanSentence(rawSnt) 105 | 106 | #or reading from data in memory 107 | else: 108 | snt = self.storedData[self.currSntIdx] 109 | self.currSntIdx += 1 110 | return snt 111 | 112 | def _handleEndOfFile(self): 113 | self.fileHandle.close() 114 | 115 | if self.firstIter: 116 | self._filterDict() 117 | 118 | self.firstIter = False 119 | 120 | if self.crazyBigMFCorpus: 121 | self.fileHandle = open(self.filename, 'r') 122 | else: 123 | self.currSntIdx = 0 124 | 125 | raise StopIteration 126 | 127 | def _handleEndOfList(self): 128 | self.currSntIdx = 0 129 | raise StopIteration 130 | 131 | def _updateNgramDict(self, fsnt): 132 | for w in fsnt: 133 | self.ngramDict[0][w[0]] += 1 134 | 135 | #update ngram freq counter 136 | for ngram, spec in self.ngrams(fsnt): 137 | nlen = len(spec) - 1 138 | 139 | if self._acceptableNgram(fsnt, spec): 140 | self.ngramDict[nlen][ngram] += 1 141 | 142 | def _getNextSentence(self): 143 | try: 144 | factoredSnt = self._tryGetNext() 145 | 146 | except IndexError: 147 | self._handleEndOfList() 148 | 149 | except StopIteration: 150 | self._handleEndOfFile() 151 | 152 | if self.firstIter: 153 | self.length += 1 154 | self._updateNgramDict(factoredSnt) 155 | 156 | if not self.crazyBigMFCorpus: 157 | self.storedData.append(factoredSnt) 158 | 159 | return factoredSnt 160 | 161 | def _acceptableNgram(self, fsnt, ngramSpec): 162 | if self.posFactor is None: 163 | return True 164 | 165 | factors = [fsnt[i][1] for i in sorted(ngramSpec)] 166 | 167 | firstOk = (self.firstPosFilter is None or factors[0] in self.firstPosFilter or factors[0] is None) 168 | lastOk = (self.lastPosFilter is None or factors[-1] in self.lastPosFilter or factors[-1] is None) 169 | someOk = (self.atLeastOnePosFilter is None or set(factors) & self.atLeastOnePosFilter) 170 | 171 | result = (firstOk and lastOk and someOk) 172 | 173 | #print(factors, firstOk, lastOk, someOk, result) 174 | 175 | return result 176 | 177 | def _filterDict(self): 178 | for nlen in self.ngramDict: 179 | before = len(self.ngramDict[nlen]) 180 | self.ngramDict[nlen] = { k: v for k, v in self.ngramDict[nlen].items() if v >= self.minCounts[nlen] } 181 | after = len(self.ngramDict[nlen]) 182 | logger.info("Filtered {0}-grams from {1} down to {2}".format(nlen + 1, before, after)) 183 | 184 | def ngrams(self, fSnt): 185 | for idx in range(len(fSnt)): 186 | for nlen in range(1, self.maxNgramLen): 187 | if idx - nlen >= 0: 188 | spec = range(idx - nlen, idx + 1) 189 | 190 | yield "__".join([fSnt[i][0] for i in spec]), set(spec) 191 | 192 | def _getNonoverlappingNgrams(self, ngramsAndSpecs): 193 | result = [] 194 | covMap = set() 195 | 196 | for ngram, spec in ngramsAndSpecs: 197 | nlen = len(spec) - 1 198 | 199 | if ngram in self.ngramDict[nlen] and self.ngramDict[nlen][ngram] >= self.minCounts[nlen]: 200 | if not (spec & covMap): 201 | threshold = math.exp((-math.log(self.ngramDict[nlen][ngram]))*self.ngramThresholdBeta) 202 | 203 | if random.random() < threshold: 204 | result.append(spec) 205 | covMap.update(spec) 206 | 207 | return result 208 | 209 | def _applyJoinOps(self, sentence, toJoin): 210 | result = [t for t, _ in sentence] 211 | 212 | for op in sorted(toJoin, key=lambda x: -min(x)): 213 | result = result[:min(op)] + ["__".join([sentence[i][0] for i in sorted(op)])] + result[max(op)+1:] 214 | 215 | return result 216 | 217 | def __iter__(self): 218 | return self 219 | 220 | if __name__ == "__main__": 221 | logging.basicConfig(level = logging.INFO) 222 | 223 | sampler = SentenceNgramSampler(sys.argv[1], minCounts = [2, 2, 2]) 224 | 225 | for snt in sampler: 226 | print(snt) 227 | 228 | print("Second iteration") 229 | 230 | for snt in sampler: 231 | print(snt) 232 | -------------------------------------------------------------------------------- /paraphrase.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | import logging 5 | import math 6 | import rnnlm 7 | from gensim.models import KeyedVectors 8 | 9 | from collections import defaultdict 10 | from apply_bpe import BPE 11 | 12 | logger = logging.getLogger('paraphrase.py') 13 | 14 | class State: 15 | def __init__(self, expansions, ngram = list(), prev = None, covVec = set(), simProb = 0.0, lmProb = 0.0): 16 | self.lmProb = lmProb 17 | self.simProb = simProb 18 | self.covVec = covVec 19 | self.ngram = ngram 20 | self.prev = prev 21 | 22 | self.expansions = [exp for exp in expansions if self.compatible(exp)] 23 | 24 | def __repr__(self): 25 | return self.getKey() + ": " + str(self.getProb()) 26 | 27 | def getProb(self): 28 | return self.simProb + self.lmProb 29 | 30 | def getKey(self): 31 | return str(self.covVec) + "//" + str(self.getFullForm()) 32 | 33 | def nextStates(self, lmMdl): 34 | for inNgram, inNgramSpec, outNgramScoreList in self.expansions: 35 | for outNgram, simProb in outNgramScoreList: 36 | newstate = State(self.expansions, prev = self, ngram = outNgram, covVec = self.combineCovVec(inNgramSpec), simProb = self.simProb + math.log(simProb)) 37 | newstate.lmProb = rnnlm.score(newstate.getFullForm(), lmMdl) 38 | #print(newstate.ngram, newstate.getFullForm(), newstate.lmProb) 39 | yield newstate 40 | 41 | def compatible(self, expansion): 42 | return not expansion[1] & self.covVec 43 | 44 | def combineCovVec(self, ngramSpec): 45 | return self.covVec | ngramSpec 46 | 47 | def isEnd(self, query): 48 | return self.covVec == set(range(len(query))) 49 | 50 | def getFullForm(self): 51 | result = [] 52 | state = self 53 | while state.prev != None: 54 | result = state.ngram + result 55 | state = state.prev 56 | return result 57 | 58 | def getExplanation(self): 59 | result = [] 60 | state = self 61 | 62 | while state.prev != None: 63 | result = [[state.ngram, state.covVec]] + result 64 | state = state.prev 65 | 66 | prevCov = set() 67 | 68 | for resultElem in result: 69 | x = resultElem[1] 70 | resultElem[1] -= prevCov 71 | prevCov |= x 72 | 73 | return ", ".join([str(a) + "/" + str(b) for a, b in result]) 74 | 75 | def ngrams(seq, simMdl, qn, maxNgramLen = 4): 76 | for i in range(len(seq)): 77 | #thisUniGram = [seq[i]] 78 | 79 | #yield thisUniGram, set(i), [ ( thisUniGram, 0.1 ) ] 80 | 81 | for l in range(maxNgramLen): 82 | if i - l >= 0: 83 | currSpec = range(i - l, i + 1) 84 | currNgram = [seq[i] for i in currSpec] 85 | 86 | currNgramStr = "__".join(currNgram) 87 | 88 | if currNgramStr in simMdl: 89 | yield currNgram, set(currSpec), [(ngram.split("__"), prob) for ngram, prob in simMdl.most_similar(currNgramStr, topn = qn)] 90 | 91 | def paraphrase(query, simMdl, lmMdl, n = 5, qn = 10): 92 | logger.debug("Paraphrasing " + str(query)) 93 | grid = defaultdict(lambda : dict()) 94 | 95 | scoredNgramPairs = list(ngrams(query, simMdl, qn)) 96 | logger.debug("search space: " + "\n".join(["> " + str(x) for x in scoredNgramPairs])) 97 | 98 | startState = State(scoredNgramPairs) 99 | 100 | grid[0] = { startState.getKey(): startState } 101 | 102 | currLev = 0 103 | 104 | results = [] 105 | 106 | while grid[currLev]: 107 | logger.debug("Processing level {0}".format(currLev)) 108 | for state in sorted(grid[currLev].values(), key=lambda x: -x.getProb())[:n]: 109 | logger.debug(" Processing level {0} state {1}".format(currLev, state)) 110 | 111 | if state.isEnd(query): 112 | expl = state.getExplanation() 113 | fullForm = state.getFullForm() 114 | prob = state.getProb() 115 | 116 | logger.debug(" End state: {0}, {1}".format(prob, expl)) 117 | 118 | results.append((fullForm, prob, expl)) 119 | else: 120 | deadEnd = True 121 | 122 | logger.debug(" Next states:") 123 | 124 | for nextState in state.nextStates(lmMdl): 125 | lev = currLev + len(nextState.ngram) 126 | key = nextState.getKey() 127 | 128 | logger.debug(" --> {0}".format(nextState)) 129 | 130 | deadEnd = False 131 | 132 | if not key in grid[lev] or grid[lev][key].getProb() < nextState.getProb(): 133 | grid[lev][key] = nextState 134 | 135 | if deadEnd: 136 | logger.debug(" Dead end") 137 | 138 | logger.debug("Finished processing level {0} state {1}".format(currLev, state)) 139 | logger.debug("---------------------------------------") 140 | logger.debug("Finished processing level {0}".format(currLev)) 141 | logger.debug("=======================================") 142 | 143 | currLev += 1 144 | 145 | return sorted(results, key=lambda x: -x[1]) 146 | 147 | def loadBpe(bpeMdlFile): 148 | with open(bpeMdlFile, 'r') as codes: 149 | bpeMdl = BPE(codes, separator = '') 150 | return bpeMdl 151 | 152 | def bpeSplit(query, bpeMdl): 153 | result = [] 154 | for segm in query: 155 | result.append(bpeMdl.segment(segm)) 156 | 157 | return (" ".join(result)).split() 158 | 159 | if __name__ == "__main__": 160 | logging.basicConfig(level = logging.INFO) 161 | 162 | try: 163 | simMdlFile = sys.argv[1] 164 | bpeMdlFile = sys.argv[2] 165 | lmMdlFile = sys.argv[3] 166 | dictFile = sys.argv[4] 167 | except IndexError: 168 | print("Usage: paraphrase.py ngramMdl bpeMdl languageMdl dictMdl") 169 | else: 170 | logger.info("Loading similarity model") 171 | simMdl = KeyedVectors.load_word2vec_format(simMdlFile, binary=True) 172 | 173 | a = simMdl.most_similar(list(simMdl.vocab)[5]) 174 | logger.debug(str(a)) 175 | 176 | logger.info("Loading LM") 177 | lmMdl = rnnlm.loadModels(lmMdlFile, dictFile) 178 | 179 | logger.info("Loading BPE model") 180 | bpeMdl = loadBpe(bpeMdlFile) 181 | 182 | logger.info("Ready to paraphrase (enter 'quit' to quit)") 183 | 184 | query = "-" 185 | 186 | while query != ["quit"]: 187 | sys.stdout.write("\nQuery: ") 188 | 189 | query = input().lower().split() 190 | 191 | splitQuery = bpeSplit(query, bpeMdl) 192 | 193 | print("(split as " + "|".join(splitQuery) + ")") 194 | results = paraphrase(splitQuery, simMdl, lmMdl, n = 5, qn = 10) 195 | for phrase, prob, expl in results[:5]: 196 | print("{2} (p={1} / {3})".format("".join(phrase), prob, "|".join(phrase), expl)) 197 | --------------------------------------------------------------------------------