├── .coveragerc ├── .flake8 ├── .gitattributes ├── .github └── ISSUE_TEMPLATE │ ├── bug-report.md │ ├── feature-request.md │ └── usage-question.md ├── .gitignore ├── .travis.yml ├── CODEOWNERS ├── CONTRIBUTING.md ├── LICENSE ├── Makefile ├── README.md ├── artworks ├── matchzoo-logo.png ├── matchzoo_github_qr.png └── matchzoo_github_qr_black.png ├── docs ├── DOCCHECK.md ├── Makefile ├── Readme.md ├── _build │ ├── doctrees │ │ ├── environment.pickle │ │ ├── index.doctree │ │ ├── matchzoo.doctree │ │ ├── matchzoo.engine.doctree │ │ ├── matchzoo.models.doctree │ │ ├── matchzoo.tasks.doctree │ │ └── modules.doctree │ └── html │ │ ├── .buildinfo │ │ ├── _images │ │ └── matchzoo-logo.png │ │ ├── _sources │ │ ├── index.rst.txt │ │ ├── matchzoo.engine.rst.txt │ │ ├── matchzoo.models.rst.txt │ │ ├── matchzoo.rst.txt │ │ ├── matchzoo.tasks.rst.txt │ │ └── modules.rst.txt │ │ ├── _static │ │ ├── ajax-loader.gif │ │ ├── basic.css │ │ ├── comment-bright.png │ │ ├── comment-close.png │ │ ├── comment.png │ │ ├── css │ │ │ ├── badge_only.css │ │ │ └── theme.css │ │ ├── doctools.js │ │ ├── documentation_options.js │ │ ├── down-pressed.png │ │ ├── down.png │ │ ├── file.png │ │ ├── fonts │ │ │ ├── Lato │ │ │ │ ├── lato-bold.eot │ │ │ │ ├── lato-bold.ttf │ │ │ │ ├── lato-bold.woff │ │ │ │ ├── lato-bold.woff2 │ │ │ │ ├── lato-bolditalic.eot │ │ │ │ ├── lato-bolditalic.ttf │ │ │ │ ├── lato-bolditalic.woff │ │ │ │ ├── lato-bolditalic.woff2 │ │ │ │ ├── lato-italic.eot │ │ │ │ ├── lato-italic.ttf │ │ │ │ ├── lato-italic.woff │ │ │ │ ├── lato-italic.woff2 │ │ │ │ ├── lato-regular.eot │ │ │ │ ├── lato-regular.ttf │ │ │ │ ├── lato-regular.woff │ │ │ │ └── lato-regular.woff2 │ │ │ ├── RobotoSlab │ │ │ │ ├── roboto-slab-v7-bold.eot │ │ │ │ ├── roboto-slab-v7-bold.ttf │ │ │ │ ├── roboto-slab-v7-bold.woff │ │ │ │ ├── roboto-slab-v7-bold.woff2 │ │ │ │ ├── roboto-slab-v7-regular.eot │ │ │ │ ├── roboto-slab-v7-regular.ttf │ │ │ │ ├── roboto-slab-v7-regular.woff │ │ │ │ └── roboto-slab-v7-regular.woff2 │ │ │ ├── fontawesome-webfont.eot │ │ │ ├── fontawesome-webfont.svg │ │ │ ├── fontawesome-webfont.ttf │ │ │ ├── fontawesome-webfont.woff │ │ │ └── fontawesome-webfont.woff2 │ │ ├── jquery-3.2.1.js │ │ ├── jquery.js │ │ ├── js │ │ │ ├── modernizr.min.js │ │ │ └── theme.js │ │ ├── minus.png │ │ ├── plus.png │ │ ├── pygments.css │ │ ├── searchtools.js │ │ ├── underscore-1.3.1.js │ │ ├── underscore.js │ │ ├── up-pressed.png │ │ ├── up.png │ │ └── websupport.js │ │ ├── genindex.html │ │ ├── index.html │ │ ├── matchzoo.engine.html │ │ ├── matchzoo.html │ │ ├── matchzoo.models.html │ │ ├── matchzoo.tasks.html │ │ ├── modules.html │ │ ├── objects.inv │ │ ├── py-modindex.html │ │ ├── search.html │ │ └── searchindex.js ├── make.bat ├── requirements.txt └── source │ ├── conf.py │ ├── index.rst │ ├── matchzoo.auto.preparer.rst │ ├── matchzoo.auto.rst │ ├── matchzoo.auto.tuner.callbacks.rst │ ├── matchzoo.auto.tuner.rst │ ├── matchzoo.contrib.layers.rst │ ├── matchzoo.contrib.models.rst │ ├── matchzoo.contrib.rst │ ├── matchzoo.data_generator.callbacks.rst │ ├── matchzoo.data_generator.rst │ ├── matchzoo.data_pack.rst │ ├── matchzoo.datasets.embeddings.rst │ ├── matchzoo.datasets.quora_qp.rst │ ├── matchzoo.datasets.rst │ ├── matchzoo.datasets.snli.rst │ ├── matchzoo.datasets.toy.rst │ ├── matchzoo.datasets.wiki_qa.rst │ ├── matchzoo.embedding.rst │ ├── matchzoo.engine.rst │ ├── matchzoo.layers.rst │ ├── matchzoo.losses.rst │ ├── matchzoo.metrics.rst │ ├── matchzoo.models.rst │ ├── matchzoo.preprocessors.rst │ ├── matchzoo.preprocessors.units.rst │ ├── matchzoo.processor_units.rst │ ├── matchzoo.rst │ ├── matchzoo.tasks.rst │ ├── matchzoo.utils.rst │ ├── model_reference.rst │ └── modules.rst ├── matchzoo ├── __init__.py ├── auto │ ├── __init__.py │ ├── preparer │ │ ├── __init__.py │ │ ├── prepare.py │ │ └── preparer.py │ └── tuner │ │ ├── __init__.py │ │ ├── callbacks │ │ ├── __init__.py │ │ ├── callback.py │ │ ├── lambda_callback.py │ │ ├── load_embedding_matrix.py │ │ └── save_model.py │ │ ├── tune.py │ │ └── tuner.py ├── contrib │ ├── README.md │ ├── __init__.py │ ├── layers │ │ ├── __init__.py │ │ ├── attention_layer.py │ │ ├── decaying_dropout_layer.py │ │ ├── matching_tensor_layer.py │ │ ├── multi_perspective_layer.py │ │ ├── semantic_composite_layer.py │ │ └── spatial_gru.py │ ├── legacy_data_generator.py │ └── models │ │ ├── __init__.py │ │ ├── bimpm.py │ │ ├── diin.py │ │ ├── esim.py │ │ ├── hbmp.py │ │ ├── match_lstm.py │ │ └── match_srnn.py ├── data_generator │ ├── __init__.py │ ├── callbacks │ │ ├── __init__.py │ │ ├── callback.py │ │ ├── dynamic_pooling.py │ │ ├── histogram.py │ │ └── lambda_callback.py │ ├── data_generator.py │ └── data_generator_builder.py ├── data_pack │ ├── __init__.py │ ├── data_pack.py │ └── pack.py ├── datasets │ ├── __init__.py │ ├── bert_resources │ │ └── uncased_vocab_100.txt │ ├── cqa_ql_16 │ │ ├── __init__.py │ │ └── load_data.py │ ├── embeddings │ │ ├── __init__.py │ │ ├── embed_10_glove.txt │ │ ├── embed_10_word2vec.txt │ │ ├── embed_err.txt.gb2312 │ │ ├── embed_rank.txt │ │ ├── embed_word.txt │ │ └── load_glove_embedding.py │ ├── quora_qp │ │ ├── __init__.py │ │ └── load_data.py │ ├── snli │ │ ├── __init__.py │ │ └── load_data.py │ ├── toy │ │ ├── __init__.py │ │ ├── dev.csv │ │ ├── embedding.2d.txt │ │ ├── test.csv │ │ └── train.csv │ └── wiki_qa │ │ ├── __init__.py │ │ └── load_data.py ├── embedding │ ├── __init__.py │ └── embedding.py ├── engine │ ├── __init__.py │ ├── base_metric.py │ ├── base_model.py │ ├── base_preprocessor.py │ ├── base_task.py │ ├── callbacks.py │ ├── hyper_spaces.py │ ├── param.py │ ├── param_table.py │ └── parse_metric.py ├── layers │ ├── __init__.py │ ├── dynamic_pooling_layer.py │ └── matching_layer.py ├── losses │ ├── __init__.py │ ├── rank_cross_entropy_loss.py │ └── rank_hinge_loss.py ├── metrics │ ├── __init__.py │ ├── average_precision.py │ ├── discounted_cumulative_gain.py │ ├── mean_average_precision.py │ ├── mean_reciprocal_rank.py │ ├── normalized_discounted_cumulative_gain.py │ └── precision.py ├── models │ ├── README.rst │ ├── __init__.py │ ├── anmm.py │ ├── arci.py │ ├── arcii.py │ ├── cdssm.py │ ├── conv_knrm.py │ ├── dense_baseline.py │ ├── drmm.py │ ├── drmmtks.py │ ├── dssm.py │ ├── duet.py │ ├── knrm.py │ ├── match_pyramid.py │ ├── mvlstm.py │ ├── naive.py │ └── parameter_readme_generator.py ├── preprocessors │ ├── __init__.py │ ├── basic_preprocessor.py │ ├── bert_preprocessor.py │ ├── build_unit_from_data_pack.py │ ├── build_vocab_unit.py │ ├── cdssm_preprocessor.py │ ├── chain_transform.py │ ├── diin_preprocessor.py │ ├── dssm_preprocessor.py │ ├── naive_preprocessor.py │ └── units │ │ ├── __init__.py │ │ ├── bert_clean.py │ │ ├── character_index.py │ │ ├── digit_removal.py │ │ ├── fixed_length.py │ │ ├── frequency_filter.py │ │ ├── lemmatization.py │ │ ├── lowercase.py │ │ ├── matching_histogram.py │ │ ├── ngram_letter.py │ │ ├── punc_removal.py │ │ ├── stateful_unit.py │ │ ├── stemming.py │ │ ├── stop_removal.py │ │ ├── tokenize.py │ │ ├── unit.py │ │ ├── vocabulary.py │ │ ├── word_exact_match.py │ │ └── word_hashing.py ├── tasks │ ├── __init__.py │ ├── classification.py │ └── ranking.py ├── utils │ ├── __init__.py │ ├── bert_utils.py │ ├── list_recursive_subclasses.py │ ├── make_keras_optimizer_picklable.py │ ├── one_hot.py │ └── tensor_type.py └── version.py ├── readthedocs.yml ├── requirements.txt ├── setup.py ├── tests ├── __init__.py ├── inte_test │ └── __init__.py └── unit_test │ ├── __init__.py │ ├── data_pack │ └── test_datapack.py │ ├── engine │ ├── test_base_preprocessor.py │ ├── test_base_task.py │ ├── test_hyper_spaces.py │ └── test_param_table.py │ ├── models │ ├── test_base_model.py │ └── test_models.py │ ├── processor_units │ └── test_processor_units.py │ ├── tasks │ └── test_tasks.py │ ├── test_data_generator.py │ ├── test_datasets.py │ ├── test_embedding.py │ ├── test_layers.py │ ├── test_losses.py │ ├── test_metrics.py │ ├── test_tuner.py │ └── test_utils.py └── tutorials ├── data_handling.ipynb ├── model_tuning.ipynb ├── models.ipynb ├── quick_start.ipynb ├── quick_start_chart.png ├── quora └── esim.ipynb └── wikiqa ├── README.rst ├── arci.ipynb ├── arcii.ipynb ├── cdssm.ipynb ├── conv_knrm.ipynb ├── drmm.ipynb ├── drmmtks.ipynb ├── dssm.ipynb ├── duet.ipynb ├── esim.ipynb ├── init.ipynb ├── knrm.ipynb ├── match_lstm.ipynb ├── matchpyramid.ipynb └── mvlstm.ipynb /.coveragerc: -------------------------------------------------------------------------------- 1 | # .coveragerc to control coverage.py 2 | [report] 3 | # regrexes for lines to exclude from consideration 4 | exclude_lines = 5 | if __name__ == .__main__.: 6 | ValueError 7 | TypeError 8 | NotImplementedError 9 | omit = 10 | matchzoo/__init__.py 11 | matchzoo/version.py 12 | matchzoo/models/parameter_readme_generator.py 13 | matchzoo/*/__init__.py 14 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = 3 | # D401 First line should be in imperative mood 4 | D401, 5 | # D202 No blank lines allowed after function docstring 6 | D202, 7 | 8 | # For doctests: 9 | # D207 Docstring is under-indented 10 | D207, 11 | # D301 Use r""" if any backslashes in a docstring 12 | D301, 13 | # F401 'blah blah' imported but unused 14 | F401, 15 | 16 | # D100 Missing docstring in public module 17 | D100, -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | tutorials/* linguist-vendored -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug-report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug Report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: bug 6 | assignees: '' 7 | 8 | --- 9 | 10 | ### Describe the bug 11 | Please provide a clear and concise description of what the bug is. If applicable, add screenshots to help explain your problem, especially for visualization related problems. 12 | 13 | ### To Reproduce 14 | Please provide a [Minimal, Complete, and Verifiable example](https://stackoverflow.com/help/mcve) here. We hope we can simply copy/paste/run it. It is also nice to share a hosted runnable script (e.g. Google Colab), especially for hardware-related problems. 15 | 16 | ### Describe your attempts 17 | - [ ] I checked the documentation and found no answer 18 | - [ ] I checked to make sure that this is not a duplicate issue 19 | 20 | You should also provide code snippets you tried as a workaround, StackOverflow solution that you have walked through, or your best guess of the cause that you can't locate (e.g. cosmic radiation). 21 | 22 | ### Context 23 | - **OS** [e.g. Windows 10, macOS 10.14]: 24 | - **Hardware** [e.g. CPU only, GTX 1080 Ti]: 25 | 26 | In addition, figure out your MatchZoo version by running `import matchzoo; matchzoo.__version__`. If this gives you an error, then you're probably using `1.0`, and `1.0` is no longer supported. Then attach the corresponding label on the issue. 27 | 28 | ### Additional Information 29 | Other things you want the developers to know. 30 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature-request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature Request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: enhancement 6 | assignees: '' 7 | 8 | --- 9 | 10 | - [ ] I checked to make sure that this is not a duplicate issue 11 | - [ ] I'm submitting the request to the correct repository (for model requests, see [here](https://github.com/NTMC-Community/awaresome-neural-models-for-semantic-match)) 12 | 13 | ### Is your feature request related to a problem? Please describe. 14 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 15 | 16 | ### Describe the solution you'd like 17 | A clear and concise description of what you want to happen. 18 | 19 | ### Describe alternatives you've considered 20 | A clear and concise description of any alternative solutions or features you've considered. 21 | 22 | ### Additional Information 23 | Other things you want the developers to know. 24 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/usage-question.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Usage Question 3 | about: Ask a question about MatchZoo usage 4 | title: '' 5 | labels: question 6 | assignees: '' 7 | 8 | --- 9 | 10 | ### Describe the Question 11 | Please provide a clear and concise description of what the question is. 12 | 13 | ### Describe your attempts 14 | - [ ] I walked through the tutorials 15 | - [ ] I checked the documentation 16 | - [ ] I checked to make sure that this is not a duplicate question 17 | 18 | You may also provide a [Minimal, Complete, and Verifiable example](https://stackoverflow.com/help/mcve) you tried as a workaround, or StackOverflow solution that you have walked through. (e.g. cosmic radiation). 19 | 20 | In addition, figure out your MatchZoo version by running `import matchzoo; matchzoo.__version__`. If this gives you an error, then you're probably using `1.0`, and `1.0` is no longer supported. Then attach the corresponding label on the issue. 21 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.log 3 | *.swp 4 | *.bak 5 | *.weights 6 | *.trec 7 | *.ranklist 8 | *.DS_Store 9 | .vscode 10 | .coverage 11 | .ipynb_checkpoints/ 12 | predict.* 13 | build/ 14 | dist/ 15 | data/ 16 | log/* 17 | .ipynb_checkpoints/ 18 | matchzoo/log/* 19 | matchzoo/querydecision/ 20 | log/* 21 | .idea/ 22 | .pytest_cache/ 23 | MatchZoo.egg-info/ 24 | notebooks/wikiqa/.ipynb_checkpoints/* 25 | .cache 26 | .tmpdir 27 | htmlcov/ -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | 3 | cache: pip 4 | 5 | sudo: true 6 | 7 | env: 8 | global: 9 | - PYTHONPATH=$PYTHONPATH:$TRAVIS_BUILD_DIR/tests:$TRAVIS_BUILD_DIR/matchzoo 10 | 11 | matrix: 12 | allow_failures: 13 | - os: osx 14 | include: 15 | - os: linux 16 | dist: trusty 17 | python: 3.6 18 | - os: osx 19 | osx_image: xcode10.2 20 | language: shell 21 | 22 | install: 23 | - pip3 install -U pip 24 | - pip3 install -r requirements.txt 25 | - python3 -m nltk.downloader punkt 26 | - python3 -m nltk.downloader wordnet 27 | - python3 -m nltk.downloader stopwords 28 | 29 | script: 30 | - stty cols 80 31 | - export COLUMNS=80 32 | - if [ "$TRAVIS_EVENT_TYPE" == "pull_request" ]; then make push; fi 33 | - if [ "$TRAVIS_EVENT_TYPE" == "push" ]; then make push; fi 34 | - if [ "$TRAVIS_EVENT_TYPE" == "cron" ]; then make cron; fi 35 | 36 | 37 | after_success: 38 | - codecov 39 | -------------------------------------------------------------------------------- /CODEOWNERS: -------------------------------------------------------------------------------- 1 | # Watchers and contributors to MatchZoo repo directories/packages/files 2 | # Please see documentation of use of CODEOWNERS file at 3 | # https://help.github.com/articles/about-codeowners/ and 4 | # https://github.com/blog/2392-introducing-code-owners 5 | # 6 | # Anybody can add themselves or a team as additional watcher or contributor 7 | # to get notified about changes in a specific package. 8 | # See https://help.github.com/articles/about-teams how to setup teams. 9 | 10 | # Define individuals or teams that are responsible for code in a repository. 11 | 12 | # global owner. 13 | * @faneshion 14 | 15 | # third-party & project configuration 16 | .coveragerc @bwanglzu 17 | .gitignore @faneshion 18 | .travis.yml @bwanglzu 19 | CONTRIBUTING.MD @bwanglzu 20 | Makefile @uduse @bwanglzu 21 | README.md @faneshion @pl8787 22 | readthedocs.yml @wqh17101 @bwanglzu 23 | requirements.txt @faneshion @pl8787 24 | setup.py @faneshion @pl8787 25 | 26 | # artworks 27 | /artworks/ @faneshion 28 | 29 | # tutorials 30 | /tutorials/ @uduse @faneshion 31 | 32 | # docs 33 | /docs/ @wqh17101 @bwanglzu 34 | 35 | # tests 36 | /tests/ @faneshion @uduse @bwanglzu 37 | 38 | # matchzoo 39 | /matchzoo/engine/ @faneshion @bwanglzu @uduse @pl8787 40 | /matchzoo/auto/ @uduse @bwanglzu 41 | /matchzoo/models/ @faneshion @pl8787 @bwanglzu @uduse 42 | /matchzoo/preprocessor/ @uduse @faneshion @pl8787 43 | /matchzoo/tasks/ @uduse @bwanglzu 44 | /matchzoo/data_generator/ @faneshion @uduse @pl8787 45 | /matchzoo/data_pack/ @faneshion @uduse 46 | /matchzoo/metrics/ @faneshion @pl8787 @uduse 47 | /matchzoo/losses/ @faneshion @pl8787 @bwanglzu 48 | /matchzoo/layers/ @uduse @yangliuy 49 | /matchzoo/* @faneshion @uduse @bwanglzu 50 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # Usages: 2 | # 3 | # to install matchzoo dependencies: 4 | # $ make init 5 | # 6 | # to run all matchzoo tests, recommended for big PRs and new versions: 7 | # $ make test 8 | # 9 | # there are three kinds of tests: 10 | # 11 | # 1. "quick" tests 12 | # - run in seconds 13 | # - include all unit tests without marks and all doctests 14 | # - for rapid prototyping 15 | # - CI run this for all PRs 16 | # 17 | # 2. "slow" tests 18 | # - run in minutes 19 | # - include all unit tests marked "slow" 20 | # - CI run this for all PRs 21 | # 22 | # 3. "cron" tests 23 | # - run in minutes 24 | # - involves underministic behavoirs (e.g. network connection) 25 | # - include all unit tests marked "cron" 26 | # - CI run this on a daily basis 27 | # 28 | # to run quick tests, excluding time consuming tests and crons: 29 | # $ make quick 30 | # 31 | # to run slow tests, excluding normal tests and crons: 32 | # $ make slow 33 | # 34 | # to run crons: 35 | # $ make cron 36 | # 37 | # to run all tests: 38 | # $ make test 39 | # 40 | # to run CI push/PR tests: 41 | # $ make push 42 | # 43 | # to run docstring style check: 44 | # $ make flake 45 | 46 | init: 47 | pip install -r requirements.txt 48 | 49 | TEST_ARGS = -v --full-trace -l --doctest-modules --doctest-continue-on-failure --cov matchzoo/ --cov-report term-missing --cov-report html --cov-config .coveragerc matchzoo/ tests/ -W ignore::DeprecationWarning --ignore=matchzoo/contrib 50 | FLAKE_ARGS = ./matchzoo --exclude=__init__.py,matchzoo/contrib 51 | 52 | test: 53 | pytest $(TEST_ARGS) 54 | flake8 $(FLAKE_ARGS) 55 | 56 | push: 57 | pytest -m 'not cron' $(TEST_ARGS) ${ARGS} 58 | flake8 $(FLAKE_ARGS) 59 | 60 | quick: 61 | pytest -m 'not slow and not cron' $(TEST_ARGS) ${ARGS} 62 | 63 | slow: 64 | pytest -m 'slow and not cron' $(TEST_ARGS) ${ARGS} 65 | 66 | cron: 67 | pytest -m 'cron' $(TEST_ARGS) ${ARGS} 68 | 69 | flake: 70 | flake8 $(FLAKE_ARGS) ${ARGS} 71 | -------------------------------------------------------------------------------- /artworks/matchzoo-logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NTMC-Community/MatchZoo/8a487ee5a574356fc91e4f48e219253dc11bcff2/artworks/matchzoo-logo.png -------------------------------------------------------------------------------- /artworks/matchzoo_github_qr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NTMC-Community/MatchZoo/8a487ee5a574356fc91e4f48e219253dc11bcff2/artworks/matchzoo_github_qr.png -------------------------------------------------------------------------------- /artworks/matchzoo_github_qr_black.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NTMC-Community/MatchZoo/8a487ee5a574356fc91e4f48e219253dc11bcff2/artworks/matchzoo_github_qr_black.png -------------------------------------------------------------------------------- /docs/DOCCHECK.md: -------------------------------------------------------------------------------- 1 | Documentation Checking Process(Only for the developers) 2 | ========================================================== 3 | 4 | # Why 5 | 6 | It is necessary for all the developers to generate the rst files which can help us check the documents. 7 | 8 | # When 9 | 10 | 1. You add a new function to one of the scripts in the {MatchZoo/matchzoo} or its subdirs 11 | 1. You add a new script to {MatchZoo/matchzoo} or its subdirs 12 | 1. You add a new directory to {MatchZoo/matchzoo} or its subdirs 13 | 14 | # How 15 | ## Make sure you have installed sphinx 16 | 17 | 1. Enter the docs directory 18 | 19 | ``` 20 | cd {MatchZoo/docs} 21 | ``` 22 | 23 | 2. Generate the rst files 24 | 25 | ``` 26 | sphinx-apidoc -f -o source ../matchzoo 27 | ``` 28 | 29 | 3. Commit 30 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SPHINXPROJ = MatchZoo 8 | SOURCEDIR = source 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/Readme.md: -------------------------------------------------------------------------------- 1 | ## Build Documentation: 2 | 3 | 4 | 5 | #### Install Requirements 6 | 7 | ```python 8 | pip install -r requirements.txt 9 | ``` 10 | 11 | 12 | 13 | #### Build Documentation 14 | 15 | ```python 16 | # Enter docs folder. 17 | cd docs 18 | # Use sphinx autodoc to generate rst. 19 | # usage: sphinx-apidoc [OPTIONS] -o [EXCLUDE_PATTERN,...] 20 | sphinx-apidoc -o source/ ../matchzoo/ ../matchzoo/contrib 21 | # Generate html from rst 22 | make clean 23 | make html 24 | ``` 25 | 26 | -------------------------------------------------------------------------------- /docs/_build/doctrees/environment.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NTMC-Community/MatchZoo/8a487ee5a574356fc91e4f48e219253dc11bcff2/docs/_build/doctrees/environment.pickle -------------------------------------------------------------------------------- /docs/_build/doctrees/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NTMC-Community/MatchZoo/8a487ee5a574356fc91e4f48e219253dc11bcff2/docs/_build/doctrees/index.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/matchzoo.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NTMC-Community/MatchZoo/8a487ee5a574356fc91e4f48e219253dc11bcff2/docs/_build/doctrees/matchzoo.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/matchzoo.engine.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NTMC-Community/MatchZoo/8a487ee5a574356fc91e4f48e219253dc11bcff2/docs/_build/doctrees/matchzoo.engine.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/matchzoo.models.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NTMC-Community/MatchZoo/8a487ee5a574356fc91e4f48e219253dc11bcff2/docs/_build/doctrees/matchzoo.models.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/matchzoo.tasks.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NTMC-Community/MatchZoo/8a487ee5a574356fc91e4f48e219253dc11bcff2/docs/_build/doctrees/matchzoo.tasks.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/modules.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NTMC-Community/MatchZoo/8a487ee5a574356fc91e4f48e219253dc11bcff2/docs/_build/doctrees/modules.doctree -------------------------------------------------------------------------------- /docs/_build/html/.buildinfo: -------------------------------------------------------------------------------- 1 | # Sphinx build info version 1 2 | # This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done. 3 | config: 14607cdb85fbc503df4ae80dc1192ccd 4 | tags: 645f666f9bcd5a90fca523b33c5a78b7 5 | -------------------------------------------------------------------------------- /docs/_build/html/_images/matchzoo-logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NTMC-Community/MatchZoo/8a487ee5a574356fc91e4f48e219253dc11bcff2/docs/_build/html/_images/matchzoo-logo.png -------------------------------------------------------------------------------- /docs/_build/html/_sources/index.rst.txt: -------------------------------------------------------------------------------- 1 | .. MatchZoo documentation master file, created by 2 | sphinx-quickstart on Mon May 28 16:40:41 2018. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to MatchZoo's documentation! 7 | ==================================== 8 | 9 | 10 | .. image:: https://travis-ci.org/faneshion/MatchZoo.svg?branch=master 11 | :alt: ci 12 | :target: https://travis-ci.org/faneshion/MatchZoo/ 13 | 14 | 15 | .. image:: ../../artworks/matchzoo-logo.png 16 | :alt: logo 17 | :align: center 18 | 19 | 20 | MatchZoo is a toolkit for text matching. It was developed with a focus on facilitating the designing, comparing and sharing of deep text matching models. There are a number of deep matching methods, such as DRMM, MatchPyramid, MV-LSTM, aNMM, DUET, ARC-I, ARC-II, DSSM, and CDSSM, designed with a unified interface. Potential tasks related to MatchZoo include document retrieval, question answering, conversational response ranking, paraphrase identification, etc. We are always happy to receive any code contributions, suggestions, comments from all our MatchZoo users. 21 | 22 | 23 | .. toctree:: 24 | :maxdepth: 2 25 | :caption: Contents: 26 | 27 | modules 28 | model_reference 29 | 30 | 31 | Indices and tables 32 | ================== 33 | 34 | * :ref:`genindex` 35 | * :ref:`modindex` 36 | * :ref:`search` 37 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/matchzoo.engine.rst.txt: -------------------------------------------------------------------------------- 1 | matchzoo.engine package 2 | ======================= 3 | 4 | Submodules 5 | ---------- 6 | 7 | matchzoo.engine.base\_metric module 8 | ----------------------------------- 9 | 10 | .. automodule:: matchzoo.engine.base_metric 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | matchzoo.engine.base\_model module 16 | ---------------------------------- 17 | 18 | .. automodule:: matchzoo.engine.base_model 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | matchzoo.engine.base\_preprocessor module 24 | ----------------------------------------- 25 | 26 | .. automodule:: matchzoo.engine.base_preprocessor 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | matchzoo.engine.base\_task module 32 | --------------------------------- 33 | 34 | .. automodule:: matchzoo.engine.base_task 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | matchzoo.engine.callbacks module 40 | -------------------------------- 41 | 42 | .. automodule:: matchzoo.engine.callbacks 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | matchzoo.engine.hyper\_spaces module 48 | ------------------------------------ 49 | 50 | .. automodule:: matchzoo.engine.hyper_spaces 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | 55 | matchzoo.engine.param module 56 | ---------------------------- 57 | 58 | .. automodule:: matchzoo.engine.param 59 | :members: 60 | :undoc-members: 61 | :show-inheritance: 62 | 63 | matchzoo.engine.param\_table module 64 | ----------------------------------- 65 | 66 | .. automodule:: matchzoo.engine.param_table 67 | :members: 68 | :undoc-members: 69 | :show-inheritance: 70 | 71 | 72 | Module contents 73 | --------------- 74 | 75 | .. automodule:: matchzoo.engine 76 | :members: 77 | :undoc-members: 78 | :show-inheritance: 79 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/matchzoo.models.rst.txt: -------------------------------------------------------------------------------- 1 | matchzoo.models package 2 | ======================= 3 | 4 | Submodules 5 | ---------- 6 | 7 | matchzoo.models.anmm module 8 | --------------------------- 9 | 10 | .. automodule:: matchzoo.models.anmm 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | matchzoo.models.arci module 16 | --------------------------- 17 | 18 | .. automodule:: matchzoo.models.arci 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | matchzoo.models.arcii module 24 | ---------------------------- 25 | 26 | .. automodule:: matchzoo.models.arcii 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | matchzoo.models.cdssm module 32 | ---------------------------- 33 | 34 | .. automodule:: matchzoo.models.cdssm 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | matchzoo.models.conv\_knrm module 40 | --------------------------------- 41 | 42 | .. automodule:: matchzoo.models.conv_knrm 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | matchzoo.models.dense\_baseline\_model module 48 | --------------------------------------------- 49 | 50 | .. automodule:: matchzoo.models.dense_baseline_model 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | 55 | matchzoo.models.drmm module 56 | --------------------------- 57 | 58 | .. automodule:: matchzoo.models.drmm 59 | :members: 60 | :undoc-members: 61 | :show-inheritance: 62 | 63 | matchzoo.models.drmmtks module 64 | ------------------------------ 65 | 66 | .. automodule:: matchzoo.models.drmmtks 67 | :members: 68 | :undoc-members: 69 | :show-inheritance: 70 | 71 | matchzoo.models.dssm module 72 | --------------------------- 73 | 74 | .. automodule:: matchzoo.models.dssm 75 | :members: 76 | :undoc-members: 77 | :show-inheritance: 78 | 79 | matchzoo.models.duet module 80 | --------------------------- 81 | 82 | .. automodule:: matchzoo.models.duet 83 | :members: 84 | :undoc-members: 85 | :show-inheritance: 86 | 87 | matchzoo.models.knrm module 88 | --------------------------- 89 | 90 | .. automodule:: matchzoo.models.knrm 91 | :members: 92 | :undoc-members: 93 | :show-inheritance: 94 | 95 | matchzoo.models.match\_pyramid module 96 | ------------------------------------- 97 | 98 | .. automodule:: matchzoo.models.match_pyramid 99 | :members: 100 | :undoc-members: 101 | :show-inheritance: 102 | 103 | matchzoo.models.mvlstm module 104 | ----------------------------- 105 | 106 | .. automodule:: matchzoo.models.mvlstm 107 | :members: 108 | :undoc-members: 109 | :show-inheritance: 110 | 111 | matchzoo.models.naive\_model module 112 | ----------------------------------- 113 | 114 | .. automodule:: matchzoo.models.naive_model 115 | :members: 116 | :undoc-members: 117 | :show-inheritance: 118 | 119 | matchzoo.models.parameter\_readme\_generator module 120 | --------------------------------------------------- 121 | 122 | .. automodule:: matchzoo.models.parameter_readme_generator 123 | :members: 124 | :undoc-members: 125 | :show-inheritance: 126 | 127 | 128 | Module contents 129 | --------------- 130 | 131 | .. automodule:: matchzoo.models 132 | :members: 133 | :undoc-members: 134 | :show-inheritance: 135 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/matchzoo.rst.txt: -------------------------------------------------------------------------------- 1 | matchzoo package 2 | ================ 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | 9 | matchzoo.auto 10 | matchzoo.contrib 11 | matchzoo.data_generator 12 | matchzoo.data_pack 13 | matchzoo.datasets 14 | matchzoo.engine 15 | matchzoo.layers 16 | matchzoo.losses 17 | matchzoo.metrics 18 | matchzoo.models 19 | matchzoo.preprocessors 20 | matchzoo.processor_units 21 | matchzoo.tasks 22 | matchzoo.utils 23 | 24 | Submodules 25 | ---------- 26 | 27 | matchzoo.embedding module 28 | ------------------------- 29 | 30 | .. automodule:: matchzoo.embedding 31 | :members: 32 | :undoc-members: 33 | :show-inheritance: 34 | 35 | matchzoo.logger module 36 | ---------------------- 37 | 38 | .. automodule:: matchzoo.logger 39 | :members: 40 | :undoc-members: 41 | :show-inheritance: 42 | 43 | matchzoo.version module 44 | ----------------------- 45 | 46 | .. automodule:: matchzoo.version 47 | :members: 48 | :undoc-members: 49 | :show-inheritance: 50 | 51 | 52 | Module contents 53 | --------------- 54 | 55 | .. automodule:: matchzoo 56 | :members: 57 | :undoc-members: 58 | :show-inheritance: 59 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/matchzoo.tasks.rst.txt: -------------------------------------------------------------------------------- 1 | matchzoo.tasks package 2 | ====================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | matchzoo.tasks.classification module 8 | ------------------------------------ 9 | 10 | .. automodule:: matchzoo.tasks.classification 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | matchzoo.tasks.ranking module 16 | ----------------------------- 17 | 18 | .. automodule:: matchzoo.tasks.ranking 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | matchzoo.tasks.utils module 24 | --------------------------- 25 | 26 | .. automodule:: matchzoo.tasks.utils 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | 32 | Module contents 33 | --------------- 34 | 35 | .. automodule:: matchzoo.tasks 36 | :members: 37 | :undoc-members: 38 | :show-inheritance: 39 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/modules.rst.txt: -------------------------------------------------------------------------------- 1 | matchzoo 2 | ======== 3 | 4 | .. toctree:: 5 | :maxdepth: 4 6 | 7 | matchzoo 8 | -------------------------------------------------------------------------------- /docs/_build/html/_static/ajax-loader.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NTMC-Community/MatchZoo/8a487ee5a574356fc91e4f48e219253dc11bcff2/docs/_build/html/_static/ajax-loader.gif -------------------------------------------------------------------------------- /docs/_build/html/_static/comment-bright.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NTMC-Community/MatchZoo/8a487ee5a574356fc91e4f48e219253dc11bcff2/docs/_build/html/_static/comment-bright.png -------------------------------------------------------------------------------- /docs/_build/html/_static/comment-close.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NTMC-Community/MatchZoo/8a487ee5a574356fc91e4f48e219253dc11bcff2/docs/_build/html/_static/comment-close.png -------------------------------------------------------------------------------- /docs/_build/html/_static/comment.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NTMC-Community/MatchZoo/8a487ee5a574356fc91e4f48e219253dc11bcff2/docs/_build/html/_static/comment.png -------------------------------------------------------------------------------- /docs/_build/html/_static/css/badge_only.css: -------------------------------------------------------------------------------- 1 | .fa:before{-webkit-font-smoothing:antialiased}.clearfix{*zoom:1}.clearfix:before,.clearfix:after{display:table;content:""}.clearfix:after{clear:both}@font-face{font-family:FontAwesome;font-weight:normal;font-style:normal;src:url("../fonts/fontawesome-webfont.eot");src:url("../fonts/fontawesome-webfont.eot?#iefix") format("embedded-opentype"),url("../fonts/fontawesome-webfont.woff") format("woff"),url("../fonts/fontawesome-webfont.ttf") format("truetype"),url("../fonts/fontawesome-webfont.svg#FontAwesome") format("svg")}.fa:before{display:inline-block;font-family:FontAwesome;font-style:normal;font-weight:normal;line-height:1;text-decoration:inherit}a .fa{display:inline-block;text-decoration:inherit}li .fa{display:inline-block}li .fa-large:before,li .fa-large:before{width:1.875em}ul.fas{list-style-type:none;margin-left:2em;text-indent:-0.8em}ul.fas li .fa{width:.8em}ul.fas li .fa-large:before,ul.fas li .fa-large:before{vertical-align:baseline}.fa-book:before{content:""}.icon-book:before{content:""}.fa-caret-down:before{content:""}.icon-caret-down:before{content:""}.fa-caret-up:before{content:""}.icon-caret-up:before{content:""}.fa-caret-left:before{content:""}.icon-caret-left:before{content:""}.fa-caret-right:before{content:""}.icon-caret-right:before{content:""}.rst-versions{position:fixed;bottom:0;left:0;width:300px;color:#fcfcfc;background:#1f1d1d;font-family:"Lato","proxima-nova","Helvetica Neue",Arial,sans-serif;z-index:400}.rst-versions a{color:#2980B9;text-decoration:none}.rst-versions .rst-badge-small{display:none}.rst-versions .rst-current-version{padding:12px;background-color:#272525;display:block;text-align:right;font-size:90%;cursor:pointer;color:#27AE60;*zoom:1}.rst-versions .rst-current-version:before,.rst-versions .rst-current-version:after{display:table;content:""}.rst-versions .rst-current-version:after{clear:both}.rst-versions .rst-current-version .fa{color:#fcfcfc}.rst-versions .rst-current-version .fa-book{float:left}.rst-versions .rst-current-version .icon-book{float:left}.rst-versions .rst-current-version.rst-out-of-date{background-color:#E74C3C;color:#fff}.rst-versions .rst-current-version.rst-active-old-version{background-color:#F1C40F;color:#000}.rst-versions.shift-up{height:auto;max-height:100%}.rst-versions.shift-up .rst-other-versions{display:block}.rst-versions .rst-other-versions{font-size:90%;padding:12px;color:gray;display:none}.rst-versions .rst-other-versions hr{display:block;height:1px;border:0;margin:20px 0;padding:0;border-top:solid 1px #413d3d}.rst-versions .rst-other-versions dd{display:inline-block;margin:0}.rst-versions .rst-other-versions dd a{display:inline-block;padding:6px;color:#fcfcfc}.rst-versions.rst-badge{width:auto;bottom:20px;right:20px;left:auto;border:none;max-width:300px}.rst-versions.rst-badge .icon-book{float:none}.rst-versions.rst-badge .fa-book{float:none}.rst-versions.rst-badge.shift-up .rst-current-version{text-align:right}.rst-versions.rst-badge.shift-up .rst-current-version .fa-book{float:left}.rst-versions.rst-badge.shift-up .rst-current-version .icon-book{float:left}.rst-versions.rst-badge .rst-current-version{width:auto;height:30px;line-height:30px;padding:0 6px;display:block;text-align:center}@media screen and (max-width: 768px){.rst-versions{width:85%;display:none}.rst-versions.shift{display:block}} 2 | -------------------------------------------------------------------------------- /docs/_build/html/_static/documentation_options.js: -------------------------------------------------------------------------------- 1 | var DOCUMENTATION_OPTIONS = { 2 | URL_ROOT: document.getElementById("documentation_options").getAttribute('data-url_root'), 3 | VERSION: '2.0', 4 | LANGUAGE: 'None', 5 | COLLAPSE_INDEX: false, 6 | FILE_SUFFIX: '.html', 7 | HAS_SOURCE: true, 8 | SOURCELINK_SUFFIX: '.txt', 9 | NAVIGATION_WITH_KEYS: false, 10 | }; -------------------------------------------------------------------------------- /docs/_build/html/_static/down-pressed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NTMC-Community/MatchZoo/8a487ee5a574356fc91e4f48e219253dc11bcff2/docs/_build/html/_static/down-pressed.png -------------------------------------------------------------------------------- /docs/_build/html/_static/down.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NTMC-Community/MatchZoo/8a487ee5a574356fc91e4f48e219253dc11bcff2/docs/_build/html/_static/down.png -------------------------------------------------------------------------------- /docs/_build/html/_static/file.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NTMC-Community/MatchZoo/8a487ee5a574356fc91e4f48e219253dc11bcff2/docs/_build/html/_static/file.png -------------------------------------------------------------------------------- /docs/_build/html/_static/fonts/Lato/lato-bold.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NTMC-Community/MatchZoo/8a487ee5a574356fc91e4f48e219253dc11bcff2/docs/_build/html/_static/fonts/Lato/lato-bold.eot -------------------------------------------------------------------------------- /docs/_build/html/_static/fonts/Lato/lato-bold.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NTMC-Community/MatchZoo/8a487ee5a574356fc91e4f48e219253dc11bcff2/docs/_build/html/_static/fonts/Lato/lato-bold.ttf -------------------------------------------------------------------------------- /docs/_build/html/_static/fonts/Lato/lato-bold.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NTMC-Community/MatchZoo/8a487ee5a574356fc91e4f48e219253dc11bcff2/docs/_build/html/_static/fonts/Lato/lato-bold.woff -------------------------------------------------------------------------------- /docs/_build/html/_static/fonts/Lato/lato-bold.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NTMC-Community/MatchZoo/8a487ee5a574356fc91e4f48e219253dc11bcff2/docs/_build/html/_static/fonts/Lato/lato-bold.woff2 -------------------------------------------------------------------------------- /docs/_build/html/_static/fonts/Lato/lato-bolditalic.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NTMC-Community/MatchZoo/8a487ee5a574356fc91e4f48e219253dc11bcff2/docs/_build/html/_static/fonts/Lato/lato-bolditalic.eot -------------------------------------------------------------------------------- /docs/_build/html/_static/fonts/Lato/lato-bolditalic.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NTMC-Community/MatchZoo/8a487ee5a574356fc91e4f48e219253dc11bcff2/docs/_build/html/_static/fonts/Lato/lato-bolditalic.ttf -------------------------------------------------------------------------------- /docs/_build/html/_static/fonts/Lato/lato-bolditalic.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NTMC-Community/MatchZoo/8a487ee5a574356fc91e4f48e219253dc11bcff2/docs/_build/html/_static/fonts/Lato/lato-bolditalic.woff -------------------------------------------------------------------------------- /docs/_build/html/_static/fonts/Lato/lato-bolditalic.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NTMC-Community/MatchZoo/8a487ee5a574356fc91e4f48e219253dc11bcff2/docs/_build/html/_static/fonts/Lato/lato-bolditalic.woff2 -------------------------------------------------------------------------------- /docs/_build/html/_static/fonts/Lato/lato-italic.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NTMC-Community/MatchZoo/8a487ee5a574356fc91e4f48e219253dc11bcff2/docs/_build/html/_static/fonts/Lato/lato-italic.eot -------------------------------------------------------------------------------- /docs/_build/html/_static/fonts/Lato/lato-italic.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NTMC-Community/MatchZoo/8a487ee5a574356fc91e4f48e219253dc11bcff2/docs/_build/html/_static/fonts/Lato/lato-italic.ttf -------------------------------------------------------------------------------- /docs/_build/html/_static/fonts/Lato/lato-italic.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NTMC-Community/MatchZoo/8a487ee5a574356fc91e4f48e219253dc11bcff2/docs/_build/html/_static/fonts/Lato/lato-italic.woff -------------------------------------------------------------------------------- /docs/_build/html/_static/fonts/Lato/lato-italic.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NTMC-Community/MatchZoo/8a487ee5a574356fc91e4f48e219253dc11bcff2/docs/_build/html/_static/fonts/Lato/lato-italic.woff2 -------------------------------------------------------------------------------- /docs/_build/html/_static/fonts/Lato/lato-regular.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NTMC-Community/MatchZoo/8a487ee5a574356fc91e4f48e219253dc11bcff2/docs/_build/html/_static/fonts/Lato/lato-regular.eot -------------------------------------------------------------------------------- /docs/_build/html/_static/fonts/Lato/lato-regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NTMC-Community/MatchZoo/8a487ee5a574356fc91e4f48e219253dc11bcff2/docs/_build/html/_static/fonts/Lato/lato-regular.ttf -------------------------------------------------------------------------------- /docs/_build/html/_static/fonts/Lato/lato-regular.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NTMC-Community/MatchZoo/8a487ee5a574356fc91e4f48e219253dc11bcff2/docs/_build/html/_static/fonts/Lato/lato-regular.woff -------------------------------------------------------------------------------- /docs/_build/html/_static/fonts/Lato/lato-regular.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NTMC-Community/MatchZoo/8a487ee5a574356fc91e4f48e219253dc11bcff2/docs/_build/html/_static/fonts/Lato/lato-regular.woff2 -------------------------------------------------------------------------------- /docs/_build/html/_static/fonts/RobotoSlab/roboto-slab-v7-bold.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NTMC-Community/MatchZoo/8a487ee5a574356fc91e4f48e219253dc11bcff2/docs/_build/html/_static/fonts/RobotoSlab/roboto-slab-v7-bold.eot -------------------------------------------------------------------------------- /docs/_build/html/_static/fonts/RobotoSlab/roboto-slab-v7-bold.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NTMC-Community/MatchZoo/8a487ee5a574356fc91e4f48e219253dc11bcff2/docs/_build/html/_static/fonts/RobotoSlab/roboto-slab-v7-bold.ttf -------------------------------------------------------------------------------- /docs/_build/html/_static/fonts/RobotoSlab/roboto-slab-v7-bold.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NTMC-Community/MatchZoo/8a487ee5a574356fc91e4f48e219253dc11bcff2/docs/_build/html/_static/fonts/RobotoSlab/roboto-slab-v7-bold.woff -------------------------------------------------------------------------------- /docs/_build/html/_static/fonts/RobotoSlab/roboto-slab-v7-bold.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NTMC-Community/MatchZoo/8a487ee5a574356fc91e4f48e219253dc11bcff2/docs/_build/html/_static/fonts/RobotoSlab/roboto-slab-v7-bold.woff2 -------------------------------------------------------------------------------- /docs/_build/html/_static/fonts/RobotoSlab/roboto-slab-v7-regular.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NTMC-Community/MatchZoo/8a487ee5a574356fc91e4f48e219253dc11bcff2/docs/_build/html/_static/fonts/RobotoSlab/roboto-slab-v7-regular.eot -------------------------------------------------------------------------------- /docs/_build/html/_static/fonts/RobotoSlab/roboto-slab-v7-regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NTMC-Community/MatchZoo/8a487ee5a574356fc91e4f48e219253dc11bcff2/docs/_build/html/_static/fonts/RobotoSlab/roboto-slab-v7-regular.ttf -------------------------------------------------------------------------------- /docs/_build/html/_static/fonts/RobotoSlab/roboto-slab-v7-regular.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NTMC-Community/MatchZoo/8a487ee5a574356fc91e4f48e219253dc11bcff2/docs/_build/html/_static/fonts/RobotoSlab/roboto-slab-v7-regular.woff -------------------------------------------------------------------------------- /docs/_build/html/_static/fonts/RobotoSlab/roboto-slab-v7-regular.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NTMC-Community/MatchZoo/8a487ee5a574356fc91e4f48e219253dc11bcff2/docs/_build/html/_static/fonts/RobotoSlab/roboto-slab-v7-regular.woff2 -------------------------------------------------------------------------------- /docs/_build/html/_static/fonts/fontawesome-webfont.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NTMC-Community/MatchZoo/8a487ee5a574356fc91e4f48e219253dc11bcff2/docs/_build/html/_static/fonts/fontawesome-webfont.eot -------------------------------------------------------------------------------- /docs/_build/html/_static/fonts/fontawesome-webfont.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NTMC-Community/MatchZoo/8a487ee5a574356fc91e4f48e219253dc11bcff2/docs/_build/html/_static/fonts/fontawesome-webfont.ttf -------------------------------------------------------------------------------- /docs/_build/html/_static/fonts/fontawesome-webfont.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NTMC-Community/MatchZoo/8a487ee5a574356fc91e4f48e219253dc11bcff2/docs/_build/html/_static/fonts/fontawesome-webfont.woff -------------------------------------------------------------------------------- /docs/_build/html/_static/fonts/fontawesome-webfont.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NTMC-Community/MatchZoo/8a487ee5a574356fc91e4f48e219253dc11bcff2/docs/_build/html/_static/fonts/fontawesome-webfont.woff2 -------------------------------------------------------------------------------- /docs/_build/html/_static/minus.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NTMC-Community/MatchZoo/8a487ee5a574356fc91e4f48e219253dc11bcff2/docs/_build/html/_static/minus.png -------------------------------------------------------------------------------- /docs/_build/html/_static/plus.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NTMC-Community/MatchZoo/8a487ee5a574356fc91e4f48e219253dc11bcff2/docs/_build/html/_static/plus.png -------------------------------------------------------------------------------- /docs/_build/html/_static/up-pressed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NTMC-Community/MatchZoo/8a487ee5a574356fc91e4f48e219253dc11bcff2/docs/_build/html/_static/up-pressed.png -------------------------------------------------------------------------------- /docs/_build/html/_static/up.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NTMC-Community/MatchZoo/8a487ee5a574356fc91e4f48e219253dc11bcff2/docs/_build/html/_static/up.png -------------------------------------------------------------------------------- /docs/_build/html/objects.inv: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NTMC-Community/MatchZoo/8a487ee5a574356fc91e4f48e219253dc11bcff2/docs/_build/html/objects.inv -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=_build 12 | set SPHINXPROJ=MatchZoo 13 | 14 | if "%1" == "" goto help 15 | 16 | %SPHINXBUILD% >NUL 2>NUL 17 | if errorlevel 9009 ( 18 | echo. 19 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 20 | echo.installed, then set the SPHINXBUILD environment variable to point 21 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 22 | echo.may add the Sphinx directory to PATH. 23 | echo. 24 | echo.If you don't have Sphinx installed, grab it from 25 | echo.http://sphinx-doc.org/ 26 | exit /b 1 27 | ) 28 | 29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 30 | goto end 31 | 32 | :help 33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 34 | 35 | :end 36 | popd 37 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | Sphinx >= 1.7.5 2 | sphinx_rtd_theme >= 0.4.0 3 | keras >= 2.0.5 4 | nltk >= 3.2.3 5 | numpy >= 1.12.1 6 | h5py >= 2.7.0 7 | dill >= 0.2.7.1 8 | hyperopt >= 0.1 9 | pandas >= 0.23.1 10 | sphinx_autodoc_typehints>=1.6.0 11 | tensorflow 12 | tabulate >= 0.8.2 13 | nbsphinx -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. MatchZoo documentation master file, created by 2 | sphinx-quickstart on Mon May 28 16:40:41 2018. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to MatchZoo's documentation! 7 | ==================================== 8 | 9 | 10 | .. image:: https://travis-ci.org/faneshion/MatchZoo.svg?branch=master 11 | :alt: ci 12 | :target: https://travis-ci.org/faneshion/MatchZoo/ 13 | 14 | 15 | .. image:: ../../artworks/matchzoo-logo.png 16 | :alt: logo 17 | :align: center 18 | 19 | 20 | MatchZoo is a toolkit for text matching. It was developed with a focus on facilitating the designing, comparing and sharing of deep text matching models. There are a number of deep matching methods, such as DRMM, MatchPyramid, MV-LSTM, aNMM, DUET, ARC-I, ARC-II, DSSM, and CDSSM, designed with a unified interface. Potential tasks related to MatchZoo include document retrieval, question answering, conversational response ranking, paraphrase identification, etc. We are always happy to receive any code contributions, suggestions, comments from all our MatchZoo users. 21 | 22 | 23 | .. toctree:: 24 | :maxdepth: 2 25 | :caption: Contents: 26 | 27 | modules 28 | model_reference 29 | 30 | 31 | Indices and tables 32 | ================== 33 | 34 | * :ref:`genindex` 35 | * :ref:`modindex` 36 | * :ref:`search` 37 | -------------------------------------------------------------------------------- /docs/source/matchzoo.auto.preparer.rst: -------------------------------------------------------------------------------- 1 | matchzoo.auto.preparer package 2 | ============================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | matchzoo.auto.preparer.prepare module 8 | ------------------------------------- 9 | 10 | .. automodule:: matchzoo.auto.preparer.prepare 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | matchzoo.auto.preparer.preparer module 16 | -------------------------------------- 17 | 18 | .. automodule:: matchzoo.auto.preparer.preparer 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | 24 | Module contents 25 | --------------- 26 | 27 | .. automodule:: matchzoo.auto.preparer 28 | :members: 29 | :undoc-members: 30 | :show-inheritance: 31 | -------------------------------------------------------------------------------- /docs/source/matchzoo.auto.rst: -------------------------------------------------------------------------------- 1 | matchzoo.auto package 2 | ===================== 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | 9 | matchzoo.auto.preparer 10 | matchzoo.auto.tuner 11 | 12 | Module contents 13 | --------------- 14 | 15 | .. automodule:: matchzoo.auto 16 | :members: 17 | :undoc-members: 18 | :show-inheritance: 19 | -------------------------------------------------------------------------------- /docs/source/matchzoo.auto.tuner.callbacks.rst: -------------------------------------------------------------------------------- 1 | matchzoo.auto.tuner.callbacks package 2 | ===================================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | matchzoo.auto.tuner.callbacks.callback module 8 | --------------------------------------------- 9 | 10 | .. automodule:: matchzoo.auto.tuner.callbacks.callback 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | matchzoo.auto.tuner.callbacks.lambda\_callback module 16 | ----------------------------------------------------- 17 | 18 | .. automodule:: matchzoo.auto.tuner.callbacks.lambda_callback 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | matchzoo.auto.tuner.callbacks.load\_embedding\_matrix module 24 | ------------------------------------------------------------ 25 | 26 | .. automodule:: matchzoo.auto.tuner.callbacks.load_embedding_matrix 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | matchzoo.auto.tuner.callbacks.save\_model module 32 | ------------------------------------------------ 33 | 34 | .. automodule:: matchzoo.auto.tuner.callbacks.save_model 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | 40 | Module contents 41 | --------------- 42 | 43 | .. automodule:: matchzoo.auto.tuner.callbacks 44 | :members: 45 | :undoc-members: 46 | :show-inheritance: 47 | -------------------------------------------------------------------------------- /docs/source/matchzoo.auto.tuner.rst: -------------------------------------------------------------------------------- 1 | matchzoo.auto.tuner package 2 | =========================== 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | 9 | matchzoo.auto.tuner.callbacks 10 | 11 | Submodules 12 | ---------- 13 | 14 | matchzoo.auto.tuner.tune module 15 | ------------------------------- 16 | 17 | .. automodule:: matchzoo.auto.tuner.tune 18 | :members: 19 | :undoc-members: 20 | :show-inheritance: 21 | 22 | matchzoo.auto.tuner.tuner module 23 | -------------------------------- 24 | 25 | .. automodule:: matchzoo.auto.tuner.tuner 26 | :members: 27 | :undoc-members: 28 | :show-inheritance: 29 | 30 | 31 | Module contents 32 | --------------- 33 | 34 | .. automodule:: matchzoo.auto.tuner 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | -------------------------------------------------------------------------------- /docs/source/matchzoo.contrib.layers.rst: -------------------------------------------------------------------------------- 1 | matchzoo.contrib.layers package 2 | =============================== 3 | 4 | Module contents 5 | --------------- 6 | 7 | .. automodule:: matchzoo.contrib.layers 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: 11 | -------------------------------------------------------------------------------- /docs/source/matchzoo.contrib.models.rst: -------------------------------------------------------------------------------- 1 | matchzoo.contrib.models package 2 | =============================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | matchzoo.contrib.models.match\_lstm module 8 | ------------------------------------------ 9 | 10 | .. automodule:: matchzoo.contrib.models.match_lstm 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | 16 | Module contents 17 | --------------- 18 | 19 | .. automodule:: matchzoo.contrib.models 20 | :members: 21 | :undoc-members: 22 | :show-inheritance: 23 | -------------------------------------------------------------------------------- /docs/source/matchzoo.contrib.rst: -------------------------------------------------------------------------------- 1 | matchzoo.contrib package 2 | ======================== 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | 9 | matchzoo.contrib.layers 10 | matchzoo.contrib.models 11 | 12 | Module contents 13 | --------------- 14 | 15 | .. automodule:: matchzoo.contrib 16 | :members: 17 | :undoc-members: 18 | :show-inheritance: 19 | -------------------------------------------------------------------------------- /docs/source/matchzoo.data_generator.callbacks.rst: -------------------------------------------------------------------------------- 1 | matchzoo.data\_generator.callbacks package 2 | ========================================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | matchzoo.data\_generator.callbacks.callback module 8 | -------------------------------------------------- 9 | 10 | .. automodule:: matchzoo.data_generator.callbacks.callback 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | matchzoo.data\_generator.callbacks.dynamic\_pooling module 16 | ---------------------------------------------------------- 17 | 18 | .. automodule:: matchzoo.data_generator.callbacks.dynamic_pooling 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | matchzoo.data\_generator.callbacks.histogram module 24 | --------------------------------------------------- 25 | 26 | .. automodule:: matchzoo.data_generator.callbacks.histogram 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | matchzoo.data\_generator.callbacks.lambda\_callback module 32 | ---------------------------------------------------------- 33 | 34 | .. automodule:: matchzoo.data_generator.callbacks.lambda_callback 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | 40 | Module contents 41 | --------------- 42 | 43 | .. automodule:: matchzoo.data_generator.callbacks 44 | :members: 45 | :undoc-members: 46 | :show-inheritance: 47 | -------------------------------------------------------------------------------- /docs/source/matchzoo.data_generator.rst: -------------------------------------------------------------------------------- 1 | matchzoo.data\_generator package 2 | ================================ 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | 9 | matchzoo.data_generator.callbacks 10 | 11 | Submodules 12 | ---------- 13 | 14 | matchzoo.data\_generator.data\_generator module 15 | ----------------------------------------------- 16 | 17 | .. automodule:: matchzoo.data_generator.data_generator 18 | :members: 19 | :undoc-members: 20 | :show-inheritance: 21 | 22 | matchzoo.data\_generator.data\_generator\_builder module 23 | -------------------------------------------------------- 24 | 25 | .. automodule:: matchzoo.data_generator.data_generator_builder 26 | :members: 27 | :undoc-members: 28 | :show-inheritance: 29 | 30 | 31 | Module contents 32 | --------------- 33 | 34 | .. automodule:: matchzoo.data_generator 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | -------------------------------------------------------------------------------- /docs/source/matchzoo.data_pack.rst: -------------------------------------------------------------------------------- 1 | matchzoo.data\_pack package 2 | =========================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | matchzoo.data\_pack.data\_pack module 8 | ------------------------------------- 9 | 10 | .. automodule:: matchzoo.data_pack.data_pack 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | matchzoo.data\_pack.pack module 16 | ------------------------------- 17 | 18 | .. automodule:: matchzoo.data_pack.pack 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | 24 | Module contents 25 | --------------- 26 | 27 | .. automodule:: matchzoo.data_pack 28 | :members: 29 | :undoc-members: 30 | :show-inheritance: 31 | -------------------------------------------------------------------------------- /docs/source/matchzoo.datasets.embeddings.rst: -------------------------------------------------------------------------------- 1 | matchzoo.datasets.embeddings package 2 | ==================================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | matchzoo.datasets.embeddings.load\_glove\_embedding module 8 | ---------------------------------------------------------- 9 | 10 | .. automodule:: matchzoo.datasets.embeddings.load_glove_embedding 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | 16 | Module contents 17 | --------------- 18 | 19 | .. automodule:: matchzoo.datasets.embeddings 20 | :members: 21 | :undoc-members: 22 | :show-inheritance: 23 | -------------------------------------------------------------------------------- /docs/source/matchzoo.datasets.quora_qp.rst: -------------------------------------------------------------------------------- 1 | matchzoo.datasets.quora\_qp package 2 | =================================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | matchzoo.datasets.quora\_qp.load\_data module 8 | --------------------------------------------- 9 | 10 | .. automodule:: matchzoo.datasets.quora_qp.load_data 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | 16 | Module contents 17 | --------------- 18 | 19 | .. automodule:: matchzoo.datasets.quora_qp 20 | :members: 21 | :undoc-members: 22 | :show-inheritance: 23 | -------------------------------------------------------------------------------- /docs/source/matchzoo.datasets.rst: -------------------------------------------------------------------------------- 1 | matchzoo.datasets package 2 | ========================= 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | 9 | matchzoo.datasets.embeddings 10 | matchzoo.datasets.quora_qp 11 | matchzoo.datasets.snli 12 | matchzoo.datasets.toy 13 | matchzoo.datasets.wiki_qa 14 | 15 | Module contents 16 | --------------- 17 | 18 | .. automodule:: matchzoo.datasets 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | -------------------------------------------------------------------------------- /docs/source/matchzoo.datasets.snli.rst: -------------------------------------------------------------------------------- 1 | matchzoo.datasets.snli package 2 | ============================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | matchzoo.datasets.snli.load\_data module 8 | ---------------------------------------- 9 | 10 | .. automodule:: matchzoo.datasets.snli.load_data 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | 16 | Module contents 17 | --------------- 18 | 19 | .. automodule:: matchzoo.datasets.snli 20 | :members: 21 | :undoc-members: 22 | :show-inheritance: 23 | -------------------------------------------------------------------------------- /docs/source/matchzoo.datasets.toy.rst: -------------------------------------------------------------------------------- 1 | matchzoo.datasets.toy package 2 | ============================= 3 | 4 | Module contents 5 | --------------- 6 | 7 | .. automodule:: matchzoo.datasets.toy 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: 11 | -------------------------------------------------------------------------------- /docs/source/matchzoo.datasets.wiki_qa.rst: -------------------------------------------------------------------------------- 1 | matchzoo.datasets.wiki\_qa package 2 | ================================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | matchzoo.datasets.wiki\_qa.load\_data module 8 | -------------------------------------------- 9 | 10 | .. automodule:: matchzoo.datasets.wiki_qa.load_data 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | 16 | Module contents 17 | --------------- 18 | 19 | .. automodule:: matchzoo.datasets.wiki_qa 20 | :members: 21 | :undoc-members: 22 | :show-inheritance: 23 | -------------------------------------------------------------------------------- /docs/source/matchzoo.embedding.rst: -------------------------------------------------------------------------------- 1 | matchzoo.embedding package 2 | ========================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | matchzoo.embedding.embedding module 8 | ----------------------------------- 9 | 10 | .. automodule:: matchzoo.embedding.embedding 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | 16 | Module contents 17 | --------------- 18 | 19 | .. automodule:: matchzoo.embedding 20 | :members: 21 | :undoc-members: 22 | :show-inheritance: 23 | -------------------------------------------------------------------------------- /docs/source/matchzoo.engine.rst: -------------------------------------------------------------------------------- 1 | matchzoo.engine package 2 | ======================= 3 | 4 | Submodules 5 | ---------- 6 | 7 | matchzoo.engine.base\_metric module 8 | ----------------------------------- 9 | 10 | .. automodule:: matchzoo.engine.base_metric 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | matchzoo.engine.base\_model module 16 | ---------------------------------- 17 | 18 | .. automodule:: matchzoo.engine.base_model 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | matchzoo.engine.base\_preprocessor module 24 | ----------------------------------------- 25 | 26 | .. automodule:: matchzoo.engine.base_preprocessor 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | matchzoo.engine.base\_task module 32 | --------------------------------- 33 | 34 | .. automodule:: matchzoo.engine.base_task 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | matchzoo.engine.callbacks module 40 | -------------------------------- 41 | 42 | .. automodule:: matchzoo.engine.callbacks 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | matchzoo.engine.hyper\_spaces module 48 | ------------------------------------ 49 | 50 | .. automodule:: matchzoo.engine.hyper_spaces 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | 55 | matchzoo.engine.param module 56 | ---------------------------- 57 | 58 | .. automodule:: matchzoo.engine.param 59 | :members: 60 | :undoc-members: 61 | :show-inheritance: 62 | 63 | matchzoo.engine.param\_table module 64 | ----------------------------------- 65 | 66 | .. automodule:: matchzoo.engine.param_table 67 | :members: 68 | :undoc-members: 69 | :show-inheritance: 70 | 71 | matchzoo.engine.parse\_metric module 72 | ------------------------------------ 73 | 74 | .. automodule:: matchzoo.engine.parse_metric 75 | :members: 76 | :undoc-members: 77 | :show-inheritance: 78 | 79 | 80 | Module contents 81 | --------------- 82 | 83 | .. automodule:: matchzoo.engine 84 | :members: 85 | :undoc-members: 86 | :show-inheritance: 87 | -------------------------------------------------------------------------------- /docs/source/matchzoo.layers.rst: -------------------------------------------------------------------------------- 1 | matchzoo.layers package 2 | ======================= 3 | 4 | Submodules 5 | ---------- 6 | 7 | matchzoo.layers.dynamic\_pooling\_layer module 8 | ---------------------------------------------- 9 | 10 | .. automodule:: matchzoo.layers.dynamic_pooling_layer 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | matchzoo.layers.matching\_layer module 16 | -------------------------------------- 17 | 18 | .. automodule:: matchzoo.layers.matching_layer 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | 24 | Module contents 25 | --------------- 26 | 27 | .. automodule:: matchzoo.layers 28 | :members: 29 | :undoc-members: 30 | :show-inheritance: 31 | -------------------------------------------------------------------------------- /docs/source/matchzoo.losses.rst: -------------------------------------------------------------------------------- 1 | matchzoo.losses package 2 | ======================= 3 | 4 | Submodules 5 | ---------- 6 | 7 | matchzoo.losses.rank\_cross\_entropy\_loss module 8 | ------------------------------------------------- 9 | 10 | .. automodule:: matchzoo.losses.rank_cross_entropy_loss 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | matchzoo.losses.rank\_hinge\_loss module 16 | ---------------------------------------- 17 | 18 | .. automodule:: matchzoo.losses.rank_hinge_loss 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | 24 | Module contents 25 | --------------- 26 | 27 | .. automodule:: matchzoo.losses 28 | :members: 29 | :undoc-members: 30 | :show-inheritance: 31 | -------------------------------------------------------------------------------- /docs/source/matchzoo.metrics.rst: -------------------------------------------------------------------------------- 1 | matchzoo.metrics package 2 | ======================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | matchzoo.metrics.average\_precision module 8 | ------------------------------------------ 9 | 10 | .. automodule:: matchzoo.metrics.average_precision 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | matchzoo.metrics.discounted\_cumulative\_gain module 16 | ---------------------------------------------------- 17 | 18 | .. automodule:: matchzoo.metrics.discounted_cumulative_gain 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | matchzoo.metrics.mean\_average\_precision module 24 | ------------------------------------------------ 25 | 26 | .. automodule:: matchzoo.metrics.mean_average_precision 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | matchzoo.metrics.mean\_reciprocal\_rank module 32 | ---------------------------------------------- 33 | 34 | .. automodule:: matchzoo.metrics.mean_reciprocal_rank 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | matchzoo.metrics.normalized\_discounted\_cumulative\_gain module 40 | ---------------------------------------------------------------- 41 | 42 | .. automodule:: matchzoo.metrics.normalized_discounted_cumulative_gain 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | matchzoo.metrics.precision module 48 | --------------------------------- 49 | 50 | .. automodule:: matchzoo.metrics.precision 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | 55 | 56 | Module contents 57 | --------------- 58 | 59 | .. automodule:: matchzoo.metrics 60 | :members: 61 | :undoc-members: 62 | :show-inheritance: 63 | -------------------------------------------------------------------------------- /docs/source/matchzoo.models.rst: -------------------------------------------------------------------------------- 1 | matchzoo.models package 2 | ======================= 3 | 4 | Submodules 5 | ---------- 6 | 7 | matchzoo.models.anmm module 8 | --------------------------- 9 | 10 | .. automodule:: matchzoo.models.anmm 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | matchzoo.models.arci module 16 | --------------------------- 17 | 18 | .. automodule:: matchzoo.models.arci 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | matchzoo.models.arcii module 24 | ---------------------------- 25 | 26 | .. automodule:: matchzoo.models.arcii 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | matchzoo.models.cdssm module 32 | ---------------------------- 33 | 34 | .. automodule:: matchzoo.models.cdssm 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | matchzoo.models.conv\_knrm module 40 | --------------------------------- 41 | 42 | .. automodule:: matchzoo.models.conv_knrm 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | matchzoo.models.dense\_baseline module 48 | -------------------------------------- 49 | 50 | .. automodule:: matchzoo.models.dense_baseline 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | 55 | matchzoo.models.drmm module 56 | --------------------------- 57 | 58 | .. automodule:: matchzoo.models.drmm 59 | :members: 60 | :undoc-members: 61 | :show-inheritance: 62 | 63 | matchzoo.models.drmmtks module 64 | ------------------------------ 65 | 66 | .. automodule:: matchzoo.models.drmmtks 67 | :members: 68 | :undoc-members: 69 | :show-inheritance: 70 | 71 | matchzoo.models.dssm module 72 | --------------------------- 73 | 74 | .. automodule:: matchzoo.models.dssm 75 | :members: 76 | :undoc-members: 77 | :show-inheritance: 78 | 79 | matchzoo.models.duet module 80 | --------------------------- 81 | 82 | .. automodule:: matchzoo.models.duet 83 | :members: 84 | :undoc-members: 85 | :show-inheritance: 86 | 87 | matchzoo.models.knrm module 88 | --------------------------- 89 | 90 | .. automodule:: matchzoo.models.knrm 91 | :members: 92 | :undoc-members: 93 | :show-inheritance: 94 | 95 | matchzoo.models.match\_pyramid module 96 | ------------------------------------- 97 | 98 | .. automodule:: matchzoo.models.match_pyramid 99 | :members: 100 | :undoc-members: 101 | :show-inheritance: 102 | 103 | matchzoo.models.mvlstm module 104 | ----------------------------- 105 | 106 | .. automodule:: matchzoo.models.mvlstm 107 | :members: 108 | :undoc-members: 109 | :show-inheritance: 110 | 111 | matchzoo.models.naive module 112 | ---------------------------- 113 | 114 | .. automodule:: matchzoo.models.naive 115 | :members: 116 | :undoc-members: 117 | :show-inheritance: 118 | 119 | matchzoo.models.parameter\_readme\_generator module 120 | --------------------------------------------------- 121 | 122 | .. automodule:: matchzoo.models.parameter_readme_generator 123 | :members: 124 | :undoc-members: 125 | :show-inheritance: 126 | 127 | 128 | Module contents 129 | --------------- 130 | 131 | .. automodule:: matchzoo.models 132 | :members: 133 | :undoc-members: 134 | :show-inheritance: 135 | -------------------------------------------------------------------------------- /docs/source/matchzoo.preprocessors.rst: -------------------------------------------------------------------------------- 1 | matchzoo.preprocessors package 2 | ============================== 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | 9 | matchzoo.preprocessors.units 10 | 11 | Submodules 12 | ---------- 13 | 14 | matchzoo.preprocessors.basic\_preprocessor module 15 | ------------------------------------------------- 16 | 17 | .. automodule:: matchzoo.preprocessors.basic_preprocessor 18 | :members: 19 | :undoc-members: 20 | :show-inheritance: 21 | 22 | matchzoo.preprocessors.build\_unit\_from\_data\_pack module 23 | ----------------------------------------------------------- 24 | 25 | .. automodule:: matchzoo.preprocessors.build_unit_from_data_pack 26 | :members: 27 | :undoc-members: 28 | :show-inheritance: 29 | 30 | matchzoo.preprocessors.build\_vocab\_unit module 31 | ------------------------------------------------ 32 | 33 | .. automodule:: matchzoo.preprocessors.build_vocab_unit 34 | :members: 35 | :undoc-members: 36 | :show-inheritance: 37 | 38 | matchzoo.preprocessors.cdssm\_preprocessor module 39 | ------------------------------------------------- 40 | 41 | .. automodule:: matchzoo.preprocessors.cdssm_preprocessor 42 | :members: 43 | :undoc-members: 44 | :show-inheritance: 45 | 46 | matchzoo.preprocessors.chain\_transform module 47 | ---------------------------------------------- 48 | 49 | .. automodule:: matchzoo.preprocessors.chain_transform 50 | :members: 51 | :undoc-members: 52 | :show-inheritance: 53 | 54 | matchzoo.preprocessors.dssm\_preprocessor module 55 | ------------------------------------------------ 56 | 57 | .. automodule:: matchzoo.preprocessors.dssm_preprocessor 58 | :members: 59 | :undoc-members: 60 | :show-inheritance: 61 | 62 | matchzoo.preprocessors.naive\_preprocessor module 63 | ------------------------------------------------- 64 | 65 | .. automodule:: matchzoo.preprocessors.naive_preprocessor 66 | :members: 67 | :undoc-members: 68 | :show-inheritance: 69 | 70 | 71 | Module contents 72 | --------------- 73 | 74 | .. automodule:: matchzoo.preprocessors 75 | :members: 76 | :undoc-members: 77 | :show-inheritance: 78 | -------------------------------------------------------------------------------- /docs/source/matchzoo.processor_units.rst: -------------------------------------------------------------------------------- 1 | matchzoo.processor\_units package 2 | ================================= 3 | 4 | Submodules 5 | ---------- 6 | 7 | matchzoo.processor\_units.chain\_transform module 8 | ------------------------------------------------- 9 | 10 | .. automodule:: matchzoo.processor_units.chain_transform 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | matchzoo.processor\_units.processor\_units module 16 | ------------------------------------------------- 17 | 18 | .. automodule:: matchzoo.processor_units.processor_units 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | 24 | Module contents 25 | --------------- 26 | 27 | .. automodule:: matchzoo.processor_units 28 | :members: 29 | :undoc-members: 30 | :show-inheritance: 31 | -------------------------------------------------------------------------------- /docs/source/matchzoo.rst: -------------------------------------------------------------------------------- 1 | matchzoo package 2 | ================ 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | 9 | matchzoo.auto 10 | matchzoo.data_generator 11 | matchzoo.data_pack 12 | matchzoo.datasets 13 | matchzoo.embedding 14 | matchzoo.engine 15 | matchzoo.layers 16 | matchzoo.losses 17 | matchzoo.metrics 18 | matchzoo.models 19 | matchzoo.preprocessors 20 | matchzoo.tasks 21 | matchzoo.utils 22 | 23 | Submodules 24 | ---------- 25 | 26 | matchzoo.version module 27 | ----------------------- 28 | 29 | .. automodule:: matchzoo.version 30 | :members: 31 | :undoc-members: 32 | :show-inheritance: 33 | 34 | 35 | Module contents 36 | --------------- 37 | 38 | .. automodule:: matchzoo 39 | :members: 40 | :undoc-members: 41 | :show-inheritance: 42 | -------------------------------------------------------------------------------- /docs/source/matchzoo.tasks.rst: -------------------------------------------------------------------------------- 1 | matchzoo.tasks package 2 | ====================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | matchzoo.tasks.classification module 8 | ------------------------------------ 9 | 10 | .. automodule:: matchzoo.tasks.classification 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | matchzoo.tasks.ranking module 16 | ----------------------------- 17 | 18 | .. automodule:: matchzoo.tasks.ranking 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | 24 | Module contents 25 | --------------- 26 | 27 | .. automodule:: matchzoo.tasks 28 | :members: 29 | :undoc-members: 30 | :show-inheritance: 31 | -------------------------------------------------------------------------------- /docs/source/matchzoo.utils.rst: -------------------------------------------------------------------------------- 1 | matchzoo.utils package 2 | ====================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | matchzoo.utils.list\_recursive\_subclasses module 8 | ------------------------------------------------- 9 | 10 | .. automodule:: matchzoo.utils.list_recursive_subclasses 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | matchzoo.utils.make\_keras\_optimizer\_picklable module 16 | ------------------------------------------------------- 17 | 18 | .. automodule:: matchzoo.utils.make_keras_optimizer_picklable 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | matchzoo.utils.one\_hot module 24 | ------------------------------ 25 | 26 | .. automodule:: matchzoo.utils.one_hot 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | matchzoo.utils.tensor\_type module 32 | ---------------------------------- 33 | 34 | .. automodule:: matchzoo.utils.tensor_type 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | 40 | Module contents 41 | --------------- 42 | 43 | .. automodule:: matchzoo.utils 44 | :members: 45 | :undoc-members: 46 | :show-inheritance: 47 | -------------------------------------------------------------------------------- /docs/source/modules.rst: -------------------------------------------------------------------------------- 1 | matchzoo 2 | ======== 3 | 4 | .. toctree:: 5 | :maxdepth: 4 6 | 7 | matchzoo 8 | -------------------------------------------------------------------------------- /matchzoo/__init__.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | USER_DIR = Path.expanduser(Path('~')).joinpath('.matchzoo') 4 | if not USER_DIR.exists(): 5 | USER_DIR.mkdir() 6 | USER_DATA_DIR = USER_DIR.joinpath('datasets') 7 | if not USER_DATA_DIR.exists(): 8 | USER_DATA_DIR.mkdir() 9 | USER_TUNED_MODELS_DIR = USER_DIR.joinpath('tuned_models') 10 | 11 | from .version import __version__ 12 | 13 | from .data_pack import DataPack 14 | from .data_pack import pack 15 | from .data_pack import load_data_pack 16 | 17 | from . import metrics 18 | from . import tasks 19 | 20 | from . import preprocessors 21 | from . import data_generator 22 | from .data_generator import DataGenerator 23 | from .data_generator import DataGeneratorBuilder 24 | 25 | from .preprocessors.chain_transform import chain_transform 26 | 27 | from . import metrics 28 | from . import losses 29 | from . import engine 30 | from . import models 31 | from . import embedding 32 | from . import datasets 33 | from . import layers 34 | from . import auto 35 | from . import contrib 36 | 37 | from .engine import hyper_spaces 38 | from .engine.base_model import load_model 39 | from .engine.base_preprocessor import load_preprocessor 40 | from .engine import callbacks 41 | from .engine.param import Param 42 | from .engine.param_table import ParamTable 43 | 44 | from .embedding.embedding import Embedding 45 | 46 | from .utils import one_hot, make_keras_optimizer_picklable 47 | from .preprocessors.build_unit_from_data_pack import build_unit_from_data_pack 48 | from .preprocessors.build_vocab_unit import build_vocab_unit 49 | 50 | # deprecated, should be removed in v2.2 51 | from .contrib.legacy_data_generator import DPoolDataGenerator 52 | from .contrib.legacy_data_generator import DPoolPairDataGenerator 53 | from .contrib.legacy_data_generator import HistogramDataGenerator 54 | from .contrib.legacy_data_generator import HistogramPairDataGenerator 55 | from .contrib.legacy_data_generator import DynamicDataGenerator 56 | from .contrib.legacy_data_generator import PairDataGenerator 57 | -------------------------------------------------------------------------------- /matchzoo/auto/__init__.py: -------------------------------------------------------------------------------- 1 | from .preparer import prepare 2 | from .preparer import Preparer 3 | 4 | from .tuner import Tuner 5 | from .tuner import tune 6 | 7 | # mz.auto.tuner.callbacks 8 | from . import tuner 9 | -------------------------------------------------------------------------------- /matchzoo/auto/preparer/__init__.py: -------------------------------------------------------------------------------- 1 | from .preparer import Preparer 2 | from .prepare import prepare 3 | -------------------------------------------------------------------------------- /matchzoo/auto/preparer/prepare.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import matchzoo as mz 4 | from .preparer import Preparer 5 | from matchzoo.engine.base_task import BaseTask 6 | from matchzoo.engine.base_model import BaseModel 7 | from matchzoo.engine.base_preprocessor import BasePreprocessor 8 | 9 | 10 | def prepare( 11 | task: BaseTask, 12 | model_class: typing.Type[BaseModel], 13 | data_pack: mz.DataPack, 14 | preprocessor: typing.Optional[BasePreprocessor] = None, 15 | embedding: typing.Optional['mz.Embedding'] = None, 16 | config: typing.Optional[dict] = None, 17 | ): 18 | """ 19 | A simple shorthand for using :class:`matchzoo.Preparer`. 20 | 21 | `config` is used to control specific behaviors. The default `config` 22 | will be updated accordingly if a `config` dictionary is passed. e.g. to 23 | override the default `bin_size`, pass `config={'bin_size': 15}`. 24 | 25 | :param task: Task. 26 | :param model_class: Model class. 27 | :param data_pack: DataPack used to fit the preprocessor. 28 | :param preprocessor: Preprocessor used to fit the `data_pack`. 29 | (default: the default preprocessor of `model_class`) 30 | :param embedding: Embedding to build a embedding matrix. If not set, 31 | then a correctly shaped randomized matrix will be built. 32 | :param config: Configuration of specific behaviors. (default: return 33 | value of `mz.Preparer.get_default_config()`) 34 | 35 | :return: A tuple of `(model, preprocessor, data_generator_builder, 36 | embedding_matrix)`. 37 | 38 | """ 39 | preparer = Preparer(task=task, config=config) 40 | return preparer.prepare( 41 | model_class=model_class, 42 | data_pack=data_pack, 43 | preprocessor=preprocessor, 44 | embedding=embedding 45 | ) 46 | -------------------------------------------------------------------------------- /matchzoo/auto/tuner/__init__.py: -------------------------------------------------------------------------------- 1 | from . import callbacks 2 | from .tuner import Tuner 3 | from .tune import tune 4 | -------------------------------------------------------------------------------- /matchzoo/auto/tuner/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | from .callback import Callback 2 | from .lambda_callback import LambdaCallback 3 | from .load_embedding_matrix import LoadEmbeddingMatrix 4 | from .save_model import SaveModel 5 | -------------------------------------------------------------------------------- /matchzoo/auto/tuner/callbacks/callback.py: -------------------------------------------------------------------------------- 1 | import matchzoo 2 | from matchzoo.engine.base_model import BaseModel 3 | 4 | 5 | class Callback(object): 6 | """ 7 | Tuner callback base class. 8 | 9 | To build your own callbacks, inherit `mz.auto.tuner.callbacks.Callback` 10 | and overrides corresponding methods. 11 | 12 | A run proceeds in the following way: 13 | 14 | - run start (callback) 15 | - build model 16 | - build end (callback) 17 | - fit and evaluate model 18 | - collect result 19 | - run end (callback) 20 | 21 | This process is repeated for `num_runs` times in a tuner. 22 | 23 | """ 24 | 25 | def on_run_start(self, tuner: 'matchzoo.auto.Tuner', sample: dict): 26 | """ 27 | Callback on run start stage. 28 | 29 | :param tuner: Tuner. 30 | :param sample: Sampled hyper space. Changes to this dictionary affects 31 | the model building process of the tuner. 32 | """ 33 | 34 | def on_build_end(self, tuner: 'matchzoo.auto.Tuner', model: BaseModel): 35 | """ 36 | Callback on build end stage. 37 | 38 | :param tuner: Tuner. 39 | :param model: A built model ready for fitting and evluating. Changes 40 | to this model affect the fitting and evaluating process. 41 | """ 42 | 43 | def on_run_end(self, tuner: 'matchzoo.auto.Tuner', model: BaseModel, 44 | result: dict): 45 | """ 46 | Callback on run end stage. 47 | 48 | :param tuner: Tuner. 49 | :param model: A built model done fitting and evaluating. Changes to 50 | the model will no longer affect the result. 51 | :param result: Result of the run. Changes to this dictionary will be 52 | visible in the return value of the `tune` method. 53 | """ 54 | -------------------------------------------------------------------------------- /matchzoo/auto/tuner/callbacks/lambda_callback.py: -------------------------------------------------------------------------------- 1 | from matchzoo.engine.base_model import BaseModel 2 | from matchzoo.auto.tuner.callbacks.callback import Callback 3 | 4 | 5 | class LambdaCallback(Callback): 6 | """ 7 | LambdaCallback. Just a shorthand for creating a callback class. 8 | 9 | See :class:`matchzoo.tuner.callbacks.Callback` for more details. 10 | 11 | Example: 12 | 13 | >>> import matchzoo as mz 14 | >>> model = mz.models.Naive() 15 | >>> model.guess_and_fill_missing_params(verbose=0) 16 | >>> data = mz.datasets.toy.load_data() 17 | >>> data = model.get_default_preprocessor().fit_transform( 18 | ... data, verbose=0) 19 | >>> def show_inputs(*args): 20 | ... print(' '.join(map(str, map(type, args)))) 21 | >>> callback = mz.auto.tuner.callbacks.LambdaCallback( 22 | ... on_run_start=show_inputs, 23 | ... on_build_end=show_inputs, 24 | ... on_run_end=show_inputs 25 | ... ) 26 | >>> _ = mz.auto.tune( 27 | ... params=model.params, 28 | ... train_data=data, 29 | ... test_data=data, 30 | ... num_runs=1, 31 | ... callbacks=[callback], 32 | ... verbose=0, 33 | ... ) # noqa: E501 34 | 35 | 36 | 37 | 38 | """ 39 | 40 | def __init__( 41 | self, 42 | on_run_start=None, 43 | on_build_end=None, 44 | on_run_end=None 45 | ): 46 | """Init.""" 47 | self._on_run_start = on_run_start 48 | self._on_build_end = on_build_end 49 | self._on_run_end = on_run_end 50 | 51 | def on_run_start(self, tuner, sample: dict): 52 | """`on_run_start`.""" 53 | if self._on_run_start: 54 | self._on_run_start(tuner, sample) 55 | 56 | def on_build_end(self, tuner, model: BaseModel): 57 | """`on_build_end`.""" 58 | if self._on_build_end: 59 | self._on_build_end(tuner, model) 60 | 61 | def on_run_end(self, tuner, model: BaseModel, result: dict): 62 | """`on_run_end`.""" 63 | if self._on_run_end: 64 | self._on_run_end(tuner, model, result) 65 | -------------------------------------------------------------------------------- /matchzoo/auto/tuner/callbacks/load_embedding_matrix.py: -------------------------------------------------------------------------------- 1 | from matchzoo.engine.base_model import BaseModel 2 | from matchzoo.auto.tuner.callbacks.callback import Callback 3 | 4 | 5 | class LoadEmbeddingMatrix(Callback): 6 | """ 7 | Load a pre-trained embedding after the model is built. 8 | 9 | Used with tuner to load a pre-trained embedding matrix for each newly built 10 | model instance. 11 | 12 | :param embedding_matrix: Embedding matrix to load. 13 | 14 | Example: 15 | 16 | >>> import matchzoo as mz 17 | >>> model = mz.models.ArcI() 18 | >>> prpr = model.get_default_preprocessor() 19 | >>> data = mz.datasets.toy.load_data() 20 | >>> data = prpr.fit_transform(data, verbose=0) 21 | >>> embed = mz.datasets.toy.load_embedding() 22 | >>> term_index = prpr.context['vocab_unit'].state['term_index'] 23 | >>> matrix = embed.build_matrix(term_index) 24 | >>> callback = mz.auto.tuner.callbacks.LoadEmbeddingMatrix(matrix) 25 | >>> model.params.update(prpr.context) 26 | >>> model.params['task'] = mz.tasks.Ranking() 27 | >>> model.params['embedding_output_dim'] = embed.output_dim 28 | >>> result = mz.auto.tune( 29 | ... params=model.params, 30 | ... train_data=data, 31 | ... test_data=data, 32 | ... num_runs=1, 33 | ... callbacks=[callback], 34 | ... verbose=0 35 | ... ) 36 | 37 | """ 38 | 39 | def __init__(self, embedding_matrix): 40 | """Init.""" 41 | self._embedding_matrix = embedding_matrix 42 | 43 | def on_build_end(self, tuner, model: BaseModel): 44 | """`on_build_end`.""" 45 | model.load_embedding_matrix(self._embedding_matrix) 46 | -------------------------------------------------------------------------------- /matchzoo/auto/tuner/callbacks/save_model.py: -------------------------------------------------------------------------------- 1 | import typing 2 | from pathlib import Path 3 | import uuid 4 | 5 | import matchzoo as mz 6 | from matchzoo.engine.base_model import BaseModel 7 | from .callback import Callback 8 | 9 | 10 | class SaveModel(Callback): 11 | """ 12 | Save trained model. 13 | 14 | For each trained model, a UUID will be generated as the `model_id`, the 15 | model will be saved under the `dir_path/model_id`. A `model_id` key will 16 | also be inserted into the result, which will visible in the return value of 17 | the `tune` method. 18 | 19 | :param dir_path: Path to save the models to. (default: 20 | `matchzoo.USER_TUNED_MODELS_DIR`) 21 | 22 | """ 23 | 24 | def __init__( 25 | self, 26 | dir_path: typing.Union[str, Path] = mz.USER_TUNED_MODELS_DIR 27 | ): 28 | """Init.""" 29 | self._dir_path = dir_path 30 | 31 | def on_run_end(self, tuner, model: BaseModel, result: dict): 32 | """Save model on run end.""" 33 | model_id = str(uuid.uuid4()) 34 | model.save(self._dir_path.joinpath(model_id)) 35 | result['model_id'] = model_id 36 | -------------------------------------------------------------------------------- /matchzoo/contrib/README.md: -------------------------------------------------------------------------------- 1 | A module containing volatile, experimental and legacy code. 2 | -------------------------------------------------------------------------------- /matchzoo/contrib/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | A module containing volatile, experimental and legacy code. 3 | """ 4 | 5 | from . import layers 6 | from . import models 7 | -------------------------------------------------------------------------------- /matchzoo/contrib/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .attention_layer import AttentionLayer 2 | from .multi_perspective_layer import MultiPerspectiveLayer 3 | from .matching_tensor_layer import MatchingTensorLayer 4 | from .spatial_gru import SpatialGRU 5 | from .decaying_dropout_layer import DecayingDropoutLayer 6 | from .semantic_composite_layer import EncodingLayer 7 | 8 | layer_dict = { 9 | "MatchingTensorLayer": MatchingTensorLayer, 10 | "SpatialGRU": SpatialGRU, 11 | "DecayingDropoutLayer": DecayingDropoutLayer, 12 | "EncodingLayer": EncodingLayer 13 | } 14 | -------------------------------------------------------------------------------- /matchzoo/contrib/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .match_lstm import MatchLSTM 2 | from .match_srnn import MatchSRNN 3 | from .hbmp import HBMP 4 | from .esim import ESIM 5 | from .bimpm import BiMPM 6 | from .diin import DIIN 7 | -------------------------------------------------------------------------------- /matchzoo/data_generator/__init__.py: -------------------------------------------------------------------------------- 1 | from . import callbacks 2 | from .data_generator import DataGenerator 3 | from .data_generator_builder import DataGeneratorBuilder 4 | -------------------------------------------------------------------------------- /matchzoo/data_generator/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | from .callback import Callback 2 | from .lambda_callback import LambdaCallback 3 | from .dynamic_pooling import DynamicPooling 4 | from .histogram import Histogram 5 | -------------------------------------------------------------------------------- /matchzoo/data_generator/callbacks/callback.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import matchzoo as mz 4 | 5 | 6 | class Callback(object): 7 | """ 8 | DataGenerator callback base class. 9 | 10 | To build your own callbacks, inherit `mz.data_generator.callbacks.Callback` 11 | and overrides corresponding methods. 12 | 13 | A batch is processed in the following way: 14 | 15 | - slice data pack based on batch index 16 | - handle `on_batch_data_pack` callbacks 17 | - unpack data pack into x, y 18 | - handle `on_batch_x_y` callbacks 19 | - return x, y 20 | 21 | """ 22 | 23 | def on_batch_data_pack(self, data_pack: mz.DataPack): 24 | """ 25 | `on_batch_data_pack`. 26 | 27 | :param data_pack: a sliced DataPack before unpacking. 28 | """ 29 | 30 | def on_batch_unpacked(self, x: dict, y: np.ndarray): 31 | """ 32 | `on_batch_unpacked`. 33 | 34 | :param x: unpacked x. 35 | :param y: unpacked y. 36 | """ 37 | -------------------------------------------------------------------------------- /matchzoo/data_generator/callbacks/histogram.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import matchzoo as mz 4 | from matchzoo.data_generator.callbacks import Callback 5 | 6 | 7 | class Histogram(Callback): 8 | """ 9 | Generate data with matching histogram. 10 | 11 | :param embedding_matrix: The embedding matrix used to generator match 12 | histogram. 13 | :param bin_size: The number of bin size of the histogram. 14 | :param hist_mode: The mode of the :class:`MatchingHistogramUnit`, one of 15 | `CH`, `NH`, and `LCH`. 16 | """ 17 | 18 | def __init__( 19 | self, 20 | embedding_matrix: np.ndarray, 21 | bin_size: int = 30, 22 | hist_mode: str = 'CH', 23 | ): 24 | """Init.""" 25 | self._match_hist_unit = mz.preprocessors.units.MatchingHistogram( 26 | bin_size=bin_size, 27 | embedding_matrix=embedding_matrix, 28 | normalize=True, 29 | mode=hist_mode 30 | ) 31 | 32 | def on_batch_unpacked(self, x, y): 33 | """Insert `match_histogram` to `x`.""" 34 | x['match_histogram'] = _build_match_histogram(x, self._match_hist_unit) 35 | 36 | 37 | def _trunc_text(input_text: list, length: list) -> list: 38 | """ 39 | Truncating the input text according to the input length. 40 | 41 | :param input_text: The input text need to be truncated. 42 | :param length: The length used to truncated the text. 43 | :return: The truncated text. 44 | """ 45 | return [row[:length[idx]] for idx, row in enumerate(input_text)] 46 | 47 | 48 | def _build_match_histogram( 49 | x: dict, 50 | match_hist_unit: mz.preprocessors.units.MatchingHistogram 51 | ) -> np.ndarray: 52 | """ 53 | Generate the matching hisogram for input. 54 | 55 | :param x: The input `dict`. 56 | :param match_hist_unit: The histogram unit :class:`MatchingHistogramUnit`. 57 | :return: The matching histogram. 58 | """ 59 | match_hist = [] 60 | text_left = x['text_left'].tolist() 61 | text_right = _trunc_text(x['text_right'].tolist(), 62 | x['length_right'].tolist()) 63 | for pair in zip(text_left, text_right): 64 | match_hist.append(match_hist_unit.transform(list(pair))) 65 | return np.asarray(match_hist) 66 | -------------------------------------------------------------------------------- /matchzoo/data_generator/callbacks/lambda_callback.py: -------------------------------------------------------------------------------- 1 | from matchzoo.data_generator.callbacks.callback import Callback 2 | 3 | 4 | class LambdaCallback(Callback): 5 | """ 6 | LambdaCallback. Just a shorthand for creating a callback class. 7 | 8 | See :class:`matchzoo.data_generator.callbacks.Callback` for more details. 9 | 10 | Example: 11 | 12 | >>> import matchzoo as mz 13 | >>> from matchzoo.data_generator.callbacks import LambdaCallback 14 | >>> data = mz.datasets.toy.load_data() 15 | >>> batch_func = lambda x: print(type(x)) 16 | >>> unpack_func = lambda x, y: print(type(x), type(y)) 17 | >>> callback = LambdaCallback(on_batch_data_pack=batch_func, 18 | ... on_batch_unpacked=unpack_func) 19 | >>> data_gen = mz.DataGenerator( 20 | ... data, batch_size=len(data), callbacks=[callback]) 21 | >>> _ = data_gen[0] 22 | 23 | 24 | 25 | """ 26 | 27 | def __init__(self, on_batch_data_pack=None, on_batch_unpacked=None): 28 | """Init.""" 29 | self._on_batch_unpacked = on_batch_unpacked 30 | self._on_batch_data_pack = on_batch_data_pack 31 | 32 | def on_batch_data_pack(self, data_pack): 33 | """`on_batch_data_pack`.""" 34 | if self._on_batch_data_pack: 35 | self._on_batch_data_pack(data_pack) 36 | 37 | def on_batch_unpacked(self, x, y): 38 | """`on_batch_unpacked`.""" 39 | if self._on_batch_unpacked: 40 | self._on_batch_unpacked(x, y) 41 | -------------------------------------------------------------------------------- /matchzoo/data_generator/data_generator_builder.py: -------------------------------------------------------------------------------- 1 | import matchzoo as mz 2 | from matchzoo.data_generator.data_generator import DataGenerator 3 | 4 | 5 | class DataGeneratorBuilder(object): 6 | """ 7 | Data Generator Bulider. In essense a wrapped partial function. 8 | 9 | Example: 10 | >>> import matchzoo as mz 11 | >>> builder = mz.DataGeneratorBuilder(mode='pair', batch_size=32) 12 | >>> data = mz.datasets.toy.load_data() 13 | >>> gen = builder.build(data) 14 | >>> type(gen) 15 | 16 | >>> gen.batch_size 17 | 32 18 | >>> gen_64 = builder.build(data, batch_size=64) 19 | >>> gen_64.batch_size 20 | 64 21 | 22 | """ 23 | 24 | def __init__(self, **kwargs): 25 | """Init.""" 26 | self._kwargs = kwargs 27 | 28 | def build(self, data_pack, **kwargs) -> DataGenerator: 29 | """ 30 | Build a DataGenerator. 31 | 32 | :param data_pack: DataPack to build upon. 33 | :param kwargs: Additional keyword arguments to override the keyword 34 | arguments passed in `__init__`. 35 | """ 36 | return mz.DataGenerator(data_pack, **{**self._kwargs, **kwargs}) 37 | -------------------------------------------------------------------------------- /matchzoo/data_pack/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_pack import DataPack, load_data_pack 2 | from .pack import pack 3 | -------------------------------------------------------------------------------- /matchzoo/data_pack/pack.py: -------------------------------------------------------------------------------- 1 | """Convert list of input into class:`DataPack` expected format.""" 2 | 3 | import typing 4 | 5 | import pandas as pd 6 | import numpy as np 7 | 8 | import matchzoo 9 | 10 | 11 | def pack(df: pd.DataFrame) -> 'matchzoo.DataPack': 12 | """ 13 | Pack a :class:`DataPack` using `df`. 14 | 15 | The `df` must have `text_left` and `text_right` columns. Optionally, 16 | the `df` can have `id_left`, `id_right` to index `text_left` and 17 | `text_right` respectively. `id_left`, `id_right` will be automatically 18 | generated if not specified. 19 | 20 | :param df: Input :class:`pandas.DataFrame` to use. 21 | 22 | Examples:: 23 | >>> import matchzoo as mz 24 | >>> import pandas as pd 25 | >>> df = pd.DataFrame(data={'text_left': list('AABC'), 26 | ... 'text_right': list('abbc'), 27 | ... 'label': [0, 1, 1, 0]}) 28 | >>> mz.pack(df).frame() 29 | id_left text_left id_right text_right label 30 | 0 L-0 A R-0 a 0 31 | 1 L-0 A R-1 b 1 32 | 2 L-1 B R-1 b 1 33 | 3 L-2 C R-2 c 0 34 | 35 | """ 36 | if 'text_left' not in df or 'text_right' not in df: 37 | raise ValueError( 38 | 'Input data frame must have `text_left` and `text_right`.') 39 | 40 | # Gather IDs 41 | if 'id_left' not in df: 42 | id_left = _gen_ids(df, 'text_left', 'L-') 43 | else: 44 | id_left = df['id_left'] 45 | if 'id_right' not in df: 46 | id_right = _gen_ids(df, 'text_right', 'R-') 47 | else: 48 | id_right = df['id_right'] 49 | 50 | # Build Relation 51 | relation = pd.DataFrame(data={'id_left': id_left, 'id_right': id_right}) 52 | for col in df: 53 | if col not in ['id_left', 'id_right', 'text_left', 'text_right']: 54 | relation[col] = df[col] 55 | 56 | # Build Left and Right 57 | left = _merge(df, id_left, 'text_left', 'id_left') 58 | right = _merge(df, id_right, 'text_right', 'id_right') 59 | return matchzoo.DataPack(relation, left, right) 60 | 61 | 62 | def _merge(data: pd.DataFrame, ids: typing.Union[list, np.array], 63 | text_label: str, id_label: str): 64 | left = pd.DataFrame(data={ 65 | text_label: data[text_label], id_label: ids 66 | }) 67 | left.drop_duplicates(id_label, inplace=True) 68 | left.set_index(id_label, inplace=True) 69 | return left 70 | 71 | 72 | def _gen_ids(data: pd.DataFrame, col: str, prefix: str): 73 | lookup = {} 74 | for text in data[col].unique(): 75 | lookup[text] = prefix + str(len(lookup)) 76 | return data[col].map(lookup) 77 | -------------------------------------------------------------------------------- /matchzoo/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from . import toy 2 | from . import wiki_qa 3 | from . import embeddings 4 | from . import snli 5 | from . import quora_qp 6 | from . import cqa_ql_16 7 | from pathlib import Path 8 | 9 | 10 | def list_available(): 11 | return [p.name for p in Path(__file__).parent.iterdir() 12 | if p.is_dir() and not p.name.startswith('_')] 13 | -------------------------------------------------------------------------------- /matchzoo/datasets/bert_resources/uncased_vocab_100.txt: -------------------------------------------------------------------------------- 1 | [PAD] 2 | ##ness 3 | episode 4 | bed 5 | added 6 | table 7 | indian 8 | private 9 | charles 10 | route 11 | available 12 | idea 13 | throughout 14 | centre 15 | addition 16 | appointed 17 | style 18 | 1994 19 | books 20 | eight 21 | construction 22 | press 23 | mean 24 | wall 25 | friends 26 | remained 27 | schools 28 | study 29 | ##ch 30 | ##um 31 | institute 32 | oh 33 | chinese 34 | sometimes 35 | events 36 | possible 37 | 1992 38 | australian 39 | type 40 | brown 41 | forward 42 | talk 43 | process 44 | food 45 | debut 46 | seat 47 | performance 48 | committee 49 | features 50 | character 51 | arts 52 | herself 53 | else 54 | lot 55 | strong 56 | russian 57 | range 58 | hours 59 | peter 60 | arm 61 | ##da 62 | morning 63 | dr 64 | sold 65 | ##ry 66 | quickly 67 | directed 68 | 1993 69 | guitar 70 | china 71 | ##w 72 | 31 73 | list 74 | ##ma 75 | performed 76 | media 77 | uk 78 | players 79 | smile 80 | ##rs 81 | myself 82 | 40 83 | placed 84 | coach 85 | province 86 | ##gawa 87 | typed 88 | ##dry 89 | favors 90 | allegheny 91 | glaciers 92 | ##rly 93 | recalling 94 | aziz 95 | ##log 96 | parasite 97 | requiem 98 | auf 99 | ##berto 100 | ##llin 101 | [UNK] -------------------------------------------------------------------------------- /matchzoo/datasets/cqa_ql_16/__init__.py: -------------------------------------------------------------------------------- 1 | from .load_data import load_data -------------------------------------------------------------------------------- /matchzoo/datasets/embeddings/__init__.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from .load_glove_embedding import load_glove_embedding 3 | 4 | DATA_ROOT = Path(__file__).parent 5 | EMBED_RANK = DATA_ROOT.joinpath('embed_rank.txt') 6 | EMBED_10 = DATA_ROOT.joinpath('embed_10_word2vec.txt') 7 | EMBED_10_GLOVE = DATA_ROOT.joinpath('embed_10_glove.txt') 8 | -------------------------------------------------------------------------------- /matchzoo/datasets/embeddings/embed_10_glove.txt: -------------------------------------------------------------------------------- 1 | A 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 2 | B 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 3 | C 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 4 | D 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 5 | E 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 6 | -------------------------------------------------------------------------------- /matchzoo/datasets/embeddings/embed_10_word2vec.txt: -------------------------------------------------------------------------------- 1 | 5 10 2 | A 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 3 | B 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 4 | C 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 5 | D 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 6 | E 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 7 | -------------------------------------------------------------------------------- /matchzoo/datasets/embeddings/embed_err.txt.gb2312: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NTMC-Community/MatchZoo/8a487ee5a574356fc91e4f48e219253dc11bcff2/matchzoo/datasets/embeddings/embed_err.txt.gb2312 -------------------------------------------------------------------------------- /matchzoo/datasets/embeddings/embed_word.txt: -------------------------------------------------------------------------------- 1 | 7 5 2 | asia 1 2 3 4 5 3 | beijing 1 1 1 1 1 4 | hot 2 2 2 2 2 5 | east 3 3 3 3 3 6 | capital 4 4 4 4 4 7 | china 5 5 5 5 5 8 | -------------------------------------------------------------------------------- /matchzoo/datasets/embeddings/load_glove_embedding.py: -------------------------------------------------------------------------------- 1 | """Embedding data loader.""" 2 | 3 | from pathlib import Path 4 | 5 | import keras 6 | 7 | import matchzoo as mz 8 | 9 | _glove_embedding_url = "http://nlp.stanford.edu/data/glove.6B.zip" 10 | 11 | 12 | def load_glove_embedding(dimension: int = 50) -> mz.embedding.Embedding: 13 | """ 14 | Return the pretrained glove embedding. 15 | 16 | :param dimension: the size of embedding dimension, the value can only be 17 | 50, 100, or 300. 18 | :return: The :class:`mz.embedding.Embedding` object. 19 | """ 20 | file_name = 'glove.6B.' + str(dimension) + 'd.txt' 21 | file_path = (Path(mz.USER_DATA_DIR) / 'glove').joinpath(file_name) 22 | if not file_path.exists(): 23 | keras.utils.data_utils.get_file('glove_embedding', 24 | _glove_embedding_url, 25 | extract=True, 26 | cache_dir=mz.USER_DATA_DIR, 27 | cache_subdir='glove') 28 | return mz.embedding.load_from_file(file_path=str(file_path), mode='glove') 29 | -------------------------------------------------------------------------------- /matchzoo/datasets/quora_qp/__init__.py: -------------------------------------------------------------------------------- 1 | from .load_data import load_data 2 | -------------------------------------------------------------------------------- /matchzoo/datasets/quora_qp/load_data.py: -------------------------------------------------------------------------------- 1 | """Quora Question Pairs data loader.""" 2 | 3 | import typing 4 | from pathlib import Path 5 | 6 | import keras 7 | import pandas as pd 8 | 9 | import matchzoo 10 | 11 | _url = "https://firebasestorage.googleapis.com/v0/b/mtl-sentence" \ 12 | "-representations.appspot.com/o/data%2FQQP.zip?alt=media&" \ 13 | "token=700c6acf-160d-4d89-81d1-de4191d02cb5" 14 | 15 | 16 | def load_data( 17 | stage: str = 'train', 18 | task: str = 'classification', 19 | return_classes: bool = False, 20 | ) -> typing.Union[matchzoo.DataPack, tuple]: 21 | """ 22 | Load QuoraQP data. 23 | 24 | :param path: `None` for download from quora, specific path for 25 | downloaded data. 26 | :param stage: One of `train`, `dev`, and `test`. 27 | :param task: Could be one of `ranking`, `classification` or a 28 | :class:`matchzoo.engine.BaseTask` instance. 29 | :param return_classes: Whether return classes for classification task. 30 | :return: A DataPack if `ranking`, a tuple of (DataPack, classes) if 31 | `classification`. 32 | """ 33 | if stage not in ('train', 'dev', 'test'): 34 | raise ValueError(f"{stage} is not a valid stage." 35 | f"Must be one of `train`, `dev`, and `test`.") 36 | 37 | data_root = _download_data() 38 | file_path = data_root.joinpath(f"{stage}.tsv") 39 | data_pack = _read_data(file_path, stage) 40 | 41 | if task == 'ranking': 42 | task = matchzoo.tasks.Ranking() 43 | elif task == 'classification': 44 | task = matchzoo.tasks.Classification() 45 | 46 | if isinstance(task, matchzoo.tasks.Ranking): 47 | return data_pack 48 | elif isinstance(task, matchzoo.tasks.Classification): 49 | if stage != 'test': 50 | data_pack.one_hot_encode_label(num_classes=2, inplace=True) 51 | if return_classes: 52 | return data_pack, [False, True] 53 | else: 54 | return data_pack 55 | else: 56 | raise ValueError(f"{task} is not a valid task.") 57 | 58 | 59 | def _download_data(): 60 | ref_path = keras.utils.data_utils.get_file( 61 | 'quora_qp', _url, extract=True, 62 | cache_dir=matchzoo.USER_DATA_DIR, 63 | cache_subdir='quora_qp' 64 | ) 65 | return Path(ref_path).parent.joinpath('QQP') 66 | 67 | 68 | def _read_data(path, stage): 69 | data = pd.read_csv(path, sep='\t', error_bad_lines=False) 70 | data = data.dropna(axis=0, how='any').reset_index(drop=True) 71 | if stage in ['train', 'dev']: 72 | df = pd.DataFrame({ 73 | 'id_left': data['qid1'], 74 | 'id_right': data['qid2'], 75 | 'text_left': data['question1'], 76 | 'text_right': data['question2'], 77 | 'label': data['is_duplicate'].astype(int) 78 | }) 79 | else: 80 | df = pd.DataFrame({ 81 | 'text_left': data['question1'], 82 | 'text_right': data['question2'] 83 | }) 84 | return matchzoo.pack(df) 85 | -------------------------------------------------------------------------------- /matchzoo/datasets/snli/__init__.py: -------------------------------------------------------------------------------- 1 | from .load_data import load_data 2 | -------------------------------------------------------------------------------- /matchzoo/datasets/snli/load_data.py: -------------------------------------------------------------------------------- 1 | """SNLI data loader.""" 2 | 3 | import typing 4 | from pathlib import Path 5 | 6 | import pandas as pd 7 | import keras 8 | 9 | import matchzoo 10 | 11 | _url = "https://nlp.stanford.edu/projects/snli/snli_1.0.zip" 12 | 13 | 14 | def load_data( 15 | stage: str = 'train', 16 | task: str = 'classification', 17 | target_label: str = 'entailment', 18 | return_classes: bool = False 19 | ) -> typing.Union[matchzoo.DataPack, tuple]: 20 | """ 21 | Load SNLI data. 22 | 23 | :param stage: One of `train`, `dev`, and `test`. (default: `train`) 24 | :param task: Could be one of `ranking`, `classification` or a 25 | :class:`matchzoo.engine.BaseTask` instance. (default: `ranking`) 26 | :param target_label: If `ranking`, chose one of `entailment`, 27 | `contradiction`, `neutral`, and `-` as the positive label. 28 | (default: `entailment`) 29 | :param return_classes: `True` to return classes for classification task, 30 | `False` otherwise. 31 | 32 | :return: A DataPack unless `task` is `classificiation` and `return_classes` 33 | is `True`: a tuple of `(DataPack, classes)` in that case. 34 | """ 35 | if stage not in ('train', 'dev', 'test'): 36 | raise ValueError(f"{stage} is not a valid stage." 37 | f"Must be one of `train`, `dev`, and `test`.") 38 | 39 | data_root = _download_data() 40 | file_path = data_root.joinpath(f'snli_1.0_{stage}.txt') 41 | data_pack = _read_data(file_path) 42 | 43 | if task == 'ranking': 44 | task = matchzoo.tasks.Ranking() 45 | if task == 'classification': 46 | task = matchzoo.tasks.Classification() 47 | 48 | if isinstance(task, matchzoo.tasks.Ranking): 49 | if target_label not in ['entailment', 'contradiction', 'neutral', '-']: 50 | raise ValueError(f"{target_label} is not a valid target label." 51 | f"Must be one of `entailment`, `contradiction`, " 52 | f"`neutral` and `-`.") 53 | binary = (data_pack.relation['label'] == target_label).astype(float) 54 | data_pack.relation['label'] = binary 55 | return data_pack 56 | elif isinstance(task, matchzoo.tasks.Classification): 57 | classes = ['entailment', 'contradiction', 'neutral', '-'] 58 | label = data_pack.relation['label'].apply(classes.index) 59 | data_pack.relation['label'] = label 60 | data_pack.one_hot_encode_label(num_classes=4, inplace=True) 61 | if return_classes: 62 | return data_pack, classes 63 | else: 64 | return data_pack 65 | else: 66 | raise ValueError(f"{task} is not a valid task." 67 | f"Must be one of `Ranking` and `Classification`.") 68 | 69 | 70 | def _download_data(): 71 | ref_path = keras.utils.data_utils.get_file( 72 | 'snli', _url, extract=True, 73 | cache_dir=matchzoo.USER_DATA_DIR, 74 | cache_subdir='snli' 75 | ) 76 | return Path(ref_path).parent.joinpath('snli_1.0') 77 | 78 | 79 | def _read_data(path): 80 | table = pd.read_csv(path, sep='\t') 81 | df = pd.DataFrame({ 82 | 'text_left': table['sentence1'], 83 | 'text_right': table['sentence2'], 84 | 'label': table['gold_label'] 85 | }) 86 | df = df.dropna(axis=0, how='any').reset_index(drop=True) 87 | return matchzoo.pack(df) 88 | -------------------------------------------------------------------------------- /matchzoo/datasets/toy/__init__.py: -------------------------------------------------------------------------------- 1 | import typing 2 | from pathlib import Path 3 | 4 | import pandas as pd 5 | 6 | import matchzoo 7 | 8 | 9 | def load_data( 10 | stage: str = 'train', 11 | task: str = 'ranking', 12 | return_classes: bool = False 13 | ) -> typing.Union[matchzoo.DataPack, typing.Tuple[matchzoo.DataPack, list]]: 14 | """ 15 | Load WikiQA data. 16 | 17 | :param stage: One of `train`, `dev`, and `test`. 18 | :param task: Could be one of `ranking`, `classification` or a 19 | :class:`matchzoo.engine.BaseTask` instance. 20 | :param return_classes: `True` to return classes for classification task, 21 | `False` otherwise. 22 | 23 | :return: A DataPack unless `task` is `classificiation` and `return_classes` 24 | is `True`: a tuple of `(DataPack, classes)` in that case. 25 | 26 | Example: 27 | >>> import matchzoo as mz 28 | >>> stages = 'train', 'dev', 'test' 29 | >>> tasks = 'ranking', 'classification' 30 | >>> for stage in stages: 31 | ... for task in tasks: 32 | ... _ = mz.datasets.toy.load_data(stage, task) 33 | """ 34 | if stage not in ('train', 'dev', 'test'): 35 | raise ValueError(f"{stage} is not a valid stage." 36 | f"Must be one of `train`, `dev`, and `test`.") 37 | 38 | if task == 'ranking': 39 | task = matchzoo.tasks.Ranking() 40 | if task == 'classification': 41 | task = matchzoo.tasks.Classification() 42 | 43 | path = Path(__file__).parent.joinpath(f'{stage}.csv') 44 | data_pack = matchzoo.pack(pd.read_csv(path, index_col=0)) 45 | 46 | if isinstance(task, matchzoo.tasks.Ranking): 47 | data_pack.relation['label'] = \ 48 | data_pack.relation['label'].astype('float32') 49 | return data_pack 50 | elif isinstance(task, matchzoo.tasks.Classification): 51 | data_pack.relation['label'] = data_pack.relation['label'].astype(int) 52 | data_pack = data_pack.one_hot_encode_label(num_classes=2) 53 | if return_classes: 54 | return data_pack, [False, True] 55 | else: 56 | return data_pack 57 | else: 58 | raise ValueError(f"{task} is not a valid task." 59 | f"Must be one of `Ranking` and `Classification`.") 60 | 61 | 62 | def load_embedding(): 63 | path = Path(__file__).parent.joinpath('embedding.2d.txt') 64 | return matchzoo.embedding.load_from_file(path, mode='glove') 65 | -------------------------------------------------------------------------------- /matchzoo/datasets/wiki_qa/__init__.py: -------------------------------------------------------------------------------- 1 | from .load_data import load_data 2 | -------------------------------------------------------------------------------- /matchzoo/datasets/wiki_qa/load_data.py: -------------------------------------------------------------------------------- 1 | """WikiQA data loader.""" 2 | 3 | import typing 4 | import csv 5 | from pathlib import Path 6 | 7 | import keras 8 | import pandas as pd 9 | 10 | import matchzoo 11 | 12 | _url = "https://download.microsoft.com/download/E/5/F/" \ 13 | "E5FCFCEE-7005-4814-853D-DAA7C66507E0/WikiQACorpus.zip" 14 | 15 | 16 | def load_data( 17 | stage: str = 'train', 18 | task: str = 'ranking', 19 | filtered: bool = False, 20 | return_classes: bool = False 21 | ) -> typing.Union[matchzoo.DataPack, tuple]: 22 | """ 23 | Load WikiQA data. 24 | 25 | :param stage: One of `train`, `dev`, and `test`. 26 | :param task: Could be one of `ranking`, `classification` or a 27 | :class:`matchzoo.engine.BaseTask` instance. 28 | :param filtered: Whether remove the questions without correct answers. 29 | :param return_classes: `True` to return classes for classification task, 30 | `False` otherwise. 31 | 32 | :return: A DataPack unless `task` is `classificiation` and `return_classes` 33 | is `True`: a tuple of `(DataPack, classes)` in that case. 34 | """ 35 | if stage not in ('train', 'dev', 'test'): 36 | raise ValueError(f"{stage} is not a valid stage." 37 | f"Must be one of `train`, `dev`, and `test`.") 38 | 39 | data_root = _download_data() 40 | file_path = data_root.joinpath(f'WikiQA-{stage}.tsv') 41 | data_pack = _read_data(file_path) 42 | if filtered and stage in ('dev', 'test'): 43 | ref_path = data_root.joinpath(f'WikiQA-{stage}.ref') 44 | filter_ref_path = data_root.joinpath(f'WikiQA-{stage}-filtered.ref') 45 | with open(filter_ref_path, mode='r') as f: 46 | filtered_ids = set([line.split()[0] for line in f]) 47 | filtered_lines = [] 48 | with open(ref_path, mode='r') as f: 49 | for idx, line in enumerate(f.readlines()): 50 | if line.split()[0] in filtered_ids: 51 | filtered_lines.append(idx) 52 | data_pack = data_pack[filtered_lines] 53 | 54 | if task == 'ranking': 55 | task = matchzoo.tasks.Ranking() 56 | if task == 'classification': 57 | task = matchzoo.tasks.Classification() 58 | 59 | if isinstance(task, matchzoo.tasks.Ranking): 60 | return data_pack 61 | elif isinstance(task, matchzoo.tasks.Classification): 62 | data_pack.one_hot_encode_label(task.num_classes, inplace=True) 63 | if return_classes: 64 | return data_pack, [False, True] 65 | else: 66 | return data_pack 67 | else: 68 | raise ValueError(f"{task} is not a valid task." 69 | f"Must be one of `Ranking` and `Classification`.") 70 | 71 | 72 | def _download_data(): 73 | ref_path = keras.utils.data_utils.get_file( 74 | 'wikiqa', _url, extract=True, 75 | cache_dir=matchzoo.USER_DATA_DIR, 76 | cache_subdir='wiki_qa' 77 | ) 78 | return Path(ref_path).parent.joinpath('WikiQACorpus') 79 | 80 | 81 | def _read_data(path): 82 | table = pd.read_csv(path, sep='\t', header=0, quoting=csv.QUOTE_NONE) 83 | df = pd.DataFrame({ 84 | 'text_left': table['Question'], 85 | 'text_right': table['Sentence'], 86 | 'id_left': table['QuestionID'], 87 | 'id_right': table['SentenceID'], 88 | 'label': table['Label'] 89 | }) 90 | return matchzoo.pack(df) 91 | -------------------------------------------------------------------------------- /matchzoo/embedding/__init__.py: -------------------------------------------------------------------------------- 1 | from .embedding import Embedding 2 | from .embedding import load_from_file 3 | -------------------------------------------------------------------------------- /matchzoo/engine/__init__.py: -------------------------------------------------------------------------------- 1 | # `engine` dependencies span across the entire project, so it's better to 2 | # leave this __init__.py empty, and use `from matchzoo.engine.package import 3 | # x` or `from matchzoo.engine import package` instead of `from matchzoo 4 | # import engine`. 5 | -------------------------------------------------------------------------------- /matchzoo/engine/base_metric.py: -------------------------------------------------------------------------------- 1 | """Metric base class and some related utilities.""" 2 | 3 | import abc 4 | 5 | import numpy as np 6 | 7 | 8 | class BaseMetric(abc.ABC): 9 | """Metric base class.""" 10 | 11 | ALIAS = 'base_metric' 12 | 13 | @abc.abstractmethod 14 | def __call__(self, y_true: np.array, y_pred: np.array) -> float: 15 | """ 16 | Call to compute the metric. 17 | 18 | :param y_true: An array of groud truth labels. 19 | :param y_pred: An array of predicted values. 20 | :return: Evaluation of the metric. 21 | """ 22 | 23 | @abc.abstractmethod 24 | def __repr__(self): 25 | """:return: Formated string representation of the metric.""" 26 | 27 | def __eq__(self, other): 28 | """:return: `True` if two metrics are equal, `False` otherwise.""" 29 | return (type(self) is type(other)) and (vars(self) == vars(other)) 30 | 31 | def __hash__(self): 32 | """:return: Hashing value using the metric as `str`.""" 33 | return str(self).__hash__() 34 | 35 | 36 | def sort_and_couple(labels: np.array, scores: np.array) -> np.array: 37 | """Zip the `labels` with `scores` into a single list.""" 38 | couple = list(zip(labels, scores)) 39 | return np.array(sorted(couple, key=lambda x: x[1], reverse=True)) 40 | -------------------------------------------------------------------------------- /matchzoo/engine/base_task.py: -------------------------------------------------------------------------------- 1 | """Base task.""" 2 | 3 | import typing 4 | import abc 5 | 6 | from matchzoo.engine import base_metric 7 | from matchzoo.engine import parse_metric 8 | 9 | 10 | class BaseTask(abc.ABC): 11 | """Base Task, shouldn't be used directly.""" 12 | 13 | def __init__(self, loss=None, metrics=None): 14 | """ 15 | Base task constructor. 16 | 17 | :param loss: By default the first loss in available losses. 18 | :param metrics: 19 | """ 20 | self._loss = loss 21 | self._metrics = self._convert_metrics(metrics) 22 | self._assure_loss() 23 | self._assure_metrics() 24 | 25 | def _convert_metrics(self, metrics): 26 | if not metrics: 27 | metrics = [] 28 | elif not isinstance(metrics, list): 29 | metrics = [metrics] 30 | return [ 31 | parse_metric.parse_metric(metric, self) for metric in metrics 32 | ] 33 | 34 | def _assure_loss(self): 35 | if not self._loss: 36 | self._loss = self.list_available_losses()[0] 37 | 38 | def _assure_metrics(self): 39 | if not self._metrics: 40 | first_available = self.list_available_metrics()[0] 41 | self._metrics = self._convert_metrics(first_available) 42 | 43 | @property 44 | def loss(self): 45 | """:return: Loss used in the task.""" 46 | return self._loss 47 | 48 | @property 49 | def metrics(self): 50 | """:return: Metrics used in the task.""" 51 | return self._metrics 52 | 53 | @metrics.setter 54 | def metrics( 55 | self, 56 | new_metrics: typing.Union[ 57 | typing.List[str], 58 | typing.List[base_metric.BaseMetric], 59 | str, 60 | base_metric.BaseMetric 61 | ] 62 | ): 63 | self._metrics = self._convert_metrics(new_metrics) 64 | 65 | @classmethod 66 | @abc.abstractmethod 67 | def list_available_losses(cls) -> list: 68 | """:return: a list of available losses.""" 69 | 70 | @classmethod 71 | @abc.abstractmethod 72 | def list_available_metrics(cls) -> list: 73 | """:return: a list of available metrics.""" 74 | 75 | @property 76 | @abc.abstractmethod 77 | def output_shape(self) -> tuple: 78 | """:return: output shape of a single sample of the task.""" 79 | 80 | @property 81 | @abc.abstractmethod 82 | def output_dtype(self): 83 | """:return: output data type for specific task.""" 84 | -------------------------------------------------------------------------------- /matchzoo/engine/callbacks.py: -------------------------------------------------------------------------------- 1 | """Callbacks.""" 2 | import typing 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import keras 7 | 8 | import matchzoo 9 | from matchzoo.engine.base_model import BaseModel 10 | 11 | 12 | class EvaluateAllMetrics(keras.callbacks.Callback): 13 | """ 14 | Callback to evaluate all metrics. 15 | 16 | MatchZoo metrics can not be evaluated batch-wise since they require 17 | dataset-level information. As a result, MatchZoo metrics are not 18 | evaluated automatically when a Model `fit`. When this callback is used, 19 | all metrics, including MatchZoo metrics and Keras metrics, are evluated 20 | once every `once_every` epochs. 21 | 22 | :param model: Model to evaluate. 23 | :param x: X. 24 | :param y: y. 25 | :param once_every: Evaluation only triggers when `epoch % once_every == 0`. 26 | (default: 1, i.e. evaluate on every epoch's end) 27 | :param batch_size: Number of samples per evaluation. This only affects the 28 | evaluation of Keras metrics, since MatchZoo metrics are always 29 | evaluated using the full data. 30 | :param model_save_path: Directory path to save the model after each 31 | evaluate callback, (default: None, i.e., no saving.) 32 | :param verbose: Verbosity. 33 | """ 34 | 35 | def __init__( 36 | self, 37 | model: 'BaseModel', 38 | x: typing.Union[np.ndarray, typing.List[np.ndarray]], 39 | y: np.ndarray, 40 | once_every: int = 1, 41 | batch_size: int = 128, 42 | model_save_path: str = None, 43 | verbose=1 44 | ): 45 | """Initializer.""" 46 | super().__init__() 47 | self._model = model 48 | self._dev_x = x 49 | self._dev_y = y 50 | self._valid_steps = once_every 51 | self._batch_size = batch_size 52 | self._model_save_path = model_save_path 53 | self._verbose = verbose 54 | 55 | def on_epoch_end(self, epoch: int, logs: dict = None): 56 | """ 57 | Called at the end of en epoch. 58 | 59 | :param epoch: integer, index of epoch. 60 | :param logs: dictionary of logs. 61 | :return: dictionary of logs. 62 | """ 63 | if (epoch + 1) % self._valid_steps == 0: 64 | val_logs = self._model.evaluate(self._dev_x, self._dev_y, 65 | self._batch_size) 66 | if self._verbose: 67 | print('Validation: ' + ' - '.join( 68 | f'{k}: {v}' for k, v in val_logs.items())) 69 | for k, v in val_logs.items(): 70 | logs[k] = v 71 | if self._model_save_path: 72 | curr_path = self._model_save_path + str('%d/' % (epoch + 1)) 73 | self._model.save(curr_path) 74 | -------------------------------------------------------------------------------- /matchzoo/engine/parse_metric.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import matchzoo 4 | from matchzoo.engine.base_metric import BaseMetric 5 | from matchzoo.engine import base_task 6 | 7 | 8 | def parse_metric( 9 | metric: typing.Union[str, typing.Type[BaseMetric], BaseMetric], 10 | task: 'base_task.BaseTask' = None 11 | ) -> typing.Union['BaseMetric', str]: 12 | """ 13 | Parse input metric in any form into a :class:`BaseMetric` instance. 14 | 15 | :param metric: Input metric in any form. 16 | :param task: Task type for determining specific metric. 17 | :return: A :class:`BaseMetric` instance 18 | 19 | Examples:: 20 | >>> from matchzoo import metrics 21 | >>> from matchzoo.engine.parse_metric import parse_metric 22 | 23 | Use `str` as keras native metrics: 24 | >>> parse_metric('mse') 25 | 'mse' 26 | 27 | Use `str` as MatchZoo metrics: 28 | >>> mz_metric = parse_metric('map') 29 | >>> type(mz_metric) 30 | 31 | 32 | Use :class:`matchzoo.engine.BaseMetric` subclasses as MatchZoo metrics: 33 | >>> type(parse_metric(metrics.AveragePrecision)) 34 | 35 | 36 | Use :class:`matchzoo.engine.BaseMetric` instances as MatchZoo metrics: 37 | >>> type(parse_metric(metrics.AveragePrecision())) 38 | 39 | 40 | """ 41 | if task is None: 42 | task = matchzoo.tasks.Ranking() 43 | 44 | if isinstance(metric, str): 45 | metric = metric.lower() # ignore case 46 | 47 | # matchzoo metrics in str form 48 | for subclass in BaseMetric.__subclasses__(): 49 | if metric == subclass.ALIAS or metric in subclass.ALIAS: 50 | return subclass() 51 | 52 | # keras native metrics 53 | return _remap_keras_metric(metric, task) 54 | elif isinstance(metric, BaseMetric): 55 | return metric 56 | elif issubclass(metric, BaseMetric): 57 | return metric() 58 | else: 59 | raise ValueError(metric) 60 | 61 | 62 | def _remap_keras_metric(metric: str, task) -> str: 63 | # we do not support sparse label in classification. 64 | lookup = { 65 | matchzoo.tasks.Ranking: { 66 | 'acc': 'binary_accuracy', 67 | 'accuracy': 'binary_accuracy', 68 | 'crossentropy': 'binary_crossentropy', 69 | 'ce': 'binary_crossentropy', 70 | }, 71 | matchzoo.tasks.Classification: { 72 | 'acc': 'categorical_accuracy', 73 | 'accuracy': 'categorical_accuracy', 74 | 'crossentropy': 'categorical_crossentropy', 75 | 'ce': 'categorical_crossentropy', 76 | } 77 | } 78 | return lookup[type(task)].get(metric, metric) 79 | -------------------------------------------------------------------------------- /matchzoo/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .matching_layer import MatchingLayer 2 | from .dynamic_pooling_layer import DynamicPoolingLayer 3 | 4 | layer_dict = { 5 | "MatchingLayer": MatchingLayer, 6 | "DynamicPoolingLayer": DynamicPoolingLayer 7 | } 8 | -------------------------------------------------------------------------------- /matchzoo/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .rank_cross_entropy_loss import RankCrossEntropyLoss 2 | from .rank_hinge_loss import RankHingeLoss 3 | -------------------------------------------------------------------------------- /matchzoo/losses/rank_cross_entropy_loss.py: -------------------------------------------------------------------------------- 1 | """The rank cross entropy loss.""" 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | from keras import layers, backend as K 6 | from keras.losses import Loss 7 | from keras.utils import losses_utils 8 | 9 | 10 | class RankCrossEntropyLoss(Loss): 11 | """ 12 | Rank cross entropy loss. 13 | 14 | Examples: 15 | >>> from keras import backend as K 16 | >>> softmax = lambda x: np.exp(x)/np.sum(np.exp(x), axis=0) 17 | >>> x_pred = K.variable(np.array([[1.0], [1.2], [0.8]])) 18 | >>> x_true = K.variable(np.array([[1], [0], [0]])) 19 | >>> expect = -np.log(softmax(np.array([[1.0], [1.2], [0.8]]))) 20 | >>> loss = K.eval(RankCrossEntropyLoss(num_neg=2)(x_true, x_pred)) 21 | >>> np.isclose(loss, expect[0]).all() 22 | True 23 | 24 | """ 25 | 26 | def __init__(self, num_neg: int = 1): 27 | """ 28 | :class:`RankCrossEntropyLoss` constructor. 29 | 30 | :param num_neg: number of negative instances in cross entropy loss. 31 | """ 32 | super().__init__(reduction=losses_utils.Reduction.SUM_OVER_BATCH_SIZE, 33 | name="rank_crossentropy") 34 | self._num_neg = num_neg 35 | 36 | def call(self, y_true: np.array, y_pred: np.array, 37 | sample_weight=None) -> np.array: 38 | """ 39 | Calculate rank cross entropy loss. 40 | 41 | :param y_true: Label. 42 | :param y_pred: Predicted result. 43 | :return: Crossentropy loss computed by user-defined negative number. 44 | """ 45 | logits = layers.Lambda(lambda a: a[::(self._num_neg + 1), :])(y_pred) 46 | labels = layers.Lambda(lambda a: a[::(self._num_neg + 1), :])(y_true) 47 | logits, labels = [logits], [labels] 48 | for neg_idx in range(self._num_neg): 49 | neg_logits = layers.Lambda( 50 | lambda a: a[neg_idx + 1::(self._num_neg + 1), :])(y_pred) 51 | neg_labels = layers.Lambda( 52 | lambda a: a[neg_idx + 1::(self._num_neg + 1), :])(y_true) 53 | logits.append(neg_logits) 54 | labels.append(neg_labels) 55 | logits = tf.concat(logits, axis=-1) 56 | labels = tf.concat(labels, axis=-1) 57 | smoothed_prob = tf.nn.softmax(logits) + np.finfo(float).eps 58 | loss = -(tf.reduce_sum(labels * tf.math.log(smoothed_prob), axis=-1)) 59 | return losses_utils.compute_weighted_loss( 60 | loss, sample_weight, reduction=self.reduction) 61 | 62 | @property 63 | def num_neg(self): 64 | """`num_neg` getter.""" 65 | return self._num_neg 66 | -------------------------------------------------------------------------------- /matchzoo/losses/rank_hinge_loss.py: -------------------------------------------------------------------------------- 1 | """The rank hinge loss.""" 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | from keras import layers, backend as K 6 | from keras.losses import Loss 7 | from keras.utils import losses_utils 8 | 9 | 10 | class RankHingeLoss(Loss): 11 | """ 12 | Rank hinge loss. 13 | 14 | Examples: 15 | >>> from keras import backend as K 16 | >>> x_pred = K.variable(np.array([[1.0], [1.2], [0.8], [1.4]])) 17 | >>> x_true = K.variable(np.array([[1], [0], [1], [0]])) 18 | >>> expect = ((1.0 + 1.2 - 1.0) + (1.0 + 1.4 - 0.8)) / 2 19 | >>> expect 20 | 1.4 21 | >>> loss = K.eval(RankHingeLoss(num_neg=1, margin=1.0)(x_true, x_pred)) 22 | >>> np.isclose(loss, expect) 23 | True 24 | 25 | """ 26 | 27 | def __init__(self, num_neg: int = 1, margin: float = 1.0): 28 | """ 29 | :class:`RankHingeLoss` constructor. 30 | 31 | :param num_neg: number of negative instances in hinge loss. 32 | :param margin: the margin between positive and negative scores. 33 | """ 34 | super().__init__(reduction=losses_utils.Reduction.SUM_OVER_BATCH_SIZE, 35 | name="rank_hinge") 36 | 37 | self._num_neg = num_neg 38 | self._margin = margin 39 | 40 | def call(self, y_true: np.array, y_pred: np.array, 41 | sample_weight=None) -> np.array: 42 | """ 43 | Calculate rank hinge loss. 44 | 45 | :param y_true: Label. 46 | :param y_pred: Predicted result. 47 | :return: Hinge loss computed by user-defined margin. 48 | """ 49 | y_pos = layers.Lambda(lambda a: a[::(self._num_neg + 1), :], 50 | output_shape=(1,))(y_pred) 51 | y_neg = [] 52 | for neg_idx in range(self._num_neg): 53 | y_neg.append( 54 | layers.Lambda( 55 | lambda a: a[(neg_idx + 1)::(self._num_neg + 1), :], 56 | output_shape=(1,))(y_pred)) 57 | y_neg = tf.concat(y_neg, axis=-1) 58 | y_neg = tf.reduce_mean(y_neg, axis=-1, keepdims=True) 59 | loss = tf.maximum(0., self._margin + y_neg - y_pos) 60 | return losses_utils.compute_weighted_loss( 61 | loss, sample_weight, reduction=self.reduction) 62 | 63 | @property 64 | def num_neg(self): 65 | """`num_neg` getter.""" 66 | return self._num_neg 67 | 68 | @property 69 | def margin(self): 70 | """`margin` getter.""" 71 | return self._margin 72 | -------------------------------------------------------------------------------- /matchzoo/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .precision import Precision 2 | from .average_precision import AveragePrecision 3 | from .discounted_cumulative_gain import DiscountedCumulativeGain 4 | from .mean_reciprocal_rank import MeanReciprocalRank 5 | from .mean_average_precision import MeanAveragePrecision 6 | from .normalized_discounted_cumulative_gain import \ 7 | NormalizedDiscountedCumulativeGain 8 | 9 | 10 | def list_available() -> list: 11 | from matchzoo.engine.base_metric import BaseMetric 12 | from matchzoo.utils import list_recursive_concrete_subclasses 13 | return list_recursive_concrete_subclasses(BaseMetric) 14 | -------------------------------------------------------------------------------- /matchzoo/metrics/average_precision.py: -------------------------------------------------------------------------------- 1 | """Average precision metric for ranking.""" 2 | import numpy as np 3 | 4 | from matchzoo.engine import base_metric 5 | from . import Precision 6 | 7 | 8 | class AveragePrecision(base_metric.BaseMetric): 9 | """Average precision metric.""" 10 | 11 | ALIAS = ['average_precision', 'ap'] 12 | 13 | def __init__(self, threshold: float = 0.): 14 | """ 15 | :class:`AveragePrecision` constructor. 16 | 17 | :param threshold: The label threshold of relevance degree. 18 | """ 19 | self._threshold = threshold 20 | 21 | def __repr__(self) -> str: 22 | """:return: Formated string representation of the metric.""" 23 | return f"{self.ALIAS[0]}({self._threshold})" 24 | 25 | def __call__(self, y_true: np.array, y_pred: np.array) -> float: 26 | """ 27 | Calculate average precision (area under PR curve). 28 | 29 | Example: 30 | >>> y_true = [0, 1] 31 | >>> y_pred = [0.1, 0.6] 32 | >>> round(AveragePrecision()(y_true, y_pred), 2) 33 | 0.75 34 | >>> round(AveragePrecision()([], []), 2) 35 | 0.0 36 | 37 | :param y_true: The ground true label of each document. 38 | :param y_pred: The predicted scores of each document. 39 | :return: Average precision. 40 | """ 41 | precision_metrics = [Precision(k + 1) for k in range(len(y_pred))] 42 | out = [metric(y_true, y_pred) for metric in precision_metrics] 43 | if not out: 44 | return 0. 45 | return np.asscalar(np.mean(out)) 46 | -------------------------------------------------------------------------------- /matchzoo/metrics/discounted_cumulative_gain.py: -------------------------------------------------------------------------------- 1 | """Discounted cumulative gain metric for ranking.""" 2 | import math 3 | 4 | import numpy as np 5 | 6 | from matchzoo.engine.base_metric import BaseMetric, sort_and_couple 7 | 8 | 9 | class DiscountedCumulativeGain(BaseMetric): 10 | """Disconunted cumulative gain metric.""" 11 | 12 | ALIAS = ['discounted_cumulative_gain', 'dcg'] 13 | 14 | def __init__(self, k: int = 1, threshold: float = 0.): 15 | """ 16 | :class:`DiscountedCumulativeGain` constructor. 17 | 18 | :param k: Number of results to consider. 19 | :param threshold: the label threshold of relevance degree. 20 | """ 21 | self._k = k 22 | self._threshold = threshold 23 | 24 | def __repr__(self) -> str: 25 | """:return: Formated string representation of the metric.""" 26 | return f"{self.ALIAS[0]}@{self._k}({self._threshold})" 27 | 28 | def __call__(self, y_true: np.array, y_pred: np.array) -> float: 29 | """ 30 | Calculate discounted cumulative gain (dcg). 31 | 32 | Relevance is positive real values or binary values. 33 | 34 | Example: 35 | >>> y_true = [0, 1, 2, 0] 36 | >>> y_pred = [0.4, 0.2, 0.5, 0.7] 37 | >>> DiscountedCumulativeGain(1)(y_true, y_pred) 38 | 0.0 39 | >>> round(DiscountedCumulativeGain(k=-1)(y_true, y_pred), 2) 40 | 0.0 41 | >>> round(DiscountedCumulativeGain(k=2)(y_true, y_pred), 2) 42 | 2.73 43 | >>> round(DiscountedCumulativeGain(k=3)(y_true, y_pred), 2) 44 | 2.73 45 | >>> type(DiscountedCumulativeGain(k=1)(y_true, y_pred)) 46 | 47 | 48 | :param y_true: The ground true label of each document. 49 | :param y_pred: The predicted scores of each document. 50 | 51 | :return: Discounted cumulative gain. 52 | """ 53 | if self._k <= 0: 54 | return 0. 55 | coupled_pair = sort_and_couple(y_true, y_pred) 56 | result = 0. 57 | for i, (label, score) in enumerate(coupled_pair): 58 | if i >= self._k: 59 | break 60 | if label > self._threshold: 61 | result += (math.pow(2., label) - 1.) / math.log(2. + i) 62 | return result 63 | -------------------------------------------------------------------------------- /matchzoo/metrics/mean_average_precision.py: -------------------------------------------------------------------------------- 1 | """Mean average precision metric for ranking.""" 2 | import numpy as np 3 | 4 | from matchzoo.engine.base_metric import BaseMetric, sort_and_couple 5 | 6 | 7 | class MeanAveragePrecision(BaseMetric): 8 | """Mean average precision metric.""" 9 | 10 | ALIAS = ['mean_average_precision', 'map'] 11 | 12 | def __init__(self, threshold: float = 0.): 13 | """ 14 | :class:`MeanAveragePrecision` constructor. 15 | 16 | :param threshold: The threshold of relevance degree. 17 | """ 18 | self._threshold = threshold 19 | 20 | def __repr__(self): 21 | """:return: Formated string representation of the metric.""" 22 | return f"{self.ALIAS[0]}({self._threshold})" 23 | 24 | def __call__(self, y_true: np.array, y_pred: np.array) -> float: 25 | """ 26 | Calculate mean average precision. 27 | 28 | Example: 29 | >>> y_true = [0, 1, 0, 0] 30 | >>> y_pred = [0.1, 0.6, 0.2, 0.3] 31 | >>> MeanAveragePrecision()(y_true, y_pred) 32 | 1.0 33 | 34 | :param y_true: The ground true label of each document. 35 | :param y_pred: The predicted scores of each document. 36 | :return: Mean average precision. 37 | """ 38 | result = 0. 39 | pos = 0 40 | coupled_pair = sort_and_couple(y_true, y_pred) 41 | for idx, (label, score) in enumerate(coupled_pair): 42 | if label > self._threshold: 43 | pos += 1. 44 | result += pos / (idx + 1.) 45 | if pos == 0: 46 | return 0. 47 | else: 48 | return result / pos 49 | -------------------------------------------------------------------------------- /matchzoo/metrics/mean_reciprocal_rank.py: -------------------------------------------------------------------------------- 1 | """Mean reciprocal ranking metric.""" 2 | import numpy as np 3 | 4 | from matchzoo.engine.base_metric import BaseMetric, sort_and_couple 5 | 6 | 7 | class MeanReciprocalRank(BaseMetric): 8 | """Mean reciprocal rank metric.""" 9 | 10 | ALIAS = ['mean_reciprocal_rank', 'mrr'] 11 | 12 | def __init__(self, threshold: float = 0.): 13 | """ 14 | :class:`MeanReciprocalRankMetric`. 15 | 16 | :param threshold: The label threshold of relevance degree. 17 | """ 18 | self._threshold = threshold 19 | 20 | def __repr__(self) -> str: 21 | """:return: Formated string representation of the metric.""" 22 | return f'{self.ALIAS[0]}({self._threshold})' 23 | 24 | def __call__(self, y_true: np.array, y_pred: np.array) -> float: 25 | """ 26 | Calculate reciprocal of the rank of the first relevant item. 27 | 28 | Example: 29 | >>> import numpy as np 30 | >>> y_pred = np.asarray([0.2, 0.3, 0.7, 1.0]) 31 | >>> y_true = np.asarray([1, 0, 0, 0]) 32 | >>> MeanReciprocalRank()(y_true, y_pred) 33 | 0.25 34 | 35 | :param y_true: The ground true label of each document. 36 | :param y_pred: The predicted scores of each document. 37 | :return: Mean reciprocal rank. 38 | """ 39 | coupled_pair = sort_and_couple(y_true, y_pred) 40 | for idx, (label, pred) in enumerate(coupled_pair): 41 | if label > self._threshold: 42 | return 1. / (idx + 1) 43 | return 0. 44 | -------------------------------------------------------------------------------- /matchzoo/metrics/normalized_discounted_cumulative_gain.py: -------------------------------------------------------------------------------- 1 | """Normalized discounted cumulative gain metric for ranking.""" 2 | import numpy as np 3 | 4 | from matchzoo.engine.base_metric import BaseMetric, sort_and_couple 5 | from .discounted_cumulative_gain import DiscountedCumulativeGain 6 | 7 | 8 | class NormalizedDiscountedCumulativeGain(BaseMetric): 9 | """Normalized discounted cumulative gain metric.""" 10 | 11 | ALIAS = ['normalized_discounted_cumulative_gain', 'ndcg'] 12 | 13 | def __init__(self, k: int = 1, threshold: float = 0.): 14 | """ 15 | :class:`NormalizedDiscountedCumulativeGain` constructor. 16 | 17 | :param k: Number of results to consider 18 | :param threshold: the label threshold of relevance degree. 19 | """ 20 | self._k = k 21 | self._threshold = threshold 22 | 23 | def __repr__(self) -> str: 24 | """:return: Formated string representation of the metric.""" 25 | return f"{self.ALIAS[0]}@{self._k}({self._threshold})" 26 | 27 | def __call__(self, y_true: np.array, y_pred: np.array) -> float: 28 | """ 29 | Calculate normalized discounted cumulative gain (ndcg). 30 | 31 | Relevance is positive real values or binary values. 32 | 33 | Example: 34 | >>> y_true = [0, 1, 2, 0] 35 | >>> y_pred = [0.4, 0.2, 0.5, 0.7] 36 | >>> ndcg = NormalizedDiscountedCumulativeGain 37 | >>> ndcg(k=1)(y_true, y_pred) 38 | 0.0 39 | >>> round(ndcg(k=2)(y_true, y_pred), 2) 40 | 0.52 41 | >>> round(ndcg(k=3)(y_true, y_pred), 2) 42 | 0.52 43 | >>> type(ndcg()(y_true, y_pred)) 44 | 45 | 46 | :param y_true: The ground true label of each document. 47 | :param y_pred: The predicted scores of each document. 48 | 49 | :return: Normalized discounted cumulative gain. 50 | """ 51 | dcg_metric = DiscountedCumulativeGain(k=self._k, 52 | threshold=self._threshold) 53 | idcg_val = dcg_metric(y_true, y_true) 54 | dcg_val = dcg_metric(y_true, y_pred) 55 | return dcg_val / idcg_val if idcg_val != 0 else 0 56 | -------------------------------------------------------------------------------- /matchzoo/metrics/precision.py: -------------------------------------------------------------------------------- 1 | """Precision for ranking.""" 2 | import numpy as np 3 | 4 | from matchzoo.engine.base_metric import BaseMetric, sort_and_couple 5 | 6 | 7 | class Precision(BaseMetric): 8 | """Precision metric.""" 9 | 10 | ALIAS = 'precision' 11 | 12 | def __init__(self, k: int = 1, threshold: float = 0.): 13 | """ 14 | :class:`PrecisionMetric` constructor. 15 | 16 | :param k: Number of results to consider. 17 | :param threshold: the label threshold of relevance degree. 18 | """ 19 | self._k = k 20 | self._threshold = threshold 21 | 22 | def __repr__(self) -> str: 23 | """:return: Formated string representation of the metric.""" 24 | return f"{self.ALIAS}@{self._k}({self._threshold})" 25 | 26 | def __call__(self, y_true: np.array, y_pred: np.array) -> float: 27 | """ 28 | Calculate precision@k. 29 | 30 | Example: 31 | >>> y_true = [0, 0, 0, 1] 32 | >>> y_pred = [0.2, 0.4, 0.3, 0.1] 33 | >>> Precision(k=1)(y_true, y_pred) 34 | 0.0 35 | >>> Precision(k=2)(y_true, y_pred) 36 | 0.0 37 | >>> Precision(k=4)(y_true, y_pred) 38 | 0.25 39 | >>> Precision(k=5)(y_true, y_pred) 40 | 0.2 41 | 42 | :param y_true: The ground true label of each document. 43 | :param y_pred: The predicted scores of each document. 44 | :return: Precision @ k 45 | :raises: ValueError: len(r) must be >= k. 46 | """ 47 | if self._k <= 0: 48 | raise ValueError(f"k must be greater than 0." 49 | f"{self._k} received.") 50 | coupled_pair = sort_and_couple(y_true, y_pred) 51 | precision = 0.0 52 | for idx, (label, score) in enumerate(coupled_pair): 53 | if idx >= self._k: 54 | break 55 | if label > self._threshold: 56 | precision += 1. 57 | return precision / self._k 58 | -------------------------------------------------------------------------------- /matchzoo/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .naive import Naive 2 | from .dssm import DSSM 3 | from .cdssm import CDSSM 4 | from .dense_baseline import DenseBaseline 5 | from .arci import ArcI 6 | from .arcii import ArcII 7 | from .match_pyramid import MatchPyramid 8 | from .knrm import KNRM 9 | from .conv_knrm import ConvKNRM 10 | from .duet import DUET 11 | from .drmmtks import DRMMTKS 12 | from .drmm import DRMM 13 | from .anmm import ANMM 14 | from .mvlstm import MVLSTM 15 | 16 | 17 | def list_available() -> list: 18 | from matchzoo.engine.base_model import BaseModel 19 | from matchzoo.utils import list_recursive_concrete_subclasses 20 | return list_recursive_concrete_subclasses(BaseModel) 21 | -------------------------------------------------------------------------------- /matchzoo/models/anmm.py: -------------------------------------------------------------------------------- 1 | """An implementation of aNMM Model.""" 2 | 3 | import keras 4 | from keras.activations import softmax 5 | from keras.initializers import RandomUniform 6 | 7 | from matchzoo.engine.base_model import BaseModel 8 | from matchzoo.engine.param import Param 9 | from matchzoo.engine.param_table import ParamTable 10 | from matchzoo.engine import hyper_spaces 11 | 12 | 13 | class ANMM(BaseModel): 14 | """ 15 | ANMM Model. 16 | 17 | Examples: 18 | >>> model = ANMM() 19 | >>> model.guess_and_fill_missing_params(verbose=0) 20 | >>> model.build() 21 | 22 | """ 23 | 24 | @classmethod 25 | def get_default_params(cls) -> ParamTable: 26 | """:return: model default parameters.""" 27 | params = super().get_default_params(with_embedding=True) 28 | params.add(Param( 29 | name='dropout_rate', value=0.1, 30 | desc="The dropout rate.", 31 | hyper_space=hyper_spaces.quniform(0, 1, 0.05) 32 | )) 33 | params.add(Param( 34 | name='num_layers', value=2, 35 | desc="Number of hidden layers in the MLP " 36 | "layer." 37 | )) 38 | params.add(Param( 39 | name='hidden_sizes', value=[30, 30], 40 | desc="Number of hidden size for each hidden" 41 | " layer" 42 | )) 43 | return params 44 | 45 | def build(self): 46 | """ 47 | Build model structure. 48 | 49 | aNMM model based on bin weighting and query term attentions 50 | """ 51 | # query is [batch_size, left_text_len] 52 | # doc is [batch_size, right_text_len, bin_num] 53 | query, doc = self._make_inputs() 54 | embedding = self._make_embedding_layer() 55 | 56 | q_embed = embedding(query) 57 | q_attention = keras.layers.Dense( 58 | 1, kernel_initializer=RandomUniform(), use_bias=False)(q_embed) 59 | q_text_len = self._params['input_shapes'][0][0] 60 | 61 | q_attention = keras.layers.Lambda( 62 | lambda x: softmax(x, axis=1), 63 | output_shape=(q_text_len,) 64 | )(q_attention) 65 | d_bin = keras.layers.Dropout( 66 | rate=self._params['dropout_rate'])(doc) 67 | for layer_id in range(self._params['num_layers'] - 1): 68 | d_bin = keras.layers.Dense( 69 | self._params['hidden_sizes'][layer_id], 70 | kernel_initializer=RandomUniform())(d_bin) 71 | d_bin = keras.layers.Activation('tanh')(d_bin) 72 | d_bin = keras.layers.Dense( 73 | self._params['hidden_sizes'][self._params['num_layers'] - 1])( 74 | d_bin) 75 | d_bin = keras.layers.Reshape((q_text_len,))(d_bin) 76 | q_attention = keras.layers.Reshape((q_text_len,))(q_attention) 77 | score = keras.layers.Dot(axes=[1, 1])([d_bin, q_attention]) 78 | x_out = self._make_output_layer()(score) 79 | self._backend = keras.Model(inputs=[query, doc], outputs=x_out) 80 | -------------------------------------------------------------------------------- /matchzoo/models/dense_baseline.py: -------------------------------------------------------------------------------- 1 | """A simple densely connected baseline model.""" 2 | 3 | import keras.layers 4 | 5 | from matchzoo.engine.base_model import BaseModel 6 | from matchzoo.engine.param_table import ParamTable 7 | from matchzoo.engine import hyper_spaces 8 | 9 | 10 | class DenseBaseline(BaseModel): 11 | """ 12 | A simple densely connected baseline model. 13 | 14 | Examples: 15 | >>> model = DenseBaseline() 16 | >>> model.params['mlp_num_layers'] = 2 17 | >>> model.params['mlp_num_units'] = 300 18 | >>> model.params['mlp_num_fan_out'] = 128 19 | >>> model.params['mlp_activation_func'] = 'relu' 20 | >>> model.guess_and_fill_missing_params(verbose=0) 21 | >>> model.build() 22 | >>> model.compile() 23 | 24 | """ 25 | 26 | @classmethod 27 | def get_default_params(cls) -> ParamTable: 28 | """:return: model default parameters.""" 29 | params = super().get_default_params(with_multi_layer_perceptron=True) 30 | params['mlp_num_units'] = 256 31 | params.get('mlp_num_units').hyper_space = \ 32 | hyper_spaces.quniform(16, 512) 33 | params.get('mlp_num_layers').hyper_space = \ 34 | hyper_spaces.quniform(1, 5) 35 | return params 36 | 37 | def build(self): 38 | """Model structure.""" 39 | x_in = self._make_inputs() 40 | x = keras.layers.concatenate(x_in) 41 | x = self._make_multi_layer_perceptron_layer()(x) 42 | x_out = self._make_output_layer()(x) 43 | self._backend = keras.models.Model(inputs=x_in, outputs=x_out) 44 | -------------------------------------------------------------------------------- /matchzoo/models/dssm.py: -------------------------------------------------------------------------------- 1 | """An implementation of DSSM, Deep Structured Semantic Model.""" 2 | from keras.models import Model 3 | from keras.layers import Input, Dot 4 | 5 | from matchzoo.engine.param_table import ParamTable 6 | from matchzoo.engine.base_model import BaseModel 7 | from matchzoo import preprocessors 8 | 9 | 10 | class DSSM(BaseModel): 11 | """ 12 | Deep structured semantic model. 13 | 14 | Examples: 15 | >>> model = DSSM() 16 | >>> model.params['mlp_num_layers'] = 3 17 | >>> model.params['mlp_num_units'] = 300 18 | >>> model.params['mlp_num_fan_out'] = 128 19 | >>> model.params['mlp_activation_func'] = 'relu' 20 | >>> model.guess_and_fill_missing_params(verbose=0) 21 | >>> model.build() 22 | 23 | """ 24 | 25 | @classmethod 26 | def get_default_params(cls) -> ParamTable: 27 | """:return: model default parameters.""" 28 | params = super().get_default_params(with_multi_layer_perceptron=True) 29 | return params 30 | 31 | def build(self): 32 | """ 33 | Build model structure. 34 | 35 | DSSM use Siamese arthitecture. 36 | """ 37 | dim_triletter = self._params['input_shapes'][0][0] 38 | input_shape = (dim_triletter,) 39 | base_network = self._make_multi_layer_perceptron_layer() 40 | # Left input and right input. 41 | input_left = Input(name='text_left', shape=input_shape) 42 | input_right = Input(name='text_right', shape=input_shape) 43 | # Process left & right input. 44 | x = [base_network(input_left), 45 | base_network(input_right)] 46 | # Dot product with cosine similarity. 47 | x = Dot(axes=[1, 1], normalize=True)(x) 48 | x_out = self._make_output_layer()(x) 49 | self._backend = Model( 50 | inputs=[input_left, input_right], 51 | outputs=x_out) 52 | 53 | @classmethod 54 | def get_default_preprocessor(cls): 55 | """:return: Default preprocessor.""" 56 | return preprocessors.DSSMPreprocessor() 57 | -------------------------------------------------------------------------------- /matchzoo/models/mvlstm.py: -------------------------------------------------------------------------------- 1 | """An implementation of MVLSTM Model.""" 2 | 3 | import keras 4 | import tensorflow as tf 5 | 6 | from matchzoo.engine import hyper_spaces 7 | from matchzoo.engine.base_model import BaseModel 8 | from matchzoo.engine.param import Param 9 | from matchzoo.engine.param_table import ParamTable 10 | 11 | 12 | class MVLSTM(BaseModel): 13 | """ 14 | MVLSTM Model. 15 | 16 | Examples: 17 | >>> model = MVLSTM() 18 | >>> model.params['lstm_units'] = 32 19 | >>> model.params['top_k'] = 50 20 | >>> model.params['mlp_num_layers'] = 2 21 | >>> model.params['mlp_num_units'] = 20 22 | >>> model.params['mlp_num_fan_out'] = 10 23 | >>> model.params['mlp_activation_func'] = 'relu' 24 | >>> model.params['dropout_rate'] = 0.5 25 | >>> model.guess_and_fill_missing_params(verbose=0) 26 | >>> model.build() 27 | 28 | """ 29 | 30 | @classmethod 31 | def get_default_params(cls) -> ParamTable: 32 | """:return: model default parameters.""" 33 | params = super().get_default_params( 34 | with_embedding=True, with_multi_layer_perceptron=True) 35 | params.add(Param(name='lstm_units', value=32, 36 | desc="Integer, the hidden size in the " 37 | "bi-directional LSTM layer.")) 38 | params.add(Param(name='dropout_rate', value=0.0, 39 | desc="Float, the dropout rate.")) 40 | params.add(Param( 41 | 'top_k', value=10, 42 | hyper_space=hyper_spaces.quniform(low=2, high=100), 43 | desc="Integer, the size of top-k pooling layer." 44 | )) 45 | params['optimizer'] = 'adam' 46 | return params 47 | 48 | def build(self): 49 | """Build model structure.""" 50 | query, doc = self._make_inputs() 51 | 52 | # Embedding layer 53 | embedding = self._make_embedding_layer(mask_zero=True) 54 | embed_query = embedding(query) 55 | embed_doc = embedding(doc) 56 | 57 | # Bi-directional LSTM layer 58 | rep_query = keras.layers.Bidirectional(keras.layers.LSTM( 59 | self._params['lstm_units'], 60 | return_sequences=True, 61 | dropout=self._params['dropout_rate'] 62 | ))(embed_query) 63 | rep_doc = keras.layers.Bidirectional(keras.layers.LSTM( 64 | self._params['lstm_units'], 65 | return_sequences=True, 66 | dropout=self._params['dropout_rate'] 67 | ))(embed_doc) 68 | 69 | # Top-k matching layer 70 | matching_matrix = keras.layers.Dot( 71 | axes=[2, 2], normalize=False)([rep_query, rep_doc]) 72 | matching_signals = keras.layers.Reshape((-1,))(matching_matrix) 73 | matching_topk = keras.layers.Lambda( 74 | lambda x: tf.nn.top_k(x, k=self._params['top_k'], sorted=True)[0] 75 | )(matching_signals) 76 | 77 | # Multilayer perceptron layer. 78 | mlp = self._make_multi_layer_perceptron_layer()(matching_topk) 79 | mlp = keras.layers.Dropout( 80 | rate=self._params['dropout_rate'])(mlp) 81 | 82 | x_out = self._make_output_layer()(mlp) 83 | self._backend = keras.Model(inputs=[query, doc], outputs=x_out) 84 | -------------------------------------------------------------------------------- /matchzoo/models/naive.py: -------------------------------------------------------------------------------- 1 | """Naive model with a simplest structure for testing purposes.""" 2 | 3 | import keras 4 | 5 | from matchzoo.engine.base_model import BaseModel 6 | from matchzoo.engine import hyper_spaces 7 | 8 | 9 | class Naive(BaseModel): 10 | """ 11 | Naive model with a simplest structure for testing purposes. 12 | 13 | Bare minimum functioning model. The best choice to get things rolling. 14 | The worst choice to fit and evaluate performance. 15 | """ 16 | 17 | @classmethod 18 | def get_default_params(cls): 19 | """Default parameters.""" 20 | params = super().get_default_params() 21 | params.get('optimizer').hyper_space = \ 22 | hyper_spaces.choice(['adam', 'adagrad', 'rmsprop']) 23 | return params 24 | 25 | def build(self): 26 | """Build.""" 27 | x_in = self._make_inputs() 28 | x = keras.layers.concatenate(x_in) 29 | x_out = self._make_output_layer()(x) 30 | self._backend = keras.models.Model(inputs=x_in, outputs=x_out) 31 | -------------------------------------------------------------------------------- /matchzoo/models/parameter_readme_generator.py: -------------------------------------------------------------------------------- 1 | """matchzoo/models/README.md generater.""" 2 | 3 | from pathlib import Path 4 | 5 | import tabulate 6 | import inspect 7 | import pandas as pd 8 | 9 | import matchzoo 10 | 11 | 12 | def _generate(): 13 | full = _make_title() 14 | for model_class in matchzoo.models.list_available(): 15 | full += _make_model_class_subtitle(model_class) 16 | full += _make_doc_section_subsubtitle() 17 | full += _make_model_doc(model_class) 18 | model = model_class() 19 | full += _make_params_section_subsubtitle() 20 | full += _make_model_params_table(model) 21 | _write_to_files(full) 22 | 23 | 24 | def _make_title(): 25 | title = 'MatchZoo Model Reference' 26 | line = '*' * len(title) 27 | return line + '\n' + title + '\n' + line + '\n\n' 28 | 29 | 30 | def _make_model_class_subtitle(model_class): 31 | subtitle = model_class.__name__ 32 | line = '#' * len(subtitle) 33 | return subtitle + '\n' + line + '\n\n' 34 | 35 | 36 | def _make_doc_section_subsubtitle(): 37 | subsubtitle = 'Model Documentation' 38 | line = '*' * len(subsubtitle) 39 | return subsubtitle + '\n' + line + '\n\n' 40 | 41 | 42 | def _make_params_section_subsubtitle(): 43 | subsubtitle = 'Model Hyper Parameters' 44 | line = '*' * len(subsubtitle) 45 | return subsubtitle + '\n' + line + '\n\n' 46 | 47 | 48 | def _make_model_doc(model_class): 49 | return inspect.getdoc(model_class) + '\n\n' 50 | 51 | 52 | def _make_model_params_table(model): 53 | params = model.get_default_params() 54 | df = params.to_frame() 55 | df = df.rename({ 56 | 'Value': 'Default Value', 57 | 'Hyper-Space': 'Default Hyper-Space' 58 | }, axis='columns') 59 | return tabulate.tabulate(df, tablefmt='rst', headers='keys') + '\n\n' 60 | 61 | 62 | def _write_to_files(full): 63 | readme_file_path = Path(__file__).parent.joinpath('README.rst') 64 | doc_file_path = Path(__file__).parent.parent.parent. \ 65 | joinpath('docs').joinpath('source').joinpath('model_reference.rst') 66 | for file_path in readme_file_path, doc_file_path: 67 | with open(file_path, 'w', encoding='utf-8') as out_file: 68 | out_file.write(full) 69 | 70 | 71 | if __name__ == '__main__': 72 | _generate() 73 | -------------------------------------------------------------------------------- /matchzoo/preprocessors/__init__.py: -------------------------------------------------------------------------------- 1 | from . import units 2 | from .dssm_preprocessor import DSSMPreprocessor 3 | from .naive_preprocessor import NaivePreprocessor 4 | from .basic_preprocessor import BasicPreprocessor 5 | from .cdssm_preprocessor import CDSSMPreprocessor 6 | from .diin_preprocessor import DIINPreprocessor 7 | from .bert_preprocessor import BertPreprocessor 8 | 9 | 10 | def list_available() -> list: 11 | from matchzoo.engine.base_preprocessor import BasePreprocessor 12 | from matchzoo.utils import list_recursive_concrete_subclasses 13 | return list_recursive_concrete_subclasses(BasePreprocessor) 14 | -------------------------------------------------------------------------------- /matchzoo/preprocessors/build_unit_from_data_pack.py: -------------------------------------------------------------------------------- 1 | """Build unit from data pack.""" 2 | 3 | from tqdm import tqdm 4 | 5 | import matchzoo as mz 6 | from .units import StatefulUnit 7 | 8 | 9 | def build_unit_from_data_pack( 10 | unit: StatefulUnit, 11 | data_pack: mz.DataPack, mode: str = 'both', 12 | flatten: bool = True, verbose: int = 1 13 | ) -> StatefulUnit: 14 | """ 15 | Build a :class:`StatefulUnit` from a :class:`DataPack` object. 16 | 17 | :param unit: :class:`StatefulUnit` object to be built. 18 | :param data_pack: The input :class:`DataPack` object. 19 | :param mode: One of 'left', 'right', and 'both', to determine the source 20 | data for building the :class:`VocabularyUnit`. 21 | :param flatten: Flatten the datapack or not. `True` to organize the 22 | :class:`DataPack` text as a list, and `False` to organize 23 | :class:`DataPack` text as a list of list. 24 | :param verbose: Verbosity. 25 | :return: A built :class:`StatefulUnit` object. 26 | 27 | """ 28 | corpus = [] 29 | if flatten: 30 | data_pack.apply_on_text(corpus.extend, mode=mode, verbose=verbose) 31 | else: 32 | data_pack.apply_on_text(corpus.append, mode=mode, verbose=verbose) 33 | if verbose: 34 | description = 'Building ' + unit.__class__.__name__ + \ 35 | ' from a datapack.' 36 | corpus = tqdm(corpus, desc=description) 37 | unit.fit(corpus) 38 | return unit 39 | -------------------------------------------------------------------------------- /matchzoo/preprocessors/build_vocab_unit.py: -------------------------------------------------------------------------------- 1 | from matchzoo.data_pack import DataPack 2 | from .units import Vocabulary 3 | from .build_unit_from_data_pack import build_unit_from_data_pack 4 | from .units import BertVocabulary 5 | 6 | 7 | def build_vocab_unit( 8 | data_pack: DataPack, 9 | mode: str = 'both', 10 | verbose: int = 1 11 | ) -> Vocabulary: 12 | """ 13 | Build a :class:`preprocessor.units.Vocabulary` given `data_pack`. 14 | 15 | The `data_pack` should be preprocessed forehand, and each item in 16 | `text_left` and `text_right` columns of the `data_pack` should be a list 17 | of tokens. 18 | 19 | :param data_pack: The :class:`DataPack` to build vocabulary upon. 20 | :param mode: One of 'left', 'right', and 'both', to determine the source 21 | data for building the :class:`VocabularyUnit`. 22 | :param verbose: Verbosity. 23 | :return: A built vocabulary unit. 24 | 25 | """ 26 | return build_unit_from_data_pack( 27 | unit=Vocabulary(), 28 | data_pack=data_pack, 29 | mode=mode, 30 | flatten=True, verbose=verbose 31 | ) 32 | 33 | 34 | def built_bert_vocab_unit(vocab_path: str) -> BertVocabulary: 35 | """ 36 | Build a :class:`preprocessor.units.BertVocabulary` given `vocab_path`. 37 | 38 | :param vocab_path: bert vocabulary path. 39 | :return: A built vocabulary unit. 40 | 41 | """ 42 | vocab_unit = BertVocabulary(pad_value='[PAD]', oov_value='[UNK]') 43 | vocab_unit.fit(vocab_path) 44 | return vocab_unit 45 | -------------------------------------------------------------------------------- /matchzoo/preprocessors/chain_transform.py: -------------------------------------------------------------------------------- 1 | """Wrapper function organizes a number of transform functions.""" 2 | import typing 3 | import functools 4 | 5 | from .units.unit import Unit 6 | 7 | 8 | def chain_transform(units: typing.List[Unit]) -> typing.Callable: 9 | """ 10 | Compose unit transformations into a single function. 11 | 12 | :param units: List of :class:`matchzoo.StatelessUnit`. 13 | """ 14 | 15 | @functools.wraps(chain_transform) 16 | def wrapper(arg): 17 | """Wrapper function of transformations composition.""" 18 | for unit in units: 19 | arg = unit.transform(arg) 20 | return arg 21 | 22 | unit_names = ' => '.join(unit.__class__.__name__ for unit in units) 23 | wrapper.__name__ += ' of ' + unit_names 24 | return wrapper 25 | -------------------------------------------------------------------------------- /matchzoo/preprocessors/naive_preprocessor.py: -------------------------------------------------------------------------------- 1 | """Naive Preprocessor.""" 2 | 3 | from tqdm import tqdm 4 | 5 | from matchzoo.engine.base_preprocessor import BasePreprocessor 6 | from matchzoo import DataPack 7 | from .chain_transform import chain_transform 8 | from .build_vocab_unit import build_vocab_unit 9 | from . import units 10 | 11 | tqdm.pandas() 12 | 13 | 14 | class NaivePreprocessor(BasePreprocessor): 15 | """ 16 | Naive preprocessor. 17 | 18 | Example: 19 | >>> import matchzoo as mz 20 | >>> train_data = mz.datasets.toy.load_data() 21 | >>> test_data = mz.datasets.toy.load_data(stage='test') 22 | >>> preprocessor = mz.preprocessors.NaivePreprocessor() 23 | >>> train_data_processed = preprocessor.fit_transform(train_data, 24 | ... verbose=0) 25 | >>> type(train_data_processed) 26 | 27 | >>> test_data_transformed = preprocessor.transform(test_data, 28 | ... verbose=0) 29 | >>> type(test_data_transformed) 30 | 31 | 32 | """ 33 | 34 | def fit(self, data_pack: DataPack, verbose: int = 1): 35 | """ 36 | Fit pre-processing context for transformation. 37 | 38 | :param data_pack: data_pack to be preprocessed. 39 | :param verbose: Verbosity. 40 | :return: class:`NaivePreprocessor` instance. 41 | """ 42 | func = chain_transform(self._default_units()) 43 | data_pack = data_pack.apply_on_text(func, verbose=verbose) 44 | vocab_unit = build_vocab_unit(data_pack, verbose=verbose) 45 | self._context['vocab_unit'] = vocab_unit 46 | return self 47 | 48 | def transform(self, data_pack: DataPack, verbose: int = 1) -> DataPack: 49 | """ 50 | Apply transformation on data, create `tri-letter` representation. 51 | 52 | :param data_pack: Inputs to be preprocessed. 53 | :param verbose: Verbosity. 54 | 55 | :return: Transformed data as :class:`DataPack` object. 56 | """ 57 | units_ = self._default_units() 58 | units_.append(self._context['vocab_unit']) 59 | units_.append(units.FixedLength(text_length=30, pad_mode='post')) 60 | func = chain_transform(units_) 61 | return data_pack.apply_on_text(func, verbose=verbose) 62 | -------------------------------------------------------------------------------- /matchzoo/preprocessors/units/__init__.py: -------------------------------------------------------------------------------- 1 | from .unit import Unit 2 | from .digit_removal import DigitRemoval 3 | from .fixed_length import FixedLength 4 | from .frequency_filter import FrequencyFilter 5 | from .lemmatization import Lemmatization 6 | from .lowercase import Lowercase 7 | from .matching_histogram import MatchingHistogram 8 | from .ngram_letter import NgramLetter 9 | from .punc_removal import PuncRemoval 10 | from .stateful_unit import StatefulUnit 11 | from .stemming import Stemming 12 | from .stop_removal import StopRemoval 13 | from .tokenize import Tokenize 14 | from .vocabulary import Vocabulary 15 | from .word_hashing import WordHashing 16 | from .character_index import CharacterIndex 17 | from .word_exact_match import WordExactMatch 18 | from .bert_clean import BertClean 19 | from .bert_clean import StripAccent 20 | from .tokenize import ChineseTokenize 21 | from .tokenize import BasicTokenize 22 | from .tokenize import WordPieceTokenize 23 | from .vocabulary import BertVocabulary 24 | 25 | 26 | def list_available() -> list: 27 | from matchzoo.utils import list_recursive_concrete_subclasses 28 | return list_recursive_concrete_subclasses(Unit) 29 | -------------------------------------------------------------------------------- /matchzoo/preprocessors/units/bert_clean.py: -------------------------------------------------------------------------------- 1 | from .unit import Unit 2 | from matchzoo.utils.bert_utils import \ 3 | is_whitespace, is_control, run_strip_accents 4 | 5 | 6 | class BertClean(Unit): 7 | """Clean unit for raw text.""" 8 | 9 | def transform(self, input_: str) -> str: 10 | """ 11 | Process input data from raw terms to cleaned text. 12 | 13 | :param input_: raw textual input. 14 | 15 | :return cleaned_text: cleaned text. 16 | """ 17 | output = [] 18 | for char in input_: 19 | cp = ord(char) 20 | if cp == 0 or cp == 0xfffd or is_control(char): 21 | continue 22 | if is_whitespace(char): 23 | output.append(" ") 24 | else: 25 | output.append(char) 26 | cleaned_text = "".join(output) 27 | return cleaned_text 28 | 29 | 30 | class StripAccent(Unit): 31 | """Process unit for text lower case.""" 32 | 33 | def transform(self, input_: list) -> list: 34 | """ 35 | Strips accents from each token. 36 | 37 | :param input_: list of tokens. 38 | 39 | :return tokens: Accent-stripped list of tokens. 40 | """ 41 | 42 | return [run_strip_accents(token) for token in input_] 43 | -------------------------------------------------------------------------------- /matchzoo/preprocessors/units/character_index.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .unit import Unit 4 | 5 | 6 | class CharacterIndex(Unit): 7 | """ 8 | CharacterIndexUnit for DIIN model. 9 | 10 | The input of :class:'CharacterIndexUnit' should be a list of word 11 | character list extracted from a text. The output is the character 12 | index representation of this text. 13 | 14 | :class:`NgramLetterUnit` and :class:`VocabularyUnit` are two 15 | essential prerequisite of :class:`CharacterIndexUnit`. 16 | 17 | Examples: 18 | >>> input_ = [['#', 'a', '#'],['#', 'o', 'n', 'e', '#']] 19 | >>> character_index = CharacterIndex( 20 | ... char_index={ 21 | ... '': 0, '': 1, 'a': 2, 'n': 3, 'e':4, '#':5}, 22 | ... fixed_length_text=2, 23 | ... fixed_length_word=5) 24 | >>> index = character_index.transform(input_) 25 | >>> index 26 | [[5.0, 2.0, 5.0, 0.0, 0.0], [5.0, 1.0, 3.0, 4.0, 5.0]] 27 | 28 | """ 29 | 30 | def __init__( 31 | self, 32 | char_index: dict, 33 | fixed_length_text: int, 34 | fixed_length_word: int 35 | ): 36 | """ 37 | Class initialization. 38 | 39 | :param char_index: character-index mapping generated by 40 | :class:'VocabularyUnit'. 41 | :param fixed_length_text: maximize length of a text. 42 | :param fixed_length_word: maximize length of a word. 43 | """ 44 | self._char_index = char_index 45 | self._fixed_length_text = fixed_length_text 46 | self._fixed_length_word = fixed_length_word 47 | 48 | def transform(self, input_: list) -> list: 49 | """ 50 | Transform list of characters to corresponding indices. 51 | 52 | :param input_: list of characters generated by 53 | :class:'NgramLetterUnit'. 54 | 55 | :return: character index representation of a text. 56 | """ 57 | idx = np.zeros((self._fixed_length_text, self._fixed_length_word)) 58 | for i in range(min(len(input_), self._fixed_length_text)): 59 | for j in range(min(len(input_[i]), self._fixed_length_word)): 60 | idx[i, j] = self._char_index.get(input_[i][j], 1) 61 | 62 | return idx.tolist() 63 | -------------------------------------------------------------------------------- /matchzoo/preprocessors/units/digit_removal.py: -------------------------------------------------------------------------------- 1 | from .unit import Unit 2 | 3 | 4 | class DigitRemoval(Unit): 5 | """Process unit to remove digits.""" 6 | 7 | def transform(self, input_: list) -> list: 8 | """ 9 | Remove digits from list of tokens. 10 | 11 | :param input_: list of tokens to be filtered. 12 | 13 | :return tokens: tokens of tokens without digits. 14 | """ 15 | return [token for token in input_ if not token.isdigit()] 16 | -------------------------------------------------------------------------------- /matchzoo/preprocessors/units/fixed_length.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import numpy as np 4 | 5 | from .unit import Unit 6 | 7 | 8 | class FixedLength(Unit): 9 | """ 10 | FixedLengthUnit Class. 11 | 12 | Process unit to get the fixed length text. 13 | 14 | Examples: 15 | >>> from matchzoo.preprocessors.units import FixedLength 16 | >>> fixedlen = FixedLength(3) 17 | >>> fixedlen.transform(list(range(1, 6))) == [3, 4, 5] 18 | True 19 | >>> fixedlen.transform(list(range(1, 3))) == [0, 1, 2] 20 | True 21 | 22 | """ 23 | 24 | def __init__( 25 | self, 26 | text_length: int, 27 | pad_value: typing.Union[int, str] = 0, 28 | pad_mode: str = 'pre', 29 | truncate_mode: str = 'pre' 30 | ): 31 | """ 32 | Class initialization. 33 | 34 | :param text_length: fixed length of the text. 35 | :param pad_value: if text length is smaller than :attr:`text_length`, 36 | filling text with :attr:`pad_value`. 37 | :param pad_mode: String, `pre` or `post`: 38 | pad either before or after each sequence. 39 | :param truncate_mode: String, `pre` or `post`: 40 | remove values from sequences larger than :attr:`text_length`, 41 | either at the beginning or at the end of the sequences. 42 | """ 43 | self._text_length = text_length 44 | self._pad_value = pad_value 45 | self._pad_mode = pad_mode 46 | self._truncate_mode = truncate_mode 47 | 48 | def transform(self, input_: list) -> list: 49 | """ 50 | Transform list of tokenized tokens into the fixed length text. 51 | 52 | :param input_: list of tokenized tokens. 53 | 54 | :return tokens: list of tokenized tokens in fixed length. 55 | """ 56 | # padding process can not handle empty list as input 57 | if len(input_) == 0: 58 | input_ = [self._pad_value] 59 | np_tokens = np.array(input_) 60 | fixed_tokens = np.full([self._text_length], self._pad_value, 61 | dtype=np_tokens.dtype) 62 | 63 | if self._truncate_mode == 'pre': 64 | trunc_tokens = input_[-self._text_length:] 65 | elif self._truncate_mode == 'post': 66 | trunc_tokens = input_[:self._text_length] 67 | else: 68 | raise ValueError('{} is not a vaild ' 69 | 'truncate mode.'.format(self._truncate_mode)) 70 | 71 | if self._pad_mode == 'post': 72 | fixed_tokens[:len(trunc_tokens)] = trunc_tokens 73 | elif self._pad_mode == 'pre': 74 | fixed_tokens[-len(trunc_tokens):] = trunc_tokens 75 | else: 76 | raise ValueError('{} is not a vaild ' 77 | 'pad mode.'.format(self._pad_mode)) 78 | 79 | return fixed_tokens.tolist() 80 | -------------------------------------------------------------------------------- /matchzoo/preprocessors/units/lemmatization.py: -------------------------------------------------------------------------------- 1 | import nltk 2 | 3 | from .unit import Unit 4 | 5 | 6 | class Lemmatization(Unit): 7 | """Process unit for token lemmatization.""" 8 | 9 | def transform(self, input_: list) -> list: 10 | """ 11 | Lemmatization a sequence of tokens. 12 | 13 | :param input_: list of tokens to be lemmatized. 14 | 15 | :return tokens: list of lemmatizd tokens. 16 | """ 17 | lemmatizer = nltk.WordNetLemmatizer() 18 | return [lemmatizer.lemmatize(token, pos='v') for token in input_] 19 | -------------------------------------------------------------------------------- /matchzoo/preprocessors/units/lowercase.py: -------------------------------------------------------------------------------- 1 | from .unit import Unit 2 | 3 | 4 | class Lowercase(Unit): 5 | """Process unit for text lower case.""" 6 | 7 | def transform(self, input_: list) -> list: 8 | """ 9 | Convert list of tokens to lower case. 10 | 11 | :param input_: list of tokens. 12 | 13 | :return tokens: lower-cased list of tokens. 14 | """ 15 | return [token.lower() for token in input_] 16 | -------------------------------------------------------------------------------- /matchzoo/preprocessors/units/matching_histogram.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .unit import Unit 4 | 5 | 6 | class MatchingHistogram(Unit): 7 | """ 8 | MatchingHistogramUnit Class. 9 | 10 | :param bin_size: The number of bins of the matching histogram. 11 | :param embedding_matrix: The word embedding matrix applied to calculate 12 | the matching histogram. 13 | :param normalize: Boolean, normalize the embedding or not. 14 | :param mode: The type of the historgram, it should be one of 'CH', 'NG', 15 | or 'LCH'. 16 | 17 | Examples: 18 | >>> embedding_matrix = np.array([[1.0, -1.0], [1.0, 2.0], [1.0, 3.0]]) 19 | >>> text_left = [0, 1] 20 | >>> text_right = [1, 2] 21 | >>> histogram = MatchingHistogram(3, embedding_matrix, True, 'CH') 22 | >>> histogram.transform([text_left, text_right]) 23 | [[3.0, 1.0, 1.0], [1.0, 2.0, 2.0]] 24 | 25 | """ 26 | 27 | def __init__(self, bin_size: int = 30, embedding_matrix=None, 28 | normalize=True, mode: str = 'LCH'): 29 | """The constructor.""" 30 | self._hist_bin_size = bin_size 31 | self._embedding_matrix = embedding_matrix 32 | if normalize: 33 | self._normalize_embedding() 34 | self._mode = mode 35 | 36 | def _normalize_embedding(self): 37 | """Normalize the embedding matrix.""" 38 | l2_norm = np.sqrt( 39 | (self._embedding_matrix * self._embedding_matrix).sum(axis=1) 40 | ) 41 | self._embedding_matrix = \ 42 | self._embedding_matrix / l2_norm[:, np.newaxis] 43 | 44 | def transform(self, input_: list) -> list: 45 | """Transform the input text.""" 46 | text_left, text_right = input_ 47 | matching_hist = np.ones((len(text_left), self._hist_bin_size), 48 | dtype=np.float32) 49 | embed_left = self._embedding_matrix[text_left] 50 | embed_right = self._embedding_matrix[text_right] 51 | matching_matrix = embed_left.dot(np.transpose(embed_right)) 52 | for (i, j), value in np.ndenumerate(matching_matrix): 53 | bin_index = int((value + 1.) / 2. * (self._hist_bin_size - 1.)) 54 | matching_hist[i][bin_index] += 1.0 55 | if self._mode == 'NH': 56 | matching_sum = matching_hist.sum(axis=1) 57 | matching_hist = matching_hist / matching_sum[:, np.newaxis] 58 | elif self._mode == 'LCH': 59 | matching_hist = np.log(matching_hist) 60 | return matching_hist.tolist() 61 | -------------------------------------------------------------------------------- /matchzoo/preprocessors/units/ngram_letter.py: -------------------------------------------------------------------------------- 1 | from .unit import Unit 2 | 3 | 4 | class NgramLetter(Unit): 5 | """ 6 | Process unit for n-letter generation. 7 | 8 | Triletter is used in :class:`DSSMModel`. 9 | This processor is expected to execute before `Vocab` 10 | has been created. 11 | 12 | Examples: 13 | >>> triletter = NgramLetter() 14 | >>> rv = triletter.transform(['hello', 'word']) 15 | >>> len(rv) 16 | 9 17 | >>> rv 18 | ['#he', 'hel', 'ell', 'llo', 'lo#', '#wo', 'wor', 'ord', 'rd#'] 19 | >>> triletter = NgramLetter(reduce_dim=False) 20 | >>> rv = triletter.transform(['hello', 'word']) 21 | >>> len(rv) 22 | 2 23 | >>> rv 24 | [['#he', 'hel', 'ell', 'llo', 'lo#'], ['#wo', 'wor', 'ord', 'rd#']] 25 | 26 | """ 27 | 28 | def __init__(self, ngram: int = 3, reduce_dim: bool = True): 29 | """ 30 | Class initialization. 31 | 32 | :param ngram: By default use 3-gram (tri-letter). 33 | :param reduce_dim: Reduce to 1-D list for sentence representation. 34 | """ 35 | self._ngram = ngram 36 | self._reduce_dim = reduce_dim 37 | 38 | def transform(self, input_: list) -> list: 39 | """ 40 | Transform token into tri-letter. 41 | 42 | For example, `word` should be represented as `#wo`, 43 | `wor`, `ord` and `rd#`. 44 | 45 | :param input_: list of tokens to be transformed. 46 | 47 | :return n_letters: generated n_letters. 48 | """ 49 | n_letters = [] 50 | for token in input_: 51 | token = '#' + token + '#' 52 | token_ngram = [] 53 | while len(token) >= self._ngram: 54 | token_ngram.append(token[:self._ngram]) 55 | token = token[1:] 56 | if self._reduce_dim: 57 | n_letters.extend(token_ngram) 58 | else: 59 | n_letters.append(token_ngram) 60 | return n_letters 61 | -------------------------------------------------------------------------------- /matchzoo/preprocessors/units/punc_removal.py: -------------------------------------------------------------------------------- 1 | import string 2 | 3 | from .unit import Unit 4 | 5 | 6 | class PuncRemoval(Unit): 7 | """Process unit for remove punctuations.""" 8 | 9 | def transform(self, input_: list) -> list: 10 | """ 11 | Remove punctuations from list of tokens. 12 | 13 | :param input_: list of toekns. 14 | 15 | :return rv: tokens without punctuation. 16 | """ 17 | table = str.maketrans({key: None for key in string.punctuation}) 18 | return [item.translate(table) for item in input_] 19 | -------------------------------------------------------------------------------- /matchzoo/preprocessors/units/stateful_unit.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import typing 3 | 4 | from .unit import Unit 5 | 6 | 7 | class StatefulUnit(Unit, metaclass=abc.ABCMeta): 8 | """ 9 | Unit with inner state. 10 | 11 | Usually need to be fit before transforming. All information gathered in the 12 | fit phrase will be stored into its `context`. 13 | """ 14 | 15 | def __init__(self): 16 | """Initialization.""" 17 | self._context = {} 18 | 19 | @property 20 | def state(self): 21 | """ 22 | Get current context. Same as `unit.context`. 23 | 24 | Deprecated since v2.2.0, and will be removed in the future. 25 | Used `unit.context` instead. 26 | """ 27 | return self._context 28 | 29 | @property 30 | def context(self): 31 | """Get current context. Same as `unit.state`.""" 32 | return self._context 33 | 34 | @abc.abstractmethod 35 | def fit(self, input_: typing.Any): 36 | """Abstract base method, need to be implemented in subclass.""" 37 | -------------------------------------------------------------------------------- /matchzoo/preprocessors/units/stemming.py: -------------------------------------------------------------------------------- 1 | import nltk 2 | 3 | from .unit import Unit 4 | 5 | 6 | class Stemming(Unit): 7 | """ 8 | Process unit for token stemming. 9 | 10 | :param stemmer: stemmer to use, `porter` or `lancaster`. 11 | """ 12 | 13 | def __init__(self, stemmer='porter'): 14 | """Initialization.""" 15 | self.stemmer = stemmer 16 | 17 | def transform(self, input_: list) -> list: 18 | """ 19 | Reducing inflected words to their word stem, base or root form. 20 | 21 | :param input_: list of string to be stemmed. 22 | """ 23 | if self.stemmer == 'porter': 24 | porter_stemmer = nltk.stem.PorterStemmer() 25 | return [porter_stemmer.stem(token) for token in input_] 26 | elif self.stemmer == 'lancaster' or self.stemmer == 'krovetz': 27 | lancaster_stemmer = nltk.stem.lancaster.LancasterStemmer() 28 | return [lancaster_stemmer.stem(token) for token in input_] 29 | else: 30 | raise ValueError( 31 | 'Not supported supported stemmer type: {}'.format( 32 | self.stemmer)) 33 | -------------------------------------------------------------------------------- /matchzoo/preprocessors/units/stop_removal.py: -------------------------------------------------------------------------------- 1 | import nltk 2 | 3 | from .unit import Unit 4 | 5 | 6 | class StopRemoval(Unit): 7 | """ 8 | Process unit to remove stop words. 9 | 10 | Example: 11 | >>> unit = StopRemoval() 12 | >>> unit.transform(['a', 'the', 'test']) 13 | ['test'] 14 | >>> type(unit.stopwords) 15 | 16 | """ 17 | 18 | def __init__(self, lang: str = 'english'): 19 | """Initialization.""" 20 | self._lang = lang 21 | self._stop = nltk.corpus.stopwords.words(self._lang) 22 | 23 | def transform(self, input_: list) -> list: 24 | """ 25 | Remove stopwords from list of tokenized tokens. 26 | 27 | :param input_: list of tokenized tokens. 28 | :param lang: language code for stopwords. 29 | 30 | :return tokens: list of tokenized tokens without stopwords. 31 | """ 32 | return [token 33 | for token 34 | in input_ 35 | if token not in self._stop] 36 | 37 | @property 38 | def stopwords(self) -> list: 39 | """ 40 | Get stopwords based on language. 41 | 42 | :params lang: language code. 43 | :return: list of stop words. 44 | """ 45 | return self._stop 46 | -------------------------------------------------------------------------------- /matchzoo/preprocessors/units/unit.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import typing 3 | 4 | 5 | class Unit(metaclass=abc.ABCMeta): 6 | """Process unit do not persive state (i.e. do not need fit).""" 7 | 8 | @abc.abstractmethod 9 | def transform(self, input_: typing.Any): 10 | """Abstract base method, need to be implemented in subclass.""" 11 | -------------------------------------------------------------------------------- /matchzoo/preprocessors/units/word_exact_match.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas 3 | 4 | from .unit import Unit 5 | 6 | 7 | class WordExactMatch(Unit): 8 | """ 9 | WordExactUnit Class. 10 | 11 | Process unit to get a binary match list of two word index lists. The 12 | word index list is the word representation of a text. 13 | 14 | Examples: 15 | >>> input_ = pandas.DataFrame({ 16 | ... 'text_left':[[1, 2, 3],[4, 5, 7, 9]], 17 | ... 'text_right':[[5, 3, 2, 7],[2, 3, 5]]} 18 | ... ) 19 | >>> left_word_exact_match = WordExactMatch( 20 | ... fixed_length_text=5, 21 | ... match='text_left', to_match='text_right' 22 | ... ) 23 | >>> left_out = input_.apply(left_word_exact_match.transform, axis=1) 24 | >>> left_out[0] 25 | [0.0, 1.0, 1.0, 0.0, 0.0] 26 | >>> left_out[1] 27 | [0.0, 1.0, 0.0, 0.0, 0.0] 28 | >>> right_word_exact_match = WordExactMatch( 29 | ... fixed_length_text=5, 30 | ... match='text_right', to_match='text_left' 31 | ... ) 32 | >>> right_out = input_.apply(right_word_exact_match.transform, axis=1) 33 | >>> right_out[0] 34 | [0.0, 1.0, 1.0, 0.0, 0.0] 35 | >>> right_out[1] 36 | [0.0, 0.0, 1.0, 0.0, 0.0] 37 | 38 | """ 39 | 40 | def __init__( 41 | self, 42 | fixed_length_text: int, 43 | match: str, 44 | to_match: str 45 | ): 46 | """ 47 | Class initialization. 48 | 49 | :param fixed_length_text: fixed length of the text. 50 | :param match: the 'match' column name. 51 | :param to_match: the 'to_match' column name. 52 | """ 53 | self._fixed_length_text = fixed_length_text 54 | self._match = match 55 | self._to_match = to_match 56 | 57 | def transform(self, input_) -> list: 58 | """ 59 | Transform two word index lists into a binary match list. 60 | 61 | :param input_: a dataframe include 'match' column and 62 | 'to_match' column. 63 | 64 | :return: a binary match result list of two word index lists. 65 | """ 66 | match_length = len(input_[self._match]) 67 | match_binary = np.zeros((self._fixed_length_text)) 68 | for i in range(min(self._fixed_length_text, match_length)): 69 | if input_[self._match][i] in set(input_[self._to_match]): 70 | match_binary[i] = 1 71 | 72 | return match_binary.tolist() 73 | -------------------------------------------------------------------------------- /matchzoo/preprocessors/units/word_hashing.py: -------------------------------------------------------------------------------- 1 | import collections 2 | 3 | import numpy as np 4 | 5 | from .unit import Unit 6 | 7 | 8 | class WordHashing(Unit): 9 | """ 10 | Word-hashing layer for DSSM-based models. 11 | 12 | The input of :class:`WordHashingUnit` should be a list of word 13 | sub-letter list extracted from one document. The output of is 14 | the word-hashing representation of this document. 15 | 16 | :class:`NgramLetterUnit` and :class:`VocabularyUnit` are two 17 | essential prerequisite of :class:`WordHashingUnit`. 18 | 19 | Examples: 20 | >>> letters = [['#te', 'tes','est', 'st#'], ['oov']] 21 | >>> word_hashing = WordHashing( 22 | ... term_index={ 23 | ... '_PAD': 0, 'OOV': 1, 'st#': 2, '#te': 3, 'est': 4, 'tes': 5 24 | ... }) 25 | >>> hashing = word_hashing.transform(letters) 26 | >>> hashing[0] 27 | [0.0, 0.0, 1.0, 1.0, 1.0, 1.0] 28 | >>> hashing[1] 29 | [0.0, 1.0, 0.0, 0.0, 0.0, 0.0] 30 | 31 | """ 32 | 33 | def __init__( 34 | self, 35 | term_index: dict, 36 | ): 37 | """ 38 | Class initialization. 39 | 40 | :param term_index: term-index mapping generated by 41 | :class:`VocabularyUnit`. 42 | :param dim_triletter: dimensionality of tri_leltters. 43 | """ 44 | self._term_index = term_index 45 | 46 | def transform(self, input_: list) -> list: 47 | """ 48 | Transform list of :attr:`letters` into word hashing layer. 49 | 50 | :param input_: list of `tri_letters` generated by 51 | :class:`NgramLetterUnit`. 52 | :return: Word hashing representation of `tri-letters`. 53 | """ 54 | if any([isinstance(elem, list) for elem in input_]): 55 | # The input shape for CDSSM is 56 | # [[word1 ngram, ngram], [word2, ngram, ngram], ...]. 57 | hashing = np.zeros((len(input_), len(self._term_index))) 58 | for idx, word in enumerate(input_): 59 | counted_letters = collections.Counter(word) 60 | for key, value in counted_letters.items(): 61 | letter_id = self._term_index.get(key, 1) 62 | hashing[idx, letter_id] = value 63 | else: 64 | # The input shape for DSSM model [ngram, ngram, ...]. 65 | hashing = np.zeros(len(self._term_index)) 66 | counted_letters = collections.Counter(input_) 67 | for key, value in counted_letters.items(): 68 | letter_id = self._term_index.get(key, 1) 69 | hashing[letter_id] = value 70 | 71 | return hashing.tolist() 72 | -------------------------------------------------------------------------------- /matchzoo/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | from .classification import Classification 2 | from .ranking import Ranking 3 | -------------------------------------------------------------------------------- /matchzoo/tasks/classification.py: -------------------------------------------------------------------------------- 1 | """Classification task.""" 2 | 3 | from matchzoo.engine.base_task import BaseTask 4 | 5 | 6 | class Classification(BaseTask): 7 | """Classification task. 8 | 9 | Examples: 10 | >>> classification_task = Classification(num_classes=2) 11 | >>> classification_task.metrics = ['precision'] 12 | >>> classification_task.num_classes 13 | 2 14 | >>> classification_task.output_shape 15 | (2,) 16 | >>> classification_task.output_dtype 17 | 18 | >>> print(classification_task) 19 | Classification Task with 2 classes 20 | 21 | """ 22 | 23 | def __init__(self, num_classes: int = 2, **kwargs): 24 | """Classification task.""" 25 | super().__init__(**kwargs) 26 | if not isinstance(num_classes, int): 27 | raise TypeError("Number of classes must be an integer.") 28 | if num_classes < 2: 29 | raise ValueError("Number of classes can't be smaller than 2") 30 | self._num_classes = num_classes 31 | 32 | @property 33 | def num_classes(self) -> int: 34 | """:return: number of classes to classify.""" 35 | return self._num_classes 36 | 37 | @classmethod 38 | def list_available_losses(cls) -> list: 39 | """:return: a list of available losses.""" 40 | return ['categorical_crossentropy'] 41 | 42 | @classmethod 43 | def list_available_metrics(cls) -> list: 44 | """:return: a list of available metrics.""" 45 | return ['acc'] 46 | 47 | @property 48 | def output_shape(self) -> tuple: 49 | """:return: output shape of a single sample of the task.""" 50 | return self._num_classes, 51 | 52 | @property 53 | def output_dtype(self): 54 | """:return: target data type, expect `int` as output.""" 55 | return int 56 | 57 | def __str__(self): 58 | """:return: Task name as string.""" 59 | return f'Classification Task with {self._num_classes} classes' 60 | -------------------------------------------------------------------------------- /matchzoo/tasks/ranking.py: -------------------------------------------------------------------------------- 1 | """Ranking task.""" 2 | 3 | from matchzoo.engine import base_task 4 | 5 | 6 | class Ranking(base_task.BaseTask): 7 | """Ranking Task. 8 | 9 | Examples: 10 | >>> ranking_task = Ranking() 11 | >>> ranking_task.metrics = ['map', 'ndcg'] 12 | >>> ranking_task.output_shape 13 | (1,) 14 | >>> ranking_task.output_dtype 15 | 16 | >>> print(ranking_task) 17 | Ranking Task 18 | 19 | """ 20 | 21 | @classmethod 22 | def list_available_losses(cls) -> list: 23 | """:return: a list of available losses.""" 24 | return ['mse'] 25 | 26 | @classmethod 27 | def list_available_metrics(cls) -> list: 28 | """:return: a list of available metrics.""" 29 | return ['map'] 30 | 31 | @property 32 | def output_shape(self) -> tuple: 33 | """:return: output shape of a single sample of the task.""" 34 | return 1, 35 | 36 | @property 37 | def output_dtype(self): 38 | """:return: target data type, expect `float` as output.""" 39 | return float 40 | 41 | def __str__(self): 42 | """:return: Task name as string.""" 43 | return 'Ranking Task' 44 | -------------------------------------------------------------------------------- /matchzoo/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .one_hot import one_hot 2 | from .tensor_type import TensorType 3 | from .list_recursive_subclasses import list_recursive_concrete_subclasses 4 | from .make_keras_optimizer_picklable import make_keras_optimizer_picklable 5 | -------------------------------------------------------------------------------- /matchzoo/utils/list_recursive_subclasses.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | 3 | 4 | def list_recursive_concrete_subclasses(base): 5 | """List all concrete subclasses of `base` recursively.""" 6 | return _filter_concrete(_bfs(base)) 7 | 8 | 9 | def _filter_concrete(classes): 10 | return list(filter(lambda c: not inspect.isabstract(c), classes)) 11 | 12 | 13 | def _bfs(base): 14 | return base.__subclasses__() + sum([ 15 | _bfs(subclass) 16 | for subclass in base.__subclasses__() 17 | ], []) 18 | -------------------------------------------------------------------------------- /matchzoo/utils/make_keras_optimizer_picklable.py: -------------------------------------------------------------------------------- 1 | import keras 2 | 3 | 4 | def make_keras_optimizer_picklable(): 5 | """ 6 | Fix https://github.com/NTMC-Community/MatchZoo/issues/726. 7 | 8 | This function changes how keras behaves, use with caution. 9 | """ 10 | def __getstate__(self): 11 | return keras.optimizers.serialize(self) 12 | 13 | def __setstate__(self, state): 14 | optimizer = keras.optimizers.deserialize(state) 15 | self.__dict__ = optimizer.__dict__ 16 | 17 | cls = keras.optimizers.Optimizer 18 | cls.__getstate__ = __getstate__ 19 | cls.__setstate__ = __setstate__ 20 | -------------------------------------------------------------------------------- /matchzoo/utils/one_hot.py: -------------------------------------------------------------------------------- 1 | """One hot vectors.""" 2 | import numpy as np 3 | 4 | 5 | def one_hot(indices: int, num_classes: int) -> np.ndarray: 6 | """:return: A one-hot encoded vector.""" 7 | vec = np.zeros((num_classes,), dtype=np.int64) 8 | vec[indices] = 1 9 | return vec 10 | -------------------------------------------------------------------------------- /matchzoo/utils/tensor_type.py: -------------------------------------------------------------------------------- 1 | """Define Keras tensor type.""" 2 | import typing 3 | 4 | TensorType = typing.Any 5 | -------------------------------------------------------------------------------- /matchzoo/version.py: -------------------------------------------------------------------------------- 1 | """Matchzoo version file.""" 2 | 3 | __version__ = '2.2.0' 4 | -------------------------------------------------------------------------------- /readthedocs.yml: -------------------------------------------------------------------------------- 1 | build: 2 | image: latest 3 | 4 | python: 5 | version: 3.6 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | keras == 2.3.0 2 | tabulate >= 0.8.2 3 | tensorflow >= 2.0.0 4 | nltk >= 3.2.3 5 | numpy >= 1.14 6 | tqdm >= 4.23.4 7 | dill >= 0.2.7.1 8 | hyperopt >= 0.1.1 9 | pandas == 0.24.2 10 | networkx >= 2.1 11 | h5py >= 2.8.0 12 | coverage >= 4.3.4 13 | codecov >= 2.0.15 14 | pytest >= 3.7.4 15 | pytest-cov >= 2.4.0 16 | flake8 >= 3.6.0 17 | flake8_docstrings >= 1.3.0 18 | pydocstyle == 2.1 19 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | 4 | from setuptools import setup, find_packages 5 | 6 | 7 | here = os.path.abspath(os.path.dirname(__file__)) 8 | 9 | # Avoids IDE errors, but actual version is read from version.py 10 | __version__ = None 11 | exec(open('matchzoo/version.py').read()) 12 | 13 | short_description = 'Facilitating the design, comparison and sharing of deep text matching models.' 14 | 15 | # Get the long description from the README file 16 | with io.open(os.path.join(here, 'README.md'), encoding='utf-8') as f: 17 | long_description = f.read() 18 | 19 | install_requires = [ 20 | 'keras >= 2.3.0', 21 | 'nltk >= 3.2.3', 22 | 'numpy >= 1.14', 23 | 'tqdm >= 4.19.4', 24 | 'dill >= 0.2.7.1', 25 | 'pandas >= 0.23.1', 26 | 'networkx >= 2.1', 27 | 'h5py >= 2.8.0', 28 | 'hyperopt >= 0.1.1' 29 | ] 30 | 31 | extras_requires = { 32 | 'tests': [ 33 | 'coverage >= 4.3.4', 34 | 'codecov >= 2.0.15', 35 | 'pytest >= 3.0.3', 36 | 'pytest-cov >= 2.4.0', 37 | 'flake8 >= 3.6.0', 38 | 'flake8_docstrings >= 1.0.2'], 39 | } 40 | 41 | 42 | setup( 43 | name="MatchZoo", 44 | version=__version__, 45 | author="Yixing Fan, Bo Wang, Zeyi Wang, Liang Pang, Liu Yang, Qinghua Wang, etc.", 46 | author_email="fanyixing@ict.ac.cn", 47 | description=(short_description), 48 | license="Apache 2.0", 49 | keywords="text matching models", 50 | url="https://github.com/NTMC-Community/MatchZoo", 51 | packages=find_packages(), 52 | long_description=long_description, 53 | long_description_content_type='text/markdown', 54 | classifiers=[ 55 | "Development Status :: 3 - Alpha", 56 | 'Environment :: Console', 57 | 'Operating System :: POSIX :: Linux', 58 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 59 | "License :: OSI Approved :: Apache Software License", 60 | 'Programming Language :: Python :: 3.6' 61 | ], 62 | install_requires=install_requires, 63 | extras_require=extras_requires 64 | ) 65 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NTMC-Community/MatchZoo/8a487ee5a574356fc91e4f48e219253dc11bcff2/tests/__init__.py -------------------------------------------------------------------------------- /tests/inte_test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NTMC-Community/MatchZoo/8a487ee5a574356fc91e4f48e219253dc11bcff2/tests/inte_test/__init__.py -------------------------------------------------------------------------------- /tests/unit_test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NTMC-Community/MatchZoo/8a487ee5a574356fc91e4f48e219253dc11bcff2/tests/unit_test/__init__.py -------------------------------------------------------------------------------- /tests/unit_test/data_pack/test_datapack.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | 3 | import pandas as pd 4 | import pytest 5 | 6 | from matchzoo import DataPack, load_data_pack 7 | 8 | 9 | @pytest.fixture 10 | def data_pack(): 11 | relation = [['qid0', 'did0', 1], ['qid1', 'did1', 0]] 12 | left = [['qid0', [1, 2]], ['qid1', [2, 3]]] 13 | right = [['did0', [2, 3, 4]], ['did1', [3, 4, 5]]] 14 | relation = pd.DataFrame(relation, columns=['id_left', 'id_right', 'label']) 15 | left = pd.DataFrame(left, columns=['id_left', 'text_left']) 16 | left.set_index('id_left', inplace=True) 17 | right = pd.DataFrame(right, columns=['id_right', 'text_right']) 18 | right.set_index('id_right', inplace=True) 19 | return DataPack(relation=relation, 20 | left=left, 21 | right=right) 22 | 23 | 24 | def test_length(data_pack): 25 | num_examples = 2 26 | assert len(data_pack) == num_examples 27 | 28 | 29 | def test_getter(data_pack): 30 | assert data_pack.relation.iloc[0].values.tolist() == ['qid0', 'did0', 1] 31 | assert data_pack.relation.iloc[1].values.tolist() == ['qid1', 'did1', 0] 32 | assert data_pack.left.loc['qid0', 'text_left'] == [1, 2] 33 | assert data_pack.right.loc['did1', 'text_right'] == [3, 4, 5] 34 | 35 | 36 | def test_save_load(data_pack): 37 | dirpath = '.tmpdir' 38 | data_pack.save(dirpath) 39 | dp = load_data_pack(dirpath) 40 | with pytest.raises(FileExistsError): 41 | data_pack.save(dirpath) 42 | assert len(data_pack) == 2 43 | assert len(dp) == 2 44 | shutil.rmtree(dirpath) 45 | -------------------------------------------------------------------------------- /tests/unit_test/engine/test_base_preprocessor.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import shutil 3 | 4 | import matchzoo as mz 5 | from matchzoo.engine.base_preprocessor import BasePreprocessor 6 | 7 | 8 | @pytest.fixture 9 | def base_preprocessor(): 10 | BasePreprocessor.__abstractmethods__ = set() 11 | base_processor = BasePreprocessor() 12 | return base_processor 13 | 14 | 15 | def test_save_load(base_preprocessor): 16 | dirpath = '.tmpdir' 17 | base_preprocessor.save(dirpath) 18 | assert mz.load_preprocessor(dirpath) 19 | with pytest.raises(FileExistsError): 20 | base_preprocessor.save(dirpath) 21 | shutil.rmtree(dirpath) 22 | -------------------------------------------------------------------------------- /tests/unit_test/engine/test_base_task.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from matchzoo.engine.base_task import BaseTask 3 | 4 | 5 | def test_base_task_instantiation(): 6 | with pytest.raises(TypeError): 7 | BaseTask() 8 | -------------------------------------------------------------------------------- /tests/unit_test/engine/test_hyper_spaces.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import hyperopt.pyll.base 3 | 4 | from matchzoo.engine import hyper_spaces 5 | 6 | 7 | @pytest.fixture(scope='module', params=[ 8 | lambda x: x + 2, 9 | lambda x: x - 2, 10 | lambda x: x * 2, 11 | lambda x: x / 2, 12 | lambda x: x // 2, 13 | lambda x: x ** 2, 14 | lambda x: 2 + x, 15 | lambda x: 2 - x, 16 | lambda x: 2 * x, 17 | lambda x: 2 / x, 18 | lambda x: 2 // x, 19 | lambda x: 2 ** x, 20 | lambda x: -x 21 | ]) 22 | def op(request): 23 | return request.param 24 | 25 | 26 | @pytest.fixture(scope='module', params=[ 27 | hyper_spaces.choice(options=[0, 1]), 28 | hyper_spaces.uniform(low=0, high=10), 29 | hyper_spaces.quniform(low=0, high=10, q=2) 30 | ]) 31 | def proxy(request): 32 | return request.param 33 | 34 | 35 | def test_init(proxy): 36 | assert isinstance(proxy.convert('label'), hyperopt.pyll.base.Apply) 37 | 38 | 39 | def test_op(proxy, op): 40 | assert isinstance(op(proxy).convert('label'), hyperopt.pyll.base.Apply) 41 | 42 | 43 | def test_str(proxy): 44 | assert isinstance(str(proxy), str) 45 | -------------------------------------------------------------------------------- /tests/unit_test/engine/test_param_table.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from matchzoo.engine.param import Param 4 | from matchzoo.engine.param_table import ParamTable 5 | from matchzoo.engine.hyper_spaces import quniform 6 | 7 | 8 | @pytest.fixture 9 | def param_table(): 10 | params = ParamTable() 11 | params.add(Param('ham', 'Parma Ham')) 12 | return params 13 | 14 | 15 | def test_get(param_table): 16 | assert param_table['ham'] == 'Parma Ham' 17 | 18 | 19 | def test_set(param_table): 20 | new_param = Param('egg', 'Over Easy') 21 | param_table.set('egg', new_param) 22 | assert 'egg' in param_table.keys() 23 | 24 | 25 | def test_keys(param_table): 26 | assert 'ham' in param_table.keys() 27 | 28 | 29 | def test_hyper_space(param_table): 30 | new_param = Param( 31 | name='my_param', 32 | value=1, 33 | hyper_space=quniform(low=1, high=5) 34 | ) 35 | param_table.add(new_param) 36 | hyper_space = param_table.hyper_space 37 | assert hyper_space 38 | -------------------------------------------------------------------------------- /tests/unit_test/models/test_base_model.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from matchzoo.engine.base_model import BaseModel 4 | 5 | 6 | def test_base_model_abstract_instantiation(): 7 | with pytest.raises(TypeError): 8 | model = BaseModel(BaseModel.get_default_params()) 9 | assert model 10 | 11 | 12 | def test_base_model_concrete_instantiation(): 13 | class MyBaseModel(BaseModel): 14 | def build(self): 15 | self._backend = 'something' 16 | 17 | model = MyBaseModel() 18 | assert model.params 19 | model.guess_and_fill_missing_params() 20 | model.build() 21 | assert model.backend 22 | assert model.params.completed() 23 | -------------------------------------------------------------------------------- /tests/unit_test/models/test_models.py: -------------------------------------------------------------------------------- 1 | """ 2 | These tests are simplied because the original verion takes too much time to 3 | run, making CI fails as it reaches the time limit. 4 | """ 5 | 6 | import pytest 7 | import copy 8 | from pathlib import Path 9 | import shutil 10 | 11 | import matchzoo as mz 12 | from keras.backend import clear_session 13 | 14 | @pytest.fixture(scope='module', params=[ 15 | mz.tasks.Ranking(loss=mz.losses.RankCrossEntropyLoss(num_neg=2)), 16 | mz.tasks.Classification(num_classes=2), 17 | ]) 18 | def task(request): 19 | return request.param 20 | 21 | 22 | @pytest.fixture(scope='module') 23 | def train_raw(task): 24 | return mz.datasets.toy.load_data('train', task)[:5] 25 | 26 | 27 | @pytest.fixture(scope='module', params=mz.models.list_available()) 28 | def model_class(request): 29 | return request.param 30 | 31 | 32 | @pytest.fixture(scope='module') 33 | def embedding(): 34 | return mz.datasets.toy.load_embedding() 35 | 36 | 37 | @pytest.fixture(scope='module') 38 | def setup(task, model_class, train_raw, embedding): 39 | clear_session() # prevent OOM during CI tests 40 | return mz.auto.prepare( 41 | task=task, 42 | model_class=model_class, 43 | data_pack=train_raw, 44 | embedding=embedding 45 | ) 46 | 47 | 48 | @pytest.fixture(scope='module') 49 | def model(setup): 50 | return setup[0] 51 | 52 | 53 | @pytest.fixture(scope='module') 54 | def preprocessor(setup): 55 | return setup[1] 56 | 57 | 58 | @pytest.fixture(scope='module') 59 | def gen_builder(setup): 60 | return setup[2] 61 | 62 | 63 | @pytest.fixture(scope='module') 64 | def embedding_matrix(setup): 65 | return setup[3] 66 | 67 | 68 | @pytest.fixture(scope='module') 69 | def data(train_raw, preprocessor, gen_builder): 70 | return gen_builder.build(preprocessor.transform(train_raw))[0] 71 | 72 | 73 | @pytest.mark.slow 74 | def test_model_fit_eval_predict(model, data): 75 | x, y = data 76 | batch_size = len(x['id_left']) 77 | assert model.fit(x, y, batch_size=batch_size, verbose=0) 78 | assert model.evaluate(x, y, batch_size=batch_size) 79 | assert model.predict(x, batch_size=batch_size) is not None 80 | 81 | 82 | @pytest.mark.cron 83 | def test_save_load_model(model): 84 | tmpdir = '.matchzoo_test_save_load_tmpdir' 85 | 86 | if Path(tmpdir).exists(): 87 | shutil.rmtree(tmpdir) 88 | 89 | try: 90 | model.save(tmpdir) 91 | assert mz.load_model(tmpdir) 92 | with pytest.raises(FileExistsError): 93 | model.save(tmpdir) 94 | finally: 95 | if Path(tmpdir).exists(): 96 | shutil.rmtree(tmpdir) 97 | 98 | 99 | @pytest.mark.cron 100 | def test_hyper_space(model): 101 | for _ in range(2): 102 | new_params = copy.deepcopy(model.params) 103 | sample = mz.hyper_spaces.sample(new_params.hyper_space) 104 | for key, value in sample.items(): 105 | new_params[key] = value 106 | new_model = new_params['model_class'](params=new_params) 107 | new_model.build() 108 | new_model.compile() 109 | -------------------------------------------------------------------------------- /tests/unit_test/tasks/test_tasks.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from matchzoo import tasks 4 | 5 | 6 | @pytest.mark.parametrize("task_type", [ 7 | tasks.Ranking, tasks.Classification 8 | ]) 9 | def test_task_listings(task_type): 10 | assert task_type.list_available_losses() 11 | assert task_type.list_available_metrics() 12 | 13 | 14 | @pytest.mark.parametrize("arg", [None, -1, 0, 1]) 15 | def test_classification_instantiation_failure(arg): 16 | with pytest.raises(Exception): 17 | tasks.Classification(num_classes=arg) 18 | 19 | 20 | @pytest.mark.parametrize("arg", [2, 10, 2048]) 21 | def test_classification_num_classes(arg): 22 | task = tasks.Classification(num_classes=arg) 23 | assert task.num_classes == arg 24 | -------------------------------------------------------------------------------- /tests/unit_test/test_data_generator.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import pytest 4 | import keras 5 | 6 | import matchzoo as mz 7 | 8 | 9 | @pytest.fixture(scope='module') 10 | def data_gen(): 11 | return mz.DataGenerator(mz.datasets.toy.load_data()) 12 | 13 | 14 | @pytest.mark.parametrize('attr', [ 15 | 'callbacks', 16 | 'num_neg', 17 | 'num_dup', 18 | 'mode', 19 | 'batch_size', 20 | 'shuffle', 21 | 22 | ]) 23 | def test_data_generator_getters_setters(data_gen, attr): 24 | assert hasattr(data_gen, attr) 25 | val = getattr(data_gen, attr) 26 | setattr(data_gen, attr, val) 27 | assert getattr(data_gen, attr) == val 28 | 29 | 30 | def test_resample(): 31 | model = mz.models.Naive() 32 | prpr = model.get_default_preprocessor() 33 | data_raw = mz.datasets.toy.load_data() 34 | data = prpr.fit_transform(data_raw) 35 | model.params.update(prpr.context) 36 | model.params['task'] = mz.tasks.Ranking() 37 | model.build() 38 | model.compile() 39 | 40 | data_gen = mz.DataGenerator( 41 | data_pack=data, 42 | mode='pair', 43 | resample=True, 44 | batch_size=4 45 | ) 46 | 47 | class CheckResample(keras.callbacks.Callback): 48 | def __init__(self, data_gen): 49 | super().__init__() 50 | self._data_gen = data_gen 51 | self._orig_indices = None 52 | self._flags = [] 53 | 54 | def on_epoch_end(self, epoch, logs=None): 55 | curr_indices = self._data_gen.batch_indices 56 | if not self._orig_indices: 57 | self._orig_indices = copy.deepcopy(curr_indices) 58 | else: 59 | self._flags.append(self._orig_indices != curr_indices) 60 | self._orig_indices = curr_indices 61 | 62 | check_resample = CheckResample(data_gen) 63 | model.fit_generator(data_gen, epochs=5, callbacks=[check_resample]) 64 | assert check_resample._flags 65 | assert all(check_resample._flags) 66 | -------------------------------------------------------------------------------- /tests/unit_test/test_embedding.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import matchzoo as mz 4 | 5 | 6 | @pytest.fixture 7 | def term_index(): 8 | return {'G': 1, 'C': 2, 'D': 3, 'A': 4, '_PAD': 0} 9 | 10 | 11 | def test_embedding(term_index): 12 | embed = mz.embedding.load_from_file(mz.datasets.embeddings.EMBED_RANK) 13 | matrix = embed.build_matrix(term_index) 14 | assert matrix.shape == (len(term_index), 50) 15 | embed = mz.embedding.load_from_file(mz.datasets.embeddings.EMBED_10_GLOVE, 16 | mode='glove') 17 | matrix = embed.build_matrix(term_index) 18 | assert matrix.shape == (len(term_index), 10) 19 | assert embed.input_dim == 5 20 | -------------------------------------------------------------------------------- /tests/unit_test/test_layers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from keras import backend as K 4 | 5 | from matchzoo import layers 6 | from matchzoo.contrib.layers import SpatialGRU 7 | from matchzoo.contrib.layers import MatchingTensorLayer 8 | 9 | 10 | def test_matching_layers(): 11 | s1_value = np.array([[[1, 2], [2, 3], [3, 4]], 12 | [[0.1, 0.2], [0.2, 0.3], [0.3, 0.4]] 13 | ]) 14 | s2_value = np.array([[[1, 2], [2, 3]], 15 | [[0.1, 0.2], [0.2, 0.3]] 16 | ]) 17 | s3_value = np.array([[[1, 2], [2, 3]], 18 | [[0.1, 0.2], [0.2, 0.3]], 19 | [[0.1, 0.2], [0.2, 0.3]] 20 | ]) 21 | s1_tensor = K.variable(s1_value) 22 | s2_tensor = K.variable(s2_value) 23 | s3_tensor = K.variable(s3_value) 24 | for matching_type in ['dot', 'mul', 'plus', 'minus', 'concat']: 25 | model = layers.MatchingLayer(matching_type=matching_type)([s1_tensor, s2_tensor]) 26 | ret = K.eval(model) 27 | with pytest.raises(ValueError): 28 | layers.MatchingLayer(matching_type='error') 29 | with pytest.raises(ValueError): 30 | layers.MatchingLayer()([s1_tensor, s3_tensor]) 31 | 32 | 33 | def test_spatial_gru(): 34 | s_value = K.variable(np.array([[[[1, 2], [2, 3], [3, 4]], 35 | [[4, 5], [5, 6], [6, 7]]], 36 | [[[0.1, 0.2], [0.2, 0.3], [0.3, 0.4]], 37 | [[0.4, 0.5], [0.5, 0.6], [0.6, 0.7]]]])) 38 | for direction in ['lt', 'rb']: 39 | model = SpatialGRU(direction=direction) 40 | _ = K.eval(model(s_value)) 41 | with pytest.raises(ValueError): 42 | SpatialGRU(direction='lr')(s_value) 43 | 44 | 45 | def test_matching_tensor_layer(): 46 | s1_value = np.array([[[1, 2], [2, 3], [3, 4]], 47 | [[0.1, 0.2], [0.2, 0.3], [0.3, 0.4]]]) 48 | s2_value = np.array([[[1, 2], [2, 3]], 49 | [[0.1, 0.2], [0.2, 0.3]]]) 50 | s3_value = np.array([[[1, 2], [2, 3]], 51 | [[0.1, 0.2], [0.2, 0.3]], 52 | [[0.1, 0.2], [0.2, 0.3]]]) 53 | s1_tensor = K.variable(s1_value) 54 | s2_tensor = K.variable(s2_value) 55 | s3_tensor = K.variable(s3_value) 56 | for init_diag in [True, False]: 57 | model = MatchingTensorLayer(init_diag=init_diag) 58 | _ = K.eval(model([s1_tensor, s2_tensor])) 59 | with pytest.raises(ValueError): 60 | MatchingTensorLayer()([s1_tensor, s3_tensor]) 61 | -------------------------------------------------------------------------------- /tests/unit_test/test_losses.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from keras import backend as K 3 | 4 | from matchzoo import losses 5 | 6 | 7 | def test_hinge_loss(): 8 | true_value = K.variable(np.array([[1.2], [1], 9 | [1], [1]])) 10 | pred_value = K.variable(np.array([[1.2], [0.1], 11 | [0], [-0.3]])) 12 | expected_loss = (0 + 1 - 0.3 + 0) / 2.0 13 | loss = K.eval(losses.RankHingeLoss()(true_value, pred_value)) 14 | assert np.isclose(expected_loss, loss) 15 | expected_loss = (2 + 0.1 - 1.2 + 2 - 0.3 + 0) / 2.0 16 | loss = K.eval(losses.RankHingeLoss(margin=2)(true_value, pred_value)) 17 | assert np.isclose(expected_loss, loss) 18 | true_value = K.variable(np.array([[1.2], [1], [0.8], 19 | [1], [1], [0.8]])) 20 | pred_value = K.variable(np.array([[1.2], [0.1], [-0.5], 21 | [0], [0], [-0.3]])) 22 | expected_loss = (0 + 1 - 0.15) / 2.0 23 | loss = K.eval(losses.RankHingeLoss(num_neg=2, margin=1)( 24 | true_value, pred_value)) 25 | assert np.isclose(expected_loss, loss) 26 | 27 | 28 | def test_rank_crossentropy_loss(): 29 | losses.neg_num = 1 30 | 31 | def softmax(x): 32 | return np.exp(x) / np.sum(np.exp(x), axis=0) 33 | 34 | true_value = K.variable(np.array([[1.], [0.], 35 | [0.], [1.]])) 36 | pred_value = K.variable(np.array([[0.8], [0.1], 37 | [0.8], [0.1]])) 38 | expected_loss = (-np.log(softmax([0.8, 0.1])[0]) - np.log( 39 | softmax([0.8, 0.1])[1])) / 2 40 | loss = K.eval(losses.RankCrossEntropyLoss()(true_value, pred_value)) 41 | assert np.isclose(expected_loss, loss) 42 | true_value = K.variable(np.array([[1.], [0.], [0.], 43 | [0.], [1.], [0.]])) 44 | pred_value = K.variable(np.array([[0.8], [0.1], [0.1], 45 | [0.8], [0.1], [0.1]])) 46 | expected_loss = (-np.log(softmax([0.8, 0.1, 0.1])[0]) - np.log( 47 | softmax([0.8, 0.1, 0.1])[1])) / 2 48 | loss = K.eval(losses.RankCrossEntropyLoss(num_neg=2)( 49 | true_value, pred_value)) 50 | assert np.isclose(expected_loss, loss) 51 | -------------------------------------------------------------------------------- /tests/unit_test/test_metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from matchzoo.engine.base_metric import sort_and_couple 4 | from matchzoo import metrics 5 | 6 | 7 | def test_sort_and_couple(): 8 | l = [0, 1, 2] 9 | s = [0.1, 0.4, 0.2] 10 | c = sort_and_couple(l, s) 11 | assert (c == np.array([(1, 0.4), (2, 0.2), (0, 0.1)])).all() 12 | 13 | 14 | def test_mean_reciprocal_rank(): 15 | label = [0, 1, 2] 16 | score = [0.1, 0.4, 0.2] 17 | assert metrics.MeanReciprocalRank()(label, score) == 1 18 | 19 | 20 | def test_precision_at_k(): 21 | label = [0, 1, 2] 22 | score = [0.1, 0.4, 0.2] 23 | assert metrics.Precision(k=1)(label, score) == 1. 24 | assert metrics.Precision(k=2)(label, score) == 1. 25 | assert round(metrics.Precision(k=3)(label, score), 2) == 0.67 26 | 27 | 28 | def test_average_precision(): 29 | label = [0, 1, 2] 30 | score = [0.1, 0.4, 0.2] 31 | assert round(metrics.AveragePrecision()(label, score), 2) == 0.89 32 | 33 | 34 | def test_mean_average_precision(): 35 | label = [0, 1, 2] 36 | score = [0.1, 0.4, 0.2] 37 | assert metrics.MeanAveragePrecision()(label, score) == 1. 38 | 39 | 40 | def test_dcg_at_k(): 41 | label = [0, 1, 2] 42 | score = [0.1, 0.4, 0.2] 43 | dcg = metrics.DiscountedCumulativeGain 44 | assert round(dcg(k=1)(label, score), 2) == 1.44 45 | assert round(dcg(k=2)(label, score), 2) == 4.17 46 | assert round(dcg(k=3)(label, score), 2) == 4.17 47 | 48 | 49 | def test_ndcg_at_k(): 50 | label = [0, 1, 2] 51 | score = [0.1, 0.4, 0.2] 52 | ndcg = metrics.NormalizedDiscountedCumulativeGain 53 | assert round(ndcg(k=1)(label, score), 2) == 0.33 54 | assert round(ndcg(k=2)(label, score), 2) == 0.80 55 | assert round(ndcg(k=3)(label, score), 2) == 0.80 56 | -------------------------------------------------------------------------------- /tests/unit_test/test_tuner.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import matchzoo as mz 4 | 5 | 6 | @pytest.fixture(scope='module') 7 | def tuner(): 8 | model = mz.models.DenseBaseline() 9 | prpr = model.get_default_preprocessor() 10 | train_raw = mz.datasets.toy.load_data('train') 11 | dev_raw = mz.datasets.toy.load_data('dev') 12 | prpr.fit(train_raw) 13 | model.params.update(prpr.context) 14 | model.guess_and_fill_missing_params() 15 | return mz.auto.Tuner( 16 | params=model.params, 17 | train_data=prpr.transform(train_raw, verbose=0), 18 | test_data=prpr.transform(dev_raw, verbose=0) 19 | ) 20 | 21 | 22 | @pytest.mark.parametrize('attr', [ 23 | 'params', 24 | 'train_data', 25 | 'test_data', 26 | 'fit_kwargs', 27 | 'evaluate_kwargs', 28 | 'metric', 29 | 'mode', 30 | 'num_runs', 31 | 'callbacks', 32 | 'verbose' 33 | ]) 34 | def test_getters_setters(tuner, attr): 35 | val = getattr(tuner, attr) 36 | setattr(tuner, attr, val) 37 | assert getattr(tuner, attr) is val 38 | 39 | 40 | def test_tuning(tuner): 41 | tuner.num_runs = 1 42 | assert tuner.tune() 43 | -------------------------------------------------------------------------------- /tests/unit_test/test_utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NTMC-Community/MatchZoo/8a487ee5a574356fc91e4f48e219253dc11bcff2/tests/unit_test/test_utils.py -------------------------------------------------------------------------------- /tutorials/models.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Under Construction" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "Refer to 'tutorials/wikiqa' for model walkthroughs. " 15 | ] 16 | } 17 | ], 18 | "metadata": { 19 | "hide_input": false, 20 | "kernelspec": { 21 | "display_name": "matchzoo", 22 | "language": "python", 23 | "name": "matchzoo" 24 | }, 25 | "language_info": { 26 | "codemirror_mode": { 27 | "name": "ipython", 28 | "version": 3 29 | }, 30 | "file_extension": ".py", 31 | "mimetype": "text/x-python", 32 | "name": "python", 33 | "nbconvert_exporter": "python", 34 | "pygments_lexer": "ipython3", 35 | "version": "3.6.3" 36 | }, 37 | "toc": { 38 | "nav_menu": {}, 39 | "number_sections": true, 40 | "sideBar": true, 41 | "skip_h1_title": false, 42 | "toc_cell": false, 43 | "toc_position": {}, 44 | "toc_section_display": "block", 45 | "toc_window_display": false 46 | }, 47 | "varInspector": { 48 | "cols": { 49 | "lenName": 16, 50 | "lenType": 16, 51 | "lenVar": 40 52 | }, 53 | "kernels_config": { 54 | "python": { 55 | "delete_cmd_postfix": "", 56 | "delete_cmd_prefix": "del ", 57 | "library": "var_list.py", 58 | "varRefreshCmd": "print(var_dic_list())" 59 | }, 60 | "r": { 61 | "delete_cmd_postfix": ") ", 62 | "delete_cmd_prefix": "rm(", 63 | "library": "var_list.r", 64 | "varRefreshCmd": "cat(var_dic_list()) " 65 | } 66 | }, 67 | "types_to_exclude": [ 68 | "module", 69 | "function", 70 | "builtin_function_or_method", 71 | "instance", 72 | "_Feature" 73 | ], 74 | "window_display": false 75 | } 76 | }, 77 | "nbformat": 4, 78 | "nbformat_minor": 2 79 | } 80 | -------------------------------------------------------------------------------- /tutorials/quick_start_chart.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NTMC-Community/MatchZoo/8a487ee5a574356fc91e4f48e219253dc11bcff2/tutorials/quick_start_chart.png --------------------------------------------------------------------------------