├── .travis.yml ├── papers └── Online_Learning_for_LDA_Hoffman.pdf ├── requirements.txt ├── web ├── settings.py ├── templates │ ├── base.html │ └── index.html ├── models.py ├── config.py └── app.py ├── src ├── distances.py ├── utils.py └── models.py ├── LICENSE ├── README.md └── .gitignore /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | python: 3 | - "3.4" 4 | - "3.5" 5 | install: 6 | - pip install flake8 7 | 8 | script: flake8 . --max-line-length=100 9 | -------------------------------------------------------------------------------- /papers/Online_Learning_for_LDA_Hoffman.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huyhoang17/LDA_Viblo_Recommender_System/HEAD/papers/Online_Learning_for_LDA_Hoffman.pdf -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | gensim 2 | requests 3 | beautifulsoup4 4 | pyvi 5 | scikit-learn 6 | numpy 7 | pandas 8 | scipy 9 | matplotlib 10 | flask 11 | mistune 12 | pymongo -------------------------------------------------------------------------------- /web/settings.py: -------------------------------------------------------------------------------- 1 | MONGODB_SETTINGS = { 2 | 'db': 'rsframgia', 3 | 'collection': 'viblo_posts', 4 | 'host': 'mongodb://localhost:27017/' 5 | } 6 | PATH_DICTIONARY = "models/id2word.dictionary" 7 | PATH_CORPUS = "models/corpus.mm" 8 | PATH_LDA_MODEL = "models/LDA.model" 9 | PATH_DOC_TOPIC_DIST = "models/doc_topic_dist.dat" 10 | -------------------------------------------------------------------------------- /web/templates/base.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | {% block meta %}{% endblock %} 5 | 6 | 7 | {% block content %}{% endblock %} 8 | 9 | {% block footer %} 10 | {% endblock %} 11 | 12 | {% block js %}{% endblock %} 13 | 14 | 15 | -------------------------------------------------------------------------------- /web/models.py: -------------------------------------------------------------------------------- 1 | from mongoengine import connect, Document, StringField 2 | 3 | 4 | connect(db="rsframgia", host="mongodb://localhost:27017") 5 | 6 | 7 | class Books(Document): 8 | id_ = StringField() 9 | slug = StringField() 10 | title = StringField() 11 | user_id = StringField() 12 | canonical_url = StringField() 13 | contents = StringField() 14 | idrs = StringField() 15 | -------------------------------------------------------------------------------- /web/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | class BaseConfig: 5 | """Base configuration""" 6 | DEBUG = False 7 | TESTING = False 8 | MONGODB_SETTINGS = { 9 | 'db': 'rsframgia', 10 | 'host': 'mongodb://localhost:27017/' 11 | } 12 | SECRET_KEY = os.environ.get("SECRET_KEY", "framgia123") 13 | 14 | 15 | class DevelopmentConfig(BaseConfig): 16 | """Development configuration""" 17 | DEBUG = True 18 | 19 | 20 | class TestingConfig(BaseConfig): 21 | """Testing configuration""" 22 | DEBUG = True 23 | TESTING = True 24 | 25 | 26 | class ProductionConfig(BaseConfig): 27 | """Production configuration""" 28 | DEBUG = False 29 | -------------------------------------------------------------------------------- /web/templates/index.html: -------------------------------------------------------------------------------- 1 | {% extends "base.html" %} 2 | {% block content %} 3 | 4 |

Framgia - Viblo Recommender System

5 | 6 | {% if random_posts %} 7 |

Random Posts

8 | 9 | {% for random_post in random_posts %} 10 |

{{ random_post.title }}

11 | {{ random_post.url }} 12 |
13 | Link 14 | {% endfor %} 15 | {% endif %} 16 | 17 | {% if main_post %} 18 |

Main Post

19 |

{{ main_post.title }}

20 | {{ main_post.url }} 21 |
22 | Link 23 | {% endif %} 24 | 25 | {% if posts %} 26 | 27 |

Related Posts

28 | {% for post in posts %} 29 |

{{ post.title }}

