├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md └── workflows │ └── python-package.yml ├── .gitignore ├── .tool-versions ├── 2023-09-01 - reddit - depression dataset - etm - example.ipynb ├── CHANGELOG.md ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── SECURITY.md ├── build.sh ├── conftest.py ├── create_test_resources.py ├── dev_requirements.txt ├── embedded_topic_model ├── __init__.py ├── data │ └── 20ng │ │ ├── bow_tr_counts.mat │ │ ├── bow_tr_tokens.mat │ │ ├── bow_ts_counts.mat │ │ ├── bow_ts_h1_counts.mat │ │ ├── bow_ts_h1_tokens.mat │ │ ├── bow_ts_h2_counts.mat │ │ ├── bow_ts_h2_tokens.mat │ │ ├── bow_ts_tokens.mat │ │ ├── bow_va_counts.mat │ │ ├── bow_va_tokens.mat │ │ └── vocab.pkl ├── models │ ├── __init__.py │ ├── etm.py │ └── model.py ├── scripts │ ├── __init__.py │ └── datasets │ │ ├── 20ng │ │ ├── bow_tr_counts.mat │ │ ├── bow_tr_tokens.mat │ │ ├── bow_ts_counts.mat │ │ ├── bow_ts_h1_counts.mat │ │ ├── bow_ts_h1_tokens.mat │ │ ├── bow_ts_h2_counts.mat │ │ ├── bow_ts_h2_tokens.mat │ │ ├── bow_ts_tokens.mat │ │ ├── bow_va_counts.mat │ │ ├── bow_va_tokens.mat │ │ └── vocab.pkl │ │ ├── __init__.py │ │ ├── data_20ng.py │ │ ├── data_nyt.py │ │ ├── data_reddit_nouns_only.py │ │ ├── data_reddit_raw_pt.py │ │ └── stops.txt └── utils │ ├── __init__.py │ ├── data.py │ ├── embedding.py │ ├── metrics.py │ └── preprocessing.py ├── lint.sh ├── publish.txt ├── requirements.txt ├── setup.py ├── tests ├── __init__.py ├── integration │ ├── __init__.py │ └── test_etm.py ├── resources │ ├── train_resources.test │ ├── train_w2v_embeddings.wordvectors │ ├── train_w2v_embeddings.wordvectors.bin │ └── train_w2v_embeddings.wordvectors.txt └── unit │ ├── __init__.py │ ├── test_embedding.py │ └── test_preprocessing.py └── train_resources.test /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: "[BUG] Bug description" 5 | labels: bug 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Set-up resources 16 | 2. Call an API 17 | 3. See error 18 | 19 | **Reproduction example** 20 | Code to reproduce the problem. You can put inline code here or a link to a publicly available an runnable reproduction example. Feel free to add further instructions to set-up the reproduction environment. 21 | 22 | [Reproduction example](https://github.com/lffloyd/embedded-topic-model) 23 | 24 | Example code: 25 | ```python 26 | # put your code here 27 | print("something is off...") 28 | ``` 29 | 30 | **Expected behavior** 31 | A clear and concise description of what you expected to happen. 32 | 33 | **Screenshots** 34 | If applicable, add screenshots to help explain your problem. 35 | 36 | **Environment (please complete the following information):** 37 | - OS: [e.g. Ubuntu] 38 | - Environment [e.g. memory-size, `torch.device` used, or anything appliable] 39 | - Package Version [e.g. 1.1.0] 40 | - Python version [e.g. 3.7.6] 41 | - Dependencies versions [e.g. `gensim==3.8.3`, `torch==1.6.0`, etc.] 42 | 43 | **Additional context** 44 | Add any other context about the problem here. 45 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: "[FEAT] Feature request" 5 | labels: enhancement 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.github/workflows/python-package.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Python package 5 | 6 | on: 7 | push: 8 | branches: [ main ] 9 | pull_request: 10 | branches: [ main ] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | strategy: 17 | matrix: 18 | python-version: ['3.9', '3.10', '3.11'] 19 | 20 | steps: 21 | - uses: actions/checkout@v2 22 | - name: Set up Python ${{ matrix.python-version }} 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install -r requirements.txt 30 | pip install -r dev_requirements.txt 31 | - name: Lint with flake8 32 | run: | 33 | # stop the build if there are Python syntax errors or undefined names 34 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 35 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 36 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 37 | - name: Test with pytest 38 | run: | 39 | pytest 40 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | -------------------------------------------------------------------------------- /.tool-versions: -------------------------------------------------------------------------------- 1 | python 3.11.5 2 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | This changelog was inspired by the [keep-a-changelog](https://github.com/olivierlacan/keep-a-changelog) project and follows [semantic versioning](https://semver.org). 4 | 5 | ## [1.2.1] - 2023-09-06 6 | 7 | ### Changed 8 | 9 | - ([#cf35c3](https://github.com/lffloyd/embedded-topic-model/commit/cf35c3)) fixes minimum python version to be `python>=3.9` 10 | 11 | ## [1.2.0] - 2023-09-06 12 | 13 | ### Added 14 | 15 | - ([#61730d](https://github.com/lffloyd/embedded-topic-model/commit/61730d), [#224995](https://github.com/lffloyd/embedded-topic-model/commit/224995), [#331fc0](https://github.com/lffloyd/embedded-topic-model/commit/331fc0)) adds support for macOS MPS devices and updates outdated `numpy`/`sklearn` code - thanks to [@d-jiao](https://github.com/d-jiao) 16 | - ([#c48016](https://github.com/lffloyd/embedded-topic-model/commit/c48016), [#2fe517](https://github.com/lffloyd/embedded-topic-model/commit/2fe517), [#c965b1](https://github.com/lffloyd/embedded-topic-model/commit/c965b1), [#5578ca](https://github.com/lffloyd/embedded-topic-model/commit/5578ca), [#5b0d85](https://github.com/lffloyd/embedded-topic-model/commit/5b0d85)) adds security guidelines and request templates 17 | 18 | ### Changed 19 | 20 | - ([#331fc0](https://github.com/lffloyd/embedded-topic-model/commit/331fc0)) updates actions pipeline, supported python versions and internal dependencies to the latest available like `torch`, `gensim`, among others. Support for `python<=3.8` was dropped as a result. Numerous security vulnerabilities were solved 21 | 22 | ## [1.1.0] - 2023-09-05 23 | 24 | ### Added 25 | 26 | - ([#3f27ee](https://github.com/lffloyd/embedded-topic-model/commit/3f27ee)) adds `transform` method 27 | - ([#f98f3f](https://github.com/lffloyd/embedded-topic-model/commit/f98f3f)) adds example jupyter notebook 28 | - ([#683bec](https://github.com/lffloyd/embedded-topic-model/commit/683bec)) adds contributing and conduct guidelines 29 | 30 | ### Changed 31 | 32 | - ([#f98f3f](https://github.com/lffloyd/embedded-topic-model/commit/f98f3f), [#c918a4](https://github.com/lffloyd/embedded-topic-model/commit/c918a4)) updates documentation 33 | 34 | ## [1.0.2] - 2021-06-23 35 | 36 | ### Changed 37 | 38 | - deactivates debug mode by default 39 | - documents get_most_similar_words method 40 | 41 | ## [1.0.1] - 2021-02-15 42 | 43 | ### Changed 44 | 45 | - optimizes original word2vec TXT file input for model training 46 | - updates README.md 47 | 48 | ## [1.0.0] - 2021-02-15 49 | 50 | ### Added 51 | 52 | - adds support for original word2vec pretrained embeddings files on both formats (BIN/TXT) 53 | 54 | ### Changed 55 | 56 | - optimizes handling of gensim's word2vec mapping file for better memory usage 57 | 58 | ## [0.1.1] - 2021-02-01 59 | 60 | ### Added 61 | 62 | - support for python 3.6 63 | 64 | ## [0.1.0] - 2021-02-01 65 | 66 | ### Added 67 | 68 | - ETM training with partially tested support for original ETM features. 69 | - ETM corpus preprocessing scripts - including word2vec embeddings training - adapted from the original code. 70 | - adds methods to retrieve document-topic and topic-word probability distributions from the trained model. 71 | - adds docstrings for tested API methods. 72 | - adds unit and integration tests for ETM and preprocessing scripts. 73 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | 2 | # Contributor Covenant Code of Conduct 3 | 4 | ## Our Pledge 5 | 6 | We as members, contributors, and leaders pledge to make participation in our 7 | community a harassment-free experience for everyone, regardless of age, body 8 | size, visible or invisible disability, ethnicity, sex characteristics, gender 9 | identity and expression, level of experience, education, socio-economic status, 10 | nationality, personal appearance, race, caste, color, religion, or sexual 11 | identity and orientation. 12 | 13 | We pledge to act and interact in ways that contribute to an open, welcoming, 14 | diverse, inclusive, and healthy community. 15 | 16 | ## Our Standards 17 | 18 | Examples of behavior that contributes to a positive environment for our 19 | community include: 20 | 21 | * Demonstrating empathy and kindness toward other people 22 | * Being respectful of differing opinions, viewpoints, and experiences 23 | * Giving and gracefully accepting constructive feedback 24 | * Accepting responsibility and apologizing to those affected by our mistakes, 25 | and learning from the experience 26 | * Focusing on what is best not just for us as individuals, but for the overall 27 | community 28 | 29 | Examples of unacceptable behavior include: 30 | 31 | * The use of sexualized language or imagery, and sexual attention or advances of 32 | any kind 33 | * Trolling, insulting or derogatory comments, and personal or political attacks 34 | * Public or private harassment 35 | * Publishing others' private information, such as a physical or email address, 36 | without their explicit permission 37 | * Other conduct which could reasonably be considered inappropriate in a 38 | professional setting 39 | 40 | ## Enforcement Responsibilities 41 | 42 | Community leaders are responsible for clarifying and enforcing our standards of 43 | acceptable behavior and will take appropriate and fair corrective action in 44 | response to any behavior that they deem inappropriate, threatening, offensive, 45 | or harmful. 46 | 47 | Community leaders have the right and responsibility to remove, edit, or reject 48 | comments, commits, code, wiki edits, issues, and other contributions that are 49 | not aligned to this Code of Conduct, and will communicate reasons for moderation 50 | decisions when appropriate. 51 | 52 | ## Scope 53 | 54 | This Code of Conduct applies within all community spaces, and also applies when 55 | an individual is officially representing the community in public spaces. 56 | Examples of representing our community include using an official e-mail address, 57 | posting via an official social media account, or acting as an appointed 58 | representative at an online or offline event. 59 | 60 | ## Enforcement 61 | 62 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 63 | reported to the community leaders responsible for enforcement at 64 | lfmatosmelo@id.uff.br. 65 | All complaints will be reviewed and investigated promptly and fairly. 66 | 67 | All community leaders are obligated to respect the privacy and security of the 68 | reporter of any incident. 69 | 70 | ## Enforcement Guidelines 71 | 72 | Community leaders will follow these Community Impact Guidelines in determining 73 | the consequences for any action they deem in violation of this Code of Conduct: 74 | 75 | ### 1. Correction 76 | 77 | **Community Impact**: Use of inappropriate language or other behavior deemed 78 | unprofessional or unwelcome in the community. 79 | 80 | **Consequence**: A private, written warning from community leaders, providing 81 | clarity around the nature of the violation and an explanation of why the 82 | behavior was inappropriate. A public apology may be requested. 83 | 84 | ### 2. Warning 85 | 86 | **Community Impact**: A violation through a single incident or series of 87 | actions. 88 | 89 | **Consequence**: A warning with consequences for continued behavior. No 90 | interaction with the people involved, including unsolicited interaction with 91 | those enforcing the Code of Conduct, for a specified period of time. This 92 | includes avoiding interactions in community spaces as well as external channels 93 | like social media. Violating these terms may lead to a temporary or permanent 94 | ban. 95 | 96 | ### 3. Temporary Ban 97 | 98 | **Community Impact**: A serious violation of community standards, including 99 | sustained inappropriate behavior. 100 | 101 | **Consequence**: A temporary ban from any sort of interaction or public 102 | communication with the community for a specified period of time. No public or 103 | private interaction with the people involved, including unsolicited interaction 104 | with those enforcing the Code of Conduct, is allowed during this period. 105 | Violating these terms may lead to a permanent ban. 106 | 107 | ### 4. Permanent Ban 108 | 109 | **Community Impact**: Demonstrating a pattern of violation of community 110 | standards, including sustained inappropriate behavior, harassment of an 111 | individual, or aggression toward or disparagement of classes of individuals. 112 | 113 | **Consequence**: A permanent ban from any sort of public interaction within the 114 | community. 115 | 116 | ## Attribution 117 | 118 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 119 | version 2.1, available at 120 | [https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1]. 121 | 122 | Community Impact Guidelines were inspired by 123 | [Mozilla's code of conduct enforcement ladder][Mozilla CoC]. 124 | 125 | For answers to common questions about this code of conduct, see the FAQ at 126 | [https://www.contributor-covenant.org/faq][FAQ]. Translations are available at 127 | [https://www.contributor-covenant.org/translations][translations]. 128 | 129 | [homepage]: https://www.contributor-covenant.org 130 | [v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html 131 | [Mozilla CoC]: https://github.com/mozilla/diversity 132 | [FAQ]: https://www.contributor-covenant.org/faq 133 | [translations]: https://www.contributor-covenant.org/translations 134 | 135 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | 2 | # Contributing to embedded-topic-model 3 | 4 | First off, thanks for taking the time to contribute! ❤️ 5 | 6 | All types of contributions are encouraged and valued. See the [Table of Contents](#table-of-contents) for different ways to help and details about how this project handles them. Please make sure to read the relevant section before making your contribution. It will make it a lot easier for us maintainers and smooth out the experience for all involved. The community looks forward to your contributions. 🎉 7 | 8 | > And if you like the project, but just don't have time to contribute, that's fine. There are other easy ways to support the project and show your appreciation, which we would also be very happy about: 9 | > - Star the project 10 | > - Tweet about it 11 | > - Refer this project in your project's readme 12 | > - Mention the project at local meetups and tell your friends/colleagues 13 | 14 | 15 | ## Table of Contents 16 | 17 | - [I Have a Question](#i-have-a-question) 18 | - [I Want To Contribute](#i-want-to-contribute) 19 | - [Reporting Bugs](#reporting-bugs) 20 | - [Suggesting Enhancements](#suggesting-enhancements) 21 | 22 | - [Improving The Documentation](#improving-the-documentation) 23 | - [Styleguides](#styleguides) 24 | - [Commit Messages](#commit-messages) 25 | 26 | 27 | 28 | 29 | ## I Have a Question 30 | 31 | > If you want to ask a question, we assume that you have read the available [Documentation](./README.md). 32 | 33 | Before you ask a question, it is best to search for existing [Issues](https://github.com/lffloyd/embedded-topic-model/issues) that might help you. In case you have found a suitable issue and still need clarification, you can write your question in this issue. It is also advisable to search the internet for answers first. 34 | 35 | If you then still feel the need to ask a question and need clarification, we recommend the following: 36 | 37 | - Open an [Issue](https://github.com/lffloyd/embedded-topic-model/issues/new). 38 | - Provide as much context as you can about what you're running into. 39 | - Provide project and platform versions (`python`, `torch`, OS, etc). 40 | 41 | We will then take care of the issue as soon as possible. 42 | 43 | 57 | 58 | ## I Want To Contribute 59 | 60 | > ### Legal Notice 61 | > When contributing to this project, you must agree that you have authored 100% of the content, that you have the necessary rights to the content and that the content you contribute may be provided under the project license. 62 | 63 | ### Reporting Bugs 64 | 65 | 66 | #### Before Submitting a Bug Report 67 | 68 | A good bug report shouldn't leave others needing to chase you up for more information. Therefore, we ask you to investigate carefully, collect information and describe the issue in detail in your report. Please complete the following steps in advance to help us fix any potential bug as fast as possible. 69 | 70 | - Make sure that you are using the latest version. 71 | - Determine if your bug is really a bug and not an error on your side e.g. using incompatible environment components/versions (Make sure that you have read the [documentation](./README.md). If you are looking for support, you might want to check [this section](#i-have-a-question)). 72 | - To see if other users have experienced (and potentially already solved) the same issue you are having, check if there is not already a bug report existing for your bug or error in the [bug tracker](https://github.com/lffloyd/embedded-topic-modelissues?q=label%3Abug). 73 | - Also make sure to search the internet (including Stack Overflow) to see if users outside of the GitHub community have discussed the issue. 74 | - Collect information about the bug: 75 | - Stack trace (Traceback) 76 | - OS, Platform and Version (Windows, Linux, macOS, x86, ARM) 77 | - Version of the interpreter, compiler, SDK, runtime environment, package manager, depending on what seems relevant. 78 | - Possibly your input and the output 79 | - Can you reliably reproduce the issue? And can you also reproduce it with older versions? A reproduction example would be very useful for us 80 | 81 | 82 | #### How Do I Submit a Good Bug Report? 83 | 84 | > You must never report security related issues, vulnerabilities or bugs including sensitive information to the issue tracker, or elsewhere in public. Instead sensitive bugs must be sent by email to . 85 | 86 | 87 | We use GitHub issues to track bugs and errors. If you run into an issue with the project: 88 | 89 | - Open an [Issue](https://github.com/lffloyd/embedded-topic-model/issues/new). (Since we can't be sure at this point whether it is a bug or not, we ask you not to talk about a bug yet and not to label the issue.) 90 | - Explain the behavior you would expect and the actual behavior. 91 | - Please provide as much context as possible and describe the *reproduction steps* that someone else can follow to recreate the issue on their own. This usually includes your code. For good bug reports you should isolate the problem and create a reduced test case. 92 | - Provide the information you collected in the previous section. 93 | 94 | Once it's filed: 95 | 96 | - The project team will label the issue accordingly. 97 | - A team member will try to reproduce the issue with your provided steps. If there are no reproduction steps or no obvious way to reproduce the issue, the team will ask you for those steps and mark the issue as `needs-repro`. Bugs with the `needs-repro` tag will not be addressed until they are reproduced. 98 | - If the team is able to reproduce the issue, it will be marked `needs-fix`, as well as possibly other tags (such as `critical`), and the issue will be left to be [implemented by someone](#your-first-code-contribution). 99 | 100 | 101 | 102 | 103 | ### Suggesting Enhancements 104 | 105 | This section guides you through submitting an enhancement suggestion for embedded-topic-model, **including completely new features and minor improvements to existing functionality**. Following these guidelines will help maintainers and the community to understand your suggestion and find related suggestions. 106 | 107 | 108 | #### Before Submitting an Enhancement 109 | 110 | - Make sure that you are using the latest version. 111 | - Read the [documentation](./README.md) carefully and find out if the functionality is already covered, maybe by an individual configuration. 112 | - Perform a [search](https://github.com/lffloyd/embedded-topic-model/issues) to see if the enhancement has already been suggested. If it has, add a comment to the existing issue instead of opening a new one. 113 | - Find out whether your idea fits with the scope and aims of the project. It's up to you to make a strong case to convince the project's developers of the merits of this feature. Keep in mind that we want features that will be useful to the majority of our users and not just a small subset. If you're just targeting a minority of users, consider writing an add-on/plugin library. 114 | 115 | 116 | #### How Do I Submit a Good Enhancement Suggestion? 117 | 118 | Enhancement suggestions are tracked as [GitHub issues](https://github.com/lffloyd/embedded-topic-model/issues). 119 | 120 | - Use a **clear and descriptive title** for the issue to identify the suggestion. 121 | - Provide a **step-by-step description of the suggested enhancement** in as many details as possible. 122 | - **Describe the current behavior** and **explain which behavior you expected to see instead** and why. At this point you can also tell which alternatives do not work for you. 123 | - You may want to **include screenshots and animated GIFs** which help you demonstrate the steps or point out the part which the suggestion is related to. You can use [this tool](https://www.cockos.com/licecap/) to record GIFs on macOS and Windows, and [this tool](https://github.com/colinkeenan/silentcast) or [this tool](https://github.com/GNOME/byzanz) on Linux. 124 | - **Explain why this enhancement would be useful** to most embedded-topic-model users. You may also want to point out the other projects that solved it better and which could serve as inspiration. 125 | 126 | 127 | 128 | 134 | 135 | ### Improving The Documentation 136 | 140 | 141 | The current [documentation](./README.md) is not extensive at all. The public API is mostly documented with [docstrings](https://peps.python.org/pep-0257/), but that is not ideal. If you're interested in contributing with the documentation, you can: 142 | 143 | * Add docstrings to undocumented APIs; 144 | * Add usage examples for each public APIs. 145 | 146 | I'm looking into that arena too, so feel free to suggest any improvements you might find feasible. 147 | 148 | ## Styleguides 149 | ### Commit Messages 150 | 153 | 156 | 157 | 158 | ## Attribution 159 | This guide is based on the **contributing-gen**. [Make your own](https://github.com/bttger/contributing-gen)! 160 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Adji B. Dieng, Francisco J. R. Ruiz, David M. Blei 4 | Copyright (c) 2021 Luiz Matos 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Embedded Topic Model 2 | [![PyPI version](https://badge.fury.io/py/embedded-topic-model.svg)](https://badge.fury.io/py/embedded-topic-model) 3 | [![Actions Status](https://github.com/lffloyd/embedded-topic-model/workflows/Python%20package/badge.svg)](https://github.com/lffloyd/embedded-topic-model/actions) 4 | ![GitHub contributors](https://img.shields.io/github/contributors/lffloyd/embedded-topic-model) 5 | ![GitHub Repo stars](https://img.shields.io/github/stars/lffloyd/embedded-topic-model) 6 | [![Downloads](https://static.pepy.tech/badge/embedded-topic-model/month)](https://pepy.tech/project/embedded-topic-model) 7 | [![License](http://img.shields.io/badge/license-MIT-blue.svg?style=flat)](https://github.com/lffloyd/embedded-topic-model/blob/main/LICENSE) 8 | 9 | This package was made to easily run embedded topic modelling on a given corpus. 10 | 11 | ETM is a topic model that marries the probabilistic topic modelling of Latent Dirichlet Allocation with the 12 | contextual information brought by word embeddings-most specifically, word2vec. ETM models topics as points 13 | in the word embedding space, arranging together topics and words with similar context. 14 | As such, ETM can either learn word embeddings alongside topics, or be given pretrained embeddings to discover 15 | the topic patterns on the corpus. 16 | 17 | ETM was originally published by Adji B. Dieng, Francisco J. R. Ruiz, and David M. Blei on a article titled ["Topic Modeling in Embedding Spaces"](https://arxiv.org/abs/1907.04907) in 2019. This code is an adaptation of the [original](https://github.com/adjidieng/ETM) provided with the article and is not affiliated in any manner with the original authors. Most of the original code was kept here, with some changes here and there, mostly for ease of usage. This package was created to facilitate research purposes. If you want a more stable and feature-rich package to train ETM and other models, take a look at [OCTIS](https://github.com/MIND-Lab/OCTIS). 18 | 19 | With the tools provided here, you can run ETM on your dataset using simple steps. 20 | 21 | ## Index 22 | 23 | * [:beer: Installation](#beer-installation) 24 | * [:wrench: Usage](#wrench-usage) 25 | * [:microscope: Examples](#microscope-examples) 26 | * [:books: Citation](#books-citation) 27 | * [:heart: Contributing](#heart-contributing) 28 | * [:v: Acknowledgements](#v-acknowledgements) 29 | * [:pushpin: License](#pushpin-license) 30 | 31 | ## :beer: Installation 32 | You can install the package using ```pip``` by running: ```pip install -U embedded_topic_model``` 33 | 34 | ## :wrench: Usage 35 | To use ETM on your corpus, you must first preprocess the documents into a format understandable by the model. 36 | This package has a quick-use preprocessing script. The only requirement is that the corpus must be composed 37 | by a list of strings, where each string corresponds to a document in the corpus. 38 | 39 | You can preprocess your corpus as follows: 40 | 41 | ```python 42 | from embedded_topic_model.utils import preprocessing 43 | import json 44 | 45 | # Loading a dataset in JSON format. As said, documents must be composed by string sentences 46 | corpus_file = 'datasets/example_dataset.json' 47 | documents_raw = json.load(open(corpus_file, 'r')) 48 | documents = [document['body'] for document in documents_raw] 49 | 50 | # Preprocessing the dataset 51 | vocabulary, train_dataset, _, = preprocessing.create_etm_datasets( 52 | documents, 53 | min_df=0.01, 54 | max_df=0.75, 55 | train_size=0.85, 56 | ) 57 | ``` 58 | 59 | Then, you can train word2vec embeddings to use with the ETM model. This is optional, and if you're not interested 60 | on training your embeddings, you can either pass a pretrained word2vec embeddings file for ETM or learn the embeddings 61 | using ETM itself. If you want ETM to learn its word embeddings, just pass ```train_embeddings=True``` as an instance parameter. 62 | 63 | To pretrain the embeddings, you can do the following: 64 | 65 | ```python 66 | from embedded_topic_model.utils import embedding 67 | 68 | # Training word2vec embeddings 69 | embeddings_mapping = embedding.create_word2vec_embedding_from_dataset(documents) 70 | ``` 71 | 72 | To create and fit the model using the training data, execute: 73 | 74 | ```python 75 | from embedded_topic_model.models.etm import ETM 76 | 77 | # Training an ETM instance 78 | etm_instance = ETM( 79 | vocabulary, 80 | embeddings=embeddings_mapping, # You can pass here the path to a word2vec file or 81 | # a KeyedVectors instance 82 | num_topics=8, 83 | epochs=100, 84 | debug_mode=True, 85 | train_embeddings=False, # Optional. If True, ETM will learn word embeddings jointly with 86 | # topic embeddings. By default, is False. If 'embeddings' argument 87 | # is being passed, this argument must not be True 88 | ) 89 | 90 | etm_instance.fit(train_dataset) 91 | ``` 92 | 93 | You can get the topic words with this method. Note that you can select how many word per topic you're interest in: 94 | ```python 95 | t_w_mtx = etm_instance.get_topics(top_n_words=20) 96 | ``` 97 | 98 | You can get the topic word matrix with this method. Note that it will return all word for each topic: 99 | ```python 100 | t_w_mtx = etm_instance.get_topic_word_matrix() 101 | ``` 102 | 103 | You can get the topic word distribution matrix and the document topic distribution matrix with the following methods, both return a normalized distribution matrix: 104 | ```python 105 | t_w_dist_mtx = etm_instance.get_topic_word_dist() 106 | d_t_dist_mtx = etm_instance.get_document_topic_dist() 107 | ``` 108 | 109 | Also, to obtain topic coherence or topic diversity of the model, you can do as follows: 110 | 111 | ```python 112 | topics = etm_instance.get_topics(20) 113 | topic_coherence = etm_instance.get_topic_coherence() 114 | topic_diversity = etm_instance.get_topic_diversity() 115 | ``` 116 | 117 | You can also predict topics for unseen documents with the following. 118 | 119 | ```python 120 | from embedded_topic_model.utils import preprocessing 121 | from embedded_topic_model.models.etm import ETM 122 | 123 | corpus_file = 'datasets/example_dataset.json' 124 | documents_raw = json.load(open(corpus_file, 'r')) 125 | documents = [document['body'] for document in documents_raw] 126 | 127 | # Splits into train/test datasets 128 | train = documents[:len(documents)-100] 129 | test = documents[len(documents)-100:] 130 | 131 | # Model fitting 132 | # ... 133 | 134 | # The vocabulary must be the same one created during preprocessing of the training dataset (see above) 135 | preprocessed_test = preprocessing.create_bow_dataset(test, vocabulary) 136 | # Transforms test dataset and returns normalized document topic distribution 137 | test_d_t_dist = etm_instance.transform(preprocessed_test) 138 | print(f'test_d_t_dist: {test_d_t_dist}') 139 | ``` 140 | 141 | For further details, see [examples](#microscope-examples). 142 | 143 | ### :microscope: Examples 144 | 145 | | title | link | 146 | | :-------------: | :--: | 147 | | ETM example - Reddit (r/depression) dataset | [Jupyter Notebook](./2023-09-01%20-%20reddit%20-%20depression%20dataset%20-%20etm%20-%20example.ipynb) | 148 | 149 | ## :books: Citation 150 | To cite ETM, use the original article's citation: 151 | 152 | ``` 153 | @article{dieng2019topic, 154 | title = {Topic modeling in embedding spaces}, 155 | author = {Dieng, Adji B and Ruiz, Francisco J R and Blei, David M}, 156 | journal = {arXiv preprint arXiv: 1907.04907}, 157 | year = {2019} 158 | } 159 | ``` 160 | 161 | ## :heart: Contributing 162 | Contributions are always welcomed :heart:! You can take a look at []() to see some guidelines. Feel free to contact through issues, to elaborate on desired enhancements and to check if work is already being done on the matter. 163 | 164 | ## :v: Acknowledgements 165 | Credits given to Adji B. Dieng, Francisco J. R. Ruiz, and David M. Blei for the original work. 166 | 167 | ## :pushpin: License 168 | Licensed under [MIT](LICENSE) license. 169 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | # Security Policy 2 | 3 | ## Supported Versions 4 | 5 | The following table lists currently supported versions of this package. If you have further questions, open an [issue](https://github.com/lffloyd/embedded-topic-model/issues). 6 | 7 | | Version | Supported | 8 | | ------- | ------------------ | 9 | | 1.2.x | :white_check_mark: | 10 | | 1.1.x | :white_check_mark: | 11 | | 1.0.2 | :x: | 12 | | 1.0.x | :x: | 13 | | 0.1.x | :x: | 14 | 15 | ## Reporting a Vulnerability 16 | If you've found a vulnerability on this package or any of its dependencies, feel free to open an [issue](https://github.com/lffloyd/embedded-topic-model/issues) reporting that. Please note, however, that you must never report security related issues, vulnerabilities or bugs including sensitive information to the issue tracker, or elsewhere in public. Instead sensitive bugs must be sent by email to lfmatosmelo@id.uff.br. 17 | -------------------------------------------------------------------------------- /build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | rm -rf build && rm -rf dist 4 | pip install -r dev_requirements.txt || exit 1 5 | pip install -r requirements.txt || exit 1 6 | pytest || exit 1 7 | python setup.py sdist bdist_wheel || exit 1 8 | -------------------------------------------------------------------------------- /conftest.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lfmatosm/embedded-topic-model/71996073d584ec38070dbd62095a021f80bcdb19/conftest.py -------------------------------------------------------------------------------- /create_test_resources.py: -------------------------------------------------------------------------------- 1 | from embedded_topic_model.utils import embedding, preprocessing 2 | import os 3 | import joblib 4 | 5 | sentences = [ 6 | "Peanut butter and jelly caused the elderly lady to think about her past.", 7 | "Toddlers feeding raccoons surprised even the seasoned park ranger.", 8 | "You realize you're not alone as you sit in your bedroom massaging your calves after a long day of playing tug-of-war with Grandpa Joe in the hospital.", 9 | "She wondered what his eyes were saying beneath his mirrored sunglasses.", 10 | "He was disappointed when he found the beach to be so sandy and the sun so sunny.", 11 | "Flesh-colored yoga pants were far worse than even he feared.", 12 | "The wake behind the boat told of the past while the open sea for told life in the unknown future.", 13 | "Improve your goldfish's physical fitness by getting him a bicycle.", 14 | "Harrold felt confident that nobody would ever suspect his spy pigeon.", 15 | "Nudist colonies shun fig-leaf couture.", 16 | ] 17 | vocabulary, train_dataset, test_dataset = preprocessing.create_etm_datasets( 18 | sentences, debug_mode=True) 19 | 20 | embeddings = embedding.create_word2vec_embedding_from_dataset( 21 | sentences, 22 | embedding_file_path='tests/resources/train_w2v_embeddings.wordvectors', 23 | save_c_format_w2vec=True, 24 | debug_mode=True, 25 | ) 26 | 27 | os.makedirs(os.path.dirname('tests/resources/train_resources.test'), exist_ok=True) 28 | joblib.dump( 29 | (vocabulary, 30 | embeddings, 31 | train_dataset, 32 | test_dataset), 33 | './train_resources.test', 34 | compress=8) 35 | 36 | print('the end') 37 | -------------------------------------------------------------------------------- /dev_requirements.txt: -------------------------------------------------------------------------------- 1 | joblib 2 | flake8 3 | pytest>=7.4 4 | autopep8 5 | bump2version 6 | twine 7 | notebook 8 | -------------------------------------------------------------------------------- /embedded_topic_model/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = '1.2.1' 2 | -------------------------------------------------------------------------------- /embedded_topic_model/data/20ng/bow_tr_counts.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lfmatosm/embedded-topic-model/71996073d584ec38070dbd62095a021f80bcdb19/embedded_topic_model/data/20ng/bow_tr_counts.mat -------------------------------------------------------------------------------- /embedded_topic_model/data/20ng/bow_tr_tokens.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lfmatosm/embedded-topic-model/71996073d584ec38070dbd62095a021f80bcdb19/embedded_topic_model/data/20ng/bow_tr_tokens.mat -------------------------------------------------------------------------------- /embedded_topic_model/data/20ng/bow_ts_counts.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lfmatosm/embedded-topic-model/71996073d584ec38070dbd62095a021f80bcdb19/embedded_topic_model/data/20ng/bow_ts_counts.mat -------------------------------------------------------------------------------- /embedded_topic_model/data/20ng/bow_ts_h1_counts.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lfmatosm/embedded-topic-model/71996073d584ec38070dbd62095a021f80bcdb19/embedded_topic_model/data/20ng/bow_ts_h1_counts.mat -------------------------------------------------------------------------------- /embedded_topic_model/data/20ng/bow_ts_h1_tokens.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lfmatosm/embedded-topic-model/71996073d584ec38070dbd62095a021f80bcdb19/embedded_topic_model/data/20ng/bow_ts_h1_tokens.mat -------------------------------------------------------------------------------- /embedded_topic_model/data/20ng/bow_ts_h2_counts.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lfmatosm/embedded-topic-model/71996073d584ec38070dbd62095a021f80bcdb19/embedded_topic_model/data/20ng/bow_ts_h2_counts.mat -------------------------------------------------------------------------------- /embedded_topic_model/data/20ng/bow_ts_h2_tokens.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lfmatosm/embedded-topic-model/71996073d584ec38070dbd62095a021f80bcdb19/embedded_topic_model/data/20ng/bow_ts_h2_tokens.mat -------------------------------------------------------------------------------- /embedded_topic_model/data/20ng/bow_ts_tokens.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lfmatosm/embedded-topic-model/71996073d584ec38070dbd62095a021f80bcdb19/embedded_topic_model/data/20ng/bow_ts_tokens.mat -------------------------------------------------------------------------------- /embedded_topic_model/data/20ng/bow_va_counts.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lfmatosm/embedded-topic-model/71996073d584ec38070dbd62095a021f80bcdb19/embedded_topic_model/data/20ng/bow_va_counts.mat -------------------------------------------------------------------------------- /embedded_topic_model/data/20ng/bow_va_tokens.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lfmatosm/embedded-topic-model/71996073d584ec38070dbd62095a021f80bcdb19/embedded_topic_model/data/20ng/bow_va_tokens.mat -------------------------------------------------------------------------------- /embedded_topic_model/data/20ng/vocab.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lfmatosm/embedded-topic-model/71996073d584ec38070dbd62095a021f80bcdb19/embedded_topic_model/data/20ng/vocab.pkl -------------------------------------------------------------------------------- /embedded_topic_model/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lfmatosm/embedded-topic-model/71996073d584ec38070dbd62095a021f80bcdb19/embedded_topic_model/models/__init__.py -------------------------------------------------------------------------------- /embedded_topic_model/models/etm.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch 4 | import numpy as np 5 | import os 6 | import math 7 | from typing import List 8 | from torch import optim 9 | from gensim.models import KeyedVectors 10 | 11 | from embedded_topic_model.models.model import Model 12 | from embedded_topic_model.utils import data 13 | from embedded_topic_model.utils import embedding 14 | from embedded_topic_model.utils import metrics 15 | 16 | 17 | class ETM(object): 18 | """ 19 | Creates an embedded topic model instance. The model hyperparameters are: 20 | 21 | vocabulary (list of str): training dataset vocabulary 22 | embeddings (str or KeyedVectors): KeyedVectors instance containing word-vector mapping for embeddings, or its path 23 | use_c_format_w2vec (bool): wheter input embeddings use word2vec C format. Both BIN and TXT formats are supported 24 | model_path (str): path to save trained model. If None, the model won't be automatically saved 25 | batch_size (int): input batch size for training 26 | num_topics (int): number of topics 27 | rho_size (int): dimension of rho 28 | emb_size (int): dimension of embeddings 29 | t_hidden_size (int): dimension of hidden space of q(theta) 30 | theta_act (str): tanh, softplus, relu, rrelu, leakyrelu, elu, selu, glu) 31 | train_embeddings (bool): whether to fix rho or train it 32 | lr (float): learning rate 33 | lr_factor (float): divide learning rate by this... 34 | epochs (int): number of epochs to train. 150 for 20ng 100 for others 35 | optimizer_type (str): choice of optimizer 36 | seed (int): random seed (default: 1) 37 | enc_drop (float): dropout rate on encoder 38 | clip (float): gradient clipping 39 | nonmono (int): number of bad hits allowed 40 | wdecay (float): some l2 regularization 41 | anneal_lr (bool): whether to anneal the learning rate or not 42 | bow_norm (bool): normalize the bows or not 43 | num_words (int): number of words for topic viz 44 | log_interval (int): when to log training 45 | visualize_every (int): when to visualize results 46 | eval_batch_size (int): input batch size for evaluation 47 | eval_perplexity (bool): whether to compute perplexity on document completion task 48 | debug_mode (bool): wheter or not should log model operations 49 | """ 50 | 51 | def __init__( 52 | self, 53 | vocabulary, 54 | embeddings=None, 55 | use_c_format_w2vec=False, 56 | model_path=None, 57 | batch_size=1000, 58 | num_topics=50, 59 | rho_size=300, 60 | emb_size=300, 61 | t_hidden_size=800, 62 | theta_act='relu', 63 | train_embeddings=False, 64 | lr=0.005, 65 | lr_factor=4.0, 66 | epochs=20, 67 | optimizer_type='adam', 68 | seed=2019, 69 | enc_drop=0.0, 70 | clip=0.0, 71 | nonmono=10, 72 | wdecay=1.2e-6, 73 | anneal_lr=False, 74 | bow_norm=True, 75 | num_words=10, 76 | log_interval=2, 77 | visualize_every=10, 78 | eval_batch_size=1000, 79 | eval_perplexity=False, 80 | debug_mode=False, 81 | ): 82 | self.vocabulary = vocabulary 83 | self.vocabulary_size = len(self.vocabulary) 84 | self.model_path = model_path 85 | self.batch_size = batch_size 86 | self.num_topics = num_topics 87 | self.rho_size = rho_size 88 | self.emb_size = emb_size 89 | self.t_hidden_size = t_hidden_size 90 | self.theta_act = theta_act 91 | self.lr_factor = lr_factor 92 | self.epochs = epochs 93 | self.seed = seed 94 | self.enc_drop = enc_drop 95 | self.clip = clip 96 | self.nonmono = nonmono 97 | self.anneal_lr = anneal_lr 98 | self.bow_norm = bow_norm 99 | self.num_words = num_words 100 | self.log_interval = log_interval 101 | self.visualize_every = visualize_every 102 | self.eval_batch_size = eval_batch_size 103 | self.eval_perplexity = eval_perplexity 104 | self.debug_mode = debug_mode 105 | 106 | device = 'cpu' 107 | if torch.cuda.is_available(): 108 | device = 'cuda' 109 | elif torch.backends.mps.is_available(): 110 | device = 'mps' 111 | self.device = torch.device(device) 112 | torch.manual_seed(self.seed) 113 | 114 | np.random.seed(self.seed) 115 | 116 | self.embeddings = None if train_embeddings else self._initialize_embeddings( 117 | embeddings, use_c_format_w2vec=use_c_format_w2vec) 118 | 119 | self.model = Model( 120 | self.device, 121 | self.num_topics, 122 | self.vocabulary_size, 123 | self.t_hidden_size, 124 | self.rho_size, 125 | self.emb_size, 126 | self.theta_act, 127 | self.embeddings, 128 | train_embeddings, 129 | self.enc_drop, 130 | self.debug_mode).to( 131 | self.device) 132 | self.optimizer = self._get_optimizer(optimizer_type, lr, wdecay) 133 | 134 | def __str__(self): 135 | return f'{self.model}' 136 | 137 | def _get_extension(self, path): 138 | assert isinstance(path, str), 'path extension is not str' 139 | filename = path.split(os.path.sep)[-1] 140 | return filename.split('.')[-1] 141 | 142 | def _get_embeddings_from_original_word2vec(self, embeddings_file): 143 | if self._get_extension(embeddings_file) == 'txt': 144 | if self.debug_mode: 145 | print('Reading embeddings from original word2vec TXT file...') 146 | vectors = {} 147 | iterator = embedding.MemoryFriendlyFileIterator(embeddings_file) 148 | for line in iterator: 149 | word = line[0] 150 | if word in self.vocabulary: 151 | vect = np.array(line[1:]).astype(float) 152 | vectors[word] = vect 153 | return vectors 154 | elif self._get_extension(embeddings_file) == 'bin': 155 | if self.debug_mode: 156 | print('Reading embeddings from original word2vec BIN file...') 157 | return KeyedVectors.load_word2vec_format( 158 | embeddings_file, 159 | binary=True 160 | ) 161 | else: 162 | raise Exception('Original Word2Vec file without BIN/TXT extension') 163 | 164 | def _initialize_embeddings( 165 | self, 166 | embeddings, 167 | use_c_format_w2vec=False 168 | ): 169 | vectors = embeddings if isinstance(embeddings, KeyedVectors) else {} 170 | 171 | if use_c_format_w2vec: 172 | vectors = self._get_embeddings_from_original_word2vec(embeddings) 173 | elif isinstance(embeddings, str): 174 | if self.debug_mode: 175 | print('Reading embeddings from word2vec file...') 176 | vectors = KeyedVectors.load(embeddings, mmap='r') 177 | 178 | model_embeddings = np.zeros((self.vocabulary_size, self.emb_size)) 179 | 180 | for i, word in enumerate(self.vocabulary): 181 | try: 182 | model_embeddings[i] = vectors[word] 183 | except KeyError: 184 | model_embeddings[i] = np.random.normal( 185 | scale=0.6, size=(self.emb_size, )) 186 | return torch.from_numpy(model_embeddings.astype(np.float32)).to(self.device) 187 | 188 | def _get_optimizer(self, optimizer_type, learning_rate, wdecay): 189 | if optimizer_type == 'adam': 190 | return optim.Adam( 191 | self.model.parameters(), 192 | lr=learning_rate, 193 | weight_decay=wdecay) 194 | elif optimizer_type == 'adagrad': 195 | return optim.Adagrad( 196 | self.model.parameters(), 197 | lr=learning_rate, 198 | weight_decay=wdecay) 199 | elif optimizer_type == 'adadelta': 200 | return optim.Adadelta( 201 | self.model.parameters(), 202 | lr=learning_rate, 203 | weight_decay=wdecay) 204 | elif optimizer_type == 'rmsprop': 205 | return optim.RMSprop( 206 | self.model.parameters(), 207 | lr=learning_rate, 208 | weight_decay=wdecay) 209 | elif optimizer_type == 'asgd': 210 | return optim.ASGD( 211 | self.model.parameters(), 212 | lr=learning_rate, 213 | t0=0, 214 | lambd=0., 215 | weight_decay=wdecay) 216 | else: 217 | if self.debug_mode: 218 | print('Defaulting to vanilla SGD') 219 | return optim.SGD(self.model.parameters(), lr=learning_rate) 220 | 221 | def _set_training_data(self, train_data): 222 | self.train_tokens = train_data['tokens'] 223 | self.train_counts = train_data['counts'] 224 | self.num_docs_train = len(self.train_tokens) 225 | 226 | def _set_test_data(self, test_data): 227 | self.test_tokens = test_data['test']['tokens'] 228 | self.test_counts = test_data['test']['counts'] 229 | self.num_docs_test = len(self.test_tokens) 230 | self.test_1_tokens = test_data['test1']['tokens'] 231 | self.test_1_counts = test_data['test1']['counts'] 232 | self.num_docs_test_1 = len(self.test_1_tokens) 233 | self.test_2_tokens = test_data['test2']['tokens'] 234 | self.test_2_counts = test_data['test2']['counts'] 235 | self.num_docs_test_2 = len(self.test_2_tokens) 236 | 237 | def _train(self, epoch): 238 | self.model.train() 239 | acc_loss = 0 240 | acc_kl_theta_loss = 0 241 | cnt = 0 242 | indices = torch.randperm(self.num_docs_train) 243 | indices = torch.split(indices, self.batch_size) 244 | for idx, ind in enumerate(indices): 245 | self.optimizer.zero_grad() 246 | self.model.zero_grad() 247 | 248 | data_batch = data.get_batch( 249 | self.train_tokens, 250 | self.train_counts, 251 | ind, 252 | self.vocabulary_size, 253 | self.device) 254 | sums = data_batch.sum(1).unsqueeze(1) 255 | if self.bow_norm: 256 | normalized_data_batch = data_batch / sums 257 | else: 258 | normalized_data_batch = data_batch 259 | recon_loss, kld_theta = self.model( 260 | data_batch, normalized_data_batch) 261 | total_loss = recon_loss + kld_theta 262 | total_loss.backward() 263 | 264 | if self.clip > 0: 265 | torch.nn.utils.clip_grad_norm_( 266 | self.model.parameters(), self.clip) 267 | self.optimizer.step() 268 | 269 | acc_loss += torch.sum(recon_loss).item() 270 | acc_kl_theta_loss += torch.sum(kld_theta).item() 271 | cnt += 1 272 | 273 | if idx % self.log_interval == 0 and idx > 0: 274 | cur_loss = round(acc_loss / cnt, 2) 275 | cur_kl_theta = round(acc_kl_theta_loss / cnt, 2) 276 | cur_real_loss = round(cur_loss + cur_kl_theta, 2) 277 | 278 | cur_loss = round(acc_loss / cnt, 2) 279 | cur_kl_theta = round(acc_kl_theta_loss / cnt, 2) 280 | cur_real_loss = round(cur_loss + cur_kl_theta, 2) 281 | 282 | if self.debug_mode: 283 | print('Epoch {} - Learning Rate: {} - KL theta: {} - Rec loss: {} - NELBO: {}'.format( 284 | epoch, self.optimizer.param_groups[0]['lr'], cur_kl_theta, cur_loss, cur_real_loss)) 285 | 286 | def _perplexity(self, test_data) -> float: 287 | """Computes perplexity on document completion for a given testing data. 288 | 289 | The document completion task is described on the original ETM's article: https://arxiv.org/pdf/1907.04907.pdf 290 | 291 | Parameters: 292 | === 293 | test_data (dict): BOW testing dataset, split in tokens and counts and used for perplexity 294 | 295 | Returns: 296 | === 297 | float: perplexity score on document completion task 298 | """ 299 | self._set_test_data(test_data) 300 | 301 | self.model.eval() 302 | with torch.no_grad(): 303 | # get \beta here 304 | beta = self.model.get_beta() 305 | 306 | # do dc here 307 | acc_loss = 0 308 | cnt = 0 309 | indices_1 = torch.split( 310 | torch.tensor( 311 | range( 312 | self.num_docs_test_1)), 313 | self.eval_batch_size) 314 | for idx, ind in enumerate(indices_1): 315 | # get theta from first half of docs 316 | data_batch_1 = data.get_batch( 317 | self.test_1_tokens, 318 | self.test_1_counts, 319 | ind, 320 | self.vocabulary_size, 321 | self.device) 322 | sums_1 = data_batch_1.sum(1).unsqueeze(1) 323 | if self.bow_norm: 324 | normalized_data_batch_1 = data_batch_1 / sums_1 325 | else: 326 | normalized_data_batch_1 = data_batch_1 327 | theta, _ = self.model.get_theta(normalized_data_batch_1) 328 | 329 | # get prediction loss using second half 330 | data_batch_2 = data.get_batch( 331 | self.test_2_tokens, 332 | self.test_2_counts, 333 | ind, 334 | self.vocabulary_size, 335 | self.device) 336 | sums_2 = data_batch_2.sum(1).unsqueeze(1) 337 | res = torch.mm(theta, beta) 338 | preds = torch.log(res) 339 | recon_loss = -(preds * data_batch_2).sum(1) 340 | 341 | loss = recon_loss / sums_2.squeeze() 342 | loss = loss.mean().item() 343 | acc_loss += loss 344 | cnt += 1 345 | 346 | cur_loss = acc_loss / cnt 347 | ppl_dc = round(math.exp(cur_loss), 1) 348 | 349 | if self.debug_mode: 350 | print(f'Document Completion Task Perplexity: {ppl_dc}') 351 | 352 | return ppl_dc 353 | 354 | def get_topics(self, top_n_words=10) -> List[str]: 355 | """ 356 | Gets topics. By default, returns the 10 most relevant terms for each topic. 357 | 358 | Parameters: 359 | === 360 | top_n_words (int): number of top words per topic to return 361 | 362 | Returns: 363 | === 364 | list of str: topic list 365 | """ 366 | 367 | with torch.no_grad(): 368 | topics = [] 369 | gammas = self.model.get_beta() 370 | 371 | for k in range(self.num_topics): 372 | gamma = gammas[k] 373 | top_words = list(gamma.cpu().numpy().argsort() 374 | [-top_n_words:][::-1]) 375 | topic_words = [self.vocabulary[a] for a in top_words] 376 | topics.append(topic_words) 377 | 378 | return topics 379 | 380 | def get_most_similar_words(self, queries, n_most_similar=20) -> dict: 381 | """ 382 | Gets the nearest neighborhoring words for a list of tokens. By default, returns the 20 most similar words for each token in 'queries' array. 383 | 384 | Parameters: 385 | === 386 | queries (list of str): words to find similar ones 387 | n_most_similar (int): number of most similar words to get for each word given in the input. By default is 20 388 | 389 | Returns: 390 | === 391 | dict of (str, list of str): dictionary containing the mapping between query words given and their respective similar words 392 | """ 393 | 394 | self.model.eval() 395 | 396 | # visualize word embeddings by using V to get nearest neighbors 397 | with torch.no_grad(): 398 | try: 399 | self.embeddings = self.model.rho.weight # Vocab_size x E 400 | except BaseException: 401 | self.embeddings = self.model.rho # Vocab_size x E 402 | 403 | neighbors = {} 404 | for word in queries: 405 | neighbors[word] = metrics.nearest_neighbors( 406 | word, self.embeddings, self.vocabulary, n_most_similar) 407 | 408 | return neighbors 409 | 410 | def fit(self, train_data, test_data=None): 411 | """ 412 | Trains the model with the given training data. 413 | 414 | Optionally receives testing data for perplexity calculation. The testing data is 415 | only used if the 'eval_perplexity' model parameter is True. 416 | 417 | Parameters: 418 | === 419 | train_data (dict): BOW training dataset, split in tokens and counts 420 | test_data (dict): optional. BOW testing dataset, split in tokens and counts. Used for perplexity calculation, if activated 421 | 422 | Returns: 423 | === 424 | self (ETM): the instance itself 425 | """ 426 | self._set_training_data(train_data) 427 | 428 | best_val_ppl = 1e9 429 | all_val_ppls = [] 430 | 431 | if self.debug_mode: 432 | print(f'Topics before training: {self.get_topics()}') 433 | 434 | for epoch in range(1, self.epochs): 435 | self._train(epoch) 436 | 437 | if self.eval_perplexity: 438 | val_ppl = self._perplexity( 439 | test_data) 440 | if val_ppl < best_val_ppl: 441 | if self.model_path is not None: 442 | self._save_model(self.model_path) 443 | best_val_ppl = val_ppl 444 | else: 445 | # check whether to anneal lr 446 | lr = self.optimizer.param_groups[0]['lr'] 447 | if self.anneal_lr and (len(all_val_ppls) > self.nonmono and val_ppl > min( 448 | all_val_ppls[:-self.nonmono]) and lr > 1e-5): 449 | self.optimizer.param_groups[0]['lr'] /= self.lr_factor 450 | 451 | all_val_ppls.append(val_ppl) 452 | 453 | if self.debug_mode and (epoch % self.visualize_every == 0): 454 | print(f'Topics: {self.get_topics()}') 455 | 456 | if self.model_path is not None: 457 | self._save_model(self.model_path) 458 | 459 | if self.eval_perplexity and self.model_path is not None: 460 | self._load_model(self.model_path) 461 | val_ppl = self._perplexity(train_data) 462 | 463 | return self 464 | 465 | def get_topic_word_matrix(self) -> List[List[str]]: 466 | """ 467 | Obtains the topic word matrix learned for the model. 468 | 469 | The topic word matrix lists all words for each discovered topic. 470 | As such, this method will return a matrix representing the words. 471 | 472 | Returns: 473 | === 474 | list of list of str: topic word matrix. 475 | Example: 476 | [['world', 'planet', 'stars', 'moon', 'astrophysics'], ...] 477 | """ 478 | self.model = self.model.to(self.device) 479 | self.model.eval() 480 | 481 | with torch.no_grad(): 482 | beta = self.model.get_beta() 483 | 484 | topics = [] 485 | 486 | for i in range(self.num_topics): 487 | words = list(beta[i].cpu().numpy()) 488 | topic_words = [self.vocabulary[a] for a, _ in enumerate(words)] 489 | topics.append(topic_words) 490 | 491 | return topics 492 | 493 | def get_topic_word_dist(self) -> torch.Tensor: 494 | """ 495 | Obtains the topic word distribution matrix. 496 | 497 | The topic word distribution matrix lists the probabilities for each word on each topic. 498 | 499 | This is a normalized distribution matrix, and as such, each row sums to one. 500 | 501 | Returns: 502 | === 503 | torch.Tensor: topic word distribution matrix, with KxV dimension, where 504 | K is the number of topics and V is the vocabulary size 505 | Example: 506 | tensor([[3.2238e-04, 3.7851e-03, 3.2811e-04, ..., 8.4206e-05, 7.9504e-05, 507 | 4.0738e-04], 508 | [3.6089e-05, 3.0677e-03, 1.3650e-04, ..., 4.5665e-05, 1.3241e-04, 509 | 5.8661e-05]]) 510 | """ 511 | self.model = self.model.to(self.device) 512 | self.model.eval() 513 | 514 | with torch.no_grad(): 515 | return self.model.get_beta() 516 | 517 | def get_document_topic_dist(self) -> torch.Tensor: 518 | """ 519 | Obtains the document topic distribution matrix. 520 | 521 | The document topic distribution matrix lists the probabilities for each topic on each document. 522 | 523 | This is a normalized distribution matrix, and as such, each row sums to one. 524 | 525 | Returns: 526 | === 527 | torch.Tensor: document topic distribution matrix, with DxK dimension, where 528 | D is the number of documents in the training corpus and K is the number of topics 529 | Example: 530 | tensor([[0.1840, 0.0489, 0.1020, 0.0726, 0.1952, 0.1042, 0.1275, 0.1657], 531 | [0.1417, 0.0918, 0.2263, 0.0840, 0.0900, 0.1635, 0.1209, 0.0817]]) 532 | """ 533 | self.model = self.model.to(self.device) 534 | self.model.eval() 535 | 536 | with torch.no_grad(): 537 | indices = torch.tensor(range(self.num_docs_train)) 538 | indices = torch.split(indices, self.batch_size) 539 | 540 | thetas = [] 541 | 542 | for ind in indices: 543 | data_batch = data.get_batch( 544 | self.train_tokens, 545 | self.train_counts, 546 | ind, 547 | self.vocabulary_size, 548 | self.device) 549 | sums = data_batch.sum(1).unsqueeze(1) 550 | normalized_data_batch = data_batch / sums if self.bow_norm else data_batch 551 | theta, _ = self.model.get_theta(normalized_data_batch) 552 | 553 | thetas.append(theta) 554 | 555 | return torch.cat(tuple(thetas), 0) 556 | 557 | def transform(self, X) -> torch.Tensor: 558 | """ 559 | Transforms the given data with the learned distribution, outputting prediction for unseen data. 560 | 561 | Parameters: 562 | === 563 | X (dict): BOW dataset, split in tokens and counts 564 | 565 | Returns: 566 | === 567 | torch.Tensor: document topic distribution matrix, with DxK dimension, where 568 | D is the number of documents in the corpus X and K is the number of topics. 569 | This is a normalized distribution matrix, and as such, each row sums to one. 570 | Example: 571 | tensor([[0.1840, 0.0489, 0.1020, 0.0726, 0.1952, 0.1042, 0.1275, 0.1657], 572 | [0.1417, 0.0918, 0.2263, 0.0840, 0.0900, 0.1635, 0.1209, 0.0817]]) 573 | """ 574 | self.model = self.model.to(self.device) 575 | self.model.eval() 576 | 577 | with torch.no_grad(): 578 | indices = torch.tensor(range(len(X["tokens"]))) 579 | indices = torch.split(indices, self.batch_size) 580 | 581 | thetas = [] 582 | 583 | for ind in indices: 584 | data_batch = data.get_batch( 585 | X["tokens"], 586 | X["counts"], 587 | ind, 588 | self.vocabulary_size, 589 | self.device) 590 | sums = data_batch.sum(1).unsqueeze(1) 591 | normalized_data_batch = data_batch / sums if self.bow_norm else data_batch 592 | theta, _ = self.model.get_theta(normalized_data_batch) 593 | 594 | thetas.append(theta) 595 | 596 | return torch.cat(tuple(thetas), 0) 597 | 598 | def get_topic_coherence(self, top_n=10) -> float: 599 | """ 600 | Calculates NPMI topic coherence for the model. 601 | 602 | By default, considers the 10 most relevant terms for each topic in coherence computation. 603 | 604 | Parameters: 605 | === 606 | top_n (int): number of words per topic to consider in coherence computation 607 | 608 | Returns: 609 | === 610 | float: the model's topic coherence 611 | """ 612 | self.model = self.model.to(self.device) 613 | self.model.eval() 614 | 615 | with torch.no_grad(): 616 | beta = self.model.get_beta().data.cpu().numpy() 617 | return metrics.get_topic_coherence( 618 | beta, self.train_tokens, self.vocabulary, top_n) 619 | 620 | def get_topic_diversity(self, top_n=25) -> float: 621 | """ 622 | Calculates topic diversity for the model. 623 | 624 | By default, considers the 25 most relevant terms for each topic in diversity computation. 625 | 626 | Parameters: 627 | === 628 | top_n (int): number of words per topic to consider in diversity computation 629 | 630 | Returns: 631 | === 632 | float: the model's topic diversity 633 | """ 634 | self.model = self.model.to(self.device) 635 | self.model.eval() 636 | 637 | with torch.no_grad(): 638 | beta = self.model.get_beta().data.cpu().numpy() 639 | return metrics.get_topic_diversity(beta, top_n) 640 | 641 | def _save_model(self, model_path): 642 | assert self.model is not None, \ 643 | 'no model to save' 644 | 645 | if not os.path.exists(model_path): 646 | os.makedirs(os.path.dirname(model_path), exist_ok=True) 647 | 648 | with open(model_path, 'wb') as file: 649 | torch.save(self.model, file) 650 | 651 | def _load_model(self, model_path): 652 | assert os.path.exists(model_path), \ 653 | "model path doesn't exists" 654 | 655 | with open(model_path, 'rb') as file: 656 | self.model = torch.load(file) 657 | self.model = self.model.to(self.device) 658 | -------------------------------------------------------------------------------- /embedded_topic_model/models/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | 6 | class Model(nn.Module): 7 | def __init__( 8 | self, 9 | device, 10 | num_topics, 11 | vocab_size, 12 | t_hidden_size, 13 | rho_size, 14 | emsize, 15 | theta_act, 16 | embeddings=None, 17 | train_embeddings=True, 18 | enc_drop=0.5, 19 | debug_mode=False): 20 | super(Model, self).__init__() 21 | 22 | # define hyperparameters 23 | self.num_topics = num_topics 24 | self.vocab_size = vocab_size 25 | self.t_hidden_size = t_hidden_size 26 | self.rho_size = rho_size 27 | self.enc_drop = enc_drop 28 | self.emsize = emsize 29 | self.t_drop = nn.Dropout(enc_drop) 30 | self.debug_mode = debug_mode 31 | self.theta_act = self.get_activation(theta_act) 32 | 33 | self.device = device 34 | 35 | # define the word embedding matrix \rho 36 | if train_embeddings: 37 | self.rho = nn.Linear(rho_size, vocab_size, bias=False) 38 | else: 39 | num_embeddings, emsize = embeddings.size() 40 | self.rho = embeddings.clone().float().to(self.device) 41 | 42 | # define the matrix containing the topic embeddings 43 | # nn.Parameter(torch.randn(rho_size, num_topics)) 44 | self.alphas = nn.Linear(rho_size, num_topics, bias=False) 45 | 46 | # define variational distribution for \theta_{1:D} via amortizartion 47 | self.q_theta = nn.Sequential( 48 | nn.Linear(vocab_size, t_hidden_size), 49 | self.theta_act, 50 | nn.Linear(t_hidden_size, t_hidden_size), 51 | self.theta_act, 52 | ) 53 | self.mu_q_theta = nn.Linear(t_hidden_size, num_topics, bias=True) 54 | self.logsigma_q_theta = nn.Linear(t_hidden_size, num_topics, bias=True) 55 | 56 | def get_activation(self, act): 57 | if act == 'tanh': 58 | act = nn.Tanh() 59 | elif act == 'relu': 60 | act = nn.ReLU() 61 | elif act == 'softplus': 62 | act = nn.Softplus() 63 | elif act == 'rrelu': 64 | act = nn.RReLU() 65 | elif act == 'leakyrelu': 66 | act = nn.LeakyReLU() 67 | elif act == 'elu': 68 | act = nn.ELU() 69 | elif act == 'selu': 70 | act = nn.SELU() 71 | elif act == 'glu': 72 | act = nn.GLU() 73 | else: 74 | act = nn.Tanh() 75 | if self.debug_mode: 76 | print('Defaulting to tanh activation') 77 | return act 78 | 79 | def reparameterize(self, mu, logvar): 80 | """Returns a sample from a Gaussian distribution via reparameterization. 81 | """ 82 | if self.training: 83 | std = torch.exp(0.5 * logvar) 84 | eps = torch.randn_like(std) 85 | return eps.mul_(std).add_(mu) 86 | else: 87 | return mu 88 | 89 | def encode(self, bows): 90 | """Returns paramters of the variational distribution for \theta. 91 | 92 | input: bows 93 | batch of bag-of-words...tensor of shape bsz x V 94 | output: mu_theta, log_sigma_theta 95 | """ 96 | q_theta = self.q_theta(bows) 97 | if self.enc_drop > 0: 98 | q_theta = self.t_drop(q_theta) 99 | mu_theta = self.mu_q_theta(q_theta) 100 | logsigma_theta = self.logsigma_q_theta(q_theta) 101 | kl_theta = -0.5 * \ 102 | torch.sum(1 + logsigma_theta - mu_theta.pow(2) - logsigma_theta.exp(), dim=-1).mean() 103 | return mu_theta, logsigma_theta, kl_theta 104 | 105 | def get_beta(self): 106 | try: 107 | # torch.mm(self.rho, self.alphas) 108 | logit = self.alphas(self.rho.weight) 109 | except BaseException: 110 | logit = self.alphas(self.rho) 111 | beta = F.softmax( 112 | logit, dim=0).transpose( 113 | 1, 0) # softmax over vocab dimension 114 | return beta 115 | 116 | def get_theta(self, normalized_bows): 117 | mu_theta, logsigma_theta, kld_theta = self.encode(normalized_bows) 118 | z = self.reparameterize(mu_theta, logsigma_theta) 119 | theta = F.softmax(z, dim=-1) 120 | return theta, kld_theta 121 | 122 | def decode(self, theta, beta): 123 | res = torch.mm(theta, beta) 124 | preds = torch.log(res + 1e-6) 125 | return preds 126 | 127 | def forward(self, bows, normalized_bows, theta=None, aggregate=True): 128 | # get \theta 129 | if theta is None: 130 | theta, kld_theta = self.get_theta(normalized_bows) 131 | else: 132 | kld_theta = None 133 | 134 | # get \beta 135 | beta = self.get_beta() 136 | 137 | # get prediction loss 138 | preds = self.decode(theta, beta) 139 | recon_loss = -(preds * bows).sum(1) 140 | if aggregate: 141 | recon_loss = recon_loss.mean() 142 | return recon_loss, kld_theta 143 | -------------------------------------------------------------------------------- /embedded_topic_model/scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lfmatosm/embedded-topic-model/71996073d584ec38070dbd62095a021f80bcdb19/embedded_topic_model/scripts/__init__.py -------------------------------------------------------------------------------- /embedded_topic_model/scripts/datasets/20ng/bow_tr_counts.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lfmatosm/embedded-topic-model/71996073d584ec38070dbd62095a021f80bcdb19/embedded_topic_model/scripts/datasets/20ng/bow_tr_counts.mat -------------------------------------------------------------------------------- /embedded_topic_model/scripts/datasets/20ng/bow_tr_tokens.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lfmatosm/embedded-topic-model/71996073d584ec38070dbd62095a021f80bcdb19/embedded_topic_model/scripts/datasets/20ng/bow_tr_tokens.mat -------------------------------------------------------------------------------- /embedded_topic_model/scripts/datasets/20ng/bow_ts_counts.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lfmatosm/embedded-topic-model/71996073d584ec38070dbd62095a021f80bcdb19/embedded_topic_model/scripts/datasets/20ng/bow_ts_counts.mat -------------------------------------------------------------------------------- /embedded_topic_model/scripts/datasets/20ng/bow_ts_h1_counts.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lfmatosm/embedded-topic-model/71996073d584ec38070dbd62095a021f80bcdb19/embedded_topic_model/scripts/datasets/20ng/bow_ts_h1_counts.mat -------------------------------------------------------------------------------- /embedded_topic_model/scripts/datasets/20ng/bow_ts_h1_tokens.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lfmatosm/embedded-topic-model/71996073d584ec38070dbd62095a021f80bcdb19/embedded_topic_model/scripts/datasets/20ng/bow_ts_h1_tokens.mat -------------------------------------------------------------------------------- /embedded_topic_model/scripts/datasets/20ng/bow_ts_h2_counts.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lfmatosm/embedded-topic-model/71996073d584ec38070dbd62095a021f80bcdb19/embedded_topic_model/scripts/datasets/20ng/bow_ts_h2_counts.mat -------------------------------------------------------------------------------- /embedded_topic_model/scripts/datasets/20ng/bow_ts_h2_tokens.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lfmatosm/embedded-topic-model/71996073d584ec38070dbd62095a021f80bcdb19/embedded_topic_model/scripts/datasets/20ng/bow_ts_h2_tokens.mat -------------------------------------------------------------------------------- /embedded_topic_model/scripts/datasets/20ng/bow_ts_tokens.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lfmatosm/embedded-topic-model/71996073d584ec38070dbd62095a021f80bcdb19/embedded_topic_model/scripts/datasets/20ng/bow_ts_tokens.mat -------------------------------------------------------------------------------- /embedded_topic_model/scripts/datasets/20ng/bow_va_counts.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lfmatosm/embedded-topic-model/71996073d584ec38070dbd62095a021f80bcdb19/embedded_topic_model/scripts/datasets/20ng/bow_va_counts.mat -------------------------------------------------------------------------------- /embedded_topic_model/scripts/datasets/20ng/bow_va_tokens.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lfmatosm/embedded-topic-model/71996073d584ec38070dbd62095a021f80bcdb19/embedded_topic_model/scripts/datasets/20ng/bow_va_tokens.mat -------------------------------------------------------------------------------- /embedded_topic_model/scripts/datasets/20ng/vocab.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lfmatosm/embedded-topic-model/71996073d584ec38070dbd62095a021f80bcdb19/embedded_topic_model/scripts/datasets/20ng/vocab.pkl -------------------------------------------------------------------------------- /embedded_topic_model/scripts/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lfmatosm/embedded-topic-model/71996073d584ec38070dbd62095a021f80bcdb19/embedded_topic_model/scripts/datasets/__init__.py -------------------------------------------------------------------------------- /embedded_topic_model/scripts/datasets/data_20ng.py: -------------------------------------------------------------------------------- 1 | from sklearn.feature_extraction.text import CountVectorizer 2 | from sklearn.datasets import fetch_20newsgroups 3 | import numpy as np 4 | import pickle 5 | from scipy import sparse 6 | from scipy.io import savemat 7 | import re 8 | import string 9 | import os 10 | 11 | # Maximum / minimum document frequency 12 | max_df = 0.7 13 | min_df = 10 # choose desired value for min_df 14 | 15 | # Read stopwords 16 | with open('stops.txt', 'r') as f: 17 | stops = f.read().split('\n') 18 | 19 | # Read data 20 | print('reading data...') 21 | train_data = fetch_20newsgroups(subset='train') 22 | test_data = fetch_20newsgroups(subset='test') 23 | 24 | init_docs_tr = [ 25 | re.findall( 26 | r'''[\w']+|[.,!?;-~{}`´_<=>:/@*()&'$%#"]''', 27 | train_data.data[doc]) for doc in range( 28 | len( 29 | train_data.data))] 30 | init_docs_ts = [ 31 | re.findall( 32 | r'''[\w']+|[.,!?;-~{}`´_<=>:/@*()&'$%#"]''', 33 | test_data.data[doc]) for doc in range( 34 | len( 35 | test_data.data))] 36 | 37 | 38 | def contains_punctuation(w): 39 | return any(char in string.punctuation for char in w) 40 | 41 | 42 | def contains_numeric(w): 43 | return any(char.isdigit() for char in w) 44 | 45 | 46 | init_docs = init_docs_tr + init_docs_ts 47 | init_docs = [[w.lower() for w in init_docs[doc] if not contains_punctuation(w)] 48 | for doc in range(len(init_docs))] 49 | init_docs = [[w for w in init_docs[doc] if not contains_numeric( 50 | w)] for doc in range(len(init_docs))] 51 | init_docs = [[w for w in init_docs[doc] if len( 52 | w) > 1] for doc in range(len(init_docs))] 53 | init_docs = [" ".join(init_docs[doc]) for doc in range(len(init_docs))] 54 | 55 | # Create count vectorizer 56 | print('counting document frequency of words...') 57 | cvectorizer = CountVectorizer(min_df=min_df, max_df=max_df, stop_words=None) 58 | cvz = cvectorizer.fit_transform(init_docs).sign() 59 | 60 | # Get vocabulary 61 | print('building the vocabulary...') 62 | sum_counts = cvz.sum(axis=0) 63 | v_size = sum_counts.shape[1] 64 | sum_counts_np = np.zeros(v_size, dtype=int) 65 | for v in range(v_size): 66 | sum_counts_np[v] = sum_counts[0, v] 67 | word2id = dict([(w, cvectorizer.vocabulary_.get(w)) 68 | for w in cvectorizer.vocabulary_]) 69 | id2word = dict([(cvectorizer.vocabulary_.get(w), w) 70 | for w in cvectorizer.vocabulary_]) 71 | del cvectorizer 72 | print(' initial vocabulary size: {}'.format(v_size)) 73 | 74 | # Sort elements in vocabulary 75 | idx_sort = np.argsort(sum_counts_np) 76 | vocab_aux = [id2word[idx_sort[cc]] for cc in range(v_size)] 77 | 78 | # Filter out stopwords (if any) 79 | vocab_aux = [w for w in vocab_aux if w not in stops] 80 | print( 81 | ' vocabulary size after removing stopwords from list: {}'.format( 82 | len(vocab_aux))) 83 | 84 | # Create dictionary and inverse dictionary 85 | vocab = vocab_aux 86 | del vocab_aux 87 | word2id = dict([(w, j) for j, w in enumerate(vocab)]) 88 | id2word = dict([(j, w) for j, w in enumerate(vocab)]) 89 | 90 | # Split in train/test/valid 91 | print('tokenizing documents and splitting into train/test/valid...') 92 | num_docs_tr = len(init_docs_tr) 93 | trSize = num_docs_tr - 100 94 | tsSize = len(init_docs_ts) 95 | vaSize = 100 96 | idx_permute = np.random.permutation(num_docs_tr).astype(int) 97 | 98 | # Remove words not in train_data 99 | vocab = list(set([w for idx_d in range(trSize) 100 | for w in init_docs[idx_permute[idx_d]].split() if w in word2id])) 101 | word2id = dict([(w, j) for j, w in enumerate(vocab)]) 102 | id2word = dict([(j, w) for j, w in enumerate(vocab)]) 103 | print(' vocabulary after removing words not in train: {}'.format(len(vocab))) 104 | 105 | # Split in train/test/valid 106 | docs_tr = [[word2id[w] for w in init_docs[idx_permute[idx_d]].split() if w in word2id] 107 | for idx_d in range(trSize)] 108 | docs_va = [[word2id[w] for w in init_docs[idx_permute[idx_d + trSize] 109 | ].split() if w in word2id] for idx_d in range(vaSize)] 110 | docs_ts = [[word2id[w] for w in init_docs[idx_d + num_docs_tr].split() if w in word2id] 111 | for idx_d in range(tsSize)] 112 | 113 | print( 114 | ' number of documents (train): {} [this should be equal to {}]'.format( 115 | len(docs_tr), 116 | trSize)) 117 | print( 118 | ' number of documents (test): {} [this should be equal to {}]'.format( 119 | len(docs_ts), 120 | tsSize)) 121 | print( 122 | ' number of documents (valid): {} [this should be equal to {}]'.format( 123 | len(docs_va), 124 | vaSize)) 125 | 126 | # Remove empty documents 127 | print('removing empty documents...') 128 | 129 | 130 | def remove_empty(in_docs): 131 | return [doc for doc in in_docs if doc != []] 132 | 133 | 134 | docs_tr = remove_empty(docs_tr) 135 | docs_ts = remove_empty(docs_ts) 136 | docs_va = remove_empty(docs_va) 137 | 138 | # Remove test documents with length=1 139 | docs_ts = [doc for doc in docs_ts if len(doc) > 1] 140 | 141 | # Split test set in 2 halves 142 | print('splitting test documents in 2 halves...') 143 | docs_ts_h1 = [[w for i, w in enumerate( 144 | doc) if i <= len(doc) / 2.0 - 1] for doc in docs_ts] 145 | docs_ts_h2 = [[w for i, w in enumerate( 146 | doc) if i > len(doc) / 2.0 - 1] for doc in docs_ts] 147 | 148 | # Getting lists of words and doc_indices 149 | print('creating lists of words...') 150 | 151 | 152 | def create_list_words(in_docs): 153 | return [x for y in in_docs for x in y] 154 | 155 | 156 | words_tr = create_list_words(docs_tr) 157 | words_ts = create_list_words(docs_ts) 158 | words_ts_h1 = create_list_words(docs_ts_h1) 159 | words_ts_h2 = create_list_words(docs_ts_h2) 160 | words_va = create_list_words(docs_va) 161 | 162 | print(' len(words_tr): ', len(words_tr)) 163 | print(' len(words_ts): ', len(words_ts)) 164 | print(' len(words_ts_h1): ', len(words_ts_h1)) 165 | print(' len(words_ts_h2): ', len(words_ts_h2)) 166 | print(' len(words_va): ', len(words_va)) 167 | 168 | # Get doc indices 169 | print('getting doc indices...') 170 | 171 | 172 | def create_doc_indices(in_docs): 173 | aux = [[j for i in range(len(doc))] for j, doc in enumerate(in_docs)] 174 | return [int(x) for y in aux for x in y] 175 | 176 | 177 | doc_indices_tr = create_doc_indices(docs_tr) 178 | doc_indices_ts = create_doc_indices(docs_ts) 179 | doc_indices_ts_h1 = create_doc_indices(docs_ts_h1) 180 | doc_indices_ts_h2 = create_doc_indices(docs_ts_h2) 181 | doc_indices_va = create_doc_indices(docs_va) 182 | 183 | print(' len(np.unique(doc_indices_tr)): {} [this should be {}]'.format( 184 | len(np.unique(doc_indices_tr)), len(docs_tr))) 185 | print(' len(np.unique(doc_indices_ts)): {} [this should be {}]'.format( 186 | len(np.unique(doc_indices_ts)), len(docs_ts))) 187 | print(' len(np.unique(doc_indices_ts_h1)): {} [this should be {}]'.format( 188 | len(np.unique(doc_indices_ts_h1)), len(docs_ts_h1))) 189 | print(' len(np.unique(doc_indices_ts_h2)): {} [this should be {}]'.format( 190 | len(np.unique(doc_indices_ts_h2)), len(docs_ts_h2))) 191 | print(' len(np.unique(doc_indices_va)): {} [this should be {}]'.format( 192 | len(np.unique(doc_indices_va)), len(docs_va))) 193 | 194 | # Number of documents in each set 195 | n_docs_tr = len(docs_tr) 196 | n_docs_ts = len(docs_ts) 197 | n_docs_ts_h1 = len(docs_ts_h1) 198 | n_docs_ts_h2 = len(docs_ts_h2) 199 | n_docs_va = len(docs_va) 200 | 201 | # Remove unused variables 202 | del docs_tr 203 | del docs_ts 204 | del docs_ts_h1 205 | del docs_ts_h2 206 | del docs_va 207 | 208 | # Create bow representation 209 | print('creating bow representation...') 210 | 211 | 212 | def create_bow(doc_indices, words, n_docs, vocab_size): 213 | return sparse.coo_matrix( 214 | ([1] * len(doc_indices), 215 | (doc_indices, 216 | words)), 217 | shape=( 218 | n_docs, 219 | vocab_size)).tocsr() 220 | 221 | 222 | bow_tr = create_bow(doc_indices_tr, words_tr, n_docs_tr, len(vocab)) 223 | bow_ts = create_bow(doc_indices_ts, words_ts, n_docs_ts, len(vocab)) 224 | bow_ts_h1 = create_bow( 225 | doc_indices_ts_h1, 226 | words_ts_h1, 227 | n_docs_ts_h1, 228 | len(vocab)) 229 | bow_ts_h2 = create_bow( 230 | doc_indices_ts_h2, 231 | words_ts_h2, 232 | n_docs_ts_h2, 233 | len(vocab)) 234 | bow_va = create_bow(doc_indices_va, words_va, n_docs_va, len(vocab)) 235 | 236 | del words_tr 237 | del words_ts 238 | del words_ts_h1 239 | del words_ts_h2 240 | del words_va 241 | del doc_indices_tr 242 | del doc_indices_ts 243 | del doc_indices_ts_h1 244 | del doc_indices_ts_h2 245 | del doc_indices_va 246 | 247 | # Write the vocabulary to a file 248 | path_save = './min_df_' + str(min_df) + '/' 249 | if not os.path.isdir(path_save): 250 | os.system('mkdir -p ' + path_save) 251 | 252 | with open(path_save + 'vocab.pkl', 'wb') as f: 253 | pickle.dump(vocab, f) 254 | del vocab 255 | 256 | # Split bow intro token/value pairs 257 | print('splitting bow intro token/value pairs and saving to disk...') 258 | 259 | 260 | def split_bow(bow_in, n_docs): 261 | indices = [[w for w in bow_in[doc, :].indices] for doc in range(n_docs)] 262 | counts = [[c for c in bow_in[doc, :].data] for doc in range(n_docs)] 263 | return indices, counts 264 | 265 | 266 | bow_tr_tokens, bow_tr_counts = split_bow(bow_tr, n_docs_tr) 267 | savemat(path_save + 'bow_tr_tokens', 268 | {'tokens': bow_tr_tokens}, 269 | do_compression=True) 270 | savemat(path_save + 'bow_tr_counts', 271 | {'counts': bow_tr_counts}, 272 | do_compression=True) 273 | del bow_tr 274 | del bow_tr_tokens 275 | del bow_tr_counts 276 | 277 | bow_ts_tokens, bow_ts_counts = split_bow(bow_ts, n_docs_ts) 278 | savemat(path_save + 'bow_ts_tokens', 279 | {'tokens': bow_ts_tokens}, 280 | do_compression=True) 281 | savemat(path_save + 'bow_ts_counts', 282 | {'counts': bow_ts_counts}, 283 | do_compression=True) 284 | del bow_ts 285 | del bow_ts_tokens 286 | del bow_ts_counts 287 | 288 | bow_ts_h1_tokens, bow_ts_h1_counts = split_bow(bow_ts_h1, n_docs_ts_h1) 289 | savemat(path_save + 'bow_ts_h1_tokens', 290 | {'tokens': bow_ts_h1_tokens}, 291 | do_compression=True) 292 | savemat(path_save + 'bow_ts_h1_counts', 293 | {'counts': bow_ts_h1_counts}, 294 | do_compression=True) 295 | del bow_ts_h1 296 | del bow_ts_h1_tokens 297 | del bow_ts_h1_counts 298 | 299 | bow_ts_h2_tokens, bow_ts_h2_counts = split_bow(bow_ts_h2, n_docs_ts_h2) 300 | savemat(path_save + 'bow_ts_h2_tokens', 301 | {'tokens': bow_ts_h2_tokens}, 302 | do_compression=True) 303 | savemat(path_save + 'bow_ts_h2_counts', 304 | {'counts': bow_ts_h2_counts}, 305 | do_compression=True) 306 | del bow_ts_h2 307 | del bow_ts_h2_tokens 308 | del bow_ts_h2_counts 309 | 310 | bow_va_tokens, bow_va_counts = split_bow(bow_va, n_docs_va) 311 | savemat(path_save + 'bow_va_tokens', 312 | {'tokens': bow_va_tokens}, 313 | do_compression=True) 314 | savemat(path_save + 'bow_va_counts', 315 | {'counts': bow_va_counts}, 316 | do_compression=True) 317 | del bow_va 318 | del bow_va_tokens 319 | del bow_va_counts 320 | 321 | print('Data ready !!') 322 | print('*************') 323 | -------------------------------------------------------------------------------- /embedded_topic_model/scripts/datasets/data_nyt.py: -------------------------------------------------------------------------------- 1 | from sklearn.feature_extraction.text import CountVectorizer 2 | import numpy as np 3 | import pickle 4 | from scipy import sparse 5 | from scipy.io import savemat 6 | import os 7 | 8 | # Maximum / minimum document frequency 9 | max_df = 0.7 10 | min_df = 100 # choose desired value for min_df 11 | 12 | # Read stopwords 13 | with open('stops.txt', 'r') as f: 14 | stops = f.read().split('\n') 15 | 16 | # Read data 17 | print('reading text file...') 18 | data_file = 'raw/new_york_times_text/nyt_docs.txt' 19 | with open(data_file, 'r') as f: 20 | docs = f.readlines() 21 | 22 | # Create count vectorizer 23 | print('counting document frequency of words...') 24 | cvectorizer = CountVectorizer(min_df=min_df, max_df=max_df, stop_words=None) 25 | cvz = cvectorizer.fit_transform(docs).sign() 26 | 27 | # Get vocabulary 28 | print('building the vocabulary...') 29 | sum_counts = cvz.sum(axis=0) 30 | v_size = sum_counts.shape[1] 31 | sum_counts_np = np.zeros(v_size, dtype=int) 32 | for v in range(v_size): 33 | sum_counts_np[v] = sum_counts[0, v] 34 | word2id = dict([(w, cvectorizer.vocabulary_.get(w)) 35 | for w in cvectorizer.vocabulary_]) 36 | id2word = dict([(cvectorizer.vocabulary_.get(w), w) 37 | for w in cvectorizer.vocabulary_]) 38 | del cvectorizer 39 | print(' initial vocabulary size: {}'.format(v_size)) 40 | 41 | # Sort elements in vocabulary 42 | idx_sort = np.argsort(sum_counts_np) 43 | vocab_aux = [id2word[idx_sort[cc]] for cc in range(v_size)] 44 | 45 | # Filter out stopwords (if any) 46 | vocab_aux = [w for w in vocab_aux if w not in stops] 47 | print( 48 | ' vocabulary size after removing stopwords from list: {}'.format( 49 | len(vocab_aux))) 50 | print(' vocabulary after removing stopwords: {}'.format(len(vocab_aux))) 51 | 52 | # Create dictionary and inverse dictionary 53 | vocab = vocab_aux 54 | del vocab_aux 55 | word2id = dict([(w, j) for j, w in enumerate(vocab)]) 56 | id2word = dict([(j, w) for j, w in enumerate(vocab)]) 57 | 58 | # Split in train/test/valid 59 | print('tokenizing documents and splitting into train/test/valid...') 60 | num_docs = cvz.shape[0] 61 | trSize = int(np.floor(0.85 * num_docs)) 62 | tsSize = int(np.floor(0.10 * num_docs)) 63 | vaSize = int(num_docs - trSize - tsSize) 64 | del cvz 65 | idx_permute = np.random.permutation(num_docs).astype(int) 66 | 67 | # Remove words not in train_data 68 | vocab = list(set([w for idx_d in range(trSize) 69 | for w in docs[idx_permute[idx_d]].split() if w in word2id])) 70 | word2id = dict([(w, j) for j, w in enumerate(vocab)]) 71 | id2word = dict([(j, w) for j, w in enumerate(vocab)]) 72 | print(' vocabulary after removing words not in train: {}'.format(len(vocab))) 73 | 74 | docs_tr = [[word2id[w] for w in docs[idx_permute[idx_d]].split() if w in word2id] 75 | for idx_d in range(trSize)] 76 | docs_ts = [[word2id[w] for w in docs[idx_permute[idx_d + trSize] 77 | ].split() if w in word2id] for idx_d in range(tsSize)] 78 | docs_va = [[word2id[w] for w in docs[idx_permute[idx_d + trSize + 79 | tsSize]].split() if w in word2id] for idx_d in range(vaSize)] 80 | del docs 81 | 82 | print( 83 | ' number of documents (train): {} [this should be equal to {}]'.format( 84 | len(docs_tr), 85 | trSize)) 86 | print( 87 | ' number of documents (test): {} [this should be equal to {}]'.format( 88 | len(docs_ts), 89 | tsSize)) 90 | print( 91 | ' number of documents (valid): {} [this should be equal to {}]'.format( 92 | len(docs_va), 93 | vaSize)) 94 | 95 | # Remove empty documents 96 | print('removing empty documents...') 97 | 98 | 99 | def remove_empty(in_docs): 100 | return [doc for doc in in_docs if doc != []] 101 | 102 | 103 | docs_tr = remove_empty(docs_tr) 104 | docs_ts = remove_empty(docs_ts) 105 | docs_va = remove_empty(docs_va) 106 | 107 | # Remove test documents with length=1 108 | docs_ts = [doc for doc in docs_ts if len(doc) > 1] 109 | 110 | # Split test set in 2 halves 111 | print('splitting test documents in 2 halves...') 112 | docs_ts_h1 = [[w for i, w in enumerate( 113 | doc) if i <= len(doc) / 2.0 - 1] for doc in docs_ts] 114 | docs_ts_h2 = [[w for i, w in enumerate( 115 | doc) if i > len(doc) / 2.0 - 1] for doc in docs_ts] 116 | 117 | # Getting lists of words and doc_indices 118 | print('creating lists of words...') 119 | 120 | 121 | def create_list_words(in_docs): 122 | return [x for y in in_docs for x in y] 123 | 124 | 125 | words_tr = create_list_words(docs_tr) 126 | words_ts = create_list_words(docs_ts) 127 | words_ts_h1 = create_list_words(docs_ts_h1) 128 | words_ts_h2 = create_list_words(docs_ts_h2) 129 | words_va = create_list_words(docs_va) 130 | 131 | print(' len(words_tr): ', len(words_tr)) 132 | print(' len(words_ts): ', len(words_ts)) 133 | print(' len(words_ts_h1): ', len(words_ts_h1)) 134 | print(' len(words_ts_h2): ', len(words_ts_h2)) 135 | print(' len(words_va): ', len(words_va)) 136 | 137 | # Get doc indices 138 | print('getting doc indices...') 139 | 140 | 141 | def create_doc_indices(in_docs): 142 | aux = [[j for i in range(len(doc))] for j, doc in enumerate(in_docs)] 143 | return [int(x) for y in aux for x in y] 144 | 145 | 146 | doc_indices_tr = create_doc_indices(docs_tr) 147 | doc_indices_ts = create_doc_indices(docs_ts) 148 | doc_indices_ts_h1 = create_doc_indices(docs_ts_h1) 149 | doc_indices_ts_h2 = create_doc_indices(docs_ts_h2) 150 | doc_indices_va = create_doc_indices(docs_va) 151 | 152 | print(' len(np.unique(doc_indices_tr)): {} [this should be {}]'.format( 153 | len(np.unique(doc_indices_tr)), len(docs_tr))) 154 | print(' len(np.unique(doc_indices_ts)): {} [this should be {}]'.format( 155 | len(np.unique(doc_indices_ts)), len(docs_ts))) 156 | print(' len(np.unique(doc_indices_ts_h1)): {} [this should be {}]'.format( 157 | len(np.unique(doc_indices_ts_h1)), len(docs_ts_h1))) 158 | print(' len(np.unique(doc_indices_ts_h2)): {} [this should be {}]'.format( 159 | len(np.unique(doc_indices_ts_h2)), len(docs_ts_h2))) 160 | print(' len(np.unique(doc_indices_va)): {} [this should be {}]'.format( 161 | len(np.unique(doc_indices_va)), len(docs_va))) 162 | 163 | # Number of documents in each set 164 | n_docs_tr = len(docs_tr) 165 | n_docs_ts = len(docs_ts) 166 | n_docs_ts_h1 = len(docs_ts_h1) 167 | n_docs_ts_h2 = len(docs_ts_h2) 168 | n_docs_va = len(docs_va) 169 | 170 | # Remove unused variables 171 | del docs_tr 172 | del docs_ts 173 | del docs_ts_h1 174 | del docs_ts_h2 175 | del docs_va 176 | 177 | # Create bow representation 178 | print('creating bow representation...') 179 | 180 | 181 | def create_bow(doc_indices, words, n_docs, vocab_size): 182 | return sparse.coo_matrix( 183 | ([1] * len(doc_indices), 184 | (doc_indices, 185 | words)), 186 | shape=( 187 | n_docs, 188 | vocab_size)).tocsr() 189 | 190 | 191 | bow_tr = create_bow(doc_indices_tr, words_tr, n_docs_tr, len(vocab)) 192 | bow_ts = create_bow(doc_indices_ts, words_ts, n_docs_ts, len(vocab)) 193 | bow_ts_h1 = create_bow( 194 | doc_indices_ts_h1, 195 | words_ts_h1, 196 | n_docs_ts_h1, 197 | len(vocab)) 198 | bow_ts_h2 = create_bow( 199 | doc_indices_ts_h2, 200 | words_ts_h2, 201 | n_docs_ts_h2, 202 | len(vocab)) 203 | bow_va = create_bow(doc_indices_va, words_va, n_docs_va, len(vocab)) 204 | 205 | del words_tr 206 | del words_ts 207 | del words_ts_h1 208 | del words_ts_h2 209 | del words_va 210 | del doc_indices_tr 211 | del doc_indices_ts 212 | del doc_indices_ts_h1 213 | del doc_indices_ts_h2 214 | del doc_indices_va 215 | 216 | # Save vocabulary to file 217 | path_save = './min_df_' + str(min_df) + '/' 218 | if not os.path.isdir(path_save): 219 | os.system('mkdir -p ' + path_save) 220 | 221 | with open(path_save + 'vocab.pkl', 'wb') as f: 222 | pickle.dump(vocab, f) 223 | del vocab 224 | 225 | # Split bow intro token/value pairs 226 | print('splitting bow intro token/value pairs and saving to disk...') 227 | 228 | 229 | def split_bow(bow_in, n_docs): 230 | indices = [[w for w in bow_in[doc, :].indices] for doc in range(n_docs)] 231 | counts = [[c for c in bow_in[doc, :].data] for doc in range(n_docs)] 232 | return indices, counts 233 | 234 | 235 | bow_tr_tokens, bow_tr_counts = split_bow(bow_tr, n_docs_tr) 236 | savemat(path_save + 'bow_tr_tokens', 237 | {'tokens': bow_tr_tokens}, 238 | do_compression=True) 239 | savemat(path_save + 'bow_tr_counts', 240 | {'counts': bow_tr_counts}, 241 | do_compression=True) 242 | del bow_tr 243 | del bow_tr_tokens 244 | del bow_tr_counts 245 | 246 | bow_ts_tokens, bow_ts_counts = split_bow(bow_ts, n_docs_ts) 247 | savemat(path_save + 'bow_ts_tokens', 248 | {'tokens': bow_ts_tokens}, 249 | do_compression=True) 250 | savemat(path_save + 'bow_ts_counts', 251 | {'counts': bow_ts_counts}, 252 | do_compression=True) 253 | del bow_ts 254 | del bow_ts_tokens 255 | del bow_ts_counts 256 | 257 | bow_ts_h1_tokens, bow_ts_h1_counts = split_bow(bow_ts_h1, n_docs_ts_h1) 258 | savemat(path_save + 'bow_ts_h1_tokens', 259 | {'tokens': bow_ts_h1_tokens}, 260 | do_compression=True) 261 | savemat(path_save + 'bow_ts_h1_counts', 262 | {'counts': bow_ts_h1_counts}, 263 | do_compression=True) 264 | del bow_ts_h1 265 | del bow_ts_h1_tokens 266 | del bow_ts_h1_counts 267 | 268 | bow_ts_h2_tokens, bow_ts_h2_counts = split_bow(bow_ts_h2, n_docs_ts_h2) 269 | savemat(path_save + 'bow_ts_h2_tokens', 270 | {'tokens': bow_ts_h2_tokens}, 271 | do_compression=True) 272 | savemat(path_save + 'bow_ts_h2_counts', 273 | {'counts': bow_ts_h2_counts}, 274 | do_compression=True) 275 | del bow_ts_h2 276 | del bow_ts_h2_tokens 277 | del bow_ts_h2_counts 278 | 279 | bow_va_tokens, bow_va_counts = split_bow(bow_va, n_docs_va) 280 | savemat(path_save + 'bow_va_tokens', 281 | {'tokens': bow_va_tokens}, 282 | do_compression=True) 283 | savemat(path_save + 'bow_va_counts', 284 | {'counts': bow_va_counts}, 285 | do_compression=True) 286 | del bow_va 287 | del bow_va_tokens 288 | del bow_va_counts 289 | 290 | print('Data ready !!') 291 | print('*************') 292 | -------------------------------------------------------------------------------- /embedded_topic_model/scripts/datasets/data_reddit_nouns_only.py: -------------------------------------------------------------------------------- 1 | from sklearn.feature_extraction.text import CountVectorizer 2 | import numpy as np 3 | import pickle 4 | from scipy import sparse 5 | from scipy.io import savemat 6 | import json 7 | from nltk.corpus import stopwords 8 | import os 9 | 10 | # Maximum / minimum document frequency 11 | # max_df = 0.85 12 | # min_df = 0.01 # choose desired value for min_df 13 | max_df = 1.0 14 | min_df = 1 15 | 16 | # min_df=0.01, max_df=0.85 17 | 18 | # Read stopwords 19 | # with open('stops.txt', 'r') as f: 20 | # stops = f.read().split('\n') 21 | stops = stopwords.words("portuguese") 22 | 23 | # Read data 24 | print('reading json file...') 25 | # data_file = 'reddit_gatherer.pt_submissions[2008_2020]_[preprocessed_dataset][nouns_only].json' 26 | # data_file = "../../datasets/reddit_gatherer.pt_submissions[original_dataset][2008_2020].json" 27 | data_file = "data/processed/reddit_gatherer.pt_submissions[original_dataset][2008_2020][nouns].json" 28 | docs = [ 29 | " ".join( 30 | data["body"]) for data in json.load( 31 | open( 32 | data_file, 33 | 'r'))] # Line for nouns only corpus 34 | # docs = [ data["body"] for data in json.load(open(data_file, 'r')) ] 35 | # #Line for raw corpus 36 | docs[0] 37 | 38 | # Create count vectorizer 39 | print('counting document frequency of words...') 40 | cvectorizer = CountVectorizer(min_df=min_df, max_df=max_df, stop_words=None) 41 | cvz = cvectorizer.fit_transform(docs).sign() 42 | 43 | # Get vocabulary 44 | print('building the vocabulary...') 45 | sum_counts = cvz.sum(axis=0) 46 | v_size = sum_counts.shape[1] 47 | sum_counts_np = np.zeros(v_size, dtype=int) 48 | for v in range(v_size): 49 | sum_counts_np[v] = sum_counts[0, v] 50 | word2id = dict([(w, cvectorizer.vocabulary_.get(w)) 51 | for w in cvectorizer.vocabulary_]) 52 | id2word = dict([(cvectorizer.vocabulary_.get(w), w) 53 | for w in cvectorizer.vocabulary_]) 54 | del cvectorizer 55 | print(' initial vocabulary size: {}'.format(v_size)) 56 | 57 | # Sort elements in vocabulary 58 | idx_sort = np.argsort(sum_counts_np) 59 | vocab_aux = [id2word[idx_sort[cc]] for cc in range(v_size)] 60 | 61 | # Filter out stopwords (if any) 62 | vocab_aux = [w for w in vocab_aux if w not in stops] 63 | print( 64 | ' vocabulary size after removing stopwords from list: {}'.format( 65 | len(vocab_aux))) 66 | print(' vocabulary after removing stopwords: {}'.format(len(vocab_aux))) 67 | 68 | # Create dictionary and inverse dictionary 69 | vocab = vocab_aux 70 | del vocab_aux 71 | word2id = dict([(w, j) for j, w in enumerate(vocab)]) 72 | id2word = dict([(j, w) for j, w in enumerate(vocab)]) 73 | 74 | # Split in train/test/valid 75 | print('tokenizing documents and splitting into train/test/valid...') 76 | num_docs = cvz.shape[0] 77 | trSize = int(np.floor(0.85 * num_docs)) 78 | tsSize = int(np.floor(0.10 * num_docs)) 79 | # trSize = int(np.floor(1.0*num_docs)) 80 | # tsSize = int(np.floor(0*num_docs)) 81 | vaSize = int(num_docs - trSize - tsSize) 82 | del cvz 83 | idx_permute = np.random.permutation(num_docs).astype(int) 84 | print("idx_permute: ", idx_permute) 85 | 86 | # Remove words not in train_data 87 | vocab = list(set([w for idx_d in range(trSize) 88 | for w in docs[idx_permute[idx_d]].split() if w in word2id])) 89 | word2id = dict([(w, j) for j, w in enumerate(vocab)]) 90 | id2word = dict([(j, w) for j, w in enumerate(vocab)]) 91 | print(' vocabulary after removing words not in train: {}'.format(len(vocab))) 92 | 93 | docs_tr = [[word2id[w] for w in docs[idx_permute[idx_d]].split() if w in word2id] 94 | for idx_d in range(trSize)] 95 | docs_ts = [[word2id[w] for w in docs[idx_permute[idx_d + trSize] 96 | ].split() if w in word2id] for idx_d in range(tsSize)] 97 | docs_va = [[word2id[w] for w in docs[idx_permute[idx_d + trSize + 98 | tsSize]].split() if w in word2id] for idx_d in range(vaSize)] 99 | del docs 100 | 101 | print( 102 | ' number of documents (train): {} [this should be equal to {}]'.format( 103 | len(docs_tr), 104 | trSize)) 105 | print( 106 | ' number of documents (test): {} [this should be equal to {}]'.format( 107 | len(docs_ts), 108 | tsSize)) 109 | print( 110 | ' number of documents (valid): {} [this should be equal to {}]'.format( 111 | len(docs_va), 112 | vaSize)) 113 | 114 | # Remove empty documents 115 | print('removing empty documents...') 116 | 117 | 118 | def remove_empty(in_docs): 119 | return [doc for doc in in_docs if doc != []] 120 | 121 | 122 | # docs_tr = remove_empty(docs_tr) 123 | # docs_ts = remove_empty(docs_ts) 124 | # docs_va = remove_empty(docs_va) 125 | docs_tr = remove_empty(docs_tr) 126 | print(f'docs_tr[0]: {docs_tr[0]}') 127 | docs_ts = remove_empty(docs_ts) 128 | print(f'docs_ts[0]: {docs_ts[0]}') 129 | docs_va = remove_empty(docs_va) 130 | print(f'docs_va[0]: {docs_va[0]}') 131 | 132 | # Remove test documents with length=1 133 | docs_ts = [doc for doc in docs_ts if len(doc) > 1] 134 | 135 | # Split test set in 2 halves 136 | print('splitting test documents in 2 halves...') 137 | docs_ts_h1 = [[w for i, w in enumerate( 138 | doc) if i <= len(doc) / 2.0 - 1] for doc in docs_ts] 139 | docs_ts_h2 = [[w for i, w in enumerate( 140 | doc) if i > len(doc) / 2.0 - 1] for doc in docs_ts] 141 | 142 | # Getting lists of words and doc_indices 143 | print('creating lists of words...') 144 | 145 | 146 | def create_list_words(in_docs): 147 | return [x for y in in_docs for x in y] 148 | 149 | 150 | words_tr = create_list_words(docs_tr) 151 | words_ts = create_list_words(docs_ts) 152 | words_ts_h1 = create_list_words(docs_ts_h1) 153 | words_ts_h2 = create_list_words(docs_ts_h2) 154 | words_va = create_list_words(docs_va) 155 | 156 | print(' len(words_tr): ', len(words_tr)) 157 | print(' len(words_ts): ', len(words_ts)) 158 | print(' len(words_ts_h1): ', len(words_ts_h1)) 159 | print(' len(words_ts_h2): ', len(words_ts_h2)) 160 | print(' len(words_va): ', len(words_va)) 161 | 162 | # Get doc indices 163 | print('getting doc indices...') 164 | 165 | 166 | def create_doc_indices(in_docs): 167 | aux = [[j for i in range(len(doc))] for j, doc in enumerate(in_docs)] 168 | return [int(x) for y in aux for x in y] 169 | 170 | 171 | doc_indices_tr = create_doc_indices(docs_tr) 172 | doc_indices_ts = create_doc_indices(docs_ts) 173 | doc_indices_ts_h1 = create_doc_indices(docs_ts_h1) 174 | doc_indices_ts_h2 = create_doc_indices(docs_ts_h2) 175 | doc_indices_va = create_doc_indices(docs_va) 176 | 177 | print(' len(np.unique(doc_indices_tr)): {} [this should be {}]'.format( 178 | len(np.unique(doc_indices_tr)), len(docs_tr))) 179 | print(' len(np.unique(doc_indices_ts)): {} [this should be {}]'.format( 180 | len(np.unique(doc_indices_ts)), len(docs_ts))) 181 | print(' len(np.unique(doc_indices_ts_h1)): {} [this should be {}]'.format( 182 | len(np.unique(doc_indices_ts_h1)), len(docs_ts_h1))) 183 | print(' len(np.unique(doc_indices_ts_h2)): {} [this should be {}]'.format( 184 | len(np.unique(doc_indices_ts_h2)), len(docs_ts_h2))) 185 | print(' len(np.unique(doc_indices_va)): {} [this should be {}]'.format( 186 | len(np.unique(doc_indices_va)), len(docs_va))) 187 | 188 | # Number of documents in each set 189 | n_docs_tr = len(docs_tr) 190 | n_docs_ts = len(docs_ts) 191 | n_docs_ts_h1 = len(docs_ts_h1) 192 | n_docs_ts_h2 = len(docs_ts_h2) 193 | n_docs_va = len(docs_va) 194 | 195 | # Remove unused variables 196 | del docs_tr 197 | del docs_ts 198 | del docs_ts_h1 199 | del docs_ts_h2 200 | del docs_va 201 | 202 | # Create bow representation 203 | print('creating bow representation...') 204 | 205 | 206 | def create_bow(doc_indices, words, n_docs, vocab_size): 207 | return sparse.coo_matrix( 208 | ([1] * len(doc_indices), 209 | (doc_indices, 210 | words)), 211 | shape=( 212 | n_docs, 213 | vocab_size)).tocsr() 214 | 215 | 216 | bow_tr = create_bow(doc_indices_tr, words_tr, n_docs_tr, len(vocab)) 217 | bow_ts = create_bow(doc_indices_ts, words_ts, n_docs_ts, len(vocab)) 218 | bow_ts_h1 = create_bow( 219 | doc_indices_ts_h1, 220 | words_ts_h1, 221 | n_docs_ts_h1, 222 | len(vocab)) 223 | bow_ts_h2 = create_bow( 224 | doc_indices_ts_h2, 225 | words_ts_h2, 226 | n_docs_ts_h2, 227 | len(vocab)) 228 | bow_va = create_bow(doc_indices_va, words_va, n_docs_va, len(vocab)) 229 | 230 | del words_tr 231 | del words_ts 232 | del words_ts_h1 233 | del words_ts_h2 234 | del words_va 235 | del doc_indices_tr 236 | del doc_indices_ts 237 | del doc_indices_ts_h1 238 | del doc_indices_ts_h2 239 | del doc_indices_va 240 | 241 | # Save vocabulary to file 242 | path_save = './TEST_min_df_' + str(min_df) + '/' 243 | if not os.path.isdir(path_save): 244 | os.system('mkdir -p ' + path_save) 245 | 246 | with open(path_save + 'vocab.pkl', 'wb') as f: 247 | pickle.dump(vocab, f) 248 | del vocab 249 | 250 | # Split bow intro token/value pairs 251 | print('splitting bow intro token/value pairs and saving to disk...') 252 | 253 | 254 | def split_bow(bow_in, n_docs): 255 | indices = [[w for w in bow_in[doc, :].indices] for doc in range(n_docs)] 256 | counts = [[c for c in bow_in[doc, :].data] for doc in range(n_docs)] 257 | return indices, counts 258 | 259 | 260 | bow_tr_tokens, bow_tr_counts = split_bow(bow_tr, n_docs_tr) 261 | savemat(path_save + 'bow_tr_tokens.mat', 262 | {'tokens': bow_tr_tokens}, 263 | do_compression=True) 264 | savemat(path_save + 'bow_tr_counts.mat', 265 | {'counts': bow_tr_counts}, 266 | do_compression=True) 267 | del bow_tr 268 | del bow_tr_tokens 269 | del bow_tr_counts 270 | 271 | bow_ts_tokens, bow_ts_counts = split_bow(bow_ts, n_docs_ts) 272 | savemat(path_save + 'bow_ts_tokens.mat', 273 | {'tokens': bow_ts_tokens}, 274 | do_compression=True) 275 | savemat(path_save + 'bow_ts_counts.mat', 276 | {'counts': bow_ts_counts}, 277 | do_compression=True) 278 | del bow_ts 279 | del bow_ts_tokens 280 | del bow_ts_counts 281 | 282 | bow_ts_h1_tokens, bow_ts_h1_counts = split_bow(bow_ts_h1, n_docs_ts_h1) 283 | savemat(path_save + 'bow_ts_h1_tokens.mat', 284 | {'tokens': bow_ts_h1_tokens}, do_compression=True) 285 | savemat(path_save + 'bow_ts_h1_counts.mat', 286 | {'counts': bow_ts_h1_counts}, do_compression=True) 287 | del bow_ts_h1 288 | del bow_ts_h1_tokens 289 | del bow_ts_h1_counts 290 | 291 | bow_ts_h2_tokens, bow_ts_h2_counts = split_bow(bow_ts_h2, n_docs_ts_h2) 292 | savemat(path_save + 'bow_ts_h2_tokens.mat', 293 | {'tokens': bow_ts_h2_tokens}, do_compression=True) 294 | savemat(path_save + 'bow_ts_h2_counts.mat', 295 | {'counts': bow_ts_h2_counts}, do_compression=True) 296 | del bow_ts_h2 297 | del bow_ts_h2_tokens 298 | del bow_ts_h2_counts 299 | 300 | bow_va_tokens, bow_va_counts = split_bow(bow_va, n_docs_va) 301 | savemat(path_save + 'bow_va_tokens.mat', 302 | {'tokens': bow_va_tokens}, 303 | do_compression=True) 304 | savemat(path_save + 'bow_va_counts.mat', 305 | {'counts': bow_va_counts}, 306 | do_compression=True) 307 | del bow_va 308 | del bow_va_tokens 309 | del bow_va_counts 310 | 311 | print('Data ready !!') 312 | print('*************') 313 | -------------------------------------------------------------------------------- /embedded_topic_model/scripts/datasets/data_reddit_raw_pt.py: -------------------------------------------------------------------------------- 1 | from sklearn.feature_extraction.text import CountVectorizer 2 | import numpy as np 3 | import pickle 4 | from scipy import sparse 5 | from scipy.io import savemat 6 | import json 7 | from nltk.corpus import stopwords 8 | import os 9 | 10 | # Maximum / minimum document frequency 11 | # max_df = 0.85 12 | # min_df = 0.01 # choose desired value for min_df 13 | max_df = 1.0 14 | min_df = 1 15 | 16 | # min_df=0.01, max_df=0.85 17 | 18 | # Read stopwords 19 | # with open('stops.txt', 'r') as f: 20 | # stops = f.read().split('\n') 21 | stops = stopwords.words("portuguese") 22 | 23 | # Read data 24 | print('reading json file...') 25 | # data_file = 'reddit_gatherer.pt_submissions[2008_2020]_[preprocessed_dataset][nouns_only].json' 26 | # data_file = "../../datasets/reddit_gatherer.pt_submissions[original_dataset][2008_2020].json" 27 | data_file = "data/processed/reddit_gatherer.pt_submissions[original_dataset][2008_2020]_[original_dataset_without_duplicates][processed].json" 28 | docs = [ 29 | " ".join( 30 | data["body"]) for data in json.load( 31 | open( 32 | data_file, 33 | 'r'))] # Line for nouns only corpus 34 | # docs = [ data["body"] for data in json.load(open(data_file, 'r')) ] 35 | # #Line for raw corpus 36 | docs[0] 37 | 38 | # Create count vectorizer 39 | print('counting document frequency of words...') 40 | cvectorizer = CountVectorizer(min_df=min_df, max_df=max_df, stop_words=None) 41 | cvz = cvectorizer.fit_transform(docs).sign() 42 | 43 | # Get vocabulary 44 | print('building the vocabulary...') 45 | sum_counts = cvz.sum(axis=0) 46 | v_size = sum_counts.shape[1] 47 | sum_counts_np = np.zeros(v_size, dtype=int) 48 | for v in range(v_size): 49 | sum_counts_np[v] = sum_counts[0, v] 50 | word2id = dict([(w, cvectorizer.vocabulary_.get(w)) 51 | for w in cvectorizer.vocabulary_]) 52 | id2word = dict([(cvectorizer.vocabulary_.get(w), w) 53 | for w in cvectorizer.vocabulary_]) 54 | del cvectorizer 55 | print(' initial vocabulary size: {}'.format(v_size)) 56 | 57 | # Sort elements in vocabulary 58 | idx_sort = np.argsort(sum_counts_np) 59 | vocab_aux = [id2word[idx_sort[cc]] for cc in range(v_size)] 60 | 61 | # Filter out stopwords (if any) 62 | # vocab_aux = [w for w in vocab_aux if w not in stops] 63 | # print(' vocabulary size after removing stopwords from list: {}'.format(len(vocab_aux))) 64 | # print(' vocabulary after removing stopwords: {}'.format(len(vocab_aux))) 65 | 66 | # Create dictionary and inverse dictionary 67 | vocab = vocab_aux 68 | del vocab_aux 69 | word2id = dict([(w, j) for j, w in enumerate(vocab)]) 70 | id2word = dict([(j, w) for j, w in enumerate(vocab)]) 71 | 72 | # Split in train/test/valid 73 | print('tokenizing documents and splitting into train/test/valid...') 74 | num_docs = cvz.shape[0] 75 | trSize = int(np.floor(0.85 * num_docs)) 76 | tsSize = int(np.floor(0.10 * num_docs)) 77 | # trSize = int(np.floor(1.0*num_docs)) 78 | # tsSize = int(np.floor(0*num_docs)) 79 | vaSize = int(num_docs - trSize - tsSize) 80 | del cvz 81 | idx_permute = np.random.permutation(num_docs).astype(int) 82 | 83 | # Remove words not in train_data 84 | vocab = list(set([w for idx_d in range(trSize) 85 | for w in docs[idx_permute[idx_d]].split() if w in word2id])) 86 | word2id = dict([(w, j) for j, w in enumerate(vocab)]) 87 | id2word = dict([(j, w) for j, w in enumerate(vocab)]) 88 | print(' vocabulary after removing words not in train: {}'.format(len(vocab))) 89 | 90 | docs_tr = [[word2id[w] for w in docs[idx_permute[idx_d]].split() if w in word2id] 91 | for idx_d in range(trSize)] 92 | docs_ts = [[word2id[w] for w in docs[idx_permute[idx_d + trSize] 93 | ].split() if w in word2id] for idx_d in range(tsSize)] 94 | docs_va = [[word2id[w] for w in docs[idx_permute[idx_d + trSize + 95 | tsSize]].split() if w in word2id] for idx_d in range(vaSize)] 96 | del docs 97 | 98 | print( 99 | ' number of documents (train): {} [this should be equal to {}]'.format( 100 | len(docs_tr), 101 | trSize)) 102 | print( 103 | ' number of documents (test): {} [this should be equal to {}]'.format( 104 | len(docs_ts), 105 | tsSize)) 106 | print( 107 | ' number of documents (valid): {} [this should be equal to {}]'.format( 108 | len(docs_va), 109 | vaSize)) 110 | 111 | # Remove empty documents 112 | print('removing empty documents...') 113 | 114 | 115 | def remove_empty(in_docs): 116 | return [doc for doc in in_docs if doc != []] 117 | 118 | 119 | docs_tr = remove_empty(docs_tr) 120 | docs_ts = remove_empty(docs_ts) 121 | docs_va = remove_empty(docs_va) 122 | 123 | # Remove test documents with length=1 124 | docs_ts = [doc for doc in docs_ts if len(doc) > 1] 125 | 126 | # Split test set in 2 halves 127 | print('splitting test documents in 2 halves...') 128 | docs_ts_h1 = [[w for i, w in enumerate( 129 | doc) if i <= len(doc) / 2.0 - 1] for doc in docs_ts] 130 | docs_ts_h2 = [[w for i, w in enumerate( 131 | doc) if i > len(doc) / 2.0 - 1] for doc in docs_ts] 132 | 133 | # Getting lists of words and doc_indices 134 | print('creating lists of words...') 135 | 136 | 137 | def create_list_words(in_docs): 138 | return [x for y in in_docs for x in y] 139 | 140 | 141 | words_tr = create_list_words(docs_tr) 142 | words_ts = create_list_words(docs_ts) 143 | words_ts_h1 = create_list_words(docs_ts_h1) 144 | words_ts_h2 = create_list_words(docs_ts_h2) 145 | words_va = create_list_words(docs_va) 146 | 147 | print(' len(words_tr): ', len(words_tr)) 148 | print(' len(words_ts): ', len(words_ts)) 149 | print(' len(words_ts_h1): ', len(words_ts_h1)) 150 | print(' len(words_ts_h2): ', len(words_ts_h2)) 151 | print(' len(words_va): ', len(words_va)) 152 | 153 | # Get doc indices 154 | print('getting doc indices...') 155 | 156 | 157 | def create_doc_indices(in_docs): 158 | aux = [[j for i in range(len(doc))] for j, doc in enumerate(in_docs)] 159 | return [int(x) for y in aux for x in y] 160 | 161 | 162 | doc_indices_tr = create_doc_indices(docs_tr) 163 | doc_indices_ts = create_doc_indices(docs_ts) 164 | doc_indices_ts_h1 = create_doc_indices(docs_ts_h1) 165 | doc_indices_ts_h2 = create_doc_indices(docs_ts_h2) 166 | doc_indices_va = create_doc_indices(docs_va) 167 | 168 | print(' len(np.unique(doc_indices_tr)): {} [this should be {}]'.format( 169 | len(np.unique(doc_indices_tr)), len(docs_tr))) 170 | print(' len(np.unique(doc_indices_ts)): {} [this should be {}]'.format( 171 | len(np.unique(doc_indices_ts)), len(docs_ts))) 172 | print(' len(np.unique(doc_indices_ts_h1)): {} [this should be {}]'.format( 173 | len(np.unique(doc_indices_ts_h1)), len(docs_ts_h1))) 174 | print(' len(np.unique(doc_indices_ts_h2)): {} [this should be {}]'.format( 175 | len(np.unique(doc_indices_ts_h2)), len(docs_ts_h2))) 176 | print(' len(np.unique(doc_indices_va)): {} [this should be {}]'.format( 177 | len(np.unique(doc_indices_va)), len(docs_va))) 178 | 179 | # Number of documents in each set 180 | n_docs_tr = len(docs_tr) 181 | n_docs_ts = len(docs_ts) 182 | n_docs_ts_h1 = len(docs_ts_h1) 183 | n_docs_ts_h2 = len(docs_ts_h2) 184 | n_docs_va = len(docs_va) 185 | 186 | # Remove unused variables 187 | del docs_tr 188 | del docs_ts 189 | del docs_ts_h1 190 | del docs_ts_h2 191 | del docs_va 192 | 193 | # Create bow representation 194 | print('creating bow representation...') 195 | 196 | 197 | def create_bow(doc_indices, words, n_docs, vocab_size): 198 | return sparse.coo_matrix( 199 | ([1] * len(doc_indices), 200 | (doc_indices, 201 | words)), 202 | shape=( 203 | n_docs, 204 | vocab_size)).tocsr() 205 | 206 | 207 | bow_tr = create_bow(doc_indices_tr, words_tr, n_docs_tr, len(vocab)) 208 | bow_ts = create_bow(doc_indices_ts, words_ts, n_docs_ts, len(vocab)) 209 | bow_ts_h1 = create_bow( 210 | doc_indices_ts_h1, 211 | words_ts_h1, 212 | n_docs_ts_h1, 213 | len(vocab)) 214 | bow_ts_h2 = create_bow( 215 | doc_indices_ts_h2, 216 | words_ts_h2, 217 | n_docs_ts_h2, 218 | len(vocab)) 219 | bow_va = create_bow(doc_indices_va, words_va, n_docs_va, len(vocab)) 220 | 221 | del words_tr 222 | del words_ts 223 | del words_ts_h1 224 | del words_ts_h2 225 | del words_va 226 | del doc_indices_tr 227 | del doc_indices_ts 228 | del doc_indices_ts_h1 229 | del doc_indices_ts_h2 230 | del doc_indices_va 231 | 232 | # Save vocabulary to file 233 | path_save = './min_df_' + str(min_df) + '/' 234 | if not os.path.isdir(path_save): 235 | os.system('mkdir -p ' + path_save) 236 | 237 | with open(path_save + 'vocab.pkl', 'wb') as f: 238 | pickle.dump(vocab, f) 239 | del vocab 240 | 241 | # Split bow intro token/value pairs 242 | print('splitting bow intro token/value pairs and saving to disk...') 243 | 244 | 245 | def split_bow(bow_in, n_docs): 246 | indices = [[w for w in bow_in[doc, :].indices] for doc in range(n_docs)] 247 | counts = [[c for c in bow_in[doc, :].data] for doc in range(n_docs)] 248 | return indices, counts 249 | 250 | 251 | bow_tr_tokens, bow_tr_counts = split_bow(bow_tr, n_docs_tr) 252 | savemat(path_save + 'bow_tr_tokens.mat', 253 | {'tokens': bow_tr_tokens}, 254 | do_compression=True) 255 | savemat(path_save + 'bow_tr_counts.mat', 256 | {'counts': bow_tr_counts}, 257 | do_compression=True) 258 | del bow_tr 259 | del bow_tr_tokens 260 | del bow_tr_counts 261 | 262 | bow_ts_tokens, bow_ts_counts = split_bow(bow_ts, n_docs_ts) 263 | savemat(path_save + 'bow_ts_tokens.mat', 264 | {'tokens': bow_ts_tokens}, 265 | do_compression=True) 266 | savemat(path_save + 'bow_ts_counts.mat', 267 | {'counts': bow_ts_counts}, 268 | do_compression=True) 269 | del bow_ts 270 | del bow_ts_tokens 271 | del bow_ts_counts 272 | 273 | bow_ts_h1_tokens, bow_ts_h1_counts = split_bow(bow_ts_h1, n_docs_ts_h1) 274 | savemat(path_save + 'bow_ts_h1_tokens.mat', 275 | {'tokens': bow_ts_h1_tokens}, do_compression=True) 276 | savemat(path_save + 'bow_ts_h1_counts.mat', 277 | {'counts': bow_ts_h1_counts}, do_compression=True) 278 | del bow_ts_h1 279 | del bow_ts_h1_tokens 280 | del bow_ts_h1_counts 281 | 282 | bow_ts_h2_tokens, bow_ts_h2_counts = split_bow(bow_ts_h2, n_docs_ts_h2) 283 | savemat(path_save + 'bow_ts_h2_tokens.mat', 284 | {'tokens': bow_ts_h2_tokens}, do_compression=True) 285 | savemat(path_save + 'bow_ts_h2_counts.mat', 286 | {'counts': bow_ts_h2_counts}, do_compression=True) 287 | del bow_ts_h2 288 | del bow_ts_h2_tokens 289 | del bow_ts_h2_counts 290 | 291 | bow_va_tokens, bow_va_counts = split_bow(bow_va, n_docs_va) 292 | savemat(path_save + 'bow_va_tokens.mat', 293 | {'tokens': bow_va_tokens}, 294 | do_compression=True) 295 | savemat(path_save + 'bow_va_counts.mat', 296 | {'counts': bow_va_counts}, 297 | do_compression=True) 298 | del bow_va 299 | del bow_va_tokens 300 | del bow_va_counts 301 | 302 | print('Data ready !!') 303 | print('*************') 304 | -------------------------------------------------------------------------------- /embedded_topic_model/scripts/datasets/stops.txt: -------------------------------------------------------------------------------- 1 | a 2 | able 3 | about 4 | above 5 | according 6 | accordingly 7 | across 8 | actually 9 | after 10 | afterwards 11 | again 12 | against 13 | all 14 | allow 15 | allows 16 | almost 17 | alone 18 | along 19 | already 20 | also 21 | although 22 | always 23 | am 24 | among 25 | amongst 26 | an 27 | and 28 | another 29 | any 30 | anybody 31 | anyhow 32 | anyone 33 | anything 34 | anyway 35 | anyways 36 | anywhere 37 | apart 38 | appear 39 | appreciate 40 | appropriate 41 | are 42 | around 43 | as 44 | aside 45 | ask 46 | asking 47 | associated 48 | at 49 | available 50 | away 51 | awfully 52 | b 53 | be 54 | became 55 | because 56 | become 57 | becomes 58 | becoming 59 | been 60 | before 61 | beforehand 62 | behind 63 | being 64 | believe 65 | below 66 | beside 67 | besides 68 | best 69 | better 70 | between 71 | beyond 72 | both 73 | brief 74 | but 75 | by 76 | c 77 | came 78 | can 79 | cannot 80 | cant 81 | cause 82 | causes 83 | certain 84 | certainly 85 | changes 86 | clearly 87 | co 88 | com 89 | come 90 | comes 91 | concerning 92 | consequently 93 | consider 94 | considering 95 | contain 96 | containing 97 | contains 98 | corresponding 99 | could 100 | course 101 | currently 102 | d 103 | definitely 104 | described 105 | despite 106 | did 107 | different 108 | do 109 | does 110 | doing 111 | done 112 | down 113 | downwards 114 | during 115 | e 116 | each 117 | edu 118 | eg 119 | eight 120 | either 121 | else 122 | elsewhere 123 | enough 124 | entirely 125 | especially 126 | et 127 | etc 128 | even 129 | ever 130 | every 131 | everybody 132 | everyone 133 | everything 134 | everywhere 135 | ex 136 | exactly 137 | example 138 | except 139 | f 140 | far 141 | few 142 | fifth 143 | first 144 | five 145 | followed 146 | following 147 | follows 148 | for 149 | former 150 | formerly 151 | forth 152 | four 153 | from 154 | further 155 | furthermore 156 | g 157 | get 158 | gets 159 | getting 160 | given 161 | gives 162 | go 163 | goes 164 | going 165 | gone 166 | got 167 | gotten 168 | greetings 169 | h 170 | had 171 | happens 172 | hardly 173 | has 174 | have 175 | having 176 | he 177 | hello 178 | help 179 | hence 180 | her 181 | here 182 | hereafter 183 | hereby 184 | herein 185 | hereupon 186 | hers 187 | herself 188 | hi 189 | him 190 | himself 191 | his 192 | hither 193 | hopefully 194 | how 195 | howbeit 196 | however 197 | i 198 | ie 199 | if 200 | ignored 201 | immediate 202 | in 203 | inasmuch 204 | inc 205 | indeed 206 | indicate 207 | indicated 208 | indicates 209 | inner 210 | insofar 211 | instead 212 | into 213 | inward 214 | is 215 | it 216 | its 217 | itself 218 | j 219 | just 220 | k 221 | keep 222 | keeps 223 | kept 224 | know 225 | knows 226 | known 227 | l 228 | last 229 | lately 230 | later 231 | latter 232 | latterly 233 | least 234 | less 235 | lest 236 | let 237 | like 238 | liked 239 | likely 240 | little 241 | look 242 | looking 243 | looks 244 | ltd 245 | m 246 | mainly 247 | many 248 | may 249 | maybe 250 | me 251 | mean 252 | meanwhile 253 | merely 254 | might 255 | more 256 | moreover 257 | most 258 | mostly 259 | much 260 | must 261 | my 262 | myself 263 | n 264 | name 265 | namely 266 | nd 267 | near 268 | nearly 269 | necessary 270 | need 271 | needs 272 | neither 273 | never 274 | nevertheless 275 | new 276 | next 277 | nine 278 | no 279 | nobody 280 | non 281 | none 282 | noone 283 | nor 284 | normally 285 | not 286 | nothing 287 | novel 288 | now 289 | nowhere 290 | o 291 | obviously 292 | of 293 | off 294 | often 295 | oh 296 | ok 297 | okay 298 | old 299 | on 300 | once 301 | one 302 | ones 303 | only 304 | onto 305 | or 306 | other 307 | others 308 | otherwise 309 | ought 310 | our 311 | ours 312 | ourselves 313 | out 314 | outside 315 | over 316 | overall 317 | own 318 | p 319 | particular 320 | particularly 321 | per 322 | perhaps 323 | placed 324 | please 325 | plus 326 | possible 327 | presumably 328 | probably 329 | provides 330 | q 331 | que 332 | quite 333 | qv 334 | r 335 | rather 336 | rd 337 | re 338 | really 339 | reasonably 340 | regarding 341 | regardless 342 | regards 343 | relatively 344 | respectively 345 | right 346 | s 347 | said 348 | same 349 | saw 350 | say 351 | saying 352 | says 353 | second 354 | secondly 355 | see 356 | seeing 357 | seem 358 | seemed 359 | seeming 360 | seems 361 | seen 362 | self 363 | selves 364 | sensible 365 | sent 366 | serious 367 | seriously 368 | seven 369 | several 370 | shall 371 | she 372 | should 373 | since 374 | six 375 | so 376 | some 377 | somebody 378 | somehow 379 | someone 380 | something 381 | sometime 382 | sometimes 383 | somewhat 384 | somewhere 385 | soon 386 | sorry 387 | specified 388 | specify 389 | specifying 390 | still 391 | sub 392 | such 393 | sup 394 | sure 395 | t 396 | take 397 | taken 398 | tell 399 | tends 400 | th 401 | than 402 | thank 403 | thanks 404 | thanx 405 | that 406 | thats 407 | the 408 | their 409 | theirs 410 | them 411 | themselves 412 | then 413 | thence 414 | there 415 | thereafter 416 | thereby 417 | therefore 418 | therein 419 | theres 420 | thereupon 421 | these 422 | they 423 | think 424 | third 425 | this 426 | thorough 427 | thoroughly 428 | those 429 | though 430 | three 431 | through 432 | throughout 433 | thru 434 | thus 435 | to 436 | together 437 | too 438 | took 439 | toward 440 | towards 441 | tried 442 | tries 443 | truly 444 | try 445 | trying 446 | twice 447 | two 448 | u 449 | un 450 | under 451 | unfortunately 452 | unless 453 | unlikely 454 | until 455 | unto 456 | up 457 | upon 458 | us 459 | use 460 | used 461 | useful 462 | uses 463 | using 464 | usually 465 | uucp 466 | v 467 | value 468 | various 469 | very 470 | via 471 | viz 472 | vs 473 | w 474 | want 475 | wants 476 | was 477 | way 478 | we 479 | welcome 480 | well 481 | went 482 | were 483 | what 484 | whatever 485 | when 486 | whence 487 | whenever 488 | where 489 | whereafter 490 | whereas 491 | whereby 492 | wherein 493 | whereupon 494 | wherever 495 | whether 496 | which 497 | while 498 | whither 499 | who 500 | whoever 501 | whole 502 | whom 503 | whose 504 | why 505 | will 506 | willing 507 | wish 508 | with 509 | within 510 | without 511 | wonder 512 | would 513 | would 514 | x 515 | y 516 | yes 517 | yet 518 | you 519 | your 520 | yours 521 | yourself 522 | yourselves 523 | z 524 | zero 525 | -------------------------------------------------------------------------------- /embedded_topic_model/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lfmatosm/embedded-topic-model/71996073d584ec38070dbd62095a021f80bcdb19/embedded_topic_model/utils/__init__.py -------------------------------------------------------------------------------- /embedded_topic_model/utils/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import numpy as np 4 | import torch 5 | import scipy.io 6 | 7 | 8 | def _fetch(path, name): 9 | if name == 'train': 10 | token_file = os.path.join(path, 'bow_tr_tokens.mat') 11 | count_file = os.path.join(path, 'bow_tr_counts.mat') 12 | elif name == 'valid': 13 | token_file = os.path.join(path, 'bow_va_tokens.mat') 14 | count_file = os.path.join(path, 'bow_va_counts.mat') 15 | else: 16 | token_file = os.path.join(path, 'bow_ts_tokens.mat') 17 | count_file = os.path.join(path, 'bow_ts_counts.mat') 18 | tokens = scipy.io.loadmat(token_file)['tokens'].squeeze() 19 | counts = scipy.io.loadmat(count_file)['counts'].squeeze() 20 | if name == 'test': 21 | token_1_file = os.path.join(path, 'bow_ts_h1_tokens.mat') 22 | count_1_file = os.path.join(path, 'bow_ts_h1_counts.mat') 23 | token_2_file = os.path.join(path, 'bow_ts_h2_tokens.mat') 24 | count_2_file = os.path.join(path, 'bow_ts_h2_counts.mat') 25 | tokens_1 = scipy.io.loadmat(token_1_file)['tokens'].squeeze() 26 | counts_1 = scipy.io.loadmat(count_1_file)['counts'].squeeze() 27 | tokens_2 = scipy.io.loadmat(token_2_file)['tokens'].squeeze() 28 | counts_2 = scipy.io.loadmat(count_2_file)['counts'].squeeze() 29 | return {'tokens': tokens, 'counts': counts, 30 | 'tokens_1': tokens_1, 'counts_1': counts_1, 31 | 'tokens_2': tokens_2, 'counts_2': counts_2} 32 | return {'tokens': tokens, 'counts': counts} 33 | 34 | 35 | def get_data(path): 36 | with open(os.path.join(path, 'vocab.pkl'), 'rb') as f: 37 | vocab = pickle.load(f) 38 | 39 | train = _fetch(path, 'train') 40 | valid = _fetch(path, 'valid') 41 | test = _fetch(path, 'test') 42 | 43 | return vocab, train, valid, test 44 | 45 | 46 | def get_batch(tokens, counts, ind, vocab_size, device, emsize=300): 47 | """fetch input data by batch.""" 48 | batch_size = len(ind) 49 | data_batch = np.zeros((batch_size, vocab_size)) 50 | 51 | for i, doc_id in enumerate(ind): 52 | doc = tokens[doc_id] 53 | count = counts[doc_id] 54 | if len(doc) == 1: 55 | doc = [doc.squeeze()] 56 | count = [count.squeeze()] 57 | else: 58 | doc = doc.squeeze() 59 | count = count.squeeze() 60 | if doc_id != -1: 61 | for j, word in enumerate(doc): 62 | data_batch[i, word] = count[j] 63 | data_batch = torch.from_numpy(data_batch).float().to(device) 64 | return data_batch 65 | -------------------------------------------------------------------------------- /embedded_topic_model/utils/embedding.py: -------------------------------------------------------------------------------- 1 | from gensim.models import Word2Vec, KeyedVectors 2 | 3 | 4 | # Class for a memory-friendly iterator over the dataset 5 | class MemoryFriendlyFileIterator(object): 6 | def __init__(self, filename): 7 | self.filename = filename 8 | 9 | def __iter__(self): 10 | for line in open(self.filename): 11 | yield line.split() 12 | 13 | 14 | def create_word2vec_embedding_from_dataset( 15 | dataset, dim_rho=300, min_count=1, sg=1, 16 | workers=25, negative_samples=10, window_size=4, iters=50, 17 | embedding_file_path=None, save_c_format_w2vec=False, debug_mode=False) -> KeyedVectors: 18 | """ 19 | Creates a Word2Vec embedding from dataset file or a list of sentences. 20 | If a file path is given, the file must be composed 21 | by a sequence of sentences separated by \\n. 22 | 23 | If the dataset is big, prefer using its file path. 24 | 25 | Parameters: 26 | === 27 | dataset (str or list of str): txt file containing the dataset or a list of sentences 28 | dim_rho (int): dimensionality of the word embeddings 29 | min_count (int): minimum term frequency (to define the vocabulary) 30 | sg (int): whether to use skip-gram 31 | workers (int): number of CPU cores 32 | negative_samples (int): number of negative samples 33 | window_size (int): window size to determine context 34 | iters (int): number of iterations 35 | embedding_file_path (str): optional. File to save the word embeddings 36 | save_c_format_w2vec (bool): wheter to save embeddings as word2vec C format (BIN and TXT files) 37 | debug_mode (bool): wheter or not to log function's operations to the console. By default, no logs are made 38 | 39 | Returns: 40 | === 41 | Word2VecKeyedVectors: mapping between words and their vector representations. 42 | Example: 43 | { 'water': nd.array([0.024187922, 0.053684134, 0.034520667, ... ]) } 44 | """ 45 | assert isinstance(dataset, str) or isinstance(dataset, list), \ 46 | 'dataset must be file path or list of sentences' 47 | 48 | if isinstance(dataset, str): 49 | assert isinstance(embedding_file_path, str), \ 50 | 'if dataset is a file path, an output embeddings file path must be given' 51 | 52 | if save_c_format_w2vec: 53 | assert isinstance(embedding_file_path, str), \ 54 | 'if save_c_format_w2vec is True, an output embeddings file path must be given' 55 | 56 | if debug_mode: 57 | print('Creating memory-friendly iterator...') 58 | 59 | sentences = MemoryFriendlyFileIterator(dataset) if isinstance( 60 | dataset, str) else [document.split() for document in dataset] 61 | 62 | if debug_mode: 63 | print('Training Word2Vec model with dataset...') 64 | 65 | model = Word2Vec( 66 | sentences, 67 | min_count=min_count, 68 | sg=sg, 69 | vector_size=dim_rho, 70 | epochs=iters, 71 | workers=workers, 72 | negative=negative_samples, 73 | window=window_size) 74 | 75 | embeddings = model.wv 76 | 77 | if embedding_file_path is not None: 78 | if debug_mode: 79 | print('Saving word-vector mappings to file...') 80 | 81 | embeddings.save(embedding_file_path) 82 | 83 | if save_c_format_w2vec: 84 | if debug_mode: 85 | print('Saving BIN/TXT original C Word2vec files...') 86 | 87 | embeddings.save_word2vec_format( 88 | f'{embedding_file_path}.bin', binary=True) 89 | embeddings.save_word2vec_format( 90 | f'{embedding_file_path}.txt', binary=False) 91 | 92 | return embeddings 93 | -------------------------------------------------------------------------------- /embedded_topic_model/utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def get_topic_diversity(beta, topk=25): 5 | num_topics = beta.shape[0] 6 | list_w = np.zeros((num_topics, topk)) 7 | for k in range(num_topics): 8 | idx = beta[k, :].argsort()[-topk:][::-1] 9 | list_w[k, :] = idx 10 | n_unique = len(np.unique(list_w)) 11 | TD = n_unique / (topk * num_topics) 12 | return TD 13 | 14 | 15 | def get_document_frequency(data, wi, wj=None): 16 | if wj is None: 17 | D_wi = 0 18 | for document in data: 19 | # FIXME: 'if' for original article's code, 'else' for updated 20 | doc = document.squeeze(0) if document.shape[0] == 1 else document 21 | 22 | if wi in doc: 23 | D_wi += 1 24 | return D_wi 25 | 26 | D_wj = 0 27 | D_wi_wj = 0 28 | for document in data: 29 | # FIXME: 'if' for original article's code, 'else' for updated version 30 | doc = document.squeeze(0) if document.shape[0] == 1 else document 31 | 32 | if wj in doc: 33 | D_wj += 1 34 | if wi in doc: 35 | D_wi_wj += 1 36 | return D_wj, D_wi_wj 37 | 38 | 39 | def get_topic_coherence(beta, data, vocab, top_n=10): 40 | D = len(data) # number of docs...data is list of documents 41 | TC = [] 42 | num_topics = len(beta) 43 | for k in range(num_topics): 44 | beta_top_n = list(beta[k].argsort()[-top_n:][::-1]) 45 | TC_k = 0 46 | counter = 0 47 | for i, word in enumerate(beta_top_n): 48 | # get D(w_i) 49 | D_wi = get_document_frequency(data, word) 50 | j = i + 1 51 | tmp = 0 52 | while j < len(beta_top_n) and j > i: 53 | # get D(w_j) and D(w_i, w_j) 54 | D_wj, D_wi_wj = get_document_frequency( 55 | data, word, beta_top_n[j]) 56 | # get f(w_i, w_j) 57 | if D_wi_wj == 0: 58 | f_wi_wj = -1 59 | else: 60 | f_wi_wj = -1 + (np.log(D_wi) + np.log(D_wj) - 61 | 2.0 * np.log(D)) / (np.log(D_wi_wj) - np.log(D)) 62 | # update tmp: 63 | tmp += f_wi_wj 64 | j += 1 65 | counter += 1 66 | # update TC_k 67 | TC_k += tmp 68 | TC.append(TC_k) 69 | TC = np.mean(TC) / counter 70 | return TC 71 | 72 | 73 | def nearest_neighbors(word, embeddings, vocab, n_most_similar=20): 74 | vectors = embeddings.data.cpu().numpy() 75 | index = vocab.index(word) 76 | query = vectors[index] 77 | ranks = vectors.dot(query).squeeze() 78 | denom = query.T.dot(query).squeeze() 79 | denom = denom * np.sum(vectors**2, 1) 80 | denom = np.sqrt(denom) 81 | ranks = ranks / denom 82 | mostSimilar = [] 83 | [mostSimilar.append(idx) for idx in ranks.argsort()[::-1]] 84 | nearest_neighbors = mostSimilar[:n_most_similar] 85 | nearest_neighbors = [vocab[comp] for comp in nearest_neighbors] 86 | return nearest_neighbors 87 | -------------------------------------------------------------------------------- /embedded_topic_model/utils/preprocessing.py: -------------------------------------------------------------------------------- 1 | from sklearn.feature_extraction.text import CountVectorizer 2 | import numpy as np 3 | from scipy import sparse 4 | from typing import Tuple, List 5 | 6 | 7 | def _remove_empty_documents(documents): 8 | return [doc for doc in documents if doc != []] 9 | 10 | 11 | def _create_list_words(documents): 12 | return [word for document in documents for word in document] 13 | 14 | 15 | def _create_document_indices(documents): 16 | aux = [[j for i in range(len(doc))] for j, doc in enumerate(documents)] 17 | return [int(x) for y in aux for x in y] 18 | 19 | 20 | def _create_bow(document_indices, words, num_docs, vocab_size): 21 | return sparse.coo_matrix( 22 | ([1] * 23 | len(document_indices), 24 | (document_indices, 25 | words)), 26 | shape=( 27 | num_docs, 28 | vocab_size)).tocsr() 29 | 30 | 31 | def _split_bow(bow_in, num_docs): 32 | indices = [[w for w in bow_in[doc, :].indices] for doc in range(num_docs)] 33 | counts = [[c for c in bow_in[doc, :].data] for doc in range(num_docs)] 34 | return indices, counts 35 | 36 | 37 | def _create_dictionaries(vocabulary): 38 | word2id = dict([(w, j) for j, w in enumerate(vocabulary)]) 39 | id2word = dict([(j, w) for j, w in enumerate(vocabulary)]) 40 | 41 | return word2id, id2word 42 | 43 | 44 | def _to_numpy_array(documents): 45 | return np.array([[np.array(doc) for doc in documents]], 46 | dtype=object).squeeze() 47 | 48 | 49 | def create_bow_dataset( 50 | dataset: List[str], 51 | vocabulary: List[str], 52 | debug_mode=False) -> Tuple[list, dict, dict]: 53 | """ 54 | Creates a bag-of-words (BOW) dataset from a given corpus. The vocabulary **must** be the same of the training corpus. 55 | 56 | Parameters: 57 | === 58 | dataset (list of str): original corpus to be preprocessed. Is composed by a list of sentences 59 | vocabulary (list of str): words vocabulary. Doesn't includes words not in the training dataset 60 | debug_mode (bool): Wheter or not to log function's operations to the console. By default, no logs are made 61 | 62 | Returns: 63 | === 64 | train_dataset (dict): BOW training dataset, split in tokens and counts. Must be used on ETM's fit() method. 65 | """ 66 | vectorizer = CountVectorizer(vocabulary=vocabulary) 67 | vectorized_documents = vectorizer.transform(dataset) 68 | 69 | dataset = [ 70 | [word for word in document.split()] 71 | for document in dataset] 72 | 73 | signed_documents = vectorized_documents.sign() 74 | 75 | if debug_mode: 76 | print('Building vocabulary...') 77 | 78 | sum_counts = signed_documents.sum(axis=0) 79 | v_size = sum_counts.shape[1] 80 | sum_counts_np = np.zeros(v_size, dtype=int) 81 | for v in range(v_size): 82 | sum_counts_np[v] = sum_counts[0, v] 83 | 84 | # Sort elements in vocabulary 85 | if debug_mode: 86 | print('Tokenizing documents and splitting into train/test...') 87 | 88 | num_docs = signed_documents.shape[0] 89 | train_dataset_size = num_docs 90 | idx_permute = np.random.permutation(num_docs).astype(int) 91 | 92 | # Create dictionary and inverse dictionary 93 | word2id, id2word = _create_dictionaries(vocabulary) 94 | 95 | docs_train = [[word2id[w] for w in dataset[idx_permute[idx_d]] 96 | if w in word2id] for idx_d in range(train_dataset_size)] 97 | 98 | if debug_mode: 99 | print( 100 | 'Number of documents (train_dataset): {} [this should be equal to {}]'.format( 101 | len(docs_train), 102 | train_dataset_size)) 103 | 104 | if debug_mode: 105 | print('Removing empty documents...') 106 | 107 | docs_train = _remove_empty_documents(docs_train) 108 | 109 | # Obtains the training and test datasets as word lists 110 | words_train = [[id2word[w] for w in doc] for doc in docs_train] 111 | 112 | words_train = _create_list_words(docs_train) 113 | 114 | if debug_mode: 115 | print('len(words_train): ', len(words_train)) 116 | 117 | doc_indices_train = _create_document_indices(docs_train) 118 | 119 | if debug_mode: 120 | print('len(np.unique(doc_indices_train)): {} [this should be {}]'.format( 121 | len(np.unique(doc_indices_train)), len(docs_train))) 122 | 123 | # Number of documents in each set 124 | n_docs_train = len(docs_train) 125 | 126 | bow_train = _create_bow( 127 | doc_indices_train, 128 | words_train, 129 | n_docs_train, 130 | len(vocabulary)) 131 | 132 | bow_train_tokens, bow_train_counts = _split_bow(bow_train, n_docs_train) 133 | 134 | return { 135 | 'tokens': _to_numpy_array(bow_train_tokens), 136 | 'counts': _to_numpy_array(bow_train_counts), 137 | } 138 | 139 | def create_etm_datasets( 140 | dataset: List[str], 141 | train_size=1.0, 142 | min_df=1, 143 | max_df=1.0, 144 | debug_mode=False) -> Tuple[list, dict, dict]: 145 | """ 146 | Creates vocabulary and train / test datasets from a given corpus. The vocabulary and datasets can 147 | be used to train an ETM model. 148 | 149 | By default, creates a train dataset with all the preprocessed documents in the corpus and an empty 150 | test dataset. 151 | 152 | This function preprocesses the given dataset, removing most and least frequent terms on the corpus - given minimum and maximum document-frequencies - and produces a BOW vocabulary. 153 | 154 | Parameters: 155 | === 156 | dataset (list of str): original corpus to be preprocessed. Is composed by a list of sentences 157 | train_size (float): fraction of the original corpus to be used for the train dataset. By default, uses entire corpus 158 | min_df (float): Minimum document-frequency for terms. Removes terms with a frequency below this threshold 159 | max_df (float): Maximum document-frequency for terms. Removes terms with a frequency above this threshold 160 | debug_mode (bool): Wheter or not to log function's operations to the console. By default, no logs are made 161 | 162 | Returns: 163 | === 164 | vocabulary (list of str): words vocabulary. Doesn't includes words not in the training dataset 165 | train_dataset (dict): BOW training dataset, split in tokens and counts. Must be used on ETM's fit() method. 166 | test_dataset (dict): BOW testing dataset, split in tokens and counts. Can be use on ETM's perplexity() method. 167 | """ 168 | vectorizer = CountVectorizer(min_df=min_df, max_df=max_df) 169 | vectorized_documents = vectorizer.fit_transform(dataset) 170 | 171 | documents_without_stop_words = [ 172 | [word for word in document.split() 173 | if word not in vectorizer.stop_words_] 174 | for document in dataset] 175 | 176 | signed_documents = vectorized_documents.sign() 177 | 178 | if debug_mode: 179 | print('Building vocabulary...') 180 | 181 | sum_counts = signed_documents.sum(axis=0) 182 | v_size = sum_counts.shape[1] 183 | sum_counts_np = np.zeros(v_size, dtype=int) 184 | for v in range(v_size): 185 | sum_counts_np[v] = sum_counts[0, v] 186 | word2id = dict([(w, vectorizer.vocabulary_.get(w)) 187 | for w in vectorizer.vocabulary_]) 188 | id2word = dict([(vectorizer.vocabulary_.get(w), w) 189 | for w in vectorizer.vocabulary_]) 190 | 191 | if debug_mode: 192 | print('Initial vocabulary size: {}'.format(v_size)) 193 | 194 | # Sort elements in vocabulary 195 | idx_sort = np.argsort(sum_counts_np) 196 | # Creates vocabulary 197 | vocabulary = [id2word[idx_sort[cc]] for cc in range(v_size)] 198 | 199 | if debug_mode: 200 | print('Tokenizing documents and splitting into train/test...') 201 | 202 | num_docs = signed_documents.shape[0] 203 | train_dataset_size = int(np.floor(train_size * num_docs)) 204 | test_dataset_size = int(num_docs - train_dataset_size) 205 | idx_permute = np.random.permutation(num_docs).astype(int) 206 | 207 | # Remove words not in train_data 208 | vocabulary = list(set([w for idx_d in range(train_dataset_size) 209 | for w in documents_without_stop_words[idx_permute[idx_d]] if w in word2id])) 210 | 211 | # Create dictionary and inverse dictionary 212 | word2id, id2word = _create_dictionaries(vocabulary) 213 | 214 | if debug_mode: 215 | print( 216 | 'vocabulary after removing words not in train: {}'.format( 217 | len(vocabulary))) 218 | 219 | docs_train = [[word2id[w] for w in documents_without_stop_words[idx_permute[idx_d]] 220 | if w in word2id] for idx_d in range(train_dataset_size)] 221 | docs_test = [ 222 | [word2id[w] for w in 223 | documents_without_stop_words[idx_permute[idx_d + train_dataset_size]] 224 | if w in word2id] for idx_d in range(test_dataset_size)] 225 | 226 | if debug_mode: 227 | print( 228 | 'Number of documents (train_dataset): {} [this should be equal to {}]'.format( 229 | len(docs_train), 230 | train_dataset_size)) 231 | print( 232 | 'Number of documents (test_dataset): {} [this should be equal to {}]'.format( 233 | len(docs_test), 234 | test_dataset_size)) 235 | 236 | if debug_mode: 237 | print('Removing empty documents...') 238 | 239 | docs_train = _remove_empty_documents(docs_train) 240 | docs_test = _remove_empty_documents(docs_test) 241 | 242 | # Remove test documents with length=1 243 | docs_test = [doc for doc in docs_test if len(doc) > 1] 244 | 245 | # Obtains the training and test datasets as word lists 246 | words_train = [[id2word[w] for w in doc] for doc in docs_train] 247 | words_test = [[id2word[w] for w in doc] for doc in docs_test] 248 | 249 | docs_test_h1 = [[w for i, w in enumerate( 250 | doc) if i <= len(doc) / 2.0 - 1] for doc in docs_test] 251 | docs_test_h2 = [[w for i, w in enumerate( 252 | doc) if i > len(doc) / 2.0 - 1] for doc in docs_test] 253 | 254 | words_train = _create_list_words(docs_train) 255 | words_test = _create_list_words(docs_test) 256 | words_ts_h1 = _create_list_words(docs_test_h1) 257 | words_ts_h2 = _create_list_words(docs_test_h2) 258 | 259 | if debug_mode: 260 | print('len(words_train): ', len(words_train)) 261 | print('len(words_test): ', len(words_test)) 262 | print('len(words_ts_h1): ', len(words_ts_h1)) 263 | print('len(words_ts_h2): ', len(words_ts_h2)) 264 | 265 | doc_indices_train = _create_document_indices(docs_train) 266 | doc_indices_test = _create_document_indices(docs_test) 267 | doc_indices_test_h1 = _create_document_indices(docs_test_h1) 268 | doc_indices_test_h2 = _create_document_indices(docs_test_h2) 269 | 270 | if debug_mode: 271 | print('len(np.unique(doc_indices_train)): {} [this should be {}]'.format( 272 | len(np.unique(doc_indices_train)), len(docs_train))) 273 | print('len(np.unique(doc_indices_test)): {} [this should be {}]'.format( 274 | len(np.unique(doc_indices_test)), len(docs_test))) 275 | print('len(np.unique(doc_indices_test_h1)): {} [this should be {}]'.format( 276 | len(np.unique(doc_indices_test_h1)), len(docs_test_h1))) 277 | print('len(np.unique(doc_indices_test_h2)): {} [this should be {}]'.format( 278 | len(np.unique(doc_indices_test_h2)), len(docs_test_h2))) 279 | 280 | # Number of documents in each set 281 | n_docs_train = len(docs_train) 282 | n_docs_test = len(docs_test) 283 | n_docs_test_h1 = len(docs_test_h1) 284 | n_docs_test_h2 = len(docs_test_h2) 285 | 286 | bow_train = _create_bow( 287 | doc_indices_train, 288 | words_train, 289 | n_docs_train, 290 | len(vocabulary)) 291 | bow_test = _create_bow( 292 | doc_indices_test, 293 | words_test, 294 | n_docs_test, 295 | len(vocabulary)) 296 | bow_test_h1 = _create_bow( 297 | doc_indices_test_h1, 298 | words_ts_h1, 299 | n_docs_test_h1, 300 | len(vocabulary)) 301 | bow_test_h2 = _create_bow( 302 | doc_indices_test_h2, 303 | words_ts_h2, 304 | n_docs_test_h2, 305 | len(vocabulary)) 306 | 307 | bow_train_tokens, bow_train_counts = _split_bow(bow_train, n_docs_train) 308 | bow_test_tokens, bow_test_counts = _split_bow(bow_test, n_docs_test) 309 | bow_test_h1_tokens, bow_test_h1_counts = _split_bow( 310 | bow_test_h1, n_docs_test_h1) 311 | bow_test_h2_tokens, bow_test_h2_counts = _split_bow( 312 | bow_test_h2, n_docs_test_h2) 313 | 314 | train_dataset = { 315 | 'tokens': _to_numpy_array(bow_train_tokens), 316 | 'counts': _to_numpy_array(bow_train_counts), 317 | } 318 | 319 | test_dataset = { 320 | 'test': { 321 | 'tokens': _to_numpy_array(bow_test_tokens), 322 | 'counts': _to_numpy_array(bow_test_counts), 323 | }, 324 | 'test1': { 325 | 'tokens': _to_numpy_array(bow_test_h1_tokens), 326 | 'counts': _to_numpy_array(bow_test_h1_counts), 327 | }, 328 | 'test2': { 329 | 'tokens': _to_numpy_array(bow_test_h2_tokens), 330 | 'counts': _to_numpy_array(bow_test_h2_counts), 331 | } 332 | } 333 | 334 | return vocabulary, train_dataset, test_dataset 335 | -------------------------------------------------------------------------------- /lint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 4 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 5 | -------------------------------------------------------------------------------- /publish.txt: -------------------------------------------------------------------------------- 1 | # Bump version 2 | bumpversion --current-version setup.py embedded_topic_model/__init__.py 3 | 4 | # Pack 5 | python setup.py sdist bdist_wheel 6 | 7 | # Check packing 8 | tar tzf dist/embedded_topic_model-1.0.0.tar.gz 9 | 10 | # Check dist 11 | twine check dist/* 12 | 13 | # Upload to TestPyPI 14 | twine upload --repository-url https://test.pypi.org/legacy/ dist/* 15 | 16 | # Install and check package 17 | pip install -U embedded_topic_model 18 | 19 | # Upload to PyPI 20 | twine upload dist/* 21 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | gensim>=4.3 2 | nltk>=3.8 3 | numpy>=1.25 4 | scikit-learn>=1.3 5 | scipy>=1.11 6 | torch>=2.0 7 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | with open('README.md') as readme_file: 4 | readme = readme_file.read() 5 | 6 | with open('CHANGELOG.md') as changelog_file: 7 | changelog = changelog_file.read() 8 | 9 | with open('requirements.txt') as f: 10 | requirements = f.read().splitlines() 11 | 12 | with open('dev_requirements.txt') as f: 13 | dev_requirements = f.read().splitlines() 14 | 15 | setup( 16 | author='Luiz F. Matos', 17 | author_email='lfmatosmelo@id.uff.br', 18 | python_requires='>=3.9', 19 | classifiers=[ 20 | 'Intended Audience :: Science/Research', 21 | 'Intended Audience :: Developers', 22 | 'License :: OSI Approved :: MIT License', 23 | 'Natural Language :: English', 24 | 'Programming Language :: Python', 25 | 'Programming Language :: Python :: 3.9', 26 | 'Programming Language :: Python :: 3.10', 27 | 'Programming Language :: Python :: 3.11', 28 | 'Topic :: Software Development', 29 | 'Topic :: Scientific/Engineering', 30 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 31 | 'Operating System :: MacOS', 32 | 'Operating System :: Unix' 33 | ], 34 | description='A package to run embedded topic modelling', 35 | install_requires=requirements, 36 | license='MIT license', 37 | long_description=readme + changelog, 38 | long_description_content_type='text/markdown', 39 | include_package_data=True, 40 | keywords='embedded_topic_model', 41 | name='embedded_topic_model', 42 | packages=find_packages(include=['embedded_topic_model', 'embedded_topic_model.models', 'embedded_topic_model.utils']), 43 | setup_requires=dev_requirements, 44 | test_suite='tests', 45 | tests_require=dev_requirements, 46 | url='https://github.com/lffloyd/embedded-topic-model', 47 | version='1.2.1', 48 | zip_safe=False, 49 | ) 50 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lfmatosm/embedded-topic-model/71996073d584ec38070dbd62095a021f80bcdb19/tests/__init__.py -------------------------------------------------------------------------------- /tests/integration/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lfmatosm/embedded-topic-model/71996073d584ec38070dbd62095a021f80bcdb19/tests/integration/__init__.py -------------------------------------------------------------------------------- /tests/integration/test_etm.py: -------------------------------------------------------------------------------- 1 | from embedded_topic_model.models import etm 2 | import joblib 3 | import torch 4 | 5 | 6 | class TestETM: 7 | def test_etm_training_with_preprocessed_dummy_dataset(self): 8 | vocabulary, embeddings, train_dataset, _ = joblib.load( 9 | 'tests/resources/train_resources.test') 10 | 11 | etm_instance = etm.ETM( 12 | vocabulary, 13 | embeddings, 14 | num_topics=3, 15 | epochs=50, 16 | train_embeddings=False, 17 | ) 18 | 19 | expected_no_topics = 3 20 | expected_no_documents = 10 21 | expected_t_w_dist_sums = torch.ones(expected_no_topics).to(torch.device('cpu')) 22 | expected_d_t_dist_sums = torch.ones(expected_no_documents).to(torch.device('cpu')) 23 | 24 | etm_instance.fit(train_dataset) 25 | 26 | topics = etm_instance.get_topics() 27 | 28 | assert len(topics) == expected_no_topics, \ 29 | "no. of topics error: exp = {}, result = {}".format(expected_no_topics, len(topics)) 30 | 31 | coherence = etm_instance.get_topic_coherence() 32 | 33 | assert coherence != 0.0, \ 34 | 'zero coherence returned' 35 | 36 | diversity = etm_instance.get_topic_diversity() 37 | 38 | assert diversity != 0.0, \ 39 | 'zero diversity returned' 40 | 41 | t_w_mtx = etm_instance.get_topic_word_matrix() 42 | 43 | assert len(t_w_mtx) == expected_no_topics, \ 44 | "no. of topics in topic-word matrix error: exp = {}, result = {}".format(expected_no_topics, len(t_w_mtx)) 45 | 46 | t_w_dist = etm_instance.get_topic_word_dist() 47 | assert len(t_w_dist) == expected_no_topics, \ 48 | "topic-word distribution error: exp = {}, result = {}".format(expected_no_topics, len(t_w_dist)) 49 | 50 | similar_words = etm_instance.get_most_similar_words(["boat", "day", "sun"], 5) 51 | assert len(similar_words) == 3, \ 52 | "get_most_similar_words error: expected {} keys but {} were returned for method call".format(3, len(similar_words)) 53 | for key in similar_words.keys(): 54 | assert 0 <= len(similar_words[key]) <= 5, \ 55 | "get_most_similar_words error: expected <= {} elements but got {} for {} key".format(5, len(similar_words[key]), key) 56 | 57 | 58 | t_w_dist_below_zero_elems = t_w_dist[t_w_dist < 0] 59 | assert len(t_w_dist_below_zero_elems) == 0, \ 60 | 'there are elements smaller than 0 in the topic-word distribution' 61 | 62 | t_w_dist_sums = torch.sum(t_w_dist, 1).to(torch.device('cpu')) 63 | assert torch.allclose( 64 | t_w_dist_sums, expected_t_w_dist_sums), "t_w_dist_sums error: exp = {}, result = {}".format( 65 | expected_t_w_dist_sums, t_w_dist_sums) 66 | 67 | d_t_dist = etm_instance.get_document_topic_dist() 68 | assert len(d_t_dist) == expected_no_documents, \ 69 | "document-topics distribution error: exp = {}, result = {}".format(expected_no_documents, len(d_t_dist)) 70 | 71 | d_t_dist_below_zero_elems = d_t_dist[d_t_dist < 0] 72 | assert len(d_t_dist_below_zero_elems) == 0, \ 73 | 'there are elements smaller than 0 in the document-topic distribution' 74 | 75 | d_t_dist_sums = torch.sum(d_t_dist, 1).to(torch.device('cpu')) 76 | assert torch.allclose( 77 | d_t_dist_sums, expected_d_t_dist_sums), "d_t_dist_sums error: exp = {}, result = {}".format( 78 | expected_d_t_dist_sums, d_t_dist_sums) 79 | 80 | def test_etm_training_with_preprocessed_dummy_dataset_and_embeddings_file( 81 | self): 82 | vocabulary, _, train_dataset, _ = joblib.load( 83 | 'tests/resources/train_resources.test') 84 | 85 | etm_instance = etm.ETM( 86 | vocabulary, 87 | embeddings='tests/resources/train_w2v_embeddings.wordvectors', 88 | num_topics=3, 89 | epochs=50, 90 | train_embeddings=False, 91 | ) 92 | 93 | expected_no_topics = 3 94 | expected_no_documents = 10 95 | expected_t_w_dist_sums = torch.ones(expected_no_topics).to(torch.device('cpu')) 96 | expected_d_t_dist_sums = torch.ones(expected_no_documents).to(torch.device('cpu')) 97 | 98 | etm_instance.fit(train_dataset) 99 | 100 | topics = etm_instance.get_topics() 101 | 102 | assert len(topics) == expected_no_topics, \ 103 | "no. of topics error: exp = {}, result = {}".format(expected_no_topics, len(topics)) 104 | 105 | coherence = etm_instance.get_topic_coherence() 106 | 107 | assert coherence != 0.0, \ 108 | 'zero coherence returned' 109 | 110 | diversity = etm_instance.get_topic_diversity() 111 | 112 | assert diversity != 0.0, \ 113 | 'zero diversity returned' 114 | 115 | t_w_mtx = etm_instance.get_topic_word_matrix() 116 | 117 | assert len(t_w_mtx) == expected_no_topics, \ 118 | "no. of topics in topic-word matrix error: exp = {}, result = {}".format(expected_no_topics, len(t_w_mtx)) 119 | 120 | t_w_dist = etm_instance.get_topic_word_dist() 121 | assert len(t_w_dist) == expected_no_topics, \ 122 | "topic-word distribution error: exp = {}, result = {}".format(expected_no_topics, len(t_w_dist)) 123 | 124 | t_w_dist_below_zero_elems = t_w_dist[t_w_dist < 0] 125 | assert len(t_w_dist_below_zero_elems) == 0, \ 126 | 'there are elements smaller than 0 in the topic-word distribution' 127 | 128 | t_w_dist_sums = torch.sum(t_w_dist, 1).to(torch.device('cpu')) 129 | assert torch.allclose( 130 | t_w_dist_sums, expected_t_w_dist_sums), "t_w_dist_sums error: exp = {}, result = {}".format( 131 | expected_t_w_dist_sums, t_w_dist_sums) 132 | 133 | d_t_dist = etm_instance.get_document_topic_dist() 134 | assert len(d_t_dist) == expected_no_documents, \ 135 | "document-topics distribution error: exp = {}, result = {}".format(expected_no_documents, len(d_t_dist)) 136 | 137 | d_t_dist_below_zero_elems = d_t_dist[d_t_dist < 0] 138 | assert len(d_t_dist_below_zero_elems) == 0, \ 139 | 'there are elements smaller than 0 in the document-topic distribution' 140 | 141 | d_t_dist_sums = torch.sum(d_t_dist, 1).to(torch.device('cpu')) 142 | assert torch.allclose( 143 | d_t_dist_sums, expected_d_t_dist_sums), "d_t_dist_sums error: exp = {}, result = {}".format( 144 | expected_d_t_dist_sums, d_t_dist_sums) 145 | 146 | def test_etm_training_with_preprocessed_dummy_dataset_and_c_wordvec_txt_embeddings_file( 147 | self): 148 | vocabulary, _, train_dataset, _ = joblib.load( 149 | 'tests/resources/train_resources.test') 150 | 151 | etm_instance = etm.ETM( 152 | vocabulary, 153 | embeddings='tests/resources/train_w2v_embeddings.wordvectors.txt', 154 | num_topics=3, 155 | epochs=50, 156 | train_embeddings=False, 157 | use_c_format_w2vec=True, 158 | ) 159 | 160 | expected_no_topics = 3 161 | expected_no_documents = 10 162 | expected_t_w_dist_sums = torch.ones(expected_no_topics).to(torch.device('cpu')) 163 | expected_d_t_dist_sums = torch.ones(expected_no_documents).to(torch.device('cpu')) 164 | 165 | etm_instance.fit(train_dataset) 166 | 167 | topics = etm_instance.get_topics() 168 | 169 | assert len(topics) == expected_no_topics, \ 170 | "no. of topics error: exp = {}, result = {}".format(expected_no_topics, len(topics)) 171 | 172 | coherence = etm_instance.get_topic_coherence() 173 | 174 | assert coherence != 0.0, \ 175 | 'zero coherence returned' 176 | 177 | diversity = etm_instance.get_topic_diversity() 178 | 179 | assert diversity != 0.0, \ 180 | 'zero diversity returned' 181 | 182 | t_w_mtx = etm_instance.get_topic_word_matrix() 183 | 184 | assert len(t_w_mtx) == expected_no_topics, \ 185 | "no. of topics in topic-word matrix error: exp = {}, result = {}".format(expected_no_topics, len(t_w_mtx)) 186 | 187 | t_w_dist = etm_instance.get_topic_word_dist() 188 | assert len(t_w_dist) == expected_no_topics, \ 189 | "topic-word distribution error: exp = {}, result = {}".format(expected_no_topics, len(t_w_dist)) 190 | 191 | t_w_dist_below_zero_elems = t_w_dist[t_w_dist < 0] 192 | assert len(t_w_dist_below_zero_elems) == 0, \ 193 | 'there are elements smaller than 0 in the topic-word distribution' 194 | 195 | t_w_dist_sums = torch.sum(t_w_dist, 1).to(torch.device('cpu')) 196 | assert torch.allclose( 197 | t_w_dist_sums, expected_t_w_dist_sums), "t_w_dist_sums error: exp = {}, result = {}".format( 198 | expected_t_w_dist_sums, t_w_dist_sums) 199 | 200 | d_t_dist = etm_instance.get_document_topic_dist() 201 | assert len(d_t_dist) == expected_no_documents, \ 202 | "document-topics distribution error: exp = {}, result = {}".format(expected_no_documents, len(d_t_dist)) 203 | 204 | d_t_dist_below_zero_elems = d_t_dist[d_t_dist < 0] 205 | assert len(d_t_dist_below_zero_elems) == 0, \ 206 | 'there are elements smaller than 0 in the document-topic distribution' 207 | 208 | d_t_dist_sums = torch.sum(d_t_dist, 1).to(torch.device('cpu')) 209 | assert torch.allclose( 210 | d_t_dist_sums, expected_d_t_dist_sums), "d_t_dist_sums error: exp = {}, result = {}".format( 211 | expected_d_t_dist_sums, d_t_dist_sums) 212 | 213 | def test_etm_training_with_preprocessed_dummy_dataset_and_c_wordvec_bin_embeddings_file( 214 | self): 215 | vocabulary, _, train_dataset, _ = joblib.load( 216 | 'tests/resources/train_resources.test') 217 | 218 | etm_instance = etm.ETM( 219 | vocabulary, 220 | embeddings='tests/resources/train_w2v_embeddings.wordvectors.bin', 221 | num_topics=3, 222 | epochs=50, 223 | train_embeddings=False, 224 | use_c_format_w2vec=True, 225 | ) 226 | 227 | expected_no_topics = 3 228 | expected_no_documents = 10 229 | expected_t_w_dist_sums = torch.ones(expected_no_topics).to(torch.device('cpu')) 230 | expected_d_t_dist_sums = torch.ones(expected_no_documents).to(torch.device('cpu')) 231 | 232 | etm_instance.fit(train_dataset) 233 | 234 | topics = etm_instance.get_topics() 235 | 236 | assert len(topics) == expected_no_topics, \ 237 | "no. of topics error: exp = {}, result = {}".format(expected_no_topics, len(topics)) 238 | 239 | coherence = etm_instance.get_topic_coherence() 240 | 241 | assert coherence != 0.0, \ 242 | 'zero coherence returned' 243 | 244 | diversity = etm_instance.get_topic_diversity() 245 | 246 | assert diversity != 0.0, \ 247 | 'zero diversity returned' 248 | 249 | t_w_mtx = etm_instance.get_topic_word_matrix() 250 | 251 | assert len(t_w_mtx) == expected_no_topics, \ 252 | "no. of topics in topic-word matrix error: exp = {}, result = {}".format(expected_no_topics, len(t_w_mtx)) 253 | 254 | t_w_dist = etm_instance.get_topic_word_dist() 255 | assert len(t_w_dist) == expected_no_topics, \ 256 | "topic-word distribution error: exp = {}, result = {}".format(expected_no_topics, len(t_w_dist)) 257 | 258 | t_w_dist_below_zero_elems = t_w_dist[t_w_dist < 0] 259 | assert len(t_w_dist_below_zero_elems) == 0, \ 260 | 'there are elements smaller than 0 in the topic-word distribution' 261 | 262 | t_w_dist_sums = torch.sum(t_w_dist, 1).to(torch.device('cpu')) 263 | assert torch.allclose( 264 | t_w_dist_sums, expected_t_w_dist_sums), "t_w_dist_sums error: exp = {}, result = {}".format( 265 | expected_t_w_dist_sums, t_w_dist_sums) 266 | 267 | d_t_dist = etm_instance.get_document_topic_dist() 268 | assert len(d_t_dist) == expected_no_documents, \ 269 | "document-topics distribution error: exp = {}, result = {}".format(expected_no_documents, len(d_t_dist)) 270 | 271 | d_t_dist_below_zero_elems = d_t_dist[d_t_dist < 0] 272 | assert len(d_t_dist_below_zero_elems) == 0, \ 273 | 'there are elements smaller than 0 in the document-topic distribution' 274 | 275 | d_t_dist_sums = torch.sum(d_t_dist, 1).to(torch.device('cpu')) 276 | assert torch.allclose( 277 | d_t_dist_sums, expected_d_t_dist_sums), "d_t_dist_sums error: exp = {}, result = {}".format( 278 | expected_d_t_dist_sums, d_t_dist_sums) 279 | -------------------------------------------------------------------------------- /tests/resources/train_resources.test: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lfmatosm/embedded-topic-model/71996073d584ec38070dbd62095a021f80bcdb19/tests/resources/train_resources.test -------------------------------------------------------------------------------- /tests/resources/train_w2v_embeddings.wordvectors: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lfmatosm/embedded-topic-model/71996073d584ec38070dbd62095a021f80bcdb19/tests/resources/train_w2v_embeddings.wordvectors -------------------------------------------------------------------------------- /tests/resources/train_w2v_embeddings.wordvectors.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lfmatosm/embedded-topic-model/71996073d584ec38070dbd62095a021f80bcdb19/tests/resources/train_w2v_embeddings.wordvectors.bin -------------------------------------------------------------------------------- /tests/unit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lfmatosm/embedded-topic-model/71996073d584ec38070dbd62095a021f80bcdb19/tests/unit/__init__.py -------------------------------------------------------------------------------- /tests/unit/test_embedding.py: -------------------------------------------------------------------------------- 1 | from embedded_topic_model.utils import embedding 2 | from gensim.models import KeyedVectors 3 | 4 | 5 | def test_create_word2vec_embedding_from_dataset(): 6 | documents = [ 7 | "Peanut butter and jelly caused the elderly lady to think about her past.", 8 | "Toddlers feeding raccoons surprised even the seasoned park ranger.", 9 | "You realize you're not alone as you sit in your bedroom massaging your calves after a long day of playing tug-of-war with Grandpa Joe in the hospital.", 10 | "She wondered what his eyes were saying beneath his mirrored sunglasses.", 11 | "He was disappointed when he found the beach to be so sandy and the sun so sunny.", 12 | "Flesh-colored yoga pants were far worse than even he feared.", 13 | "The wake behind the boat told of the past while the open sea for told life in the unknown future.", 14 | "Improve your goldfish's physical fitness by getting him a bicycle.", 15 | "Harrold felt confident that nobody would ever suspect his spy pigeon.", 16 | "Nudist colonies shun fig-leaf couture.", 17 | ] 18 | 19 | dimensionality = 240 20 | embeddings = embedding.create_word2vec_embedding_from_dataset( 21 | documents, dim_rho=dimensionality) 22 | 23 | assert isinstance( 24 | embeddings, KeyedVectors), "embeddings isn't KeyedVectors instance" 25 | 26 | vector = embeddings['Peanut'] 27 | 28 | assert len(vector) == dimensionality, "lenght of 'Peanut' vector doesn't match: exp = {}, result = {}".format( 29 | dimensionality, len(vector)) 30 | -------------------------------------------------------------------------------- /tests/unit/test_preprocessing.py: -------------------------------------------------------------------------------- 1 | from embedded_topic_model.utils import preprocessing 2 | 3 | 4 | def test_create_etm_datasets(): 5 | documents = [ 6 | "Peanut butter and jelly caused the elderly lady to think about her past.", 7 | "Toddlers feeding raccoons surprised even the seasoned park ranger.", 8 | "You realize you're not alone as you sit in your bedroom massaging your calves after a long day of playing tug-of-war with Grandpa Joe in the hospital.", 9 | "She wondered what his eyes were saying beneath his mirrored sunglasses.", 10 | "He was disappointed when he found the beach to be so sandy and the sun so sunny.", 11 | "Flesh-colored yoga pants were far worse than even he feared.", 12 | "The wake behind the boat told of the past while the open sea for told life in the unknown future.", 13 | "Improve your goldfish's physical fitness by getting him a bicycle.", 14 | "Harrold felt confident that nobody would ever suspect his spy pigeon.", 15 | "Nudist colonies shun fig-leaf couture.", 16 | ] 17 | 18 | no_documents_in_train = 7 19 | no_documents_in_test = 3 20 | 21 | vocabulary, train_dataset, test_dataset = preprocessing.create_etm_datasets( 22 | documents, train_size=0.7) 23 | 24 | assert isinstance(vocabulary, list), "vocabulary isn't list" 25 | 26 | assert len(train_dataset['tokens']) == no_documents_in_train and len( 27 | train_dataset['counts']) == no_documents_in_train, "lengths of tokens and counts for training dataset doesn't match" 28 | 29 | assert len(test_dataset['test']['tokens']) == no_documents_in_test and len( 30 | test_dataset['test']['counts']) == no_documents_in_test, "lengths of tokens and counts for testing dataset doesn't match" 31 | -------------------------------------------------------------------------------- /train_resources.test: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lfmatosm/embedded-topic-model/71996073d584ec38070dbd62095a021f80bcdb19/train_resources.test --------------------------------------------------------------------------------