├── .codecov.yml ├── .coveragerc ├── .flake8 ├── .gitattributes ├── .github └── ISSUE_TEMPLATE │ ├── bug-report.md │ ├── feature-request.md │ └── usage-question.md ├── .gitignore ├── .travis.yml ├── CODEOWNERS ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── Makefile ├── README.md ├── artworks └── matchzoo-logo.png ├── docs ├── DOCCHECK.md ├── Makefile ├── Readme.md ├── make.bat ├── requirements.txt └── source │ ├── conf.py │ ├── index.rst │ ├── model_reference.rst │ └── modules.rst ├── matchzoo ├── __init__.py ├── auto │ ├── __init__.py │ ├── preparer │ │ ├── __init__.py │ │ ├── prepare.py │ │ └── preparer.py │ └── tuner │ │ ├── __init__.py │ │ ├── tune.py │ │ └── tuner.py ├── data_pack │ ├── __init__.py │ ├── data_pack.py │ └── pack.py ├── dataloader │ ├── __init__.py │ ├── callbacks │ │ ├── __init__.py │ │ ├── histogram.py │ │ ├── lambda_callback.py │ │ ├── ngram.py │ │ └── padding.py │ ├── dataloader.py │ ├── dataloader_builder.py │ ├── dataset.py │ └── dataset_builder.py ├── datasets │ ├── __init__.py │ ├── embeddings │ │ ├── __init__.py │ │ ├── embed_10_glove.txt │ │ ├── embed_10_word2vec.txt │ │ ├── embed_err.txt.gb2312 │ │ ├── embed_rank.txt │ │ ├── embed_word.txt │ │ ├── load_fasttext_embedding.py │ │ └── load_glove_embedding.py │ ├── quora_qp │ │ ├── __init__.py │ │ └── load_data.py │ ├── snli │ │ ├── __init__.py │ │ └── load_data.py │ ├── toy │ │ ├── __init__.py │ │ ├── dev.csv │ │ ├── embedding.2d.txt │ │ ├── test.csv │ │ └── train.csv │ └── wiki_qa │ │ ├── __init__.py │ │ └── load_data.py ├── embedding │ ├── __init__.py │ └── embedding.py ├── engine │ ├── __init__.py │ ├── base_callback.py │ ├── base_metric.py │ ├── base_model.py │ ├── base_preprocessor.py │ ├── base_task.py │ ├── hyper_spaces.py │ ├── param.py │ └── param_table.py ├── losses │ ├── __init__.py │ ├── rank_cross_entropy_loss.py │ └── rank_hinge_loss.py ├── metrics │ ├── __init__.py │ ├── accuracy.py │ ├── average_precision.py │ ├── cross_entropy.py │ ├── discounted_cumulative_gain.py │ ├── mean_average_precision.py │ ├── mean_reciprocal_rank.py │ ├── normalized_discounted_cumulative_gain.py │ └── precision.py ├── models │ ├── README.rst │ ├── __init__.py │ ├── anmm.py │ ├── arci.py │ ├── arcii.py │ ├── bert.py │ ├── bimpm.py │ ├── cdssm.py │ ├── conv_knrm.py │ ├── dense_baseline.py │ ├── diin.py │ ├── drmm.py │ ├── drmmtks.py │ ├── dssm.py │ ├── duet.py │ ├── esim.py │ ├── hbmp.py │ ├── knrm.py │ ├── match_pyramid.py │ ├── match_srnn.py │ ├── matchlstm.py │ ├── mvlstm.py │ └── parameter_readme_generator.py ├── modules │ ├── __init__.py │ ├── attention.py │ ├── bert_module.py │ ├── character_embedding.py │ ├── dense_net.py │ ├── dropout.py │ ├── gaussian_kernel.py │ ├── matching.py │ ├── matching_tensor.py │ ├── semantic_composite.py │ ├── spatial_gru.py │ └── stacked_brnn.py ├── preprocessors │ ├── __init__.py │ ├── basic_preprocessor.py │ ├── bert_preprocessor.py │ ├── build_unit_from_data_pack.py │ ├── build_vocab_unit.py │ ├── chain_transform.py │ ├── naive_preprocessor.py │ └── units │ │ ├── __init__.py │ │ ├── character_index.py │ │ ├── digit_removal.py │ │ ├── frequency_filter.py │ │ ├── lemmatization.py │ │ ├── lowercase.py │ │ ├── matching_histogram.py │ │ ├── ngram_letter.py │ │ ├── punc_removal.py │ │ ├── stateful_unit.py │ │ ├── stemming.py │ │ ├── stop_removal.py │ │ ├── tokenize.py │ │ ├── truncated_length.py │ │ ├── unit.py │ │ ├── vocabulary.py │ │ ├── word_exact_match.py │ │ └── word_hashing.py ├── tasks │ ├── __init__.py │ ├── classification.py │ └── ranking.py ├── trainers │ ├── __init__.py │ └── trainer.py ├── utils │ ├── __init__.py │ ├── average_meter.py │ ├── early_stopping.py │ ├── get_file.py │ ├── list_recursive_subclasses.py │ ├── one_hot.py │ ├── parse.py │ ├── tensor_type.py │ └── timer.py └── version.py ├── pytest.ini ├── requirements.txt ├── setup.py ├── tests ├── __init__.py ├── data_pack │ └── test_datapack.py ├── dataloader │ ├── test_callbacks.py │ └── test_dataset.py ├── engine │ ├── test_base_preprocessor.py │ ├── test_base_task.py │ ├── test_hyper_spaces.py │ └── test_param_table.py ├── models │ ├── test_base_model.py │ └── test_models.py ├── modules │ └── test_modules.py ├── tasks │ └── test_tasks.py ├── test_datasets.py ├── test_embedding.py ├── test_losses.py ├── test_metrics.py ├── test_utils.py └── trainer │ └── test_trainer.py └── tutorials ├── classification ├── bert.ipynb ├── esim.ipynb └── init.ipynb └── ranking ├── anmm.ipynb ├── arci.ipynb ├── arcii.ipynb ├── bert.ipynb ├── cdssm.ipynb ├── conv_knrm.ipynb ├── drmm.ipynb ├── drmmtks.ipynb ├── dssm.ipynb ├── duet.ipynb ├── esim.ipynb ├── init.ipynb ├── knrm.ipynb ├── match_pyramid.ipynb ├── match_srnn.ipynb └── matchlstm.ipynb /.codecov.yml: -------------------------------------------------------------------------------- 1 | coverage: 2 | status: 3 | project: 4 | default: 5 | # basic 6 | target: auto 7 | threshold: 3% 8 | base: auto 9 | # advanced 10 | branches: null 11 | if_no_uploads: error 12 | if_not_found: success 13 | if_ci_failed: error 14 | only_pulls: false 15 | flags: null 16 | paths: null 17 | patch: 18 | default: 19 | threshold: 1% -------------------------------------------------------------------------------- /.coveragerc: -------------------------------------------------------------------------------- 1 | # .coveragerc to control coverage.py 2 | [report] 3 | # regrexes for lines to exclude from consideration 4 | exclude_lines = 5 | if __name__ == .__main__.: 6 | ValueError 7 | TypeError 8 | NotImplementedError 9 | omit = 10 | matchzoo/__init__.py 11 | matchzoo/version.py 12 | matchzoo/*/__init__.py 13 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | 3 | # Maximum number of characters on a single line. Ideally, lines should be under 79 characters, 4 | # but we allow some leeway before calling it an error. 5 | max-line-length = 90 6 | 7 | ignore = 8 | # D401 First line should be in imperative mood 9 | D401, 10 | # D202 No blank lines allowed after function docstring 11 | D202, 12 | 13 | # For doctests: 14 | # D207 Docstring is under-indented 15 | D207, 16 | # D301 Use r""" if any backslashes in a docstring 17 | D301, 18 | # F401 'blah blah' imported but unused 19 | F401, 20 | 21 | # D100 Missing docstring in public module 22 | D100, -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | tutorials/* linguist-vendored -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug-report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug Report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: bug 6 | assignees: '' 7 | 8 | --- 9 | 10 | ### Describe the bug 11 | Please provide a clear and concise description of what the bug is. If applicable, add screenshots to help explain your problem, especially for visualization related problems. 12 | 13 | ### To Reproduce 14 | Please provide a [Minimal, Complete, and Verifiable example](https://stackoverflow.com/help/mcve) here. We hope we can simply copy/paste/run it. It is also nice to share a hosted runnable script (e.g. Google Colab), especially for hardware-related problems. 15 | 16 | ### Describe your attempts 17 | - [ ] I checked the documentation and found no answer 18 | - [ ] I checked to make sure that this is not a duplicate issue 19 | 20 | You should also provide code snippets you tried as a workaround, StackOverflow solution that you have walked through, or your best guess of the cause that you can't locate (e.g. cosmic radiation). 21 | 22 | ### Context 23 | - **OS** [e.g. Windows 10, macOS 10.14]: 24 | - **Hardware** [e.g. CPU only, GTX 1080 Ti]: 25 | 26 | In addition, figure out your MatchZoo version by running `import matchzoo; matchzoo.__version__`. 27 | 28 | ### Additional Information 29 | Other things you want the developers to know. 30 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature-request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature Request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: enhancement 6 | assignees: '' 7 | 8 | --- 9 | 10 | - [ ] I checked to make sure that this is not a duplicate issue 11 | - [ ] I'm submitting the request to the correct repository (for model requests, see [here](https://github.com/NTMC-Community/awaresome-neural-models-for-semantic-match)) 12 | 13 | ### Is your feature request related to a problem? Please describe. 14 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 15 | 16 | ### Describe the solution you'd like 17 | A clear and concise description of what you want to happen. 18 | 19 | ### Describe alternatives you've considered 20 | A clear and concise description of any alternative solutions or features you've considered. 21 | 22 | ### Additional Information 23 | Other things you want the developers to know. 24 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/usage-question.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Usage Question 3 | about: Ask a question about MatchZoo usage 4 | title: '' 5 | labels: question 6 | assignees: '' 7 | 8 | --- 9 | 10 | ### Describe the Question 11 | Please provide a clear and concise description of what the question is. 12 | 13 | ### Describe your attempts 14 | - [ ] I walked through the tutorials 15 | - [ ] I checked the documentation 16 | - [ ] I checked to make sure that this is not a duplicate question 17 | 18 | You may also provide a [Minimal, Complete, and Verifiable example](https://stackoverflow.com/help/mcve) you tried as a workaround, or StackOverflow solution that you have walked through. (e.g. cosmic radiation). 19 | 20 | In addition, figure out your MatchZoo version by running `import matchzoo; matchzoo.__version__`. 21 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.log 3 | *.swp 4 | *.bak 5 | *.weights 6 | *.trec 7 | *.ranklist 8 | *.DS_Store 9 | .vscode 10 | .coverage 11 | .ipynb_checkpoints/ 12 | predict.* 13 | build/ 14 | dist/ 15 | data/ 16 | save/ 17 | log/* 18 | .ipynb_checkpoints/ 19 | matchzoo/log/* 20 | matchzoo/querydecision/ 21 | log/* 22 | .idea/ 23 | .pytest_cache/ 24 | MatchZoo.egg-info/ 25 | notebooks/wikiqa/.ipynb_checkpoints/* 26 | .cache 27 | .tmpdir 28 | htmlcov/ 29 | docs/_build 30 | matchzoo_py.egg-info/ 31 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | 3 | cache: pip 4 | 5 | sudo: true 6 | 7 | env: 8 | global: 9 | - PYTHONPATH=$PYTHONPATH:$TRAVIS_BUILD_DIR/tests:$TRAVIS_BUILD_DIR/matchzoo 10 | 11 | matrix: 12 | allow_failures: 13 | - os: osx 14 | include: 15 | - os: linux 16 | dist: xenial 17 | python: 3.6 18 | - os: osx 19 | osx_image: xcode10.2 20 | language: shell 21 | 22 | install: 23 | - pip3 install --progress-bar off -r requirements.txt 24 | - python3 -m nltk.downloader punkt 25 | - python3 -m nltk.downloader wordnet 26 | - python3 -m nltk.downloader stopwords 27 | 28 | script: 29 | - stty cols 80 30 | - export COLUMNS=80 31 | - if [ "$TRAVIS_EVENT_TYPE" == "pull_request" ]; then make push; fi 32 | - if [ "$TRAVIS_EVENT_TYPE" == "push" ]; then make push; fi 33 | - if [ "$TRAVIS_EVENT_TYPE" == "cron" ]; then make cron; fi 34 | 35 | 36 | after_success: 37 | - codecov 38 | -------------------------------------------------------------------------------- /CODEOWNERS: -------------------------------------------------------------------------------- 1 | # Watchers and contributors to MatchZoo repo directories/packages/files 2 | # Please see documentation of use of CODEOWNERS file at 3 | # https://help.github.com/articles/about-codeowners/ and 4 | # https://github.com/blog/2392-introducing-code-owners 5 | # 6 | # Anybody can add themselves or a team as additional watcher or contributor 7 | # to get notified about changes in a specific package. 8 | # See https://help.github.com/articles/about-teams how to setup teams. 9 | 10 | # Define individuals or teams that are responsible for code in a repository. 11 | 12 | # global owner. 13 | * @faneshion 14 | * @Chriskuei 15 | 16 | # third-party & project configuration 17 | .codecov.yml @Chriskuei 18 | .coveragerc @Chriskuei 19 | .flake8 @Chriskuei 20 | .gitignore @Chriskuei 21 | .travis.yml @Chriskuei 22 | CONTRIBUTING.MD @Chriskuei 23 | Makefile @Chriskuei 24 | pytest.ini @Chriskuei 25 | README.md @faneshion @Chriskuei 26 | readthedocs.yml @wqh17101 27 | requirements.txt @Chriskuei @faneshion 28 | setup.py @Chriskuei @faneshion 29 | 30 | # artworks 31 | /artworks/ @faneshion 32 | 33 | # tutorials 34 | /tutorials/ @Chriskuei @faneshion @caiyinqiong 35 | 36 | # docs 37 | /docs/ @wqh17101 38 | 39 | # tests 40 | /tests/ @Chriskuei @faneshion 41 | 42 | # matchzoo 43 | 44 | /matchzoo/auto/ @Chriskuei 45 | /matchzoo/data_pack/ @caiyinqiong @faneshion 46 | /matchzoo/dataloader/ @caiyinqiong @Chriskuei 47 | /matchzoo/datasets/ @caiyinqiong 48 | /matchzoo/embedding/ @caiyinqiong 49 | /matchzoo/engine/ @faneshion @Chriskuei 50 | /matchzoo/losses/ @faneshion @Chriskuei 51 | /matchzoo/metrics/ @faneshion @Chriskuei 52 | /matchzoo/models/ @Chriskuei @faneshion @caiyinqiong 53 | /matchzoo/modules/ @Chriskuei @caiyinqiong 54 | /matchzoo/preprocessors/ @caiyinqiong @faneshion 55 | /matchzoo/tasks/ @Chriskuei 56 | /matchzoo/trainers/ @Chriskuei 57 | /matchzoo/utils/ @Chriskuei @caiyinqiong 58 | /matchzoo/* @faneshion @Chriskuei 59 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | Contributing to MatchZoo-py 2 | ---------- 3 | 4 | > Note: MatchZoo-py is developed under Python 3.6. 5 | 6 | Welcome! MatchZoo-py is a community project that aims to work for a wide range of NLP and IR tasks such as Question Answering, Information Retrieval, Paraphrase identification etc. Your experience and what you can contribute are important to the project's success. 7 | 8 | Discussion 9 | ---------- 10 | 11 | If you've run into behavior in MatchZoo-py you don't understand, or you're having trouble working out a good way to apply it to your code, or you've found a bug or would like a feature it doesn't have, we want to hear from you! 12 | 13 | Our main forum for discussion is the project's [GitHub issue tracker](https://github.com/NTMC-Community/MatchZoo-py/issues). This is the right place to start a discussion of any of the above or most any other topic concerning the project. 14 | 15 | For less formal discussion we have a chat room on WeChat (mostly Chinese speakers). MatchZoo-py core developers are almost always present; feel free to find us there and we're happy to chat. Please add *YQ-Cai1198593462* as your WeChat friend, she will invite you to join the chat room. 16 | 17 | First Time Contributors 18 | ----------------------- 19 | 20 | MatchZoo-py appreciates your contribution! If you are interested in helping improve MatchZoo-py, there are several ways to get started: 21 | 22 | * Work on [new models](https://github.com/NTMC-Community/awaresome-neural-models-for-semantic-match). 23 | * Work on [tutorials](https://github.com/NTMC-Community/MatchZoo-py/tree/master/tutorials). 24 | * Work on [documentation](https://github.com/NTMC-Community/MatchZoo-py/tree/master/docs). 25 | * Try to answer questions on [the issue tracker](https://github.com/NTMC-Community/MatchZoo-py/issues). 26 | 27 | Submitting Changes 28 | ------------------ 29 | 30 | Even more excellent than a good bug report is a fix for a bug, or the implementation of a much-needed new model. 31 | 32 | (*) We'd love to have your contributions. 33 | 34 | (*) If your new feature will be a lot of work, we recommend talking to us early -- see below. 35 | 36 | We use the usual GitHub pull-request flow, which may be familiar to you if you've contributed to other projects on GitHub -- see below. 37 | 38 | Anyone interested in MatchZoo-py may review your code. One of the MatchZoo-py core developers will merge your pull request when they think it's ready. 39 | For every pull request, we aim to promptly either merge it or say why it's not yet ready; if you go a few days without a reply, please feel 40 | free to ping the thread by adding a new comment. 41 | 42 | For a list of MatchZoo-py core developers, see [README](https://github.com/NTMC-Community/MatchZoo-py/blob/master/README.md). 43 | 44 | Contributing Flow 45 | ------------------ 46 | 47 | 1. Fork the latest version of [MatchZoo-py](https://github.com/NTMC-Community/MatchZoo-py) into your repo. 48 | 2. Create an issue under [NTMC-Community/MatchZoo-py](https://github.com/NTMC-Community/MatchZoo-py/issues), write description about the bug/enhancement. 49 | 3. Clone your forked MatchZoo into your machine, add your changes together with associated tests. 50 | 4. Run `make push` with terminal, ensure all unit tests & integration tests passed on your computer. 51 | 5. Push to your forked repo, then send the pull request to the official repo. In pull request, you need to create a link to the issue you created using `#[issue_id]`, and describe what has been changed. 52 | 6. Wait [continuous integration](https://travis-ci.org/NTMC-Community/MatchZoo-py) passed. 53 | 7. Wait [Codecov](https://codecov.io/gh/NTMC-Community/MatchZoo-py) generate the coverage report. 54 | 8. We'll assign reviewers to review your code. 55 | 56 | 57 | Your PR will be merged if: 58 | - Funcitonally benefit for the project. 59 | - Passed Countinuous Integration (all unit tests, integration tests and [PEP8](https://www.python.org/dev/peps/pep-0008/) check passed). 60 | - Test coverage didn't decreased, we use [pytest](https://docs.pytest.org/en/latest/). 61 | - With proper docstrings, see codebase as examples. 62 | - With type hints, see [typing](https://docs.python.org/3/library/typing.html). 63 | - All reviewers approved your changes. 64 | 65 | 66 | **Thanks and let's improve MatchZoo-py together!** -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-include matchzoo/datasets/toy * 2 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # Usages: 2 | # 3 | # to install matchzoo dependencies: 4 | # $ make init 5 | # 6 | # to run all matchzoo tests, recommended for big PRs and new versions: 7 | # $ make test 8 | # 9 | # there are three kinds of tests: 10 | # 11 | # 1. "quick" tests 12 | # - run in seconds 13 | # - include all unit tests without marks and all doctests 14 | # - for rapid prototyping 15 | # - CI run this for all PRs 16 | # 17 | # 2. "slow" tests 18 | # - run in minutes 19 | # - include all unit tests marked "slow" 20 | # - CI run this for all PRs 21 | # 22 | # 3. "cron" tests 23 | # - run in minutes 24 | # - involves underministic behavoirs (e.g. network connection) 25 | # - include all unit tests marked "cron" 26 | # - CI run this on a daily basis 27 | # 28 | # to run quick tests, excluding time consuming tests and crons: 29 | # $ make quick 30 | # 31 | # to run slow tests, excluding normal tests and crons: 32 | # $ make slow 33 | # 34 | # to run crons: 35 | # $ make cron 36 | # 37 | # to run all tests: 38 | # $ make test 39 | # 40 | # to run CI push/PR tests: 41 | # $ make push 42 | # 43 | # to run docstring style check: 44 | # $ make flake 45 | 46 | init: 47 | pip install -r requirements.txt 48 | 49 | TEST_ARGS = -v --full-trace -l --doctest-modules --doctest-continue-on-failure --cov matchzoo/ --cov-report term-missing --cov-report html --cov-config .coveragerc matchzoo/ tests/ -W ignore::DeprecationWarning --ignore=matchzoo/contrib 50 | FLAKE_ARGS = ./matchzoo --exclude=__init__.py,matchzoo/contrib 51 | 52 | test: 53 | pytest $(TEST_ARGS) 54 | flake8 $(FLAKE_ARGS) 55 | 56 | push: 57 | pytest -m 'not cron' $(TEST_ARGS) ${ARGS} 58 | flake8 $(FLAKE_ARGS) 59 | 60 | quick: 61 | pytest -m 'not slow and not cron' $(TEST_ARGS) ${ARGS} 62 | 63 | slow: 64 | pytest -m 'slow and not cron' $(TEST_ARGS) ${ARGS} 65 | 66 | cron: 67 | pytest -m 'cron' $(TEST_ARGS) ${ARGS} 68 | 69 | flake: 70 | flake8 $(FLAKE_ARGS) ${ARGS} 71 | -------------------------------------------------------------------------------- /artworks/matchzoo-logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NTMC-Community/MatchZoo-py/0e5c04e1e948aa9277abd5c85ff99d9950d8527f/artworks/matchzoo-logo.png -------------------------------------------------------------------------------- /docs/DOCCHECK.md: -------------------------------------------------------------------------------- 1 | Documentation Checking Process(Only for the developers) 2 | ========================================================== 3 | 4 | # Why 5 | 6 | It is necessary for all the developers to generate the rst files which can help us check the documents. 7 | 8 | # When 9 | 10 | 1. You add a new function to one of the scripts in the {MatchZoo/matchzoo} or its subdirs 11 | 1. You add a new script to {MatchZoo/matchzoo} or its subdirs 12 | 1. You add a new directory to {MatchZoo/matchzoo} or its subdirs 13 | 14 | # How 15 | ## Make sure you have installed sphinx 16 | 17 | 1. Enter the docs directory 18 | 19 | ``` 20 | cd {MatchZoo/docs} 21 | ``` 22 | 23 | 2. Generate the rst files 24 | 25 | ``` 26 | sphinx-apidoc -f -o source ../matchzoo 27 | ``` 28 | 29 | 3. Commit 30 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SPHINXPROJ = MatchZoo 8 | SOURCEDIR = source 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/Readme.md: -------------------------------------------------------------------------------- 1 | ## Build Documentation: 2 | 3 | 4 | 5 | #### Install Requirements 6 | 7 | ```python 8 | pip install -r requirements.txt 9 | ``` 10 | 11 | 12 | 13 | #### Build Documentation 14 | 15 | ```bash 16 | # Enter docs folder. 17 | cd docs 18 | # Use sphinx autodoc to generate rst. 19 | sphinx-apidoc -o source/ ../matchzoo/ 20 | # Generate html from rst 21 | make clean 22 | make html 23 | ``` 24 | This will install all the packages need in the code. This can cause some error [issue](https://github.com/readthedocs/readthedocs.org/issues/5882) 25 | That is not necessary. 26 | 27 | So , we have a new way to generate documents 28 | Follow this [link](https://sphinx-autoapi.readthedocs.io/en/latest/tutorials.html) 29 | ```bash 30 | pip install sphinx-autoapi 31 | ``` 32 | then modify the conf.py 33 | ```bash 34 | extensions = ['autoapi.extension'] 35 | autoapi_dirs = ['../mypackage'] 36 | ``` 37 | then 38 | ```bash 39 | make html 40 | ``` 41 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=_build 12 | set SPHINXPROJ=MatchZoo 13 | 14 | if "%1" == "" goto help 15 | 16 | %SPHINXBUILD% >NUL 2>NUL 17 | if errorlevel 9009 ( 18 | echo. 19 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 20 | echo.installed, then set the SPHINXBUILD environment variable to point 21 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 22 | echo.may add the Sphinx directory to PATH. 23 | echo. 24 | echo.If you don't have Sphinx installed, grab it from 25 | echo.http://sphinx-doc.org/ 26 | exit /b 1 27 | ) 28 | 29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 30 | goto end 31 | 32 | :help 33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 34 | 35 | :end 36 | popd 37 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | Sphinx >= 1.7.5 2 | sphinx_rtd_theme >= 0.4.0 3 | numpy >= 1.12.1 4 | pandas >= 0.23.1 5 | sphinx_autodoc_typehints>=1.6.0 6 | sphinx-autoapi -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. MatchZoo documentation master file, created by 2 | sphinx-quickstart on Mon May 28 16:40:41 2018. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to MatchZoo's documentation! 7 | ==================================== 8 | 9 | 10 | .. image:: https://travis-ci.org/NTMC-Community/MatchZoo-py.svg?branch=master 11 | :alt: ci 12 | :target: https://travis-ci.org/NTMC-Community/MatchZoo-py/ 13 | 14 | 15 | .. image:: ../../artworks/matchzoo-logo.png 16 | :alt: logo 17 | :align: center 18 | 19 | 20 | MatchZoo is a toolkit for text matching. It was developed with a focus on facilitating the designing, comparing and sharing of deep text matching models. There are a number of deep matching methods, such as DRMM, MatchPyramid, MV-LSTM, aNMM, DUET, ARC-I, ARC-II, DSSM, and CDSSM, designed with a unified interface. Potential tasks related to MatchZoo include document retrieval, question answering, conversational response ranking, paraphrase identification, etc. We are always happy to receive any code contributions, suggestions, comments from all our MatchZoo users. 21 | 22 | 23 | .. toctree:: 24 | :maxdepth: 2 25 | :caption: Contents: 26 | 27 | modules 28 | model_reference 29 | 30 | 31 | Indices and tables 32 | ================== 33 | 34 | * :ref:`genindex` 35 | * :ref:`modindex` 36 | * :ref:`search` 37 | -------------------------------------------------------------------------------- /docs/source/modules.rst: -------------------------------------------------------------------------------- 1 | matchzoo 2 | ======== 3 | 4 | .. toctree:: 5 | :maxdepth: 4 6 | 7 | matchzoo 8 | -------------------------------------------------------------------------------- /matchzoo/__init__.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | USER_DIR = Path.expanduser(Path('~')).joinpath('.matchzoo') 4 | if not USER_DIR.exists(): 5 | USER_DIR.mkdir() 6 | USER_DATA_DIR = USER_DIR.joinpath('datasets') 7 | if not USER_DATA_DIR.exists(): 8 | USER_DATA_DIR.mkdir() 9 | USER_TUNED_MODELS_DIR = USER_DIR.joinpath('tuned_models') 10 | 11 | from .version import __version__ 12 | 13 | from .data_pack import DataPack 14 | from .data_pack import pack 15 | from .data_pack import load_data_pack 16 | 17 | from . import preprocessors 18 | from . import dataloader 19 | 20 | from .preprocessors.chain_transform import chain_transform 21 | 22 | from . import auto 23 | from . import tasks 24 | from . import metrics 25 | from . import losses 26 | from . import engine 27 | from . import models 28 | from . import trainers 29 | from . import embedding 30 | from . import datasets 31 | from . import modules 32 | 33 | from .engine import hyper_spaces 34 | from .engine.base_preprocessor import load_preprocessor 35 | from .engine.param import Param 36 | from .engine.param_table import ParamTable 37 | 38 | from .embedding.embedding import Embedding 39 | 40 | from .preprocessors.build_unit_from_data_pack import build_unit_from_data_pack 41 | from .preprocessors.build_vocab_unit import build_vocab_unit 42 | -------------------------------------------------------------------------------- /matchzoo/auto/__init__.py: -------------------------------------------------------------------------------- 1 | from .preparer import prepare 2 | from .preparer import Preparer 3 | 4 | from .tuner import Tuner 5 | from .tuner import tune 6 | 7 | from . import tuner 8 | -------------------------------------------------------------------------------- /matchzoo/auto/preparer/__init__.py: -------------------------------------------------------------------------------- 1 | from .preparer import Preparer 2 | from .prepare import prepare 3 | -------------------------------------------------------------------------------- /matchzoo/auto/preparer/prepare.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import matchzoo as mz 4 | from .preparer import Preparer 5 | from matchzoo.engine.base_task import BaseTask 6 | from matchzoo.engine.base_model import BaseModel 7 | from matchzoo.engine.base_callback import BaseCallback 8 | from matchzoo.engine.base_preprocessor import BasePreprocessor 9 | 10 | 11 | def prepare( 12 | task: BaseTask, 13 | model_class: typing.Type[BaseModel], 14 | data_pack: mz.DataPack, 15 | callback: typing.Optional[BaseCallback] = None, 16 | preprocessor: typing.Optional[BasePreprocessor] = None, 17 | embedding: typing.Optional['mz.Embedding'] = None, 18 | config: typing.Optional[dict] = None, 19 | ): 20 | """ 21 | A simple shorthand for using :class:`matchzoo.Preparer`. 22 | 23 | `config` is used to control specific behaviors. The default `config` 24 | will be updated accordingly if a `config` dictionary is passed. e.g. to 25 | override the default `bin_size`, pass `config={'bin_size': 15}`. 26 | 27 | :param task: Task. 28 | :param model_class: Model class. 29 | :param data_pack: DataPack used to fit the preprocessor. 30 | :param callback: Callback used to padding a batch. 31 | (default: the default callback of `model_class`) 32 | :param preprocessor: Preprocessor used to fit the `data_pack`. 33 | (default: the default preprocessor of `model_class`) 34 | :param embedding: Embedding to build a embedding matrix. If not set, 35 | then a correctly shaped randomized matrix will be built. 36 | :param config: Configuration of specific behaviors. (default: return 37 | value of `mz.Preparer.get_default_config()`) 38 | 39 | :return: A tuple of `(model, preprocessor, data_generator_builder, 40 | embedding_matrix)`. 41 | 42 | """ 43 | preparer = Preparer(task=task, config=config) 44 | return preparer.prepare( 45 | model_class=model_class, 46 | data_pack=data_pack, 47 | callback=callback, 48 | preprocessor=preprocessor, 49 | embedding=embedding 50 | ) 51 | -------------------------------------------------------------------------------- /matchzoo/auto/tuner/__init__.py: -------------------------------------------------------------------------------- 1 | from .tuner import Tuner 2 | from .tune import tune 3 | -------------------------------------------------------------------------------- /matchzoo/data_pack/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_pack import DataPack, load_data_pack 2 | from .pack import pack 3 | -------------------------------------------------------------------------------- /matchzoo/data_pack/pack.py: -------------------------------------------------------------------------------- 1 | """Convert list of input into class:`DataPack` expected format.""" 2 | 3 | import typing 4 | 5 | import pandas as pd 6 | import numpy as np 7 | 8 | import matchzoo 9 | from matchzoo.engine.base_task import BaseTask 10 | 11 | 12 | def pack( 13 | df: pd.DataFrame, 14 | task: typing.Union[str, BaseTask] = 'ranking', 15 | ) -> 'matchzoo.DataPack': 16 | """ 17 | Pack a :class:`DataPack` using `df`. 18 | 19 | The `df` must have `text_left` and `text_right` columns. Optionally, 20 | the `df` can have `id_left`, `id_right` to index `text_left` and 21 | `text_right` respectively. `id_left`, `id_right` will be automatically 22 | generated if not specified. 23 | 24 | :param df: Input :class:`pandas.DataFrame` to use. 25 | :param task: Could be one of `ranking`, `classification` or a 26 | :class:`matchzoo.engine.BaseTask` instance. 27 | 28 | Examples:: 29 | >>> import matchzoo as mz 30 | >>> import pandas as pd 31 | >>> df = pd.DataFrame(data={'text_left': list('AABC'), 32 | ... 'text_right': list('abbc'), 33 | ... 'label': [0, 1, 1, 0]}) 34 | >>> mz.pack(df, task='classification').frame() 35 | id_left text_left id_right text_right label 36 | 0 L-0 A R-0 a 0 37 | 1 L-0 A R-1 b 1 38 | 2 L-1 B R-1 b 1 39 | 3 L-2 C R-2 c 0 40 | >>> mz.pack(df, task='ranking').frame() 41 | id_left text_left id_right text_right label 42 | 0 L-0 A R-0 a 0.0 43 | 1 L-0 A R-1 b 1.0 44 | 2 L-1 B R-1 b 1.0 45 | 3 L-2 C R-2 c 0.0 46 | 47 | """ 48 | if 'text_left' not in df or 'text_right' not in df: 49 | raise ValueError( 50 | 'Input data frame must have `text_left` and `text_right`.') 51 | 52 | df = df.dropna(axis=0, how='any').reset_index(drop=True) 53 | 54 | # Gather IDs 55 | if 'id_left' not in df: 56 | id_left = _gen_ids(df, 'text_left', 'L-') 57 | else: 58 | id_left = df['id_left'] 59 | if 'id_right' not in df: 60 | id_right = _gen_ids(df, 'text_right', 'R-') 61 | else: 62 | id_right = df['id_right'] 63 | 64 | # Build Relation 65 | relation = pd.DataFrame(data={'id_left': id_left, 'id_right': id_right}) 66 | for col in df: 67 | if col not in ['id_left', 'id_right', 'text_left', 'text_right']: 68 | relation[col] = df[col] 69 | if 'label' in relation: 70 | if task == 'classification' or isinstance( 71 | task, matchzoo.tasks.Classification): 72 | relation['label'] = relation['label'].astype(int) 73 | elif task == 'ranking' or isinstance(task, matchzoo.tasks.Ranking): 74 | relation['label'] = relation['label'].astype(float) 75 | else: 76 | raise ValueError(f"{task} is not a valid task.") 77 | 78 | # Build Left and Right 79 | left = _merge(df, id_left, 'text_left', 'id_left') 80 | right = _merge(df, id_right, 'text_right', 'id_right') 81 | return matchzoo.DataPack(relation, left, right) 82 | 83 | 84 | def _merge(data: pd.DataFrame, ids: typing.Union[list, np.array], 85 | text_label: str, id_label: str): 86 | left = pd.DataFrame(data={ 87 | text_label: data[text_label], id_label: ids 88 | }) 89 | left.drop_duplicates(id_label, inplace=True) 90 | left.set_index(id_label, inplace=True) 91 | return left 92 | 93 | 94 | def _gen_ids(data: pd.DataFrame, col: str, prefix: str): 95 | lookup = {} 96 | for text in data[col].unique(): 97 | lookup[text] = prefix + str(len(lookup)) 98 | return data[col].map(lookup) 99 | -------------------------------------------------------------------------------- /matchzoo/dataloader/__init__.py: -------------------------------------------------------------------------------- 1 | from . import callbacks 2 | from .dataset import Dataset 3 | from .dataloader import DataLoader 4 | from .dataloader_builder import DataLoaderBuilder 5 | from .dataset_builder import DatasetBuilder 6 | -------------------------------------------------------------------------------- /matchzoo/dataloader/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | from .lambda_callback import LambdaCallback 2 | from .histogram import Histogram 3 | from .ngram import Ngram 4 | from .padding import BasicPadding 5 | from .padding import DRMMPadding 6 | from .padding import BertPadding 7 | -------------------------------------------------------------------------------- /matchzoo/dataloader/callbacks/histogram.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import matchzoo as mz 4 | from matchzoo.engine.base_callback import BaseCallback 5 | 6 | 7 | class Histogram(BaseCallback): 8 | """ 9 | Generate data with matching histogram. 10 | 11 | :param embedding_matrix: The embedding matrix used to generator match 12 | histogram. 13 | :param bin_size: The number of bin size of the histogram. 14 | :param hist_mode: The mode of the :class:`MatchingHistogramUnit`, one of 15 | `CH`, `NH`, and `LCH`. 16 | """ 17 | 18 | def __init__( 19 | self, 20 | embedding_matrix: np.ndarray, 21 | bin_size: int = 30, 22 | hist_mode: str = 'CH', 23 | ): 24 | """Init.""" 25 | self._match_hist_unit = mz.preprocessors.units.MatchingHistogram( 26 | bin_size=bin_size, 27 | embedding_matrix=embedding_matrix, 28 | normalize=True, 29 | mode=hist_mode 30 | ) 31 | 32 | def on_batch_unpacked(self, x, y): 33 | """Insert `match_histogram` to `x`.""" 34 | x['match_histogram'] = _build_match_histogram(x, self._match_hist_unit) 35 | 36 | 37 | def _trunc_text(input_text: list, length: list) -> list: 38 | """ 39 | Truncating the input text according to the input length. 40 | 41 | :param input_text: The input text need to be truncated. 42 | :param length: The length used to truncated the text. 43 | :return: The truncated text. 44 | """ 45 | return [row[:length[idx]] for idx, row in enumerate(input_text)] 46 | 47 | 48 | def _build_match_histogram( 49 | x: dict, 50 | match_hist_unit: mz.preprocessors.units.MatchingHistogram 51 | ) -> np.ndarray: 52 | """ 53 | Generate the matching hisogram for input. 54 | 55 | :param x: The input `dict`. 56 | :param match_hist_unit: The histogram unit :class:`MatchingHistogramUnit`. 57 | :return: The matching histogram. 58 | """ 59 | match_hist = [] 60 | text_left = x['text_left'].tolist() 61 | text_right = _trunc_text(x['text_right'].tolist(), 62 | x['length_right'].tolist()) 63 | for pair in zip(text_left, text_right): 64 | match_hist.append(match_hist_unit.transform(list(pair))) 65 | return np.asarray(match_hist) 66 | -------------------------------------------------------------------------------- /matchzoo/dataloader/callbacks/lambda_callback.py: -------------------------------------------------------------------------------- 1 | from matchzoo.engine.base_callback import BaseCallback 2 | 3 | 4 | class LambdaCallback(BaseCallback): 5 | """ 6 | LambdaCallback. Just a shorthand for creating a callback class. 7 | 8 | See :class:`matchzoo.engine.base_callback.BaseCallback` for more details. 9 | 10 | Example: 11 | 12 | >>> import matchzoo as mz 13 | >>> from matchzoo.dataloader.callbacks import LambdaCallback 14 | >>> data = mz.datasets.toy.load_data() 15 | >>> batch_func = lambda x: print(type(x)) 16 | >>> unpack_func = lambda x, y: print(type(x), type(y)) 17 | >>> callback = LambdaCallback(on_batch_data_pack=batch_func, 18 | ... on_batch_unpacked=unpack_func) 19 | >>> dataset = mz.dataloader.Dataset( 20 | ... data, callbacks=[callback]) 21 | >>> _ = dataset[0] 22 | 23 | 24 | 25 | """ 26 | 27 | def __init__(self, on_batch_data_pack=None, on_batch_unpacked=None): 28 | """Init.""" 29 | self._on_batch_unpacked = on_batch_unpacked 30 | self._on_batch_data_pack = on_batch_data_pack 31 | 32 | def on_batch_data_pack(self, data_pack): 33 | """`on_batch_data_pack`.""" 34 | if self._on_batch_data_pack: 35 | self._on_batch_data_pack(data_pack) 36 | 37 | def on_batch_unpacked(self, x, y): 38 | """`on_batch_unpacked`.""" 39 | if self._on_batch_unpacked: 40 | self._on_batch_unpacked(x, y) 41 | -------------------------------------------------------------------------------- /matchzoo/dataloader/dataloader_builder.py: -------------------------------------------------------------------------------- 1 | import matchzoo as mz 2 | from matchzoo.dataloader import DataLoader 3 | 4 | 5 | class DataLoaderBuilder(object): 6 | """ 7 | DataLoader Bulider. In essense a wrapped partial function. 8 | 9 | Example: 10 | >>> import matchzoo as mz 11 | >>> padding_callback = mz.dataloader.callbacks.BasicPadding() 12 | >>> builder = mz.dataloader.DataLoaderBuilder( 13 | ... stage='train', callback=padding_callback 14 | ... ) 15 | >>> data_pack = mz.datasets.toy.load_data() 16 | >>> preprocessor = mz.preprocessors.BasicPreprocessor() 17 | >>> data_processed = preprocessor.fit_transform(data_pack) 18 | >>> dataset = mz.dataloader.Dataset(data_processed, mode='point') 19 | >>> dataloder = builder.build(dataset) 20 | >>> type(dataloder) 21 | 22 | 23 | """ 24 | 25 | def __init__(self, **kwargs): 26 | """Init.""" 27 | self._kwargs = kwargs 28 | 29 | def build(self, dataset, **kwargs) -> DataLoader: 30 | """ 31 | Build a DataLoader. 32 | 33 | :param dataset: Dataset to build upon. 34 | :param kwargs: Additional keyword arguments to override the keyword 35 | arguments passed in `__init__`. 36 | """ 37 | return mz.dataloader.DataLoader( 38 | dataset, **{**self._kwargs, **kwargs} 39 | ) 40 | -------------------------------------------------------------------------------- /matchzoo/dataloader/dataset_builder.py: -------------------------------------------------------------------------------- 1 | import matchzoo as mz 2 | from matchzoo.dataloader import Dataset 3 | 4 | 5 | class DatasetBuilder(object): 6 | """ 7 | Dataset Bulider. In essense a wrapped partial function. 8 | 9 | Example: 10 | >>> import matchzoo as mz 11 | >>> builder = mz.dataloader.DatasetBuilder( 12 | ... mode='point' 13 | ... ) 14 | >>> data = mz.datasets.toy.load_data() 15 | >>> gen = builder.build(data) 16 | >>> type(gen) 17 | 18 | 19 | """ 20 | 21 | def __init__(self, **kwargs): 22 | """Init.""" 23 | self._kwargs = kwargs 24 | 25 | def build(self, data_pack, **kwargs) -> Dataset: 26 | """ 27 | Build a Dataset. 28 | 29 | :param data_pack: DataPack to build upon. 30 | :param kwargs: Additional keyword arguments to override the keyword 31 | arguments passed in `__init__`. 32 | """ 33 | return mz.dataloader.Dataset( 34 | data_pack, **{**self._kwargs, **kwargs} 35 | ) 36 | -------------------------------------------------------------------------------- /matchzoo/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from . import toy 2 | from . import wiki_qa 3 | from . import embeddings 4 | from . import snli 5 | from . import quora_qp 6 | from pathlib import Path 7 | 8 | 9 | def list_available(): 10 | return [p.name for p in Path(__file__).parent.iterdir() 11 | if p.is_dir() and not p.name.startswith('_')] 12 | -------------------------------------------------------------------------------- /matchzoo/datasets/embeddings/__init__.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from .load_glove_embedding import load_glove_embedding 3 | from .load_fasttext_embedding import load_fasttext_embedding 4 | 5 | DATA_ROOT = Path(__file__).parent 6 | EMBED_RANK = DATA_ROOT.joinpath('embed_rank.txt') 7 | EMBED_10 = DATA_ROOT.joinpath('embed_10_word2vec.txt') 8 | EMBED_10_GLOVE = DATA_ROOT.joinpath('embed_10_glove.txt') 9 | -------------------------------------------------------------------------------- /matchzoo/datasets/embeddings/embed_10_glove.txt: -------------------------------------------------------------------------------- 1 | A 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 2 | B 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 3 | C 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 4 | D 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 5 | E 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 6 | -------------------------------------------------------------------------------- /matchzoo/datasets/embeddings/embed_10_word2vec.txt: -------------------------------------------------------------------------------- 1 | 5 10 2 | A 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 3 | B 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 4 | C 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 5 | D 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 6 | E 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 7 | -------------------------------------------------------------------------------- /matchzoo/datasets/embeddings/embed_err.txt.gb2312: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NTMC-Community/MatchZoo-py/0e5c04e1e948aa9277abd5c85ff99d9950d8527f/matchzoo/datasets/embeddings/embed_err.txt.gb2312 -------------------------------------------------------------------------------- /matchzoo/datasets/embeddings/embed_word.txt: -------------------------------------------------------------------------------- 1 | 7 5 2 | asia 1 2 3 4 5 3 | beijing 1 1 1 1 1 4 | hot 2 2 2 2 2 5 | east 3 3 3 3 3 6 | capital 4 4 4 4 4 7 | china 5 5 5 5 5 8 | -------------------------------------------------------------------------------- /matchzoo/datasets/embeddings/load_fasttext_embedding.py: -------------------------------------------------------------------------------- 1 | """FastText embedding data loader.""" 2 | 3 | from pathlib import Path 4 | 5 | import matchzoo as mz 6 | 7 | _fasttext_embedding_url = "https://dl.fbaipublicfiles.com/fasttext/vectors" \ 8 | "-wiki/wiki.{}.vec" 9 | 10 | 11 | def load_fasttext_embedding(language: str = 'en') -> mz.embedding.Embedding: 12 | """ 13 | Return the pretrained fasttext embedding. 14 | 15 | :param language: the language of embedding. Supported language can be 16 | referred to "https://github.com/facebookresearch/fastText/blob/master" 17 | "/docs/pretrained-vectors.md" 18 | :return: The :class:`mz.embedding.Embedding` object. 19 | """ 20 | file_name = _fasttext_embedding_url.split('/')[-1].format(language) 21 | file_path = (Path(mz.USER_DATA_DIR) / 'fasttext').joinpath(file_name) 22 | if not file_path.exists(): 23 | mz.utils.get_file(file_name, 24 | _fasttext_embedding_url.format(language), 25 | extract=False, 26 | cache_dir=mz.USER_DATA_DIR, 27 | cache_subdir='fasttext') 28 | return mz.embedding.load_from_file(file_path=str(file_path), 29 | mode='fasttext') 30 | -------------------------------------------------------------------------------- /matchzoo/datasets/embeddings/load_glove_embedding.py: -------------------------------------------------------------------------------- 1 | """GloVe Embedding data loader.""" 2 | 3 | from pathlib import Path 4 | 5 | import matchzoo as mz 6 | 7 | _glove_embedding_url = "http://nlp.stanford.edu/data/glove.6B.zip" 8 | 9 | 10 | def load_glove_embedding(dimension: int = 50) -> mz.embedding.Embedding: 11 | """ 12 | Return the pretrained glove embedding. 13 | 14 | :param dimension: the size of embedding dimension, the value can only be 15 | 50, 100, or 300. 16 | :return: The :class:`mz.embedding.Embedding` object. 17 | """ 18 | file_name = 'glove.6B.' + str(dimension) + 'd.txt' 19 | file_path = (Path(mz.USER_DATA_DIR) / 'glove').joinpath(file_name) 20 | if not file_path.exists(): 21 | mz.utils.get_file('glove_embedding', 22 | _glove_embedding_url, 23 | extract=True, 24 | cache_dir=mz.USER_DATA_DIR, 25 | cache_subdir='glove') 26 | return mz.embedding.load_from_file(file_path=str(file_path), mode='glove') 27 | -------------------------------------------------------------------------------- /matchzoo/datasets/quora_qp/__init__.py: -------------------------------------------------------------------------------- 1 | from .load_data import load_data 2 | -------------------------------------------------------------------------------- /matchzoo/datasets/quora_qp/load_data.py: -------------------------------------------------------------------------------- 1 | """Quora Question Pairs data loader.""" 2 | 3 | import typing 4 | from pathlib import Path 5 | 6 | import pandas as pd 7 | 8 | import matchzoo 9 | from matchzoo.engine.base_task import BaseTask 10 | 11 | _url = "https://firebasestorage.googleapis.com/v0/b/mtl-sentence" \ 12 | "-representations.appspot.com/o/data%2FQQP.zip?alt=media&" \ 13 | "token=700c6acf-160d-4d89-81d1-de4191d02cb5" 14 | 15 | 16 | def load_data( 17 | stage: str = 'train', 18 | task: typing.Union[str, BaseTask] = 'classification', 19 | return_classes: bool = False, 20 | ) -> typing.Union[matchzoo.DataPack, tuple]: 21 | """ 22 | Load QuoraQP data. 23 | 24 | :param path: `None` for download from quora, specific path for 25 | downloaded data. 26 | :param stage: One of `train`, `dev`, and `test`. 27 | :param task: Could be one of `ranking`, `classification` or a 28 | :class:`matchzoo.engine.BaseTask` instance. 29 | :param return_classes: Whether return classes for classification task. 30 | :return: A DataPack if `ranking`, a tuple of (DataPack, classes) if 31 | `classification`. 32 | """ 33 | if stage not in ('train', 'dev', 'test'): 34 | raise ValueError(f"{stage} is not a valid stage." 35 | f"Must be one of `train`, `dev`, and `test`.") 36 | 37 | data_root = _download_data() 38 | file_path = data_root.joinpath(f"{stage}.tsv") 39 | data_pack = _read_data(file_path, stage, task) 40 | 41 | if task == 'ranking' or isinstance(task, matchzoo.tasks.Ranking): 42 | return data_pack 43 | elif task == 'classification' or isinstance( 44 | task, matchzoo.tasks.Classification): 45 | if return_classes: 46 | return data_pack, [False, True] 47 | else: 48 | return data_pack 49 | else: 50 | raise ValueError(f"{task} is not a valid task.") 51 | 52 | 53 | def _download_data(): 54 | ref_path = matchzoo.utils.get_file( 55 | 'quora_qp', _url, extract=True, 56 | cache_dir=matchzoo.USER_DATA_DIR, 57 | cache_subdir='quora_qp' 58 | ) 59 | return Path(ref_path).parent.joinpath('QQP') 60 | 61 | 62 | def _read_data(path, stage, task): 63 | data = pd.read_csv(path, sep='\t', error_bad_lines=False, dtype=object) 64 | data = data.dropna(axis=0, how='any').reset_index(drop=True) 65 | if stage in ['train', 'dev']: 66 | df = pd.DataFrame({ 67 | 'id_left': data['qid1'], 68 | 'id_right': data['qid2'], 69 | 'text_left': data['question1'], 70 | 'text_right': data['question2'], 71 | 'label': data['is_duplicate'].astype(int) 72 | }) 73 | else: 74 | df = pd.DataFrame({ 75 | 'text_left': data['question1'], 76 | 'text_right': data['question2'] 77 | }) 78 | return matchzoo.pack(df, task) 79 | -------------------------------------------------------------------------------- /matchzoo/datasets/snli/__init__.py: -------------------------------------------------------------------------------- 1 | from .load_data import load_data 2 | -------------------------------------------------------------------------------- /matchzoo/datasets/snli/load_data.py: -------------------------------------------------------------------------------- 1 | """SNLI data loader.""" 2 | 3 | import typing 4 | from pathlib import Path 5 | 6 | import pandas as pd 7 | 8 | import matchzoo 9 | from matchzoo.engine.base_task import BaseTask 10 | 11 | _url = "https://nlp.stanford.edu/projects/snli/snli_1.0.zip" 12 | 13 | 14 | def load_data( 15 | stage: str = 'train', 16 | task: typing.Union[str, BaseTask] = 'classification', 17 | target_label: str = 'entailment', 18 | return_classes: bool = False 19 | ) -> typing.Union[matchzoo.DataPack, tuple]: 20 | """ 21 | Load SNLI data. 22 | 23 | :param stage: One of `train`, `dev`, and `test`. (default: `train`) 24 | :param task: Could be one of `ranking`, `classification` or a 25 | :class:`matchzoo.engine.BaseTask` instance. (default: `classification`) 26 | :param target_label: If `ranking`, chose one of `entailment`, 27 | `contradiction` and `neutral` as the positive label. 28 | (default: `entailment`) 29 | :param return_classes: `True` to return classes for classification task, 30 | `False` otherwise. 31 | 32 | :return: A DataPack unless `task` is `classificiation` and `return_classes` 33 | is `True`: a tuple of `(DataPack, classes)` in that case. 34 | """ 35 | if stage not in ('train', 'dev', 'test'): 36 | raise ValueError(f"{stage} is not a valid stage." 37 | f"Must be one of `train`, `dev`, and `test`.") 38 | 39 | data_root = _download_data() 40 | file_path = data_root.joinpath(f'snli_1.0_{stage}.txt') 41 | data_pack = _read_data(file_path, task, target_label) 42 | 43 | if task == 'ranking' or isinstance(task, matchzoo.tasks.Ranking): 44 | return data_pack 45 | elif task == 'classification' or isinstance( 46 | task, matchzoo.tasks.Classification): 47 | classes = ['entailment', 'contradiction', 'neutral'] 48 | if return_classes: 49 | return data_pack, classes 50 | else: 51 | return data_pack 52 | else: 53 | raise ValueError(f"{task} is not a valid task." 54 | f"Must be one of `Ranking` and `Classification`.") 55 | 56 | 57 | def _download_data(): 58 | ref_path = matchzoo.utils.get_file( 59 | 'snli', _url, extract=True, 60 | cache_dir=matchzoo.USER_DATA_DIR, 61 | cache_subdir='snli' 62 | ) 63 | return Path(ref_path).parent.joinpath('snli_1.0') 64 | 65 | 66 | def _read_data(path, task, target_label): 67 | table = pd.read_csv(path, sep='\t') 68 | df = pd.DataFrame({ 69 | 'text_left': table['sentence1'], 70 | 'text_right': table['sentence2'], 71 | 'label': table['gold_label'] 72 | }) 73 | df = df.dropna(axis=0, how='any').reset_index(drop=True) 74 | 75 | filter_id = df[df['label'] == '-'].index.tolist() 76 | df.drop(filter_id, inplace=True) 77 | 78 | if task == 'ranking' or isinstance(task, matchzoo.tasks.Ranking): 79 | if target_label not in ['entailment', 'contradiction', 'neutral']: 80 | raise ValueError(f"{target_label} is not a valid target label." 81 | f"Must be one of `entailment`, `contradiction`" 82 | f" and `neutral`") 83 | df['label'] = (df['label'] == target_label) 84 | elif task == 'classification' or isinstance( 85 | task, matchzoo.tasks.Classification): 86 | classes = ['entailment', 'contradiction', 'neutral'] 87 | df['label'] = df['label'].apply(classes.index) 88 | else: 89 | raise ValueError(f"{task} is not a valid task." 90 | f"Must be one of `Ranking` and `Classification`.") 91 | 92 | return matchzoo.pack(df, task) 93 | -------------------------------------------------------------------------------- /matchzoo/datasets/toy/__init__.py: -------------------------------------------------------------------------------- 1 | import typing 2 | from pathlib import Path 3 | 4 | import pandas as pd 5 | 6 | import matchzoo 7 | from matchzoo.engine.base_task import BaseTask 8 | 9 | 10 | def load_data( 11 | stage: str = 'train', 12 | task: typing.Union[str, BaseTask] = 'ranking', 13 | return_classes: bool = False 14 | ) -> typing.Union[matchzoo.DataPack, typing.Tuple[matchzoo.DataPack, list]]: 15 | """ 16 | Load toy data. 17 | 18 | :param stage: One of `train`, `dev`, and `test`. 19 | :param task: Could be one of `ranking`, `classification` or a 20 | :class:`matchzoo.engine.BaseTask` instance. 21 | :param return_classes: `True` to return classes for classification task, 22 | `False` otherwise. 23 | 24 | :return: A DataPack unless `task` is `classificiation` and `return_classes` 25 | is `True`: a tuple of `(DataPack, classes)` in that case. 26 | 27 | Example: 28 | >>> import matchzoo as mz 29 | >>> stages = 'train', 'dev', 'test' 30 | >>> tasks = 'ranking', 'classification' 31 | >>> for stage in stages: 32 | ... for task in tasks: 33 | ... _ = mz.datasets.toy.load_data(stage, task) 34 | """ 35 | if stage not in ('train', 'dev', 'test'): 36 | raise ValueError(f"{stage} is not a valid stage." 37 | f"Must be one of `train`, `dev`, and `test`.") 38 | 39 | path = Path(__file__).parent.joinpath(f'{stage}.csv') 40 | data_pack = matchzoo.pack(pd.read_csv(path, index_col=0), task) 41 | 42 | if task == 'ranking' or isinstance(task, matchzoo.tasks.Ranking): 43 | return data_pack 44 | elif task == 'classification' or isinstance( 45 | task, matchzoo.tasks.Classification): 46 | if return_classes: 47 | return data_pack, [False, True] 48 | else: 49 | return data_pack 50 | else: 51 | raise ValueError(f"{task} is not a valid task." 52 | f"Must be one of `Ranking` and `Classification`.") 53 | 54 | 55 | def load_embedding(): 56 | path = Path(__file__).parent.joinpath('embedding.2d.txt') 57 | return matchzoo.embedding.load_from_file(path, mode='glove') 58 | -------------------------------------------------------------------------------- /matchzoo/datasets/toy/dev.csv: -------------------------------------------------------------------------------- 1 | ,id_left,text_left,id_right,text_right,label 2 | 0,Q18,how a rocket engine works,D18-0,RS-68 being tested at NASA's Stennis Space Center.,0.0 3 | 1,Q18,how a rocket engine works,D18-1,"The nearly transparent exhaust is due to this engine's exhaust being mostly superheated steam (water vapor from its propellants, hydrogen and oxygen)",0.0 4 | 2,Q18,how a rocket engine works,D18-2,Viking 5C rocket engine,0.0 5 | 3,Q18,how a rocket engine works,D18-3,"A rocket engine, or simply ""rocket"", is a jet engine that uses only stored propellant mass for forming its high speed propulsive jet .",1.0 6 | 4,Q18,how a rocket engine works,D18-4,Rocket engines are reaction engines and obtain thrust in accordance with Newton's third law .,0.0 7 | 5,Q18,how a rocket engine works,D18-5,"Since they need no external material to form their jet, rocket engines can be used for spacecraft propulsion as well as terrestrial uses, such as missiles .",0.0 8 | 6,Q18,how a rocket engine works,D18-6,"Most rocket engines are internal combustion engines , although non-combusting forms also exist.",0.0 9 | 7,Q18,how a rocket engine works,D18-7,"Rocket engines as a group have the highest exhaust velocities, are by far the lightest, but are the least propellant efficient of all types of jet engines.",0.0 10 | 8,Q19,how old was britney spears when she recorded hit me baby one more time,D19-0,"""...Baby One More Time"" is a song by American recording artist Britney Spears .",0.0 11 | 9,Q19,how old was britney spears when she recorded hit me baby one more time,D19-1,"It served as Spears's debut single and title track from her debut studio album, ...Baby One More Time (1999).",0.0 12 | 10,Q19,how old was britney spears when she recorded hit me baby one more time,D19-2,"Written by Max Martin and produced by Martin and Rami , ""...Baby One More Time"" was released on September 30, 1998, by Jive Records .",0.0 13 | 11,Q19,how old was britney spears when she recorded hit me baby one more time,D19-3,"After recording and sending a demo tape with an unused song from Toni Braxton , Spears signed a multi-album deal with Jive.",0.0 14 | 12,Q19,how old was britney spears when she recorded hit me baby one more time,D19-4,"""...Baby One More Time"" is a teen pop and dance-pop song that refers to a girl who regrets breaking up with her boyfriend.",0.0 15 | 13,Q19,how old was britney spears when she recorded hit me baby one more time,D19-5,"The song received generally favorable reviews from critics, who praised its composition.",0.0 16 | 14,Q19,how old was britney spears when she recorded hit me baby one more time,D19-6,"""...Baby One More Time"" attained global success, reaching number one in every country it charted.",0.0 17 | 15,Q19,how old was britney spears when she recorded hit me baby one more time,D19-7,"It also received numerous certifications around the world, and is one of the best-selling singles of all time , with over 10 million copies sold.",0.0 18 | 16,Q19,how old was britney spears when she recorded hit me baby one more time,D19-8,"An accompanying music video, directed by Nigel Dick , portrays Spears as a student from a Catholic high school, who starts to daydream that she is singing and dancing around the school, while watching her love interest from afar.",0.0 19 | 17,Q19,how old was britney spears when she recorded hit me baby one more time,D19-9,"The music video was later referenced in the music video of "" If U Seek Amy "" (2008), where Spears's fictional daughter is dressed with a similar schoolgirl outfit while wearing pink ribbons in her hair.",0.0 20 | 18,Q19,how old was britney spears when she recorded hit me baby one more time,D19-10,"In 2010, the music video for ""...Baby One More Time"" was voted the third most influential video in the history of pop music, in a poll held by Jam! .",0.0 21 | 19,Q19,how old was britney spears when she recorded hit me baby one more time,D19-11,"Spears has performed ""...Baby One More Time"" in a number of live appearances and in all of her concert tours.",0.0 22 | -------------------------------------------------------------------------------- /matchzoo/datasets/toy/test.csv: -------------------------------------------------------------------------------- 1 | ,id_left,text_left,id_right,text_right,label 2 | 0,Q19,how old was britney spears when she recorded hit me baby one more time,D19-12,It was the encore of the ...Baby One More Time Tour (1999) and Dream Within a Dream Tour (2001); Spears also performed remixed versions of the song during the Oops!...,0.0 3 | 1,Q19,how old was britney spears when she recorded hit me baby one more time,D19-13,"I Did It Again World Tour (2000), The Onyx Hotel Tour (2004), The M+M's Tour (2007), The Circus Starring Britney Spears (2009), and the Femme Fatale Tour (2011).",0.0 4 | 2,Q19,how old was britney spears when she recorded hit me baby one more time,D19-14,"""...Baby One More Time"" was nominated for a Grammy Award for Best Female Pop Vocal Performance , and has been included in lists by Blender , Rolling Stone and VH1 .",0.0 5 | 3,Q19,how old was britney spears when she recorded hit me baby one more time,D19-15,It has been noted for redefining the sound of late 1990s music.,0.0 6 | 4,Q19,how old was britney spears when she recorded hit me baby one more time,D19-16,"Spears has named ""...Baby One More Time"" as one of her favorite songs from her career.",0.0 7 | 5,Q19,how old was britney spears when she recorded hit me baby one more time,D19-17,It was also the final song to be played on the BBC 's music programme Top of the Pops .,0.0 8 | 6,Q21,how are cholera and typhus transmitted and prevented,D21-0,Cholera is an infection in the small intestine caused by the bacterium Vibrio cholerae .,0.0 9 | 7,Q21,how are cholera and typhus transmitted and prevented,D21-1,The main symptoms are watery diarrhea and vomiting .,0.0 10 | 8,Q21,how are cholera and typhus transmitted and prevented,D21-2,"Transmission occurs primarily by drinking water or eating food that has been contaminated by the feces (waste product) of an infected person, including one with no apparent symptoms.",1.0 11 | 9,Q21,how are cholera and typhus transmitted and prevented,D21-3,"The severity of the diarrhea and vomiting can lead to rapid dehydration and electrolyte imbalance, and death in some cases.",0.0 12 | 10,Q21,how are cholera and typhus transmitted and prevented,D21-4,"The primary treatment is oral rehydration therapy , typically with oral rehydration solution , to replace water and electrolytes.",0.0 13 | 11,Q21,how are cholera and typhus transmitted and prevented,D21-5,"If this is not tolerated or does not provide improvement fast enough, intravenous fluids can also be used.",0.0 14 | 12,Q21,how are cholera and typhus transmitted and prevented,D21-6,Antibacterial drugs are beneficial in those with severe disease to shorten its duration and severity.,0.0 15 | 13,Q21,how are cholera and typhus transmitted and prevented,D21-7,"Worldwide, it affects 3–5 million people and causes 100,000–130,000 deaths a year .",0.0 16 | 14,Q21,how are cholera and typhus transmitted and prevented,D21-8,Cholera was one of the earliest infections to be studied by epidemiological methods.,0.0 17 | 15,Q22,how old is sybil vane in the picture of dorian gray,D22-0,"The Picture of Dorian Gray is the only published novel by Oscar Wilde , appearing as the lead story in Lippincott's Monthly Magazine on 20 June 1890, printed as the July 1890 issue of this magazine.",0.0 18 | 16,Q22,how old is sybil vane in the picture of dorian gray,D22-1,"The magazine's editors feared the story was indecent as submitted, so they censored roughly 500 words, without Wilde's knowledge, before publication.",0.0 19 | 17,Q22,how old is sybil vane in the picture of dorian gray,D22-2,"But even with that, the story was still greeted with outrage by British reviewers, some of whom suggested that Wilde should be prosecuted on moral grounds, leading Wilde to defend the novel aggressively in letters to the British press.",0.0 20 | 18,Q22,how old is sybil vane in the picture of dorian gray,D22-3,"Wilde later revised the story for book publication, making substantial alterations, deleting controversial passages, adding new chapters and including an aphoristic Preface which has since become famous in its own right.",0.0 21 | 19,Q22,how old is sybil vane in the picture of dorian gray,D22-4,"The amended version was published by Ward, Lock and Company in April 1891.",0.0 22 | -------------------------------------------------------------------------------- /matchzoo/datasets/wiki_qa/__init__.py: -------------------------------------------------------------------------------- 1 | from .load_data import load_data 2 | -------------------------------------------------------------------------------- /matchzoo/datasets/wiki_qa/load_data.py: -------------------------------------------------------------------------------- 1 | """WikiQA data loader.""" 2 | 3 | import typing 4 | import csv 5 | from pathlib import Path 6 | 7 | import pandas as pd 8 | 9 | import matchzoo 10 | from matchzoo.engine.base_task import BaseTask 11 | 12 | _url = "https://download.microsoft.com/download/E/5/F/" \ 13 | "E5FCFCEE-7005-4814-853D-DAA7C66507E0/WikiQACorpus.zip" 14 | 15 | 16 | def load_data( 17 | stage: str = 'train', 18 | task: typing.Union[str, BaseTask] = 'ranking', 19 | filtered: bool = False, 20 | return_classes: bool = False 21 | ) -> typing.Union[matchzoo.DataPack, tuple]: 22 | """ 23 | Load WikiQA data. 24 | 25 | :param stage: One of `train`, `dev`, and `test`. 26 | :param task: Could be one of `ranking`, `classification` or a 27 | :class:`matchzoo.engine.BaseTask` instance. 28 | :param filtered: Whether remove the questions without correct answers. 29 | :param return_classes: `True` to return classes for classification task, 30 | `False` otherwise. 31 | 32 | :return: A DataPack unless `task` is `classificiation` and `return_classes` 33 | is `True`: a tuple of `(DataPack, classes)` in that case. 34 | """ 35 | if stage not in ('train', 'dev', 'test'): 36 | raise ValueError(f"{stage} is not a valid stage." 37 | f"Must be one of `train`, `dev`, and `test`.") 38 | 39 | data_root = _download_data() 40 | file_path = data_root.joinpath(f'WikiQA-{stage}.tsv') 41 | data_pack = _read_data(file_path, task) 42 | if filtered and stage in ('dev', 'test'): 43 | ref_path = data_root.joinpath(f'WikiQA-{stage}.ref') 44 | filter_ref_path = data_root.joinpath(f'WikiQA-{stage}-filtered.ref') 45 | with open(filter_ref_path, mode='r') as f: 46 | filtered_ids = set([line.split()[0] for line in f]) 47 | filtered_lines = [] 48 | with open(ref_path, mode='r') as f: 49 | for idx, line in enumerate(f.readlines()): 50 | if line.split()[0] in filtered_ids: 51 | filtered_lines.append(idx) 52 | data_pack = data_pack[filtered_lines] 53 | 54 | if task == 'ranking' or isinstance(task, matchzoo.tasks.Ranking): 55 | return data_pack 56 | elif task == 'classification' or isinstance( 57 | task, matchzoo.tasks.Classification): 58 | if return_classes: 59 | return data_pack, [False, True] 60 | else: 61 | return data_pack 62 | else: 63 | raise ValueError(f"{task} is not a valid task." 64 | f"Must be one of `Ranking` and `Classification`.") 65 | 66 | 67 | def _download_data(): 68 | ref_path = matchzoo.utils.get_file( 69 | 'wikiqa', _url, extract=True, 70 | cache_dir=matchzoo.USER_DATA_DIR, 71 | cache_subdir='wiki_qa' 72 | ) 73 | return Path(ref_path).parent.joinpath('WikiQACorpus') 74 | 75 | 76 | def _read_data(path, task): 77 | table = pd.read_csv(path, sep='\t', header=0, quoting=csv.QUOTE_NONE) 78 | df = pd.DataFrame({ 79 | 'text_left': table['Question'], 80 | 'text_right': table['Sentence'], 81 | 'id_left': table['QuestionID'], 82 | 'id_right': table['SentenceID'], 83 | 'label': table['Label'] 84 | }) 85 | return matchzoo.pack(df, task) 86 | -------------------------------------------------------------------------------- /matchzoo/embedding/__init__.py: -------------------------------------------------------------------------------- 1 | from .embedding import Embedding 2 | from .embedding import load_from_file 3 | -------------------------------------------------------------------------------- /matchzoo/embedding/embedding.py: -------------------------------------------------------------------------------- 1 | """Matchzoo toolkit for token embedding.""" 2 | 3 | import csv 4 | import typing 5 | 6 | import numpy as np 7 | import pandas as pd 8 | 9 | import matchzoo as mz 10 | 11 | 12 | class Embedding(object): 13 | """ 14 | Embedding class. 15 | 16 | Examples:: 17 | >>> import matchzoo as mz 18 | >>> train_raw = mz.datasets.toy.load_data() 19 | >>> pp = mz.preprocessors.NaivePreprocessor() 20 | >>> train = pp.fit_transform(train_raw, verbose=0) 21 | >>> vocab_unit = mz.build_vocab_unit(train, verbose=0) 22 | >>> term_index = vocab_unit.state['term_index'] 23 | >>> embed_path = mz.datasets.embeddings.EMBED_RANK 24 | 25 | To load from a file: 26 | >>> embedding = mz.embedding.load_from_file(embed_path) 27 | >>> matrix = embedding.build_matrix(term_index) 28 | >>> matrix.shape[0] == len(term_index) 29 | True 30 | 31 | To build your own: 32 | >>> data = {'A':[0, 1], 'B':[2, 3]} 33 | >>> embedding = mz.Embedding(data, 2) 34 | >>> matrix = embedding.build_matrix({'A': 2, 'B': 1, '_PAD': 0}) 35 | >>> matrix.shape == (3, 2) 36 | True 37 | 38 | """ 39 | 40 | def __init__(self, data: dict, output_dim: int): 41 | """ 42 | Embedding. 43 | 44 | :param data: Dictionary to use as term to vector mapping. 45 | :param output_dim: The dimension of embedding. 46 | """ 47 | self._data = data 48 | self._output_dim = output_dim 49 | 50 | def build_matrix( 51 | self, 52 | term_index: typing.Union[ 53 | dict, mz.preprocessors.units.Vocabulary.TermIndex] 54 | ) -> np.ndarray: 55 | """ 56 | Build a matrix using `term_index`. 57 | 58 | :param term_index: A `dict` or `TermIndex` to build with. 59 | :param initializer: A callable that returns a default value for missing 60 | terms in data. (default: a random uniform distribution in range) 61 | `(-0.2, 0.2)`). 62 | :return: A matrix. 63 | """ 64 | input_dim = len(term_index) 65 | matrix = np.empty((input_dim, self._output_dim)) 66 | 67 | valid_keys = self._data.keys() 68 | for term, index in term_index.items(): 69 | if term in valid_keys: 70 | matrix[index] = self._data[term] 71 | else: 72 | matrix[index] = np.random.uniform(-0.2, 0.2, size=self._output_dim) 73 | 74 | return matrix 75 | 76 | 77 | def load_from_file(file_path: str, mode: str = 'word2vec') -> Embedding: 78 | """ 79 | Load embedding from `file_path`. 80 | 81 | :param file_path: Path to file. 82 | :param mode: Embedding file format mode, one of 'word2vec', 'fasttext' 83 | or 'glove'.(default: 'word2vec') 84 | :return: An :class:`matchzoo.embedding.Embedding` instance. 85 | """ 86 | embedding_data = {} 87 | output_dim = 0 88 | if mode == 'word2vec' or mode == 'fasttext': 89 | with open(file_path, 'r') as f: 90 | output_dim = int(f.readline().strip().split(' ')[-1]) 91 | for line in f: 92 | current_line = line.rstrip().split(' ') 93 | embedding_data[current_line[0]] = current_line[1:] 94 | elif mode == 'glove': 95 | with open(file_path, 'r') as f: 96 | output_dim = len(f.readline().rstrip().split(' ')) - 1 97 | f.seek(0) 98 | for line in f: 99 | current_line = line.rstrip().split(' ') 100 | embedding_data[current_line[0]] = current_line[1:] 101 | else: 102 | raise TypeError(f"{mode} is not a supported embedding type." 103 | f"`word2vec`, `fasttext` or `glove` expected.") 104 | return Embedding(embedding_data, output_dim) 105 | -------------------------------------------------------------------------------- /matchzoo/engine/__init__.py: -------------------------------------------------------------------------------- 1 | # `engine` dependencies span across the entire project, so it's better to 2 | # leave this __init__.py empty, and use `from matchzoo.engine.package import 3 | # x` or `from matchzoo.engine import package` instead of `from matchzoo 4 | # import engine`. 5 | -------------------------------------------------------------------------------- /matchzoo/engine/base_callback.py: -------------------------------------------------------------------------------- 1 | """Base callback.""" 2 | import abc 3 | 4 | import numpy as np 5 | 6 | import matchzoo as mz 7 | 8 | 9 | class BaseCallback(abc.ABC): 10 | """ 11 | DataGenerator callback base class. 12 | 13 | To build your own callbacks, inherit `mz.data_generator.callbacks.Callback` 14 | and overrides corresponding methods. 15 | 16 | A batch is processed in the following way: 17 | 18 | - slice data pack based on batch index 19 | - handle `on_batch_data_pack` callbacks 20 | - unpack data pack into x, y 21 | - handle `on_batch_x_y` callbacks 22 | - return x, y 23 | 24 | """ 25 | 26 | def on_batch_data_pack(self, data_pack: mz.DataPack): 27 | """ 28 | `on_batch_data_pack`. 29 | 30 | :param data_pack: a sliced DataPack before unpacking. 31 | """ 32 | 33 | @abc.abstractmethod 34 | def on_batch_unpacked(self, x: dict, y: np.ndarray): 35 | """ 36 | `on_batch_unpacked`. 37 | 38 | :param x: unpacked x. 39 | :param y: unpacked y. 40 | """ 41 | -------------------------------------------------------------------------------- /matchzoo/engine/base_metric.py: -------------------------------------------------------------------------------- 1 | """Metric base class and some related utilities.""" 2 | 3 | import abc 4 | 5 | import numpy as np 6 | 7 | 8 | class BaseMetric(abc.ABC): 9 | """Metric base class.""" 10 | 11 | ALIAS = 'base_metric' 12 | 13 | @abc.abstractmethod 14 | def __call__(self, y_true: np.array, y_pred: np.array) -> float: 15 | """ 16 | Call to compute the metric. 17 | 18 | :param y_true: An array of groud truth labels. 19 | :param y_pred: An array of predicted values. 20 | :return: Evaluation of the metric. 21 | """ 22 | 23 | @abc.abstractmethod 24 | def __repr__(self): 25 | """:return: Formated string representation of the metric.""" 26 | 27 | def __eq__(self, other): 28 | """:return: `True` if two metrics are equal, `False` otherwise.""" 29 | return (type(self) is type(other)) and (vars(self) == vars(other)) 30 | 31 | def __hash__(self): 32 | """:return: Hashing value using the metric as `str`.""" 33 | return str(self).__hash__() 34 | 35 | 36 | class RankingMetric(BaseMetric): 37 | """Ranking metric base class.""" 38 | 39 | ALIAS = 'ranking_metric' 40 | 41 | 42 | class ClassificationMetric(BaseMetric): 43 | """Rangking metric base class.""" 44 | 45 | ALIAS = 'classification_metric' 46 | 47 | 48 | def sort_and_couple(labels: np.array, scores: np.array) -> np.array: 49 | """Zip the `labels` with `scores` into a single list.""" 50 | couple = list(zip(labels, scores)) 51 | return np.array(sorted(couple, key=lambda x: x[1], reverse=True)) 52 | -------------------------------------------------------------------------------- /matchzoo/engine/base_task.py: -------------------------------------------------------------------------------- 1 | """Base task.""" 2 | 3 | import typing 4 | import abc 5 | 6 | import torch 7 | from torch import nn 8 | 9 | from matchzoo.engine import base_metric 10 | from matchzoo.utils import parse_metric, parse_loss 11 | 12 | 13 | class BaseTask(abc.ABC): 14 | """Base Task, shouldn't be used directly.""" 15 | 16 | TYPE = 'base' 17 | 18 | def __init__(self, losses=None, metrics=None): 19 | """ 20 | Base task constructor. 21 | 22 | :param losses: Losses of task. 23 | :param metrics: Metrics for evaluating. 24 | """ 25 | self._losses = self._convert(losses, parse_loss) 26 | self._metrics = self._convert(metrics, parse_metric) 27 | self._assure_losses() 28 | self._assure_metrics() 29 | 30 | def _convert(self, identifiers, parse): 31 | if not identifiers: 32 | identifiers = [] 33 | elif not isinstance(identifiers, list): 34 | identifiers = [identifiers] 35 | return [ 36 | parse(identifier, self.__class__.TYPE) 37 | for identifier in identifiers 38 | ] 39 | 40 | def _assure_losses(self): 41 | if not self._losses: 42 | first_available = self.list_available_losses()[0] 43 | self._losses = self._convert(first_available, parse_loss) 44 | 45 | def _assure_metrics(self): 46 | if not self._metrics: 47 | first_available = self.list_available_metrics()[0] 48 | self._metrics = self._convert(first_available, parse_metric) 49 | 50 | @property 51 | def losses(self): 52 | """:return: Losses used in the task.""" 53 | return self._losses 54 | 55 | @property 56 | def metrics(self): 57 | """:return: Metrics used in the task.""" 58 | return self._metrics 59 | 60 | @losses.setter 61 | def losses( 62 | self, 63 | new_losses: typing.Union[ 64 | typing.List[str], 65 | typing.List[nn.Module], 66 | str, 67 | nn.Module 68 | ] 69 | ): 70 | self._losses = self._convert(new_losses, parse_loss) 71 | 72 | @metrics.setter 73 | def metrics( 74 | self, 75 | new_metrics: typing.Union[ 76 | typing.List[str], 77 | typing.List[base_metric.BaseMetric], 78 | str, 79 | base_metric.BaseMetric 80 | ] 81 | ): 82 | self._metrics = self._convert(new_metrics, parse_metric) 83 | 84 | @classmethod 85 | @abc.abstractmethod 86 | def list_available_losses(cls) -> list: 87 | """:return: a list of available losses.""" 88 | 89 | @classmethod 90 | @abc.abstractmethod 91 | def list_available_metrics(cls) -> list: 92 | """:return: a list of available metrics.""" 93 | 94 | @property 95 | @abc.abstractmethod 96 | def output_shape(self) -> tuple: 97 | """:return: output shape of a single sample of the task.""" 98 | 99 | @property 100 | @abc.abstractmethod 101 | def output_dtype(self): 102 | """:return: output data type for specific task.""" 103 | -------------------------------------------------------------------------------- /matchzoo/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .rank_cross_entropy_loss import RankCrossEntropyLoss 2 | from .rank_hinge_loss import RankHingeLoss 3 | -------------------------------------------------------------------------------- /matchzoo/losses/rank_cross_entropy_loss.py: -------------------------------------------------------------------------------- 1 | """The rank cross entropy loss.""" 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class RankCrossEntropyLoss(nn.Module): 8 | """Creates a criterion that measures rank cross entropy loss.""" 9 | 10 | __constants__ = ['num_neg'] 11 | 12 | def __init__(self, num_neg: int = 1): 13 | """ 14 | :class:`RankCrossEntropyLoss` constructor. 15 | 16 | :param num_neg: Number of negative instances in hinge loss. 17 | """ 18 | super().__init__() 19 | self.num_neg = num_neg 20 | 21 | def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor): 22 | """ 23 | Calculate rank cross entropy loss. 24 | 25 | :param y_pred: Predicted result. 26 | :param y_true: Label. 27 | :return: Rank cross loss. 28 | """ 29 | logits = y_pred[::(self.num_neg + 1), :] 30 | labels = y_true[::(self.num_neg + 1), :] 31 | for neg_idx in range(self.num_neg): 32 | neg_logits = y_pred[(neg_idx + 1)::(self.num_neg + 1), :] 33 | neg_labels = y_true[(neg_idx + 1)::(self.num_neg + 1), :] 34 | logits = torch.cat((logits, neg_logits), dim=-1) 35 | labels = torch.cat((labels, neg_labels), dim=-1) 36 | return -torch.mean( 37 | torch.sum( 38 | labels * torch.log(F.softmax(logits, dim=-1) + torch.finfo(float).eps), 39 | dim=-1 40 | ) 41 | ) 42 | 43 | @property 44 | def num_neg(self): 45 | """`num_neg` getter.""" 46 | return self._num_neg 47 | 48 | @num_neg.setter 49 | def num_neg(self, value): 50 | """`num_neg` setter.""" 51 | self._num_neg = value 52 | -------------------------------------------------------------------------------- /matchzoo/losses/rank_hinge_loss.py: -------------------------------------------------------------------------------- 1 | """The rank hinge loss.""" 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class RankHingeLoss(nn.Module): 8 | """ 9 | Creates a criterion that measures rank hinge loss. 10 | 11 | Given inputs :math:`x1`, :math:`x2`, two 1D mini-batch `Tensors`, 12 | and a label 1D mini-batch tensor :math:`y` (containing 1 or -1). 13 | 14 | If :math:`y = 1` then it assumed the first input should be ranked 15 | higher (have a larger value) than the second input, and vice-versa 16 | for :math:`y = -1`. 17 | 18 | The loss function for each sample in the mini-batch is: 19 | 20 | .. math:: 21 | loss_{x, y} = max(0, -y * (x1 - x2) + margin) 22 | """ 23 | 24 | __constants__ = ['num_neg', 'margin', 'reduction'] 25 | 26 | def __init__(self, num_neg: int = 1, margin: float = 1., 27 | reduction: str = 'mean'): 28 | """ 29 | :class:`RankHingeLoss` constructor. 30 | 31 | :param num_neg: Number of negative instances in hinge loss. 32 | :param margin: Margin between positive and negative scores. 33 | Float. Has a default value of :math:`0`. 34 | :param reduction: String. Specifies the reduction to apply to 35 | the output: ``'none'`` | ``'mean'`` | ``'sum'``. 36 | ``'none'``: no reduction will be applied, 37 | ``'mean'``: the sum of the output will be divided by the 38 | number of elements in the output, 39 | ``'sum'``: the output will be summed. 40 | """ 41 | super().__init__() 42 | self.num_neg = num_neg 43 | self.margin = margin 44 | self.reduction = reduction 45 | 46 | def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor): 47 | """ 48 | Calculate rank hinge loss. 49 | 50 | :param y_pred: Predicted result. 51 | :param y_true: Label. 52 | :return: Hinge loss computed by user-defined margin. 53 | """ 54 | y_pos = y_pred[::(self.num_neg + 1), :] 55 | y_neg = [] 56 | for neg_idx in range(self.num_neg): 57 | neg = y_pred[(neg_idx + 1)::(self.num_neg + 1), :] 58 | y_neg.append(neg) 59 | y_neg = torch.cat(y_neg, dim=-1) 60 | y_neg = torch.mean(y_neg, dim=-1, keepdim=True) 61 | y_true = torch.ones_like(y_pos) 62 | return F.margin_ranking_loss( 63 | y_pos, y_neg, y_true, 64 | margin=self.margin, 65 | reduction=self.reduction 66 | ) 67 | 68 | @property 69 | def num_neg(self): 70 | """`num_neg` getter.""" 71 | return self._num_neg 72 | 73 | @num_neg.setter 74 | def num_neg(self, value): 75 | """`num_neg` setter.""" 76 | self._num_neg = value 77 | 78 | @property 79 | def margin(self): 80 | """`margin` getter.""" 81 | return self._margin 82 | 83 | @margin.setter 84 | def margin(self, value): 85 | """`margin` setter.""" 86 | self._margin = value 87 | -------------------------------------------------------------------------------- /matchzoo/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .precision import Precision 2 | from .average_precision import AveragePrecision 3 | from .discounted_cumulative_gain import DiscountedCumulativeGain 4 | from .mean_reciprocal_rank import MeanReciprocalRank 5 | from .mean_average_precision import MeanAveragePrecision 6 | from .normalized_discounted_cumulative_gain import \ 7 | NormalizedDiscountedCumulativeGain 8 | 9 | from .accuracy import Accuracy 10 | from .cross_entropy import CrossEntropy 11 | 12 | 13 | def list_available() -> list: 14 | from matchzoo.engine.base_metric import BaseMetric 15 | from matchzoo.utils import list_recursive_concrete_subclasses 16 | return list_recursive_concrete_subclasses(BaseMetric) 17 | -------------------------------------------------------------------------------- /matchzoo/metrics/accuracy.py: -------------------------------------------------------------------------------- 1 | """Accuracy metric for Classification.""" 2 | import numpy as np 3 | 4 | from matchzoo.engine.base_metric import ClassificationMetric 5 | 6 | 7 | class Accuracy(ClassificationMetric): 8 | """Accuracy metric.""" 9 | 10 | ALIAS = ['accuracy', 'acc'] 11 | 12 | def __init__(self): 13 | """:class:`Accuracy` constructor.""" 14 | 15 | def __repr__(self) -> str: 16 | """:return: Formated string representation of the metric.""" 17 | return f"{self.ALIAS[0]}" 18 | 19 | def __call__(self, y_true: np.array, y_pred: np.array) -> float: 20 | """ 21 | Calculate accuracy. 22 | 23 | Example: 24 | >>> import numpy as np 25 | >>> y_true = np.array([1]) 26 | >>> y_pred = np.array([[0, 1]]) 27 | >>> Accuracy()(y_true, y_pred) 28 | 1.0 29 | 30 | :param y_true: The ground true label of each document. 31 | :param y_pred: The predicted scores of each document. 32 | :return: Accuracy. 33 | """ 34 | y_pred = np.argmax(y_pred, axis=1) 35 | return np.sum(y_pred == y_true) / float(y_true.size) 36 | -------------------------------------------------------------------------------- /matchzoo/metrics/average_precision.py: -------------------------------------------------------------------------------- 1 | """Average precision metric for ranking.""" 2 | import numpy as np 3 | 4 | from matchzoo.engine.base_metric import RankingMetric 5 | from . import Precision 6 | 7 | 8 | class AveragePrecision(RankingMetric): 9 | """Average precision metric.""" 10 | 11 | ALIAS = ['average_precision', 'ap'] 12 | 13 | def __init__(self, threshold: float = 0.): 14 | """ 15 | :class:`AveragePrecision` constructor. 16 | 17 | :param threshold: The label threshold of relevance degree. 18 | """ 19 | self._threshold = threshold 20 | 21 | def __repr__(self) -> str: 22 | """:return: Formated string representation of the metric.""" 23 | return f"{self.ALIAS[0]}({self._threshold})" 24 | 25 | def __call__(self, y_true: np.array, y_pred: np.array) -> float: 26 | """ 27 | Calculate average precision (area under PR curve). 28 | 29 | Example: 30 | >>> y_true = [0, 1] 31 | >>> y_pred = [0.1, 0.6] 32 | >>> round(AveragePrecision()(y_true, y_pred), 2) 33 | 0.75 34 | >>> round(AveragePrecision()([], []), 2) 35 | 0.0 36 | 37 | :param y_true: The ground true label of each document. 38 | :param y_pred: The predicted scores of each document. 39 | :return: Average precision. 40 | """ 41 | precision_metrics = [Precision(k + 1) for k in range(len(y_pred))] 42 | out = [metric(y_true, y_pred) for metric in precision_metrics] 43 | if not out: 44 | return 0. 45 | return np.mean(out).item() 46 | -------------------------------------------------------------------------------- /matchzoo/metrics/cross_entropy.py: -------------------------------------------------------------------------------- 1 | """CrossEntropy metric for Classification.""" 2 | import numpy as np 3 | 4 | from matchzoo.engine.base_metric import ClassificationMetric 5 | from matchzoo.utils import one_hot 6 | 7 | 8 | class CrossEntropy(ClassificationMetric): 9 | """Cross entropy metric.""" 10 | 11 | ALIAS = ['cross_entropy', 'ce'] 12 | 13 | def __init__(self): 14 | """:class:`CrossEntropy` constructor.""" 15 | 16 | def __repr__(self) -> str: 17 | """:return: Formated string representation of the metric.""" 18 | return f"{self.ALIAS[0]}" 19 | 20 | def __call__( 21 | self, 22 | y_true: np.array, 23 | y_pred: np.array, 24 | eps: float = 1e-12 25 | ) -> float: 26 | """ 27 | Calculate cross entropy. 28 | 29 | Example: 30 | >>> y_true = [0, 1] 31 | >>> y_pred = [[0.25, 0.25], [0.01, 0.90]] 32 | >>> CrossEntropy()(y_true, y_pred) 33 | 0.7458274358333028 34 | 35 | :param y_true: The ground true label of each document. 36 | :param y_pred: The predicted scores of each document. 37 | :param eps: The Log loss is undefined for p=0 or p=1, 38 | so probabilities are clipped to max(eps, min(1 - eps, p)). 39 | :return: Average precision. 40 | """ 41 | y_pred = np.clip(y_pred, eps, 1. - eps) 42 | y_true = [ 43 | one_hot(y, num_classes=y_pred.shape[1]) for y in y_true 44 | ] 45 | return -np.sum(y_true * np.log(y_pred + 1e-9)) / y_pred.shape[0] 46 | -------------------------------------------------------------------------------- /matchzoo/metrics/discounted_cumulative_gain.py: -------------------------------------------------------------------------------- 1 | """Discounted cumulative gain metric for ranking.""" 2 | import math 3 | 4 | import numpy as np 5 | 6 | from matchzoo.engine.base_metric import ( 7 | BaseMetric, sort_and_couple, RankingMetric 8 | ) 9 | 10 | 11 | class DiscountedCumulativeGain(RankingMetric): 12 | """Disconunted cumulative gain metric.""" 13 | 14 | ALIAS = ['discounted_cumulative_gain', 'dcg'] 15 | 16 | def __init__(self, k: int = 1, threshold: float = 0.): 17 | """ 18 | :class:`DiscountedCumulativeGain` constructor. 19 | 20 | :param k: Number of results to consider. 21 | :param threshold: the label threshold of relevance degree. 22 | """ 23 | self._k = k 24 | self._threshold = threshold 25 | 26 | def __repr__(self) -> str: 27 | """:return: Formated string representation of the metric.""" 28 | return f"{self.ALIAS[0]}@{self._k}({self._threshold})" 29 | 30 | def __call__(self, y_true: np.array, y_pred: np.array) -> float: 31 | """ 32 | Calculate discounted cumulative gain (dcg). 33 | 34 | Relevance is positive real values or binary values. 35 | 36 | Example: 37 | >>> y_true = [0, 1, 2, 0] 38 | >>> y_pred = [0.4, 0.2, 0.5, 0.7] 39 | >>> DiscountedCumulativeGain(1)(y_true, y_pred) 40 | 0.0 41 | >>> round(DiscountedCumulativeGain(k=-1)(y_true, y_pred), 2) 42 | 0.0 43 | >>> round(DiscountedCumulativeGain(k=2)(y_true, y_pred), 2) 44 | 2.73 45 | >>> round(DiscountedCumulativeGain(k=3)(y_true, y_pred), 2) 46 | 2.73 47 | >>> type(DiscountedCumulativeGain(k=1)(y_true, y_pred)) 48 | 49 | 50 | :param y_true: The ground true label of each document. 51 | :param y_pred: The predicted scores of each document. 52 | 53 | :return: Discounted cumulative gain. 54 | """ 55 | if self._k <= 0: 56 | return 0. 57 | coupled_pair = sort_and_couple(y_true, y_pred) 58 | result = 0. 59 | for i, (label, score) in enumerate(coupled_pair): 60 | if i >= self._k: 61 | break 62 | if label > self._threshold: 63 | result += (math.pow(2., label) - 1.) / math.log(2. + i) 64 | return result 65 | -------------------------------------------------------------------------------- /matchzoo/metrics/mean_average_precision.py: -------------------------------------------------------------------------------- 1 | """Mean average precision metric for ranking.""" 2 | import numpy as np 3 | 4 | from matchzoo.engine.base_metric import ( 5 | BaseMetric, sort_and_couple, RankingMetric 6 | ) 7 | 8 | 9 | class MeanAveragePrecision(RankingMetric): 10 | """Mean average precision metric.""" 11 | 12 | ALIAS = ['mean_average_precision', 'map'] 13 | 14 | def __init__(self, threshold: float = 0.): 15 | """ 16 | :class:`MeanAveragePrecision` constructor. 17 | 18 | :param threshold: The threshold of relevance degree. 19 | """ 20 | self._threshold = threshold 21 | 22 | def __repr__(self): 23 | """:return: Formated string representation of the metric.""" 24 | return f"{self.ALIAS[0]}({self._threshold})" 25 | 26 | def __call__(self, y_true: np.array, y_pred: np.array) -> float: 27 | """ 28 | Calculate mean average precision. 29 | 30 | Example: 31 | >>> y_true = [0, 1, 0, 0] 32 | >>> y_pred = [0.1, 0.6, 0.2, 0.3] 33 | >>> MeanAveragePrecision()(y_true, y_pred) 34 | 1.0 35 | 36 | :param y_true: The ground true label of each document. 37 | :param y_pred: The predicted scores of each document. 38 | :return: Mean average precision. 39 | """ 40 | result = 0. 41 | pos = 0 42 | coupled_pair = sort_and_couple(y_true, y_pred) 43 | for idx, (label, score) in enumerate(coupled_pair): 44 | if label > self._threshold: 45 | pos += 1. 46 | result += pos / (idx + 1.) 47 | if pos == 0: 48 | return 0. 49 | else: 50 | return result / pos 51 | -------------------------------------------------------------------------------- /matchzoo/metrics/mean_reciprocal_rank.py: -------------------------------------------------------------------------------- 1 | """Mean reciprocal ranking metric.""" 2 | import numpy as np 3 | 4 | from matchzoo.engine.base_metric import ( 5 | BaseMetric, sort_and_couple, RankingMetric 6 | ) 7 | 8 | 9 | class MeanReciprocalRank(RankingMetric): 10 | """Mean reciprocal rank metric.""" 11 | 12 | ALIAS = ['mean_reciprocal_rank', 'mrr'] 13 | 14 | def __init__(self, threshold: float = 0.): 15 | """ 16 | :class:`MeanReciprocalRankMetric`. 17 | 18 | :param threshold: The label threshold of relevance degree. 19 | """ 20 | self._threshold = threshold 21 | 22 | def __repr__(self) -> str: 23 | """:return: Formated string representation of the metric.""" 24 | return f'{self.ALIAS[0]}({self._threshold})' 25 | 26 | def __call__(self, y_true: np.array, y_pred: np.array) -> float: 27 | """ 28 | Calculate reciprocal of the rank of the first relevant item. 29 | 30 | Example: 31 | >>> import numpy as np 32 | >>> y_pred = np.asarray([0.2, 0.3, 0.7, 1.0]) 33 | >>> y_true = np.asarray([1, 0, 0, 0]) 34 | >>> MeanReciprocalRank()(y_true, y_pred) 35 | 0.25 36 | 37 | :param y_true: The ground true label of each document. 38 | :param y_pred: The predicted scores of each document. 39 | :return: Mean reciprocal rank. 40 | """ 41 | coupled_pair = sort_and_couple(y_true, y_pred) 42 | for idx, (label, pred) in enumerate(coupled_pair): 43 | if label > self._threshold: 44 | return 1. / (idx + 1) 45 | return 0. 46 | -------------------------------------------------------------------------------- /matchzoo/metrics/normalized_discounted_cumulative_gain.py: -------------------------------------------------------------------------------- 1 | """Normalized discounted cumulative gain metric for ranking.""" 2 | import numpy as np 3 | 4 | from matchzoo.engine.base_metric import ( 5 | BaseMetric, sort_and_couple, RankingMetric 6 | ) 7 | from .discounted_cumulative_gain import DiscountedCumulativeGain 8 | 9 | 10 | class NormalizedDiscountedCumulativeGain(RankingMetric): 11 | """Normalized discounted cumulative gain metric.""" 12 | 13 | ALIAS = ['normalized_discounted_cumulative_gain', 'ndcg'] 14 | 15 | def __init__(self, k: int = 1, threshold: float = 0.): 16 | """ 17 | :class:`NormalizedDiscountedCumulativeGain` constructor. 18 | 19 | :param k: Number of results to consider 20 | :param threshold: the label threshold of relevance degree. 21 | """ 22 | self._k = k 23 | self._threshold = threshold 24 | 25 | def __repr__(self) -> str: 26 | """:return: Formated string representation of the metric.""" 27 | return f"{self.ALIAS[0]}@{self._k}({self._threshold})" 28 | 29 | def __call__(self, y_true: np.array, y_pred: np.array) -> float: 30 | """ 31 | Calculate normalized discounted cumulative gain (ndcg). 32 | 33 | Relevance is positive real values or binary values. 34 | 35 | Example: 36 | >>> y_true = [0, 1, 2, 0] 37 | >>> y_pred = [0.4, 0.2, 0.5, 0.7] 38 | >>> ndcg = NormalizedDiscountedCumulativeGain 39 | >>> ndcg(k=1)(y_true, y_pred) 40 | 0.0 41 | >>> round(ndcg(k=2)(y_true, y_pred), 2) 42 | 0.52 43 | >>> round(ndcg(k=3)(y_true, y_pred), 2) 44 | 0.52 45 | >>> type(ndcg()(y_true, y_pred)) 46 | 47 | 48 | :param y_true: The ground true label of each document. 49 | :param y_pred: The predicted scores of each document. 50 | 51 | :return: Normalized discounted cumulative gain. 52 | """ 53 | dcg_metric = DiscountedCumulativeGain(k=self._k, 54 | threshold=self._threshold) 55 | idcg_val = dcg_metric(y_true, y_true) 56 | dcg_val = dcg_metric(y_true, y_pred) 57 | return dcg_val / idcg_val if idcg_val != 0 else 0 58 | -------------------------------------------------------------------------------- /matchzoo/metrics/precision.py: -------------------------------------------------------------------------------- 1 | """Precision for ranking.""" 2 | import numpy as np 3 | 4 | from matchzoo.engine.base_metric import ( 5 | BaseMetric, sort_and_couple, RankingMetric 6 | ) 7 | 8 | 9 | class Precision(RankingMetric): 10 | """Precision metric.""" 11 | 12 | ALIAS = 'precision' 13 | 14 | def __init__(self, k: int = 1, threshold: float = 0.): 15 | """ 16 | :class:`PrecisionMetric` constructor. 17 | 18 | :param k: Number of results to consider. 19 | :param threshold: the label threshold of relevance degree. 20 | """ 21 | self._k = k 22 | self._threshold = threshold 23 | 24 | def __repr__(self) -> str: 25 | """:return: Formated string representation of the metric.""" 26 | return f"{self.ALIAS}@{self._k}({self._threshold})" 27 | 28 | def __call__(self, y_true: np.array, y_pred: np.array) -> float: 29 | """ 30 | Calculate precision@k. 31 | 32 | Example: 33 | >>> y_true = [0, 0, 0, 1] 34 | >>> y_pred = [0.2, 0.4, 0.3, 0.1] 35 | >>> Precision(k=1)(y_true, y_pred) 36 | 0.0 37 | >>> Precision(k=2)(y_true, y_pred) 38 | 0.0 39 | >>> Precision(k=4)(y_true, y_pred) 40 | 0.25 41 | >>> Precision(k=5)(y_true, y_pred) 42 | 0.2 43 | 44 | :param y_true: The ground true label of each document. 45 | :param y_pred: The predicted scores of each document. 46 | :return: Precision @ k 47 | :raises: ValueError: len(r) must be >= k. 48 | """ 49 | if self._k <= 0: 50 | raise ValueError(f"k must be greater than 0." 51 | f"{self._k} received.") 52 | coupled_pair = sort_and_couple(y_true, y_pred) 53 | precision = 0.0 54 | for idx, (label, score) in enumerate(coupled_pair): 55 | if idx >= self._k: 56 | break 57 | if label > self._threshold: 58 | precision += 1. 59 | return precision / self._k 60 | -------------------------------------------------------------------------------- /matchzoo/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .dense_baseline import DenseBaseline 2 | from .dssm import DSSM 3 | from .cdssm import CDSSM 4 | from .drmm import DRMM 5 | from .drmmtks import DRMMTKS 6 | from .esim import ESIM 7 | from .knrm import KNRM 8 | from .conv_knrm import ConvKNRM 9 | from .bimpm import BiMPM 10 | from .matchlstm import MatchLSTM 11 | from .arci import ArcI 12 | from .arcii import ArcII 13 | from .bert import Bert 14 | from .mvlstm import MVLSTM 15 | from .match_pyramid import MatchPyramid 16 | from .anmm import aNMM 17 | from .hbmp import HBMP 18 | from .duet import DUET 19 | from .diin import DIIN 20 | from .match_srnn import MatchSRNN 21 | 22 | 23 | def list_available() -> list: 24 | from matchzoo.engine.base_model import BaseModel 25 | from matchzoo.utils import list_recursive_concrete_subclasses 26 | return list_recursive_concrete_subclasses(BaseModel) 27 | -------------------------------------------------------------------------------- /matchzoo/models/bert.py: -------------------------------------------------------------------------------- 1 | """An implementation of Bert Model.""" 2 | import typing 3 | 4 | import torch 5 | import torch.nn as nn 6 | from pytorch_transformers import BertModel 7 | 8 | from matchzoo import preprocessors 9 | from matchzoo.engine.param_table import ParamTable 10 | from matchzoo.engine.param import Param 11 | from matchzoo.engine.base_model import BaseModel 12 | from matchzoo.engine.base_preprocessor import BasePreprocessor 13 | from matchzoo.engine import hyper_spaces 14 | from matchzoo.dataloader import callbacks 15 | from matchzoo.modules import BertModule 16 | 17 | 18 | class Bert(BaseModel): 19 | """Bert Model.""" 20 | 21 | @classmethod 22 | def get_default_params(cls) -> ParamTable: 23 | """:return: model default parameters.""" 24 | params = super().get_default_params() 25 | params.add(Param(name='mode', value='bert-base-uncased', 26 | desc="Pretrained Bert model.")) 27 | params.add(Param( 28 | 'dropout_rate', 0.0, 29 | hyper_space=hyper_spaces.quniform( 30 | low=0.0, high=0.8, q=0.01), 31 | desc="The dropout rate." 32 | )) 33 | return params 34 | 35 | @classmethod 36 | def get_default_preprocessor( 37 | cls, 38 | mode: str = 'bert-base-uncased' 39 | ) -> BasePreprocessor: 40 | """:return: Default preprocessor.""" 41 | return preprocessors.BertPreprocessor(mode=mode) 42 | 43 | @classmethod 44 | def get_default_padding_callback( 45 | cls, 46 | fixed_length_left: int = None, 47 | fixed_length_right: int = None, 48 | pad_value: typing.Union[int, str] = 0, 49 | pad_mode: str = 'pre' 50 | ): 51 | """:return: Default padding callback.""" 52 | return callbacks.BertPadding( 53 | fixed_length_left=fixed_length_left, 54 | fixed_length_right=fixed_length_right, 55 | pad_value=pad_value, 56 | pad_mode=pad_mode) 57 | 58 | def build(self): 59 | """Build model structure.""" 60 | self.bert = BertModule(mode=self._params['mode']) 61 | self.dropout = nn.Dropout(p=self._params['dropout_rate']) 62 | if 'base' in self._params['mode']: 63 | dim = 768 64 | elif 'large' in self._params['mode']: 65 | dim = 1024 66 | self.out = self._make_output_layer(dim) 67 | 68 | def forward(self, inputs): 69 | """Forward.""" 70 | 71 | input_left, input_right = inputs['text_left'], inputs['text_right'] 72 | 73 | bert_output = self.bert(input_left, input_right)[1] 74 | 75 | out = self.out(self.dropout(bert_output)) 76 | 77 | return out 78 | -------------------------------------------------------------------------------- /matchzoo/models/dense_baseline.py: -------------------------------------------------------------------------------- 1 | """A simple densely connected baseline model.""" 2 | import typing 3 | 4 | import torch 5 | 6 | from matchzoo.engine.base_model import BaseModel 7 | from matchzoo.engine.param_table import ParamTable 8 | from matchzoo.engine import hyper_spaces 9 | 10 | 11 | class DenseBaseline(BaseModel): 12 | """ 13 | A simple densely connected baseline model. 14 | 15 | Examples: 16 | >>> model = DenseBaseline() 17 | >>> model.params['mlp_num_layers'] = 2 18 | >>> model.params['mlp_num_units'] = 300 19 | >>> model.params['mlp_num_fan_out'] = 128 20 | >>> model.params['mlp_activation_func'] = 'relu' 21 | >>> model.guess_and_fill_missing_params(verbose=0) 22 | >>> model.build() 23 | 24 | """ 25 | 26 | @classmethod 27 | def get_default_params(cls) -> ParamTable: 28 | """:return: model default parameters.""" 29 | params = super().get_default_params( 30 | with_embedding=True, 31 | with_multi_layer_perceptron=True 32 | ) 33 | params['mlp_num_units'] = 256 34 | params.get('mlp_num_units').hyper_space = \ 35 | hyper_spaces.quniform(16, 512) 36 | params.get('mlp_num_layers').hyper_space = \ 37 | hyper_spaces.quniform(1, 5) 38 | return params 39 | 40 | def build(self): 41 | """Build.""" 42 | self.embeddinng = self._make_default_embedding_layer() 43 | self.mlp = self._make_multi_layer_perceptron_layer( 44 | 2 * self._params['embedding_output_dim'] 45 | ) 46 | self.out = self._make_output_layer( 47 | self._params['mlp_num_fan_out'] 48 | ) 49 | 50 | def forward(self, inputs): 51 | """Forward.""" 52 | input_left, input_right = inputs['text_left'], inputs['text_right'] 53 | input_left = self.embeddinng(input_left.long()).sum(1) 54 | input_right = self.embeddinng(input_right.long()).sum(1) 55 | x = torch.cat((input_left, input_right), dim=1) 56 | return self.out(self.mlp(x)) 57 | -------------------------------------------------------------------------------- /matchzoo/models/drmm.py: -------------------------------------------------------------------------------- 1 | """An implementation of DRMM Model.""" 2 | import typing 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from matchzoo.engine.param_table import ParamTable 8 | from matchzoo.engine.param import Param 9 | from matchzoo.engine.base_model import BaseModel 10 | from matchzoo.dataloader import callbacks 11 | from matchzoo.modules import Attention 12 | 13 | 14 | class DRMM(BaseModel): 15 | """ 16 | DRMM Model. 17 | 18 | Examples: 19 | >>> model = DRMM() 20 | >>> model.params['mlp_num_layers'] = 1 21 | >>> model.params['mlp_num_units'] = 5 22 | >>> model.params['mlp_num_fan_out'] = 1 23 | >>> model.params['mlp_activation_func'] = 'tanh' 24 | >>> model.guess_and_fill_missing_params(verbose=0) 25 | >>> model.build() 26 | 27 | """ 28 | 29 | @classmethod 30 | def get_default_params(cls) -> ParamTable: 31 | """:return: model default parameters.""" 32 | params = super().get_default_params( 33 | with_embedding=True, 34 | with_multi_layer_perceptron=True 35 | ) 36 | params.add(Param(name='mask_value', value=0, 37 | desc="The value to be masked from inputs.")) 38 | params.add(Param(name='hist_bin_size', value=30, 39 | desc="The number of bin size of the histogram.")) 40 | params['mlp_num_fan_out'] = 1 41 | return params 42 | 43 | @classmethod 44 | def get_default_padding_callback( 45 | cls, 46 | fixed_length_left: int = None, 47 | fixed_length_right: int = None, 48 | pad_value: typing.Union[int, str] = 0, 49 | pad_mode: str = 'pre' 50 | ): 51 | """:return: Default padding callback.""" 52 | return callbacks.DRMMPadding( 53 | fixed_length_left=fixed_length_left, 54 | fixed_length_right=fixed_length_right, 55 | pad_value=pad_value, 56 | pad_mode=pad_mode 57 | ) 58 | 59 | def build(self): 60 | """Build model structure.""" 61 | self.embedding = self._make_default_embedding_layer() 62 | self.attention = Attention( 63 | input_size=self._params['embedding_output_dim'] 64 | ) 65 | self.mlp = self._make_multi_layer_perceptron_layer( 66 | self._params['hist_bin_size'] 67 | ) 68 | self.out = self._make_output_layer(1) 69 | 70 | def forward(self, inputs): 71 | """Forward.""" 72 | 73 | # Scalar dimensions referenced here: 74 | # B = batch size (number of sequences) 75 | # D = embedding size 76 | # L = `input_left` sequence length 77 | # R = `input_right` sequence length 78 | # H = histogram size 79 | # K = size of top-k 80 | 81 | # Left input and right input. 82 | # query: shape = [B, L] 83 | # doc: shape = [B, L, H] 84 | # Note here, the doc is the matching histogram between original query 85 | # and original document. 86 | 87 | query, match_hist = inputs['text_left'], inputs['match_histogram'] 88 | 89 | # shape = [B, L] 90 | mask_query = (query == self._params['mask_value']) 91 | 92 | # Process left input. 93 | # shape = [B, L, D] 94 | embed_query = self.embedding(query.long()) 95 | 96 | # shape = [B, L] 97 | attention_probs = self.attention(embed_query, mask_query) 98 | 99 | # shape = [B, L] 100 | dense_output = self.mlp(match_hist).squeeze(dim=-1) 101 | 102 | x = torch.einsum('bl,bl->b', dense_output, attention_probs) 103 | 104 | out = self.out(x.unsqueeze(dim=-1)) 105 | return out 106 | -------------------------------------------------------------------------------- /matchzoo/models/dssm.py: -------------------------------------------------------------------------------- 1 | """An implementation of DSSM, Deep Structured Semantic Model.""" 2 | import typing 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | from matchzoo import preprocessors 8 | from matchzoo.engine.param_table import ParamTable 9 | from matchzoo.engine.param import Param 10 | from matchzoo.engine.base_model import BaseModel 11 | from matchzoo.engine.base_preprocessor import BasePreprocessor 12 | 13 | 14 | class DSSM(BaseModel): 15 | """ 16 | Deep structured semantic model. 17 | 18 | Examples: 19 | >>> model = DSSM() 20 | >>> model.params['mlp_num_layers'] = 3 21 | >>> model.params['mlp_num_units'] = 300 22 | >>> model.params['mlp_num_fan_out'] = 128 23 | >>> model.params['mlp_activation_func'] = 'relu' 24 | >>> model.guess_and_fill_missing_params(verbose=0) 25 | >>> model.build() 26 | 27 | """ 28 | 29 | @classmethod 30 | def get_default_params(cls) -> ParamTable: 31 | """:return: model default parameters.""" 32 | params = super().get_default_params(with_multi_layer_perceptron=True) 33 | params.add(Param(name='vocab_size', value=419, 34 | desc="Size of vocabulary.")) 35 | return params 36 | 37 | @classmethod 38 | def get_default_preprocessor( 39 | cls, 40 | truncated_mode: str = 'pre', 41 | truncated_length_left: typing.Optional[int] = None, 42 | truncated_length_right: typing.Optional[int] = None, 43 | filter_mode: str = 'df', 44 | filter_low_freq: float = 1, 45 | filter_high_freq: float = float('inf'), 46 | remove_stop_words: bool = False, 47 | ngram_size: typing.Optional[int] = 3, 48 | ) -> BasePreprocessor: 49 | """ 50 | Model default preprocessor. 51 | 52 | The preprocessor's transform should produce a correctly shaped data 53 | pack that can be used for training. 54 | 55 | :return: Default preprocessor. 56 | """ 57 | return preprocessors.BasicPreprocessor( 58 | truncated_mode=truncated_mode, 59 | truncated_length_left=truncated_length_left, 60 | truncated_length_right=truncated_length_right, 61 | filter_mode=filter_mode, 62 | filter_low_freq=filter_low_freq, 63 | filter_high_freq=filter_high_freq, 64 | remove_stop_words=remove_stop_words, 65 | ngram_size=ngram_size 66 | ) 67 | 68 | @classmethod 69 | def get_default_padding_callback(cls): 70 | """:return: Default padding callback.""" 71 | return None 72 | 73 | def build(self): 74 | """ 75 | Build model structure. 76 | 77 | DSSM use Siamese arthitecture. 78 | """ 79 | self.mlp_left = self._make_multi_layer_perceptron_layer( 80 | self._params['vocab_size'] 81 | ) 82 | self.mlp_right = self._make_multi_layer_perceptron_layer( 83 | self._params['vocab_size'] 84 | ) 85 | self.out = self._make_output_layer(1) 86 | 87 | def forward(self, inputs): 88 | """Forward.""" 89 | # Process left & right input. 90 | input_left, input_right = inputs['ngram_left'], inputs['ngram_right'] 91 | input_left = self.mlp_left(input_left) 92 | input_right = self.mlp_right(input_right) 93 | 94 | # Dot product with cosine similarity. 95 | x = F.cosine_similarity(input_left, input_right) 96 | 97 | out = self.out(x.unsqueeze(dim=1)) 98 | return out 99 | -------------------------------------------------------------------------------- /matchzoo/models/knrm.py: -------------------------------------------------------------------------------- 1 | """An implementation of KNRM Model.""" 2 | import typing 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from matchzoo.engine.param_table import ParamTable 9 | from matchzoo.engine.param import Param 10 | from matchzoo.engine.base_model import BaseModel 11 | from matchzoo.engine import hyper_spaces 12 | from matchzoo.modules import GaussianKernel 13 | 14 | 15 | class KNRM(BaseModel): 16 | """ 17 | KNRM Model. 18 | 19 | Examples: 20 | >>> model = KNRM() 21 | >>> model.params['kernel_num'] = 11 22 | >>> model.params['sigma'] = 0.1 23 | >>> model.params['exact_sigma'] = 0.001 24 | >>> model.guess_and_fill_missing_params(verbose=0) 25 | >>> model.build() 26 | 27 | """ 28 | 29 | @classmethod 30 | def get_default_params(cls) -> ParamTable: 31 | """:return: model default parameters.""" 32 | params = super().get_default_params(with_embedding=True) 33 | params.add(Param( 34 | name='kernel_num', 35 | value=11, 36 | hyper_space=hyper_spaces.quniform(low=5, high=20), 37 | desc="The number of RBF kernels." 38 | )) 39 | params.add(Param( 40 | name='sigma', 41 | value=0.1, 42 | hyper_space=hyper_spaces.quniform( 43 | low=0.01, high=0.2, q=0.01), 44 | desc="The `sigma` defines the kernel width." 45 | )) 46 | params.add(Param( 47 | name='exact_sigma', value=0.001, 48 | desc="The `exact_sigma` denotes the `sigma` " 49 | "for exact match." 50 | )) 51 | return params 52 | 53 | def build(self): 54 | """Build model structure.""" 55 | self.embedding = self._make_default_embedding_layer() 56 | 57 | self.kernels = nn.ModuleList() 58 | for i in range(self._params['kernel_num']): 59 | mu = 1. / (self._params['kernel_num'] - 1) + (2. * i) / ( 60 | self._params['kernel_num'] - 1) - 1.0 61 | sigma = self._params['sigma'] 62 | if mu > 1.0: 63 | sigma = self._params['exact_sigma'] 64 | mu = 1.0 65 | self.kernels.append(GaussianKernel(mu=mu, sigma=sigma)) 66 | 67 | self.out = self._make_output_layer(self._params['kernel_num']) 68 | 69 | def forward(self, inputs): 70 | """Forward.""" 71 | 72 | # Scalar dimensions referenced here: 73 | # B = batch size (number of sequences) 74 | # D = embedding size 75 | # L = `input_left` sequence length 76 | # R = `input_right` sequence length 77 | # K = number of kernels 78 | 79 | # Left input and right input. 80 | # shape = [B, L] 81 | # shape = [B, R] 82 | query, doc = inputs['text_left'], inputs['text_right'] 83 | 84 | # Process left input. 85 | # shape = [B, L, D] 86 | embed_query = self.embedding(query.long()) 87 | # shape = [B, R, D] 88 | embed_doc = self.embedding(doc.long()) 89 | 90 | # shape = [B, L, R] 91 | matching_matrix = torch.einsum( 92 | 'bld,brd->blr', 93 | F.normalize(embed_query, p=2, dim=-1), 94 | F.normalize(embed_doc, p=2, dim=-1) 95 | ) 96 | 97 | KM = [] 98 | for kernel in self.kernels: 99 | # shape = [B] 100 | K = torch.log1p(kernel(matching_matrix).sum(dim=-1)).sum(dim=-1) 101 | KM.append(K) 102 | 103 | # shape = [B, K] 104 | phi = torch.stack(KM, dim=1) 105 | 106 | out = self.out(phi) 107 | return out 108 | -------------------------------------------------------------------------------- /matchzoo/models/match_srnn.py: -------------------------------------------------------------------------------- 1 | """An implementation of Match-SRNN Model.""" 2 | import typing 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from matchzoo.engine.param_table import ParamTable 9 | from matchzoo.engine.param import Param 10 | from matchzoo.engine.base_model import BaseModel 11 | from matchzoo.engine import hyper_spaces 12 | from matchzoo.modules import MatchingTensor 13 | from matchzoo.modules import SpatialGRU 14 | 15 | 16 | class MatchSRNN(BaseModel): 17 | """ 18 | Match-SRNN Model. 19 | 20 | Examples: 21 | >>> model = MatchSRNN() 22 | >>> model.params['channels'] = 4 23 | >>> model.params['units'] = 10 24 | >>> model.params['dropout'] = 0.2 25 | >>> model.params['direction'] = 'lt' 26 | >>> model.guess_and_fill_missing_params(verbose=0) 27 | >>> model.build() 28 | 29 | """ 30 | 31 | @classmethod 32 | def get_default_params(cls) -> ParamTable: 33 | """:return: model default parameters.""" 34 | params = super().get_default_params( 35 | with_embedding=True, 36 | with_multi_layer_perceptron=False 37 | ) 38 | params.add(Param(name='channels', value=4, 39 | desc="Number of word interaction tensor channels")) 40 | params.add(Param(name='units', value=10, 41 | desc="Number of SpatialGRU units")) 42 | params.add(Param(name='direction', value='lt', 43 | desc="Direction of SpatialGRU scanning")) 44 | params.add(Param( 45 | 'dropout', 0.2, 46 | hyper_space=hyper_spaces.quniform( 47 | low=0.0, high=0.8, q=0.01), 48 | desc="The dropout rate." 49 | )) 50 | return params 51 | 52 | def build(self): 53 | """Build model structure.""" 54 | 55 | self.embedding = self._make_default_embedding_layer() 56 | self.dropout = nn.Dropout(p=self._params['dropout']) 57 | 58 | self.matching_tensor = MatchingTensor( 59 | self._params['embedding_output_dim'], 60 | channels=self._params["channels"]) 61 | 62 | self.spatial_gru = SpatialGRU( 63 | units=self._params['units'], 64 | direction=self._params['direction']) 65 | 66 | self.out = self._make_output_layer(self._params['units']) 67 | 68 | def forward(self, inputs): 69 | """Forward.""" 70 | 71 | # Scalar dimensions referenced here: 72 | # B = batch size (number of sequences) 73 | # D = embedding size 74 | # L = `input_left` sequence length 75 | # R = `input_right` sequence length 76 | # C = number of channels 77 | 78 | # Left input and right input 79 | # query = [B, L] 80 | # doc = [B, R] 81 | query, doc = inputs["text_left"].long(), inputs["text_right"].long() 82 | 83 | # Process left and right input 84 | # query = [B, L, D] 85 | # doc = [B, R, D] 86 | query = self.embedding(query) 87 | doc = self.embedding(doc) 88 | 89 | # query = [B, L, D] 90 | # doc = [B, R, D] 91 | query = self.dropout(query) 92 | doc = self.dropout(doc) 93 | 94 | # Get matching tensor 95 | # matching_tensor = [B, C, L, R] 96 | matching_tensor = self.matching_tensor(query, doc) 97 | 98 | # Apply spatial GRU to the word level interaction tensor 99 | # h_ij = [B, U] 100 | h_ij = self.spatial_gru(matching_tensor) 101 | 102 | # h_ij = [B, U] 103 | h_ij = self.dropout(h_ij) 104 | 105 | # Make output layer 106 | out = self.out(h_ij) 107 | 108 | return out 109 | -------------------------------------------------------------------------------- /matchzoo/models/parameter_readme_generator.py: -------------------------------------------------------------------------------- 1 | """matchzoo/models/README.md generater.""" 2 | 3 | from pathlib import Path 4 | 5 | import tabulate 6 | import inspect 7 | import pandas as pd 8 | 9 | import matchzoo 10 | 11 | 12 | def _generate(): 13 | full = _make_title() 14 | for model_class in matchzoo.models.list_available(): 15 | full += _make_model_class_subtitle(model_class) 16 | full += _make_doc_section_subsubtitle() 17 | full += _make_model_doc(model_class) 18 | model = model_class() 19 | full += _make_params_section_subsubtitle() 20 | full += _make_model_params_table(model) 21 | _write_to_files(full) 22 | 23 | 24 | def _make_title(): 25 | title = 'MatchZoo Model Reference' 26 | line = '*' * len(title) 27 | return line + '\n' + title + '\n' + line + '\n\n' 28 | 29 | 30 | def _make_model_class_subtitle(model_class): 31 | subtitle = model_class.__name__ 32 | line = '#' * len(subtitle) 33 | return subtitle + '\n' + line + '\n\n' 34 | 35 | 36 | def _make_doc_section_subsubtitle(): 37 | subsubtitle = 'Model Documentation' 38 | line = '*' * len(subsubtitle) 39 | return subsubtitle + '\n' + line + '\n\n' 40 | 41 | 42 | def _make_params_section_subsubtitle(): 43 | subsubtitle = 'Model Hyper Parameters' 44 | line = '*' * len(subsubtitle) 45 | return subsubtitle + '\n' + line + '\n\n' 46 | 47 | 48 | def _make_model_doc(model_class): 49 | return inspect.getdoc(model_class) + '\n\n' 50 | 51 | 52 | def _make_model_params_table(model): 53 | params = model.get_default_params() 54 | df = params.to_frame() 55 | df = df.rename({ 56 | 'Value': 'Default Value', 57 | 'Hyper-Space': 'Default Hyper-Space' 58 | }, axis='columns') 59 | return tabulate.tabulate(df, tablefmt='rst', headers='keys') + '\n\n' 60 | 61 | 62 | def _write_to_files(full): 63 | readme_file_path = Path(__file__).parent.joinpath('README.rst') 64 | doc_file_path = Path(__file__).parent.parent.parent. \ 65 | joinpath('docs').joinpath('source').joinpath('model_reference.rst') 66 | for file_path in readme_file_path, doc_file_path: 67 | with open(file_path, 'w', encoding='utf-8') as out_file: 68 | out_file.write(full) 69 | 70 | 71 | if __name__ == '__main__': 72 | _generate() 73 | -------------------------------------------------------------------------------- /matchzoo/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .attention import Attention 2 | from .attention import BidirectionalAttention 3 | from .attention import MatchModule 4 | from .dropout import RNNDropout 5 | from .stacked_brnn import StackedBRNN 6 | from .gaussian_kernel import GaussianKernel 7 | from .matching import Matching 8 | from .bert_module import BertModule 9 | from .character_embedding import CharacterEmbedding 10 | from .semantic_composite import SemanticComposite 11 | from .dense_net import DenseNet 12 | from .matching_tensor import MatchingTensor 13 | from .spatial_gru import SpatialGRU -------------------------------------------------------------------------------- /matchzoo/modules/attention.py: -------------------------------------------------------------------------------- 1 | """Attention module.""" 2 | import typing 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class Attention(nn.Module): 10 | """ 11 | Attention module. 12 | 13 | :param input_size: Size of input. 14 | :param mask: An integer to mask the invalid values. Defaults to 0. 15 | 16 | Examples: 17 | >>> import torch 18 | >>> attention = Attention(input_size=10) 19 | >>> x = torch.randn(4, 5, 10) 20 | >>> x.shape 21 | torch.Size([4, 5, 10]) 22 | >>> x_mask = torch.BoolTensor(4, 5) 23 | >>> attention(x, x_mask).shape 24 | torch.Size([4, 5]) 25 | 26 | """ 27 | 28 | def __init__(self, input_size: int = 100): 29 | """Attention constructor.""" 30 | super().__init__() 31 | self.linear = nn.Linear(input_size, 1, bias=False) 32 | 33 | def forward(self, x, x_mask): 34 | """Perform attention on the input.""" 35 | x = self.linear(x).squeeze(dim=-1) 36 | x = x.masked_fill(x_mask, -float('inf')) 37 | return F.softmax(x, dim=-1) 38 | 39 | 40 | class BidirectionalAttention(nn.Module): 41 | """Computing the soft attention between two sequence.""" 42 | 43 | def __init__(self): 44 | """Init.""" 45 | super().__init__() 46 | 47 | def forward(self, v1, v1_mask, v2, v2_mask): 48 | """Forward.""" 49 | similarity_matrix = v1.bmm(v2.transpose(2, 1).contiguous()) 50 | 51 | v2_v1_attn = F.softmax( 52 | similarity_matrix.masked_fill( 53 | v1_mask.unsqueeze(2), -1e-7), dim=1) 54 | v1_v2_attn = F.softmax( 55 | similarity_matrix.masked_fill( 56 | v2_mask.unsqueeze(1), -1e-7), dim=2) 57 | 58 | attended_v1 = v1_v2_attn.bmm(v2) 59 | attended_v2 = v2_v1_attn.transpose(1, 2).bmm(v1) 60 | 61 | attended_v1.masked_fill_(v1_mask.unsqueeze(2), 0) 62 | attended_v2.masked_fill_(v2_mask.unsqueeze(2), 0) 63 | 64 | return attended_v1, attended_v2 65 | 66 | 67 | class MatchModule(nn.Module): 68 | """ 69 | Computing the match representation for Match LSTM. 70 | 71 | :param hidden_size: Size of hidden vectors. 72 | :param dropout_rate: Dropout rate of the projection layer. Defaults to 0. 73 | 74 | Examples: 75 | >>> import torch 76 | >>> attention = MatchModule(hidden_size=10) 77 | >>> v1 = torch.randn(4, 5, 10) 78 | >>> v1.shape 79 | torch.Size([4, 5, 10]) 80 | >>> v2 = torch.randn(4, 5, 10) 81 | >>> v2_mask = torch.ones(4, 5).to(dtype=torch.uint8) 82 | >>> attention(v1, v2, v2_mask).shape 83 | torch.Size([4, 5, 20]) 84 | 85 | 86 | """ 87 | 88 | def __init__(self, hidden_size, dropout_rate=0): 89 | """Init.""" 90 | super().__init__() 91 | self.v2_proj = nn.Linear(hidden_size, hidden_size) 92 | self.proj = nn.Linear(hidden_size * 4, hidden_size * 2) 93 | self.dropout = nn.Dropout(p=dropout_rate) 94 | 95 | def forward(self, v1, v2, v2_mask): 96 | """Computing attention vectors and projection vectors.""" 97 | proj_v2 = self.v2_proj(v2) 98 | similarity_matrix = v1.bmm(proj_v2.transpose(2, 1).contiguous()) 99 | 100 | v1_v2_attn = F.softmax( 101 | similarity_matrix.masked_fill( 102 | v2_mask.unsqueeze(1).bool(), -1e-7), dim=2) 103 | v2_wsum = v1_v2_attn.bmm(v2) 104 | fusion = torch.cat([v1, v2_wsum, v1 - v2_wsum, v1 * v2_wsum], dim=2) 105 | match = self.dropout(F.relu(self.proj(fusion))) 106 | return match 107 | -------------------------------------------------------------------------------- /matchzoo/modules/bert_module.py: -------------------------------------------------------------------------------- 1 | """Bert module.""" 2 | import typing 3 | 4 | import torch 5 | import torch.nn as nn 6 | from pytorch_transformers import BertModel 7 | 8 | 9 | class BertModule(nn.Module): 10 | """ 11 | Bert module. 12 | 13 | BERT (from Google) released with the paper BERT: Pre-training of Deep 14 | Bidirectional Transformers for Language Understanding by Jacob Devlin, 15 | Ming-Wei Chang, Kenton Lee and Kristina Toutanova. 16 | 17 | :param mode: String, supported mode can be referred 18 | https://huggingface.co/pytorch-transformers/pretrained_models.html. 19 | 20 | """ 21 | 22 | def __init__(self, mode: str = 'bert-base-uncased'): 23 | """:class:`BertModule` constructor.""" 24 | super().__init__() 25 | self.bert = BertModel.from_pretrained(mode) 26 | 27 | def forward(self, x, y): 28 | """Forward.""" 29 | input_ids = torch.cat((x, y), dim=-1) 30 | token_type_ids = torch.cat(( 31 | torch.zeros_like(x), 32 | torch.ones_like(y)), dim=-1).long() 33 | attention_mask = (input_ids != 0) 34 | return self.bert(input_ids=input_ids, token_type_ids=token_type_ids, 35 | attention_mask=attention_mask) 36 | -------------------------------------------------------------------------------- /matchzoo/modules/character_embedding.py: -------------------------------------------------------------------------------- 1 | """Character embedding module.""" 2 | import typing 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class CharacterEmbedding(nn.Module): 9 | """ 10 | Character embedding module. 11 | 12 | :param char_embedding_input_dim: The input dimension of character embedding layer. 13 | :param char_embedding_output_dim: The output dimension of character embedding layer. 14 | :param char_conv_filters: The filter size of character convolution layer. 15 | :param char_conv_kernel_size: The kernel size of character convolution layer. 16 | 17 | Examples: 18 | >>> import torch 19 | >>> character_embedding = CharacterEmbedding() 20 | >>> x = torch.ones(10, 32, 16, dtype=torch.long) 21 | >>> x.shape 22 | torch.Size([10, 32, 16]) 23 | >>> character_embedding(x).shape 24 | torch.Size([10, 32, 100]) 25 | 26 | """ 27 | 28 | def __init__( 29 | self, 30 | char_embedding_input_dim: int = 100, 31 | char_embedding_output_dim: int = 8, 32 | char_conv_filters: int = 100, 33 | char_conv_kernel_size: int = 5 34 | ): 35 | """Init.""" 36 | super().__init__() 37 | self.char_embedding = nn.Embedding( 38 | num_embeddings=char_embedding_input_dim, 39 | embedding_dim=char_embedding_output_dim 40 | ) 41 | self.conv = nn.Conv1d( 42 | in_channels=char_embedding_output_dim, 43 | out_channels=char_conv_filters, 44 | kernel_size=char_conv_kernel_size 45 | ) 46 | 47 | def forward(self, x): 48 | """Forward.""" 49 | embed_x = self.char_embedding(x) 50 | 51 | batch_size, seq_len, word_len, embed_dim = embed_x.shape 52 | 53 | embed_x = embed_x.contiguous().view(-1, word_len, embed_dim) 54 | 55 | embed_x = self.conv(embed_x.transpose(1, 2)) 56 | embed_x = torch.max(embed_x, dim=-1)[0] 57 | 58 | embed_x = embed_x.view(batch_size, seq_len, -1) 59 | return embed_x 60 | -------------------------------------------------------------------------------- /matchzoo/modules/dropout.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class RNNDropout(nn.Dropout): 5 | """Dropout for RNN.""" 6 | 7 | def forward(self, sequences_batch): 8 | """Masking whole hidden vector for tokens.""" 9 | # B: batch size 10 | # L: sequence length 11 | # D: hidden size 12 | 13 | # sequence_batch: BxLxD 14 | ones = sequences_batch.data.new_ones(sequences_batch.shape[0], 15 | sequences_batch.shape[-1]) 16 | dropout_mask = nn.functional.dropout(ones, self.p, self.training, 17 | inplace=False) 18 | return dropout_mask.unsqueeze(1) * sequences_batch 19 | -------------------------------------------------------------------------------- /matchzoo/modules/gaussian_kernel.py: -------------------------------------------------------------------------------- 1 | """Gaussian kernel module.""" 2 | import typing 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class GaussianKernel(nn.Module): 9 | """ 10 | Gaussian kernel module. 11 | 12 | :param mu: Float, mean of the kernel. 13 | :param sigma: Float, sigma of the kernel. 14 | 15 | Examples: 16 | >>> import torch 17 | >>> kernel = GaussianKernel() 18 | >>> x = torch.randn(4, 5, 10) 19 | >>> x.shape 20 | torch.Size([4, 5, 10]) 21 | >>> kernel(x).shape 22 | torch.Size([4, 5, 10]) 23 | 24 | """ 25 | 26 | def __init__(self, mu: float = 1., sigma: float = 1.): 27 | """Gaussian kernel constructor.""" 28 | super().__init__() 29 | self.mu = mu 30 | self.sigma = sigma 31 | 32 | def forward(self, x): 33 | """Forward.""" 34 | return torch.exp( 35 | -0.5 * ((x - self.mu) ** 2) / (self.sigma ** 2) 36 | ) 37 | -------------------------------------------------------------------------------- /matchzoo/modules/matching.py: -------------------------------------------------------------------------------- 1 | """Matching module.""" 2 | import typing 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class Matching(nn.Module): 10 | """ 11 | Module that computes a matching matrix between samples in two tensors. 12 | 13 | :param normalize: Whether to L2-normalize samples along the 14 | dot product axis before taking the dot product. 15 | If set to `True`, then the output of the dot product 16 | is the cosine proximity between the two samples. 17 | :param matching_type: the similarity function for matching 18 | 19 | Examples: 20 | >>> import torch 21 | >>> matching = Matching(matching_type='dot', normalize=True) 22 | >>> x = torch.randn(2, 3, 2) 23 | >>> y = torch.randn(2, 4, 2) 24 | >>> matching(x, y).shape 25 | torch.Size([2, 3, 4]) 26 | 27 | """ 28 | 29 | def __init__(self, normalize: bool = False, matching_type: str = 'dot'): 30 | """:class:`Matching` constructor.""" 31 | super().__init__() 32 | self._normalize = normalize 33 | self._validate_matching_type(matching_type) 34 | self._matching_type = matching_type 35 | 36 | @classmethod 37 | def _validate_matching_type(cls, matching_type: str = 'dot'): 38 | valid_matching_type = ['dot', 'exact', 'mul', 'plus', 'minus', 'concat'] 39 | if matching_type not in valid_matching_type: 40 | raise ValueError(f"{matching_type} is not a valid matching type, " 41 | f"{valid_matching_type} expected.") 42 | 43 | def forward(self, x, y): 44 | """Perform attention on the input.""" 45 | length_left = x.shape[1] 46 | length_right = y.shape[1] 47 | if self._matching_type == 'dot': 48 | if self._normalize: 49 | x = F.normalize(x, p=2, dim=-1) 50 | y = F.normalize(y, p=2, dim=-1) 51 | return torch.einsum('bld,brd->blr', x, y) 52 | elif self._matching_type == 'exact': 53 | x = x.unsqueeze(dim=2).repeat(1, 1, length_right) 54 | y = y.unsqueeze(dim=1).repeat(1, length_left, 1) 55 | matching_matrix = (x == y) 56 | x = torch.sum(matching_matrix, dim=2, dtype=torch.float) 57 | y = torch.sum(matching_matrix, dim=1, dtype=torch.float) 58 | return x, y 59 | else: 60 | x = x.unsqueeze(dim=2).repeat(1, 1, length_right, 1) 61 | y = y.unsqueeze(dim=1).repeat(1, length_left, 1, 1) 62 | if self._matching_type == 'mul': 63 | return x * y 64 | elif self._matching_type == 'plus': 65 | return x + y 66 | elif self._matching_type == 'minus': 67 | return x - y 68 | elif self._matching_type == 'concat': 69 | return torch.cat((x, y), dim=3) 70 | -------------------------------------------------------------------------------- /matchzoo/modules/matching_tensor.py: -------------------------------------------------------------------------------- 1 | """Matching Tensor module.""" 2 | import typing 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class MatchingTensor(nn.Module): 10 | """ 11 | Module that captures the basic interactions between two tensors. 12 | 13 | :param matching_dims: Word dimension of two interaction texts. 14 | :param channels: Number of word interaction tensor channels. 15 | :param normalize: Whether to L2-normalize samples along the 16 | dot product axis before taking the dot product. 17 | If set to True, then the output of the dot product 18 | is the cosine proximity between the two samples. 19 | :param init_diag: Whether to initialize the diagonal elements 20 | of the matrix. 21 | 22 | Examples: 23 | >>> import matchzoo as mz 24 | >>> matching_dim = 5 25 | >>> matching_tensor = mz.modules.MatchingTensor( 26 | ... matching_dim, 27 | ... channels=4, 28 | ... normalize=True, 29 | ... init_diag=True 30 | ... ) 31 | 32 | """ 33 | 34 | def __init__( 35 | self, 36 | matching_dim: int, 37 | channels: int = 4, 38 | normalize: bool = True, 39 | init_diag: bool = True 40 | ): 41 | """:class:`MatchingTensor` constructor.""" 42 | super().__init__() 43 | self._matching_dim = matching_dim 44 | self._channels = channels 45 | self._normalize = normalize 46 | self._init_diag = init_diag 47 | 48 | self.interaction_matrix = torch.empty( 49 | self._channels, self._matching_dim, self._matching_dim 50 | ) 51 | if self._init_diag: 52 | self.interaction_matrix = self.interaction_matrix.uniform_(-0.05, 0.05) 53 | for channel_index in range(self._channels): 54 | self.interaction_matrix[channel_index].fill_diagonal_(0.1) 55 | self.interaction_matrix = nn.Parameter(self.interaction_matrix) 56 | else: 57 | self.interaction_matrix = nn.Parameter(self.interaction_matrix.uniform_()) 58 | 59 | def forward(self, x, y): 60 | """ 61 | The computation logic of MatchingTensor. 62 | 63 | :param inputs: two input tensors. 64 | """ 65 | 66 | if self._normalize: 67 | x = F.normalize(x, p=2, dim=-1) 68 | y = F.normalize(y, p=2, dim=-1) 69 | 70 | # output = [b, c, l, r] 71 | output = torch.einsum( 72 | 'bld,cde,bre->bclr', 73 | x, self.interaction_matrix, y 74 | ) 75 | return output 76 | -------------------------------------------------------------------------------- /matchzoo/modules/semantic_composite.py: -------------------------------------------------------------------------------- 1 | """Semantic composite module for DIIN model.""" 2 | import typing 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class SemanticComposite(nn.Module): 9 | """ 10 | SemanticComposite module. 11 | 12 | Apply a self-attention layer and a semantic composite fuse gate to compute the 13 | encoding result of one tensor. 14 | 15 | :param in_features: Feature size of input. 16 | :param dropout_rate: The dropout rate. 17 | 18 | Examples: 19 | >>> import torch 20 | >>> module = SemanticComposite(in_features=10) 21 | >>> x = torch.randn(4, 5, 10) 22 | >>> x.shape 23 | torch.Size([4, 5, 10]) 24 | >>> module(x).shape 25 | torch.Size([4, 5, 10]) 26 | 27 | """ 28 | 29 | def __init__(self, in_features, dropout_rate: float = 0.0): 30 | """Init.""" 31 | super().__init__() 32 | self.att_linear = nn.Linear(3 * in_features, 1, False) 33 | self.z_gate = nn.Linear(2 * in_features, in_features, True) 34 | self.r_gate = nn.Linear(2 * in_features, in_features, True) 35 | self.f_gate = nn.Linear(2 * in_features, in_features, True) 36 | 37 | self.dropout = nn.Dropout(p=dropout_rate) 38 | 39 | def forward(self, x): 40 | """Forward.""" 41 | seq_length = x.shape[1] 42 | 43 | x_1 = x.unsqueeze(dim=2).repeat(1, 1, seq_length, 1) 44 | x_2 = x.unsqueeze(dim=1).repeat(1, seq_length, 1, 1) 45 | x_concat = torch.cat([x_1, x_2, x_1 * x_2], dim=-1) 46 | 47 | # Self-attention layer. 48 | x_concat = self.dropout(x_concat) 49 | attn_matrix = self.att_linear(x_concat).squeeze(dim=-1) 50 | attn_weight = torch.softmax(attn_matrix, dim=2) 51 | attn = torch.bmm(attn_weight, x) 52 | 53 | # Semantic composite fuse gate. 54 | x_attn_concat = self.dropout(torch.cat([x, attn], dim=-1)) 55 | x_attn_concat = torch.cat([x, attn], dim=-1) 56 | z = torch.tanh(self.z_gate(x_attn_concat)) 57 | r = torch.sigmoid(self.r_gate(x_attn_concat)) 58 | f = torch.sigmoid(self.f_gate(x_attn_concat)) 59 | encoding = r * x + f * z 60 | 61 | return encoding 62 | -------------------------------------------------------------------------------- /matchzoo/modules/stacked_brnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | 5 | 6 | class StackedBRNN(nn.Module): 7 | """ 8 | Stacked Bi-directional RNNs. 9 | 10 | Differs from standard PyTorch library in that it has the option to save 11 | and concat the hidden states between layers. (i.e. the output hidden size 12 | for each sequence input is num_layers * hidden_size). 13 | 14 | Examples: 15 | >>> import torch 16 | >>> rnn = StackedBRNN( 17 | ... input_size=10, 18 | ... hidden_size=10, 19 | ... num_layers=2, 20 | ... dropout_rate=0.2, 21 | ... dropout_output=True, 22 | ... concat_layers=False 23 | ... ) 24 | >>> x = torch.randn(2, 5, 10) 25 | >>> x.size() 26 | torch.Size([2, 5, 10]) 27 | >>> x_mask = (torch.ones(2, 5) == 1) 28 | >>> rnn(x, x_mask).shape 29 | torch.Size([2, 5, 20]) 30 | 31 | """ 32 | 33 | def __init__(self, input_size, hidden_size, num_layers, 34 | dropout_rate=0, dropout_output=False, rnn_type=nn.LSTM, 35 | concat_layers=False): 36 | """Stacked Bidirectional LSTM.""" 37 | super().__init__() 38 | self.dropout_output = dropout_output 39 | self.dropout_rate = dropout_rate 40 | self.num_layers = num_layers 41 | self.concat_layers = concat_layers 42 | self.rnns = nn.ModuleList() 43 | for i in range(num_layers): 44 | input_size = input_size if i == 0 else 2 * hidden_size 45 | self.rnns.append(rnn_type(input_size, hidden_size, 46 | num_layers=1, 47 | bidirectional=True)) 48 | 49 | def forward(self, x, x_mask): 50 | """Encode either padded or non-padded sequences.""" 51 | if x_mask.data.sum() == 0: 52 | # No padding necessary. 53 | output = self._forward_unpadded(x, x_mask) 54 | output = self._forward_unpadded(x, x_mask) 55 | 56 | return output.contiguous() 57 | 58 | def _forward_unpadded(self, x, x_mask): 59 | """Faster encoding that ignores any padding.""" 60 | # Transpose batch and sequence dims 61 | x = x.transpose(0, 1) 62 | 63 | # Encode all layers 64 | outputs = [x] 65 | for i in range(self.num_layers): 66 | rnn_input = outputs[-1] 67 | 68 | # Apply dropout to hidden input 69 | if self.dropout_rate > 0: 70 | rnn_input = F.dropout(rnn_input, 71 | p=self.dropout_rate, 72 | training=self.training) 73 | # Forward 74 | rnn_output = self.rnns[i](rnn_input)[0] 75 | outputs.append(rnn_output) 76 | 77 | # Concat hidden layers 78 | if self.concat_layers: 79 | output = torch.cat(outputs[1:], 2) 80 | else: 81 | output = outputs[-1] 82 | 83 | # Transpose back 84 | output = output.transpose(0, 1) 85 | 86 | # Dropout on output layer 87 | if self.dropout_output and self.dropout_rate > 0: 88 | output = F.dropout(output, 89 | p=self.dropout_rate, 90 | training=self.training) 91 | return output 92 | -------------------------------------------------------------------------------- /matchzoo/preprocessors/__init__.py: -------------------------------------------------------------------------------- 1 | from . import units 2 | from .naive_preprocessor import NaivePreprocessor 3 | from .basic_preprocessor import BasicPreprocessor 4 | from .bert_preprocessor import BertPreprocessor 5 | 6 | 7 | def list_available() -> list: 8 | from matchzoo.engine.base_preprocessor import BasePreprocessor 9 | from matchzoo.utils import list_recursive_concrete_subclasses 10 | return list_recursive_concrete_subclasses(BasePreprocessor) 11 | -------------------------------------------------------------------------------- /matchzoo/preprocessors/bert_preprocessor.py: -------------------------------------------------------------------------------- 1 | """Bert Preprocessor.""" 2 | 3 | from pytorch_transformers import BertTokenizer 4 | 5 | from . import units 6 | from matchzoo import DataPack 7 | from matchzoo.engine.base_preprocessor import BasePreprocessor 8 | 9 | 10 | class BertPreprocessor(BasePreprocessor): 11 | """ 12 | Baisc preprocessor helper. 13 | 14 | :param mode: String, supported mode can be referred 15 | https://huggingface.co/pytorch-transformers/pretrained_models.html. 16 | 17 | """ 18 | 19 | def __init__(self, mode: str = 'bert-base-uncased'): 20 | """Initialization.""" 21 | super().__init__() 22 | self._tokenizer = BertTokenizer.from_pretrained(mode) 23 | 24 | def fit(self, data_pack: DataPack, verbose: int = 1): 25 | """Tokenizer is all BertPreprocessor's need.""" 26 | return 27 | 28 | def transform(self, data_pack: DataPack, verbose: int = 1) -> DataPack: 29 | """ 30 | Apply transformation on data. 31 | 32 | :param data_pack: Inputs to be preprocessed. 33 | :param verbose: Verbosity. 34 | 35 | :return: Transformed data as :class:`DataPack` object. 36 | """ 37 | data_pack = data_pack.copy() 38 | 39 | data_pack.apply_on_text(self._tokenizer.encode, 40 | mode='both', inplace=True, verbose=verbose) 41 | data_pack.append_text_length(inplace=True, verbose=verbose) 42 | data_pack.drop_empty(inplace=True) 43 | return data_pack 44 | -------------------------------------------------------------------------------- /matchzoo/preprocessors/build_unit_from_data_pack.py: -------------------------------------------------------------------------------- 1 | """Build unit from data pack.""" 2 | 3 | from tqdm import tqdm 4 | 5 | import matchzoo as mz 6 | from .units import StatefulUnit 7 | 8 | 9 | def build_unit_from_data_pack( 10 | unit: StatefulUnit, 11 | data_pack: mz.DataPack, mode: str = 'both', 12 | flatten: bool = True, verbose: int = 1 13 | ) -> StatefulUnit: 14 | """ 15 | Build a :class:`StatefulUnit` from a :class:`DataPack` object. 16 | 17 | :param unit: :class:`StatefulUnit` object to be built. 18 | :param data_pack: The input :class:`DataPack` object. 19 | :param mode: One of 'left', 'right', and 'both', to determine the source 20 | data for building the :class:`VocabularyUnit`. 21 | :param flatten: Flatten the datapack or not. `True` to organize the 22 | :class:`DataPack` text as a list, and `False` to organize 23 | :class:`DataPack` text as a list of list. 24 | :param verbose: Verbosity. 25 | :return: A built :class:`StatefulUnit` object. 26 | 27 | """ 28 | corpus = [] 29 | if flatten: 30 | data_pack.apply_on_text(corpus.extend, mode=mode, verbose=verbose) 31 | else: 32 | data_pack.apply_on_text(corpus.append, mode=mode, verbose=verbose) 33 | if verbose: 34 | description = 'Building ' + unit.__class__.__name__ + \ 35 | ' from a datapack.' 36 | corpus = tqdm(corpus, desc=description) 37 | unit.fit(corpus) 38 | return unit 39 | -------------------------------------------------------------------------------- /matchzoo/preprocessors/build_vocab_unit.py: -------------------------------------------------------------------------------- 1 | from matchzoo.data_pack import DataPack 2 | from .units import Vocabulary 3 | from .build_unit_from_data_pack import build_unit_from_data_pack 4 | 5 | 6 | def build_vocab_unit( 7 | data_pack: DataPack, 8 | mode: str = 'both', 9 | verbose: int = 1 10 | ) -> Vocabulary: 11 | """ 12 | Build a :class:`preprocessor.units.Vocabulary` given `data_pack`. 13 | 14 | The `data_pack` should be preprocessed forehand, and each item in 15 | `text_left` and `text_right` columns of the `data_pack` should be a list 16 | of tokens. 17 | 18 | :param data_pack: The :class:`DataPack` to build vocabulary upon. 19 | :param mode: One of 'left', 'right', and 'both', to determine the source 20 | data for building the :class:`VocabularyUnit`. 21 | :param verbose: Verbosity. 22 | :return: A built vocabulary unit. 23 | 24 | """ 25 | return build_unit_from_data_pack( 26 | unit=Vocabulary(), 27 | data_pack=data_pack, 28 | mode=mode, 29 | flatten=True, verbose=verbose 30 | ) 31 | -------------------------------------------------------------------------------- /matchzoo/preprocessors/chain_transform.py: -------------------------------------------------------------------------------- 1 | """Wrapper function organizes a number of transform functions.""" 2 | import typing 3 | import functools 4 | 5 | from .units.unit import Unit 6 | 7 | 8 | def chain_transform(units: typing.List[Unit]) -> typing.Callable: 9 | """ 10 | Compose unit transformations into a single function. 11 | 12 | :param units: List of :class:`matchzoo.StatelessUnit`. 13 | """ 14 | 15 | @functools.wraps(chain_transform) 16 | def wrapper(arg): 17 | """Wrapper function of transformations composition.""" 18 | for unit in units: 19 | arg = unit.transform(arg) 20 | return arg 21 | 22 | unit_names = ' => '.join(unit.__class__.__name__ for unit in units) 23 | wrapper.__name__ += ' of ' + unit_names 24 | return wrapper 25 | -------------------------------------------------------------------------------- /matchzoo/preprocessors/naive_preprocessor.py: -------------------------------------------------------------------------------- 1 | """Naive Preprocessor.""" 2 | 3 | from tqdm import tqdm 4 | 5 | from matchzoo.engine.base_preprocessor import BasePreprocessor 6 | from matchzoo import DataPack 7 | from .chain_transform import chain_transform 8 | from .build_vocab_unit import build_vocab_unit 9 | from . import units 10 | 11 | tqdm.pandas() 12 | 13 | 14 | class NaivePreprocessor(BasePreprocessor): 15 | """ 16 | Naive preprocessor. 17 | 18 | Example: 19 | >>> import matchzoo as mz 20 | >>> train_data = mz.datasets.toy.load_data() 21 | >>> test_data = mz.datasets.toy.load_data(stage='test') 22 | >>> preprocessor = mz.preprocessors.NaivePreprocessor() 23 | >>> train_data_processed = preprocessor.fit_transform(train_data, 24 | ... verbose=0) 25 | >>> type(train_data_processed) 26 | 27 | >>> test_data_transformed = preprocessor.transform(test_data, 28 | ... verbose=0) 29 | >>> type(test_data_transformed) 30 | 31 | 32 | """ 33 | 34 | def fit(self, data_pack: DataPack, verbose: int = 1): 35 | """ 36 | Fit pre-processing context for transformation. 37 | 38 | :param data_pack: data_pack to be preprocessed. 39 | :param verbose: Verbosity. 40 | :return: class:`NaivePreprocessor` instance. 41 | """ 42 | func = chain_transform(self._default_units()) 43 | data_pack = data_pack.apply_on_text(func, verbose=verbose) 44 | vocab_unit = build_vocab_unit(data_pack, verbose=verbose) 45 | self._context['vocab_unit'] = vocab_unit 46 | return self 47 | 48 | def transform(self, data_pack: DataPack, verbose: int = 1) -> DataPack: 49 | """ 50 | Apply transformation on data, create truncated length representation. 51 | 52 | :param data_pack: Inputs to be preprocessed. 53 | :param verbose: Verbosity. 54 | 55 | :return: Transformed data as :class:`DataPack` object. 56 | """ 57 | data_pack = data_pack.copy() 58 | 59 | units_ = self._default_units() 60 | units_.append(self._context['vocab_unit']) 61 | units_.append( 62 | units.TruncatedLength(text_length=30, truncate_mode='post')) 63 | func = chain_transform(units_) 64 | data_pack.apply_on_text(func, inplace=True, verbose=verbose) 65 | data_pack.append_text_length(inplace=True, verbose=verbose) 66 | data_pack.drop_empty(inplace=True) 67 | return data_pack 68 | -------------------------------------------------------------------------------- /matchzoo/preprocessors/units/__init__.py: -------------------------------------------------------------------------------- 1 | from .unit import Unit 2 | from .digit_removal import DigitRemoval 3 | from .frequency_filter import FrequencyFilter 4 | from .lemmatization import Lemmatization 5 | from .lowercase import Lowercase 6 | from .matching_histogram import MatchingHistogram 7 | from .ngram_letter import NgramLetter 8 | from .punc_removal import PuncRemoval 9 | from .stateful_unit import StatefulUnit 10 | from .stemming import Stemming 11 | from .stop_removal import StopRemoval 12 | from .tokenize import Tokenize 13 | from .vocabulary import Vocabulary 14 | from .word_hashing import WordHashing 15 | from .character_index import CharacterIndex 16 | from .word_exact_match import WordExactMatch 17 | from .truncated_length import TruncatedLength 18 | 19 | 20 | def list_available() -> list: 21 | from matchzoo.utils import list_recursive_concrete_subclasses 22 | return list_recursive_concrete_subclasses(Unit) 23 | -------------------------------------------------------------------------------- /matchzoo/preprocessors/units/character_index.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .unit import Unit 4 | 5 | 6 | class CharacterIndex(Unit): 7 | """ 8 | CharacterIndexUnit for DIIN model. 9 | 10 | The input of :class:'CharacterIndexUnit' should be a list of word 11 | character list extracted from a text. The output is the character 12 | index representation of this text. 13 | 14 | :class:`NgramLetterUnit` and :class:`VocabularyUnit` are two 15 | essential prerequisite of :class:`CharacterIndexUnit`. 16 | 17 | Examples: 18 | >>> input_ = [['#', 'a', '#'],['#', 'o', 'n', 'e', '#']] 19 | >>> character_index = CharacterIndex( 20 | ... char_index={ 21 | ... '': 0, '': 1, 'a': 2, 'n': 3, 'e':4, '#':5}) 22 | >>> index = character_index.transform(input_) 23 | >>> index 24 | [[5, 2, 5], [5, 1, 3, 4, 5]] 25 | 26 | """ 27 | 28 | def __init__( 29 | self, 30 | char_index: dict, 31 | ): 32 | """ 33 | Class initialization. 34 | 35 | :param char_index: character-index mapping generated by 36 | :class:'VocabularyUnit'. 37 | """ 38 | self._char_index = char_index 39 | 40 | def transform(self, input_: list) -> list: 41 | """ 42 | Transform list of characters to corresponding indices. 43 | 44 | :param input_: list of characters generated by 45 | :class:'NgramLetterUnit'. 46 | 47 | :return: character index representation of a text. 48 | """ 49 | idx = [] 50 | for i in range(len(input_)): 51 | current = [ 52 | self._char_index.get(input_[i][j], 1) 53 | for j in range(len(input_[i]))] 54 | idx.append(current) 55 | return idx 56 | -------------------------------------------------------------------------------- /matchzoo/preprocessors/units/digit_removal.py: -------------------------------------------------------------------------------- 1 | from .unit import Unit 2 | 3 | 4 | class DigitRemoval(Unit): 5 | """Process unit to remove digits.""" 6 | 7 | def transform(self, input_: list) -> list: 8 | """ 9 | Remove digits from list of tokens. 10 | 11 | :param input_: list of tokens to be filtered. 12 | 13 | :return tokens: tokens of tokens without digits. 14 | """ 15 | return [token for token in input_ if not token.isdigit()] 16 | -------------------------------------------------------------------------------- /matchzoo/preprocessors/units/frequency_filter.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import typing 3 | 4 | import numpy as np 5 | 6 | from .stateful_unit import StatefulUnit 7 | 8 | 9 | class FrequencyFilter(StatefulUnit): 10 | """ 11 | Frequency filter unit. 12 | 13 | :param low: Lower bound, inclusive. 14 | :param high: Upper bound, exclusive. 15 | :param mode: One of `tf` (term frequency), `df` (document frequency), 16 | and `idf` (inverse document frequency). 17 | 18 | Examples:: 19 | >>> import matchzoo as mz 20 | 21 | To filter based on term frequency (tf): 22 | >>> tf_filter = mz.preprocessors.units.FrequencyFilter( 23 | ... low=2, mode='tf') 24 | >>> tf_filter.fit([['A', 'B', 'B'], ['C', 'C', 'C']]) 25 | >>> tf_filter.transform(['A', 'B', 'C']) 26 | ['B', 'C'] 27 | 28 | To filter based on document frequency (df): 29 | >>> tf_filter = mz.preprocessors.units.FrequencyFilter( 30 | ... low=2, mode='df') 31 | >>> tf_filter.fit([['A', 'B'], ['B', 'C']]) 32 | >>> tf_filter.transform(['A', 'B', 'C']) 33 | ['B'] 34 | 35 | To filter based on inverse document frequency (idf): 36 | >>> idf_filter = mz.preprocessors.units.FrequencyFilter( 37 | ... low=1.2, mode='idf') 38 | >>> idf_filter.fit([['A', 'B'], ['B', 'C', 'D']]) 39 | >>> idf_filter.transform(['A', 'B', 'C']) 40 | ['A', 'C'] 41 | 42 | """ 43 | 44 | def __init__(self, low: float = 0, high: float = float('inf'), 45 | mode: str = 'df'): 46 | """Frequency filter unit.""" 47 | super().__init__() 48 | self._low = low 49 | self._high = high 50 | self._mode = mode 51 | 52 | def fit(self, list_of_tokens: typing.List[typing.List[str]]): 53 | """Fit `list_of_tokens` by calculating `mode` states.""" 54 | valid_terms = set() 55 | if self._mode == 'tf': 56 | stats = self._tf(list_of_tokens) 57 | elif self._mode == 'df': 58 | stats = self._df(list_of_tokens) 59 | elif self._mode == 'idf': 60 | stats = self._idf(list_of_tokens) 61 | else: 62 | raise ValueError(f"{self._mode} is not a valid filtering mode." 63 | f"Mode must be one of `tf`, `df`, and `idf`.") 64 | 65 | for k, v in stats.items(): 66 | if self._low <= v < self._high: 67 | valid_terms.add(k) 68 | 69 | self._context[self._mode] = valid_terms 70 | 71 | def transform(self, input_: list) -> list: 72 | """Transform a list of tokens by filtering out unwanted words.""" 73 | valid_terms = self._context[self._mode] 74 | return list(filter(lambda token: token in valid_terms, input_)) 75 | 76 | @classmethod 77 | def _tf(cls, list_of_tokens: list) -> dict: 78 | stats = collections.Counter() 79 | for tokens in list_of_tokens: 80 | stats.update(tokens) 81 | return stats 82 | 83 | @classmethod 84 | def _df(cls, list_of_tokens: list) -> dict: 85 | stats = collections.Counter() 86 | for tokens in list_of_tokens: 87 | stats.update(set(tokens)) 88 | return stats 89 | 90 | @classmethod 91 | def _idf(cls, list_of_tokens: list) -> dict: 92 | num_docs = len(list_of_tokens) 93 | stats = cls._df(list_of_tokens) 94 | for key, val in stats.most_common(): 95 | stats[key] = np.log((1 + num_docs) / (1 + val)) + 1 96 | return stats 97 | -------------------------------------------------------------------------------- /matchzoo/preprocessors/units/lemmatization.py: -------------------------------------------------------------------------------- 1 | import nltk 2 | 3 | from .unit import Unit 4 | 5 | 6 | class Lemmatization(Unit): 7 | """Process unit for token lemmatization.""" 8 | 9 | def transform(self, input_: list) -> list: 10 | """ 11 | Lemmatization a sequence of tokens. 12 | 13 | :param input_: list of tokens to be lemmatized. 14 | 15 | :return tokens: list of lemmatizd tokens. 16 | """ 17 | lemmatizer = nltk.WordNetLemmatizer() 18 | return [lemmatizer.lemmatize(token, pos='v') for token in input_] 19 | -------------------------------------------------------------------------------- /matchzoo/preprocessors/units/lowercase.py: -------------------------------------------------------------------------------- 1 | from .unit import Unit 2 | 3 | 4 | class Lowercase(Unit): 5 | """Process unit for text lower case.""" 6 | 7 | def transform(self, input_: list) -> list: 8 | """ 9 | Convert list of tokens to lower case. 10 | 11 | :param input_: list of tokens. 12 | 13 | :return tokens: lower-cased list of tokens. 14 | """ 15 | return [token.lower() for token in input_] 16 | -------------------------------------------------------------------------------- /matchzoo/preprocessors/units/matching_histogram.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .unit import Unit 4 | 5 | 6 | class MatchingHistogram(Unit): 7 | """ 8 | MatchingHistogramUnit Class. 9 | 10 | :param bin_size: The number of bins of the matching histogram. 11 | :param embedding_matrix: The word embedding matrix applied to calculate 12 | the matching histogram. 13 | :param normalize: Boolean, normalize the embedding or not. 14 | :param mode: The type of the historgram, it should be one of 'CH', 'NG', 15 | or 'LCH'. 16 | 17 | Examples: 18 | >>> embedding_matrix = np.array([[1.0, -1.0], [1.0, 2.0], [1.0, 3.0]]) 19 | >>> text_left = [0, 1] 20 | >>> text_right = [1, 2] 21 | >>> histogram = MatchingHistogram(3, embedding_matrix, True, 'CH') 22 | >>> histogram.transform([text_left, text_right]) 23 | [[3.0, 1.0, 1.0], [1.0, 2.0, 2.0]] 24 | 25 | """ 26 | 27 | def __init__(self, bin_size: int = 30, embedding_matrix=None, 28 | normalize=True, mode: str = 'LCH'): 29 | """The constructor.""" 30 | self._hist_bin_size = bin_size 31 | self._embedding_matrix = embedding_matrix 32 | if normalize: 33 | self._normalize_embedding() 34 | self._mode = mode 35 | 36 | def _normalize_embedding(self): 37 | """Normalize the embedding matrix.""" 38 | l2_norm = np.sqrt( 39 | (self._embedding_matrix * self._embedding_matrix).sum(axis=1) 40 | ) 41 | self._embedding_matrix = \ 42 | self._embedding_matrix / l2_norm[:, np.newaxis] 43 | 44 | def transform(self, input_: list) -> list: 45 | """Transform the input text.""" 46 | text_left, text_right = input_ 47 | matching_hist = np.ones((len(text_left), self._hist_bin_size), 48 | dtype=np.float32) 49 | embed_left = self._embedding_matrix[text_left] 50 | embed_right = self._embedding_matrix[text_right] 51 | matching_matrix = embed_left.dot(np.transpose(embed_right)) 52 | for (i, j), value in np.ndenumerate(matching_matrix): 53 | bin_index = int((value + 1.) / 2. * (self._hist_bin_size - 1.)) 54 | matching_hist[i][bin_index] += 1.0 55 | if self._mode == 'NH': 56 | matching_sum = matching_hist.sum(axis=1) 57 | matching_hist = matching_hist / matching_sum[:, np.newaxis] 58 | elif self._mode == 'LCH': 59 | matching_hist = np.log(matching_hist) 60 | return matching_hist.tolist() 61 | -------------------------------------------------------------------------------- /matchzoo/preprocessors/units/ngram_letter.py: -------------------------------------------------------------------------------- 1 | from .unit import Unit 2 | 3 | 4 | class NgramLetter(Unit): 5 | """ 6 | Process unit for n-letter generation. 7 | 8 | Triletter is used in :class:`DSSMModel`. 9 | This processor is expected to execute before `Vocab` 10 | has been created. 11 | 12 | Examples: 13 | >>> triletter = NgramLetter() 14 | >>> rv = triletter.transform(['hello', 'word']) 15 | >>> len(rv) 16 | 9 17 | >>> rv 18 | ['#he', 'hel', 'ell', 'llo', 'lo#', '#wo', 'wor', 'ord', 'rd#'] 19 | >>> triletter = NgramLetter(reduce_dim=False) 20 | >>> rv = triletter.transform(['hello', 'word']) 21 | >>> len(rv) 22 | 2 23 | >>> rv 24 | [['#he', 'hel', 'ell', 'llo', 'lo#'], ['#wo', 'wor', 'ord', 'rd#']] 25 | 26 | """ 27 | 28 | def __init__(self, ngram: int = 3, reduce_dim: bool = True): 29 | """ 30 | Class initialization. 31 | 32 | :param ngram: By default use 3-gram (tri-letter). 33 | :param reduce_dim: Reduce to 1-D list for sentence representation. 34 | """ 35 | self._ngram = ngram 36 | self._reduce_dim = reduce_dim 37 | 38 | def transform(self, input_: list) -> list: 39 | """ 40 | Transform token into tri-letter. 41 | 42 | For example, `word` should be represented as `#wo`, 43 | `wor`, `ord` and `rd#`. 44 | 45 | :param input_: list of tokens to be transformed. 46 | 47 | :return n_letters: generated n_letters. 48 | """ 49 | n_letters = [] 50 | if len(input_) == 0: 51 | token_ngram = [] 52 | if self._reduce_dim: 53 | n_letters.extend(token_ngram) 54 | else: 55 | n_letters.append(token_ngram) 56 | else: 57 | for token in input_: 58 | token = '#' + token + '#' 59 | token_ngram = [] 60 | while len(token) >= self._ngram: 61 | token_ngram.append(token[:self._ngram]) 62 | token = token[1:] 63 | if self._reduce_dim: 64 | n_letters.extend(token_ngram) 65 | else: 66 | n_letters.append(token_ngram) 67 | return n_letters 68 | -------------------------------------------------------------------------------- /matchzoo/preprocessors/units/punc_removal.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | from .unit import Unit 4 | 5 | 6 | class PuncRemoval(Unit): 7 | """Process unit for remove punctuations.""" 8 | 9 | _MATCH_PUNC = re.compile(r'[^\w\s]') 10 | 11 | def transform(self, input_: list) -> list: 12 | """ 13 | Remove punctuations from list of tokens. 14 | 15 | :param input_: list of toekns. 16 | 17 | :return rv: tokens without punctuation. 18 | """ 19 | return [token for token in input_ if 20 | not self._MATCH_PUNC.search(token)] 21 | -------------------------------------------------------------------------------- /matchzoo/preprocessors/units/stateful_unit.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import typing 3 | 4 | from .unit import Unit 5 | 6 | 7 | class StatefulUnit(Unit, metaclass=abc.ABCMeta): 8 | """ 9 | Unit with inner state. 10 | 11 | Usually need to be fit before transforming. All information gathered in the 12 | fit phrase will be stored into its `context`. 13 | """ 14 | 15 | def __init__(self): 16 | """Initialization.""" 17 | self._context = {} 18 | 19 | @property 20 | def state(self): 21 | """ 22 | Get current context. Same as `unit.context`. 23 | 24 | Deprecated since v2.2.0, and will be removed in the future. 25 | Used `unit.context` instead. 26 | """ 27 | return self._context 28 | 29 | @property 30 | def context(self): 31 | """Get current context. Same as `unit.state`.""" 32 | return self._context 33 | 34 | @abc.abstractmethod 35 | def fit(self, input_: typing.Any): 36 | """Abstract base method, need to be implemented in subclass.""" 37 | -------------------------------------------------------------------------------- /matchzoo/preprocessors/units/stemming.py: -------------------------------------------------------------------------------- 1 | import nltk 2 | 3 | from .unit import Unit 4 | 5 | 6 | class Stemming(Unit): 7 | """ 8 | Process unit for token stemming. 9 | 10 | :param stemmer: stemmer to use, `porter` or `lancaster`. 11 | """ 12 | 13 | def __init__(self, stemmer='porter'): 14 | """Initialization.""" 15 | self.stemmer = stemmer 16 | 17 | def transform(self, input_: list) -> list: 18 | """ 19 | Reducing inflected words to their word stem, base or root form. 20 | 21 | :param input_: list of string to be stemmed. 22 | """ 23 | if self.stemmer == 'porter': 24 | porter_stemmer = nltk.stem.PorterStemmer() 25 | return [porter_stemmer.stem(token) for token in input_] 26 | elif self.stemmer == 'lancaster' or self.stemmer == 'krovetz': 27 | lancaster_stemmer = nltk.stem.lancaster.LancasterStemmer() 28 | return [lancaster_stemmer.stem(token) for token in input_] 29 | else: 30 | raise ValueError( 31 | 'Not supported supported stemmer type: {}'.format( 32 | self.stemmer)) 33 | -------------------------------------------------------------------------------- /matchzoo/preprocessors/units/stop_removal.py: -------------------------------------------------------------------------------- 1 | import nltk 2 | 3 | from .unit import Unit 4 | 5 | 6 | class StopRemoval(Unit): 7 | """ 8 | Process unit to remove stop words. 9 | 10 | Example: 11 | >>> unit = StopRemoval() 12 | >>> unit.transform(['a', 'the', 'test']) 13 | ['test'] 14 | >>> type(unit.stopwords) 15 | 16 | """ 17 | 18 | def __init__(self, lang: str = 'english'): 19 | """Initialization.""" 20 | self._lang = lang 21 | self._stop = nltk.corpus.stopwords.words(self._lang) 22 | 23 | def transform(self, input_: list) -> list: 24 | """ 25 | Remove stopwords from list of tokenized tokens. 26 | 27 | :param input_: list of tokenized tokens. 28 | :param lang: language code for stopwords. 29 | 30 | :return tokens: list of tokenized tokens without stopwords. 31 | """ 32 | return [token 33 | for token 34 | in input_ 35 | if token not in self._stop] 36 | 37 | @property 38 | def stopwords(self) -> list: 39 | """ 40 | Get stopwords based on language. 41 | 42 | :params lang: language code. 43 | :return: list of stop words. 44 | """ 45 | return self._stop 46 | -------------------------------------------------------------------------------- /matchzoo/preprocessors/units/tokenize.py: -------------------------------------------------------------------------------- 1 | import nltk 2 | 3 | from .unit import Unit 4 | 5 | 6 | class Tokenize(Unit): 7 | """Process unit for text tokenization.""" 8 | 9 | def transform(self, input_: str) -> list: 10 | """ 11 | Process input data from raw terms to list of tokens. 12 | 13 | :param input_: raw textual input. 14 | 15 | :return tokens: tokenized tokens as a list. 16 | """ 17 | return nltk.word_tokenize(input_) 18 | -------------------------------------------------------------------------------- /matchzoo/preprocessors/units/truncated_length.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import numpy as np 4 | 5 | from .unit import Unit 6 | 7 | 8 | class TruncatedLength(Unit): 9 | """ 10 | TruncatedLengthUnit Class. 11 | 12 | Process unit to truncate the text that exceeds the set length. 13 | 14 | Examples: 15 | >>> from matchzoo.preprocessors.units import TruncatedLength 16 | >>> truncatedlen = TruncatedLength(3) 17 | >>> truncatedlen.transform(list(range(1, 6))) == [3, 4, 5] 18 | True 19 | >>> truncatedlen.transform(list(range(2))) == [0, 1] 20 | True 21 | 22 | """ 23 | 24 | def __init__( 25 | self, 26 | text_length: int, 27 | truncate_mode: str = 'pre' 28 | ): 29 | """ 30 | Class initialization. 31 | 32 | :param text_length: the specified maximum length of text. 33 | :param truncate_mode: String, `pre` or `post`: 34 | remove values from sequences larger than :attr:`text_length`, 35 | either at the beginning or at the end of the sequences. 36 | """ 37 | self._text_length = text_length 38 | self._truncate_mode = truncate_mode 39 | 40 | def transform(self, input_: list) -> list: 41 | """ 42 | Truncate the text that exceeds the specified maximum length. 43 | 44 | :param input_: list of tokenized tokens. 45 | 46 | :return tokens: list of tokenized tokens in fixed length 47 | if its origin length larger than :attr:`text_length`. 48 | """ 49 | if len(input_) <= self._text_length: 50 | truncated_tokens = input_ 51 | else: 52 | if self._truncate_mode == 'pre': 53 | truncated_tokens = input_[-self._text_length:] 54 | elif self._truncate_mode == 'post': 55 | truncated_tokens = input_[:self._text_length] 56 | else: 57 | raise ValueError('{} is not a vaild ' 58 | 'truncate mode.'.format(self._truncate_mode)) 59 | return truncated_tokens 60 | -------------------------------------------------------------------------------- /matchzoo/preprocessors/units/unit.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import typing 3 | 4 | 5 | class Unit(metaclass=abc.ABCMeta): 6 | """Process unit do not persive state (i.e. do not need fit).""" 7 | 8 | @abc.abstractmethod 9 | def transform(self, input_: typing.Any): 10 | """Abstract base method, need to be implemented in subclass.""" 11 | -------------------------------------------------------------------------------- /matchzoo/preprocessors/units/vocabulary.py: -------------------------------------------------------------------------------- 1 | from .stateful_unit import StatefulUnit 2 | 3 | 4 | class Vocabulary(StatefulUnit): 5 | """ 6 | Vocabulary class. 7 | 8 | :param pad_value: The string value for the padding position. 9 | :param oov_value: The string value for the out-of-vocabulary terms. 10 | 11 | Examples: 12 | >>> vocab = Vocabulary(pad_value='[PAD]', oov_value='[OOV]') 13 | >>> vocab.fit(['A', 'B', 'C', 'D', 'E']) 14 | >>> term_index = vocab.state['term_index'] 15 | >>> term_index # doctest: +SKIP 16 | {'[PAD]': 0, '[OOV]': 1, 'D': 2, 'A': 3, 'B': 4, 'C': 5, 'E': 6} 17 | >>> index_term = vocab.state['index_term'] 18 | >>> index_term # doctest: +SKIP 19 | {0: '[PAD]', 1: '[OOV]', 2: 'D', 3: 'A', 4: 'B', 5: 'C', 6: 'E'} 20 | 21 | >>> term_index['out-of-vocabulary-term'] 22 | 1 23 | >>> index_term[0] 24 | '[PAD]' 25 | >>> index_term[42] 26 | Traceback (most recent call last): 27 | ... 28 | KeyError: 42 29 | >>> a_index = term_index['A'] 30 | >>> c_index = term_index['C'] 31 | >>> vocab.transform(['C', 'A', 'C']) == [c_index, a_index, c_index] 32 | True 33 | >>> vocab.transform(['C', 'A', '[OOV]']) == [c_index, a_index, 1] 34 | True 35 | >>> indices = vocab.transform(list('ABCDDZZZ')) 36 | >>> ' '.join(vocab.state['index_term'][i] for i in indices) 37 | 'A B C D D [OOV] [OOV] [OOV]' 38 | 39 | """ 40 | 41 | def __init__(self, pad_value: str = '', oov_value: str = ''): 42 | """Vocabulary unit initializer.""" 43 | super().__init__() 44 | self._pad = pad_value 45 | self._oov = oov_value 46 | self._context['term_index'] = self.TermIndex() 47 | self._context['index_term'] = dict() 48 | 49 | class TermIndex(dict): 50 | """Map term to index.""" 51 | 52 | def __missing__(self, key): 53 | """Map out-of-vocabulary terms to index 1.""" 54 | return 1 55 | 56 | def fit(self, tokens: list): 57 | """Build a :class:`TermIndex` and a :class:`IndexTerm`.""" 58 | self._context['term_index'][self._pad] = 0 59 | self._context['term_index'][self._oov] = 1 60 | self._context['index_term'][0] = self._pad 61 | self._context['index_term'][1] = self._oov 62 | 63 | terms = sorted(set(tokens)) 64 | for index, term in enumerate(terms): 65 | self._context['term_index'][term] = index + 2 66 | self._context['index_term'][index + 2] = term 67 | 68 | def transform(self, input_: list) -> list: 69 | """Transform a list of tokens to corresponding indices.""" 70 | return [self._context['term_index'][token] for token in input_] 71 | -------------------------------------------------------------------------------- /matchzoo/preprocessors/units/word_exact_match.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .unit import Unit 4 | 5 | 6 | class WordExactMatch(Unit): 7 | """ 8 | WordExactUnit Class. 9 | 10 | Process unit to get a binary match list of two word index lists. The 11 | word index list is the word representation of a text. 12 | 13 | Examples: 14 | >>> import pandas 15 | >>> input_ = pandas.DataFrame({ 16 | ... 'text_left':[[1, 2, 3],[4, 5, 7, 9]], 17 | ... 'text_right':[[5, 3, 2, 7],[2, 3, 5]]} 18 | ... ) 19 | >>> left_word_exact_match = WordExactMatch( 20 | ... match='text_left', to_match='text_right' 21 | ... ) 22 | >>> left_out = input_.apply(left_word_exact_match.transform, axis=1) 23 | >>> left_out[0] 24 | [0, 1, 1] 25 | >>> left_out[1] 26 | [0, 1, 0, 0] 27 | >>> right_word_exact_match = WordExactMatch( 28 | ... match='text_right', to_match='text_left' 29 | ... ) 30 | >>> right_out = input_.apply(right_word_exact_match.transform, axis=1) 31 | >>> right_out[0] 32 | [0, 1, 1, 0] 33 | >>> right_out[1] 34 | [0, 0, 1] 35 | 36 | """ 37 | 38 | def __init__( 39 | self, 40 | match: str, 41 | to_match: str 42 | ): 43 | """ 44 | Class initialization. 45 | 46 | :param match: the 'match' column name. 47 | :param to_match: the 'to_match' column name. 48 | """ 49 | self._match = match 50 | self._to_match = to_match 51 | 52 | def transform(self, input_) -> list: 53 | """ 54 | Transform two word index lists into a binary match list. 55 | 56 | :param input_: a dataframe include 'match' column and 57 | 'to_match' column. 58 | 59 | :return: a binary match result list of two word index lists. 60 | """ 61 | match_binary = [] 62 | for i in range(len(input_[self._match])): 63 | if input_[self._match][i] in set(input_[self._to_match]): 64 | match_binary.append(1) 65 | else: 66 | match_binary.append(0) 67 | 68 | return match_binary 69 | -------------------------------------------------------------------------------- /matchzoo/preprocessors/units/word_hashing.py: -------------------------------------------------------------------------------- 1 | import collections 2 | 3 | import numpy as np 4 | 5 | from .unit import Unit 6 | 7 | 8 | class WordHashing(Unit): 9 | """ 10 | Word-hashing layer for DSSM-based models. 11 | 12 | The input of :class:`WordHashingUnit` should be a list of word 13 | sub-letter list extracted from one document. The output of is 14 | the word-hashing representation of this document. 15 | 16 | :class:`NgramLetterUnit` and :class:`VocabularyUnit` are two 17 | essential prerequisite of :class:`WordHashingUnit`. 18 | 19 | Examples: 20 | >>> letters = [['#te', 'tes','est', 'st#'], ['oov']] 21 | >>> word_hashing = WordHashing( 22 | ... term_index={ 23 | ... '_PAD': 0, 'OOV': 1, 'st#': 2, '#te': 3, 'est': 4, 'tes': 5 24 | ... }) 25 | >>> hashing = word_hashing.transform(letters) 26 | >>> hashing[0] 27 | [0.0, 0.0, 1.0, 1.0, 1.0, 1.0] 28 | >>> hashing[1] 29 | [0.0, 1.0, 0.0, 0.0, 0.0, 0.0] 30 | 31 | """ 32 | 33 | def __init__( 34 | self, 35 | term_index: dict, 36 | ): 37 | """ 38 | Class initialization. 39 | 40 | :param term_index: term-index mapping generated by 41 | :class:`VocabularyUnit`. 42 | :param dim_triletter: dimensionality of tri_leltters. 43 | """ 44 | self._term_index = term_index 45 | 46 | def transform(self, input_: list) -> list: 47 | """ 48 | Transform list of :attr:`letters` into word hashing layer. 49 | 50 | :param input_: list of `tri_letters` generated by 51 | :class:`NgramLetterUnit`. 52 | :return: Word hashing representation of `tri-letters`. 53 | """ 54 | if any([isinstance(elem, list) for elem in input_]): 55 | # The input shape for CDSSM is 56 | # [[word1 ngram, ngram], [word2, ngram, ngram], ...]. 57 | hashing = np.zeros((len(input_), len(self._term_index))) 58 | for idx, word in enumerate(input_): 59 | counted_letters = collections.Counter(word) 60 | for key, value in counted_letters.items(): 61 | letter_id = self._term_index.get(key, 1) 62 | hashing[idx, letter_id] = value 63 | else: 64 | # The input shape for DSSM model [ngram, ngram, ...]. 65 | hashing = np.zeros(len(self._term_index)) 66 | counted_letters = collections.Counter(input_) 67 | for key, value in counted_letters.items(): 68 | letter_id = self._term_index.get(key, 1) 69 | hashing[letter_id] = value 70 | 71 | return hashing.tolist() 72 | -------------------------------------------------------------------------------- /matchzoo/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | from .classification import Classification 2 | from .ranking import Ranking 3 | -------------------------------------------------------------------------------- /matchzoo/tasks/classification.py: -------------------------------------------------------------------------------- 1 | """Classification task.""" 2 | 3 | from matchzoo.engine.base_task import BaseTask 4 | 5 | 6 | class Classification(BaseTask): 7 | """Classification task. 8 | 9 | Examples: 10 | >>> classification_task = Classification(num_classes=2) 11 | >>> classification_task.metrics = ['acc'] 12 | >>> classification_task.num_classes 13 | 2 14 | >>> classification_task.output_shape 15 | (2,) 16 | >>> classification_task.output_dtype 17 | 18 | >>> print(classification_task) 19 | Classification Task with 2 classes 20 | 21 | """ 22 | 23 | TYPE = 'classification' 24 | 25 | def __init__(self, num_classes: int = 2, **kwargs): 26 | """Classification task.""" 27 | super().__init__(**kwargs) 28 | if not isinstance(num_classes, int): 29 | raise TypeError("Number of classes must be an integer.") 30 | if num_classes < 2: 31 | raise ValueError("Number of classes can't be smaller than 2") 32 | self._num_classes = num_classes 33 | 34 | @property 35 | def num_classes(self) -> int: 36 | """:return: number of classes to classify.""" 37 | return self._num_classes 38 | 39 | @classmethod 40 | def list_available_losses(cls) -> list: 41 | """:return: a list of available losses.""" 42 | return ['cross_entropy'] 43 | 44 | @classmethod 45 | def list_available_metrics(cls) -> list: 46 | """:return: a list of available metrics.""" 47 | return ['acc'] 48 | 49 | @property 50 | def output_shape(self) -> tuple: 51 | """:return: output shape of a single sample of the task.""" 52 | return self._num_classes, 53 | 54 | @property 55 | def output_dtype(self): 56 | """:return: target data type, expect `int` as output.""" 57 | return int 58 | 59 | def __str__(self): 60 | """:return: Task name as string.""" 61 | return f'Classification Task with {self._num_classes} classes' 62 | -------------------------------------------------------------------------------- /matchzoo/tasks/ranking.py: -------------------------------------------------------------------------------- 1 | """Ranking task.""" 2 | 3 | from matchzoo.engine import base_task 4 | 5 | 6 | class Ranking(base_task.BaseTask): 7 | """Ranking Task. 8 | 9 | Examples: 10 | >>> ranking_task = Ranking() 11 | >>> ranking_task.metrics = ['map', 'ndcg'] 12 | >>> ranking_task.output_shape 13 | (1,) 14 | >>> ranking_task.output_dtype 15 | 16 | >>> print(ranking_task) 17 | Ranking Task 18 | 19 | """ 20 | 21 | TYPE = 'ranking' 22 | 23 | @classmethod 24 | def list_available_losses(cls) -> list: 25 | """:return: a list of available losses.""" 26 | return ['mse'] 27 | 28 | @classmethod 29 | def list_available_metrics(cls) -> list: 30 | """:return: a list of available metrics.""" 31 | return ['map'] 32 | 33 | @property 34 | def output_shape(self) -> tuple: 35 | """:return: output shape of a single sample of the task.""" 36 | return 1, 37 | 38 | @property 39 | def output_dtype(self): 40 | """:return: target data type, expect `float` as output.""" 41 | return float 42 | 43 | def __str__(self): 44 | """:return: Task name as string.""" 45 | return 'Ranking Task' 46 | -------------------------------------------------------------------------------- /matchzoo/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | from .trainer import Trainer 2 | -------------------------------------------------------------------------------- /matchzoo/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .one_hot import one_hot 2 | from .tensor_type import TensorType 3 | from .list_recursive_subclasses import list_recursive_concrete_subclasses 4 | from .parse import parse_loss, parse_activation, parse_metric, parse_optimizer 5 | from .average_meter import AverageMeter 6 | from .timer import Timer 7 | from .early_stopping import EarlyStopping 8 | from .get_file import get_file, _hash_file 9 | -------------------------------------------------------------------------------- /matchzoo/utils/average_meter.py: -------------------------------------------------------------------------------- 1 | """Average meter.""" 2 | 3 | 4 | class AverageMeter(object): 5 | """ 6 | Computes and stores the average and current value. 7 | 8 | Examples: 9 | >>> am = AverageMeter() 10 | >>> am.update(1) 11 | >>> am.avg 12 | 1.0 13 | >>> am.update(val=2.5, n=2) 14 | >>> am.avg 15 | 2.0 16 | 17 | """ 18 | 19 | def __init__(self): 20 | """Average meter constructor.""" 21 | self.reset() 22 | 23 | def reset(self): 24 | """Reset AverageMeter.""" 25 | self._val = 0. 26 | self._avg = 0. 27 | self._sum = 0. 28 | self._count = 0. 29 | 30 | def update(self, val, n=1): 31 | """Update value.""" 32 | self._val = val 33 | self._sum += val * n 34 | self._count += n 35 | self._avg = self._sum / self._count 36 | 37 | @property 38 | def avg(self): 39 | """Get avg.""" 40 | return self._avg 41 | -------------------------------------------------------------------------------- /matchzoo/utils/early_stopping.py: -------------------------------------------------------------------------------- 1 | """Early stopping.""" 2 | 3 | import typing 4 | 5 | import torch 6 | import numpy as np 7 | 8 | 9 | class EarlyStopping: 10 | """ 11 | EarlyStopping stops training if no improvement after a given patience. 12 | 13 | :param patience: Number fo events to wait if no improvement and then 14 | stop the training. 15 | :param should_decrease: The way to judge the best so far. 16 | :param key: Key of metric to be compared. 17 | """ 18 | 19 | def __init__( 20 | self, 21 | patience: typing.Optional[int] = None, 22 | should_decrease: bool = None, 23 | key: typing.Any = None 24 | ): 25 | """Early stopping Constructor.""" 26 | self._patience = patience 27 | self._key = key 28 | self._best_so_far = 0 29 | self._epochs_with_no_improvement = 0 30 | self._is_best_so_far = False 31 | self._early_stop = False 32 | 33 | def state_dict(self) -> typing.Dict[str, typing.Any]: 34 | """A `Trainer` can use this to serialize the state.""" 35 | return { 36 | 'patience': self._patience, 37 | 'best_so_far': self._best_so_far, 38 | 'is_best_so_far': self._is_best_so_far, 39 | 'epochs_with_no_improvement': self._epochs_with_no_improvement, 40 | } 41 | 42 | def load_state_dict( 43 | self, 44 | state_dict: typing.Dict[str, typing.Any] 45 | ) -> None: 46 | """Hydrate a early stopping from a serialized state.""" 47 | self._patience = state_dict["patience"] 48 | self._is_best_so_far = state_dict["is_best_so_far"] 49 | self._best_so_far = state_dict["best_so_far"] 50 | self._epochs_with_no_improvement = \ 51 | state_dict["epochs_with_no_improvement"] 52 | 53 | def update(self, result: list): 54 | """Call function.""" 55 | score = result[self._key] 56 | if score > self._best_so_far: 57 | self._best_so_far = score 58 | self._is_best_so_far = True 59 | self._epochs_with_no_improvement = 0 60 | else: 61 | self._is_best_so_far = False 62 | self._epochs_with_no_improvement += 1 63 | 64 | @property 65 | def best_so_far(self) -> bool: 66 | """Returns best so far.""" 67 | return self._best_so_far 68 | 69 | @property 70 | def is_best_so_far(self) -> bool: 71 | """Returns true if it is the best so far.""" 72 | return self._is_best_so_far 73 | 74 | @property 75 | def should_stop_early(self) -> bool: 76 | """Returns true if improvement has stopped for long enough.""" 77 | if not self._patience: 78 | return False 79 | else: 80 | return self._epochs_with_no_improvement >= self._patience 81 | -------------------------------------------------------------------------------- /matchzoo/utils/list_recursive_subclasses.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | 3 | 4 | def list_recursive_concrete_subclasses(base): 5 | """List all concrete subclasses of `base` recursively.""" 6 | return _filter_concrete(_bfs(base)) 7 | 8 | 9 | def _filter_concrete(classes): 10 | return list(filter(lambda c: not inspect.isabstract(c), classes)) 11 | 12 | 13 | def _bfs(base): 14 | return base.__subclasses__() + sum([ 15 | _bfs(subclass) 16 | for subclass in base.__subclasses__() 17 | ], []) 18 | -------------------------------------------------------------------------------- /matchzoo/utils/one_hot.py: -------------------------------------------------------------------------------- 1 | """One hot vectors.""" 2 | import numpy as np 3 | 4 | 5 | def one_hot(indices: int, num_classes: int) -> np.ndarray: 6 | """:return: A one-hot encoded vector.""" 7 | vec = np.zeros((num_classes,), dtype=np.int64) 8 | vec[indices] = 1 9 | return vec 10 | -------------------------------------------------------------------------------- /matchzoo/utils/tensor_type.py: -------------------------------------------------------------------------------- 1 | """Define Keras tensor type.""" 2 | import typing 3 | 4 | TensorType = typing.Any 5 | -------------------------------------------------------------------------------- /matchzoo/utils/timer.py: -------------------------------------------------------------------------------- 1 | """Timer.""" 2 | 3 | import time 4 | 5 | 6 | class Timer(object): 7 | """Computes elapsed time.""" 8 | 9 | def __init__(self): 10 | """Timer constructor.""" 11 | self.reset() 12 | 13 | def reset(self): 14 | """Reset timer.""" 15 | self.running = True 16 | self.total = 0 17 | self.start = time.time() 18 | 19 | def resume(self): 20 | """Resume.""" 21 | if not self.running: 22 | self.running = True 23 | self.start = time.time() 24 | return self 25 | 26 | def stop(self): 27 | """Stop.""" 28 | if self.running: 29 | self.running = False 30 | self.total += time.time() - self.start 31 | return self 32 | 33 | @property 34 | def time(self): 35 | """Return time.""" 36 | if self.running: 37 | return self.total + time.time() - self.start 38 | return self.total 39 | -------------------------------------------------------------------------------- /matchzoo/version.py: -------------------------------------------------------------------------------- 1 | """Matchzoo version file.""" 2 | 3 | __version__ = '1.1.1' 4 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | markers = 3 | cron: marks tests as cron (deselect with '-m "not cron"') 4 | slow: marks tests as slow (deselect with '-m "not slow"') 5 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch >= 1.2.0 2 | pytorch-transformers >= 1.1.0 3 | tabulate >= 0.8.3 4 | nltk >= 3.4.3 5 | numpy >= 1.16.4 6 | tqdm == 4.38.0 7 | dill >= 0.2.9 8 | hyperopt == 0.1.2 9 | pandas == 0.24.2 10 | networkx >= 2.3 11 | h5py >= 2.9.0 12 | coverage >= 4.5.3 13 | codecov >= 2.0.15 14 | pytest >= 4.6.3 15 | pytest-cov >= 2.7.1 16 | flake8 >= 3.7.7 17 | flake8_docstrings >= 1.3.0 18 | pydocstyle == 2.1 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | 4 | from setuptools import setup, find_packages 5 | 6 | 7 | here = os.path.abspath(os.path.dirname(__file__)) 8 | 9 | # Avoids IDE errors, but actual version is read from version.py 10 | __version__ = None 11 | exec(open('matchzoo/version.py').read()) 12 | 13 | short_description = 'Facilitating the design, comparison and sharing' \ 14 | 'of deep text matching models.' 15 | 16 | # Get the long description from the README file 17 | with io.open(os.path.join(here, 'README.md'), encoding='utf-8') as f: 18 | long_description = f.read() 19 | 20 | install_requires = [ 21 | 'torch >= 1.2.0', 22 | 'pytorch-transformers >= 1.1.0', 23 | 'nltk >= 3.4.3', 24 | 'numpy >= 1.16.4', 25 | 'tqdm == 4.38.0', 26 | 'dill >= 0.2.9', 27 | 'pandas == 0.24.2', 28 | 'networkx >= 2.3', 29 | 'h5py >= 2.9.0', 30 | 'hyperopt == 0.1.2' 31 | ] 32 | 33 | extras_requires = { 34 | 'tests': [ 35 | 'coverage >= 4.5.3', 36 | 'codecov >= 2.0.15', 37 | 'pytest >= 4.6.3', 38 | 'pytest-cov >= 2.7.1', 39 | 'flake8 >= 3.7.7', 40 | 'flake8_docstrings >= 1.3.0'], 41 | } 42 | 43 | 44 | setup( 45 | name="matchzoo-py", 46 | version=__version__, 47 | author="MatchZoo-py Authors", 48 | author_email="fanyixing@ict.ac.cn", 49 | description=(short_description), 50 | license="Apache 2.0", 51 | keywords="text matching models", 52 | url="https://github.com/NTMC-Community/MatchZoo-py", 53 | packages=find_packages(), 54 | include_package_data=True, 55 | long_description=long_description, 56 | long_description_content_type='text/markdown', 57 | classifiers=[ 58 | "Development Status :: 3 - Alpha", 59 | 'Environment :: Console', 60 | 'Operating System :: POSIX :: Linux', 61 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 62 | "License :: OSI Approved :: Apache Software License", 63 | 'Programming Language :: Python :: 3.6' 64 | ], 65 | install_requires=install_requires, 66 | extras_require=extras_requires 67 | ) 68 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NTMC-Community/MatchZoo-py/0e5c04e1e948aa9277abd5c85ff99d9950d8527f/tests/__init__.py -------------------------------------------------------------------------------- /tests/data_pack/test_datapack.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | 3 | import pandas as pd 4 | import pytest 5 | 6 | from matchzoo import DataPack, load_data_pack 7 | 8 | 9 | @pytest.fixture 10 | def data_pack(): 11 | relation = [['qid0', 'did0', 1], ['qid1', 'did1', 0]] 12 | left = [['qid0', [1, 2]], ['qid1', [2, 3]]] 13 | right = [['did0', [2, 3, 4]], ['did1', [3, 4, 5]]] 14 | relation = pd.DataFrame(relation, columns=['id_left', 'id_right', 'label']) 15 | left = pd.DataFrame(left, columns=['id_left', 'text_left']) 16 | left.set_index('id_left', inplace=True) 17 | right = pd.DataFrame(right, columns=['id_right', 'text_right']) 18 | right.set_index('id_right', inplace=True) 19 | return DataPack(relation=relation, 20 | left=left, 21 | right=right) 22 | 23 | 24 | def test_length(data_pack): 25 | num_examples = 2 26 | assert len(data_pack) == num_examples 27 | 28 | 29 | def test_getter(data_pack): 30 | assert data_pack.relation.iloc[0].values.tolist() == ['qid0', 'did0', 1] 31 | assert data_pack.relation.iloc[1].values.tolist() == ['qid1', 'did1', 0] 32 | assert data_pack.left.loc['qid0', 'text_left'] == [1, 2] 33 | assert data_pack.right.loc['did1', 'text_right'] == [3, 4, 5] 34 | 35 | 36 | def test_save_load(data_pack): 37 | dirpath = '.tmpdir' 38 | data_pack.save(dirpath) 39 | dp = load_data_pack(dirpath) 40 | assert len(data_pack) == 2 41 | assert len(dp) == 2 42 | shutil.rmtree(dirpath) 43 | -------------------------------------------------------------------------------- /tests/dataloader/test_callbacks.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import matchzoo as mz 4 | from matchzoo import preprocessors 5 | from matchzoo.dataloader import callbacks 6 | from matchzoo.dataloader import Dataset, DataLoader 7 | from matchzoo.datasets import embeddings 8 | from matchzoo.embedding import load_from_file 9 | 10 | 11 | @pytest.fixture(scope='module') 12 | def train_raw(): 13 | return mz.datasets.toy.load_data('test', task='ranking')[:5] 14 | 15 | 16 | def test_basic_padding(train_raw): 17 | preprocessor = preprocessors.BasicPreprocessor() 18 | data_preprocessed = preprocessor.fit_transform(train_raw, verbose=0) 19 | dataset = Dataset(data_preprocessed, batch_size=5, mode='point') 20 | 21 | pre_fixed_padding = callbacks.BasicPadding( 22 | fixed_length_left=5, fixed_length_right=5, pad_word_mode='pre', with_ngram=False) 23 | dataloader = DataLoader(dataset, callback=pre_fixed_padding) 24 | for batch in dataloader: 25 | assert batch[0]['text_left'].shape == (5, 5) 26 | assert batch[0]['text_right'].shape == (5, 5) 27 | 28 | post_padding = callbacks.BasicPadding(pad_word_mode='post', with_ngram=False) 29 | dataloader = DataLoader(dataset, callback=post_padding) 30 | for batch in dataloader: 31 | max_left_len = max(batch[0]['length_left'].detach().cpu().numpy()) 32 | max_right_len = max(batch[0]['length_right'].detach().cpu().numpy()) 33 | assert batch[0]['text_left'].shape == (5, max_left_len) 34 | assert batch[0]['text_right'].shape == (5, max_right_len) 35 | 36 | 37 | def test_drmm_padding(train_raw): 38 | preprocessor = preprocessors.BasicPreprocessor() 39 | data_preprocessed = preprocessor.fit_transform(train_raw, verbose=0) 40 | 41 | embedding_matrix = load_from_file(embeddings.EMBED_10_GLOVE, mode='glove') 42 | term_index = preprocessor.context['vocab_unit'].state['term_index'] 43 | embedding_matrix = embedding_matrix.build_matrix(term_index) 44 | histgram_callback = callbacks.Histogram( 45 | embedding_matrix=embedding_matrix, bin_size=30, hist_mode='LCH') 46 | dataset = Dataset( 47 | data_preprocessed, mode='point', batch_size=5, callbacks=[histgram_callback]) 48 | 49 | pre_fixed_padding = callbacks.DRMMPadding( 50 | fixed_length_left=5, fixed_length_right=5, pad_mode='pre') 51 | dataloader = DataLoader(dataset, callback=pre_fixed_padding) 52 | for batch in dataloader: 53 | assert batch[0]['text_left'].shape == (5, 5) 54 | assert batch[0]['text_right'].shape == (5, 5) 55 | assert batch[0]['match_histogram'].shape == (5, 5, 30) 56 | 57 | post_padding = callbacks.DRMMPadding(pad_mode='post') 58 | dataloader = DataLoader(dataset, callback=post_padding) 59 | for batch in dataloader: 60 | max_left_len = max(batch[0]['length_left'].detach().cpu().numpy()) 61 | max_right_len = max(batch[0]['length_right'].detach().cpu().numpy()) 62 | assert batch[0]['text_left'].shape == (5, max_left_len) 63 | assert batch[0]['text_right'].shape == (5, max_right_len) 64 | assert batch[0]['match_histogram'].shape == (5, max_left_len, 30) 65 | 66 | 67 | def test_bert_padding(train_raw): 68 | preprocessor = preprocessors.BertPreprocessor() 69 | data_preprocessed = preprocessor.transform(train_raw, verbose=0) 70 | dataset = Dataset(data_preprocessed, mode='point', batch_size=5) 71 | 72 | pre_fixed_padding = callbacks.BertPadding( 73 | fixed_length_left=5, fixed_length_right=5, pad_mode='pre') 74 | dataloader = DataLoader(dataset, callback=pre_fixed_padding) 75 | for batch in dataloader: 76 | assert batch[0]['text_left'].shape == (5, 7) 77 | assert batch[0]['text_right'].shape == (5, 6) 78 | 79 | post_padding = callbacks.BertPadding(pad_mode='post') 80 | dataloader = DataLoader(dataset, callback=post_padding) 81 | for batch in dataloader: 82 | max_left_len = max(batch[0]['length_left'].detach().cpu().numpy()) 83 | max_right_len = max(batch[0]['length_right'].detach().cpu().numpy()) 84 | assert batch[0]['text_left'].shape == (5, max_left_len + 2) 85 | assert batch[0]['text_right'].shape == (5, max_right_len + 1) 86 | -------------------------------------------------------------------------------- /tests/dataloader/test_dataset.py: -------------------------------------------------------------------------------- 1 | import matchzoo as mz 2 | from matchzoo import preprocessors 3 | from matchzoo.dataloader import Dataset 4 | 5 | 6 | def test_dataset(): 7 | data_pack = mz.datasets.toy.load_data('train', task='ranking') 8 | preprocessor = mz.preprocessors.BasicPreprocessor() 9 | data_processed = preprocessor.fit_transform(data_pack) 10 | 11 | dataset_point = mz.dataloader.Dataset( 12 | data_processed, 13 | mode='point', 14 | batch_size=1, 15 | resample=False, 16 | shuffle=True, 17 | sort=False 18 | ) 19 | dataset_point.batch_size = 10 20 | dataset_point.shuffle = not dataset_point.shuffle 21 | dataset_point.sort = not dataset_point.sort 22 | assert len(dataset_point.batch_indices) == 10 23 | 24 | dataset_pair = mz.dataloader.Dataset( 25 | data_processed, 26 | mode='pair', 27 | num_dup=1, 28 | num_neg=1, 29 | batch_size=1, 30 | resample=True, 31 | shuffle=False, 32 | sort=False 33 | ) 34 | assert len(dataset_pair) == 5 35 | dataset_pair.num_dup = dataset_pair.num_dup + 1 36 | assert len(dataset_pair) == 10 37 | dataset_pair.num_neg = dataset_pair.num_neg + 2 38 | assert len(dataset_pair) == 10 39 | dataset_pair.batch_size = dataset_pair.batch_size + 1 40 | assert len(dataset_pair) == 5 41 | dataset_pair.resample = not dataset_pair.resample 42 | assert len(dataset_pair) == 5 43 | -------------------------------------------------------------------------------- /tests/engine/test_base_preprocessor.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import shutil 3 | 4 | import matchzoo as mz 5 | from matchzoo.engine.base_preprocessor import BasePreprocessor 6 | 7 | 8 | @pytest.fixture 9 | def base_preprocessor(): 10 | BasePreprocessor.__abstractmethods__ = set() 11 | base_processor = BasePreprocessor() 12 | return base_processor 13 | 14 | 15 | def test_save_load(base_preprocessor): 16 | dirpath = '.tmpdir' 17 | base_preprocessor.save(dirpath) 18 | assert mz.load_preprocessor(dirpath) 19 | shutil.rmtree(dirpath) 20 | -------------------------------------------------------------------------------- /tests/engine/test_base_task.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from matchzoo.engine.base_task import BaseTask 3 | 4 | 5 | def test_base_task_instantiation(): 6 | with pytest.raises(TypeError): 7 | BaseTask() 8 | -------------------------------------------------------------------------------- /tests/engine/test_hyper_spaces.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import hyperopt.pyll.base 3 | 4 | from matchzoo.engine import hyper_spaces 5 | 6 | 7 | @pytest.fixture(scope='module', params=[ 8 | lambda x: x + 2, 9 | lambda x: x - 2, 10 | lambda x: x * 2, 11 | lambda x: x / 2, 12 | lambda x: x // 2, 13 | lambda x: x ** 2, 14 | lambda x: 2 + x, 15 | lambda x: 2 - x, 16 | lambda x: 2 * x, 17 | lambda x: 2 / x, 18 | lambda x: 2 // x, 19 | lambda x: 2 ** x, 20 | lambda x: -x 21 | ]) 22 | def op(request): 23 | return request.param 24 | 25 | 26 | @pytest.fixture(scope='module', params=[ 27 | hyper_spaces.choice(options=[0, 1]), 28 | hyper_spaces.uniform(low=0, high=10), 29 | hyper_spaces.quniform(low=0, high=10, q=2) 30 | ]) 31 | def proxy(request): 32 | return request.param 33 | 34 | 35 | def test_init(proxy): 36 | assert isinstance(proxy.convert('label'), hyperopt.pyll.base.Apply) 37 | 38 | 39 | def test_op(proxy, op): 40 | assert isinstance(op(proxy).convert('label'), hyperopt.pyll.base.Apply) 41 | 42 | 43 | def test_str(proxy): 44 | assert isinstance(str(proxy), str) 45 | -------------------------------------------------------------------------------- /tests/engine/test_param_table.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from matchzoo.engine.param import Param 4 | from matchzoo.engine.param_table import ParamTable 5 | from matchzoo.engine.hyper_spaces import quniform 6 | 7 | 8 | @pytest.fixture 9 | def param_table(): 10 | params = ParamTable() 11 | params.add(Param('ham', 'Parma Ham')) 12 | return params 13 | 14 | 15 | def test_get(param_table): 16 | assert param_table['ham'] == 'Parma Ham' 17 | 18 | 19 | def test_set(param_table): 20 | new_param = Param('egg', 'Over Easy') 21 | param_table.set('egg', new_param) 22 | assert 'egg' in param_table.keys() 23 | 24 | 25 | def test_keys(param_table): 26 | assert 'ham' in param_table.keys() 27 | 28 | 29 | def test_hyper_space(param_table): 30 | new_param = Param( 31 | name='my_param', 32 | value=1, 33 | hyper_space=quniform(low=1, high=5) 34 | ) 35 | param_table.add(new_param) 36 | hyper_space = param_table.hyper_space 37 | assert hyper_space 38 | -------------------------------------------------------------------------------- /tests/models/test_base_model.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from matchzoo.engine.base_model import BaseModel 4 | 5 | 6 | def test_base_model_abstract_instantiation(): 7 | with pytest.raises(TypeError): 8 | model = BaseModel(BaseModel.get_default_params()) 9 | assert model 10 | 11 | 12 | def test_base_model_concrete_instantiation(): 13 | class MyBaseModel(BaseModel): 14 | def build(self): 15 | self.a, self.b = 1, 2 16 | def forward(self): 17 | return self.a + self.b 18 | 19 | model = MyBaseModel() 20 | assert model.params 21 | model.guess_and_fill_missing_params() 22 | model.build() 23 | assert model.params.completed(exclude=['out_activation_func']) 24 | -------------------------------------------------------------------------------- /tests/models/test_models.py: -------------------------------------------------------------------------------- 1 | """ 2 | These tests are simplied because the original verion takes too much time to 3 | run, making CI fails as it reaches the time limit. 4 | """ 5 | import torch 6 | import pytest 7 | from pathlib import Path 8 | import shutil 9 | 10 | import matchzoo as mz 11 | 12 | 13 | @pytest.fixture(scope='module', params=[ 14 | mz.tasks.Ranking(losses=mz.losses.RankCrossEntropyLoss(num_neg=2)), 15 | mz.tasks.Classification(num_classes=2), 16 | ]) 17 | def task(request): 18 | return request.param 19 | 20 | 21 | @pytest.fixture(scope='module') 22 | def train_raw(task): 23 | return mz.datasets.toy.load_data('train', task)[:10] 24 | 25 | 26 | @pytest.fixture(scope='module', params=mz.models.list_available()) 27 | def model_class(request): 28 | return request.param 29 | 30 | 31 | @pytest.fixture(scope='module') 32 | def embedding(): 33 | return mz.datasets.toy.load_embedding() 34 | 35 | 36 | @pytest.fixture(scope='module') 37 | def setup(task, model_class, train_raw, embedding): 38 | return mz.auto.prepare( 39 | task=task, 40 | model_class=model_class, 41 | data_pack=train_raw, 42 | embedding=embedding 43 | ) 44 | 45 | 46 | @pytest.fixture(scope='module') 47 | def model(setup): 48 | return setup[0] 49 | 50 | 51 | @pytest.fixture(scope='module') 52 | def preprocessor(setup): 53 | return setup[1] 54 | 55 | 56 | @pytest.fixture(scope='module') 57 | def dataset_builder(setup): 58 | return setup[2] 59 | 60 | 61 | @pytest.fixture(scope='module') 62 | def dataloader_builder(setup): 63 | return setup[3] 64 | 65 | 66 | @pytest.fixture(scope='module') 67 | def dataloader(train_raw, preprocessor, dataset_builder, dataloader_builder): 68 | return dataloader_builder.build( 69 | dataset_builder.build(preprocessor.transform(train_raw))) 70 | 71 | 72 | @pytest.fixture(scope='module') 73 | def optimizer(model): 74 | return torch.optim.Adam(model.parameters()) 75 | 76 | 77 | @pytest.fixture(scope='module') 78 | def save_dir(): 79 | return Path('.matchzoo_test_save_load_tmpdir') 80 | 81 | 82 | @pytest.mark.slow 83 | def test_model_fit_eval_predict(model, optimizer, dataloader, save_dir): 84 | trainer = mz.trainers.Trainer( 85 | model=model, 86 | optimizer=optimizer, 87 | trainloader=dataloader, 88 | validloader=dataloader, 89 | epochs=2, 90 | save_dir=save_dir, 91 | verbose=0 92 | ) 93 | trainer.run() 94 | 95 | if save_dir.exists(): 96 | shutil.rmtree(save_dir) 97 | -------------------------------------------------------------------------------- /tests/modules/test_modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytest 3 | 4 | from matchzoo.modules import Matching 5 | 6 | 7 | def test_matching(): 8 | x = torch.randn(2, 3, 2) 9 | y = torch.randn(2, 4, 2) 10 | z = torch.randn(2, 3, 3) 11 | for matching_type in ['dot', 'mul', 'plus', 'minus', 'concat']: 12 | Matching(matching_type=matching_type)(x, y) 13 | with pytest.raises(ValueError): 14 | Matching(matching_type='error') 15 | with pytest.raises(RuntimeError): 16 | Matching()(x, z) 17 | -------------------------------------------------------------------------------- /tests/tasks/test_tasks.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from matchzoo import tasks 4 | 5 | 6 | @pytest.mark.parametrize("task_type", [ 7 | tasks.Ranking, tasks.Classification 8 | ]) 9 | def test_task_listings(task_type): 10 | assert task_type.list_available_losses() 11 | assert task_type.list_available_metrics() 12 | 13 | 14 | @pytest.mark.parametrize("arg", [None, -1, 0, 1]) 15 | def test_classification_instantiation_failure(arg): 16 | with pytest.raises(Exception): 17 | tasks.Classification(num_classes=arg) 18 | 19 | 20 | @pytest.mark.parametrize("arg", [2, 10, 2048]) 21 | def test_classification_num_classes(arg): 22 | task = tasks.Classification(num_classes=arg) 23 | assert task.num_classes == arg 24 | -------------------------------------------------------------------------------- /tests/test_datasets.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import matchzoo as mz 4 | 5 | 6 | @pytest.mark.cron 7 | def test_load_data(): 8 | train_data = mz.datasets.wiki_qa.load_data('train', task='ranking') 9 | assert len(train_data) == 20360 10 | train_data, _ = mz.datasets.wiki_qa.load_data('train', 11 | task='classification', 12 | return_classes=True) 13 | assert len(train_data) == 20360 14 | 15 | dev_data = mz.datasets.wiki_qa.load_data('dev', task='ranking', 16 | filtered=False) 17 | assert len(dev_data) == 2733 18 | dev_data, tag = mz.datasets.wiki_qa.load_data('dev', task='classification', 19 | filtered=True, 20 | return_classes=True) 21 | assert len(dev_data) == 1126 22 | assert tag == [False, True] 23 | 24 | test_data = mz.datasets.wiki_qa.load_data('test', task='ranking', 25 | filtered=False) 26 | assert len(test_data) == 6165 27 | test_data, tag = mz.datasets.wiki_qa.load_data('test', 28 | task='classification', 29 | filtered=True, 30 | return_classes=True) 31 | assert len(test_data) == 2341 32 | assert tag == [False, True] 33 | 34 | 35 | @pytest.mark.cron 36 | def test_load_snli(): 37 | train_data, classes = mz.datasets.snli.load_data('train', 38 | 'classification', 39 | return_classes=True) 40 | num_samples = 549361 41 | assert len(train_data) == num_samples 42 | x, y = train_data.unpack() 43 | assert len(x['text_left']) == num_samples 44 | assert len(x['text_right']) == num_samples 45 | assert y.shape == (num_samples, 1) 46 | assert classes == ['entailment', 'contradiction', 'neutral'] 47 | dev_data, classes = mz.datasets.snli.load_data('dev', 'classification', 48 | return_classes=True) 49 | assert len(dev_data) == 9842 50 | assert classes == ['entailment', 'contradiction', 'neutral'] 51 | test_data, classes = mz.datasets.snli.load_data('test', 'classification', 52 | return_classes=True) 53 | assert len(test_data) == 9824 54 | assert classes == ['entailment', 'contradiction', 'neutral'] 55 | 56 | train_data = mz.datasets.snli.load_data('train', 'ranking') 57 | x, y = train_data.unpack() 58 | assert len(x['text_left']) == num_samples 59 | assert len(x['text_right']) == num_samples 60 | assert y.shape == (num_samples, 1) 61 | 62 | 63 | @pytest.mark.cron 64 | def test_load_quora_qp(): 65 | train_data = mz.datasets.quora_qp.load_data(task='classification') 66 | assert len(train_data) == 363177 67 | 68 | dev_data, tag = mz.datasets.quora_qp.load_data( 69 | 'dev', 70 | task='classification', 71 | return_classes=True) 72 | assert tag == [False, True] 73 | assert len(dev_data) == 40371 74 | x, y = dev_data.unpack() 75 | assert len(x['text_left']) == 40371 76 | assert len(x['text_right']) == 40371 77 | assert y.shape == (40371, 1) 78 | 79 | test_data = mz.datasets.quora_qp.load_data('test') 80 | assert len(test_data) == 390965 81 | 82 | dev_data = mz.datasets.quora_qp.load_data('dev', 'ranking') 83 | x, y = dev_data.unpack() 84 | assert y.shape == (40371, 1) 85 | -------------------------------------------------------------------------------- /tests/test_embedding.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import matchzoo as mz 4 | 5 | 6 | @pytest.fixture 7 | def term_index(): 8 | return {'G': 1, 'C': 2, 'D': 3, 'A': 4, '_PAD': 0} 9 | 10 | 11 | def test_embedding(term_index): 12 | embed = mz.embedding.load_from_file(mz.datasets.embeddings.EMBED_RANK) 13 | matrix = embed.build_matrix(term_index) 14 | assert matrix.shape == (len(term_index), 50) 15 | embed = mz.embedding.load_from_file(mz.datasets.embeddings.EMBED_10_GLOVE, 16 | mode='glove') 17 | matrix = embed.build_matrix(term_index) 18 | assert matrix.shape == (len(term_index), 10) 19 | -------------------------------------------------------------------------------- /tests/test_losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from matchzoo import losses 5 | 6 | 7 | def test_hinge_loss(): 8 | true_value = torch.Tensor([[1.2], [1], [1], [1]]) 9 | pred_value = torch.Tensor([[1.2], [0.1], [0], [-0.3]]) 10 | expected_loss = torch.Tensor([(0 + 1 - 0.3 + 0) / 2.0]) 11 | loss = losses.RankHingeLoss()(pred_value, true_value) 12 | assert torch.isclose(expected_loss, loss) 13 | expected_loss = torch.Tensor( 14 | [(2 + 0.1 - 1.2 + 2 - 0.3 + 0) / 2.0]) 15 | loss = losses.RankHingeLoss(margin=2)(pred_value, true_value) 16 | assert torch.isclose(expected_loss, loss) 17 | true_value = torch.Tensor( 18 | [[1.2], [1], [0.8], [1], [1], [0.8]]) 19 | pred_value = torch.Tensor( 20 | [[1.2], [0.1], [-0.5], [0], [0], [-0.3]]) 21 | expected_loss = torch.Tensor( 22 | [(0 + 1 - 0.15) / 2.0]) 23 | loss = losses.RankHingeLoss(num_neg=2, margin=1)( 24 | pred_value, true_value) 25 | assert torch.isclose(expected_loss, loss) 26 | 27 | 28 | def test_rank_crossentropy_loss(): 29 | losses.neg_num = 1 30 | 31 | def softmax(x): 32 | return np.exp(x) / np.sum(np.exp(x), axis=0) 33 | 34 | true_value = torch.Tensor([[1.], [0.], [0.], [1.]]) 35 | pred_value = torch.Tensor([[0.8], [0.1], [0.8], [0.1]]) 36 | expected_loss = torch.Tensor( 37 | [(-np.log(softmax([0.8, 0.1])[0]) - np.log( 38 | softmax([0.8, 0.1])[1])) / 2]) 39 | loss = losses.RankCrossEntropyLoss()(pred_value, true_value) 40 | assert torch.isclose(expected_loss, loss) 41 | true_value = torch.Tensor([[1.], [0.], [0.], [0.], [1.], [0.]]) 42 | pred_value = torch.Tensor([[0.8], [0.1], [0.1], [0.8], [0.1], [0.1]]) 43 | expected_loss = torch.Tensor( 44 | [(-np.log(softmax([0.8, 0.1, 0.1])[0]) - np.log( 45 | softmax([0.8, 0.1, 0.1])[1])) / 2]) 46 | loss = losses.RankCrossEntropyLoss(num_neg=2)( 47 | pred_value, true_value) 48 | assert torch.isclose(expected_loss, loss) 49 | -------------------------------------------------------------------------------- /tests/test_metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from matchzoo.engine.base_metric import sort_and_couple 4 | from matchzoo import metrics 5 | 6 | 7 | def test_sort_and_couple(): 8 | l = [0, 1, 2] 9 | s = [0.1, 0.4, 0.2] 10 | c = sort_and_couple(l, s) 11 | assert (c == np.array([(1, 0.4), (2, 0.2), (0, 0.1)])).all() 12 | 13 | 14 | def test_mean_reciprocal_rank(): 15 | label = [0, 1, 2] 16 | score = [0.1, 0.4, 0.2] 17 | assert metrics.MeanReciprocalRank()(label, score) == 1 18 | 19 | 20 | def test_precision_at_k(): 21 | label = [0, 1, 2] 22 | score = [0.1, 0.4, 0.2] 23 | assert metrics.Precision(k=1)(label, score) == 1. 24 | assert metrics.Precision(k=2)(label, score) == 1. 25 | assert round(metrics.Precision(k=3)(label, score), 2) == 0.67 26 | 27 | 28 | def test_average_precision(): 29 | label = [0, 1, 2] 30 | score = [0.1, 0.4, 0.2] 31 | assert round(metrics.AveragePrecision()(label, score), 2) == 0.89 32 | 33 | 34 | def test_mean_average_precision(): 35 | label = [0, 1, 2] 36 | score = [0.1, 0.4, 0.2] 37 | assert metrics.MeanAveragePrecision()(label, score) == 1. 38 | 39 | 40 | def test_dcg_at_k(): 41 | label = [0, 1, 2] 42 | score = [0.1, 0.4, 0.2] 43 | dcg = metrics.DiscountedCumulativeGain 44 | assert round(dcg(k=1)(label, score), 2) == 1.44 45 | assert round(dcg(k=2)(label, score), 2) == 4.17 46 | assert round(dcg(k=3)(label, score), 2) == 4.17 47 | 48 | 49 | def test_ndcg_at_k(): 50 | label = [0, 1, 2] 51 | score = [0.1, 0.4, 0.2] 52 | ndcg = metrics.NormalizedDiscountedCumulativeGain 53 | assert round(ndcg(k=1)(label, score), 2) == 0.33 54 | assert round(ndcg(k=2)(label, score), 2) == 0.80 55 | assert round(ndcg(k=3)(label, score), 2) == 0.80 56 | 57 | 58 | def test_accuracy(): 59 | label = np.array([1]) 60 | score = np.array([[0, 1]]) 61 | assert metrics.Accuracy()(label, score) == 1 62 | 63 | 64 | def test_cross_entropy(): 65 | label = [0, 1] 66 | score = [[0.25, 0.25], [0.01, 0.90]] 67 | assert round(metrics.CrossEntropy()(label, score), 2) == 0.75 68 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from pathlib import Path 4 | 5 | import matchzoo 6 | from matchzoo import utils 7 | from matchzoo.engine.base_model import BaseModel 8 | 9 | 10 | def test_timer(): 11 | timer = utils.Timer() 12 | start = timer.time 13 | timer.stop() 14 | assert timer.time 15 | timer.resume() 16 | assert timer.time > start 17 | 18 | 19 | def test_list_recursive_subclasses(): 20 | assert utils.list_recursive_concrete_subclasses( 21 | BaseModel 22 | ) 23 | 24 | 25 | def test_average_meter(): 26 | am = utils.AverageMeter() 27 | am.update(1) 28 | assert am.avg == 1.0 29 | am.update(val=2.5, n=2) 30 | assert am.avg == 2.0 31 | 32 | 33 | def test_early_stopping(): 34 | es = utils.EarlyStopping( 35 | patience=1, 36 | key='key', 37 | ) 38 | result = {'key': 1.0} 39 | es.update(result) 40 | assert es.should_stop_early is False 41 | es.update(result) 42 | assert es.should_stop_early is True 43 | state = es.state_dict() 44 | new_es = utils.EarlyStopping() 45 | assert new_es.should_stop_early is False 46 | new_es.load_state_dict(state) 47 | assert new_es.best_so_far == 1.0 48 | assert new_es.is_best_so_far is False 49 | assert new_es.should_stop_early is True 50 | 51 | 52 | def test_get_file(): 53 | _url = "https://raw.githubusercontent.com/NTMC-Community/" \ 54 | "MatchZoo-py/master/LICENSE" 55 | file_path = utils.get_file( 56 | 'LICENSE', _url, extract=True, 57 | cache_dir=matchzoo.USER_DATA_DIR, 58 | cache_subdir='LICENSE', 59 | verbose=1 60 | ) 61 | num_lines = 203 62 | assert len(open(file_path, 'rb').readlines()) == num_lines 63 | file_hash = utils._hash_file(file_path, algorithm='md5') 64 | 65 | file_path2 = utils.get_file( 66 | 'LICENSE', _url, extract=False, 67 | md5_hash=file_hash, 68 | cache_dir=matchzoo.USER_DATA_DIR, 69 | cache_subdir='LICENSE', 70 | verbose=1 71 | ) 72 | file_hash2 = utils._hash_file(file_path2, algorithm='md5') 73 | assert file_hash == file_hash2 74 | 75 | file_dir = matchzoo.USER_DATA_DIR.joinpath('LICENSE') 76 | if os.path.exists(file_dir): 77 | shutil.rmtree(file_dir) 78 | -------------------------------------------------------------------------------- /tests/trainer/test_trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytest 3 | from pathlib import Path 4 | import shutil 5 | 6 | import matchzoo as mz 7 | 8 | 9 | @pytest.fixture(scope='module') 10 | def task(): 11 | return mz.tasks.Ranking(losses=mz.losses.RankCrossEntropyLoss()) 12 | 13 | 14 | @pytest.fixture(scope='module') 15 | def train_raw(task): 16 | return mz.datasets.toy.load_data('train', task)[:10] 17 | 18 | 19 | @pytest.fixture(scope='module') 20 | def model_class(): 21 | return mz.models.DenseBaseline 22 | 23 | 24 | @pytest.fixture(scope='module') 25 | def embedding(): 26 | return mz.datasets.toy.load_embedding() 27 | 28 | 29 | @pytest.fixture(scope='module') 30 | def setup(task, model_class, train_raw, embedding): 31 | return mz.auto.prepare( 32 | task=task, 33 | model_class=model_class, 34 | data_pack=train_raw, 35 | embedding=embedding 36 | ) 37 | 38 | 39 | @pytest.fixture(scope='module') 40 | def model(setup): 41 | return setup[0] 42 | 43 | 44 | @pytest.fixture(scope='module') 45 | def preprocessor(setup): 46 | return setup[1] 47 | 48 | 49 | @pytest.fixture(scope='module') 50 | def dataset_builder(setup): 51 | return setup[2] 52 | 53 | 54 | @pytest.fixture(scope='module') 55 | def dataloader_builder(setup): 56 | return setup[3] 57 | 58 | 59 | @pytest.fixture(scope='module') 60 | def dataloader(train_raw, preprocessor, dataset_builder, dataloader_builder): 61 | return dataloader_builder.build( 62 | dataset_builder.build(preprocessor.transform(train_raw))) 63 | 64 | 65 | @pytest.fixture(scope='module') 66 | def optimizer(model): 67 | return torch.optim.Adam(model.parameters()) 68 | 69 | 70 | @pytest.fixture(scope='module') 71 | def scheduler(optimizer): 72 | return torch.optim.lr_scheduler.StepLR(optimizer, step_size=10) 73 | 74 | 75 | @pytest.fixture(scope='module') 76 | def save_dir(): 77 | return Path('.matchzoo_test_save_load_tmpdir') 78 | 79 | 80 | @pytest.fixture(scope='module') 81 | def trainer( 82 | model, optimizer, dataloader, scheduler, save_dir 83 | ): 84 | return mz.trainers.Trainer( 85 | model=model, 86 | optimizer=optimizer, 87 | trainloader=dataloader, 88 | validloader=dataloader, 89 | epochs=4, 90 | validate_interval=2, 91 | patience=1, 92 | scheduler=scheduler, 93 | clip_norm=10, 94 | save_dir=save_dir, 95 | save_all=True, 96 | verbose=1, 97 | ) 98 | 99 | 100 | @pytest.mark.slow 101 | def test_trainer(trainer, dataloader, save_dir): 102 | trainer.run() 103 | assert trainer.evaluate(dataloader) 104 | assert trainer.predict(dataloader) is not None 105 | 106 | # Save model 107 | model_checkpoint = save_dir.joinpath('model.pt') 108 | trainer.save_model() 109 | trainer.restore_model(model_checkpoint) 110 | 111 | # Save model 112 | trainer_checkpoint = save_dir.joinpath('trainer.pt') 113 | trainer.save() 114 | trainer.restore(trainer_checkpoint) 115 | 116 | if save_dir.exists(): 117 | shutil.rmtree(save_dir) 118 | -------------------------------------------------------------------------------- /tutorials/classification/init.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "matchzoo version 1.0\n" 13 | ] 14 | } 15 | ], 16 | "source": [ 17 | "import torch\n", 18 | "import numpy as np\n", 19 | "import pandas as pd\n", 20 | "import matchzoo as mz\n", 21 | "print('matchzoo version', mz.__version__)" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 2, 27 | "metadata": {}, 28 | "outputs": [ 29 | { 30 | "name": "stdout", 31 | "output_type": "stream", 32 | "text": [ 33 | "`classification_task` initialized with metrics [accuracy]\n" 34 | ] 35 | } 36 | ], 37 | "source": [ 38 | "classification_task = mz.tasks.Classification(num_classes=2)\n", 39 | "classification_task.metrics = ['acc']\n", 40 | "print(\"`classification_task` initialized with metrics\", classification_task.metrics)" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 3, 46 | "metadata": {}, 47 | "outputs": [ 48 | { 49 | "name": "stdout", 50 | "output_type": "stream", 51 | "text": [ 52 | "data loading ...\n", 53 | "data loaded as `train_pack_raw` `dev_pack_raw` `test_pack_raw`\n" 54 | ] 55 | } 56 | ], 57 | "source": [ 58 | "print('data loading ...')\n", 59 | "train_pack_raw = mz.datasets.wiki_qa.load_data('train', task=classification_task)\n", 60 | "dev_pack_raw = mz.datasets.wiki_qa.load_data('dev', task=classification_task)\n", 61 | "test_pack_raw = mz.datasets.wiki_qa.load_data('test', task=classification_task)\n", 62 | "print('data loaded as `train_pack_raw` `dev_pack_raw` `test_pack_raw`')" 63 | ] 64 | } 65 | ], 66 | "metadata": { 67 | "kernelspec": { 68 | "display_name": "Python 3", 69 | "language": "python", 70 | "name": "python3" 71 | }, 72 | "language_info": { 73 | "codemirror_mode": { 74 | "name": "ipython", 75 | "version": 3 76 | }, 77 | "file_extension": ".py", 78 | "mimetype": "text/x-python", 79 | "name": "python", 80 | "nbconvert_exporter": "python", 81 | "pygments_lexer": "ipython3", 82 | "version": "3.6.8" 83 | } 84 | }, 85 | "nbformat": 4, 86 | "nbformat_minor": 2 87 | } 88 | -------------------------------------------------------------------------------- /tutorials/ranking/init.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "ExecuteTime": { 8 | "end_time": "2019-03-20T09:24:32.779551Z", 9 | "start_time": "2019-03-20T09:24:30.316404Z" 10 | } 11 | }, 12 | "outputs": [ 13 | { 14 | "name": "stdout", 15 | "output_type": "stream", 16 | "text": [ 17 | "matchzoo version 1.0\n" 18 | ] 19 | } 20 | ], 21 | "source": [ 22 | "import torch\n", 23 | "import numpy as np\n", 24 | "import pandas as pd\n", 25 | "import matchzoo as mz\n", 26 | "print('matchzoo version', mz.__version__)" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 2, 32 | "metadata": { 33 | "ExecuteTime": { 34 | "end_time": "2019-03-20T09:24:33.370082Z", 35 | "start_time": "2019-03-20T09:24:33.365067Z" 36 | } 37 | }, 38 | "outputs": [ 39 | { 40 | "name": "stdout", 41 | "output_type": "stream", 42 | "text": [ 43 | "`ranking_task` initialized with metrics [normalized_discounted_cumulative_gain@3(0.0), normalized_discounted_cumulative_gain@5(0.0), mean_average_precision(0.0)]\n" 44 | ] 45 | } 46 | ], 47 | "source": [ 48 | "ranking_task = mz.tasks.Ranking(losses=mz.losses.RankHingeLoss())\n", 49 | "ranking_task.metrics = [\n", 50 | " mz.metrics.NormalizedDiscountedCumulativeGain(k=3),\n", 51 | " mz.metrics.NormalizedDiscountedCumulativeGain(k=5),\n", 52 | " mz.metrics.MeanAveragePrecision()\n", 53 | "]\n", 54 | "print(\"`ranking_task` initialized with metrics\", ranking_task.metrics)" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 3, 60 | "metadata": { 61 | "ExecuteTime": { 62 | "end_time": "2019-03-20T09:24:33.363273Z", 63 | "start_time": "2019-03-20T09:24:32.781793Z" 64 | } 65 | }, 66 | "outputs": [ 67 | { 68 | "name": "stdout", 69 | "output_type": "stream", 70 | "text": [ 71 | "data loading ...\n", 72 | "data loaded as `train_pack_raw` `dev_pack_raw` `test_pack_raw`\n" 73 | ] 74 | } 75 | ], 76 | "source": [ 77 | "print('data loading ...')\n", 78 | "train_pack_raw = mz.datasets.wiki_qa.load_data('train', task=ranking_task)\n", 79 | "dev_pack_raw = mz.datasets.wiki_qa.load_data('dev', task=ranking_task, filtered=True)\n", 80 | "test_pack_raw = mz.datasets.wiki_qa.load_data('test', task=ranking_task, filtered=True)\n", 81 | "print('data loaded as `train_pack_raw` `dev_pack_raw` `test_pack_raw`')" 82 | ] 83 | } 84 | ], 85 | "metadata": { 86 | "hide_input": false, 87 | "kernelspec": { 88 | "display_name": "Python 3", 89 | "language": "python", 90 | "name": "python3" 91 | }, 92 | "language_info": { 93 | "codemirror_mode": { 94 | "name": "ipython", 95 | "version": 3 96 | }, 97 | "file_extension": ".py", 98 | "mimetype": "text/x-python", 99 | "name": "python", 100 | "nbconvert_exporter": "python", 101 | "pygments_lexer": "ipython3", 102 | "version": "3.6.8" 103 | }, 104 | "toc": { 105 | "nav_menu": {}, 106 | "number_sections": true, 107 | "sideBar": true, 108 | "skip_h1_title": false, 109 | "toc_cell": false, 110 | "toc_position": {}, 111 | "toc_section_display": "block", 112 | "toc_window_display": false 113 | }, 114 | "varInspector": { 115 | "cols": { 116 | "lenName": 16, 117 | "lenType": 16, 118 | "lenVar": 40 119 | }, 120 | "kernels_config": { 121 | "python": { 122 | "delete_cmd_postfix": "", 123 | "delete_cmd_prefix": "del ", 124 | "library": "var_list.py", 125 | "varRefreshCmd": "print(var_dic_list())" 126 | }, 127 | "r": { 128 | "delete_cmd_postfix": ") ", 129 | "delete_cmd_prefix": "rm(", 130 | "library": "var_list.r", 131 | "varRefreshCmd": "cat(var_dic_list()) " 132 | } 133 | }, 134 | "types_to_exclude": [ 135 | "module", 136 | "function", 137 | "builtin_function_or_method", 138 | "instance", 139 | "_Feature" 140 | ], 141 | "window_display": false 142 | } 143 | }, 144 | "nbformat": 4, 145 | "nbformat_minor": 2 146 | } 147 | --------------------------------------------------------------------------------