30 | {{ post.url }} 31 |
32 | Link 33 | {% endfor %} 34 | 35 | {% endif %} 36 | 37 | {% endblock %} 38 | -------------------------------------------------------------------------------- /src/distances.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.stats import entropy 3 | 4 | 5 | def jensen_shannon(query, matrix): 6 | """ 7 | This function implements a Jensen-Shannon similarity 8 | between the input query (an LDA topic distribution for a document) 9 | and the entire corpus of topic distributions. 10 | It returns an array of length M (the number of documents in the corpus) 11 | """ 12 | # lets keep with the p,q notation above 13 | p = query[None, :].T # take transpose 14 | q = matrix.T # transpose matrix 15 | 16 | m = 0.5 * (p + q) 17 | return np.sqrt(0.5 * (entropy(p, m) + entropy(q, m))) 18 | 19 | 20 | def get_most_similar_documents(query, matrix, k=10): 21 | """ 22 | This function implements the Jensen-Shannon distance above 23 | and returns the top k indices of the smallest jensen shannon distances 24 | """ 25 | # list of jensen shannon distances 26 | sims = jensen_shannon(query, matrix) 27 | # the top k positional index of the smallest Jensen Shannon distances 28 | return sims.argsort()[:k] 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Phan Hoang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Build Status](https://travis-ci.org/huyhoang17/LDA_Viblo_Recommender_System.svg?branch=master)](https://travis-ci.org/huyhoang17/LDA_Viblo_Recommender_System) 2 | 3 | # Simple Recommender System for Viblo Website using LDA (Latent Dirichlet Allocation) 4 | 5 | Blog Post 6 | --- 7 | 8 | - https://viblo.asia/p/xay-dung-he-thong-goi-y-don-gian-cho-website-viblo-YWOZrgLYlQ0 9 | 10 | Command 11 | --- 12 | 13 | ``` 14 | export PYTHONPATH="path_to_Recommender_System_Viblo_project" 15 | ``` 16 | 17 | Viblo's API 18 | --- 19 | 20 | - Posts: https://viblo.asia/api/posts 21 | - 1 Post with slug field: https://viblo.asia/api/posts/GrLZDXBBZk0 22 | 23 | 24 | Reference 25 | --- 26 | 27 | LDA 28 | - https://www.machinelearningplus.com/nlp/topic-modeling-gensim-python/ 29 | - https://radimrehurek.com/topic_modeling_tutorial/2%20-%20Topic%20Modeling.html 30 | - https://markroxor.github.io/gensim/static/notebooks/gensim_news_classification.html#topic=1&lambda=1&term= 31 | - https://nlpforhackers.io/topic-modeling/ 32 | 33 | Multicore LDA 34 | - https://rare-technologies.com/multicore-lda-in-python-from-over-night-to-over-lunch/ 35 | 36 | Online Learning LDA 37 | - https://radimrehurek.com/gensim/models/ldamodel.html#usage-examples 38 | - https://radimrehurek.com/gensim/wiki.html#latent-dirichlet-allocation 39 | - https://wellecks.wordpress.com/2014/10/26/ldaoverflow-with-online-lda/ 40 | 41 | Similarity 42 | - https://www.kaggle.com/ktattan/lda-and-document-similarity 43 | 44 | Visual 45 | - https://www.kaggle.com/yohanb/lda-visualized-using-t-sne-and-bokeh 46 | 47 | Other 48 | - https://miningthedetails.com/blog/python/lda/GensimLDA/ 49 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Custom 2 | data 3 | notebooks 4 | models 5 | temp 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # celery beat schedule file 90 | celerybeat-schedule 91 | 92 | # SageMath parsed files 93 | *.sage.py 94 | 95 | # Environments 96 | .env 97 | .venv 98 | env/ 99 | venv/ 100 | ENV/ 101 | env.bak/ 102 | venv.bak/ 103 | 104 | # Spyder project settings 105 | .spyderproject 106 | .spyproject 107 | 108 | # Rope project settings 109 | .ropeproject 110 | 111 | # mkdocs documentation 112 | /site 113 | 114 | # mypy 115 | .mypy_cache/ 116 | .dmypy.json 117 | dmypy.json -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import re 3 | 4 | from bs4 import BeautifulSoup 5 | from markdown import markdown 6 | from pyvi import ViTokenizer 7 | 8 | logging.basicConfig(format='%(levelname)s : %(message)s', level=logging.INFO) 9 | logging.root.level = logging.INFO 10 | 11 | 12 | with open('data/vni_stopwords.txt') as f: 13 | stopwords = [] 14 | for line in f: 15 | stopwords.append("_".join(line.strip().split())) 16 | 17 | 18 | def preprocessing_tags(soup, tags=None): 19 | if tags is not None: 20 | for tag in tags: 21 | for sample in soup.find_all(tag): 22 | sample.replaceWith('') 23 | else: 24 | raise NotImplementedError("Tags must be set!") 25 | 26 | return soup.get_text() 27 | 28 | 29 | def markdown_to_text(markdown_string, parser="html.parser", 30 | tags=['pre', 'code', 'a', 'img', 'i']): 31 | """ Converts a markdown string to plaintext 32 | https://stackoverflow.com/questions/18453176 33 | """ 34 | 35 | import mistune # noqa 36 | # md -> html -> text since BeautifulSoup can extract text cleanly 37 | markdown = mistune.Markdown() 38 | html = markdown(markdown_string) 39 | 40 | soup = BeautifulSoup(html, parser) 41 | # remove code snippets 42 | text = preprocessing_tags(soup, tags) 43 | 44 | text = remove_links_content(text) 45 | text = remove_emails(text) 46 | text = remove_punctuation(text) 47 | text = text.replace('\n', ' ') 48 | text = remove_numeric(text) 49 | text = remove_multiple_space(text) 50 | text = text.lower().strip() 51 | text = ViTokenizer.tokenize(text) 52 | text = remove_stopwords(text, stopwords=stopwords) 53 | 54 | return text 55 | 56 | 57 | def markdown_process(content, markdown=markdown, tags_space=None): 58 | """ 59 | Author: Hoang Anh Pham 60 | :param tags_space: technology keyword to replace to remain tags 61 | ruby on rails -> ruby_on_rails 62 | """ 63 | import mistune # noqa 64 | 65 | markdown = mistune.Markdown() 66 | html_doc = markdown(content) 67 | soup = BeautifulSoup(html_doc, 'html.parser') 68 | 69 | for tag in soup.find_all(['pre']): 70 | tag.replace_with('') 71 | for tag in soup.find_all(['img']): 72 | tag.replace_with('') 73 | for tag in soup.find_all(['a']): 74 | tag.replace_with('') 75 | 76 | text = soup.text 77 | text = text.replace('\n', ' ') 78 | text = re.sub(r'[^\w\s]', ' ', text) 79 | text = text.lower() 80 | text = text.strip() 81 | 82 | for tag in tags_space: 83 | text = text.replace(tag, tags_space[tag]) 84 | 85 | return text 86 | 87 | 88 | def remain_tags_space(text, tags_space): 89 | """ 90 | :param tags_space: tag to remained 91 | { 92 | "ruby on rails": "ruby_on_rails", 93 | ... 94 | } 95 | """ 96 | for tag in tags_space: 97 | text = text.replace(tag, tags_space[tag]) 98 | return ['_'.join(tag.split()) for tag in tags_space] 99 | 100 | 101 | def remove_emails(text): 102 | return re.sub('\S*@\S*\s?', '', text) 103 | 104 | 105 | def remove_newline_characters(text): 106 | return re.sub('\s+', ' ', text) 107 | 108 | 109 | def remove_links_content(text): 110 | text = re.sub(r"http\S+", "", text) 111 | return text 112 | 113 | 114 | def remove_multiple_space(text): 115 | return re.sub("\s\s+", " ", text) 116 | 117 | 118 | def remove_punctuation(text): 119 | """https://stackoverflow.com/a/37221663""" 120 | import string # noqa 121 | table = str.maketrans({key: None for key in string.punctuation}) 122 | return text.translate(table) 123 | 124 | 125 | def remove_numeric(text): 126 | import string # noqa 127 | table = str.maketrans({key: None for key in string.digits}) 128 | return text.translate(table) 129 | 130 | 131 | def remove_html_tags(text): 132 | """Remove html tags from a string""" 133 | clean = re.compile('<.*?>') 134 | return re.sub(clean, '', text) 135 | 136 | 137 | def remove_stopwords(text, stopwords): 138 | return " ".join([word for word in text.split() if word not in stopwords]) 139 | -------------------------------------------------------------------------------- /web/app.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import random 4 | 5 | from flask import Flask, jsonify, render_template 6 | import numpy as np 7 | import pymongo 8 | 9 | import settings 10 | from src.distances import get_most_similar_documents 11 | from src.models import make_texts_corpus 12 | from src.utils import markdown_to_text 13 | 14 | client = pymongo.MongoClient(settings.MONGODB_SETTINGS["host"]) 15 | db = client[settings.MONGODB_SETTINGS["db"]] 16 | mongo_col = db[settings.MONGODB_SETTINGS["collection"]] 17 | 18 | app = Flask(__name__) 19 | app.secret_key = os.environ.get("SECRET_KEY", "framgia123") 20 | 21 | # app.config.from_object('web.config.DevelopmentConfig') 22 | logging.basicConfig( 23 | format='%(asctime)s : %(levelname)s : %(message)s', 24 | level=logging.INFO 25 | ) 26 | 27 | 28 | def load_model(): 29 | import gensim # noqa 30 | from sklearn.externals import joblib # noqa 31 | # load LDA model 32 | lda_model = gensim.models.LdaModel.load( 33 | settings.PATH_LDA_MODEL 34 | ) 35 | # load corpus 36 | corpus = gensim.corpora.MmCorpus( 37 | settings.PATH_CORPUS 38 | ) 39 | # load dictionary 40 | id2word = gensim.corpora.Dictionary.load( 41 | settings.PATH_DICTIONARY 42 | ) 43 | # load documents topic distribution matrix 44 | doc_topic_dist = joblib.load( 45 | settings.PATH_DOC_TOPIC_DIST 46 | ) 47 | # doc_topic_dist = np.array([np.array(dist) for dist in doc_topic_dist]) 48 | 49 | return lda_model, corpus, id2word, doc_topic_dist 50 | 51 | 52 | lda_model, corpus, id2word, doc_topic_dist = load_model() 53 | 54 | 55 | @app.route('/ping', methods=['GET']) 56 | def ping_pong(): 57 | return jsonify({ 58 | 'call': 'success', 59 | 'message': 'pong!' 60 | }) 61 | 62 | 63 | @app.route('/posts/', methods=["GET"]) 64 | def show_posts(): 65 | idrss = random.sample(range(0, mongo_col.count()), 10) 66 | posts = mongo_col.find({"idrs": {"$in": idrss}}) 67 | random_posts = [ 68 | { 69 | "url": post["canonical_url"], 70 | "title": post["title"], 71 | "slug": post["slug"] 72 | } 73 | for post in posts 74 | ] 75 | return render_template('index.html', random_posts=random_posts) 76 | 77 | 78 | @app.route('/posts/', methods=["GET"]) 79 | def show_post(slug): 80 | main_post = mongo_col.find_one({"slug": slug}) 81 | main_post = { 82 | "url": main_post["canonical_url"], 83 | "title": main_post["title"], 84 | "slug": main_post["slug"], 85 | "content": main_post["contents"] 86 | } 87 | 88 | # preprocessing 89 | content = markdown_to_text(main_post["content"]) 90 | text_corpus = make_texts_corpus([content]) 91 | bow = id2word.doc2bow(next(text_corpus)) 92 | doc_distribution = np.array( 93 | [doc_top[1] for doc_top in lda_model.get_document_topics(bow=bow)] 94 | ) 95 | 96 | # recommender posts 97 | most_sim_ids = list(get_most_similar_documents( 98 | doc_distribution, doc_topic_dist))[1:] 99 | 100 | most_sim_ids = [int(id_) for id_ in most_sim_ids] 101 | posts = mongo_col.find({"idrs": {"$in": most_sim_ids}}) 102 | related_posts = [ 103 | { 104 | "url": post["canonical_url"], 105 | "title": post["title"], 106 | "slug": post["slug"] 107 | } 108 | for post in posts 109 | ][1:] 110 | 111 | return render_template( 112 | 'index.html', main_post=main_post, posts=related_posts 113 | ) 114 | 115 | 116 | @app.route('/posts_HAU/', methods=["GET"]) 117 | def show_post_HAU(slug): 118 | """ 119 | Author: Thanh Hau 120 | """ 121 | from sklearn.externals import joblib # noqa 122 | sim_topics = joblib.load('data/similarity_dict_HAU.pkl') 123 | main_post = mongo_col.find_one({"slug": slug}) 124 | main_post = [ 125 | { 126 | "url": main_post["canonical_url"], 127 | "title": main_post["title"], 128 | "slug": main_post["slug"], 129 | "content": main_post["contents"] 130 | } 131 | ] 132 | main_post = main_post[0] 133 | 134 | most_sim_slugs = sim_topics[slug] 135 | posts = mongo_col.find({"slug": {"$in": most_sim_slugs}}) 136 | related_posts = [ 137 | { 138 | "url": post["canonical_url"], 139 | "title": post["title"], 140 | "slug": post["slug"] 141 | } 142 | for post in posts 143 | ] 144 | 145 | return render_template( 146 | 'index.html', main_post=main_post, posts=related_posts 147 | ) 148 | 149 | 150 | if __name__ == "__main__": 151 | app.run(host='0.0.0.0', debug=True) 152 | -------------------------------------------------------------------------------- /src/models.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import logging 3 | 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | import gensim 7 | from gensim.utils import simple_preprocess 8 | from sklearn.externals import joblib 9 | 10 | from src.distances import get_most_similar_documents 11 | 12 | logging.basicConfig(format='%(levelname)s : %(message)s', level=logging.INFO) 13 | logging.root.level = logging.INFO 14 | 15 | 16 | PATH_DICTIONARY = "models/id2word.dictionary" 17 | PATH_CORPUS = "models/corpus.mm" 18 | PATH_LDA_MODEL = "models/LDA.model" 19 | PATH_DOC_TOPIC_DIST = "models/doc_topic_dist.dat" 20 | 21 | 22 | def head(stream, n=10): 23 | """ 24 | Return the first `n` elements of the stream, as plain list. 25 | """ 26 | return list(itertools.islice(stream, n)) 27 | 28 | 29 | def tokenize(text, STOPWORDS): 30 | # deacc=True to remove punctuations 31 | return [token for token in simple_preprocess(text, deacc=True) 32 | if token not in STOPWORDS] 33 | 34 | 35 | def make_texts_corpus(sentences): 36 | for sentence in sentences: 37 | yield simple_preprocess(sentence, deacc=True) 38 | 39 | 40 | class StreamCorpus(object): 41 | def __init__(self, sentences, dictionary, clip_docs=None): 42 | """ 43 | Parse the first `clip_docs` documents 44 | Yield each document in turn, as a list of tokens. 45 | """ 46 | self.sentences = sentences 47 | self.dictionary = dictionary 48 | self.clip_docs = clip_docs 49 | 50 | def __iter__(self): 51 | for tokens in itertools.islice(make_texts_corpus(self.sentences), 52 | self.clip_docs): 53 | yield self.dictionary.doc2bow(tokens) 54 | 55 | def __len__(self): 56 | return self.clip_docs 57 | 58 | 59 | class LDAModel: 60 | 61 | def __init__(self, num_topics, passes, chunksize, 62 | random_state=100, update_every=1, alpha='auto', 63 | per_word_topics=False): 64 | """ 65 | :param sentences: list or iterable (recommend) 66 | """ 67 | 68 | # data 69 | self.sentences = None 70 | 71 | # params 72 | self.lda_model = None 73 | self.dictionary = None 74 | self.corpus = None 75 | 76 | # hyperparams 77 | self.num_topics = num_topics 78 | self.passes = passes 79 | self.chunksize = chunksize 80 | self.random_state = random_state 81 | self.update_every = update_every 82 | self.alpha = alpha 83 | self.per_word_topics = per_word_topics 84 | 85 | # init model 86 | # self._make_dictionary() 87 | # self._make_corpus_bow() 88 | 89 | def _make_corpus_bow(self, sentences): 90 | self.corpus = StreamCorpus(sentences, self.id2word) 91 | # save corpus 92 | gensim.corpora.MmCorpus.serialize(PATH_CORPUS, self.corpus) 93 | 94 | def _make_corpus_tfidf(self): 95 | pass 96 | 97 | def _make_dictionary(self, sentences): 98 | self.texts_corpus = make_texts_corpus(sentences) 99 | self.id2word = gensim.corpora.Dictionary(self.texts_corpus) 100 | self.id2word.filter_extremes(no_below=10, no_above=0.25) 101 | self.id2word.compactify() 102 | self.id2word.save(PATH_DICTIONARY) 103 | 104 | def documents_topic_distribution(self): 105 | doc_topic_dist = np.array( 106 | [[tup[1] for tup in lst] for lst in self.lda_model[self.corpus]] 107 | ) 108 | # save documents-topics matrix 109 | joblib.dump(doc_topic_dist, PATH_DOC_TOPIC_DIST) 110 | return doc_topic_dist 111 | 112 | def fit(self, sentences): 113 | from itertools import tee 114 | sentences_1, sentences_2 = tee(sentences) 115 | self._make_dictionary(sentences_1) 116 | self._make_corpus_bow(sentences_2) 117 | self.lda_model = gensim.models.ldamodel.LdaModel( 118 | self.corpus, id2word=self.id2word, num_topics=64, passes=5, 119 | chunksize=100, random_state=42, alpha=1e-2, eta=0.5e-2, 120 | minimum_probability=0.0, per_word_topics=False 121 | ) 122 | self.lda_model.save(PATH_LDA_MODEL) 123 | 124 | def transform(self, sentence): 125 | """ 126 | :param document: preprocessed document 127 | """ 128 | document_corpus = next(make_texts_corpus([sentence])) 129 | corpus = self.id2word.doc2bow(document_corpus) 130 | document_dist = np.array( 131 | [tup[1] for tup in self.lda_model.get_document_topics(bow=corpus)] 132 | ) 133 | return corpus, document_dist 134 | 135 | def predict(self, document_dist): 136 | doc_topic_dist = self.documents_topic_distribution() 137 | return get_most_similar_documents(document_dist, doc_topic_dist) 138 | 139 | def update(self, new_corpus): # TODO 140 | """ 141 | Online Learning LDA 142 | https://radimrehurek.com/gensim/models/ldamodel.html#usage-examples 143 | https://radimrehurek.com/gensim/wiki.html#latent-dirichlet-allocation 144 | """ 145 | self.lda_model.update(new_corpus) 146 | # get topic probability distribution for documents 147 | for corpus in new_corpus: 148 | yield self.lda_model[corpus] 149 | 150 | def model_perplexity(self): 151 | logging.INFO(self.lda_model.log_perplexity(self.corpus)) 152 | 153 | def coherence_score(self): 154 | self.coherence_model_lda = gensim.models.coherencemodel.CoherenceModel( 155 | model=self.lda_model, texts=self.corpus, 156 | dictionary=self.id2word, coherence='c_v' 157 | ) 158 | logging.INFO(self.coherence_model_lda.get_coherence()) 159 | 160 | def compute_coherence_values(self, mallet_path, dictionary, corpus, 161 | texts, end=40, start=2, step=3): 162 | """ 163 | Compute c_v coherence for various number of topics 164 | 165 | Parameters: 166 | ---------- 167 | dictionary : Gensim dictionary 168 | corpus : Gensim corpus 169 | texts : List of input texts 170 | end : Max num of topics 171 | 172 | Returns: 173 | ------- 174 | model_list : List of LDA topic models 175 | coherence_values : Coherence values corresponding to the LDA model 176 | with respective number of topics 177 | """ 178 | coherence_values = [] 179 | model_list = [] 180 | for num_topics in range(start, end, step): 181 | model = gensim.models.wrappers.LdaMallet( 182 | mallet_path, corpus=self.corpus, 183 | num_topics=self.num_topics, id2word=self.id2word) 184 | model_list.append(model) 185 | coherencemodel = gensim.models.coherencemodel.CoherenceModel( 186 | model=model, texts=self.texts_corpus, 187 | dictionary=self.dictionary, coherence='c_v' 188 | ) 189 | coherence_values.append(coherencemodel.get_coherence()) 190 | 191 | return model_list, coherence_values 192 | 193 | def plot(self, coherence_values, end=40, start=2, step=3): 194 | x = range(start, end, step) 195 | plt.plot(x, coherence_values) 196 | plt.xlabel("Num Topics") 197 | plt.ylabel("Coherence score") 198 | plt.legend(("coherence_values"), loc='best') 199 | plt.show() 200 | 201 | def print_topics(self): 202 | pass 203 | 204 | 205 | def main(): 206 | # TODO 207 | sentences = None 208 | sentences = make_texts_corpus(sentences) 209 | id2word = gensim.corpora.Dictionary(sentences) 210 | id2word.filter_extremes(no_below=20, no_above=0.1) 211 | id2word.compactify() 212 | 213 | # save dictionary 214 | # id2word.save('path_to_save_file.dictionary') 215 | cospus = StreamCorpus(sentences, id2word) 216 | # save corpus 217 | # gensim.corpora.MmCorpus.serialize('path_to_save_file.mm', cospus) 218 | # load corpus 219 | # mm_corpus = gensim.corpora.MmCorpus('path_to_save_file.mm') 220 | lda_model = gensim.models.ldamodel.LdaModel( 221 | cospus, num_topics=64, id2word=id2word, passes=10, chunksize=100 222 | ) 223 | # save model 224 | # lda_model.save('path_to_save_model.model') 225 | lda_model.print_topics(-1) 226 | 227 | 228 | if __name__ == '__main__': 229 | main() 230 | --------------------------------------------------------------------------------