├── .gitattributes ├── .github └── workflows │ └── testing.yml ├── .gitignore ├── LICENSE ├── Makefile ├── README-zh.md ├── README.md ├── docs ├── api │ ├── keybert.md │ ├── maxsum.md │ └── mmr.md ├── changelog.md ├── faq.md ├── guides │ ├── embeddings.md │ └── quickstart.md ├── index.md └── style.css ├── images ├── highlight.png ├── icon.png └── logo.png ├── mkdocs.yml ├── scripts └── pip.sh ├── setup.py ├── tests ├── __init__.py ├── example1.txt ├── test_model.py └── utils.py ├── theme └── style.css └── zhkeybert ├── __init__.py ├── _extract_kws.py ├── _highlight.py ├── _maxsum.py ├── _mmr.py ├── _model.py └── backend ├── __init__.py ├── _base.py ├── _flair.py ├── _gensim.py ├── _sentencetransformers.py ├── _spacy.py ├── _use.py └── _utils.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.ipynb linguist-documentation 2 | -------------------------------------------------------------------------------- /.github/workflows/testing.yml: -------------------------------------------------------------------------------- 1 | name: Code Checks 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | - dev 8 | pull_request: 9 | branches: 10 | - master 11 | - dev 12 | 13 | jobs: 14 | build: 15 | runs-on: ubuntu-latest 16 | strategy: 17 | matrix: 18 | python-version: [3.6, 3.8] 19 | 20 | steps: 21 | - uses: actions/checkout@v2 22 | - name: Set up Python ${{ matrix.python-version }} 23 | uses: actions/setup-python@v1 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install -e ".[dev]" 30 | - name: Run Checking Mechanisms 31 | run: make check 32 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Sphinx documentation 54 | docs/_build/ 55 | 56 | 57 | # Jupyter Notebook 58 | .ipynb_checkpoints 59 | 60 | # IPython 61 | profile_default/ 62 | ipython_config.py 63 | 64 | # pyenv 65 | .python-version 66 | 67 | # Environments 68 | .env 69 | .venv 70 | env/ 71 | venv/ 72 | ENV/ 73 | env.bak/ 74 | venv.bak/ 75 | 76 | .idea 77 | .idea/ 78 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020, Maarten P. Grootendorst 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | test: 2 | pytest 3 | 4 | install: 5 | python -m pip install -e . 6 | 7 | install-test: 8 | python -m pip install -e ".[test]" 9 | python -m pip install -e ".[all]" 10 | 11 | pypi: 12 | python setup.py sdist 13 | twine upload dist/* 14 | 15 | clean: 16 | rm -rf **/.ipynb_checkpoints **/.pytest_cache **/__pycache__ **/**/__pycache__ .ipynb_checkpoints .pytest_cache 17 | 18 | check: test clean 19 | -------------------------------------------------------------------------------- /README-zh.md: -------------------------------------------------------------------------------- 1 | [![PyPI - Python](https://img.shields.io/badge/python-3.6%20|%203.7%20|%203.8-blue.svg)](https://pypi.org/project/keybert/) 2 | [![PyPI - License](https://img.shields.io/badge/license-MIT-green.svg)](https://github.com/MaartenGr/keybert/blob/master/LICENSE) 3 | 4 | # ZhKeyBERT 5 | 6 | 基于[KeyBERT](https://github.com/MaartenGr/KeyBERT), 增强模型对中文 7 | 关键词提取的支持。 8 | 9 | ## 目录 10 | 11 | 1. [开始](#gettingstarted) 12 | 1.1. [安装](#installation) 13 | 1.2. [用法](#usage) 14 | 1.3. [对比](#compare) 15 | 1.4. [模型](#embeddings) 16 | 17 | 18 | 19 | ### 1.1. 安装 20 | 21 | ``` 22 | pip install zhkeybert --user 23 | ``` 24 | 25 | 26 | ### 1.2. 用法 27 | 28 | 最简单的样例: 29 | 30 | ```python 31 | from zhkeybert import KeyBERT, extract_kws_zh 32 | 33 | docs = """时值10月25日抗美援朝纪念日,《长津湖》片方发布了“纪念中国人民志愿军抗美援朝出国作战71周年特别短片”,再次向伟大的志愿军致敬! 34 | 电影《长津湖》全情全景地还原了71年前抗美援朝战场上那场史诗战役,志愿军奋不顾身的英勇精神令观众感叹:“岁月峥嵘英雄不灭,丹心铁骨军魂永存!”影片上映以来票房屡创新高,目前突破53亿元,暂列中国影史票房总榜第三名。 35 | 值得一提的是,这部影片的很多主创或有军人的血脉,或有当兵的经历,或者家人是军人。提起这些他们也充满自豪,影片总监制黄建新称:“当兵以后会有一种特别能坚持的劲儿。”饰演雷公的胡军透露:“我父亲曾经参加过抗美援朝,还得了一个三等功。”影片历史顾问王树增表示:“我当了五十多年的兵,我的老部队就是上甘岭上下来的,那些老兵都是我的偶像。” 36 | “身先士卒卫华夏家国,血战无畏护山河无恙。”片中饰演七连连长伍千里的吴京感叹:“要永远记住这些先烈们,他们给我们带来今天的和平。感谢他们的付出,才让我们有今天的幸福生活。”饰演新兵伍万里的易烊千玺表示:“战争的残酷、碾压式的伤害,其实我们现在的年轻人几乎很难能体会到,希望大家看完电影后能明白,是那些先辈们的牺牲奉献,换来了我们的现在。” 37 | 影片对战争群像的恢弘呈现,对个体命运的深切关怀,令许多观众无法控制自己的眼泪,观众称:“当看到影片中的惊险战斗场面,看到英雄们壮怀激烈的拼杀,为国捐躯的英勇无畏和无悔付出,我明白了为什么说今天的幸福生活来之不易。”(记者 王金跃)""" 38 | kw_model = KeyBERT() 39 | extract_kws_zh(docs, kw_model) 40 | ``` 41 | 42 | `ngram_range`决定了结果短句可以由多少个词语构成 43 | 44 | ```python 45 | >>> extract_kws_zh(docs, kw_model, ngram_range=(1, 1)) 46 | [('中国人民志愿军', 0.6094), 47 | ('长津湖', 0.505), 48 | ('周年', 0.4504), 49 | ('影片', 0.447), 50 | ('英雄', 0.4297)] 51 | ``` 52 | 53 | ```python 54 | >>> extract_kws_zh(docs, kw_model, ngram_range=(1, 2)) 55 | [('纪念中国人民志愿军', 0.6894), 56 | ('电影长津湖', 0.6285), 57 | ('年前抗美援朝', 0.476), 58 | ('中国人民志愿军抗美援朝', 0.6349), 59 | ('中国影史', 0.5534)] 60 | ``` 61 | 62 | `use_mmr`默认为`True`,会使用Maximal Marginal Relevance(MMR)算法 63 | 增加结果的多样性,`diversity`([0,1]之间,默认`0.25`)控制了多样性的程度, 64 | 值越大程度越高。若`use_mmr=False`,则容易出现多个结果包含同一个词语的情况。 65 | 66 | 高多样性的结果很杂乱: 67 | ```python 68 | >>> extract_kws_zh(docs, kw_model, use_mmr=True, diversity=0.7) 69 | [('纪念中国人民志愿军抗美援朝', 0.7034), 70 | ('观众无法控制自己', 0.1212), 71 | ('山河无恙', 0.2233), 72 | ('影片上映以来', 0.5427), 73 | ('53亿元', 0.3287)] 74 | ``` 75 | 76 | 低多样性的结果重复度相对较高: 77 | ```python 78 | >>> extract_kws_zh(docs, kw_model, use_mmr=True, diversity=0.7) 79 | [('纪念中国人民志愿军抗美援朝', 0.7034), 80 | ('电影长津湖', 0.6285), 81 | ('纪念中国人民志愿军', 0.6894), 82 | ('周年特别短片', 0.5668), 83 | ('作战71周年', 0.5637)] 84 | ``` 85 | 86 | `diversity=0.0`的结果与`use_mmr=False`的结果无异,权衡两种极端,比较推荐默认值`diversity=0.25`. 87 | 88 | ### 对比 89 | 本项目对KeyBERT的主要改进有: 90 | - 细化候选关键词的筛选,避免跨句组合等情况 91 | - 调整超参数,寻找效果较优的组合(例如原始模型中`use_maxsum`的效果奇差) 92 | - 找出效率和效果均比较优秀的模型`paraphrase-multilingual-MiniLM-L12-v2` 93 | 94 | ```python 95 | >>> from zhkeybert import KeyBERT, extract_kws_zh 96 | >>> import jieba 97 | >>> lines = [] 98 | >>> with open('tests/example1.txt', 'r') as f: 99 | for line in f: 100 | lines.append(line) 101 | >>> kw_model = KeyBERT(model='paraphrase-multilingual-MiniLM-L12-v2') 102 | >>> for line in lines: 103 | print(extract_kws_zh(line, kw_model)) 104 | [('网络文明大会', 0.7627), ('推进文明办网', 0.7084), ('北京国家会议', 0.5802), ('文明办网', 0.7105), ('大会主题为', 0.6182)] 105 | [('国家自然科学奖评选出', 0.7038), ('前沿研究领域', 0.6102), ('的重要科学', 0.62), ('自然科学奖', 0.693), ('自然科学奖一等奖', 0.6887)] 106 | [('等蔬菜价格下降', 0.7361), ('滑县瓜菜种植', 0.649), ('蔬菜均价每公斤', 0.6768), ('全国蔬菜均价', 0.709), ('村蔬菜种植', 0.6536)] 107 | [('中国共产党的任务', 0.7928), ('一个中国原则', 0.751), ('中国共产党的', 0.7541), ('统一是中国', 0.7095), ('中国人民捍卫', 0.7081)] 108 | >>> for line in lines: 109 | print(kw_model.extract_keywords(' '.join(jieba.cut(line)), keyphrase_ngram_range=(1, 3), 110 | use_mmr=True, diversity=0.25)) 111 | [('中国 网络 文明', 0.7355), ('网络 文明 大会', 0.7018), ('北京 国家 会议', 0.6802), ('首届 中国 网络', 0.723), ('打造 我国 网络', 0.6766)] 112 | [('基础 研究 国家自然科学奖', 0.7712), ('领域 重要 科学', 0.7054), ('研究 国家自然科学奖 评选', 0.7441), ('研究 国家自然科学奖', 0.7499), ('自然科学 一等奖', 0.7193)] 113 | [('目前 蔬菜 储备', 0.8036), ('滑县 瓜菜 种植', 0.7484), ('大省 蔬菜 面积', 0.798), ('居民 蔬菜 供应', 0.7902), ('设施 蔬菜 大省', 0.792)] 114 | [('统一 中国 全体', 0.7614), ('谈到 祖国统一', 0.7532), ('习近平 总书记 表示', 0.6338), ('祖国统一 问题 总书记', 0.7368), ('中国共产党 任务 实现', 0.679)] 115 | ``` 116 | 117 | 118 | ### 1.4. 模型 119 | KeyBERT支持许多embedding模型,但是对于中文语料,应该采用带有`multilingual`的模型, 120 | 默认模型为`paraphrase-multilingual-MiniLM-L12-v2`,也可手动指定:`KeyBERT(model=...)`。 121 | 122 | 可以从KeyBERT的[文档](https://maartengr.github.io/KeyBERT/guides/embeddings.html)中 123 | 了解如何加载各种来源的模型 124 | 125 | 下面是一些测试的结果 126 | 127 | |模型名称|计算速度|精确度|大小| 128 | |------|-------|-----|------| 129 | |universal-sentence-encoder-multilingual-large|很慢|较好|近1G| 130 | |universal-sentence-encoder-multilingual|一般|一般|小几百M| 131 | |paraphrase-multilingual-MiniLM-L12-v2|快|较好|400M| 132 | |bert-base-multilingual-cased (Flair)|慢|一般|几M| 133 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![PyPI - Python](https://img.shields.io/badge/python-3.6%20|%203.7%20|%203.8-blue.svg)](https://pypi.org/project/keybert/) 2 | [![PyPI - License](https://img.shields.io/badge/license-MIT-green.svg)](https://github.com/MaartenGr/keybert/blob/master/LICENSE) 3 | 4 | # ZhKeyBERT 5 | 6 | [中文文档](https://github.com/deepdialog/ZhKeyBERT/blob/master/README-zh.md) 7 | 8 | 9 | Based on [KeyBERT](https://github.com/MaartenGr/KeyBERT), enhance the keyword 10 | extraction model for Chinese. 11 | 12 | Corresponding medium post can be found [here](https://towardsdatascience.com/keyword-extraction-with-bert-724efca412ea). 13 | 14 | 15 | ## Table of Contents 16 | 17 | 1. [About the Project](#about) 18 | 2. [Getting Started](#gettingstarted) 19 | 2.1. [Installation](#installation) 20 | 2.2. [Basic Usage](#usage) 21 | 2.3. [Maximal Marginal Relevance](#maximal) 22 | 2.4. [Embedding Models](#embeddings) 23 | 24 | 25 | 26 | 27 | ## 1. About the Project 28 | [Back to ToC](#toc) 29 | 30 | Although there are already many methods available for keyword generation 31 | (e.g., 32 | [Rake](https://github.com/aneesha/RAKE), 33 | [YAKE!](https://github.com/LIAAD/yake), TF-IDF, etc.) 34 | I wanted to create a very basic, but powerful method for extracting keywords and keyphrases. 35 | This is where **KeyBERT** comes in! Which uses BERT-embeddings and simple cosine similarity 36 | to find the sub-phrases in a document that are the most similar to the document itself. 37 | 38 | First, document embeddings are extracted with BERT to get a document-level representation. 39 | Then, word embeddings are extracted for N-gram words/phrases. Finally, we use cosine similarity 40 | to find the words/phrases that are the most similar to the document. The most similar words could 41 | then be identified as the words that best describe the entire document. 42 | 43 | KeyBERT is by no means unique and is created as a quick and easy method 44 | for creating keywords and keyphrases. Although there are many great 45 | papers and solutions out there that use BERT-embeddings 46 | (e.g., 47 | [1](https://github.com/pranav-ust/BERT-keyphrase-extraction), 48 | [2](https://github.com/ibatra/BERT-Keyword-Extractor), 49 | [3](https://www.preprints.org/manuscript/201908.0073/download/final_file), 50 | ), I could not find a BERT-based solution that did not have to be trained from scratch and 51 | could be used for beginners (**correct me if I'm wrong!**). 52 | Thus, the goal was a `pip install keybert` and at most 3 lines of code in usage. 53 | 54 | 55 | ## 2. Getting Started 56 | [Back to ToC](#toc) 57 | 58 | 59 | ### 2.1. Installation 60 | 61 | ``` 62 | git clone https://github.com/deepdialog/ZhKeyBERT 63 | cd ZhKeyBERT 64 | python setup.py install --user 65 | ``` 66 | 67 | 68 | ### 2.2. Usage 69 | 70 | The most minimal example can be seen below for the extraction of keywords: 71 | ```python 72 | from zhkeybert import KeyBERT, extract_kws_zh 73 | 74 | docs = """时值10月25日抗美援朝纪念日,《长津湖》片方发布了“纪念中国人民志愿军抗美援朝出国作战71周年特别短片”,再次向伟大的志愿军致敬! 75 | 电影《长津湖》全情全景地还原了71年前抗美援朝战场上那场史诗战役,志愿军奋不顾身的英勇精神令观众感叹:“岁月峥嵘英雄不灭,丹心铁骨军魂永存!”影片上映以来票房屡创新高,目前突破53亿元,暂列中国影史票房总榜第三名。 76 | 值得一提的是,这部影片的很多主创或有军人的血脉,或有当兵的经历,或者家人是军人。提起这些他们也充满自豪,影片总监制黄建新称:“当兵以后会有一种特别能坚持的劲儿。”饰演雷公的胡军透露:“我父亲曾经参加过抗美援朝,还得了一个三等功。”影片历史顾问王树增表示:“我当了五十多年的兵,我的老部队就是上甘岭上下来的,那些老兵都是我的偶像。” 77 | “身先士卒卫华夏家国,血战无畏护山河无恙。”片中饰演七连连长伍千里的吴京感叹:“要永远记住这些先烈们,他们给我们带来今天的和平。感谢他们的付出,才让我们有今天的幸福生活。”饰演新兵伍万里的易烊千玺表示:“战争的残酷、碾压式的伤害,其实我们现在的年轻人几乎很难能体会到,希望大家看完电影后能明白,是那些先辈们的牺牲奉献,换来了我们的现在。” 78 | 影片对战争群像的恢弘呈现,对个体命运的深切关怀,令许多观众无法控制自己的眼泪,观众称:“当看到影片中的惊险战斗场面,看到英雄们壮怀激烈的拼杀,为国捐躯的英勇无畏和无悔付出,我明白了为什么说今天的幸福生活来之不易。”(记者 王金跃)""" 79 | kw_model = KeyBERT(model='paraphrase-multilingual-MiniLM-L12-v2') 80 | extract_kws_zh(docs, kw_model) 81 | ``` 82 | 83 | Comparison 84 | ```python 85 | >>> extract_kws_zh(docs, kw_model) 86 | 87 | [('纪念中国人民志愿军抗美援朝', 0.7034), 88 | ('电影长津湖', 0.6285), 89 | ('周年特别短片', 0.5668), 90 | ('纪念中国人民志愿军', 0.6894), 91 | ('作战71周年', 0.5637)] 92 | >>> import jieba; kw_model.extract_keywords(' '.join(jieba.cut(docs)), keyphrase_ngram_range=(1, 3), 93 | use_mmr=True, diversity=0.25) 94 | 95 | [('抗美援朝 纪念日 长津湖', 0.796), 96 | ('纪念 中国人民志愿军 抗美援朝', 0.7577), 97 | ('作战 71 周年', 0.6126), 98 | ('25 抗美援朝 纪念日', 0.635), 99 | ('致敬 电影 长津湖', 0.6514)] 100 | ``` 101 | 102 | You can set `ngram_range`, whose default value is `(1, 3)`, 103 | to set the length of the resulting keywords/keyphrases: 104 | 105 | ```python 106 | >>> extract_kws_zh(docs, kw_model, ngram_range=(1, 1)) 107 | [('中国人民志愿军', 0.6094), 108 | ('长津湖', 0.505), 109 | ('周年', 0.4504), 110 | ('影片', 0.447), 111 | ('英雄', 0.4297)] 112 | ``` 113 | 114 | ```python 115 | >>> extract_kws_zh(docs, kw_model, ngram_range=(1, 2)) 116 | [('纪念中国人民志愿军', 0.6894), 117 | ('电影长津湖', 0.6285), 118 | ('年前抗美援朝', 0.476), 119 | ('中国人民志愿军抗美援朝', 0.6349), 120 | ('中国影史', 0.5534)] 121 | ``` 122 | 123 | **NOTE**: For a full overview of all possible transformer models see [sentence-transformer](https://www.sbert.net/docs/pretrained_models.html). 124 | I would advise `"paraphrase-multilingual-MiniLM-L12-v2"` Chinese documents for efficiency 125 | and acceptable accuracy. 126 | 127 | 128 | ### 2.3. Maximal Marginal Relevance 129 | 130 | It's recommended to use Maximal Margin Relevance (MMR) for diversity by 131 | setting the optional parameter `use_mmr`, which is `True` in default. 132 | To diversify the results, we can use MMR to create 133 | keywords / keyphrases which is also based on cosine similarity. The results 134 | with **high diversity**: 135 | 136 | ```python 137 | >>> extract_kws_zh(docs, kw_model, use_mmr = True, diversity=0.7) 138 | [('纪念中国人民志愿军抗美援朝', 0.7034), 139 | ('观众无法控制自己', 0.1212), 140 | ('山河无恙', 0.2233), 141 | ('影片上映以来', 0.5427), 142 | ('53亿元', 0.3287)] 143 | ``` 144 | 145 | The results with **low diversity**: 146 | 147 | ```python 148 | >>> extract_kws_zh(docs, kw_model, use_mmr = True, diversity=0.2) 149 | [('纪念中国人民志愿军抗美援朝', 0.7034), 150 | ('电影长津湖', 0.6285), 151 | ('纪念中国人民志愿军', 0.6894), 152 | ('周年特别短片', 0.5668), 153 | ('作战71周年', 0.5637)] 154 | ``` 155 | 156 | And the default and recommended `diversity` is `0.25`. 157 | 158 | 159 | ### 2.4. Embedding Models 160 | KeyBERT supports many embedding models that can be used to embed the documents and words: 161 | 162 | * Sentence-Transformers 163 | * Flair 164 | * Spacy 165 | * Gensim 166 | * USE 167 | 168 | Click [here](https://maartengr.github.io/KeyBERT/guides/embeddings.html) for a full overview of all supported embedding models. 169 | 170 | **Sentence-Transformers** 171 | You can select any model from `sentence-transformers` [here](https://www.sbert.net/docs/pretrained_models.html) 172 | and pass it through KeyBERT with `model`: 173 | 174 | ```python 175 | from zhkeybert import KeyBERT 176 | kw_model = KeyBERT(model='all-MiniLM-L6-v2') 177 | ``` 178 | 179 | Or select a SentenceTransformer model with your own parameters: 180 | 181 | ```python 182 | from zhkeybert import KeyBERT 183 | from sentence_transformers import SentenceTransformer 184 | 185 | sentence_model = SentenceTransformer("all-MiniLM-L6-v2") 186 | kw_model = KeyBERT(model=sentence_model) 187 | ``` 188 | 189 | For Chinese keywords extraction, you should choose multilingual models 190 | like `paraphrase-multilingual-mpnet-base-v2` and `paraphrase-multilingual-MiniLM-L12-v2`. 191 | 192 | **MUSE** 193 | Multilingual Universal Sentence Encoder([MUSE](https://arxiv.org/abs/1907.04307)) 194 | 195 | ```python 196 | from zhkeybert import KeyBERT 197 | import tensorflow_hub import hub 198 | 199 | module_url = 'https://hub.tensorflow.google.cn/google/universal-sentence-encoder-multilingual-large/3' 200 | 201 | model = hub.load(module_url) 202 | kw_model = KeyBERT(model=model) ## slow but acceptable performance 203 | ``` 204 | 205 | ## Citation 206 | To cite KeyBERT in your work, please use the following bibtex reference: 207 | 208 | ```bibtex 209 | @misc{grootendorst2020keybert, 210 | author = {Maarten Grootendorst}, 211 | title = {KeyBERT: Minimal keyword extraction with BERT.}, 212 | year = 2020, 213 | publisher = {Zenodo}, 214 | version = {v0.3.0}, 215 | doi = {10.5281/zenodo.4461265}, 216 | url = {https://doi.org/10.5281/zenodo.4461265} 217 | } 218 | ``` 219 | 220 | ## References 221 | Below, you can find several resources that were used for the creation of KeyBERT 222 | but most importantly, these are amazing resources for creating impressive keyword extraction models: 223 | 224 | **Papers**: 225 | * Sharma, P., & Li, Y. (2019). [Self-Supervised Contextual Keyword and Keyphrase Retrieval with Self-Labelling.](https://www.preprints.org/manuscript/201908.0073/download/final_file) 226 | 227 | **Github Repos**: 228 | * https://github.com/thunlp/BERT-KPE 229 | * https://github.com/ibatra/BERT-Keyword-Extractor 230 | * https://github.com/pranav-ust/BERT-keyphrase-extraction 231 | * https://github.com/swisscom/ai-research-keyphrase-extraction 232 | 233 | **MMR**: 234 | The selection of keywords/keyphrases was modeled after: 235 | * https://github.com/swisscom/ai-research-keyphrase-extraction 236 | 237 | **NOTE**: If you find a paper or github repo that has an easy-to-use implementation 238 | of BERT-embeddings for keyword/keyphrase extraction, let me know! I'll make sure to 239 | add a reference to this repo. 240 | 241 | -------------------------------------------------------------------------------- /docs/api/keybert.md: -------------------------------------------------------------------------------- 1 | # `KeyBERT` 2 | 3 | ::: keybert.model.KeyBERT 4 | -------------------------------------------------------------------------------- /docs/api/maxsum.md: -------------------------------------------------------------------------------- 1 | # `Max Sum Similarity` 2 | 3 | ::: keybert.maxsum.max_sum_similarity 4 | -------------------------------------------------------------------------------- /docs/api/mmr.md: -------------------------------------------------------------------------------- 1 | # `Maximal Marginal Relevance` 2 | 3 | ::: keybert.mmr.mmr 4 | -------------------------------------------------------------------------------- /docs/changelog.md: -------------------------------------------------------------------------------- 1 | ## **Version 0.5.0** 2 | *Release date: 28 September, 2021* 3 | 4 | **Highlights**: 5 | 6 | * Added Guided KeyBERT 7 | * kw_model.extract_keywords(doc, seed_keywords=seed_keywords) 8 | * Thanks to [@zolekode](https://github.com/zolekode) for the inspiration! 9 | * Use the newest all-* models from SBERT 10 | 11 | **Miscellaneous**: 12 | 13 | * Added instructions in the FAQ to extract keywords from Chinese documents 14 | 15 | ## **Version 0.4.0** 16 | *Release date: 23 June, 2021* 17 | 18 | **Highlights**: 19 | 20 | * Highlight a document's keywords with: 21 | * ```keywords = kw_model.extract_keywords(doc, highlight=True)``` 22 | * Use `paraphrase-MiniLM-L6-v2` as the default embedder which gives great results! 23 | 24 | **Miscellaneous**: 25 | 26 | * Update Flair dependencies 27 | * Added FAQ 28 | 29 | ## **Version 0.3.0** 30 | *Release date: 10 May, 2021* 31 | 32 | The two main features are **candidate keywords** 33 | and several **backends** to use instead of Flair and SentenceTransformers! 34 | 35 | **Highlights**: 36 | 37 | * Use candidate words instead of extracting those from the documents ([#25](https://github.com/MaartenGr/KeyBERT/issues/25)) 38 | * ```KeyBERT().extract_keywords(doc, candidates)``` 39 | * Spacy, Gensim, USE, and Custom Backends were added (see documentation [here](https://maartengr.github.io/KeyBERT/guides/embeddings.html)) 40 | 41 | **Fixes**: 42 | 43 | * Improved imports 44 | * Fix encoding error when locally installing KeyBERT ([#30](https://github.com/MaartenGr/KeyBERT/issues/30)) 45 | 46 | **Miscellaneous**: 47 | 48 | * Improved documentation (ReadMe & MKDocs) 49 | * Add the main tutorial as a shield 50 | * Typos ([#31](https://github.com/MaartenGr/KeyBERT/pull/31), [#35](https://github.com/MaartenGr/KeyBERT/pull/35)) 51 | 52 | 53 | ## **Version 0.2.0** 54 | *Release date: 9 Feb, 2021* 55 | 56 | **Highlights**: 57 | 58 | * Add similarity scores to the output 59 | * Add Flair as a possible back-end 60 | * Update documentation + improved testing 61 | 62 | ## **Version 0.1.2* 63 | *Release date: 28 Oct, 2020* 64 | 65 | Added Max Sum Similarity as an option to diversify your results. 66 | 67 | 68 | ## **Version 0.1.0** 69 | *Release date: 27 Oct, 2020* 70 | 71 | This first release includes keyword/keyphrase extraction using BERT and simple cosine similarity. 72 | There is also an option to use Maximal Marginal Relevance to select the candidate keywords/keyphrases. 73 | -------------------------------------------------------------------------------- /docs/faq.md: -------------------------------------------------------------------------------- 1 | ## **Which embedding model works best for which language?** 2 | Unfortunately, there is not a definitive list of the best models for each language, this highly depends 3 | on your data, the model, and your specific use-case. However, the default model in KeyBERT 4 | (`"all-MiniLM-L6-v2"`) works great for **English** documents. In contrast, for **multi-lingual** 5 | documents or any other language, `"paraphrase-multilingual-MiniLM-L12-v2""` has shown great performance. 6 | 7 | If you want to use a model that provides a higher quality, but takes more compute time, then I would advise using `paraphrase-mpnet-base-v2` and `paraphrase-multilingual-mpnet-base-v2` instead. 8 | 9 | 10 | ## **Should I preprocess the data?** 11 | No. By using document embeddings there is typically no need to preprocess the data as all parts of a document 12 | are important in understanding the general topic of the document. Although this holds true in 99% of cases, if you 13 | have data that contains a lot of noise, for example, HTML-tags, then it would be best to remove them. HTML-tags 14 | typically do not contribute to the meaning of a document and should therefore be removed. However, if you apply 15 | topic modeling to HTML-code to extract topics of code, then it becomes important. 16 | 17 | 18 | ## **Can I use the GPU to speed up the model?** 19 | Yes! Since KeyBERT uses embeddings as its backend, a GPU is actually prefered when using this package. 20 | Although it is possible to use it without a dedicated GPU, the inference speed will be significantly slower. 21 | 22 | ## **How can I use KeyBERT with Chinese documents?** 23 | You need to make sure you use a Tokenizer in KeyBERT that supports tokenization of Chinese. I suggest installing [`jieba`](https://github.com/fxsjy/jieba) for this: 24 | 25 | ```python 26 | from sklearn.feature_extraction.text import CountVectorizer 27 | import jieba 28 | 29 | def tokenize_zh(text): 30 | words = jieba.lcut(text) 31 | return words 32 | 33 | vectorizer = CountVectorizer(tokenizer=tokenize_zh) 34 | ``` 35 | 36 | Then, simply pass the vectorizer to your KeyBERT instance: 37 | 38 | ```python 39 | from keybert import KeyBERT 40 | 41 | kw_model = KeyBERT() 42 | keywords = kw_model.extract_keywords(doc, vectorizer=vectorizer) 43 | ``` -------------------------------------------------------------------------------- /docs/guides/embeddings.md: -------------------------------------------------------------------------------- 1 | # Embedding Models 2 | In this tutorial we will be going through the embedding models that can be used in KeyBERT. 3 | Having the option to choose embedding models allow you to leverage pre-trained embeddings that suit your use-case. 4 | 5 | ### **Sentence Transformers** 6 | You can select any model from sentence-transformers [here](https://www.sbert.net/docs/pretrained_models.html) 7 | and pass it through KeyBERT with `model`: 8 | 9 | ```python 10 | from keybert import KeyBERT 11 | kw_model = KeyBERT(model="all-MiniLM-L6-v2") 12 | ``` 13 | 14 | Or select a SentenceTransformer model with your own parameters: 15 | 16 | ```python 17 | from sentence_transformers import SentenceTransformer 18 | 19 | sentence_model = SentenceTransformer("all-MiniLM-L6-v2") 20 | kw_model = KeyBERT(model=sentence_model) 21 | ``` 22 | 23 | ### **Flair** 24 | [Flair](https://github.com/flairNLP/flair) allows you to choose almost any embedding model that 25 | is publicly available. Flair can be used as follows: 26 | 27 | ```python 28 | from flair.embeddings import TransformerDocumentEmbeddings 29 | 30 | roberta = TransformerDocumentEmbeddings('roberta-base') 31 | kw_model = KeyBERT(model=roberta) 32 | ``` 33 | 34 | You can select any 🤗 transformers model [here](https://huggingface.co/models). 35 | 36 | Moreover, you can also use Flair to use word embeddings and pool them to create document embeddings. 37 | Under the hood, Flair simply averages all word embeddings in a document. Then, we can easily 38 | pass it to KeyBERT in order to use those word embeddings as document embeddings: 39 | 40 | ```python 41 | from flair.embeddings import WordEmbeddings, DocumentPoolEmbeddings 42 | 43 | glove_embedding = WordEmbeddings('crawl') 44 | document_glove_embeddings = DocumentPoolEmbeddings([glove_embedding]) 45 | 46 | kw_model = KeyBERT(model=document_glove_embeddings) 47 | ``` 48 | 49 | ### **Spacy** 50 | [Spacy](https://github.com/explosion/spaCy) is an amazing framework for processing text. There are 51 | many models available across many languages for modeling text. 52 | 53 | allows you to choose almost any embedding model that 54 | is publicly available. Flair can be used as follows: 55 | 56 | To use Spacy's non-transformer models in KeyBERT: 57 | 58 | ```python 59 | import spacy 60 | 61 | nlp = spacy.load("en_core_web_md", exclude=['tagger', 'parser', 'ner', 'attribute_ruler', 'lemmatizer']) 62 | 63 | kw_model = KeyBERT(model=nlp) 64 | ``` 65 | 66 | Using spacy-transformer models: 67 | 68 | ```python 69 | import spacy 70 | 71 | spacy.prefer_gpu() 72 | nlp = spacy.load("en_core_web_trf", exclude=['tagger', 'parser', 'ner', 'attribute_ruler', 'lemmatizer']) 73 | 74 | kw_model = KeyBERT(model=nlp) 75 | ``` 76 | 77 | If you run into memory issues with spacy-transformer models, try: 78 | 79 | ```python 80 | import spacy 81 | from thinc.api import set_gpu_allocator, require_gpu 82 | 83 | nlp = spacy.load("en_core_web_trf", exclude=['tagger', 'parser', 'ner', 'attribute_ruler', 'lemmatizer']) 84 | set_gpu_allocator("pytorch") 85 | require_gpu(0) 86 | 87 | kw_model = KeyBERT(model=nlp) 88 | ``` 89 | 90 | ### **Universal Sentence Encoder (USE)** 91 | The Universal Sentence Encoder encodes text into high dimensional vectors that are used here 92 | for embedding the documents. The model is trained and optimized for greater-than-word length text, 93 | such as sentences, phrases or short paragraphs. 94 | 95 | Using USE in KeyBERT is rather straightforward: 96 | 97 | ```python 98 | import tensorflow_hub 99 | embedding_model = tensorflow_hub.load("https://tfhub.dev/google/universal-sentence-encoder/4") 100 | kw_model = KeyBERT(model=embedding_model) 101 | ``` 102 | 103 | ### **Gensim** 104 | For Gensim, KeyBERT supports its `gensim.downloader` module. Here, we can download any model word embedding model 105 | to be used in KeyBERT. Note that Gensim is primarily used for Word Embedding models. This works typically 106 | best for short documents since the word embeddings are pooled. 107 | 108 | ```python 109 | import gensim.downloader as api 110 | ft = api.load('fasttext-wiki-news-subwords-300') 111 | kw_model = KeyBERT(model=ft) 112 | ``` 113 | 114 | ### **Custom Backend** 115 | If your backend or model cannot be found in the ones currently available, you can use the `keybert.backend.BaseEmbedder` class to 116 | create your own backend. Below, you will find an example of creating a SentenceTransformer backend for KeyBERT: 117 | 118 | ```python 119 | from keybert.backend import BaseEmbedder 120 | from sentence_transformers import SentenceTransformer 121 | 122 | class CustomEmbedder(BaseEmbedder): 123 | def __init__(self, embedding_model): 124 | super().__init__() 125 | self.embedding_model = embedding_model 126 | 127 | def embed(self, documents, verbose=False): 128 | embeddings = self.embedding_model.encode(documents, show_progress_bar=verbose) 129 | return embeddings 130 | 131 | # Create custom backend 132 | distilbert = SentenceTransformer("paraphrase-MiniLM-L6-v2") 133 | custom_embedder = CustomEmbedder(embedding_model=distilbert) 134 | 135 | # Pass custom backend to keybert 136 | kw_model = KeyBERT(model=custom_embedder) 137 | ``` 138 | -------------------------------------------------------------------------------- /docs/guides/quickstart.md: -------------------------------------------------------------------------------- 1 | ## **Installation** 2 | Installation can be done using [pypi](https://pypi.org/project/keybert/): 3 | 4 | ``` 5 | pip install keybert 6 | ``` 7 | 8 | You may want to install more depending on the transformers and language backends that you will be using. The possible installations are: 9 | 10 | ``` 11 | pip install keybert[flair] 12 | pip install keybert[gensim] 13 | pip install keybert[spacy] 14 | pip install keybert[use] 15 | ``` 16 | 17 | To install all backends: 18 | 19 | ``` 20 | pip install keybert[all] 21 | ``` 22 | 23 | ## **Usage** 24 | 25 | The most minimal example can be seen below for the extraction of keywords: 26 | ```python 27 | from keybert import KeyBERT 28 | 29 | doc = """ 30 | Supervised learning is the machine learning task of learning a function that 31 | maps an input to an output based on example input-output pairs.[1] It infers a 32 | function from labeled training data consisting of a set of training examples.[2] 33 | In supervised learning, each example is a pair consisting of an input object 34 | (typically a vector) and a desired output value (also called the supervisory signal). 35 | A supervised learning algorithm analyzes the training data and produces an inferred function, 36 | which can be used for mapping new examples. An optimal scenario will allow for the 37 | algorithm to correctly determine the class labels for unseen instances. This requires 38 | the learning algorithm to generalize from the training data to unseen situations in a 39 | 'reasonable' way (see inductive bias). 40 | """ 41 | kw_model = KeyBERT() 42 | keywords = kw_model.extract_keywords(doc) 43 | ``` 44 | 45 | You can set `keyphrase_ngram_range` to set the length of the resulting keywords/keyphrases: 46 | 47 | ```python 48 | >>> kw_model.extract_keywords(doc, keyphrase_ngram_range=(1, 1), stop_words=None) 49 | [('learning', 0.4604), 50 | ('algorithm', 0.4556), 51 | ('training', 0.4487), 52 | ('class', 0.4086), 53 | ('mapping', 0.3700)] 54 | ``` 55 | 56 | To extract keyphrases, simply set `keyphrase_ngram_range` to (1, 2) or higher depending on the number 57 | of words you would like in the resulting keyphrases: 58 | 59 | ```python 60 | >>> kw_model.extract_keywords(doc, keyphrase_ngram_range=(1, 2), stop_words=None) 61 | [('learning algorithm', 0.6978), 62 | ('machine learning', 0.6305), 63 | ('supervised learning', 0.5985), 64 | ('algorithm analyzes', 0.5860), 65 | ('learning function', 0.5850)] 66 | ``` 67 | 68 | We can highlight the keywords in the document by simply setting `hightlight`: 69 | 70 | ```python 71 | keywords = kw_model.extract_keywords(doc, highlight=True) 72 | ``` 73 | 74 | **NOTE**: For a full overview of all possible transformer models see [sentence-transformer](https://www.sbert.net/docs/pretrained_models.html). 75 | I would advise either `"all-MiniLM-L6-v2"` for English documents or `"paraphrase-multilingual-MiniLM-L12-v2"` 76 | for multi-lingual documents or any other language. 77 | 78 | ### Max Sum Similarity 79 | 80 | To diversify the results, we take the 2 x top_n most similar words/phrases to the document. 81 | Then, we take all top_n combinations from the 2 x top_n words and extract the combination 82 | that are the least similar to each other by cosine similarity. 83 | 84 | ```python 85 | >>> kw_model.extract_keywords(doc, keyphrase_ngram_range=(3, 3), stop_words='english', 86 | use_maxsum=True, nr_candidates=20, top_n=5) 87 | [('set training examples', 0.7504), 88 | ('generalize training data', 0.7727), 89 | ('requires learning algorithm', 0.5050), 90 | ('supervised learning algorithm', 0.3779), 91 | ('learning machine learning', 0.2891)] 92 | ``` 93 | 94 | ### Maximal Marginal Relevance 95 | 96 | To diversify the results, we can use Maximal Margin Relevance (MMR) to create 97 | keywords / keyphrases which is also based on cosine similarity. The results 98 | with **high diversity**: 99 | 100 | ```python 101 | >>> kw_model.extract_keywords(doc, keyphrase_ngram_range=(3, 3), stop_words='english', 102 | use_mmr=True, diversity=0.7) 103 | [('algorithm generalize training', 0.7727), 104 | ('labels unseen instances', 0.1649), 105 | ('new examples optimal', 0.4185), 106 | ('determine class labels', 0.4774), 107 | ('supervised learning algorithm', 0.7502)] 108 | ``` 109 | 110 | The results with **low diversity**: 111 | 112 | ```python 113 | >>> kw_model.extract_keywords(doc, keyphrase_ngram_range=(3, 3), stop_words='english', 114 | use_mmr=True, diversity=0.2) 115 | [('algorithm generalize training', 0.7727), 116 | ('supervised learning algorithm', 0.7502), 117 | ('learning machine learning', 0.7577), 118 | ('learning algorithm analyzes', 0.7587), 119 | ('learning algorithm generalize', 0.7514)] 120 | ``` 121 | 122 | ### Candidate Keywords/Keyphrases 123 | In some cases, one might want to be using candidate keywords generated by other keyword algorithms or retrieved from a select list of possible keywords/keyphrases. In KeyBERT, you can easily use those candidate keywords to perform keyword extraction: 124 | 125 | ```python 126 | import yake 127 | from keybert import KeyBERT 128 | 129 | doc = """ 130 | Supervised learning is the machine learning task of learning a function that 131 | maps an input to an output based on example input-output pairs.[1] It infers a 132 | function from labeled training data consisting of a set of training examples.[2] 133 | In supervised learning, each example is a pair consisting of an input object 134 | (typically a vector) and a desired output value (also called the supervisory signal). 135 | A supervised learning algorithm analyzes the training data and produces an inferred function, 136 | which can be used for mapping new examples. An optimal scenario will allow for the 137 | algorithm to correctly determine the class labels for unseen instances. This requires 138 | the learning algorithm to generalize from the training data to unseen situations in a 139 | 'reasonable' way (see inductive bias). 140 | """ 141 | 142 | # Create candidates 143 | kw_extractor = yake.KeywordExtractor(top=50) 144 | candidates = kw_extractor.extract_keywords(doc) 145 | candidates = [candidate[0] for candidate in candidates] 146 | 147 | # KeyBERT init 148 | kw_model = KeyBERT() 149 | keywords = kw_model.extract_keywords(doc, candidates) 150 | ``` 151 | 152 | ### Guided KeyBERT 153 | 154 | Guided KeyBERT is similar to Guided Topic Modeling in that it tries to steer the training towards a set of seeded terms. When applying KeyBERT it automatically extracts the most related keywords to a specific document. However, there are times when stakeholders and users are looking for specific types of keywords. For example, when publishing an article on your website through contentful, you typically already know the global keywords related to the article. However, there might be a specific topic in the article that you would like to be extracted through the keywords. To achieve this, we simply give KeyBERT a set of related seeded keywords (it can also be a single one!) and search for keywords that are similar to both the document and the seeded keywords. 155 | 156 | Using this feature is as simple as defining a list of seeded keywords and passing them to KeyBERT: 157 | 158 | 159 | ```python 160 | doc = """ 161 | Supervised learning is the machine learning task of learning a function that 162 | maps an input to an output based on example input-output pairs.[1] It infers a 163 | function from labeled training data consisting of a set of training examples.[2] 164 | In supervised learning, each example is a pair consisting of an input object 165 | (typically a vector) and a desired output value (also called the supervisory signal). 166 | A supervised learning algorithm analyzes the training data and produces an inferred function, 167 | which can be used for mapping new examples. An optimal scenario will allow for the 168 | algorithm to correctly determine the class labels for unseen instances. This requires 169 | the learning algorithm to generalize from the training data to unseen situations in a 170 | 'reasonable' way (see inductive bias). 171 | """ 172 | 173 | kw_model = KeyBERT() 174 | seed_keywords = ["information"] 175 | keywords = kw_model.extract_keywords(doc, use_mmr=True, diversity=0.1, seed_keywords=seed_keywords) 176 | ``` -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # KeyBERT 4 | 5 | KeyBERT is a minimal and easy-to-use keyword extraction technique that leverages BERT embeddings to 6 | create keywords and keyphrases that are most similar to a document. 7 | 8 | ## About the Project 9 | 10 | Although there are already many methods available for keyword generation 11 | (e.g., 12 | [Rake](https://github.com/aneesha/RAKE), 13 | [YAKE!](https://github.com/LIAAD/yake), TF-IDF, etc.) 14 | I wanted to create a very basic, but powerful method for extracting keywords and keyphrases. 15 | This is where **KeyBERT** comes in! Which uses BERT-embeddings and simple cosine similarity 16 | to find the sub-phrases in a document that are the most similar to the document itself. 17 | 18 | First, document embeddings are extracted with BERT to get a document-level representation. 19 | Then, word embeddings are extracted for N-gram words/phrases. Finally, we use cosine similarity 20 | to find the words/phrases that are the most similar to the document. The most similar words could 21 | then be identified as the words that best describe the entire document. 22 | 23 | KeyBERT is by no means unique and is created as a quick and easy method 24 | for creating keywords and keyphrases. Although there are many great 25 | papers and solutions out there that use BERT-embeddings 26 | (e.g., 27 | [1](https://github.com/pranav-ust/BERT-keyphrase-extraction), 28 | [2](https://github.com/ibatra/BERT-Keyword-Extractor), 29 | [3](https://www.preprints.org/manuscript/201908.0073/download/final_file), 30 | ), I could not find a BERT-based solution that did not have to be trained from scratch and 31 | could be used for beginners (**correct me if I'm wrong!**). 32 | Thus, the goal was a `pip install keybert` and at most 3 lines of code in usage. 33 | 34 | ## Installation 35 | Installation can be done using [pypi](https://pypi.org/project/keybert/): 36 | 37 | ``` 38 | pip install keybert 39 | ``` 40 | 41 | You may want to install more depending on the transformers and language backends that you will be using. The possible installations are: 42 | 43 | ``` 44 | pip install keybert[flair] 45 | pip install keybert[gensim] 46 | pip install keybert[spacy] 47 | pip install keybert[use] 48 | ``` 49 | 50 | To install all backends: 51 | 52 | ``` 53 | pip install keybert[all] 54 | ``` 55 | 56 | 57 | ## Usage 58 | 59 | 60 | The most minimal example can be seen below for the extraction of keywords: 61 | ```python 62 | from keybert import KeyBERT 63 | 64 | doc = """ 65 | Supervised learning is the machine learning task of learning a function that 66 | maps an input to an output based on example input-output pairs. It infers a 67 | function from labeled training data consisting of a set of training examples. 68 | In supervised learning, each example is a pair consisting of an input object 69 | (typically a vector) and a desired output value (also called the supervisory signal). 70 | A supervised learning algorithm analyzes the training data and produces an inferred function, 71 | which can be used for mapping new examples. An optimal scenario will allow for the 72 | algorithm to correctly determine the class labels for unseen instances. This requires 73 | the learning algorithm to generalize from the training data to unseen situations in a 74 | 'reasonable' way (see inductive bias). 75 | """ 76 | kw_model = KeyBERT() 77 | keywords = kw_model.extract_keywords(doc) 78 | ``` 79 | 80 | You can set `keyphrase_ngram_range` to set the length of the resulting keywords/keyphrases: 81 | 82 | ```python 83 | >>> kw_model.extract_keywords(doc, keyphrase_ngram_range=(1, 1), stop_words=None) 84 | [('learning', 0.4604), 85 | ('algorithm', 0.4556), 86 | ('training', 0.4487), 87 | ('class', 0.4086), 88 | ('mapping', 0.3700)] 89 | ``` 90 | 91 | To extract keyphrases, simply set `keyphrase_ngram_range` to (1, 2) or higher depending on the number 92 | of words you would like in the resulting keyphrases: 93 | 94 | ```python 95 | >>> kw_model.extract_keywords(doc, keyphrase_ngram_range=(1, 2), stop_words=None) 96 | [('learning algorithm', 0.6978), 97 | ('machine learning', 0.6305), 98 | ('supervised learning', 0.5985), 99 | ('algorithm analyzes', 0.5860), 100 | ('learning function', 0.5850)] 101 | ``` 102 | -------------------------------------------------------------------------------- /docs/style.css: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepdialog/ZhKeyBERT/15926f4418323a5d2983edaff784edb30a4f3068/docs/style.css -------------------------------------------------------------------------------- /images/highlight.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepdialog/ZhKeyBERT/15926f4418323a5d2983edaff784edb30a4f3068/images/highlight.png -------------------------------------------------------------------------------- /images/icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepdialog/ZhKeyBERT/15926f4418323a5d2983edaff784edb30a4f3068/images/icon.png -------------------------------------------------------------------------------- /images/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepdialog/ZhKeyBERT/15926f4418323a5d2983edaff784edb30a4f3068/images/logo.png -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: KeyBERT 2 | extra_css: [style.css] 3 | repo_url: https://github.com/MaartenGr/keyBERT 4 | site_url: https://maartengr.github.io/keyBERT/ 5 | site_description: Leveraging BERT to extract important keywords 6 | site_author: Maarten P. Grootendorst 7 | use_directory_urls: false 8 | nav: 9 | - Home: index.md 10 | - Guides: 11 | - Quickstart: guides/quickstart.md 12 | - Embedding Models: guides/embeddings.md 13 | - API: 14 | - KeyBERT: api/keybert.md 15 | - MMR: api/mmr.md 16 | - MaxSum: api/maxsum.md 17 | - FAQ: faq.md 18 | - Changelog: changelog.md 19 | plugins: 20 | - mkdocstrings: 21 | watch: 22 | - keybert 23 | - search 24 | copyright: Copyright © 2021 Maintained by Maarten. 25 | theme: 26 | custom_dir: images/ 27 | name: material 28 | icon: 29 | logo: material/library 30 | font: 31 | text: Ubuntu 32 | code: Ubuntu Mono 33 | favicon: icon.png 34 | logo: icon.png 35 | feature: 36 | tabs: true 37 | palette: 38 | primary: black 39 | accent: blue 40 | markdown_extensions: 41 | - pymdownx.highlight: 42 | - pymdownx.superfences: 43 | - toc: 44 | permalink: true 45 | -------------------------------------------------------------------------------- /scripts/pip.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # python -m pip install -U twine keyring keyrings.alt 3 | 4 | set -e 5 | rm -rf ./dist 6 | python3 setup.py bdist_wheel 7 | twine upload dist/*.whl 8 | 9 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | test_packages = [ 4 | "pytest>=5.4.3", 5 | "pytest-cov>=2.6.1" 6 | ] 7 | 8 | base_packages = [ 9 | "jieba", 10 | "sentence-transformers>=0.3.8", 11 | "scikit-learn>=0.22.2", 12 | "numpy>=1.18.5", 13 | "rich>=10.4.0" 14 | ] 15 | 16 | docs_packages = [ 17 | "mkdocs>=1.1", 18 | "mkdocs-material>=4.6.3", 19 | "mkdocstrings>=0.8.0", 20 | ] 21 | 22 | flair_packages = [ 23 | "transformers==3.5.1", 24 | "torch>=1.4.0,<1.7.1", 25 | "flair==0.7" 26 | ] 27 | 28 | spacy_packages = [ 29 | "spacy>=3.0.1" 30 | ] 31 | 32 | use_packages = [ 33 | "tensorflow", 34 | "tensorflow_hub", 35 | "tensorflow_text" 36 | ] 37 | 38 | gensim_packages = [ 39 | "gensim>=3.6.0" 40 | ] 41 | 42 | extra_packages = flair_packages + spacy_packages + use_packages + gensim_packages 43 | 44 | dev_packages = docs_packages + test_packages 45 | 46 | with open("README.md", "r", encoding='utf-8') as fh: 47 | long_description = fh.read() 48 | 49 | setup( 50 | name="zhkeybert", 51 | packages=find_packages(exclude=["notebooks", "docs"]), 52 | version="0.1.2", 53 | author="Maarten Grootendorst, Yao Su", 54 | author_email="maartengrootendorst@gmail.com, 1092702101@qq.com", 55 | description="Based on KeyBERT performs Chinese documents keyword extraction with state-of-the-art transformer models.", 56 | long_description=long_description, 57 | long_description_content_type="text/markdown", 58 | url="https://github.com/deepdialog/ZhKeyBERT", 59 | keywords="nlp bert keyword extraction embeddings for Chinese", 60 | classifiers=[ 61 | "Programming Language :: Python", 62 | "Intended Audience :: Science/Research", 63 | "Intended Audience :: Developers", 64 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 65 | "License :: OSI Approved :: MIT License", 66 | "Topic :: Scientific/Engineering", 67 | "Operating System :: Microsoft :: Windows", 68 | "Operating System :: POSIX", 69 | "Operating System :: Unix", 70 | "Operating System :: MacOS", 71 | "Programming Language :: Python :: 3.7", 72 | "Programming Language :: Python :: 3.6", 73 | "Programming Language :: Python :: 3.8", 74 | ], 75 | install_requires=base_packages, 76 | extras_require={ 77 | "test": test_packages, 78 | "docs": docs_packages, 79 | "dev": dev_packages, 80 | "flair": flair_packages 81 | }, 82 | python_requires='>=3.6', 83 | ) -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepdialog/ZhKeyBERT/15926f4418323a5d2983edaff784edb30a4f3068/tests/__init__.py -------------------------------------------------------------------------------- /tests/example1.txt: -------------------------------------------------------------------------------- 1 | 首届中国网络文明大会将于11月19日在北京国家会议中心举办。大会主题为“汇聚向上向善力量,携手建设网络文明”,会上将发布新时代网络文明建设十件大事和共建网络文明行动倡议。大会由中央网信办、中央文明办和中共北京市委、北京市人民政府共同主办。中央网信办副主任、国家网信办副主任盛荣华在会上介绍说,大会定位于打造我国网络文明的理念宣介平台、经验交流平台、成果展示平台和国际网络文明互鉴平台。主要内容包括开幕式、主论坛和分论坛。有关部门负责同志,互联网企业和网络社会组织代表,以及专家学者和网民代表将在会上开展对话交流。据介绍,大会的举办旨在坚持以人民为中心的发展思想,大力发展积极健康的网络文化,净化网络生态、滋养网络空间,满足亿万网民对美好生活的向往。同时,通过理念宣介、经验交流、成果展示、文明互鉴等方式,全面推进文明办网、文明用网、文明上网、文明兴网,为加快建设网络强国,全面建设社会主义现代化国家,实现第二个百年奋斗目标提供坚强思想保证、强大精神动力、有力舆论支持和良好文化条件 2 | 持续激励基础研究。国家自然科学奖评选出一批原创性成果,有的聚焦基础研究,如数学研究在现代数论的前沿研究领域取得了重要突破,“具有界面效应的复合材料细观力学研究”处于国际先进水平;有的瞄准应用基础研究或民生领域的重要科学问题,如“麻风危害发生的免疫遗传学机制”研究成果加速了我国消除麻风危害的进程。2项自然科学奖一等奖全部由化学领域的研究成果摘得:中科院大连化物所包信和团队原创性提出了“纳米限域催化”新概念并成功实践,为催化过程和催化剂设计走向“精准”建立了理论基础;复旦大学赵东元团队的研究成果“有序介孔高分子和碳材料的创制和应用”在国际上率先提出有机—有机自组装软模板合成介孔材料思路,首次将功能介孔材料从无机骨架扩展到有机高分子材料。 3 | 为应对菜价上涨,近日,河南滑县多措并举保障居民蔬菜供应。那么,目前蔬菜的储备和生产形势如何?走进安阳市滑县大寨乡东冯营村的香菜种植田,这里的400余亩香菜正陆续进行采收,每天采收量可达2万斤。随着采收季的到来,此前上涨的价格也有明显回落。河南安阳市滑县大寨乡东冯营村香菜种植户武子威:最好的时候11每斤,现在行情有所回落,今天走的是4.8一斤。在滑县八里营镇刘苑村蔬菜种植基地,近200个黄瓜大棚陆续开棚。在村口的交易市场内,前来选购黄瓜的客商络绎不绝,打包装车后黄瓜将销往河北、湖北、上海等地。今年,滑县瓜菜种植面积达58万余亩,产量约231万余吨。记者了解到,目前,云南、海南等“南菜北运”的七个大省加上山东、河北等六个北方设施蔬菜大省,在田蔬菜面积达到9130万亩,同比增加350万亩。随着这些地区陆续进入蔬菜采收旺季,11月第2周,全国蔬菜均价每公斤5.79元,比较10月份下降2.3%,其中菠菜、生菜等蔬菜价格下降明显。 4 | 孙中山先生说过:“‘统一’是中国全体国民的希望。能够统一,全国人民便享福;不能统一,便要受害。”完成祖国统一是历史交与中国共产党的任务,是实现中华民族伟大复兴的必然要求。在讲话中,习近平总书记重点谈到祖国统一问题。总书记坚定表示,台湾问题因民族弱乱而产生,必将随着民族复兴而解决。他呼吁,两岸同胞都要站在历史正确的一边,共同创造祖国完全统一、民族伟大复兴的光荣伟业。同时,习近平总书记郑重宣示了坚持一个中国原则和“九二共识”、坚决遏制“台独”分裂活动、捍卫国家主权和领土完整的坚强决心与坚定意志。对于干涉台湾问题的外部势力,习近平总书记表示,任何人都不要低估中国人民捍卫国家主权和领土完整的坚强决心、坚定意志、强大能力!这指出了祖国完全统一的光明前景,也昭示了中华民族复兴伟业的光明前景。 5 | -------------------------------------------------------------------------------- /tests/test_model.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from .utils import get_test_data 3 | from sklearn.feature_extraction.text import CountVectorizer 4 | from keybert import KeyBERT 5 | 6 | doc_one, doc_two = get_test_data() 7 | model = KeyBERT(model='all-MiniLM-L6-v2') 8 | 9 | 10 | @pytest.mark.parametrize("keyphrase_length", [(1, i+1) for i in range(5)]) 11 | @pytest.mark.parametrize("vectorizer", [None, CountVectorizer(ngram_range=(1, 1), stop_words="english")]) 12 | def test_single_doc(keyphrase_length, vectorizer): 13 | """ Test whether the keywords are correctly extracted """ 14 | top_n = 5 15 | 16 | keywords = model.extract_keywords(doc_one, 17 | keyphrase_ngram_range=keyphrase_length, 18 | min_df=1, 19 | top_n=top_n, 20 | vectorizer=vectorizer) 21 | 22 | assert isinstance(keywords, list) 23 | assert isinstance(keywords[0], tuple) 24 | assert isinstance(keywords[0][0], str) 25 | assert isinstance(keywords[0][1], float) 26 | assert len(keywords) == top_n 27 | for keyword in keywords: 28 | assert len(keyword[0].split(" ")) <= keyphrase_length[1] 29 | 30 | 31 | @pytest.mark.parametrize("keyphrase_length, mmr, maxsum", [((1, i+1), truth, not truth) 32 | for i in range(4) 33 | for truth in [True, False]]) 34 | @pytest.mark.parametrize("vectorizer", [None, CountVectorizer(ngram_range=(1, 1), stop_words="english")]) 35 | def test_extract_keywords_single_doc(keyphrase_length, mmr, maxsum, vectorizer): 36 | """ Test extraction of protected single document method """ 37 | top_n = 5 38 | keywords = model._extract_keywords_single_doc(doc_one, 39 | top_n=top_n, 40 | keyphrase_ngram_range=keyphrase_length, 41 | use_mmr=mmr, 42 | use_maxsum=maxsum, 43 | diversity=0.5, 44 | vectorizer=vectorizer) 45 | assert isinstance(keywords, list) 46 | assert isinstance(keywords[0][0], str) 47 | assert isinstance(keywords[0][1], float) 48 | assert len(keywords) == top_n 49 | for keyword in keywords: 50 | assert len(keyword[0].split(" ")) <= keyphrase_length[1] 51 | 52 | 53 | @pytest.mark.parametrize("keyphrase_length", [(1, i+1) for i in range(5)]) 54 | def test_extract_keywords_multiple_docs(keyphrase_length): 55 | """ Test extractino of protected multiple document method""" 56 | top_n = 5 57 | keywords_list = model._extract_keywords_multiple_docs([doc_one, doc_two], 58 | top_n=top_n, 59 | keyphrase_ngram_range=keyphrase_length) 60 | assert isinstance(keywords_list, list) 61 | assert isinstance(keywords_list[0], list) 62 | assert len(keywords_list) == 2 63 | 64 | for keywords in keywords_list: 65 | assert len(keywords) == top_n 66 | 67 | for keyword in keywords: 68 | assert len(keyword[0].split(" ")) <= keyphrase_length[1] 69 | 70 | 71 | def test_guided(): 72 | """ Test whether the keywords are correctly extracted """ 73 | top_n = 5 74 | seed_keywords = ["time", "night", "day", "moment"] 75 | keywords = model.extract_keywords(doc_one, 76 | min_df=1, 77 | top_n=top_n, 78 | seed_keywords=seed_keywords) 79 | 80 | assert isinstance(keywords, list) 81 | assert isinstance(keywords[0], tuple) 82 | assert isinstance(keywords[0][0], str) 83 | assert isinstance(keywords[0][1], float) 84 | assert len(keywords) == top_n 85 | 86 | 87 | def test_error(): 88 | """ Empty doc should raise a ValueError """ 89 | with pytest.raises(AttributeError): 90 | doc = [] 91 | model._extract_keywords_single_doc(doc) 92 | 93 | 94 | def test_empty_doc(): 95 | """ Test empty document """ 96 | doc = "" 97 | result = model._extract_keywords_single_doc(doc) 98 | 99 | assert result == [] 100 | -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | def get_test_data(): 2 | doc_one = "\n\nI am sure some bashers of Pens fans are pretty confused about the lack\nof " \ 3 | "any kind of posts about the recent Pens massacre of the Devils. Actually,\nI am " \ 4 | "bit puzzled too and a bit relieved. However, I am going to put an end\nto non-PIttsburghers' " \ 5 | "relief with a bit of praise for the Pens. Man, they\nare killing those Devils worse than I thought. " \ 6 | "Jagr just showed you why\nhe is much better than his regular season stats. " \ 7 | "He is also a lot\nfo fun to watch in the playoffs. Bowman should let JAgr have " \ 8 | "a lot of\nfun in the next couple of games since the Pens are going to beat the " \ 9 | "pulp out of Jersey anyway. I was very disappointed not to see the Islanders lose " \ 10 | "the final\nregular season game. PENS RULE!!!\n\n" 11 | 12 | doc_two = "\n[stuff deleted]\n\nOk, here's the solution to your problem. " \ 13 | "Move to Canada. Yesterday I was able\nto watch FOUR games...the NJ-PITT " \ 14 | "at 1:00 on ABC, LA-CAL at 3:00 (CBC), \nBUFF-BOS at 7:00 (TSN and FOX), " \ 15 | "and MON-QUE at 7:30 (CBC). I think that if\neach series goes its max I " \ 16 | "could be watching hockey playoffs for 40-some odd\nconsecutive nights " \ 17 | "(I haven't counted so that's a pure guess).\n\nI have two tv's in my house, " \ 18 | "and I set them up side-by-side to watch MON-QUE\nand keep an eye on " \ 19 | "BOS-BUFF at the same time. I did the same for the two\nafternoon games." \ 20 | "\n\nBtw, those ABC commentaters were great! I was quite impressed; they " \ 21 | "seemed\nto know that their audience wasn't likely to be well-schooled in " \ 22 | "hockey lore\nand they did an excellent job. They were quite impartial also, IMO.\n\n" 23 | 24 | return doc_one, doc_two 25 | -------------------------------------------------------------------------------- /theme/style.css: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepdialog/ZhKeyBERT/15926f4418323a5d2983edaff784edb30a4f3068/theme/style.css -------------------------------------------------------------------------------- /zhkeybert/__init__.py: -------------------------------------------------------------------------------- 1 | from ._model import KeyBERT 2 | from ._extract_kws import extract_kws_zh 3 | 4 | __version__ = "0.1.2" 5 | -------------------------------------------------------------------------------- /zhkeybert/_extract_kws.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Union 2 | import re 3 | import jieba 4 | 5 | from ._model import KeyBERT 6 | 7 | 8 | def gen_candidates_zh(docs: str, ngram_range: Tuple[int, int]) -> List[str]: 9 | """split the Chinese document into keyword candidates 10 | 11 | Args: 12 | docs (str): the Chinese document 13 | ngram_range (Tuple[int, int]): Length, in words, of the extracted keywords/keyphrases 14 | 15 | Returns: 16 | List[str]: keyword candidates 17 | """ 18 | sdocs = re.split(r'[。!;?,,.?::、“”]', docs) 19 | res = set() 20 | for sdoc in sdocs: 21 | res 22 | cdoc = list(jieba.cut(re.sub('\W*', '', sdoc))) 23 | for i in range(ngram_range[0], ngram_range[1] + 1): 24 | for j in range(i, len(cdoc) + 1): 25 | res.add(''.join(cdoc[j-i:j])) 26 | return list(res) 27 | 28 | 29 | def extract_kws_zh(docs: str, model: KeyBERT, 30 | ngram_range: Tuple[int, int] = (1, 3), 31 | top_n: int = 5, 32 | use_mmr: bool = True, 33 | diversity: float = 0.25,) -> Union[List[Tuple[str, float]], 34 | List[List[Tuple[str, float]]]]: 35 | """extract keywords from Chinese document 36 | 37 | Args: 38 | docs (str): the Chinese document 39 | model (keybert.KeyBERT): the KeyBERT model to do extraction 40 | ngram_range (Tuple[int, int], optional): Length, in words, of the extracted 41 | keywords/keyphrases. Defaults to (1, 3). 42 | top_n (int, optional): extract n keywords. Defaults to 5. 43 | use_mmr (bool, optional): Whether to use MMR. Defaults to True. 44 | diversity (float, optional): The diversity of results between 0 and 1 45 | if use_mmr is True. Defaults to 0.25. 46 | Returns: 47 | Union[List[Tuple[str, float]], List[List[Tuple[str, float]]]]: the top n keywords for a document 48 | """ 49 | 50 | candi = gen_candidates_zh(docs, ngram_range) 51 | return model.extract_keywords(docs, candi, 52 | stop_words=None, 53 | top_n=top_n, 54 | use_mmr=use_mmr, 55 | diversity=diversity) 56 | -------------------------------------------------------------------------------- /zhkeybert/_highlight.py: -------------------------------------------------------------------------------- 1 | import re 2 | from rich.console import Console 3 | from rich.highlighter import RegexHighlighter 4 | from typing import Tuple, List 5 | 6 | 7 | class NullHighlighter(RegexHighlighter): 8 | """Apply style to anything that looks like an email.""" 9 | 10 | base_style = "" 11 | highlights = [r""] 12 | 13 | 14 | def highlight_document(doc: str, 15 | keywords: List[Tuple[str, float]]): 16 | """ Highlight keywords in a document 17 | 18 | Arguments: 19 | doc: The document for which to extract keywords/keyphrases 20 | keywords: the top n keywords for a document with their respective distances 21 | to the input document 22 | 23 | Returns: 24 | highlighted_text: The document with additional tags to highlight keywords 25 | according to the rich package 26 | """ 27 | keywords_only = [keyword for keyword, _ in keywords] 28 | max_len = max([len(token.split(" ")) for token in keywords_only]) 29 | 30 | if max_len == 1: 31 | highlighted_text = _highlight_one_gram(doc, keywords_only) 32 | else: 33 | highlighted_text = _highlight_n_gram(doc, keywords_only) 34 | 35 | console = Console(highlighter=NullHighlighter()) 36 | console.print(highlighted_text) 37 | 38 | 39 | def _highlight_one_gram(doc: str, 40 | keywords: List[str]) -> str: 41 | """ Highlight 1-gram keywords in a document 42 | 43 | Arguments: 44 | doc: The document for which to extract keywords/keyphrases 45 | keywords: the top n keywords for a document 46 | 47 | Returns: 48 | highlighted_text: The document with additional tags to highlight keywords 49 | according to the rich package 50 | """ 51 | tokens = re.sub(r' +', ' ', doc.replace("\n", " ")).split(" ") 52 | 53 | highlighted_text = " ".join([f"[black on #FFFF00]{token}[/]" 54 | if token.lower() in keywords 55 | else f"{token}" 56 | for token in tokens]).strip() 57 | return highlighted_text 58 | 59 | 60 | def _highlight_n_gram(doc: str, 61 | keywords: List[str]) -> str: 62 | """ Highlight n-gram keywords in a document 63 | 64 | Arguments: 65 | doc: The document for which to extract keywords/keyphrases 66 | keywords: the top n keywords for a document 67 | 68 | Returns: 69 | highlighted_text: The document with additional tags to highlight keywords 70 | according to the rich package 71 | """ 72 | max_len = max([len(token.split(" ")) for token in keywords]) 73 | tokens = re.sub(r' +', ' ', doc.replace("\n", " ")).strip().split(" ") 74 | n_gram_tokens = [[" ".join(tokens[i: i + max_len][0: j + 1]) for j in range(max_len)] for i, _ in enumerate(tokens)] 75 | highlighted_text = [] 76 | skip = False 77 | 78 | for n_grams in n_gram_tokens: 79 | candidate = False 80 | 81 | if not skip: 82 | for index, n_gram in enumerate(n_grams): 83 | 84 | if n_gram.lower() in keywords: 85 | candidate = f"[black on #FFFF00]{n_gram}[/]" + n_grams[-1].split(n_gram)[-1] 86 | skip = index + 1 87 | 88 | if not candidate: 89 | candidate = n_grams[0] 90 | 91 | highlighted_text.append(candidate) 92 | 93 | else: 94 | skip = skip - 1 95 | highlighted_text = " ".join(highlighted_text) 96 | return highlighted_text 97 | -------------------------------------------------------------------------------- /zhkeybert/_maxsum.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import itertools 3 | from sklearn.metrics.pairwise import cosine_similarity 4 | from typing import List, Tuple 5 | 6 | 7 | def max_sum_similarity(doc_embedding: np.ndarray, 8 | word_embeddings: np.ndarray, 9 | words: List[str], 10 | top_n: int, 11 | nr_candidates: int) -> List[Tuple[str, float]]: 12 | """ Calculate Max Sum Distance for extraction of keywords 13 | 14 | We take the 2 x top_n most similar words/phrases to the document. 15 | Then, we take all top_n combinations from the 2 x top_n words and 16 | extract the combination that are the least similar to each other 17 | by cosine similarity. 18 | 19 | NOTE: 20 | This is O(n^2) and therefore not advised if you use a large top_n 21 | 22 | Arguments: 23 | doc_embedding: The document embeddings 24 | word_embeddings: The embeddings of the selected candidate keywords/phrases 25 | words: The selected candidate keywords/keyphrases 26 | top_n: The number of keywords/keyhprases to return 27 | nr_candidates: The number of candidates to consider 28 | 29 | Returns: 30 | List[Tuple[str, float]]: The selected keywords/keyphrases with their distances 31 | """ 32 | if nr_candidates < top_n: 33 | raise Exception("Make sure that the number of candidates exceeds the number " 34 | "of keywords to return.") 35 | 36 | # Calculate distances and extract keywords 37 | distances = cosine_similarity(doc_embedding, word_embeddings) 38 | distances_words = cosine_similarity(word_embeddings, word_embeddings) 39 | 40 | # Get 2*top_n words as candidates based on cosine similarity 41 | words_idx = list(distances.argsort()[0][-nr_candidates:]) 42 | words_vals = [words[index] for index in words_idx] 43 | candidates = distances_words[np.ix_(words_idx, words_idx)] 44 | 45 | # Calculate the combination of words that are the least similar to each other 46 | min_sim = 100_000 47 | candidate = None 48 | for combination in itertools.combinations(range(len(words_idx)), top_n): 49 | sim = sum([candidates[i][j] for i in combination for j in combination if i != j]) 50 | if sim < min_sim: 51 | candidate = combination 52 | min_sim = sim 53 | 54 | return [(words_vals[idx], round(float(distances[0][idx]), 4)) for idx in candidate] 55 | -------------------------------------------------------------------------------- /zhkeybert/_mmr.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.metrics.pairwise import cosine_similarity 3 | from typing import List, Tuple 4 | 5 | 6 | def mmr(doc_embedding: np.ndarray, 7 | word_embeddings: np.ndarray, 8 | words: List[str], 9 | top_n: int = 5, 10 | diversity: float = 0.8) -> List[Tuple[str, float]]: 11 | """ Calculate Maximal Marginal Relevance (MMR) 12 | between candidate keywords and the document. 13 | 14 | 15 | MMR considers the similarity of keywords/keyphrases with the 16 | document, along with the similarity of already selected 17 | keywords and keyphrases. This results in a selection of keywords 18 | that maximize their within diversity with respect to the document. 19 | 20 | Arguments: 21 | doc_embedding: The document embeddings 22 | word_embeddings: The embeddings of the selected candidate keywords/phrases 23 | words: The selected candidate keywords/keyphrases 24 | top_n: The number of keywords/keyhprases to return 25 | diversity: How diverse the select keywords/keyphrases are. 26 | Values between 0 and 1 with 0 being not diverse at all 27 | and 1 being most diverse. 28 | 29 | Returns: 30 | List[Tuple[str, float]]: The selected keywords/keyphrases with their distances 31 | 32 | """ 33 | 34 | # Extract similarity within words, and between words and the document 35 | word_doc_similarity = cosine_similarity(word_embeddings, doc_embedding) 36 | word_similarity = cosine_similarity(word_embeddings) 37 | 38 | # Initialize candidates and already choose best keyword/keyphras 39 | keywords_idx = [np.argmax(word_doc_similarity)] 40 | candidates_idx = [i for i in range(len(words)) if i != keywords_idx[0]] 41 | 42 | for _ in range(top_n - 1): 43 | # Extract similarities within candidates and 44 | # between candidates and selected keywords/phrases 45 | candidate_similarities = word_doc_similarity[candidates_idx, :] 46 | target_similarities = np.max(word_similarity[candidates_idx][:, keywords_idx], axis=1) 47 | 48 | # Calculate MMR 49 | mmr = (1-diversity) * candidate_similarities - diversity * target_similarities.reshape(-1, 1) 50 | mmr_idx = candidates_idx[np.argmax(mmr)] 51 | 52 | # Update keywords & candidates 53 | keywords_idx.append(mmr_idx) 54 | candidates_idx.remove(mmr_idx) 55 | 56 | return [(words[idx], round(float(word_doc_similarity.reshape(1, -1)[0][idx]), 4)) for idx in keywords_idx] 57 | 58 | -------------------------------------------------------------------------------- /zhkeybert/_model.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.filterwarnings("ignore", category=FutureWarning) 3 | 4 | import numpy as np 5 | from tqdm import tqdm 6 | from typing import List, Union, Tuple 7 | from sklearn.metrics.pairwise import cosine_similarity 8 | from sklearn.feature_extraction.text import CountVectorizer 9 | 10 | # KeyBERT 11 | from ._mmr import mmr 12 | from ._maxsum import max_sum_similarity 13 | from ._highlight import highlight_document 14 | from .backend._utils import select_backend 15 | 16 | 17 | class KeyBERT: 18 | """ 19 | A minimal method for keyword extraction with BERT 20 | 21 | The keyword extraction is done by finding the sub-phrases in 22 | a document that are the most similar to the document itself. 23 | 24 | First, document embeddings are extracted with BERT to get a 25 | document-level representation. Then, word embeddings are extracted 26 | for N-gram words/phrases. Finally, we use cosine similarity to find the 27 | words/phrases that are the most similar to the document. 28 | 29 | The most similar words could then be identified as the words that 30 | best describe the entire document. 31 | 32 | """ 33 | def __init__(self, 34 | model="paraphrase-multilingual-MiniLM-L12-v2"): 35 | """ KeyBERT initialization 36 | 37 | Arguments: 38 | model: Use a custom embedding model. 39 | The following backends are currently supported 40 | * SentenceTransformers 41 | * Flair 42 | * Spacy 43 | * Gensim 44 | * USE (TF-Hub) 45 | You can also pass in a string that points to one of the following 46 | sentence-transformers models: 47 | * https://www.sbert.net/docs/pretrained_models.html 48 | """ 49 | self.model = select_backend(model) 50 | 51 | def extract_keywords(self, 52 | docs: Union[str, List[str]], 53 | candidates: List[str] = None, 54 | keyphrase_ngram_range: Tuple[int, int] = (1, 1), 55 | stop_words: Union[str, List[str]] = 'english', 56 | top_n: int = 5, 57 | min_df: int = 1, 58 | use_maxsum: bool = False, 59 | use_mmr: bool = False, 60 | diversity: float = 0.5, 61 | nr_candidates: int = 20, 62 | vectorizer: CountVectorizer = None, 63 | highlight: bool = False, 64 | seed_keywords: List[str] = None) -> Union[List[Tuple[str, float]], 65 | List[List[Tuple[str, float]]]]: 66 | """ Extract keywords/keyphrases 67 | 68 | NOTE: 69 | I would advise you to iterate over single documents as they 70 | will need the least amount of memory. Even though this is slower, 71 | you are not likely to run into memory errors. 72 | 73 | Multiple Documents: 74 | There is an option to extract keywords for multiple documents 75 | that is faster than extraction for multiple single documents. 76 | 77 | However...this method assumes that you can keep the word embeddings 78 | for all words in the vocabulary in memory which might be troublesome. 79 | 80 | I would advise against using this option and simply iterating 81 | over documents instead if you have limited hardware. 82 | 83 | Arguments: 84 | docs: The document(s) for which to extract keywords/keyphrases 85 | candidates: Candidate keywords/keyphrases to use instead of extracting them from the document(s) 86 | keyphrase_ngram_range: Length, in words, of the extracted keywords/keyphrases 87 | stop_words: Stopwords to remove from the document 88 | top_n: Return the top n keywords/keyphrases 89 | min_df: Minimum document frequency of a word across all documents 90 | if keywords for multiple documents need to be extracted 91 | use_maxsum: Whether to use Max Sum Similarity for the selection 92 | of keywords/keyphrases 93 | use_mmr: Whether to use Maximal Marginal Relevance (MMR) for the 94 | selection of keywords/keyphrases 95 | diversity: The diversity of the results between 0 and 1 if use_mmr 96 | is set to True 97 | nr_candidates: The number of candidates to consider if use_maxsum is 98 | set to True 99 | vectorizer: Pass in your own CountVectorizer from scikit-learn 100 | highlight: Whether to print the document and highlight 101 | its keywords/keyphrases. NOTE: This does not work if 102 | multiple documents are passed. 103 | seed_keywords: Seed keywords that may guide the extraction of keywords by 104 | steering the similarities towards the seeded keywords 105 | 106 | Returns: 107 | keywords: the top n keywords for a document with their respective distances 108 | to the input document 109 | 110 | """ 111 | 112 | if isinstance(docs, str): 113 | keywords = self._extract_keywords_single_doc(doc=docs, 114 | candidates=candidates, 115 | keyphrase_ngram_range=keyphrase_ngram_range, 116 | stop_words=stop_words, 117 | top_n=top_n, 118 | use_maxsum=use_maxsum, 119 | use_mmr=use_mmr, 120 | diversity=diversity, 121 | nr_candidates=nr_candidates, 122 | vectorizer=vectorizer, 123 | seed_keywords=seed_keywords) 124 | if highlight: 125 | highlight_document(docs, keywords) 126 | 127 | return keywords 128 | 129 | elif isinstance(docs, list): 130 | warnings.warn("Although extracting keywords for multiple documents is faster " 131 | "than iterating over single documents, it requires significantly more memory " 132 | "to hold all word embeddings. Use this at your own discretion!") 133 | return self._extract_keywords_multiple_docs(docs, 134 | keyphrase_ngram_range, 135 | stop_words, 136 | top_n, 137 | min_df, 138 | vectorizer) 139 | 140 | def _extract_keywords_single_doc(self, 141 | doc: str, 142 | candidates: List[str] = None, 143 | keyphrase_ngram_range: Tuple[int, int] = (1, 1), 144 | stop_words: Union[str, List[str]] = 'english', 145 | top_n: int = 5, 146 | use_maxsum: bool = False, 147 | use_mmr: bool = False, 148 | diversity: float = 0.5, 149 | nr_candidates: int = 20, 150 | vectorizer: CountVectorizer = None, 151 | seed_keywords: List[str] = None) -> List[Tuple[str, float]]: 152 | """ Extract keywords/keyphrases for a single document 153 | 154 | Arguments: 155 | doc: The document for which to extract keywords/keyphrases 156 | candidates: Candidate keywords/keyphrases to use instead of extracting them from the document(s) 157 | keyphrase_ngram_range: Length, in words, of the extracted keywords/keyphrases 158 | stop_words: Stopwords to remove from the document 159 | top_n: Return the top n keywords/keyphrases 160 | use_mmr: Whether to use Max Sum Similarity 161 | use_mmr: Whether to use MMR 162 | diversity: The diversity of results between 0 and 1 if use_mmr is True 163 | nr_candidates: The number of candidates to consider if use_maxsum is set to True 164 | vectorizer: Pass in your own CountVectorizer from scikit-learn 165 | seed_keywords: Seed keywords that may guide the extraction of keywords by 166 | steering the similarities towards the seeded keywords 167 | 168 | Returns: 169 | keywords: the top n keywords for a document with their respective distances 170 | to the input document 171 | """ 172 | try: 173 | # Extract Words 174 | if candidates is None: 175 | if vectorizer: 176 | count = vectorizer.fit([doc]) 177 | else: 178 | count = CountVectorizer(ngram_range=keyphrase_ngram_range, stop_words=stop_words).fit([doc]) 179 | candidates = count.get_feature_names() 180 | 181 | # Extract Embeddings 182 | doc_embedding = self.model.embed([doc]) 183 | candidate_embeddings = self.model.embed(candidates) 184 | 185 | # Guided KeyBERT with seed keywords 186 | if seed_keywords is not None: 187 | seed_embeddings = self.model.embed([" ".join(seed_keywords)]) 188 | doc_embedding = np.average([doc_embedding, seed_embeddings], axis=0, weights=[3, 1]) 189 | 190 | # Calculate distances and extract keywords 191 | if use_mmr: 192 | keywords = mmr(doc_embedding, candidate_embeddings, candidates, top_n, diversity) 193 | elif use_maxsum: 194 | keywords = max_sum_similarity(doc_embedding, candidate_embeddings, candidates, top_n, nr_candidates) 195 | else: 196 | distances = cosine_similarity(doc_embedding, candidate_embeddings) 197 | keywords = [(candidates[index], round(float(distances[0][index]), 4)) 198 | for index in distances.argsort()[0][-top_n:]][::-1] 199 | 200 | return keywords 201 | except ValueError: 202 | return [] 203 | 204 | def _extract_keywords_multiple_docs(self, 205 | docs: List[str], 206 | keyphrase_ngram_range: Tuple[int, int] = (1, 1), 207 | stop_words: str = 'english', 208 | top_n: int = 5, 209 | min_df: int = 1, 210 | vectorizer: CountVectorizer = None) -> List[List[Tuple[str, float]]]: 211 | """ Extract keywords/keyphrases for a multiple documents 212 | 213 | This currently does not use MMR and Max Sum Similarity as it cannot 214 | process these methods in bulk. 215 | 216 | Arguments: 217 | docs: The document for which to extract keywords/keyphrases 218 | keyphrase_ngram_range: Length, in words, of the extracted keywords/keyphrases 219 | stop_words: Stopwords to remove from the document 220 | top_n: Return the top n keywords/keyphrases 221 | min_df: The minimum frequency of words 222 | vectorizer: Pass in your own CountVectorizer from scikit-learn 223 | 224 | Returns: 225 | keywords: the top n keywords for a document with their respective distances 226 | to the input document 227 | """ 228 | # Extract words 229 | if vectorizer: 230 | count = vectorizer.fit(docs) 231 | else: 232 | count = CountVectorizer(ngram_range=keyphrase_ngram_range, stop_words=stop_words, min_df=min_df).fit(docs) 233 | words = count.get_feature_names() 234 | df = count.transform(docs) 235 | 236 | # Extract embeddings 237 | doc_embeddings = self.model.embed(docs) 238 | word_embeddings = self.model.embed(words) 239 | 240 | # Extract keywords 241 | keywords = [] 242 | for index, doc in tqdm(enumerate(docs)): 243 | doc_words = [words[i] for i in df[index].nonzero()[1]] 244 | 245 | if doc_words: 246 | doc_word_embeddings = np.array([word_embeddings[i] for i in df[index].nonzero()[1]]) 247 | distances = cosine_similarity([doc_embeddings[index]], doc_word_embeddings)[0] 248 | doc_keywords = [(doc_words[i], round(float(distances[i]), 4)) for i in distances.argsort()[-top_n:]] 249 | keywords.append(doc_keywords) 250 | else: 251 | keywords.append(["None Found"]) 252 | 253 | return keywords 254 | -------------------------------------------------------------------------------- /zhkeybert/backend/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from ._base import BaseEmbedder 3 | 4 | __all__ = [ 5 | "BaseEmbedder" 6 | ] -------------------------------------------------------------------------------- /zhkeybert/backend/_base.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import List 3 | 4 | 5 | class BaseEmbedder: 6 | """ The Base Embedder used for creating embedding models 7 | Arguments: 8 | embedding_model: The main embedding model to be used for extracting 9 | document and word embedding 10 | word_embedding_model: The embedding model used for extracting word 11 | embeddings only. If this model is selected, 12 | then the `embedding_model` is purely used for 13 | creating document embeddings. 14 | """ 15 | def __init__(self, 16 | embedding_model=None, 17 | word_embedding_model=None): 18 | self.embedding_model = embedding_model 19 | self.word_embedding_model = word_embedding_model 20 | 21 | def embed(self, 22 | documents: List[str], 23 | verbose: bool = False) -> np.ndarray: 24 | """ Embed a list of n documents/words into an n-dimensional 25 | matrix of embeddings 26 | Arguments: 27 | documents: A list of documents or words to be embedded 28 | verbose: Controls the verbosity of the process 29 | Returns: 30 | Document/words embeddings with shape (n, m) with `n` documents/words 31 | that each have an embeddings size of `m` 32 | """ 33 | pass 34 | -------------------------------------------------------------------------------- /zhkeybert/backend/_flair.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | from typing import Union, List 4 | from flair.data import Sentence 5 | from flair.embeddings import DocumentEmbeddings, TokenEmbeddings, DocumentPoolEmbeddings 6 | 7 | from keybert.backend import BaseEmbedder 8 | 9 | 10 | class FlairBackend(BaseEmbedder): 11 | """ Flair Embedding Model 12 | The Flair embedding model used for generating document and 13 | word embeddings. 14 | Arguments: 15 | embedding_model: A Flair embedding model 16 | Usage: 17 | ```python 18 | from keybert.backend import FlairBackend 19 | from flair.embeddings import WordEmbeddings, DocumentPoolEmbeddings 20 | 21 | # Create a Flair Embedding model 22 | glove_embedding = WordEmbeddings('crawl') 23 | document_glove_embeddings = DocumentPoolEmbeddings([glove_embedding]) 24 | 25 | # Pass the Flair model to create a new backend 26 | flair_embedder = FlairBackend(document_glove_embeddings) 27 | ``` 28 | """ 29 | def __init__(self, embedding_model: Union[TokenEmbeddings, DocumentEmbeddings]): 30 | super().__init__() 31 | 32 | # Flair word embeddings 33 | if isinstance(embedding_model, TokenEmbeddings): 34 | self.embedding_model = DocumentPoolEmbeddings([embedding_model]) 35 | 36 | # Flair document embeddings + disable fine tune to prevent CUDA OOM 37 | # https://github.com/flairNLP/flair/issues/1719 38 | elif isinstance(embedding_model, DocumentEmbeddings): 39 | if "fine_tune" in embedding_model.__dict__: 40 | embedding_model.fine_tune = False 41 | self.embedding_model = embedding_model 42 | 43 | else: 44 | raise ValueError("Please select a correct Flair model by either using preparing a token or document " 45 | "embedding model: \n" 46 | "`from flair.embeddings import TransformerDocumentEmbeddings` \n" 47 | "`roberta = TransformerDocumentEmbeddings('roberta-base')`") 48 | 49 | def embed(self, 50 | documents: List[str], 51 | verbose: bool = False) -> np.ndarray: 52 | """ Embed a list of n documents/words into an n-dimensional 53 | matrix of embeddings 54 | Arguments: 55 | documents: A list of documents or words to be embedded 56 | verbose: Controls the verbosity of the process 57 | Returns: 58 | Document/words embeddings with shape (n, m) with `n` documents/words 59 | that each have an embeddings size of `m` 60 | """ 61 | embeddings = [] 62 | for index, document in tqdm(enumerate(documents), disable=not verbose): 63 | try: 64 | sentence = Sentence(document) if document else Sentence("an empty document") 65 | self.embedding_model.embed(sentence) 66 | except RuntimeError: 67 | sentence = Sentence("an empty document") 68 | self.embedding_model.embed(sentence) 69 | embedding = sentence.embedding.detach().cpu().numpy() 70 | embeddings.append(embedding) 71 | embeddings = np.asarray(embeddings) 72 | return embeddings -------------------------------------------------------------------------------- /zhkeybert/backend/_gensim.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | from typing import List 4 | from keybert.backend import BaseEmbedder 5 | from gensim.models.keyedvectors import Word2VecKeyedVectors 6 | 7 | 8 | class GensimBackend(BaseEmbedder): 9 | """ Gensim Embedding Model 10 | 11 | The Gensim embedding model is typically used for word embeddings with 12 | GloVe, Word2Vec or FastText. 13 | 14 | Arguments: 15 | embedding_model: A Gensim embedding model 16 | 17 | Usage: 18 | 19 | ```python 20 | from keybert.backend import GensimBackend 21 | import gensim.downloader as api 22 | 23 | ft = api.load('fasttext-wiki-news-subwords-300') 24 | ft_embedder = GensimBackend(ft) 25 | ``` 26 | """ 27 | def __init__(self, embedding_model: Word2VecKeyedVectors): 28 | super().__init__() 29 | 30 | if isinstance(embedding_model, Word2VecKeyedVectors): 31 | self.embedding_model = embedding_model 32 | else: 33 | raise ValueError("Please select a correct Gensim model: \n" 34 | "`import gensim.downloader as api` \n" 35 | "`ft = api.load('fasttext-wiki-news-subwords-300')`") 36 | 37 | def embed(self, 38 | documents: List[str], 39 | verbose: bool = False) -> np.ndarray: 40 | """ Embed a list of n documents/words into an n-dimensional 41 | matrix of embeddings 42 | 43 | Arguments: 44 | documents: A list of documents or words to be embedded 45 | verbose: Controls the verbosity of the process 46 | 47 | Returns: 48 | Document/words embeddings with shape (n, m) with `n` documents/words 49 | that each have an embeddings size of `m` 50 | """ 51 | vector_shape = self.embedding_model.word_vec(list(self.embedding_model.vocab.keys())[0]).shape 52 | empty_vector = np.zeros(vector_shape[0]) 53 | 54 | embeddings = [] 55 | for doc in tqdm(documents, disable=not verbose, position=0, leave=True): 56 | doc_embedding = [] 57 | 58 | # Extract word embeddings 59 | for word in doc.split(" "): 60 | try: 61 | word_embedding = self.embedding_model.word_vec(word) 62 | doc_embedding.append(word_embedding) 63 | except KeyError: 64 | doc_embedding.append(empty_vector) 65 | 66 | # Pool word embeddings 67 | doc_embedding = np.mean(doc_embedding, axis=0) 68 | embeddings.append(doc_embedding) 69 | 70 | embeddings = np.array(embeddings) 71 | return embeddings 72 | -------------------------------------------------------------------------------- /zhkeybert/backend/_sentencetransformers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import List, Union 3 | from sentence_transformers import SentenceTransformer 4 | 5 | from . import BaseEmbedder 6 | 7 | 8 | class SentenceTransformerBackend(BaseEmbedder): 9 | """ Sentence-transformers embedding model 10 | The sentence-transformers embedding model used for generating document and 11 | word embeddings. 12 | Arguments: 13 | embedding_model: A sentence-transformers embedding model 14 | Usage: 15 | To create a model, you can load in a string pointing to a 16 | sentence-transformers model: 17 | ```python 18 | from keybert.backend import SentenceTransformerBackend 19 | sentence_model = SentenceTransformerBackend("all-MiniLM-L6-v2") 20 | ``` 21 | or you can instantiate a model yourself: 22 | ```python 23 | from keybert.backend import SentenceTransformerBackend 24 | from sentence_transformers import SentenceTransformer 25 | embedding_model = SentenceTransformer("all-MiniLM-L6-v2") 26 | sentence_model = SentenceTransformerBackend(embedding_model) 27 | ``` 28 | """ 29 | def __init__(self, embedding_model: Union[str, SentenceTransformer]): 30 | super().__init__() 31 | 32 | if isinstance(embedding_model, SentenceTransformer): 33 | self.embedding_model = embedding_model 34 | elif isinstance(embedding_model, str): 35 | self.embedding_model = SentenceTransformer(embedding_model) 36 | else: 37 | raise ValueError("Please select a correct SentenceTransformers model: \n" 38 | "`from sentence_transformers import SentenceTransformer` \n" 39 | "`model = SentenceTransformer('all-MiniLM-L6-v2')`") 40 | 41 | def embed(self, 42 | documents: List[str], 43 | verbose: bool = False) -> np.ndarray: 44 | """ Embed a list of n documents/words into an n-dimensional 45 | matrix of embeddings 46 | Arguments: 47 | documents: A list of documents or words to be embedded 48 | verbose: Controls the verbosity of the process 49 | Returns: 50 | Document/words embeddings with shape (n, m) with `n` documents/words 51 | that each have an embeddings size of `m` 52 | """ 53 | embeddings = self.embedding_model.encode(documents, show_progress_bar=verbose) 54 | return embeddings 55 | -------------------------------------------------------------------------------- /zhkeybert/backend/_spacy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | from typing import List 4 | from keybert.backend import BaseEmbedder 5 | 6 | 7 | class SpacyBackend(BaseEmbedder): 8 | """ Spacy embedding model 9 | 10 | The Spacy embedding model used for generating document and 11 | word embeddings. 12 | 13 | Arguments: 14 | embedding_model: A spacy embedding model 15 | 16 | Usage: 17 | 18 | To create a Spacy backend, you need to create an nlp object and 19 | pass it through this backend: 20 | 21 | ```python 22 | import spacy 23 | from keybert.backend import SpacyBackend 24 | 25 | nlp = spacy.load("en_core_web_md", exclude=['tagger', 'parser', 'ner', 'attribute_ruler', 'lemmatizer']) 26 | spacy_model = SpacyBackend(nlp) 27 | ``` 28 | 29 | To load in a transformer model use the following: 30 | 31 | ```python 32 | import spacy 33 | from thinc.api import set_gpu_allocator, require_gpu 34 | from keybert.backend import SpacyBackend 35 | 36 | nlp = spacy.load("en_core_web_trf", exclude=['tagger', 'parser', 'ner', 'attribute_ruler', 'lemmatizer']) 37 | set_gpu_allocator("pytorch") 38 | require_gpu(0) 39 | spacy_model = SpacyBackend(nlp) 40 | ``` 41 | 42 | If you run into gpu/memory-issues, please use: 43 | 44 | ```python 45 | import spacy 46 | from keybert.backend import SpacyBackend 47 | 48 | spacy.prefer_gpu() 49 | nlp = spacy.load("en_core_web_trf", exclude=['tagger', 'parser', 'ner', 'attribute_ruler', 'lemmatizer']) 50 | spacy_model = SpacyBackend(nlp) 51 | ``` 52 | """ 53 | def __init__(self, embedding_model): 54 | super().__init__() 55 | 56 | if "spacy" in str(type(embedding_model)): 57 | self.embedding_model = embedding_model 58 | else: 59 | raise ValueError("Please select a correct Spacy model by either using a string such as 'en_core_web_md' " 60 | "or create a nlp model using: `nlp = spacy.load('en_core_web_md')") 61 | 62 | def embed(self, 63 | documents: List[str], 64 | verbose: bool = False) -> np.ndarray: 65 | """ Embed a list of n documents/words into an n-dimensional 66 | matrix of embeddings 67 | 68 | Arguments: 69 | documents: A list of documents or words to be embedded 70 | verbose: Controls the verbosity of the process 71 | 72 | Returns: 73 | Document/words embeddings with shape (n, m) with `n` documents/words 74 | that each have an embeddings size of `m` 75 | """ 76 | 77 | # Extract embeddings from a transformer model 78 | if "transformer" in self.embedding_model.component_names: 79 | embeddings = [] 80 | for doc in tqdm(documents, position=0, leave=True, disable=not verbose): 81 | try: 82 | embedding = self.embedding_model(doc)._.trf_data.tensors[-1][0].tolist() 83 | except: 84 | embedding = self.embedding_model("An empty document")._.trf_data.tensors[-1][0].tolist() 85 | embeddings.append(embedding) 86 | embeddings = np.array(embeddings) 87 | 88 | # Extract embeddings from a general spacy model 89 | else: 90 | embeddings = [] 91 | for doc in tqdm(documents, position=0, leave=True, disable=not verbose): 92 | try: 93 | vector = self.embedding_model(doc).vector 94 | except ValueError: 95 | vector = self.embedding_model("An empty document").vector 96 | embeddings.append(vector) 97 | embeddings = np.array(embeddings) 98 | 99 | return embeddings 100 | -------------------------------------------------------------------------------- /zhkeybert/backend/_use.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | from typing import List 4 | 5 | from keybert.backend import BaseEmbedder 6 | 7 | 8 | class USEBackend(BaseEmbedder): 9 | """ Universal Sentence Encoder 10 | 11 | USE encodes text into high-dimensional vectors that 12 | are used for semantic similarity in KeyBERT. 13 | 14 | Arguments: 15 | embedding_model: An USE embedding model 16 | 17 | Usage: 18 | 19 | ```python 20 | import tensorflow_hub 21 | from keybert.backend import USEBackend 22 | 23 | embedding_model = tensorflow_hub.load("https://tfhub.dev/google/universal-sentence-encoder/4") 24 | use_embedder = USEBackend(embedding_model) 25 | ``` 26 | """ 27 | def __init__(self, embedding_model): 28 | super().__init__() 29 | 30 | try: 31 | embedding_model(["test sentence"]) 32 | self.embedding_model = embedding_model 33 | except TypeError: 34 | raise ValueError("Please select a correct USE model: \n" 35 | "`import tensorflow_hub` \n" 36 | "`embedding_model = tensorflow_hub.load(path_to_model)`") 37 | 38 | def embed(self, 39 | documents: List[str], 40 | verbose: bool = False) -> np.ndarray: 41 | """ Embed a list of n documents/words into an n-dimensional 42 | matrix of embeddings 43 | 44 | Arguments: 45 | documents: A list of documents or words to be embedded 46 | verbose: Controls the verbosity of the process 47 | 48 | Returns: 49 | Document/words embeddings with shape (n, m) with `n` documents/words 50 | that each have an embeddings size of `m` 51 | """ 52 | embeddings = np.array([self.embedding_model([doc]).cpu().numpy()[0] 53 | for doc in tqdm(documents, disable=not verbose)]) 54 | return embeddings 55 | -------------------------------------------------------------------------------- /zhkeybert/backend/_utils.py: -------------------------------------------------------------------------------- 1 | from ._base import BaseEmbedder 2 | from ._sentencetransformers import SentenceTransformerBackend 3 | 4 | 5 | def select_backend(embedding_model) -> BaseEmbedder: 6 | """ Select an embedding model based on language or a specific sentence transformer models. 7 | When selecting a language, we choose `all-MiniLM-L6-v2` for English and 8 | `paraphrase-multilingual-MiniLM-L12-v2` for all other languages as it support 100+ languages. 9 | 10 | Returns: 11 | model: Either a Sentence-Transformer or Flair model 12 | """ 13 | # keybert language backend 14 | if isinstance(embedding_model, BaseEmbedder): 15 | return embedding_model 16 | 17 | # Flair word embeddings 18 | if "flair" in str(type(embedding_model)): 19 | from keybert.backend._flair import FlairBackend 20 | return FlairBackend(embedding_model) 21 | 22 | # Spacy embeddings 23 | if "spacy" in str(type(embedding_model)): 24 | from keybert.backend._spacy import SpacyBackend 25 | return SpacyBackend(embedding_model) 26 | 27 | # Gensim embeddings 28 | if "gensim" in str(type(embedding_model)): 29 | from keybert.backend._gensim import GensimBackend 30 | return GensimBackend(embedding_model) 31 | 32 | # USE embeddings 33 | if "tensorflow" and "saved_model" in str(type(embedding_model)): 34 | from keybert.backend._use import USEBackend 35 | return USEBackend(embedding_model) 36 | 37 | # Sentence Transformer embeddings 38 | if "sentence_transformers" in str(type(embedding_model)): 39 | return SentenceTransformerBackend(embedding_model) 40 | 41 | # Create a Sentence Transformer model based on a string 42 | if isinstance(embedding_model, str): 43 | return SentenceTransformerBackend(embedding_model) 44 | 45 | return SentenceTransformerBackend("paraphrase-multilingual-MiniLM-L12-v2") 46 | --------------------------------------------------------------------------------