├── .bumpversion.cfg ├── .github ├── PULL_REQUEST_TEMPLATE.md └── workflows │ ├── python-app-py3.yml │ └── python-publish.yml ├── .gitignore ├── CONTRIBUTING.md ├── LICENSE.md ├── README.md ├── RELEASING.md ├── detext_model_architecture.png ├── pypi_release.sh ├── setup.cfg ├── setup.py ├── src ├── detext │ ├── __init__.py │ ├── args.py │ ├── examples │ │ ├── __init__.py │ │ ├── embedding_layer_example.py │ │ └── vocab_layer_example.py │ ├── layers │ │ ├── __init__.py │ │ ├── bert_layer.py │ │ ├── cnn_layer.py │ │ ├── embedding_layer.py │ │ ├── feature_grouper.py │ │ ├── feature_name_type_converter.py │ │ ├── feature_normalizer.py │ │ ├── feature_rescaler.py │ │ ├── id_embed_layer.py │ │ ├── interaction_layer.py │ │ ├── lstm_layer.py │ │ ├── multi_layer_perceptron.py │ │ ├── output_transform_layer.py │ │ ├── representation_layer.py │ │ ├── scoring_layer.py │ │ ├── shallow_tower_layer.py │ │ ├── sparse_embedding_layer.py │ │ └── vocab_layer.py │ ├── metaclass.py │ ├── run_detext.py │ ├── train │ │ ├── __init__.py │ │ ├── constant.py │ │ ├── data_fn.py │ │ ├── loss.py │ │ ├── metrics.py │ │ ├── model.py │ │ ├── optimization.py │ │ ├── train.py │ │ ├── train_flow_helper.py │ │ └── train_model_helper.py │ └── utils │ │ ├── __init__.py │ │ ├── distributed_utils.py │ │ ├── layer_utils.py │ │ ├── parsing_utils.py │ │ ├── testing │ │ ├── __init__.py │ │ ├── data_setup.py │ │ └── testing_utils.py │ │ └── vocab_utils.py ├── libert │ ├── __init__.py │ └── preprocess.py └── smart_compose │ ├── __init__.py │ ├── args.py │ ├── layers │ ├── __init__.py │ ├── beam_search.py │ ├── embedding_layer.py │ ├── prefix_search.py │ └── vocab_layer.py │ ├── run_smart_compose.py │ ├── train │ ├── __init__.py │ ├── data_fn.py │ ├── losses.py │ ├── metrics.py │ ├── model.py │ ├── optimization.py │ ├── train.py │ ├── train_flow_helper.py │ └── train_model_helper.py │ └── utils │ ├── __init__.py │ ├── distributed_utils.py │ ├── layer_utils.py │ ├── parsing_utils.py │ ├── testing │ ├── __init__.py │ ├── data_setup.py │ ├── test_case.py │ └── testing_utils.py │ └── vocab_utils.py ├── test ├── __init__.py ├── detext │ ├── __init__.py │ ├── layers │ │ ├── __init__.py │ │ ├── test_bert_layer.py │ │ ├── test_cnn_layer.py │ │ ├── test_embedding_layer.py │ │ ├── test_feature_grouper.py │ │ ├── test_feature_name_type_converter.py │ │ ├── test_feature_normalizer.py │ │ ├── test_feature_rescaler.py │ │ ├── test_id_embed_layer.py │ │ ├── test_interaction_layer.py │ │ ├── test_lstm_layer.py │ │ ├── test_multi_layer_perceptron.py │ │ ├── test_representation_layer.py │ │ ├── test_scoring_layer.py │ │ ├── test_shallow_tower_layer.py │ │ ├── test_sparse_embedding_layer.py │ │ └── test_vocab_layer.py │ ├── resources │ │ ├── bert-hub │ │ │ ├── assets │ │ │ │ └── uncased_vocab.txt │ │ │ ├── saved_model.pb │ │ │ └── variables │ │ │ │ ├── variables.data-00000-of-00001 │ │ │ │ └── variables.index │ │ ├── bert_config.json │ │ ├── embedding_layer_hub │ │ │ ├── saved_model.pb │ │ │ └── variables │ │ │ │ ├── variables.data-00000-of-00001 │ │ │ │ └── variables.index │ │ ├── libert-sp-hub │ │ │ ├── assets │ │ │ │ ├── __tokenizer_type__ │ │ │ │ └── spbpe.model │ │ │ ├── saved_model.pb │ │ │ └── variables │ │ │ │ ├── variables.data-00000-of-00001 │ │ │ │ └── variables.index │ │ ├── libert-space-hub │ │ │ ├── assets │ │ │ │ ├── __tokenizer_type__ │ │ │ │ └── vocab.txt │ │ │ ├── saved_model.pb │ │ │ └── variables │ │ │ │ ├── variables.data-00000-of-00001 │ │ │ │ └── variables.index │ │ ├── multilingual_vocab.txt.gz │ │ ├── run_detext.sh │ │ ├── sample_data │ │ │ └── hc_examples.tfrecord │ │ ├── train │ │ │ ├── binary_classification │ │ │ │ └── tfrecord │ │ │ │ │ └── test.tfrecord │ │ │ ├── classification │ │ │ │ └── tfrecord │ │ │ │ │ └── test.tfrecord │ │ │ ├── dataset │ │ │ │ └── tfrecord │ │ │ │ │ └── test.tfrecord │ │ │ ├── multitask │ │ │ │ └── tfrecord │ │ │ │ │ └── test.tfrecord │ │ │ └── ranking │ │ │ │ └── tfrecord │ │ │ │ └── test.tfrecord │ │ ├── vocab.txt │ │ ├── vocab_layer_hub │ │ │ ├── saved_model.pb │ │ │ └── variables │ │ │ │ ├── variables.data-00000-of-00001 │ │ │ │ └── variables.index │ │ └── we.pkl │ ├── test_metaclass.py │ ├── test_run_detext.py │ ├── test_tf2.py │ ├── train │ │ ├── __init__.py │ │ ├── test_data_fn.py │ │ ├── test_loss.py │ │ ├── test_metrics.py │ │ ├── test_model.py │ │ ├── test_optimization.py │ │ ├── test_train_flow_helper.py │ │ └── test_train_model_helper.py │ └── utils │ │ ├── __init__.py │ │ └── test_parsing_utils.py └── smart_compose │ ├── __init__.py │ ├── layers │ ├── __init__.py │ ├── test_embedding_layer.py │ ├── test_prefix_search.py │ └── test_vocab_layer.py │ ├── resources │ ├── embedding_layer_hub │ │ ├── saved_model.pb │ │ └── variables │ │ │ ├── variables.data-00000-of-00001 │ │ │ └── variables.index │ ├── train │ │ └── dataset │ │ │ └── tfrecord │ │ │ └── test.tfrecord │ ├── vocab.30k.txt │ ├── vocab.txt │ └── vocab_layer_hub │ │ ├── saved_model.pb │ │ └── variables │ │ ├── variables.data-00000-of-00001 │ │ └── variables.index │ ├── test_run_smart_compose.py │ ├── train │ ├── __init__.py │ ├── test_data_fn.py │ ├── test_losses.py │ ├── test_metrics.py │ ├── test_model.py │ └── test_optimization.py │ └── utils │ ├── __init__.py │ ├── test_layer_utils.py │ ├── test_parsing_utils.py │ └── test_vocab_utils.py ├── thumbnail_DeText.png └── user_guide ├── TRAINING.md └── notebooks ├── autocompletion_demo.ipynb └── text_classification_demo.ipynb /.bumpversion.cfg: -------------------------------------------------------------------------------- 1 | [bumpversion] 2 | current_version = 3.1.0 3 | commit = True 4 | tag = True 5 | 6 | [bumpversion:file:setup.py] 7 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | # Description 2 | 3 | Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change. 4 | 5 | Fixes # (issue) 6 | 7 | ## Type of change 8 | 9 | Please delete options that are not relevant. 10 | 11 | - [ ] Bug fix (non-breaking change which fixes an issue) 12 | - [ ] New feature (non-breaking change which adds functionality) 13 | - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) 14 | 15 | ## List all changes 16 | Please list all changes in the commit. 17 | * change1 18 | * change2 19 | 20 | # Testing 21 | Please describe the tests that you ran to verify your changes. Provide instructions so we can reproduce. Please also list any relevant details for your test configuration 22 | 23 | 24 | **Test Configuration**: 25 | * Firmware version: 26 | * Hardware: 27 | * Toolchain: 28 | * SDK: 29 | 30 | # Checklist 31 | 32 | - [ ] My code follows the style guidelines of this project 33 | - [ ] I have performed a self-review of my own code 34 | - [ ] I have commented my code, particularly in hard-to-understand areas 35 | - [ ] I have made corresponding changes to the documentation 36 | - [ ] My changes generate no new warnings 37 | - [ ] I have added tests that prove my fix is effective or that my feature works 38 | - [ ] New and existing unit tests pass locally with my changes 39 | - [ ] Any dependent changes have been merged and published in downstream modules 40 | -------------------------------------------------------------------------------- /.github/workflows/python-app-py3.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Python 3 application 5 | 6 | on: 7 | push: 8 | branches: [ master ] 9 | pull_request: 10 | branches: [ master ] 11 | 12 | jobs: 13 | build: 14 | strategy: 15 | matrix: 16 | python-version: [3.7] 17 | runs-on: ubuntu-latest 18 | 19 | steps: 20 | - uses: actions/checkout@v2 21 | - name: Set up Python ${{ matrix.python-version }} 22 | uses: actions/setup-python@v1 23 | with: 24 | python-version: ${{ matrix.python-version }} 25 | - name: Install dependencies 26 | run: | 27 | python -m pip install --upgrade pip 28 | pip install --upgrade setuptools 29 | pip install . 30 | - name: Lint with flake8 31 | run: | 32 | pip install -U flake8 33 | # stop the build if there are Python syntax errors or undefined names 34 | flake8 . --count --show-source --statistics 35 | - name: Test with pytest 36 | run: | 37 | pip install pytest 38 | pytest 39 | 40 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflows will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | name: Upload Python Package 5 | 6 | on: 7 | release: 8 | types: [created] 9 | 10 | jobs: 11 | deploy: 12 | 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v2 17 | - name: Set up Python 18 | uses: actions/setup-python@v2 19 | with: 20 | python-version: '3.7' 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install setuptools wheel twine 25 | - name: Build and publish 26 | env: 27 | TWINE_USERNAME: xwli 28 | TWINE_PASSWORD: ${{ secrets.pypi_password }} 29 | run: | 30 | python setup.py sdist bdist_wheel 31 | twine upload dist/* --verbose 32 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /out 2 | *.egg 3 | *.egg-info/ 4 | *.iml 5 | *.ipr 6 | *.iws 7 | *.pyc 8 | *.pyo 9 | *.sublime-* 10 | .*.swo 11 | .*.swp 12 | .cache/ 13 | .coverage 14 | .direnv/ 15 | .env 16 | .envrc 17 | .gradle/ 18 | .idea/ 19 | .tox* 20 | .venv* 21 | venv 22 | /*/*pinned.txt 23 | /*/MANIFEST 24 | /*/activate 25 | /*/build/ 26 | /*/config 27 | /*/coverage.xml 28 | /*/dist/ 29 | /*/htmlcov/ 30 | /*/product-spec.json 31 | /build/ 32 | /config/ 33 | /dist/ 34 | /ligradle/ 35 | TEST-*.xml 36 | __pycache__/ 37 | /*/build 38 | .DS_Store 39 | 40 | gradle/ 41 | gradlew 42 | gradlew.bat 43 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | Contribution Agreement 2 | ====================== 3 | 4 | As a contributor, you represent that the code you submit is your original work or 5 | that of your employer (in which case you represent you have the right to bind your 6 | employer). By submitting code, you (and, if applicable, your employer) are 7 | licensing the submitted code to LinkedIn and the open source community subject 8 | to the BSD 2-Clause license. 9 | 10 | Responsible Disclosure of Security Vulnerabilities 11 | ================================================== 12 | 13 | **Do not file an issue on Github for security issues.** Please review 14 | the [guidelines for disclosure][disclosure_guidelines]. Reports should 15 | be encrypted using PGP ([public key][pubkey]) and sent to 16 | [security@linkedin.com][disclosure_email] preferably with the title 17 | "Vulnerability in Github LinkedIn/detext - <short summary>". 18 | 19 | Tips for Getting Your Pull Request Accepted 20 | =========================================== 21 | 22 | 1. Make sure all new features are tested and the tests pass. 23 | 2. Bug fixes must include a test case demonstrating the error that it fixes. 24 | 3. Open an issue first and seek advice for your change before submitting 25 | a pull request. Large features which have never been discussed are 26 | unlikely to be accepted. **You have been warned.** 27 | 28 | [disclosure_guidelines]: https://www.linkedin.com/help/linkedin/answer/62924 29 | [pubkey]: https://www.linkedin.com/help/linkedin/answer/79676 30 | [disclosure_email]: mailto:security@linkedin.com?subject=Vulnerability%20in%20Github%20LinkedIn/detext%20-%20%3Csummary%3E 31 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | BSD 2-CLAUSE LICENSE 2 | Copyright 2019 LinkedIn Corporation 3 | All Rights Reserved. 4 | 5 | Redistribution and use in source and binary forms, with or 6 | without modification, are permitted provided that the following 7 | conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright 10 | notice, this list of conditions and the following disclaimer. 11 | 2. Redistributions in binary form must reproduce the above 12 | copyright notice, this list of conditions and the following 13 | disclaimer in the documentation and/or other materials provided 14 | with the distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 17 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 18 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 19 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 20 | HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 21 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 22 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 23 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 24 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 25 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 26 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /RELEASING.md: -------------------------------------------------------------------------------- 1 | Release a New Version of DeText 2 | ========= 3 | NOTE: this guide is for DeText owners to publish new versions to PyPi 4 | 5 | Make sure you have the correct permission to release DeText packages. Only owners and maintainers can upload new releases for DeText. Check more details at https://pypi.org/project/detext/. 6 | 7 | 8 | ## When to release 9 | Releasing a new version for adding new features and bug fixes. Minor updates to the text (eg. typos in README.md) shall be combined with a later release. 10 | 11 | 12 | ## How to release 13 | 14 | ### Step 1: increment the version and publish to PyPi 15 | Releasing the package involves: 16 | * Incrementing the version of DeText 17 | * Publishing to pypi: Note that this prepares and uploads two packages (`detext` and `detext-nodep`) with the same version. `detext` is the oss package for public use. `detext-nodep` is for LI internal use only, without any dependencies such as `tensorflow` pulled in. 18 | 19 | Please ensure all changes are merged before releasing. Use the following command to release a new package: 20 | ```shell script 21 | bash pypi_release.sh 22 | ``` 23 | Where `part` is a required argument. It specifies the part of the version to increase, used in `bump2version`. Valid values are: `patch`, `minor`, and `major`. See `pypi_release.sh` for more details. 24 | 25 | #### Examples: 26 | 27 | Assume the current version in `.bumpversion.cfg` is 0.0.1. 28 | 29 | * 0.0.1 -> 0.0.2: 30 | ```shell script 31 | bash pypi_release.sh patch 32 | ``` 33 | * 0.0.1 -> 0.1.0: 34 | ```shell script 35 | bash pypi_release.sh minor 36 | ``` 37 | * 0.0.1 -> 1.0.0: 38 | ```shell script 39 | bash pypi_release.sh major 40 | ``` 41 | 42 | The `.bumpversion.cfg` is the single source of truth for versioning DeText. You do not need to manually update the version number. Both `.bumpversion.cfg` and `setup.py` will be updated automatically. More about `bump2version`: https://github.com/c4urself/bump2version. 43 | #### Best practices for versioning 44 | * Breaking changes are indicated by increasing the major number (high risk) 45 | * New non-breaking features increment the minor number (medium risk) 46 | * All other non-breaking changes increment the patch number (lowest risk). 47 | 48 | ### Step 2: merge version changes 49 | Running the releasing script automatically creates a new commit that includes the version update and a new tag. 50 | 51 | You can verify the new releases at https://pypi.org/project/detext/ and https://pypi.org/project/detext-nodep/. If the packages are successfully published, create a PR and merge to master. 52 | 53 | 54 | ### Step 3: add a Tag 55 | 56 | Once the x.x.x version is released to PyPi, please add tag in the `release` section of the repo home page. The tag should have the same version name `vx.x.x` (eg. `v1.0.12`) as in the released PyPi package. 57 | -------------------------------------------------------------------------------- /detext_model_architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/detext/671d43c5ffc83cae635174ed15c58d0bc84b76ef/detext_model_architecture.png -------------------------------------------------------------------------------- /pypi_release.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Make sure all changes are committed before running this script for releasing. 3 | # See RELEASING.md for more instructions on running the script for releasing. 4 | 5 | # Usage: 6 | # bash pypi_release.sh 7 | # must equal to "patch", "minor", or "major". 8 | # e.g.: 9 | # 0.0.1 -> 0.0.2: 10 | # bash pypi_release.sh patch 11 | # 0.0.1 -> 0.1.0: 12 | # bash pypi_release.sh minor 13 | # 0.0.1 -> 1.0.0: 14 | # bash pypi_release.sh major 15 | 16 | # Please follow the following best practices for versioning: 17 | # breaking changes are indicated by increasing the major number (high risk), 18 | # new non-breaking features increment the minor number (medium risk) 19 | # all other non-breaking changes increment the patch number (lowest risk). 20 | 21 | # Exit when any command fails 22 | set -e 23 | # Cleaning up dist directory for old releases 24 | rm -rf dist/ 25 | 26 | # Check input argument 27 | if [ "$1" != "patch" ] && [ "$1" != "minor" ] && [ "$1" != "major" ]; then 28 | echo "Must include correct argument. Eg., ash pypi_release.sh patch" 29 | exit 30 | fi 31 | 32 | # Install/upgrade needed pypi packages 33 | pip install -U bump2version twine 34 | 35 | # Increment version with bumpversion. Version format: {major}.{minor}.{patch} 36 | echo "Incrementing DeText $1 version." 37 | bump2version "$1" 38 | 39 | # Build the source distribution 40 | echo "******** Preparing pypi package..." 41 | python setup.py sdist 42 | 43 | # Build the source distribution without dependencies added for LI internal use. 44 | echo "******** Preparing pypi package without dependencies..." 45 | # Temporarily save setup.py for recover 46 | cp setup.py setup.py.tmp 47 | 48 | # Rename the pypi package name and install_requires entries 49 | if [[ "$OSTYPE" == "darwin"* ]]; then 50 | sed -i "" "s/name='detext'/name='detext-nodep'/" setup.py 51 | sed -i "" "s/install_requires=.*/install_requires=[],/g" setup.py 52 | else 53 | sed -i "s/name='detext'/name='detext-nodep'/" setup.py 54 | sed -i "s/install_requires=.*/install_requires=[],/g" setup.py 55 | fi 56 | 57 | python setup.py sdist 58 | # Recover original setup.py 59 | rm setup.py 60 | mv setup.py.tmp setup.py 61 | 62 | # Upload to pypi, username and password required (make sure you have permission for releasing detext packages) 63 | echo "******** Uploading all sdist under dist/" 64 | twine upload dist/* 65 | 66 | echo "******** Pypi package releasing succeeded!" 67 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = E121,E123,E226,W292,E402,W504,E126,E275 3 | max-line-length = 160 4 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) LinkedIn Corporation. All rights reserved. Licensed under the BSD-2 Clause license. 2 | # See LICENSE in the project root for license information. 3 | import setuptools 4 | from os import path 5 | 6 | this_directory = path.abspath(path.dirname(__file__)) 7 | with open(path.join(this_directory, 'README.md'), encoding='utf-8') as f: 8 | long_description = f.read() 9 | 10 | TF_VERSION_QUANTIFIER = '>=2.4,<2.5' 11 | PACKAGES = ['smart-arg==0.4', 'bump2version', 'twine==3.2.0', f'tf-models-official{TF_VERSION_QUANTIFIER}', 12 | f'tensorflow{TF_VERSION_QUANTIFIER}', f'tensorflow-text{TF_VERSION_QUANTIFIER}', 'tensorflow_ranking', 13 | 'future<0.14'] 14 | 15 | setuptools.setup( 16 | name='detext', 17 | long_description=long_description, 18 | long_description_content_type='text/markdown', 19 | classifiers=["Programming Language :: Python :: 3.7", 20 | "Programming Language :: Python :: 3.8", 21 | "Programming Language :: Python :: 3.9", 22 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 23 | "Topic :: Software Development :: Libraries", 24 | "Intended Audience :: Science/Research", 25 | "Intended Audience :: Developers", 26 | "License :: OSI Approved"], 27 | license='BSD-2-CLAUSE', 28 | # DO NOT CHANGE: version should be incremented by bump2version when releasing. See pypi_release.sh 29 | version='3.1.0', 30 | package_dir={'': 'src'}, 31 | packages=setuptools.find_packages('src'), 32 | include_package_data=True, 33 | install_requires=PACKAGES, 34 | tests_require=[ 35 | 'pytest', 36 | ]) 37 | -------------------------------------------------------------------------------- /src/detext/__init__.py: -------------------------------------------------------------------------------- 1 | try: 2 | __import__('pkg_resources').declare_namespace(__name__) 3 | except ImportError: 4 | from pkgutil import extend_path 5 | __path__ = extend_path(__path__, __name__) 6 | -------------------------------------------------------------------------------- /src/detext/examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/detext/671d43c5ffc83cae635174ed15c58d0bc84b76ef/src/detext/examples/__init__.py -------------------------------------------------------------------------------- /src/detext/examples/embedding_layer_example.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import sys 3 | from dataclasses import dataclass 4 | 5 | import tensorflow as tf 6 | import tensorflow_hub as hub 7 | from absl import logging 8 | from smart_arg import arg_suite 9 | 10 | from detext.layers import embedding_layer 11 | from detext.utils.layer_utils import get_sorted_dict 12 | from detext.utils.parsing_utils import InternalFtrType 13 | 14 | 15 | @arg_suite 16 | @dataclass 17 | class Args: 18 | vocab_hub_url: str # TF hub URL to vocab layer 19 | embedding_file: str # Embedding matrix in pickle format. Shape=[vocab_size, num_units] 20 | num_units: int # Dimension of embedding file 21 | trainable: bool # Whether embedding is trainable 22 | output_file: str # Output location of embedding layer 23 | 24 | 25 | def init_word_embedding(vocab_size, num_units, we_trainable, we_file=None, name_prefix="w"): 26 | """Initialize word embeddings from random initialization or pretrained word embedding. 27 | 28 | This function is only used by encoding models other than BERT 29 | """ 30 | 31 | if not we_file: 32 | embedding_name = "{}_pretrained_embedding".format(name_prefix) 33 | # Random initialization 34 | embedding = tf.compat.v1.get_variable( 35 | embedding_name, [vocab_size, num_units], dtype=tf.float32, trainable=we_trainable) 36 | else: 37 | # Initialize by pretrained word embedding 38 | embedding_name = "{}_embedding".format(name_prefix) 39 | we = pickle.load(tf.io.gfile.GFile(we_file, 'rb')) 40 | assert vocab_size == we.shape[0] and num_units == we.shape[1] 41 | embedding = tf.compat.v1.get_variable(name=embedding_name, 42 | shape=[vocab_size, num_units], 43 | dtype=tf.float32, 44 | initializer=tf.compat.v1.constant_initializer(we), 45 | trainable=we_trainable 46 | ) 47 | return embedding 48 | 49 | 50 | class ExampleEmbeddingLayer(embedding_layer.EmbeddingLayerBase): 51 | """Example embedding layer accepted by DeText 52 | 53 | To check whether a given vocab layer conforms to the DeText API, follow the test function in detext unit tests: 54 | test/layers/test_embedding_layer.testEmbeddingLayerApi() 55 | """ 56 | 57 | def __init__(self, vocab_hub_url, we_file, we_trainable, num_units, name_prefix='w'): 58 | """ Initializes the embedding layer 59 | 60 | :param vocab_hub_url Url to saved vocabulary layer. If empty string or None, no vocab layer will be loaded 61 | :param we_file Path to pretrained word embedding 62 | :param we_trainable Whether word embedding is trainable 63 | :param num_units Dimension of embedding 64 | :param name_prefix Prefix of embedding variables 65 | """ 66 | super().__init__() 67 | self.vocab_layer = hub.load(vocab_hub_url) # A vocab layer accepted by DeText. Check example/vocab_layer_example for an example 68 | self._num_units = num_units 69 | self._vocab_size = self.vocab_layer.vocab_size() 70 | 71 | self.embedding = init_word_embedding(self._vocab_size, num_units, we_trainable, we_file, name_prefix) 72 | 73 | @tf.function 74 | def tokenize_to_indices(self, inputs): 75 | """Tokenize given inputs and convert to indices 76 | 77 | Example: tokenize_to_indices(['hello world', 'sentence 1 token']) -> {RESULT: [[20, 10, pad_id], [4, 5, 6]], LENGTH: [2, 3]} 78 | 79 | :param inputs tf.Tensor(dtype=string) Shape=[batch_size] 80 | :return A dictionary containing the following key values: 81 | RESULT: tf.Tensor(dtype=int) Shape=[batch_size, sentence_len]. Tokenization and lookup result 82 | LENGTH: tf.Tensor(dtype=int) Shape=[batch_size]. Sentence lengths 83 | """ 84 | return self.vocab_layer(inputs) 85 | 86 | @tf.function(input_signature=[]) 87 | def vocab_size(self): 88 | """Returns the vocabulary size of the vocab paired with the embedding 89 | 90 | :return int/Tensor(dtype=int) 91 | """ 92 | return self._vocab_size 93 | 94 | @tf.function(input_signature=[]) 95 | def num_units(self): 96 | """Returns the number of units (embedding size) 97 | 98 | :return int/tf.Tensor(dtype=int) 99 | """ 100 | return self._num_units 101 | 102 | @tf.function(input_signature=[tf.TensorSpec(shape=(None, None), dtype=tf.dtypes.int32)]) 103 | def embedding_lookup(self, inputs): 104 | """Returns the embedding of the inputs 105 | 106 | :param inputs Tensor(dtype=int) Shape=[batch_size, sentence_len] 107 | :return Tensor(dtype=float) Shape[batch_size, sentence_len, num_units] 108 | """ 109 | return tf.nn.embedding_lookup(params=self.embedding, ids=inputs) 110 | 111 | 112 | SENTENCES = tf.constant(['hello sent1', 'build sent2']) 113 | NUM_CLS = tf.constant(0, dtype=tf.dtypes.int32) 114 | NUM_SEP = tf.constant(0, dtype=tf.dtypes.int32) 115 | MIN_LEN = tf.constant(0, dtype=tf.dtypes.int32) 116 | MAX_LEN = tf.constant(5, dtype=tf.dtypes.int32) 117 | INPUTS = get_sorted_dict({InternalFtrType.SENTENCES: SENTENCES, 118 | InternalFtrType.NUM_CLS: NUM_CLS, 119 | InternalFtrType.NUM_SEP: NUM_SEP, 120 | InternalFtrType.MIN_LEN: MIN_LEN, 121 | InternalFtrType.MAX_LEN: MAX_LEN}) 122 | TOKENIZED_RESULTS = tf.constant([[1, 2], [0, 1]]) 123 | 124 | 125 | def check_embedding_layer_api(layer: embedding_layer.EmbeddingLayerBase): 126 | """Checks whether the embedding layer has APIs required by DeText 127 | 128 | Layer that does not pass this check is definitely DeText incompatible, but passing this check does not fully guarantee DeText compatibility 129 | """ 130 | layer.num_units() 131 | layer.vocab_size() 132 | layer.tokenize_to_indices(INPUTS) 133 | layer.embedding_lookup(TOKENIZED_RESULTS) 134 | layer(INPUTS) 135 | 136 | 137 | def build_embedding_layer(vocab_hub_url, embedding_file, trainable, num_units): 138 | return ExampleEmbeddingLayer(vocab_hub_url, embedding_file, trainable, num_units) 139 | 140 | 141 | def main(argv): 142 | argument = Args.__from_argv__(argv[1:], error_on_unknown=True) 143 | argument: Args 144 | 145 | logging.info("Building embedding layer") 146 | layer = build_embedding_layer(argument.vocab_hub_url, argument.embedding_file, argument.trainable, argument.num_units) 147 | 148 | logging.info("Checking embedding layer api") 149 | check_embedding_layer_api(layer) 150 | 151 | tf.saved_model.save(layer, argument.output_file) 152 | logging.info(f"Layer saved to {argument.output_file}") 153 | 154 | 155 | if __name__ == '__main__': 156 | main(sys.argv) 157 | -------------------------------------------------------------------------------- /src/detext/examples/vocab_layer_example.py: -------------------------------------------------------------------------------- 1 | import string 2 | import sys 3 | from dataclasses import dataclass 4 | 5 | import tensorflow as tf 6 | from absl import logging 7 | from detext.layers import vocab_layer 8 | from detext.utils.layer_utils import get_sorted_dict 9 | from detext.utils.parsing_utils import InternalFtrType 10 | from smart_arg import arg_suite 11 | 12 | 13 | @arg_suite 14 | @dataclass 15 | class Args: 16 | vocab_file: str # Path of the vocabulary file which contains one token each line 17 | output_file: str # Path of the output layer 18 | 19 | CLS: str = '[CLS]' # Start of sentence token 20 | SEP: str = '[SEP]' # End of sentence token 21 | PAD: str = '[PAD]' # Padding token 22 | UNK: str = '[UNK]' # Unknown token 23 | 24 | 25 | def read_vocab(input_file: str): 26 | """Read vocabulary file and return a dict 27 | 28 | :param input_file Path to input vocab file in txt format 29 | """ 30 | vocab = {} 31 | fin = tf.io.gfile.GFile(input_file, 'r') 32 | for line in fin: 33 | line = line.strip(string.whitespace) 34 | word = line.split()[0] 35 | vocab[word] = len(vocab) 36 | fin.close() 37 | return vocab 38 | 39 | 40 | def read_tf_vocab(input_file: str, UNK: str): 41 | """Read vocabulary and return a tf hashtable 42 | 43 | :param input_file Path to input vocab file in txt format 44 | :param token for unknown words 45 | """ 46 | keys, values = [], [] 47 | fin = tf.io.gfile.GFile(input_file, 'r') 48 | for line in fin: 49 | line = line.strip(string.whitespace) 50 | word = line.split()[0] 51 | keys.append(word) 52 | values.append(len(values)) 53 | fin.close() 54 | UNK_ID = keys.index(UNK) 55 | 56 | initializer = tf.lookup.KeyValueTensorInitializer(tf.constant(keys), tf.constant(values)) 57 | vocab_table = tf.lookup.StaticHashTable(initializer, UNK_ID) 58 | return initializer, vocab_table 59 | 60 | 61 | class ExampleVocabLayer(vocab_layer.VocabLayerBase): 62 | """Example vocabulary layer accepted by DeText 63 | 64 | To check whether a given vocab layer conforms to the DeText API, follow the test function in detext unit tests: 65 | test/layers/test_vocab_layer.testVocabLayerApi() 66 | 67 | Input text will be tokenized and convert to indices by this layer. Whitespace split is used as tokenization 68 | """ 69 | 70 | def __init__(self, CLS: str, SEP: str, PAD: str, UNK: str, vocab_file: str): 71 | """ Initializes the vocabulary layer 72 | 73 | :param CLS Token that represents the start of a sentence 74 | :param SEP Token that represents the end of a segment 75 | :param PAD Token that represents padding 76 | :param UNK Token that represents unknown tokens 77 | :param vocab_file Path to the vocabulary file 78 | """ 79 | super().__init__() 80 | self._vocab_table_initializer, self.vocab_table = read_tf_vocab(vocab_file, UNK) 81 | 82 | self._CLS = CLS 83 | self._SEP = SEP 84 | self._PAD = PAD 85 | 86 | py_vocab_table = read_vocab(vocab_file) 87 | self._pad_id = py_vocab_table[PAD] 88 | self._cls_id = py_vocab_table[CLS] if CLS else -1 89 | self._sep_id = py_vocab_table[SEP] if SEP else -1 90 | self._vocab_size = len(py_vocab_table) 91 | 92 | @tf.function(input_signature=[]) 93 | def pad_id(self): 94 | """Returns the index of the padding token 95 | 96 | :return int/tf.Tensor(dtype=int) 97 | """ 98 | return self._pad_id 99 | 100 | @tf.function(input_signature=[]) 101 | def vocab_size(self): 102 | """Returns the vocabulary size 103 | 104 | :return int/tf.Tensor(dtype=int) 105 | """ 106 | return self._vocab_size 107 | 108 | def cls_id(self): 109 | """Returns the index of CLS token 110 | 111 | :return int/tf.Tensor(dtype=int) 112 | """ 113 | return self._cls_id 114 | 115 | def sep_id(self): 116 | """Returns the index of SEP token 117 | 118 | :return int/tf.Tensor(dtype=int) 119 | """ 120 | return self._sep_id 121 | 122 | def _vocab_lookup(self, inputs): 123 | """Converts given input tokens into indices 124 | 125 | :param inputs Tensor(dtype=string) Shape=[batch_size, sentence_len]. This is the output from _tokenize() method 126 | :return Tensor(dtype=int) Shape=[batch_size, sentence_len] 127 | """ 128 | return self.vocab_table.lookup(inputs) 129 | 130 | def _tokenize(self, inputs): 131 | """Converts given input into tokens 132 | 133 | :param inputs Tensor(dtype=string) Shape=[batch_size] 134 | :return Tensor(dtype=string) Shape=[batch_size, sentence_len]. Output should be either dense or sparse. Ragged tensor is not supported for now 135 | """ 136 | return tf.strings.split(inputs).to_sparse() 137 | 138 | 139 | SENTENCES = tf.constant(['hello sent1', 'build sent2']) 140 | NUM_CLS = tf.constant(0, dtype=tf.dtypes.int32) 141 | NUM_SEP = tf.constant(0, dtype=tf.dtypes.int32) 142 | MIN_LEN = tf.constant(0, dtype=tf.dtypes.int32) 143 | MAX_LEN = tf.constant(5, dtype=tf.dtypes.int32) 144 | INPUTS = get_sorted_dict({InternalFtrType.SENTENCES: SENTENCES, 145 | InternalFtrType.NUM_CLS: NUM_CLS, 146 | InternalFtrType.NUM_SEP: NUM_SEP, 147 | InternalFtrType.MIN_LEN: MIN_LEN, 148 | InternalFtrType.MAX_LEN: MAX_LEN}) 149 | 150 | 151 | def check_vocab_layer_api(layer: vocab_layer.VocabLayerBase): 152 | """Checks whether the vocab layer has APIs required by DeText 153 | 154 | Layer that does not pass this check is definitely DeText incompatible, but passing this check does not fully guarantee DeText compatibility 155 | """ 156 | layer(INPUTS) 157 | 158 | 159 | def build_vocab_layer(CLS: str, SEP: str, PAD: str, UNK: str, vocab_file: str): 160 | return ExampleVocabLayer(CLS, SEP, PAD, UNK, vocab_file) 161 | 162 | 163 | def main(argv): 164 | argument = Args.__from_argv__(argv[1:], error_on_unknown=True) 165 | argument: Args 166 | 167 | logging.info("Building vocab layer") 168 | layer = build_vocab_layer(CLS=argument.CLS, SEP=argument.SEP, PAD=argument.PAD, UNK=argument.UNK, vocab_file=argument.vocab_file) 169 | 170 | logging.info("Checking vocab layer api") 171 | check_vocab_layer_api(layer) 172 | 173 | tf.saved_model.save(layer, argument.output_file) 174 | logging.info(f"Layer saved to {argument.output_file}") 175 | 176 | 177 | if __name__ == '__main__': 178 | main(sys.argv) 179 | -------------------------------------------------------------------------------- /src/detext/layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/detext/671d43c5ffc83cae635174ed15c58d0bc84b76ef/src/detext/layers/__init__.py -------------------------------------------------------------------------------- /src/detext/layers/feature_grouper.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from detext.utils.parsing_utils import InputFtrType 4 | 5 | 6 | class FeatureGrouper(tf.keras.layers.Layer): 7 | """Feature grouper 8 | 9 | Features of the same type (e.g. dense numeric features) will be concatenated into one vector 10 | """ 11 | 12 | def __init__(self): 13 | super(FeatureGrouper, self).__init__() 14 | 15 | def process_single_ftr(self, inputs, ftr_type, func): 16 | """Apply function on the input wrt the given feature type """ 17 | if ftr_type in inputs: 18 | inputs[ftr_type] = func(inputs[ftr_type]) 19 | return inputs 20 | 21 | def process_list_ftr(self, inputs, ftr_type, func): 22 | """Applies function on every element of the input list wrt the given feature types 23 | 24 | The function is applied on each feature tensor (corresponding to one feature name) 25 | """ 26 | if ftr_type in inputs: 27 | result = [] 28 | for tensor in inputs[ftr_type]: 29 | result.append(func(tensor)) 30 | inputs[ftr_type] = result 31 | return inputs 32 | 33 | def call(self, inputs, *args, **kwargs): 34 | """Processes input features """ 35 | inputs = inputs.copy() 36 | # Concatenate features that supports a list of inputs 37 | # E.g., users may have two arrays of dense features, one named "demographics", one named "professional". Since DeText treat them as dense features, 38 | # we concatenate them into one array 39 | self.process_single_ftr(inputs, InputFtrType.DENSE_FTRS_COLUMN_NAMES, concat_on_last_axis_dense) 40 | 41 | return inputs 42 | 43 | 44 | def concat_on_last_axis_sparse(tensor_list): 45 | """Concatenates list of sparse tensors on the last axis""" 46 | return tf.sparse.concat(sp_inputs=tensor_list, axis=-1) 47 | 48 | 49 | def concat_on_last_axis_dense(tensor_list): 50 | """Concatenates list of dense tensors on the last axis""" 51 | return tf.concat(tensor_list, axis=-1) 52 | -------------------------------------------------------------------------------- /src/detext/layers/feature_normalizer.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | class FeatureNormalizer(tf.keras.layers.Layer): 5 | """Feature normalizer to normalize dense features 6 | 7 | This layer improves numeric stability and is useful for network convergence 8 | """ 9 | 10 | def __init__(self, ftr_mean, ftr_std): 11 | super(FeatureNormalizer, self).__init__() 12 | self.ftr_mean = tf.constant(ftr_mean, dtype=tf.dtypes.float32) 13 | self.ftr_std = tf.constant(ftr_std, dtype=tf.dtypes.float32) 14 | 15 | def call(self, inputs, **kwargs): 16 | """ Normalizes inputs to (inputs - self.ftr_mean) / self.ftr_std 17 | 18 | :param inputs: Tensor(tf.float32). Shape=[..., num_ftrs] 19 | :param kwargs: Dummy args for suppress warning for method overriding 20 | :return: Normalized input 21 | """ 22 | return (inputs - self.ftr_mean) / self.ftr_std 23 | -------------------------------------------------------------------------------- /src/detext/layers/feature_rescaler.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | class FeatureRescaler(tf.keras.layers.Layer): 5 | """Feature rescaler to rescale dense features 6 | 7 | This layer improves numeric stability and is useful for network convergence 8 | """ 9 | 10 | def __init__(self, num_ftrs, prefix=''): 11 | super(FeatureRescaler, self).__init__() 12 | self._num_ftrs = num_ftrs 13 | self._initial_w = 1.0 14 | self._initial_b = 0.0 15 | 16 | self.norm_w = tf.compat.v1.get_variable(f"{prefix}norm_w", [num_ftrs], dtype=tf.float32, 17 | initializer=tf.compat.v1.constant_initializer(self._initial_w)) 18 | self.norm_b = tf.compat.v1.get_variable(f"{prefix}norm_b", [num_ftrs], dtype=tf.float32, 19 | initializer=tf.compat.v1.constant_initializer(self._initial_b)) 20 | 21 | def call(self, inputs, **kwargs): 22 | """ Rescales inputs to tf.tanh(inputs * self.norm_w + self.norm_b) 23 | 24 | :param inputs: Tensor(tf.float32). Shape=[..., num_ftrs] 25 | :param kwargs: Dummy args for suppress warning for method overriding 26 | :return: Rescaled input 27 | """ 28 | return tf.tanh(inputs * self.norm_w + self.norm_b) 29 | -------------------------------------------------------------------------------- /src/detext/layers/id_embed_layer.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from detext.layers.embedding_layer import create_embedding_layer 4 | from detext.utils.parsing_utils import InputFtrType, InternalFtrType 5 | 6 | DEFAULT_MIN_LEN = 1 7 | DEFAULT_MAX_LEN = 100 8 | 9 | 10 | class IdEmbedLayer(tf.keras.layers.Layer): 11 | """ ID embedding layer""" 12 | 13 | def __init__(self, num_id_fields, embedding_layer_param, embedding_hub_url_for_id_ftr): 14 | """ Initializes the layer 15 | 16 | For more details on parameters, check args.py 17 | """ 18 | super(IdEmbedLayer, self).__init__() 19 | self._num_id_fields = num_id_fields 20 | 21 | self.min_len = DEFAULT_MIN_LEN 22 | self.max_len = DEFAULT_MAX_LEN 23 | self.num_cls_sep = 0 24 | 25 | if num_id_fields: 26 | self.embedding = create_embedding_layer(embedding_layer_param, embedding_hub_url_for_id_ftr) 27 | self.id_ftr_size = self.embedding.num_units() 28 | 29 | def call(self, inputs, **kwargs): 30 | """ Applies ID embedding lookup and summation on document and user fields 31 | 32 | :param inputs: Dict A mapping that contains the following key: 33 | doc_id_fields: list(Tensor(dtype=string)) List of document fields. Each has shape=[batch_size, max_group_size] 34 | user_id_fields: list(Tensor(dtype=string)) List of user fields. Each has shape=[batch_size] 35 | :return: doc_ftrs, user_ftrs 36 | """ 37 | doc_id_fields = inputs.get(InputFtrType.DOC_ID_COLUMN_NAMES, None) 38 | user_id_fields = inputs.get(InputFtrType.USER_ID_COLUMN_NAMES, None) 39 | 40 | if self._num_id_fields == 0: 41 | assert doc_id_fields is None and user_id_fields is None, "Document ID fields and user ID fields must be None when there's no id field" 42 | 43 | user_ftrs = self.apply_embed_on_user_id(user_id_fields) if user_id_fields is not None else None 44 | doc_ftrs = self.apply_embed_on_doc_id(doc_id_fields) if doc_id_fields is not None else None 45 | return doc_ftrs, user_ftrs 46 | 47 | def apply_embedding(self, inputs): 48 | """Applies embedding on give inputs 49 | 50 | :param inputs Tensor(dtype=string) Shape=[batch_size] 51 | :return Tensor(dtype=string) Shape=[batch_size, sentence_len, num_units_for_id_ftr] 52 | """ 53 | embedding_result = self.embedding({ 54 | InternalFtrType.SENTENCES: inputs, 55 | InternalFtrType.NUM_CLS: self.num_cls_sep, 56 | InternalFtrType.NUM_SEP: self.num_cls_sep, 57 | InternalFtrType.MIN_LEN: self.min_len, 58 | InternalFtrType.MAX_LEN: self.max_len, 59 | }) 60 | 61 | seq_length = embedding_result[InternalFtrType.LENGTH] 62 | max_seq_len = tf.math.reduce_max(seq_length) 63 | seq_mask = tf.expand_dims(tf.sequence_mask(seq_length, max_seq_len, dtype=tf.float32), axis=-1) 64 | seq_length = tf.expand_dims(tf.cast(seq_length, dtype=tf.dtypes.float32), axis=-1) 65 | 66 | user_id_embeddings = embedding_result[InternalFtrType.EMBEDDED] 67 | sum_user_id_embedding = tf.reduce_sum( 68 | input_tensor=user_id_embeddings * seq_mask, axis=1) # [batch_size, num_units_for_id_ftr] 69 | user_id_avg_embedding = tf.math.divide_no_nan(sum_user_id_embedding, seq_length) # [batch_size, num_units_for_id_ftr] 70 | return user_id_avg_embedding 71 | 72 | def apply_embed_on_user_id(self, user_id_fields): 73 | """Applies embedding lookup and averaging for user id features 74 | 75 | :return Tensor Shape=[batch_size, num_user_id_fields, num_units_for_id_ftr] 76 | """ 77 | user_ftrs = [] 78 | for i, user_field in enumerate(user_id_fields): 79 | user_id_avg_embedding = self.apply_embedding(user_field) 80 | user_ftrs.append(user_id_avg_embedding) 81 | return tf.stack(user_ftrs, axis=1) 82 | 83 | def apply_embed_on_doc_id(self, doc_id_fields): 84 | """Applies embedding lookup and averaging for doc id features 85 | 86 | :return Tensor Shape=[batch_size, max_group_size, num_doc_id_fields, num_units_for_id_ftr] 87 | """ 88 | doc_ftrs = [] 89 | for i, doc_field in enumerate(doc_id_fields): 90 | doc_field_shape = tf.shape(doc_field) 91 | reshape_doc_field = tf.reshape(doc_field, shape=[doc_field_shape[0] * doc_field_shape[1]]) 92 | doc_id_avg_embedding = self.apply_embedding(reshape_doc_field) 93 | doc_id_avg_embedding = tf.reshape(doc_id_avg_embedding, shape=[doc_field_shape[0], doc_field_shape[1], self.id_ftr_size]) 94 | doc_ftrs.append(doc_id_avg_embedding) 95 | return tf.stack(doc_ftrs, axis=2) 96 | -------------------------------------------------------------------------------- /src/detext/layers/multi_layer_perceptron.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import tensorflow as tf 4 | 5 | 6 | class MultiLayerPerceptron(tf.keras.layers.Layer): 7 | """ A multi layer perceptron """ 8 | 9 | def __init__(self, num_hidden: List[int], activations: List, prefix: str = ''): 10 | """ Initializes the layer 11 | 12 | :param num_hidden: list of hidden layer sizes 13 | :param activations: list of activations for dense layer 14 | :param prefix: prefix of hidden layer name 15 | """ 16 | super(MultiLayerPerceptron, self).__init__() 17 | assert len(num_hidden) == len(activations), "num hidden and activations must contain the same number of elements" 18 | 19 | self.mlp = [] 20 | for i, (hidden_size, activation) in enumerate(zip(num_hidden, activations)): 21 | if hidden_size == 0: 22 | continue 23 | layer = tf.keras.layers.Dense(units=hidden_size, use_bias=True, activation=activation, 24 | name=f'{prefix}hidden_projection_{str(i)}') 25 | self.mlp.append(layer) 26 | 27 | def call(self, inputs, **kwargs): 28 | """ Applies multi-layer perceptron on given inputs 29 | 30 | :return output Shape=inputs.shape[:-1] + [num_hidden[-1]] 31 | """ 32 | x = inputs 33 | for layer in self.mlp: 34 | x = layer(x) 35 | return x 36 | -------------------------------------------------------------------------------- /src/detext/layers/output_transform_layer.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from detext.utils.parsing_utils import TaskType, OutputFtrType 4 | 5 | 6 | def _ranking_output_transform(inputs): 7 | """ 8 | Transforms the outputs for DeText ranking task. 9 | :param inputs: Tensor with shape [batch_size, list_size, 1]. 10 | :return: final output for ranking, with shape [batch_size, list_size] 11 | """ 12 | # shape: [batch_size, list_size] 13 | outputs = tf.squeeze(inputs, axis=-1) 14 | return {OutputFtrType.DETEXT_RANKING_SCORES: outputs} 15 | 16 | 17 | def _classification_output_transform(inputs): 18 | """ 19 | Transforms the outputs for DeText classification task. 20 | :param inputs: Tensor with shape [batch_size, 1, num_classes]. 21 | :return: final output for classification, with shape [batch_size, num_classes] 22 | """ 23 | # shape: [batch_size, num_classes] 24 | inputs = tf.squeeze(inputs, axis=-2) 25 | # Return logits, softmax, and label predictions for classification 26 | return {OutputFtrType.DETEXT_CLS_PROBABILITIES: tf.nn.softmax(inputs), 27 | OutputFtrType.DETEXT_CLS_LOGITS: inputs, 28 | OutputFtrType.DETEXT_CLS_PREDICTED_LABEL: tf.argmax(inputs, axis=-1)} 29 | 30 | 31 | def _multilabel_classification_output_transform(inputs): 32 | """ 33 | Transforms the outputs for DeText multi-label classification task. 34 | :param inputs: Tensor with shape [batch_size, 1, num_classes]. 35 | :return: final output for classification, with shape [batch_size, num_classes] 36 | """ 37 | # shape: [batch_size, num_classes] 38 | inputs = tf.squeeze(inputs, axis=-2) 39 | # Return probabilities, logits, and label predictions (as bool tensor) for multi-label classification 40 | return {OutputFtrType.DETEXT_CLS_PROBABILITIES: tf.nn.sigmoid(inputs), 41 | OutputFtrType.DETEXT_CLS_LOGITS: inputs, 42 | # 0.'s Tensor with 1.'s to represent predicted label: 43 | OutputFtrType.DETEXT_CLS_PREDICTED_LABELS: tf.cast(tf.math.greater(inputs, tf.constant(0.0)), dtype=tf.float32)} 44 | 45 | 46 | def _binary_classification_output_transform(inputs): 47 | """ 48 | Transforms the outputs for DeText binary classification task. 49 | :param inputs: Tensor with shape [batch_size, 1, num_classes]. 50 | :return: final output for classification, with shape [batch_size] 51 | """ 52 | # shape: [batch_size] 53 | inputs = tf.squeeze(inputs, axis=[1, 2]) 54 | # Return logits and sigmoid predictions for classification 55 | return {OutputFtrType.DETEXT_CLS_PROBABILITIES: tf.nn.sigmoid(inputs), 56 | OutputFtrType.DETEXT_CLS_LOGITS: inputs} 57 | 58 | 59 | class OutputTransformLayer(tf.keras.layers.Layer): 60 | """ Output transform layer that prepares the final output based on task types (classification or ranking) """ 61 | 62 | def __init__(self, task_type): 63 | super().__init__() 64 | self.task_type = task_type 65 | 66 | def call(self, inputs): 67 | # Handle outputs for different task types 68 | task_type_to_output_transforms = { 69 | TaskType.RANKING: _ranking_output_transform, 70 | TaskType.CLASSIFICATION: _classification_output_transform, 71 | TaskType.BINARY_CLASSIFICATION: _binary_classification_output_transform, 72 | TaskType.MULTILABEL_CLASSIFICATION: _multilabel_classification_output_transform 73 | } 74 | return task_type_to_output_transforms[self.task_type](inputs) 75 | -------------------------------------------------------------------------------- /src/detext/layers/scoring_layer.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from detext.layers import interaction_layer 4 | from detext.utils.parsing_utils import InputFtrType 5 | 6 | 7 | class ScoringLayer(tf.keras.layers.Layer): 8 | """Scoring layer that performs linear projection from interaction features to scalar scores""" 9 | 10 | def __init__(self, task_ids, num_classes): 11 | super(ScoringLayer, self).__init__() 12 | if task_ids is None: 13 | task_ids = [0] 14 | 15 | self._task_ids = task_ids 16 | self._num_classes = num_classes 17 | self.final_projections = self.create_final_projections(task_ids, num_classes) 18 | 19 | def compute_final_scores(self, ftrs_to_score, task_id): 20 | """Returns final scores given interaction outputs wrt the given task""" 21 | if len(self._task_ids) <= 1 or task_id is None: 22 | return self.final_projections[0](ftrs_to_score[self.get_scoring_ftrs_key(0)]) 23 | 24 | # Multitask 25 | data_shape = tf.shape(ftrs_to_score[self.get_scoring_ftrs_key(0)]) 26 | batch_size = data_shape[0] 27 | max_group_size = data_shape[1] 28 | 29 | score_shape = [batch_size, tf.maximum(max_group_size, 1), tf.maximum(1, self._num_classes)] 30 | scores = tf.zeros(shape=score_shape, dtype="float32") 31 | for i, ith_task_id in enumerate(self._task_ids): 32 | task_score = self.final_projections[i](ftrs_to_score[self.get_scoring_ftrs_key(i)]) 33 | task_mask = tf.cast(tf.equal(task_id, int(ith_task_id)), dtype=tf.float32) 34 | # Broadcast task_mask for compatible tensor shape with scores tensor 35 | task_mask = tf.transpose(a=tf.broadcast_to(task_mask, score_shape[::-1])) 36 | scores += task_mask * task_score 37 | return scores 38 | 39 | @staticmethod 40 | def create_final_projections(task_ids, num_classes): 41 | """Returns a list of final projection layers for given task_ids """ 42 | final_projections = [] 43 | # When task_ids is None, treat it as a special case of multitask learning when there's only one task 44 | if task_ids is None: 45 | task_ids = [0] 46 | # Set up layers for each task 47 | for task_id in task_ids: 48 | final_projections.append(tf.keras.layers.Dense(num_classes, name=f"task_{task_id}_final_projection")) 49 | return final_projections 50 | 51 | @staticmethod 52 | def get_scoring_ftrs_key(i): 53 | return interaction_layer.InteractionLayer.get_interaction_ftrs_key(i) 54 | 55 | def get_ftrs_to_score(self, inputs): 56 | return interaction_layer.InteractionLayer.get_interaction_ftrs(inputs, self._task_ids) 57 | 58 | def call(self, inputs, **kwargs): 59 | """ Projects features linearly to scores (scalar) 60 | 61 | :param inputs: Map { 62 | InternalFtrType.FTRS_TO_SCORE: Tensor(dtype=float32, shape=[batch_size, list_size, num_features]) 63 | InternalFtrType.TASK_ID: Tensor(dtype=int32, shape=[batch_size]) 64 | } 65 | :param kwargs: 66 | :return: scores. Tensor(dtype=float32, shape=[batch_size, list_size]) 67 | """ 68 | ftrs_to_score = self.get_ftrs_to_score(inputs) 69 | task_id = inputs.get(InputFtrType.TASK_ID_COLUMN_NAME, None) 70 | 71 | return self.compute_final_scores(ftrs_to_score, task_id) 72 | -------------------------------------------------------------------------------- /src/detext/layers/shallow_tower_layer.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from .sparse_embedding_layer import SparseEmbeddingLayer 4 | from detext.utils.parsing_utils import InputFtrType 5 | 6 | 7 | class ShallowTowerLayer(tf.keras.layers.Layer): 8 | """Shallow tower layer including only linear combination of features """ 9 | 10 | def __init__(self, nums_shallow_tower_sparse_ftrs, num_classes, initializer='glorot_uniform'): 11 | super(ShallowTowerLayer, self).__init__() 12 | self.sparse_linear = SparseEmbeddingLayer(num_classes, nums_shallow_tower_sparse_ftrs, initializer, 'sum', 'sum') 13 | 14 | def call(self, inputs, **kwargs): 15 | sparse_ftrs = inputs[InputFtrType.SHALLOW_TOWER_SPARSE_FTRS_COLUMN_NAMES] 16 | return self.sparse_linear({InputFtrType.SPARSE_FTRS_COLUMN_NAMES: sparse_ftrs}) 17 | -------------------------------------------------------------------------------- /src/detext/layers/sparse_embedding_layer.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from detext.utils.parsing_utils import InputFtrType 4 | 5 | 6 | class SparseEmbeddingLayer(tf.keras.layers.Layer): 7 | """Sparse embedding layer that accepts a sparse tensor as input, looks up embeddings and combines embeddings using given combiner""" 8 | 9 | def __init__(self, sparse_embedding_size, nums_sparse_ftrs, initializer, sparse_embedding_cross_ftr_combiner, sparse_embedding_same_ftr_combiner): 10 | """ Initializes SparseEmbeddingLayer 11 | 12 | :param sparse_embedding_size: size of sparse embedding 13 | :param nums_sparse_ftrs: numbers of sparse features for each column 14 | :param initializer: initializer for embeddings 15 | :param sparse_embedding_cross_ftr_combiner: how to combine the column embeddings. E.g., sum 16 | :param sparse_embedding_same_ftr_combiner: how to combine the look up embeddings within the column. E.g., sum/mean 17 | """ 18 | super().__init__() 19 | self._sparse_embedding_cross_ftr_combiner = sparse_embedding_cross_ftr_combiner 20 | self._sparse_embedding_same_ftr_combiner = sparse_embedding_same_ftr_combiner 21 | self._nums_sparse_ftrs = nums_sparse_ftrs 22 | 23 | combiner2sparse_embedding_size = { 24 | 'sum': sparse_embedding_size, 25 | 'concat': sparse_embedding_size * len(nums_sparse_ftrs) 26 | } 27 | self._sparse_embedding_size = combiner2sparse_embedding_size[sparse_embedding_cross_ftr_combiner] 28 | self._combiner_fn = self._get_embedding_combiner_fn(sparse_embedding_cross_ftr_combiner) 29 | 30 | for i, num_sparse_ftrs in enumerate(nums_sparse_ftrs): 31 | setattr(self, self._get_embedding_weights_name(i), self.add_weight( 32 | name=self._get_embedding_weights_name(i), 33 | shape=[num_sparse_ftrs, sparse_embedding_size], 34 | dtype=tf.dtypes.float32, 35 | initializer=initializer, trainable=True 36 | )) 37 | 38 | def _get_embedding_combiner_fn(self, sparse_embedding_combiner): 39 | def concat_embedding(lst): 40 | return tf.concat(lst, axis=-1) 41 | 42 | def sum_embedding(lst): 43 | return tf.reduce_sum(lst, axis=0) 44 | 45 | combiner2embedding_combiner_fn = { 46 | 'sum': sum_embedding, 47 | 'concat': concat_embedding 48 | } 49 | 50 | return combiner2embedding_combiner_fn[sparse_embedding_combiner] 51 | 52 | def _get_embedding_weights_name(self, i): 53 | return f"sparse_embedding_weights_{i}" 54 | 55 | def _get_embedding_weights(self, i): 56 | return getattr(self, self._get_embedding_weights_name(i)) 57 | 58 | def call(self, inputs, **kwargs): 59 | """Looks up and combines embeddings corresponding to given sparse features 60 | 61 | :param inputs: Map containing { 62 | InputFtrType.SPARSE_FTRS_COLUMN_NAMES: List(tf.SparseFeature) 63 | }. The last dimension of the sparse feature should be <= total_num_sparse_ftrs 64 | """ 65 | sparse_ftrs = inputs[InputFtrType.SPARSE_FTRS_COLUMN_NAMES] 66 | sparse_embeddings = [] 67 | for i, sparse_ftr in enumerate(sparse_ftrs): 68 | dense_shape = tf.shape(sparse_ftr, out_type=tf.dtypes.int64) 69 | values = sparse_ftr.indices[:, -1] 70 | sparse_ids = tf.sparse.SparseTensor(indices=sparse_ftr.indices, values=values, 71 | dense_shape=dense_shape) 72 | sparse_weights = sparse_ftr 73 | sparse_embedding = tf.nn.safe_embedding_lookup_sparse(self._get_embedding_weights(i), sparse_ids=sparse_ids, 74 | sparse_weights=sparse_weights, combiner=self._sparse_embedding_same_ftr_combiner) 75 | sparse_embeddings.append(sparse_embedding) 76 | 77 | return self._combiner_fn(sparse_embeddings) 78 | -------------------------------------------------------------------------------- /src/detext/metaclass.py: -------------------------------------------------------------------------------- 1 | import threading 2 | 3 | 4 | class SingletonMeta(type): 5 | _instances = {} 6 | _instance_lock = threading.Lock() 7 | 8 | def __call__(cls, *args, **kargs): 9 | # move lock inside can reduce the race contention. 10 | if cls not in cls._instances: 11 | with cls._instance_lock: 12 | if cls not in cls._instances: 13 | cls._instances[cls] = super(SingletonMeta, cls).__call__(*args, **kargs) 14 | return cls._instances[cls] 15 | -------------------------------------------------------------------------------- /src/detext/run_detext.py: -------------------------------------------------------------------------------- 1 | """ 2 | Overall pipeline to train the model. It parses arguments, and trains a DeText model. 3 | """ 4 | 5 | import tempfile 6 | from dataclasses import dataclass, asdict 7 | 8 | import sys 9 | import tensorflow as tf 10 | from absl import logging 11 | from official.utils.misc import distribution_utils 12 | from smart_arg import arg_suite 13 | 14 | from detext.args import DatasetArg, FeatureArg, NetworkArg, OptimizationArg 15 | from detext.train import train 16 | from detext.utils import parsing_utils, distributed_utils 17 | 18 | 19 | @arg_suite 20 | @dataclass 21 | class DetextArg(DatasetArg, FeatureArg, NetworkArg, OptimizationArg): 22 | """ 23 | DeText: a Deep Text understanding framework for NLP related ranking, classification, and language generation tasks. 24 | 25 | It leverages semantic matching using deep neural networks to understand member intents in search and recommender systems. 26 | As a general NLP framework, currently DeText can be applied to many tasks, including search & recommendation ranking, 27 | multi-class classification and query understanding tasks. 28 | """ 29 | 30 | def __post_init__(self): 31 | """ Post initializes fields 32 | 33 | This method is automatically called by smart-arg once the argument is created by parsing cli or the constructor 34 | """ 35 | logging.info(f"Start __post_init__ the argument now: {self}") 36 | super().__post_init__() 37 | 38 | 39 | def main(argv): 40 | """ This is the main method for training the model. 41 | 42 | :param argv: training parameters 43 | """ 44 | 45 | argument = DetextArg.__from_argv__(argv[1:], error_on_unknown=False) 46 | run_detext(argument) 47 | 48 | 49 | def run_detext(argument): 50 | """ Launches DeText training program""" 51 | logging.set_verbosity(logging.INFO) 52 | logging.info(f"Args:\n {argument}") 53 | 54 | hparams = parsing_utils.HParams(**asdict(argument)) 55 | 56 | strategy = distribution_utils.get_distribution_strategy(hparams.distribution_strategy, num_gpus=hparams.num_gpu, all_reduce_alg=hparams.all_reduce_alg) 57 | logging.info(f"***********Num replica: {strategy.num_replicas_in_sync}***********") 58 | create_output_dir(hparams.resume_training, hparams.out_dir, strategy) 59 | save_hparams(hparams.out_dir, hparams, strategy) 60 | 61 | logging.info("***********DeText Training***********") 62 | train.train(strategy, hparams) 63 | 64 | 65 | def save_hparams(out_dir, hparams, strategy): 66 | """Saves hparams to out_dir""" 67 | is_chief = distributed_utils.is_chief(strategy) 68 | if not is_chief: 69 | out_dir = tempfile.mkdtemp() 70 | 71 | parsing_utils.save_hparams(out_dir, hparams) 72 | 73 | if not is_chief: 74 | tf.io.gfile.remove(parsing_utils._get_hparam_path(out_dir)) 75 | 76 | 77 | def create_output_dir(resume_training, out_dir, strategy): 78 | """Creates output directory if not exists""" 79 | is_chief = distributed_utils.is_chief(strategy) 80 | if not is_chief: 81 | out_dir = tempfile.mkdtemp() 82 | 83 | if not resume_training: 84 | if tf.io.gfile.exists(out_dir): 85 | logging.info("Removing previous output directory...") 86 | tf.io.gfile.rmtree(out_dir) 87 | 88 | # If output directory deleted or does not exist, create the directory. 89 | if not tf.io.gfile.exists(out_dir): 90 | logging.info('Creating dirs recursively at: {0}'.format(out_dir)) 91 | tf.io.gfile.makedirs(out_dir) 92 | 93 | if not is_chief: 94 | tf.io.gfile.rmtree(out_dir) 95 | 96 | 97 | if __name__ == '__main__': 98 | main(sys.argv) 99 | -------------------------------------------------------------------------------- /src/detext/train/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/detext/671d43c5ffc83cae635174ed15c58d0bc84b76ef/src/detext/train/__init__.py -------------------------------------------------------------------------------- /src/detext/train/constant.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow_ranking as tfr 3 | 4 | from detext.utils.parsing_utils import InputFtrType 5 | from detext.metaclass import SingletonMeta 6 | 7 | 8 | class Constant(metaclass=SingletonMeta): 9 | def __init__(self): 10 | self._LABEL_PADDING = tfr.data._PADDING_LABEL 11 | self._DEFAULT_WEIGHT_FTR_NAME = InputFtrType.WEIGHT_COLUMN_NAME 12 | self._DEFAULT_UID_FTR_NAME = InputFtrType.UID_COLUMN_NAME 13 | 14 | self._FTR_TYPE2PADDED_SHAPE = { 15 | InputFtrType.QUERY_COLUMN_NAME: tf.TensorShape([]), 16 | InputFtrType.USER_TEXT_COLUMN_NAMES: tf.TensorShape([]), 17 | InputFtrType.DOC_TEXT_COLUMN_NAMES: tf.TensorShape([None]), 18 | InputFtrType.USER_ID_COLUMN_NAMES: tf.TensorShape([]), 19 | InputFtrType.DOC_ID_COLUMN_NAMES: tf.TensorShape([None]), 20 | InputFtrType.TASK_ID_COLUMN_NAME: tf.TensorShape([]), 21 | } 22 | 23 | self._FTR_TYPE2PADDED_VALUE = { 24 | InputFtrType.QUERY_COLUMN_NAME: '', 25 | InputFtrType.USER_TEXT_COLUMN_NAMES: '', 26 | InputFtrType.DOC_TEXT_COLUMN_NAMES: '', 27 | InputFtrType.USER_ID_COLUMN_NAMES: '', 28 | InputFtrType.DOC_ID_COLUMN_NAMES: '', 29 | InputFtrType.TASK_ID_COLUMN_NAME: tf.cast(0, tf.int64), 30 | InputFtrType.DENSE_FTRS_COLUMN_NAMES: 0.0, 31 | } 32 | 33 | self._RANKING_FTR_TYPE_TO_DENSE_DEFAULT_VAL = { 34 | InputFtrType.QUERY_COLUMN_NAME: '', 35 | InputFtrType.USER_TEXT_COLUMN_NAMES: '', 36 | InputFtrType.USER_ID_COLUMN_NAMES: '', 37 | InputFtrType.DOC_TEXT_COLUMN_NAMES: '', 38 | InputFtrType.DOC_ID_COLUMN_NAMES: '', 39 | InputFtrType.DENSE_FTRS_COLUMN_NAMES: 0.0, 40 | InputFtrType.LABEL_COLUMN_NAME: self._LABEL_PADDING, 41 | InputFtrType.SPARSE_FTRS_COLUMN_NAMES: 0 42 | } 43 | 44 | self._CLASSIFICATION_FTR_TYPE_TO_DENSE_DEFAULT_VAL = { 45 | InputFtrType.DENSE_FTRS_COLUMN_NAMES: 0.0, 46 | InputFtrType.SPARSE_FTRS_COLUMN_NAMES: 0 47 | } 48 | 49 | self._RANKING_FTR_TYPE_TO_SCHEMA = { 50 | InputFtrType.WEIGHT_COLUMN_NAME: tf.io.FixedLenFeature(shape=[1], dtype=tf.float32), 51 | InputFtrType.UID_COLUMN_NAME: tf.io.FixedLenFeature(shape=[1], dtype=tf.int64), 52 | InputFtrType.TASK_ID_COLUMN_NAME: tf.io.FixedLenFeature(shape=[1], dtype=tf.int64), 53 | 54 | InputFtrType.LABEL_COLUMN_NAME: tf.io.VarLenFeature(dtype=tf.float32), 55 | 56 | InputFtrType.QUERY_COLUMN_NAME: tf.io.VarLenFeature(dtype=tf.string), 57 | InputFtrType.USER_TEXT_COLUMN_NAMES: tf.io.VarLenFeature(dtype=tf.string), 58 | InputFtrType.USER_ID_COLUMN_NAMES: tf.io.VarLenFeature(dtype=tf.string), 59 | 60 | InputFtrType.DOC_TEXT_COLUMN_NAMES: tf.io.VarLenFeature(dtype=tf.string), 61 | InputFtrType.DOC_ID_COLUMN_NAMES: tf.io.VarLenFeature(dtype=tf.string), 62 | 63 | InputFtrType.DENSE_FTRS_COLUMN_NAMES: tf.io.VarLenFeature(dtype=tf.float32), 64 | } 65 | 66 | self._CLASSIFICATION_FTR_TYPE_TO_SCHEMA = { 67 | InputFtrType.WEIGHT_COLUMN_NAME: tf.io.FixedLenFeature(shape=[1], dtype=tf.float32), 68 | InputFtrType.UID_COLUMN_NAME: tf.io.FixedLenFeature(shape=[1], dtype=tf.int64), 69 | InputFtrType.TASK_ID_COLUMN_NAME: tf.io.FixedLenFeature(shape=[1], dtype=tf.int64), 70 | 71 | InputFtrType.LABEL_COLUMN_NAME: tf.io.FixedLenFeature(shape=[], dtype=tf.float32), 72 | 73 | InputFtrType.QUERY_COLUMN_NAME: tf.io.FixedLenFeature(shape=[1], dtype=tf.string), # not used 74 | InputFtrType.USER_TEXT_COLUMN_NAMES: tf.io.FixedLenFeature(shape=[1], dtype=tf.string), 75 | InputFtrType.USER_ID_COLUMN_NAMES: tf.io.FixedLenFeature(shape=[1], dtype=tf.string), 76 | 77 | InputFtrType.DOC_TEXT_COLUMN_NAMES: tf.io.FixedLenFeature(shape=[1], dtype=tf.string), 78 | InputFtrType.DOC_ID_COLUMN_NAMES: tf.io.FixedLenFeature(shape=[1], dtype=tf.string), 79 | 80 | InputFtrType.DENSE_FTRS_COLUMN_NAMES: tf.io.VarLenFeature(dtype=tf.float32), 81 | } 82 | 83 | self._MULTILABEL_CLASSIFICATION_FTR_TYPE_TO_SCHEMA = { 84 | InputFtrType.WEIGHT_COLUMN_NAME: tf.io.FixedLenFeature(shape=[1], dtype=tf.float32), 85 | InputFtrType.UID_COLUMN_NAME: tf.io.FixedLenFeature(shape=[1], dtype=tf.int64), 86 | InputFtrType.TASK_ID_COLUMN_NAME: tf.io.FixedLenFeature(shape=[1], dtype=tf.int64), 87 | # For MULTILABEL_CLASSIFICATION, label_column is multi-hot encoded (shape[batch_size][num_classes], defined in data_fn.py) 88 | InputFtrType.LABEL_COLUMN_NAME: tf.io.FixedLenFeature(shape=[], dtype=tf.float32), 89 | 90 | InputFtrType.QUERY_COLUMN_NAME: tf.io.FixedLenFeature(shape=[1], dtype=tf.string), # not used 91 | InputFtrType.USER_TEXT_COLUMN_NAMES: tf.io.FixedLenFeature(shape=[1], dtype=tf.string), 92 | InputFtrType.USER_ID_COLUMN_NAMES: tf.io.FixedLenFeature(shape=[1], dtype=tf.string), 93 | 94 | InputFtrType.DOC_TEXT_COLUMN_NAMES: tf.io.FixedLenFeature(shape=[1], dtype=tf.string), 95 | InputFtrType.DOC_ID_COLUMN_NAMES: tf.io.FixedLenFeature(shape=[1], dtype=tf.string), 96 | 97 | InputFtrType.DENSE_FTRS_COLUMN_NAMES: tf.io.VarLenFeature(dtype=tf.float32), 98 | } 99 | -------------------------------------------------------------------------------- /src/detext/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/detext/671d43c5ffc83cae635174ed15c58d0bc84b76ef/src/detext/utils/__init__.py -------------------------------------------------------------------------------- /src/detext/utils/distributed_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | import tensorflow as tf 5 | from absl import logging 6 | 7 | 8 | def should_export_summary(strategy): 9 | """Returns whether the summary should be exported given current strategy""" 10 | return (not strategy) or strategy.extended.should_save_summary 11 | 12 | 13 | def should_export_checkpoint(strategy): 14 | """Returns whether the checkpoint should be exported given current strategy""" 15 | return (not strategy) or strategy.extended.should_checkpoint 16 | 17 | 18 | def is_chief(strategy): 19 | """Returns whether the current node is chief node""" 20 | tf_config = os.environ.get("TF_CONFIG") 21 | # With multiworker training, tf_config contains tasks in the cluster, and each task's type in the cluster. 22 | # To get ther worker/evaluator status, need to fetch corresponding fields in the config (json format). 23 | # Read more at https://www.tensorflow.org/guide/distributed_training#setting_up_tf_config_environment_variable 24 | if tf_config: 25 | tf_config_json = json.loads(tf_config) 26 | # Logging the status of current worker/evaluator 27 | logging.info("Running with TF_CONFIG: {}".format(tf_config_json)) 28 | task = tf_config_json.get('task', {}) 29 | task_type = task.get('type', None) 30 | task_id = task.get('index', None) 31 | logging.info(f"=========== Current executor task type: {task_type}, task id: {task_id} ==========") 32 | return _is_chief(task_type, task_id, strategy) 33 | else: 34 | logging.info("=========== No TF_CONFIG found. Running local mode. ==========") 35 | return True 36 | 37 | 38 | def _is_chief(task_type, task_id, strategy): 39 | # If `task_type` is None, this may be operating as single worker, which works 40 | # effectively as chief. 41 | if task_type is None: 42 | return True 43 | if isinstance(strategy, tf.distribute.experimental.MultiWorkerMirroredStrategy): 44 | return task_type == 'worker' and task_id == 0 45 | return task_type == 'chief' 46 | -------------------------------------------------------------------------------- /src/detext/utils/layer_utils.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | import tensorflow as tf 4 | from absl import logging 5 | 6 | 7 | def init_word_embedding(vocab_size, num_units, we_trainable, we_file=None, name_prefix="w"): 8 | """Initialize word embeddings from random initialization or pretrained word embedding. 9 | 10 | This function is only used by encoding models other than BERT 11 | """ 12 | 13 | if not we_file: 14 | embedding_name = "{}_pretrained_embedding".format(name_prefix) 15 | # Random initialization 16 | embedding = tf.compat.v1.get_variable( 17 | embedding_name, [vocab_size, num_units], dtype=tf.float32, trainable=we_trainable) 18 | logging.info(f'Initializing embedding {embedding_name}') 19 | else: 20 | # Initialize by pretrained word embedding 21 | embedding_name = "{}_embedding".format(name_prefix) 22 | we = pickle.load(tf.io.gfile.GFile(we_file, 'rb')) 23 | assert vocab_size == we.shape[0] and num_units == we.shape[1] 24 | embedding = tf.compat.v1.get_variable(name=embedding_name, 25 | shape=[vocab_size, num_units], 26 | dtype=tf.float32, 27 | initializer=tf.compat.v1.constant_initializer(we), 28 | trainable=we_trainable) 29 | logging.info(f'Loading pretrained embedding {embedding_name} from {we_file}') 30 | return embedding 31 | 32 | 33 | def get_sorted_dict(dct: dict): 34 | """Returns dictionary in sorted order""" 35 | return dict(sorted(dct.items())) 36 | -------------------------------------------------------------------------------- /src/detext/utils/testing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/detext/671d43c5ffc83cae635174ed15c58d0bc84b76ef/src/detext/utils/testing/__init__.py -------------------------------------------------------------------------------- /src/detext/utils/testing/data_setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import tensorflow as tf 4 | 5 | from detext.utils.parsing_utils import HParams 6 | from detext.utils.vocab_utils import read_vocab 7 | 8 | 9 | class DataSetup: 10 | """Class containing common setup on file paths, layer params used in unit tests""" 11 | resource_dir = os.path.join(os.getcwd(), 'test', 'detext', 'resources') 12 | we_file = os.path.join(resource_dir, 'we.pkl') 13 | vocab_file = os.path.join(resource_dir, 'vocab.txt') 14 | vocab_file_for_id_ftr = vocab_file 15 | 16 | vocab_layer_dir = os.path.join(resource_dir, 'vocab_layer') 17 | embedding_layer_dir = os.path.join(resource_dir, 'embedding_layer') 18 | 19 | bert_hub_url = os.path.join(resource_dir, 'bert-hub') 20 | libert_sp_hub_url = os.path.join(resource_dir, 'libert-sp-hub') 21 | libert_space_hub_url = os.path.join(resource_dir, 'libert-space-hub') 22 | vocab_hub_url = os.path.join(resource_dir, 'vocab_layer_hub') 23 | embedding_hub_url = os.path.join(resource_dir, 'embedding_layer_hub') 24 | 25 | out_dir = os.path.join(resource_dir, "output") 26 | data_dir = os.path.join(resource_dir, "train", "dataset", "tfrecord") 27 | multitask_data_dir = os.path.join(resource_dir, "train", "multitask", "tfrecord") 28 | cls_data_dir = os.path.join(resource_dir, "train", "classification", "tfrecord") 29 | binary_cls_data_dir = os.path.join(resource_dir, "train", "binary_classification", "tfrecord") 30 | multilabel_cls_data_dir = os.path.join(resource_dir, "train", "multilabel_classification", "tfrecord") 31 | ranking_data_dir = os.path.join(resource_dir, "train", "ranking", "tfrecord") 32 | 33 | vocab_table_py = read_vocab(vocab_file) 34 | vocab_size = len(vocab_table_py) 35 | 36 | CLS = '[CLS]' 37 | PAD = '[PAD]' 38 | SEP = '[SEP]' 39 | UNK = '[UNK]' 40 | CLS_ID = vocab_table_py[CLS] 41 | PAD_ID = vocab_table_py[PAD] 42 | SEP_ID = vocab_table_py[SEP] 43 | UNK_ID = vocab_table_py[UNK] 44 | 45 | PAD_FOR_ID_FTR = PAD 46 | UNK_FOR_ID_FTR = UNK 47 | 48 | query = tf.constant(['batch1', 49 | 'batch 2 query build'], dtype=tf.dtypes.string) 50 | query_length = [1, 4] 51 | 52 | user_id_field1 = query 53 | user_id_field2 = query 54 | 55 | cls_doc_field1 = ['same content build', 'batch 2 field 1 word'] 56 | cls_doc_field2 = ['same content build', 'batch 2 field 2 word'] 57 | 58 | ranking_doc_field1 = [['same content build', 59 | 'batch 1 doc 2 field able', 60 | 'batch 1 doc 3 field 1'], 61 | ['batch 2 doc 1 field word', 62 | 'batch 2 doc 2 field 1', 63 | 'batch 2 doc 3 field test']] 64 | ranking_doc_id_field1 = ranking_doc_field1 65 | 66 | ranking_doc_field2 = [['same content build', 67 | 'batch 1 doc 2 field test', 68 | 'batch 1 doc 3 field 2'], 69 | ['batch 2 doc 1 field test', 70 | 'batch 2 doc 2 field 2', 71 | 'batch 2 doc 3 field word']] 72 | ranking_doc_id_field2 = ranking_doc_field2 73 | 74 | cls_sparse_features_1 = [[1.0, 2.0, 4.0], [0.0, -1.0, 4.0]] 75 | cls_sparse_features_2 = [[2.0, 0.0, 4.0], [2.0, 2.0, 4.0]] 76 | cls_sparse_features = [tf.sparse.from_dense(tf.constant(cls_sparse_features_1)), 77 | tf.sparse.from_dense(tf.constant(cls_sparse_features_2))] 78 | 79 | ranking_sparse_features_1 = [[[1.0, 2.0, 4.0], 80 | [1.0, 2.0, 4.0], 81 | [1.0, 2.0, 4.0]], 82 | [[0.0, -1.0, 4.0], 83 | [0.0, -1.0, 4.0], 84 | [0.0, -1.0, 4.0]]] 85 | ranking_sparse_features_2 = [[[1.0, 2.0, 4.0], 86 | [1.0, 2.0, 4.0], 87 | [1.0, 2.0, 4.0]], 88 | [[0.0, -1.0, 4.0], 89 | [0.0, -1.0, 4.0], 90 | [0.0, -1.0, 4.0]]] 91 | ranking_sparse_features = [tf.sparse.from_dense(tf.constant(ranking_sparse_features_1)), 92 | tf.sparse.from_dense(tf.constant(ranking_sparse_features_2))] 93 | nums_sparse_ftrs = [3] 94 | total_num_sparse_ftrs = sum(nums_sparse_ftrs) 95 | sparse_embedding_size = 33 96 | 97 | num_user_fields = 2 98 | user_fields = [tf.constant(query, dtype=tf.dtypes.string), tf.constant(query, dtype=tf.dtypes.string)] 99 | 100 | num_doc_fields = 2 101 | ranking_doc_fields = [tf.constant(ranking_doc_field1, dtype=tf.dtypes.string), tf.constant(ranking_doc_field2, dtype=tf.dtypes.string)] 102 | cls_doc_fields = [tf.constant(cls_doc_field1, dtype=tf.dtypes.string), tf.constant(cls_doc_field2, dtype=tf.dtypes.string)] 103 | 104 | num_user_id_fields = 2 105 | user_id_fields = user_fields 106 | 107 | num_doc_id_fields = 2 108 | ranking_doc_id_fields = ranking_doc_fields 109 | cls_doc_id_fields = cls_doc_fields 110 | 111 | num_id_fields = num_user_id_fields + num_doc_id_fields 112 | 113 | num_units = 6 114 | num_units_for_id_ftr = num_units 115 | 116 | vocab_layer_param = {'CLS': CLS, 117 | 'SEP': SEP, 118 | 'PAD': PAD, 119 | 'UNK': UNK, 120 | 'vocab_file': vocab_file} 121 | 122 | embedding_layer_param = {'vocab_layer_param': vocab_layer_param, 123 | 'vocab_hub_url': '', 124 | 'we_file': '', 125 | 'we_trainable': True, 126 | 'num_units': num_units} 127 | 128 | min_len = 3 129 | max_len = 7 130 | filter_window_sizes = [1, 2, 3] 131 | num_filters = 5 132 | 133 | cnn_param = HParams( 134 | filter_window_sizes=filter_window_sizes, 135 | num_filters=num_filters, num_doc_fields=num_doc_fields, num_user_fields=num_user_fields, 136 | min_len=min_len, max_len=max_len, 137 | embedding_layer_param=embedding_layer_param, embedding_hub_url=None) 138 | 139 | id_encoder_param = HParams(num_id_fields=num_id_fields, 140 | embedding_layer_param=embedding_layer_param, embedding_hub_url_for_id_ftr=None) 141 | rep_layer_param = HParams(ftr_ext='cnn', 142 | num_doc_fields=num_doc_fields, num_user_fields=num_user_fields, 143 | num_doc_id_fields=num_doc_id_fields, num_user_id_fields=num_user_id_fields, 144 | add_doc_projection=False, add_user_projection=False, 145 | text_encoder_param=cnn_param, id_encoder_param=id_encoder_param) 146 | -------------------------------------------------------------------------------- /src/detext/utils/vocab_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility function for vocabulary 3 | """ 4 | import gzip 5 | import os 6 | import re 7 | import string 8 | 9 | import six 10 | import tensorflow as tf 11 | from tensorflow.python.ops import lookup_ops 12 | 13 | 14 | def convert_to_unicode(text): 15 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 16 | if six.PY3: 17 | if isinstance(text, str): 18 | return text 19 | elif isinstance(text, bytes): 20 | return text.decode("utf-8", "ignore") 21 | else: 22 | raise ValueError("Unsupported string type: %s" % (type(text))) 23 | elif six.PY2: 24 | if isinstance(text, str): 25 | return text.decode("utf-8", "ignore") 26 | # In py2, six.text_type = unicode. We use six.text_type to pass the pyflake checking 27 | elif isinstance(text, six.text_type): 28 | return text 29 | else: 30 | raise ValueError("Unsupported string type: %s" % (type(text))) 31 | else: 32 | raise ValueError("Not running on Python2 or Python 3?") 33 | 34 | 35 | def strip(s): 36 | """Strips ascii whitespace characters off string s.""" 37 | return s.strip(string.whitespace) 38 | 39 | 40 | def split(s): 41 | """Split string s by whitespace characters.""" 42 | whitespace_lst = [re.escape(ws) for ws in string.whitespace] 43 | pattern = re.compile('|'.join(whitespace_lst)) 44 | return pattern.split(s) 45 | 46 | 47 | def read_vocab(input_file): 48 | """Read vocabulary file and return a dict""" 49 | if input_file is None: 50 | return None 51 | 52 | vocab = {} 53 | if input_file.endswith('.gz'): 54 | f = tf.io.gfile.GFile(input_file, 'r') 55 | fin = gzip.GzipFile(fileobj=f) 56 | else: 57 | fin = tf.io.gfile.GFile(input_file, 'r') 58 | for line in fin: 59 | word = split(strip(line))[0] 60 | vocab[word] = len(vocab) 61 | fin.close() 62 | return vocab 63 | 64 | 65 | def read_tf_vocab(input_file, UNK): 66 | """Read vocabulary and return a tf hashtable""" 67 | if input_file is None: 68 | return None 69 | 70 | keys, values = [], [] 71 | if input_file.endswith('.gz'): 72 | f = tf.io.gfile.GFile(input_file, 'r') 73 | fin = gzip.GzipFile(fileobj=f) 74 | else: 75 | fin = tf.io.gfile.GFile(input_file, 'r') 76 | for line in fin: 77 | word = split(strip(line))[0] 78 | keys.append(word) 79 | values.append(len(values)) 80 | fin.close() 81 | UNK_ID = keys.index(UNK) 82 | 83 | initializer = lookup_ops.KeyValueTensorInitializer(tf.constant(keys), tf.constant(values)) 84 | vocab_table = lookup_ops.HashTable(initializer, UNK_ID) 85 | return initializer, vocab_table 86 | 87 | 88 | def extract_text_data(input_dir, output_file, text_fields): 89 | """ 90 | Extract text data from tfrecords. The data will be used for word embedding pretraining. 91 | """ 92 | with tf.io.gfile.GFile(output_file, 'w') as fout: 93 | for file in tf.io.gfile.listdir(input_dir): 94 | input_file = os.path.join(input_dir, file) 95 | print(input_file) 96 | for example in tf.compat.v1.python_io.tf_record_iterator(input_file): 97 | result = tf.train.Example.FromString(example) 98 | for field in text_fields: 99 | text_values = result.features.feature[field].bytes_list.value 100 | for text in text_values: 101 | text = convert_to_unicode(text) 102 | text = text.strip() 103 | if ' ' not in text: 104 | continue 105 | fout.write(text + '\n') 106 | -------------------------------------------------------------------------------- /src/libert/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/detext/671d43c5ffc83cae635174ed15c58d0bc84b76ef/src/libert/__init__.py -------------------------------------------------------------------------------- /src/smart_compose/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/detext/671d43c5ffc83cae635174ed15c58d0bc84b76ef/src/smart_compose/__init__.py -------------------------------------------------------------------------------- /src/smart_compose/layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/detext/671d43c5ffc83cae635174ed15c58d0bc84b76ef/src/smart_compose/layers/__init__.py -------------------------------------------------------------------------------- /src/smart_compose/run_smart_compose.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import tempfile 3 | from dataclasses import asdict 4 | 5 | import tensorflow as tf 6 | from absl import logging 7 | from official.utils.misc import distribution_utils 8 | 9 | from smart_compose.args import SmartComposeArg 10 | from smart_compose.train import train 11 | from smart_compose.utils import distributed_utils, parsing_utils 12 | 13 | 14 | def main(argv): 15 | """ This is the main method for training the model. 16 | 17 | :param argv: training parameters 18 | """ 19 | 20 | argument = SmartComposeArg.__from_argv__(argv[1:], error_on_unknown=False) 21 | logging.set_verbosity(logging.INFO) 22 | logging.info(f"Args:\n {argument}") 23 | 24 | hparams = argument 25 | 26 | strategy = distribution_utils.get_distribution_strategy(hparams.distribution_strategy, num_gpus=hparams.num_gpu, all_reduce_alg=hparams.all_reduce_alg) 27 | logging.info(f"***********Num replica: {strategy.num_replicas_in_sync}***********") 28 | create_output_dir(hparams.resume_training, hparams.out_dir, strategy) 29 | 30 | save_hparams(hparams.out_dir, parsing_utils.HParams(**asdict(argument)), strategy) 31 | 32 | logging.info("***********Smart Compose Training***********") 33 | return train.train(strategy, hparams) 34 | 35 | 36 | def save_hparams(out_dir, hparams, strategy): 37 | """Saves hparams to out_dir""" 38 | is_chief = distributed_utils.is_chief(strategy) 39 | if not is_chief: 40 | out_dir = tempfile.mkdtemp() 41 | 42 | parsing_utils.save_hparams(out_dir, hparams) 43 | 44 | if not is_chief: 45 | tf.io.gfile.remove(parsing_utils._get_hparam_path(out_dir)) 46 | 47 | 48 | def create_output_dir(resume_training, out_dir, strategy): 49 | """Creates output directory if not exists""" 50 | is_chief = distributed_utils.is_chief(strategy) 51 | if not is_chief: 52 | out_dir = tempfile.mkdtemp() 53 | 54 | if not resume_training: 55 | if tf.io.gfile.exists(out_dir): 56 | logging.info("Removing previous output directory...") 57 | tf.io.gfile.rmtree(out_dir) 58 | 59 | # If output directory deleted or does not exist, create the directory. 60 | if not tf.io.gfile.exists(out_dir): 61 | logging.info('Creating dirs recursively at: {0}'.format(out_dir)) 62 | tf.io.gfile.makedirs(out_dir) 63 | 64 | if not is_chief: 65 | tf.io.gfile.rmtree(out_dir) 66 | 67 | 68 | if __name__ == '__main__': 69 | main(sys.argv) 70 | -------------------------------------------------------------------------------- /src/smart_compose/train/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/detext/671d43c5ffc83cae635174ed15c58d0bc84b76ef/src/smart_compose/train/__init__.py -------------------------------------------------------------------------------- /src/smart_compose/train/data_fn.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from functools import partial 3 | 4 | from smart_compose.utils.parsing_utils import get_input_files, InputFtrType, iterate_items_with_list_val 5 | 6 | 7 | def _read_specified_features(inputs, feature_type2name): 8 | """Only reads in features specified in the DeText arguments""" 9 | required_inputs = {} 10 | for _, ftr_name_list in iterate_items_with_list_val(feature_type2name): 11 | for ftr_name in ftr_name_list: 12 | required_inputs[ftr_name] = inputs[ftr_name] 13 | return required_inputs 14 | 15 | 16 | _FTR_TYPE_TO_SCHEMA = { 17 | InputFtrType.TARGET_COLUMN_NAME: tf.io.FixedLenFeature(shape=[], dtype=tf.string) 18 | } 19 | 20 | 21 | def _get_tfrecord_feature_parsing_schema(feature_type_2_name: dict): 22 | """Returns parsing schema for input TFRecord 23 | 24 | :param feature_type_2_name: Features mapping from feature types to feature names 25 | """ 26 | ftr_name_2_schema = dict() 27 | for ftr_type, ftr_name_lst in iterate_items_with_list_val(feature_type_2_name): 28 | for ftr_name in ftr_name_lst: 29 | ftr_name_2_schema[ftr_name] = _FTR_TYPE_TO_SCHEMA[ftr_type] 30 | 31 | return ftr_name_2_schema 32 | 33 | 34 | def _cast_features_to_smaller_dtype(example, feature_type_2_names: dict): 35 | """Casts tensor to smaller storage dtype. int64 -> int32, float64 -> float32""" 36 | 37 | def _cast_to_dtype_of_smaller_size(t): 38 | if t.dtype == tf.int64: 39 | return tf.cast(t, dtype=tf.int32) 40 | elif t.dtype == tf.float64: 41 | return tf.cast(t, dtype=tf.float32) 42 | else: 43 | return t 44 | 45 | for ftr_type, ftr_name_lst in iterate_items_with_list_val(feature_type_2_names): 46 | for ftr_name in ftr_name_lst: 47 | example[ftr_name] = _cast_to_dtype_of_smaller_size(example[ftr_name]) 48 | return example 49 | 50 | 51 | _FTR_TYPE_TO_DENSE_DEFAULT_VAL = { 52 | InputFtrType.TARGET_COLUMN_NAME: '', 53 | } 54 | 55 | 56 | def input_fn_tfrecord(input_pattern, 57 | batch_size, 58 | mode, 59 | feature_type_2_name: dict, 60 | block_length=100, 61 | prefetch_size=tf.data.experimental.AUTOTUNE, 62 | num_parallel_calls=tf.data.experimental.AUTOTUNE, 63 | input_pipeline_context=None): 64 | """ 65 | Data input function for training given TFRecord 66 | """ 67 | output_buffer_size = 1000 68 | 69 | input_files = get_input_files(input_pattern) 70 | feature_type_2_name = feature_type_2_name.copy() 71 | if len(input_files) > 1: # Multiple input files 72 | # Preprocess files concurrently, and interleave blocks of block_length records from each file 73 | dataset = tf.data.Dataset.from_tensor_slices(tf.constant(input_files)) 74 | # Shard input when using distributed training strategy 75 | if mode == tf.estimator.ModeKeys.TRAIN and input_pipeline_context and input_pipeline_context.num_input_pipelines > 1: 76 | dataset = dataset.shard(input_pipeline_context.num_input_pipelines, 77 | input_pipeline_context.input_pipeline_id) 78 | 79 | dataset = dataset.shuffle(buffer_size=len(input_files)) 80 | 81 | dataset = dataset.interleave(tf.data.TFRecordDataset, block_length=block_length, 82 | num_parallel_calls=num_parallel_calls) 83 | else: 84 | dataset = tf.data.TFRecordDataset(input_files[0]) 85 | 86 | # Parse and preprocess data 87 | dataset = tfrecord_transform_fn(dataset, 88 | batch_size, 89 | mode, 90 | feature_type_2_name, 91 | output_buffer_size, 92 | prefetch_size) 93 | return dataset 94 | 95 | 96 | def _split_features_and_labels(example, feature_type_2_name: dict): 97 | """Split inputs into two parts: features and label""" 98 | target_ftr_name = feature_type_2_name[InputFtrType.TARGET_COLUMN_NAME] 99 | labels = { 100 | target_ftr_name: example.pop(target_ftr_name) 101 | } 102 | 103 | return example, labels 104 | 105 | 106 | def tfrecord_transform_fn(dataset, 107 | batch_size, 108 | mode, 109 | feature_type_2_name, 110 | output_buffer_size, 111 | prefetch_size=tf.data.experimental.AUTOTUNE, 112 | num_parallel_calls=tf.data.experimental.AUTOTUNE): 113 | """ Preprocesses datasets including 114 | 1. dataset shuffling 115 | 2. record parsing 116 | 3. padding and batching 117 | """ 118 | if mode == tf.estimator.ModeKeys.TRAIN: 119 | dataset = dataset.shuffle(output_buffer_size) 120 | dataset = dataset.repeat() 121 | 122 | def _process_data(record, features_schema): 123 | example = tf.io.parse_single_example(serialized=record, features=features_schema) 124 | example = _cast_features_to_smaller_dtype(example, feature_type_2_name) 125 | features, labels = _split_features_and_labels(example, feature_type_2_name) 126 | return features, labels 127 | 128 | features_schema = _get_tfrecord_feature_parsing_schema(feature_type_2_name) 129 | dataset = dataset.map(partial(_process_data, features_schema=features_schema), 130 | num_parallel_calls=num_parallel_calls) 131 | 132 | dataset = (dataset 133 | .batch(batch_size, drop_remainder=True) 134 | .prefetch(prefetch_size)) 135 | return dataset 136 | -------------------------------------------------------------------------------- /src/smart_compose/train/losses.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def compute_regularization_penalty(l1, l2, trainable_vars): 5 | """ Returns the regularization penalty specified in hparams """ 6 | l1 = l1 if l1 is not None else 0 7 | l2 = l2 if l2 is not None else 0 8 | regularizer = tf.keras.regularizers.L1L2(l1=l1, l2=l2) 9 | 10 | penalty = 0.0 11 | for weight in trainable_vars: 12 | penalty += regularizer(weight) 13 | return penalty 14 | 15 | 16 | def compute_text_generation_loss(logits, labels, lengths): 17 | """ Returns categorical crossentropy for given sequence 18 | 19 | :param logits: Shape=[batch_size, max_sentence_length, vocab_size] 20 | :param labels: Shape=[batch_size, max_sentence_length] 21 | :param lengths: Shape=[batch_size] 22 | """ 23 | loss_val = tf.keras.losses.sparse_categorical_crossentropy(y_true=labels, y_pred=logits, from_logits=True) # [batch_size, max_sentence_length] 24 | 25 | mask = tf.sequence_mask(lengths, maxlen=tf.shape(labels)[1], dtype=tf.dtypes.float32) # [batch_size, max_sentence_length] 26 | return tf.reduce_mean(loss_val * mask) 27 | 28 | 29 | def compute_loss(l1, l2, logits, labels, lengths, trainable_vars): 30 | """ Computes loss with regularization """ 31 | return compute_text_generation_loss(logits, labels, lengths) + compute_regularization_penalty(l1, l2, trainable_vars) 32 | -------------------------------------------------------------------------------- /src/smart_compose/train/metrics.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import tensorflow as tf 4 | 5 | from .losses import compute_text_generation_loss 6 | 7 | 8 | class NegativePerplexity(tf.keras.metrics.Metric): 9 | """ Metric for computing perplexity""" 10 | 11 | def __init__(self, **kwargs): 12 | super(NegativePerplexity, self).__init__(**kwargs) 13 | self.cross_entropy_loss_obj = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) 14 | self.cross_entropy_loss_val = self.add_weight(name='total_perplexity', initializer='zeros') 15 | self.num_samples = self.add_weight(name='sample_count', initializer='zeros') 16 | 17 | def reset_states(self): 18 | for s in self.variables: 19 | s.assign(tf.zeros(shape=s.shape)) 20 | 21 | def result(self): 22 | return -tf.math.exp(self.cross_entropy_loss_val) 23 | 24 | def update_state(self, labels, logits, lengths, sample_weight=None): 25 | """ Accumulates metric statistics 26 | 27 | :param logits: Tensor Predicted scores. Shape=[batch_size, max_sentence_length, vocab_size] 28 | :param labels: Tensor Labels. Shape=[batch_size, max_sentence_length] 29 | :param sample_weight: Sample weight. Check the inherited class method for more detail 30 | """ 31 | loss = compute_text_generation_loss(logits=logits, labels=labels, lengths=lengths) 32 | batch_size = tf.cast(tf.shape(labels)[0], dtype=tf.dtypes.float32) 33 | 34 | self.cross_entropy_loss_val.assign( 35 | self.cross_entropy_loss_val * (self.num_samples / (self.num_samples + batch_size)) + 36 | loss * batch_size / (self.num_samples + batch_size)) 37 | self.num_samples.assign_add(batch_size) 38 | 39 | 40 | def get_metric_fn(metric_name): 41 | """ Returns the corresponding metric_fn according to metric name""" 42 | metrics = {'perplexity': lambda: NegativePerplexity(name=metric_name)} 43 | 44 | # Metric not found in ranking metric. Switch to classification metric matching 45 | for clf_metric_name, metric_fn in metrics.items(): 46 | if metric_name == clf_metric_name: 47 | return metric_fn 48 | 49 | raise ValueError(f'Unsupported metric name: {metric_name}') 50 | 51 | 52 | def get_metric_fn_lst(all_metrics: List[str]): 53 | """ Returns a list of metric_fn from the given metrics 54 | 55 | :param all_metrics A list of metrics supported by Smart Compose 56 | """ 57 | metric_fn_lst = [] 58 | 59 | for metric_name in all_metrics: 60 | metric_fn_lst.append(get_metric_fn(metric_name)) 61 | 62 | return metric_fn_lst 63 | -------------------------------------------------------------------------------- /src/smart_compose/train/train_model_helper.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from collections import OrderedDict 3 | from functools import partial 4 | 5 | from smart_compose.args import SmartComposeArg 6 | from smart_compose.train.data_fn import input_fn_tfrecord 7 | from smart_compose.train.losses import compute_loss 8 | from smart_compose.train.model import create_smart_compose_model 9 | from smart_compose.train.optimization import create_optimizer 10 | from smart_compose.utils.parsing_utils import HParams 11 | 12 | LR_BERT = 0.001 13 | 14 | 15 | def get_model_input(inputs, feature_type_2_name: dict): 16 | """ Returns Smart Compose model inputs in the format of OrderedDict 17 | 18 | Unrelated features are filtered out 19 | """ 20 | return OrderedDict(sorted([(ftr_name, inputs[ftr_name]) for ftr_type, ftr_name in feature_type_2_name.items()])) 21 | 22 | 23 | def get_input_fn_common(pattern, batch_size, mode, hparams: SmartComposeArg): 24 | """ Returns the common input function used in Smart Compose training and evaluation""" 25 | return _get_input_fn_common(pattern, batch_size, mode, 26 | **_get_func_param_from_hparams(_get_input_fn_common, hparams, ('pattern', 'batch_size', 'mode'))) 27 | 28 | 29 | def _get_input_fn_common(pattern, batch_size, mode, feature_type_2_name): 30 | """ Returns the common input function used in Smart Compose training and evaluation""" 31 | return lambda ctx: input_fn_tfrecord( 32 | input_pattern=pattern, batch_size=batch_size, mode=mode, 33 | feature_type_2_name=feature_type_2_name, 34 | input_pipeline_context=ctx 35 | ) 36 | 37 | 38 | def _get_vocab_layer_param(hparams: SmartComposeArg): 39 | """ Extracts required parameters for VocabLayer by function signature from hparams """ 40 | from smart_compose.layers.vocab_layer import VocabLayerFromPath 41 | param_dct = _get_func_param_from_hparams(VocabLayerFromPath.__init__, hparams) 42 | return HParams(**param_dct) 43 | 44 | 45 | def _get_embedding_layer_param(hparams: SmartComposeArg): 46 | """ Extracts required parameters for EmbeddingLayer by function signature from hparams """ 47 | from smart_compose.layers.embedding_layer import EmbeddingLayer 48 | param_dct = _get_func_param_from_hparams(EmbeddingLayer.__init__, hparams, exclude_lst=('self', 'vocab_layer_param', 'name_prefix')) 49 | return HParams(**param_dct, vocab_layer_param=_get_vocab_layer_param(hparams)) 50 | 51 | 52 | def _get_func_param_from_hparams(func, hparams: HParams, exclude_lst=('self', 'args', 'kwargs')) -> dict: 53 | """ Extracts required parameters by the function signature from hparams. Used by Smart Compose dev only 54 | 55 | This function saves the trouble in specifying the long list of params required by Smart Compose layers, models, and functions. Note that names of the given 56 | function should be attributes of hparams (i.e., if there's a param in func named "p1", hparams.p1 must exist) 57 | :param func Target function 58 | :param hparams Parameter holder 59 | :param exclude_lst List of parameters to exclude. There are two cases to use customized setting for this parameter -- 60 | 1. parameters that are not directly accessible from hparams. Such as 61 | 1) rep_layer_param in DeepMatch 62 | 2) text_encoder_param and id_encoder_param in RepLayer 63 | 2. parameters to be exposed when creating a partial function. Such as 64 | 1) get_loss_fn. For this function, we want to emphasize the parameters taken by the partial functions, which are 65 | 'scores', 'labels', 'weight', 'trainable_vars', as specified in exclude_lst 66 | """ 67 | param_dct = dict() 68 | param_lst = inspect.signature(func) 69 | for param in param_lst.parameters: 70 | if param in exclude_lst: 71 | continue 72 | param_dct[param] = getattr(hparams, param) 73 | 74 | return param_dct 75 | 76 | 77 | def get_model_fn(hparams: SmartComposeArg): 78 | """Returns a lambda function that creates a Smart Compose model from hparams""" 79 | return lambda: create_smart_compose_model(_get_embedding_layer_param(hparams), 80 | **_get_func_param_from_hparams(create_smart_compose_model, 81 | hparams, 82 | exclude_lst=('self', 'embedding_layer_param'))) 83 | 84 | 85 | def get_optimizer_fn(hparams): 86 | """ Returns function that creates an optimizer for non bert parameters""" 87 | return lambda: create_optimizer(hparams.learning_rate, hparams.num_train_steps, hparams.num_warmup_steps, hparams.optimizer, 88 | hparams.use_bias_correction_for_adamw) 89 | 90 | 91 | def get_bert_optimizer_fn(hparams): 92 | """ Returns function that creates an optimizer for bert parameters """ 93 | return lambda: create_optimizer(LR_BERT, hparams.num_train_steps, hparams.num_warmup_steps, hparams.optimizer, 94 | hparams.use_bias_correction_for_adamw) 95 | 96 | 97 | def get_loss_fn(hparams: SmartComposeArg): 98 | """ Returns a partial function that returns the loss function""" 99 | loss_fn = partial(compute_loss, l1=hparams.l1, l2=hparams.l2) 100 | return loss_fn 101 | 102 | 103 | def load_model_with_ckpt(hparams, path): 104 | """Returns loaded model with given checkpoint path""" 105 | model = get_model_fn(hparams)() 106 | model.load_weights(path) 107 | return model 108 | -------------------------------------------------------------------------------- /src/smart_compose/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/detext/671d43c5ffc83cae635174ed15c58d0bc84b76ef/src/smart_compose/utils/__init__.py -------------------------------------------------------------------------------- /src/smart_compose/utils/distributed_utils.py: -------------------------------------------------------------------------------- 1 | """ Distributed training utilities """ 2 | import json 3 | import os 4 | 5 | import tensorflow as tf 6 | from absl import logging 7 | 8 | 9 | def should_export_summary(strategy): 10 | """Returns whether the summary should be exported given current strategy""" 11 | return (not strategy) or strategy.extended.should_save_summary 12 | 13 | 14 | def should_export_checkpoint(strategy): 15 | """Returns whether the checkpoint should be exported given current strategy""" 16 | return (not strategy) or strategy.extended.should_checkpoint 17 | 18 | 19 | def is_chief(strategy): 20 | """Returns whether the current node is chief node""" 21 | tf_config = os.environ.get("TF_CONFIG") 22 | # With multiworker training, tf_config contains tasks in the cluster, and each task's type in the cluster. 23 | # To get ther worker/evaluator status, need to fetch corresponding fields in the config (json format). 24 | # Read more at https://www.tensorflow.org/guide/distributed_training#setting_up_tf_config_environment_variable 25 | if tf_config: 26 | tf_config_json = json.loads(tf_config) 27 | # Logging the status of current worker/evaluator 28 | logging.info("Running with TF_CONFIG: {}".format(tf_config_json)) 29 | task = tf_config_json.get('task', {}) 30 | task_type = task.get('type', None) 31 | task_id = task.get('index', None) 32 | logging.info(f"=========== Current executor task type: {task_type}, task id: {task_id} ==========") 33 | return _is_chief(task_type, task_id, strategy) 34 | else: 35 | logging.info("=========== No TF_CONFIG found. Running local mode. ==========") 36 | return True 37 | 38 | 39 | def _is_chief(task_type, task_id, strategy): 40 | # If `task_type` is None, this may be operating as single worker, which works 41 | # effectively as chief. 42 | if task_type is None: 43 | return True 44 | if isinstance(strategy, tf.distribute.experimental.MultiWorkerMirroredStrategy): 45 | return task_type == 'worker' and task_id == 0 46 | return task_type == 'chief' 47 | -------------------------------------------------------------------------------- /src/smart_compose/utils/parsing_utils.py: -------------------------------------------------------------------------------- 1 | """ Parsing and IO utilities""" 2 | import codecs 3 | import json 4 | import os 5 | from collections import Mapping 6 | 7 | import six 8 | import tensorflow as tf 9 | from absl import logging 10 | 11 | _MAX_FILE = 50 12 | _RANDOM_SEED = 4321 13 | 14 | _HPARAM_FILE = 'hparams' 15 | 16 | 17 | class InputFtrType: 18 | """ Input feature types 19 | 20 | These are the feature types directly read from user inputs 21 | """ 22 | TARGET_COLUMN_NAME = 'target_column_name' 23 | 24 | 25 | class OutputFtrType: 26 | """Output feature types 27 | 28 | These are the feature types that will be provided in the model output 29 | """ 30 | PREDICTED_SCORES = 'predicted_scores' 31 | PREDICTED_TEXTS = 'predicted_texts' 32 | EXIST_PREFIX = 'exist_prefix' 33 | 34 | 35 | class InternalFtrType: 36 | """Internal feature types 37 | 38 | These are the feature types that will be passed to layers inside the model 39 | """ 40 | LAST_MEMORY_STATE = 'last_memory_state' 41 | LAST_CARRY_STATE = 'last_carry_state' 42 | 43 | SENTENCES = 'sentences' 44 | MIN_LEN = 'min_len' 45 | MAX_LEN = 'max_len' 46 | 47 | NUM_CLS = 'num_cls' 48 | NUM_SEP = 'num_sep' 49 | 50 | TOKENIZED_IDS = 'tokenized_ids' 51 | TOKENIZED_TEXTS = 'tokenized_texts' 52 | LENGTH = 'length' 53 | EMBEDDED = 'embedded' 54 | 55 | LOGIT = 'logits' 56 | LABEL = 'labels' 57 | RNN_OUTPUT = '_rnn_output' 58 | SAMPLE_ID = '_sample_id' 59 | 60 | EXIST_KEY = 'exist_key' 61 | EXIST_PREFIX = OutputFtrType.EXIST_PREFIX 62 | COMPLETION_INDICES = 'completion_indices' 63 | COMPLETION_VOCAB_MASK = 'completion_vocab_mask' 64 | 65 | SEQUENCE_TO_ENCODE = 'sequence_to_encode' 66 | 67 | 68 | def as_list(value): 69 | """Returns value as a list 70 | 71 | If the value is not a list, it will be converted as a list containing one single item 72 | """ 73 | if isinstance(value, list): 74 | return value 75 | return [value] 76 | 77 | 78 | def iterate_items_with_list_val(dct): 79 | """Helper function that iterates the dict items 80 | 81 | If the value is not a list, it will be converted as a list containing one single item 82 | """ 83 | return [(key, as_list(value)) for key, value in dct.items()] 84 | 85 | 86 | def get_feature_types(): 87 | """ Returns the list of feature names defined in class FtrName""" 88 | constant_to_name_tuples = filter(lambda x: not x[0].startswith(('_', '__')), vars(InputFtrType).items()) # [(QUERY, query), (WEIGHT, weight), ...] 89 | feature_types = [t[1] for t in constant_to_name_tuples] 90 | return feature_types 91 | 92 | 93 | class HParams(Mapping): 94 | """ 95 | Hyper parameter class that behaves similar to the original tf 1.x HParams class w.r.t. functionality used in Smart Compose 96 | """ 97 | 98 | def __init__(self, **kwargs): 99 | for key, value in kwargs.items(): 100 | setattr(self, key, value) 101 | 102 | def __len__(self): 103 | return len(self.__dict__) 104 | 105 | def __iter__(self): 106 | for k in self.__dict__: 107 | yield k 108 | 109 | def __getitem__(self, item): 110 | return self.__dict__.get(item) 111 | 112 | def __repr__(self): 113 | return self.__dict__.__repr__() 114 | 115 | def to_json(self): 116 | """Serializes hparams to json 117 | 118 | Reference: https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/hparam.py#L529 119 | """ 120 | 121 | def remove_callables(x): 122 | """Omit callable elements from input with arbitrary nesting.""" 123 | if isinstance(x, dict): 124 | return {k: remove_callables(v) for k, v in six.iteritems(x) 125 | if not callable(v)} 126 | elif isinstance(x, list): 127 | return [remove_callables(i) for i in x if not callable(i)] 128 | return x 129 | 130 | return json.dumps(remove_callables(self.__dict__)) 131 | 132 | 133 | def _get_hparam_path(out_dir): 134 | return os.path.join(out_dir, _HPARAM_FILE) 135 | 136 | 137 | def save_hparams(out_dir, hparams): 138 | """Saves hparams""" 139 | hparams_file = _get_hparam_path(out_dir) 140 | logging.info("Saving hparams to %s" % hparams_file) 141 | with codecs.getwriter("utf-8")(tf.compat.v1.gfile.GFile(hparams_file, "wb")) as f: 142 | f.write(hparams.to_json()) 143 | 144 | 145 | def load_hparams(out_dir): 146 | """Loads hparams""" 147 | hparams_file = _get_hparam_path(out_dir) 148 | if tf.compat.v1.gfile.Exists(hparams_file): 149 | logging.info("Loading hparams from %s" % hparams_file) 150 | with codecs.getreader("utf-8")(tf.compat.v1.gfile.GFile(hparams_file, "rb")) as f: 151 | try: 152 | hparams_values = json.load(f) 153 | hparams = HParams(**hparams_values) 154 | except ValueError: 155 | logging.error("Can't load hparams file") 156 | return None 157 | return hparams 158 | else: 159 | return None 160 | 161 | 162 | def get_input_files(input_patterns): 163 | """Returns a list of file paths that match every pattern in input_patterns 164 | 165 | :param input_patterns a comma-separated string 166 | :return list of file paths 167 | """ 168 | input_files = [] 169 | for input_pattern in input_patterns.split(","): 170 | if tf.io.gfile.isdir(input_pattern): 171 | input_pattern = os.path.join(input_pattern, '*') 172 | input_files.extend(tf.compat.v1.gfile.Glob(input_pattern)) 173 | return input_files 174 | 175 | 176 | def estimate_steps_per_epoch(input_pattern, batch_size): 177 | """ Estimates train steps per epoch for tfrecord files 178 | 179 | Counting exact total number of examples is time consuming and unnecessary, 180 | We count the first file and use the total file size to estimate total number of examples. 181 | """ 182 | input_files = get_input_files(input_pattern) 183 | 184 | file_1st = input_files[0] 185 | file_1st_num_examples = sum(1 for _ in tf.compat.v1.python_io.tf_record_iterator(file_1st)) 186 | logging.info("number of examples in first file: {0}".format(file_1st_num_examples)) 187 | 188 | file_1st_size = tf.compat.v1.gfile.GFile(file_1st).size() 189 | logging.info("first file size: {0}".format(file_1st_size)) 190 | 191 | file_size_num_example_ratio = float(file_1st_size) / file_1st_num_examples 192 | 193 | estimated_num_examples = sum([int(tf.compat.v1.gfile.GFile(fn).size() / file_size_num_example_ratio) 194 | for fn in input_files]) 195 | logging.info("Estimated number of examples: {0}".format(estimated_num_examples)) 196 | 197 | return int(estimated_num_examples / batch_size) 198 | -------------------------------------------------------------------------------- /src/smart_compose/utils/testing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/detext/671d43c5ffc83cae635174ed15c58d0bc84b76ef/src/smart_compose/utils/testing/__init__.py -------------------------------------------------------------------------------- /src/smart_compose/utils/testing/data_setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import tensorflow as tf 4 | 5 | from smart_compose.utils.vocab_utils import read_vocab 6 | 7 | 8 | class DataSetup: 9 | """Class containing common setup on file paths, layer params used in unit tests""" 10 | resource_dir = os.path.join(os.getcwd(), 'test', 'smart_compose', 'resources') 11 | out_dir = os.path.join(resource_dir, "output") 12 | data_dir = os.path.join(resource_dir, "train", "dataset", "tfrecord") 13 | 14 | # Vocab 15 | vocab_file = os.path.join(resource_dir, 'vocab.txt') 16 | large_vocab_file = os.path.join(resource_dir, 'vocab.30k.txt') 17 | vocab_layer_dir = os.path.join(resource_dir, 'vocab_layer') 18 | vocab_hub_url = os.path.join(resource_dir, 'vocab_layer_hub') 19 | vocab_table_py = read_vocab(vocab_file) 20 | vocab_size = len(vocab_table_py) 21 | 22 | # Embedding layer 23 | we_file = '' 24 | embedding_layer_dir = os.path.join(resource_dir, 'embedding_layer') 25 | embedding_hub_url = os.path.join(resource_dir, 'embedding_layer_hub') 26 | 27 | # Special tokens 28 | CLS = '[CLS]' 29 | PAD = '[PAD]' 30 | SEP = '[SEP]' 31 | UNK = '[UNK]' 32 | CLS_ID = vocab_table_py[CLS] 33 | PAD_ID = vocab_table_py[PAD] 34 | SEP_ID = vocab_table_py[SEP] 35 | UNK_ID = vocab_table_py[UNK] 36 | 37 | # Vocab layer 38 | vocab_layer_param = {'CLS': CLS, 39 | 'SEP': SEP, 40 | 'PAD': PAD, 41 | 'UNK': UNK, 42 | 'vocab_file': vocab_file} 43 | 44 | # Embedding layer 45 | num_units = 10 46 | embedding_layer_param = {'vocab_layer_param': vocab_layer_param, 47 | 'vocab_hub_url': '', 48 | 'we_file': '', 49 | 'we_trainable': True, 50 | 'num_units': num_units} 51 | 52 | # Vocab layer with larger vocabulary size 53 | large_vocab_layer_param = {'CLS': CLS, 54 | 'SEP': SEP, 55 | 'PAD': PAD, 56 | 'UNK': UNK, 57 | 'vocab_file': large_vocab_file} 58 | 59 | # Embedding layer with larger embedding size 60 | num_units_large = 200 61 | embedding_layer_with_large_vocab_layer_param = {'vocab_layer_param': large_vocab_layer_param, 62 | 'vocab_hub_url': '', 63 | 'we_file': '', 64 | 'we_trainable': True, 65 | 'num_units': num_units_large} 66 | 67 | empty_url = '' 68 | 69 | # Beam search params 70 | min_len = 3 71 | max_len = 7 72 | beam_width = 10 73 | max_iter = 3 74 | max_decode_length = 3 75 | min_seq_prob = 0.01 76 | length_norm_power = 0 77 | 78 | # Target testing 79 | target_text = tf.constant(['test', 'function', 'hello'], dtype=tf.dtypes.string) 80 | -------------------------------------------------------------------------------- /src/smart_compose/utils/testing/test_case.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | import tensorflow as tf 5 | 6 | from . import data_setup 7 | 8 | 9 | class TestCase(tf.test.TestCase, data_setup.DataSetup): 10 | """ Unit test class""" 11 | 12 | def assertDictAllEqual(self, a: dict, b: dict): 13 | """ Checks that two dictionaries are the same """ 14 | self.assertIsInstance(a, dict) 15 | self.assertIsInstance(b, dict) 16 | self.assertAllEqual(a.keys(), b.keys()) 17 | 18 | for k in a.keys(): 19 | self.assertAllEqual(a[k], b[k]) 20 | 21 | def _cleanUp(self, dir): 22 | if os.path.exists(dir): 23 | shutil.rmtree(dir, ignore_errors=True) 24 | -------------------------------------------------------------------------------- /src/smart_compose/utils/testing/testing_utils.py: -------------------------------------------------------------------------------- 1 | """Unit test utilities""" 2 | import pickle 3 | import time 4 | 5 | import numpy as np 6 | import tensorflow as tf 7 | 8 | 9 | def create_sample_tfrecord(out_file): 10 | """Creates sample tfrecord to out_file""" 11 | print("Removing existing file {}".format(out_file)) 12 | if tf.io.gfile.exists(out_file): 13 | tf.io.gfile.remove(out_file) 14 | 15 | target_column_name = 'query' 16 | input_column_name = 'user_input' 17 | 18 | target_cases = [b'word function', b'data', b'', b'this is a sentence'] 19 | input_cases = [b'word function', b'data', b'', b'this is a sentence'] 20 | 21 | print("Composing fake tfrecord to file {}".format(out_file)) 22 | with tf.io.TFRecordWriter(out_file) as writer: 23 | with tf.Graph().as_default(), tf.compat.v1.Session(): 24 | num_instances_per_case = 10 25 | 26 | target_list = target_cases * num_instances_per_case 27 | input_list = input_cases * num_instances_per_case 28 | 29 | for inp, target in zip(input_list, target_list): 30 | features = { 31 | input_column_name: _bytes_feature([inp]), 32 | target_column_name: _bytes_feature([target]), 33 | } 34 | example_proto = tf.train.Example(features=tf.train.Features(feature=features)) 35 | writer.write(example_proto.SerializeToString()) 36 | 37 | 38 | def _bytes_feature(value): 39 | """Returns a bytes_list feature""" 40 | if isinstance(value, type(tf.constant(0))): 41 | value = value.numpy() # BytesList won't unpack a string from an EagerTensor. 42 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=value)) 43 | 44 | 45 | def _float_feature(value): 46 | """Returns a float_list feature""" 47 | return tf.train.Feature(float_list=tf.train.FloatList(value=value)) 48 | 49 | 50 | def _int64_feature(value): 51 | """Returns an int64_list feature""" 52 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) 53 | 54 | 55 | def make_we_file(vocab_file, embedding_size, output_path): 56 | """Generates a word embedding file to output_path """ 57 | embedding = [] 58 | with tf.io.gfile.GFile(vocab_file, 'r') as fin: 59 | for _ in fin: 60 | embedding.append(np.random.uniform(-1, 1, [embedding_size])) 61 | embedding = np.array(embedding) 62 | pickle.dump(embedding, tf.io.gfile.GFile(output_path, 'w')) 63 | print(f'Dumped embedding to {output_path}') 64 | 65 | 66 | def timeit(method): 67 | """A decorator function to measure the run time of a function""" 68 | 69 | def timed(*args, **kw): 70 | ts = time.time() 71 | result = method(*args, **kw) 72 | te = time.time() 73 | if 'log_time' in kw: 74 | name = kw.get('log_name', method.__name__.upper()) 75 | kw['log_time'][name] = int((te - ts) * 1000) 76 | else: 77 | print('%r %2.2f ms' % (method.__name__, (te - ts) * 1000)) 78 | return result 79 | 80 | return timed 81 | 82 | 83 | if __name__ == '__main__': 84 | from smart_compose.utils.testing.data_setup import DataSetup 85 | from os.path import join 86 | 87 | create_sample_tfrecord(join(DataSetup.data_dir, 'test.tfrecord')) 88 | -------------------------------------------------------------------------------- /src/smart_compose/utils/vocab_utils.py: -------------------------------------------------------------------------------- 1 | """ Vocabulary construction utilities """ 2 | import gzip 3 | import re 4 | import string 5 | 6 | import tensorflow as tf 7 | from tensorflow.python.ops import lookup_ops 8 | 9 | 10 | def strip(s): 11 | """Strips ascii whitespace characters off string s.""" 12 | return s.strip(string.whitespace) 13 | 14 | 15 | def split(s): 16 | """Split string s by whitespace characters.""" 17 | whitespace_lst = [re.escape(ws) for ws in string.whitespace] 18 | pattern = re.compile('|'.join(whitespace_lst)) 19 | return pattern.split(s) 20 | 21 | 22 | def read_vocab(input_file): 23 | """Read vocabulary file and return a dict""" 24 | if input_file is None: 25 | return None 26 | 27 | vocab = {} 28 | if input_file.endswith('.gz'): 29 | f = tf.io.gfile.GFile(input_file, 'r') 30 | fin = gzip.GzipFile(fileobj=f) 31 | else: 32 | fin = tf.io.gfile.GFile(input_file, 'r') 33 | for line in fin: 34 | word = split(strip(line))[0] 35 | vocab[word] = len(vocab) 36 | fin.close() 37 | return vocab 38 | 39 | 40 | def read_tf_vocab_inverse(input_file, UNK): 41 | """Read vocabulary (token->id) and return a tf hashtable (id->token)""" 42 | if input_file is None: 43 | return None 44 | 45 | keys, values = [], [] 46 | if input_file.endswith('.gz'): 47 | f = tf.io.gfile.GFile(input_file, 'r') 48 | fin = gzip.GzipFile(fileobj=f) 49 | else: 50 | fin = tf.io.gfile.GFile(input_file, 'r') 51 | for line in fin: 52 | word = split(strip(line))[0] 53 | 54 | keys.append(len(keys)) 55 | values.append(word) 56 | fin.close() 57 | 58 | initializer = lookup_ops.KeyValueTensorInitializer(tf.constant(keys), tf.constant(values)) 59 | vocab_table = lookup_ops.HashTable(initializer, UNK) 60 | return initializer, vocab_table 61 | 62 | 63 | def read_tf_vocab(input_file, UNK): 64 | """Read vocabulary and return a tf hashtable""" 65 | if input_file is None: 66 | return None 67 | 68 | keys, values = [], [] 69 | if input_file.endswith('.gz'): 70 | f = tf.io.gfile.GFile(input_file, 'r') 71 | fin = gzip.GzipFile(fileobj=f) 72 | else: 73 | fin = tf.io.gfile.GFile(input_file, 'r') 74 | for line in fin: 75 | word = split(strip(line))[0] 76 | keys.append(word) 77 | values.append(len(values)) 78 | fin.close() 79 | UNK_ID = keys.index(UNK) 80 | 81 | initializer = lookup_ops.KeyValueTensorInitializer(tf.constant(keys), tf.constant(values)) 82 | vocab_table = lookup_ops.HashTable(initializer, UNK_ID) 83 | return initializer, vocab_table 84 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/detext/671d43c5ffc83cae635174ed15c58d0bc84b76ef/test/__init__.py -------------------------------------------------------------------------------- /test/detext/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/detext/671d43c5ffc83cae635174ed15c58d0bc84b76ef/test/detext/__init__.py -------------------------------------------------------------------------------- /test/detext/layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/detext/671d43c5ffc83cae635174ed15c58d0bc84b76ef/test/detext/layers/__init__.py -------------------------------------------------------------------------------- /test/detext/layers/test_cnn_layer.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from detext.layers import cnn_layer 4 | from detext.utils.parsing_utils import InputFtrType 5 | from detext.utils.testing.data_setup import DataSetup 6 | 7 | 8 | class TestCnnLayer(tf.test.TestCase, DataSetup): 9 | """Unit test for cnn_layer.py.""" 10 | num_filters = 20 11 | filter_window_sizes = [2] 12 | 13 | min_len = 3 14 | max_len = 4 15 | 16 | def testCnnLayer(self): 17 | """Test CNN layer """ 18 | for embedding_hub_url in ['', self.embedding_hub_url]: 19 | self._testCnnLayer(embedding_hub_url) 20 | 21 | def _testCnnLayer(self, embedding_hub_url): 22 | query = self.query 23 | 24 | doc_fields = [self.ranking_doc_field1, self.ranking_doc_field2] 25 | user_fields = [query, query, query] 26 | 27 | num_filters = self.num_filters 28 | filter_window_sizes = self.filter_window_sizes 29 | layer = cnn_layer.CnnLayer(filter_window_sizes=filter_window_sizes, 30 | num_filters=num_filters, num_doc_fields=2, num_user_fields=3, 31 | min_len=self.min_len, max_len=self.max_len, 32 | embedding_layer_param=self.embedding_layer_param, embedding_hub_url=embedding_hub_url) 33 | text_ftr_size = num_filters * len(filter_window_sizes) 34 | 35 | query_ftrs, doc_ftrs, user_ftrs = layer( 36 | {InputFtrType.QUERY_COLUMN_NAME: query, InputFtrType.DOC_TEXT_COLUMN_NAMES: doc_fields, InputFtrType.USER_TEXT_COLUMN_NAMES: user_fields}) 37 | self.assertEqual(text_ftr_size, layer.text_ftr_size) 38 | self.assertAllEqual(query_ftrs.shape, [2, text_ftr_size]) 39 | self.assertAllEqual(doc_ftrs.shape, [2, 3, 2, text_ftr_size]) 40 | self.assertAllEqual(user_ftrs.shape, [2, 3, text_ftr_size]) 41 | # 1st query, 2nd doc, 2nd field should be the same as 2nd query, 1st doc, 2nd field (20, 5, 3, 1) 42 | self.assertAllEqual(doc_ftrs[0, 1, 1], doc_ftrs[1, 0, 1]) 43 | # 1st query, 1st doc, 1st field should NOT be the same as 1st query, 1st doc, 2nd field (1, 2, 3, 0) 44 | self.assertNotAllClose(doc_ftrs[0, 0, 0], doc_ftrs[0, 0, 1]) 45 | 46 | def testCnnConsistency(self): 47 | """ Test CNN consistency for data that only differ in batch sizes and padding tokens """ 48 | # doc_field1 = tf.constant(self.doc_field1, dtype=tf.int32) 49 | doc_field1 = tf.constant(self.ranking_doc_field1, dtype=tf.dtypes.string) 50 | doc_fields = [doc_field1] 51 | user_fields = None 52 | 53 | filter_window_sizes = [3] 54 | num_filters = self.num_filters 55 | 56 | layer = cnn_layer.CnnLayer(filter_window_sizes=filter_window_sizes, num_units=self.num_units, 57 | num_filters=num_filters, num_doc_fields=1, num_user_fields=0, 58 | min_len=self.min_len, max_len=self.max_len, 59 | embedding_layer_param=self.embedding_layer_param, embedding_hub_url=None) 60 | query = ['batch1 query 1'] 61 | query_ftrs, _, _ = layer( 62 | {InputFtrType.QUERY_COLUMN_NAME: query, InputFtrType.DOC_TEXT_COLUMN_NAMES: doc_fields, InputFtrType.USER_TEXT_COLUMN_NAMES: user_fields}) 63 | 64 | query2 = query + ['batch 2 query build'] 65 | query_ftrs2, _, _ = layer( 66 | {InputFtrType.QUERY_COLUMN_NAME: query2, InputFtrType.DOC_TEXT_COLUMN_NAMES: doc_fields, InputFtrType.USER_TEXT_COLUMN_NAMES: user_fields}) 67 | self.assertAllClose(query_ftrs[0], query_ftrs2[0]) 68 | 69 | 70 | if __name__ == "__main__": 71 | tf.test.main() 72 | -------------------------------------------------------------------------------- /test/detext/layers/test_embedding_layer.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import shutil 3 | 4 | import tensorflow as tf 5 | import tensorflow_hub as hub 6 | 7 | from detext.layers import embedding_layer 8 | from detext.utils.layer_utils import get_sorted_dict 9 | from detext.utils.parsing_utils import InternalFtrType 10 | from detext.utils.testing.data_setup import DataSetup 11 | 12 | 13 | class TestEmbeddingLayer(tf.test.TestCase, DataSetup): 14 | """Tests embedding_layer.py""" 15 | num_cls_sep = 0 16 | min_len = 0 17 | max_len = 5 18 | 19 | sentences = tf.constant(['hello sent1', 'build sent2']) 20 | inputs = get_sorted_dict({InternalFtrType.SENTENCES: sentences, 21 | InternalFtrType.NUM_CLS: tf.constant(num_cls_sep, dtype=tf.dtypes.int32), 22 | InternalFtrType.NUM_SEP: tf.constant(num_cls_sep, dtype=tf.dtypes.int32), 23 | InternalFtrType.MIN_LEN: tf.constant(min_len, dtype=tf.dtypes.int32), 24 | InternalFtrType.MAX_LEN: tf.constant(max_len, dtype=tf.dtypes.int32)}) 25 | 26 | embedding_layer_param = {'vocab_layer_param': DataSetup.vocab_layer_param, 27 | 'vocab_hub_url': '', 28 | 'we_file': '', 29 | 'we_trainable': True, 30 | 'num_units': DataSetup.num_units} 31 | 32 | def testEmbeddingLayerApi(self): 33 | """Checks whether a given layer conforms to the detext embedding layer api""" 34 | layer = hub.load(self.embedding_hub_url) 35 | layer: embedding_layer.EmbeddingLayerBase 36 | 37 | self.assertEqual(layer.num_units(), self.num_units) 38 | self.assertEqual(layer.vocab_size(), self.vocab_size) 39 | 40 | tokenized = layer.tokenize_to_indices(self.inputs) 41 | expected_tokenized = {InternalFtrType.LENGTH: tf.constant([2, 2]), 42 | InternalFtrType.TOKENIZED_IDS: tf.constant([[0, 0], 43 | [4, 0]])} 44 | for k, v in tokenized.items(): 45 | self.assertAllEqual(v, expected_tokenized[k]) 46 | 47 | tokenized_result = tf.constant([[1, 2], [0, 1]]) 48 | tokenized_result_shape = tf.shape(tokenized_result) 49 | embedding_lookup_result = layer.embedding_lookup(tokenized_result) 50 | self.assertAllEqual(tf.shape(embedding_lookup_result), [tokenized_result_shape[0], tokenized_result_shape[1], layer.num_units()]) 51 | 52 | outputs = layer(self.inputs) 53 | self.assertAllEqual(tf.shape(outputs[InternalFtrType.EMBEDDED]), [tokenized_result_shape[0], tokenized_result_shape[1], layer.num_units()]) 54 | self.assertAllEqual(outputs[InternalFtrType.LENGTH], tf.constant([2, 2])) 55 | 56 | def testCreateEmbeddingLayer(self): 57 | for vocab_hub_url in ['', self.vocab_hub_url]: 58 | embedding_layer_param = copy.copy(self.embedding_layer_param) 59 | embedding_layer_param['vocab_hub_url'] = vocab_hub_url 60 | self._testCreateEmbeddingLayer('', embedding_layer_param) 61 | 62 | embedding_layer_param = copy.copy(self.embedding_layer_param) 63 | embedding_layer_param['we_file'] = self.we_file 64 | self._testCreateEmbeddingLayer('', embedding_layer_param) 65 | 66 | embedding_layer_param = copy.copy(self.embedding_layer_param) 67 | self._testCreateEmbeddingLayer(self.embedding_hub_url, embedding_layer_param) 68 | 69 | def _testCreateEmbeddingLayer(self, embedding_hub_url, embedding_layer_param): 70 | layer = embedding_layer.create_embedding_layer(embedding_layer_param, embedding_hub_url) 71 | outputs = layer(self.inputs) 72 | 73 | tf.saved_model.save(layer, self.embedding_layer_dir) 74 | 75 | loaded_layer = embedding_layer.create_embedding_layer(embedding_layer_param, embedding_hub_url=self.embedding_layer_dir) 76 | loaded_layer_outputs = loaded_layer(self.inputs) 77 | 78 | for k, v in outputs.items(): 79 | self.assertAllEqual(v, loaded_layer_outputs[k]) 80 | shutil.rmtree(self.embedding_layer_dir) 81 | 82 | 83 | if __name__ == '__main__': 84 | tf.test.main() 85 | -------------------------------------------------------------------------------- /test/detext/layers/test_feature_grouper.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from detext.layers import feature_grouper 4 | from detext.layers.feature_grouper import FeatureGrouper 5 | from detext.utils import vocab_utils 6 | from detext.utils.parsing_utils import InputFtrType 7 | from detext.utils.testing.data_setup import DataSetup 8 | 9 | 10 | class TestFeatureGrouper(tf.test.TestCase, DataSetup): 11 | """Unit test for feature_grouper.py""" 12 | _, vocab_tf_table = vocab_utils.read_tf_vocab(DataSetup.vocab_file, '[UNK]') 13 | vocab_table = vocab_utils.read_vocab(DataSetup.vocab_file) 14 | 15 | PAD_ID = vocab_table[DataSetup.PAD] 16 | SEP_ID = vocab_table[DataSetup.SEP] 17 | CLS_ID = vocab_table[DataSetup.CLS] 18 | UNK_ID = vocab_table[DataSetup.UNK] 19 | 20 | max_filter_window_size = 0 21 | 22 | def testFeatureGrouperKerasInput(self): 23 | """Tests FeatureGrouper with tf.keras.Input""" 24 | nums_dense_ftrs = [2, 3] 25 | nums_sparse_ftrs = [10, 30] 26 | layer = FeatureGrouper() 27 | inputs = { 28 | InputFtrType.QUERY_COLUMN_NAME: tf.keras.Input(shape=(), dtype='string'), 29 | InputFtrType.USER_TEXT_COLUMN_NAMES: [tf.keras.Input(shape=(), dtype='string')], 30 | InputFtrType.USER_ID_COLUMN_NAMES: [tf.keras.Input(shape=(), dtype='string')], 31 | InputFtrType.DOC_TEXT_COLUMN_NAMES: [tf.keras.Input(shape=(None,), dtype='string')], 32 | InputFtrType.DOC_ID_COLUMN_NAMES: [tf.keras.Input(shape=(None,), dtype='string')], 33 | InputFtrType.DENSE_FTRS_COLUMN_NAMES: [tf.keras.Input(shape=(num_dense_ftrs,), dtype='float32') for num_dense_ftrs in nums_dense_ftrs], 34 | InputFtrType.SPARSE_FTRS_COLUMN_NAMES: [tf.keras.Input(shape=(num_sparse_ftrs,), dtype='float32', sparse=True) 35 | for num_sparse_ftrs in nums_sparse_ftrs] 36 | } 37 | outputs = layer(inputs) 38 | self.assertLen(outputs, len(inputs)) 39 | 40 | def testFeatureGrouperTensor(self): 41 | """Tests FeatureGrouper with tensor input""" 42 | layer = FeatureGrouper() 43 | inputs = {InputFtrType.QUERY_COLUMN_NAME: tf.constant(['batch 1 user 1 build', 44 | 'batch 2 user 2 word'], dtype=tf.string), 45 | InputFtrType.DENSE_FTRS_COLUMN_NAMES: [tf.constant([[1, 1], [2, 2]], dtype=tf.float32), 46 | tf.constant([[0], [1]], dtype=tf.float32)], 47 | InputFtrType.SPARSE_FTRS_COLUMN_NAMES: [tf.sparse.from_dense(tf.constant([[1, 0], [2, 0]], dtype=tf.float32)), 48 | tf.sparse.from_dense(tf.constant([[1], [1]], dtype=tf.float32))] 49 | } 50 | expected_result = {InputFtrType.QUERY_COLUMN_NAME: tf.constant(['batch 1 user 1 build', 51 | 'batch 2 user 2 word'], dtype=tf.string), 52 | InputFtrType.DENSE_FTRS_COLUMN_NAMES: tf.constant([[1, 1, 0], 53 | [2, 2, 1]]), 54 | InputFtrType.SPARSE_FTRS_COLUMN_NAMES: [tf.constant([[1, 0], 55 | [2, 0]], dtype=tf.float32), 56 | tf.constant([[1], [1]], dtype=tf.float32)] 57 | } 58 | outputs = layer(inputs) 59 | 60 | self.assertEqual(len(outputs), len(expected_result)), "Outputs must have the same shape" 61 | for ftr_type, expected_ftr in expected_result.items(): 62 | output = outputs[ftr_type] 63 | if ftr_type == InputFtrType.SPARSE_FTRS_COLUMN_NAMES: 64 | output = [tf.sparse.to_dense(t) for t in output] 65 | for e, o in zip(expected_ftr, output): 66 | self.assertAllEqual(e, o) 67 | continue 68 | self.assertAllEqual(expected_ftr, output) 69 | 70 | def testConcatFtrOnLastDim(self): 71 | """Tests concatenate features on last dimension""" 72 | tensor_lst = [tf.constant([1, 2, 3], dtype='int32'), tf.constant([4, 5, 6], dtype='int32')] 73 | result = feature_grouper.concat_on_last_axis_dense(tensor_lst) 74 | expected_output = tf.constant([1, 2, 3, 4, 5, 6], dtype='int32') 75 | self.assertAllEqual(result, expected_output) 76 | 77 | 78 | if __name__ == '__main__': 79 | tf.test.main() 80 | -------------------------------------------------------------------------------- /test/detext/layers/test_feature_name_type_converter.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import tensorflow as tf 4 | 5 | from detext.layers import feature_name_type_converter 6 | from detext.utils.parsing_utils import InputFtrType, TaskType, InternalFtrType 7 | from detext.utils.testing.data_setup import DataSetup 8 | 9 | 10 | class TestFeatureTypeNameConverter(tf.test.TestCase, DataSetup): 11 | """Unit test for feature_name_type_converter.py""" 12 | ranking_dense_ftrs = tf.random.uniform(shape=[2, 3, 5]) 13 | cls_dense_ftrs = tf.random.uniform(shape=[2, 5]) 14 | 15 | ranking_inputs = OrderedDict(sorted({'query': DataSetup.query, 16 | 'usr_headline': DataSetup.user_fields[0], 17 | 'usr_title': DataSetup.user_fields[1], 18 | 'usrId_headline': DataSetup.user_id_fields[0], 19 | 'usrId_title': DataSetup.user_id_fields[1], 20 | 'doc_headline': DataSetup.ranking_doc_fields[0], 21 | 'doc_title': DataSetup.ranking_doc_fields[1], 22 | 'docId_headline': DataSetup.ranking_doc_id_fields[0], 23 | 'docId_title': DataSetup.ranking_doc_id_fields[1], 24 | 'dense_ftrs': ranking_dense_ftrs, 25 | 'task_id_field': tf.constant([2, 1]) 26 | }.items())) 27 | 28 | classification_inputs = OrderedDict(sorted({'query': DataSetup.query, 29 | 'usr_headline': DataSetup.user_fields[0], 30 | 'usr_title': DataSetup.user_fields[1], 31 | 'usrId_headline': DataSetup.user_id_fields[0], 32 | 'usrId_title': DataSetup.user_id_fields[1], 33 | 34 | 'doc_headline': DataSetup.cls_doc_fields[0], 35 | 'doc_title': DataSetup.cls_doc_fields[1], 36 | 'docId_headline': DataSetup.cls_doc_id_fields[0], 37 | 'docId_title': DataSetup.cls_doc_id_fields[1], 38 | 'dense_ftrs': cls_dense_ftrs, 39 | 'task_id_field': tf.constant([2, 1]) 40 | }.items())) 41 | 42 | feature_type2name_deep = {InputFtrType.QUERY_COLUMN_NAME: 'query', 43 | InputFtrType.DOC_TEXT_COLUMN_NAMES: ['doc_headline', 'doc_title'], 44 | InputFtrType.DOC_ID_COLUMN_NAMES: ['docId_headline', 'docId_title'], 45 | InputFtrType.USER_TEXT_COLUMN_NAMES: ['usr_headline', 'usr_title'], 46 | InputFtrType.USER_ID_COLUMN_NAMES: ['usrId_headline', 'usrId_title']} 47 | feature_type2name_wide = {InputFtrType.DENSE_FTRS_COLUMN_NAMES: ['dense_ftrs']} 48 | feature_type2name_multitask = {InputFtrType.TASK_ID_COLUMN_NAME: 'task_id_field'} 49 | 50 | feature_type2name = {**feature_type2name_deep, **feature_type2name_wide, **feature_type2name_multitask} 51 | 52 | def test_name_type_converter_ranking(self): 53 | task_type_list = [TaskType.RANKING, TaskType.CLASSIFICATION, TaskType.BINARY_CLASSIFICATION] 54 | inputs_list = [self.ranking_inputs, self.classification_inputs, self.classification_inputs] 55 | for task_type, inputs in zip(task_type_list, inputs_list): 56 | self._test_name_type_converter(task_type, inputs) 57 | 58 | def _test_name_type_converter(self, task_type, inputs): 59 | converter = feature_name_type_converter.FeatureNameTypeConverter(task_type=task_type, feature_type2name=self.feature_type2name) 60 | converter_type_list = [InternalFtrType.DEEP_FTR_BAG, InternalFtrType.WIDE_FTR_BAG, InternalFtrType.MULTITASK_FTR_BAG] 61 | result_feature_type2name_list = [self.feature_type2name_deep, self.feature_type2name_wide, self.feature_type2name_multitask] 62 | for converter_type, result_feature_type2_name in zip(converter_type_list, result_feature_type2name_list): 63 | outputs = converter({converter_type: inputs})[converter_type] 64 | self.assertAllEqual(sorted(outputs.keys()), sorted(result_feature_type2_name.keys())) 65 | 66 | # Classification expands an additional dimension so that cls inputs can reuse the ranking layers since it's treated as list_size=1 67 | if task_type in [TaskType.CLASSIFICATION, TaskType.BINARY_CLASSIFICATION]: 68 | if converter_type == InternalFtrType.WIDE_FTR_BAG: 69 | for t in outputs[InputFtrType.DENSE_FTRS_COLUMN_NAMES]: 70 | self.assertAllEqual(tf.rank(t), 3) 71 | if converter_type == InternalFtrType.DEEP_FTR_BAG: 72 | for t in outputs[InputFtrType.DOC_TEXT_COLUMN_NAMES]: 73 | self.assertAllEqual(tf.rank(t), 2) 74 | for t in outputs[InputFtrType.DOC_ID_COLUMN_NAMES]: 75 | self.assertAllEqual(tf.rank(t), 2) 76 | 77 | 78 | if __name__ == '__main__': 79 | tf.test.main() 80 | -------------------------------------------------------------------------------- /test/detext/layers/test_feature_normalizer.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from detext.layers import feature_normalizer 4 | from detext.utils.testing.data_setup import DataSetup 5 | 6 | 7 | class TestFeatureNormalizer(DataSetup, tf.test.TestCase): 8 | """Unit test for feature_normalizer.py""" 9 | 10 | def testFeatureNormalizer(self): 11 | batch_size = 10 12 | list_size = 5 13 | num_ftrs = 20 14 | inputs = tf.random.uniform(shape=[batch_size, list_size, num_ftrs]) 15 | 16 | ftr_mean = 0.9 17 | ftr_std = 1.2 18 | normalizer = feature_normalizer.FeatureNormalizer(ftr_mean, ftr_std) 19 | 20 | self.assertEqual(normalizer.ftr_mean, ftr_mean) 21 | self.assertEqual(normalizer.ftr_std, ftr_std) 22 | outputs = normalizer(inputs) 23 | self.assertAllEqual(outputs, (inputs - ftr_mean) / ftr_std) 24 | 25 | 26 | if __name__ == '__main__': 27 | tf.test.main() 28 | -------------------------------------------------------------------------------- /test/detext/layers/test_feature_rescaler.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from detext.layers import feature_rescaler 4 | from detext.utils.testing.data_setup import DataSetup 5 | 6 | 7 | class TestFeatureRescaler(DataSetup, tf.test.TestCase): 8 | """Unit test for feature_rescaler.py""" 9 | 10 | def testFeatureRescaler(self): 11 | batch_size = 10 12 | list_size = 5 13 | num_ftrs = 20 14 | inputs = tf.random.uniform(shape=[batch_size, list_size, num_ftrs]) 15 | 16 | rescaler = feature_rescaler.FeatureRescaler(num_ftrs) 17 | 18 | self.assertEqual(rescaler._initial_w, 1.0) 19 | self.assertEqual(rescaler._initial_b, 0.0) 20 | outputs = rescaler(inputs) 21 | self.assertAllEqual(outputs, tf.tanh(inputs * rescaler._initial_w + rescaler._initial_b)) 22 | 23 | 24 | if __name__ == '__main__': 25 | tf.test.main() 26 | -------------------------------------------------------------------------------- /test/detext/layers/test_id_embed_layer.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from detext.layers import id_embed_layer 4 | from detext.utils.parsing_utils import InputFtrType 5 | from detext.utils.testing.data_setup import DataSetup 6 | 7 | 8 | class TestIdEmbedLayer(tf.test.TestCase, DataSetup): 9 | """Unit test for id_embed_layer.py.""" 10 | 11 | def testIdEmbedLayer(self): 12 | """test IdEmbedLayer outputs""" 13 | doc_id_field1 = tf.constant(self.ranking_doc_id_field1, dtype=tf.dtypes.string) 14 | doc_id_field2 = tf.constant(self.ranking_doc_id_field2, dtype=tf.dtypes.string) 15 | doc_id_fields = [doc_id_field1, doc_id_field2] 16 | 17 | user_id_field1 = tf.constant(self.user_id_field1, dtype=tf.dtypes.string) 18 | user_id_field2 = tf.constant(self.user_id_field2, dtype=tf.dtypes.string) 19 | user_id_fields = [user_id_field1, user_id_field2] 20 | 21 | id_embedder = id_embed_layer.IdEmbedLayer( 22 | num_id_fields=len(doc_id_fields) + len(user_id_fields), 23 | embedding_layer_param=self.embedding_layer_param, 24 | embedding_hub_url_for_id_ftr='' 25 | ) 26 | 27 | id_ftr_size = id_embedder.id_ftr_size 28 | self.assertEqual(self.num_units_for_id_ftr, id_ftr_size) 29 | 30 | doc_ftrs, user_ftrs = id_embedder({InputFtrType.DOC_ID_COLUMN_NAMES: doc_id_fields, InputFtrType.USER_ID_COLUMN_NAMES: user_id_fields}) 31 | self.assertAllEqual(doc_ftrs.shape, [2, 3, 2, id_ftr_size]) 32 | self.assertAllEqual(user_ftrs.shape, [2, 2, id_ftr_size]) 33 | # 1st query, 2nd doc, 2nd field should be the same as 2nd query, 1st doc, 2nd field (20, 5, 3, 1) 34 | self.assertAllEqual(doc_ftrs[0, 1, 1], doc_ftrs[1, 0, 1]) 35 | # 1st query, 1st doc, 1st field should be the same as 1st query, 1st doc, 2nd field (1, 2, 3, 0) 36 | self.assertAllEqual(doc_ftrs[0, 0, 0], doc_ftrs[0, 0, 1]) 37 | # For randomly chosed doc field (2nd sample, 3rd doc, 2nd field), vector should not be all zero because 38 | # initialized embedding should be non-zero 39 | self.assertNotAllClose(doc_ftrs[1, 2, 1], tf.zeros([self.num_units_for_id_ftr], dtype=tf.float32)) 40 | 41 | 42 | if __name__ == '__main__': 43 | tf.test.main() 44 | -------------------------------------------------------------------------------- /test/detext/layers/test_lstm_layer.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from detext.layers import lstm_layer 4 | from detext.utils.parsing_utils import InputFtrType, InternalFtrType 5 | from detext.utils.testing.data_setup import DataSetup 6 | 7 | 8 | class TestLstmLayer(tf.test.TestCase, DataSetup): 9 | """Unit test for lstm_layer.py.""" 10 | 11 | def testLstm(self): 12 | """Tests LSTM text encoder """ 13 | for num_layers in [1, 2]: 14 | for bidirectional in [False, True]: 15 | self._testLstm(num_layers, bidirectional, '') 16 | 17 | for embedding_hub_url in ['', self.embedding_hub_url]: 18 | self._testLstm(1, True, embedding_hub_url) 19 | 20 | def _testLstm(self, num_layers, bidirectional, embedding_hub_url): 21 | """Tests LSTM text encoder given input """ 22 | query = tf.constant(self.query, dtype=tf.dtypes.string) 23 | doc_field1 = tf.constant(self.ranking_doc_field1, dtype=tf.dtypes.string) 24 | doc_field2 = tf.constant(self.ranking_doc_field2, dtype=tf.dtypes.string) 25 | doc_fields = [doc_field1, doc_field2] 26 | user_fields = [query, query, query] 27 | num_units = self.num_units 28 | 29 | layer = lstm_layer.LstmLayer( 30 | we_trainable=True, 31 | num_units=self.num_units, 32 | num_doc_fields=2, 33 | num_layers=num_layers, 34 | forget_bias=0.5, 35 | rnn_dropout=0., 36 | bidirectional=bidirectional, 37 | min_len=self.min_len, max_len=self.max_len, 38 | embedding_layer_param=self.embedding_layer_param, embedding_hub_url=embedding_hub_url 39 | ) 40 | text_ftr_size = num_units 41 | 42 | query_ftrs, doc_ftrs, user_ftrs = layer( 43 | {InputFtrType.QUERY_COLUMN_NAME: query, InputFtrType.DOC_TEXT_COLUMN_NAMES: doc_fields, InputFtrType.USER_TEXT_COLUMN_NAMES: user_fields}, 44 | training=False) 45 | 46 | self.assertEqual(text_ftr_size, layer.text_ftr_size) 47 | self.assertAllEqual(query_ftrs.shape, [2, text_ftr_size]) 48 | self.assertAllEqual(doc_ftrs.shape, [2, 3, 2, text_ftr_size]) 49 | self.assertAllEqual(user_ftrs.shape, [2, 3, text_ftr_size]) 50 | 51 | # 1st query, 2nd doc, 2nd field should be the same as 2nd query, 1st doc, 2nd field 52 | self.assertAllEqual(doc_ftrs[0, 1, 1], doc_ftrs[1, 0, 1]) 53 | # 1st query, 1st doc, 1st field should NOT be the same as 1st query, 2nd doc, 1st field 54 | self.assertNotAllClose(doc_ftrs[0, 0, 0], doc_ftrs[0, 1, 0]) 55 | 56 | def testApplyLstmOnText(self): 57 | """Tests apply_lstm_on_text() """ 58 | for num_layers in [1, 2]: 59 | for bidirectional in [True]: 60 | self._testApplyLstmOnText(num_layers, bidirectional) 61 | 62 | def _testApplyLstmOnText(self, num_layers, bidirectional): 63 | """Tests apply_lstm_on_text() given input """ 64 | query = tf.constant(self.query, dtype=tf.dtypes.string) 65 | doc_field1 = tf.constant(self.ranking_doc_field1, dtype=tf.dtypes.string) 66 | doc_field2 = tf.constant(self.ranking_doc_field2, dtype=tf.dtypes.string) 67 | doc_fields = [doc_field1, doc_field2] 68 | 69 | num_units = self.num_units 70 | layer = lstm_layer.LstmLayer( 71 | we_trainable=True, 72 | num_units=self.num_units, 73 | num_layers=num_layers, 74 | forget_bias=0.5, 75 | rnn_dropout=0., 76 | bidirectional=bidirectional, 77 | min_len=self.min_len, max_len=self.max_len, 78 | embedding_layer_param=self.embedding_layer_param, embedding_hub_url=None 79 | ) 80 | 81 | query_ftrs, doc_ftrs, user_ftrs = layer({InputFtrType.QUERY_COLUMN_NAME: query, InputFtrType.DOC_TEXT_COLUMN_NAMES: doc_fields}, training=False) 82 | results = lstm_layer.apply_lstm_on_text(query, layer.text_encoders, layer.embedding, 83 | bidirectional, self.min_len, self.max_len, layer.num_cls_sep, 84 | False) 85 | query_seq_outputs = results[InternalFtrType.SEQ_OUTPUTS] 86 | query_memory_state = results[InternalFtrType.LAST_MEMORY_STATE] 87 | 88 | # Make sure layer.call() and apply_lstm_on_text output the same result 89 | self.assertAllEqual(query_ftrs, query_memory_state) 90 | 91 | # Make sure sequence outputs are different at each token 92 | self.assertNotAllEqual(query_seq_outputs[0][2], query_seq_outputs[0][0]) 93 | self.assertNotAllEqual(query_seq_outputs[0][2], query_seq_outputs[0][1]) 94 | if not bidirectional: 95 | # Make sure output is the last state that's not masked out by the sequence mask inferred from sequence length 96 | expected = tf.stack([query_seq_outputs[0][2], query_seq_outputs[1][3]], axis=0) 97 | self.assertAllEqual(query_memory_state, expected) 98 | else: 99 | first_query_end = min(self.query_length[0] + 1, self.max_len - 1) 100 | second_query_end = min(self.query_length[1] + 1, self.max_len - 1) 101 | # In the bidirectional LSTM cases, the last state of the backward layer is the memory state of the **first** token. 102 | # Therefore, the checking above does not apply. Instead, we check the forward and backward output separately 103 | expected_fw_last_state = tf.stack([query_seq_outputs[0][first_query_end][:num_units // 2], query_seq_outputs[1][second_query_end][:num_units // 2]], 104 | axis=0) 105 | expected_bw_last_state = tf.stack([query_seq_outputs[0][0][num_units // 2:], query_seq_outputs[1][0][num_units // 2:]], axis=0) 106 | self.assertAllEqual(tf.slice(query_memory_state, [0, 0], [len(query), num_units // 2]), expected_fw_last_state) 107 | self.assertAllEqual(tf.slice(query_memory_state, [0, num_units // 2], [len(query), num_units // 2]), expected_bw_last_state) 108 | 109 | 110 | if __name__ == "__main__": 111 | tf.test.main() 112 | -------------------------------------------------------------------------------- /test/detext/layers/test_multi_layer_perceptron.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from detext.layers import multi_layer_perceptron 4 | 5 | 6 | class TestMultiLayerPerceptron(tf.test.TestCase): 7 | """Unit test of multi_layer_perceptron.py""" 8 | 9 | def testMultiLayerPerceptron(self): 10 | """Tests MultiLayerPerceptron class""" 11 | num_hidden = [3, 5, 1] 12 | activations = ['tanh'] * len(num_hidden) 13 | prefix = '' 14 | layer = multi_layer_perceptron.MultiLayerPerceptron(num_hidden, activations, prefix) 15 | 16 | input_shape = [2, 3, 4] 17 | x = tf.random.uniform(input_shape) 18 | y = layer(x) 19 | self.assertEqual(y.shape, input_shape[:-1] + num_hidden[-1:]) 20 | 21 | 22 | if __name__ == '__main__': 23 | tf.test.main() 24 | -------------------------------------------------------------------------------- /test/detext/layers/test_representation_layer.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from detext.layers import representation_layer 4 | from detext.utils.parsing_utils import InputFtrType, InternalFtrType 5 | from detext.utils.testing.data_setup import DataSetup 6 | 7 | 8 | class TestRepresentationLayer(tf.test.TestCase, DataSetup): 9 | """Unit test for representation_layer.py """ 10 | min_len = 3 11 | max_len = 4 12 | 13 | ftr_ext = 'cnn' 14 | text_encoder_param = DataSetup.cnn_param 15 | 16 | def testRepresentationLayer(self): 17 | """Tests RepLayer""" 18 | doc_id_fields_lst = [self.ranking_doc_id_fields, None] 19 | num_doc_id_fields_lst = [self.num_doc_id_fields, 0] 20 | user_id_fields_lst = [self.user_id_fields, None] 21 | num_user_id_fields_lst = [self.num_user_id_fields, 0] 22 | 23 | for doc_id_fields, user_id_fields, num_doc_id_fields, num_user_id_fields in zip(doc_id_fields_lst, user_id_fields_lst, num_doc_id_fields_lst, 24 | num_user_id_fields_lst): 25 | add_doc_projection = False 26 | add_user_projection = False 27 | self._testRepresentationLayer(add_doc_projection, add_user_projection, doc_id_fields, user_id_fields, num_doc_id_fields, num_user_id_fields) 28 | 29 | add_doc_projection = True 30 | add_user_projection = False 31 | self._testRepresentationLayer(add_doc_projection, add_user_projection, doc_id_fields, user_id_fields, num_doc_id_fields, num_user_id_fields) 32 | 33 | add_doc_projection = True 34 | add_user_projection = True 35 | self._testRepresentationLayer(add_doc_projection, add_user_projection, doc_id_fields, user_id_fields, num_doc_id_fields, num_user_id_fields) 36 | 37 | def _testRepresentationLayer(self, add_doc_projection, add_user_projection, doc_id_fields, user_id_fields, num_doc_id_fields, num_user_id_fields): 38 | """Tests RepLayer given input""" 39 | layer = representation_layer.RepresentationLayer(self.ftr_ext, self.num_doc_fields, self.num_user_fields, 40 | num_doc_id_fields, num_user_id_fields, add_doc_projection, add_user_projection, 41 | self.text_encoder_param, self.id_encoder_param) 42 | outputs = layer( 43 | {InputFtrType.QUERY_COLUMN_NAME: self.query, InputFtrType.DOC_TEXT_COLUMN_NAMES: self.ranking_doc_fields, 44 | InputFtrType.USER_TEXT_COLUMN_NAMES: self.user_fields, 45 | InputFtrType.DOC_ID_COLUMN_NAMES: doc_id_fields, InputFtrType.USER_ID_COLUMN_NAMES: user_id_fields}, False) 46 | 47 | self.assertEqual(tf.shape(outputs[InternalFtrType.QUERY_FTRS])[-1], layer.ftr_size) 48 | self.assertAllEqual(tf.shape(outputs[InternalFtrType.DOC_FTRS])[-2:], [layer.output_num_doc_fields, layer.ftr_size]) 49 | self.assertAllEqual(tf.shape(outputs[InternalFtrType.USER_FTRS])[-2:], [layer.output_num_user_fields, layer.ftr_size]) 50 | 51 | 52 | if __name__ == '__main__': 53 | tf.test.main() 54 | -------------------------------------------------------------------------------- /test/detext/layers/test_scoring_layer.py: -------------------------------------------------------------------------------- 1 | from itertools import product 2 | 3 | import tensorflow as tf 4 | 5 | from detext.layers import scoring_layer 6 | from detext.utils.parsing_utils import InputFtrType 7 | from detext.utils.testing.data_setup import DataSetup 8 | 9 | 10 | class TestEmbeddingInteractionLayer(tf.test.TestCase, DataSetup): 11 | """Unit test for scoring_layer.py""" 12 | batch_size = 3 13 | list_size = 10 14 | interaction_ftr_size = 5 15 | 16 | def testScoringLayer(self): 17 | """Tests ScoringLayer""" 18 | task_ids_list = [[0], [213, 3213, 4]] 19 | num_classes_list = [1, 20] 20 | for task_ids, num_classes in product(task_ids_list, num_classes_list): 21 | self._testScoringLayer(task_ids, num_classes) 22 | 23 | def _testScoringLayer(self, task_ids, num_classes): 24 | """Tests ScoringLayer under given settings""" 25 | inputs = { 26 | InputFtrType.TASK_ID_COLUMN_NAME: tf.constant(0, dtype=tf.dtypes.int32), 27 | **{scoring_layer.ScoringLayer.get_scoring_ftrs_key(i): tf.random.uniform([self.batch_size, self.list_size, self.interaction_ftr_size]) 28 | for i in range(len(task_ids))} 29 | } 30 | 31 | layer = scoring_layer.ScoringLayer(task_ids, num_classes) 32 | outputs = layer(inputs) 33 | self.assertAllEqual(tf.shape(outputs), [self.batch_size, self.list_size, num_classes]) 34 | 35 | 36 | if __name__ == '__main__': 37 | tf.test.main() 38 | -------------------------------------------------------------------------------- /test/detext/layers/test_shallow_tower_layer.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from detext.layers import shallow_tower_layer 4 | from detext.utils.parsing_utils import InputFtrType 5 | from detext.utils.testing.data_setup import DataSetup 6 | 7 | 8 | class TestShallowTowerLayer(tf.test.TestCase, DataSetup): 9 | """Unit test for shallow_tower_layer.py""" 10 | dense_tensor = [tf.constant([[0, 3]], dtype=tf.dtypes.float32), 11 | tf.constant([[4, 20]], dtype=tf.dtypes.float32)] 12 | sparse_tensor = [tf.sparse.from_dense(t) for t in dense_tensor] 13 | 14 | sparse_embedding_size = 10 15 | nums_sparse_ftrs = [2, 2] 16 | batch_size = tf.shape(dense_tensor[0])[0] 17 | 18 | class RangeInitializer(tf.keras.initializers.Initializer): 19 | def __init__(self): 20 | super().__init__() 21 | 22 | def __call__(self, shape, dtype=None, **kwargs): 23 | return tf.reshape(tf.range(0, tf.reduce_prod(shape), dtype=dtype), shape=shape) 24 | 25 | def test(self): 26 | inputs = { 27 | InputFtrType.SHALLOW_TOWER_SPARSE_FTRS_COLUMN_NAMES: self.sparse_tensor 28 | } 29 | first_embedding = tf.range(10, dtype=tf.float32) 30 | second_embedding = tf.range(10, 20, dtype=tf.float32) 31 | 32 | expected_result = tf.expand_dims(first_embedding * (self.dense_tensor[0][0][0] + self.dense_tensor[1][0][0]) + 33 | second_embedding * (self.dense_tensor[0][0][1] + self.dense_tensor[1][0][1]), axis=0) 34 | 35 | layer = shallow_tower_layer.ShallowTowerLayer(nums_shallow_tower_sparse_ftrs=self.nums_sparse_ftrs, 36 | num_classes=self.sparse_embedding_size, 37 | initializer=self.RangeInitializer()) 38 | outputs = layer(inputs) 39 | embedding_size = self.sparse_embedding_size 40 | self.assertAllEqual(tf.shape(outputs), [self.batch_size, embedding_size]) 41 | self.assertAllEqual(outputs, expected_result) 42 | 43 | 44 | if __name__ == '__main__': 45 | tf.test.main() 46 | -------------------------------------------------------------------------------- /test/detext/layers/test_sparse_embedding_layer.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from detext.layers import sparse_embedding_layer 4 | from detext.utils.parsing_utils import InputFtrType 5 | from detext.utils.testing.data_setup import DataSetup 6 | 7 | 8 | class TestSparseEmbeddingLayer(tf.test.TestCase, DataSetup): 9 | """Unit test for sparse_embedding_layer.py""" 10 | dense_tensor = [tf.constant([[0, 3]], dtype=tf.dtypes.float32), 11 | tf.constant([[4, 20]], dtype=tf.dtypes.float32)] 12 | sparse_tensor = [tf.sparse.from_dense(t) for t in dense_tensor] 13 | 14 | sparse_embedding_size = 10 15 | nums_sparse_ftrs = [2, 2] 16 | batch_size = tf.shape(dense_tensor[0])[0] 17 | 18 | class RangeInitializer(tf.keras.initializers.Initializer): 19 | def __init__(self): 20 | super().__init__() 21 | 22 | def __call__(self, shape, dtype=None, **kwargs): 23 | return tf.reshape(tf.range(0, tf.reduce_prod(shape), dtype=dtype), shape=shape) 24 | 25 | def testConcat(self): 26 | self._testConcatRangeInitializer() 27 | self._testConcatOnesInitializer() 28 | 29 | def _testConcatOnesInitializer(self): 30 | embedding = tf.ones([1, self.sparse_embedding_size], dtype=tf.dtypes.float32) 31 | expected_result_list = [tf.concat([embedding * tf.reduce_sum(self.dense_tensor[0][0]), 32 | embedding * tf.reduce_sum(self.dense_tensor[1][0])], axis=-1), 33 | tf.concat([embedding, embedding], axis=-1)] 34 | 35 | combiner_list = ['sum', 'mean'] 36 | for combiner, expected_result in zip(combiner_list, expected_result_list): 37 | self._testInput(combiner, 'concat', expected_result, 'ones') 38 | 39 | def _testConcatRangeInitializer(self): 40 | first_embedding = tf.range(10, dtype=tf.float32) 41 | second_embedding = tf.range(10, 20, dtype=tf.float32) 42 | 43 | expected_result = tf.expand_dims( 44 | tf.concat([first_embedding * self.dense_tensor[0][0][0], first_embedding * self.dense_tensor[1][0][0]], axis=-1) + 45 | tf.concat([second_embedding * self.dense_tensor[0][0][1], second_embedding * self.dense_tensor[1][0][1]], axis=-1), 46 | axis=0) 47 | self._testInput('sum', 'concat', expected_result, self.RangeInitializer()) 48 | 49 | def testSum(self): 50 | self._testSumRangeInitializer() 51 | self._testSumOnesInitializer() 52 | 53 | def _testSumRangeInitializer(self): 54 | first_embedding = tf.range(10, dtype=tf.float32) 55 | second_embedding = tf.range(10, 20, dtype=tf.float32) 56 | 57 | expected_result = tf.expand_dims(first_embedding * (self.dense_tensor[0][0][0] + self.dense_tensor[1][0][0]) + 58 | second_embedding * (self.dense_tensor[0][0][1] + self.dense_tensor[1][0][1]), axis=0) 59 | self._testInput('sum', 'sum', expected_result, self.RangeInitializer()) 60 | 61 | def _testSumOnesInitializer(self): 62 | expected_result_list = [tf.ones([1, self.sparse_embedding_size], dtype=tf.dtypes.float32) * tf.reduce_sum(self.dense_tensor), 63 | tf.ones([1, self.sparse_embedding_size], dtype=tf.dtypes.float32) * len(self.dense_tensor)] 64 | 65 | combiner_list = ['sum', 'mean'] 66 | for combiner, expected_result in zip(combiner_list, expected_result_list): 67 | self._testInput(combiner, 'sum', expected_result, 'ones') 68 | 69 | def _testInput(self, sparse_embedding_same_ftr_combiner, sparse_embedding_cross_ftr_combiner, expected_result, initializer): 70 | inputs = { 71 | InputFtrType.SPARSE_FTRS_COLUMN_NAMES: self.sparse_tensor 72 | } 73 | layer = sparse_embedding_layer.SparseEmbeddingLayer(sparse_embedding_size=self.sparse_embedding_size, 74 | nums_sparse_ftrs=self.nums_sparse_ftrs, 75 | initializer=initializer, sparse_embedding_cross_ftr_combiner=sparse_embedding_cross_ftr_combiner, 76 | sparse_embedding_same_ftr_combiner=sparse_embedding_same_ftr_combiner) 77 | outputs = layer(inputs) 78 | embedding_size = {'sum': self.sparse_embedding_size, 'concat': self.sparse_embedding_size * len(self.nums_sparse_ftrs)}[ 79 | sparse_embedding_cross_ftr_combiner] 80 | self.assertAllEqual(tf.shape(outputs), [self.batch_size, embedding_size]) 81 | self.assertAllEqual(outputs, expected_result) 82 | 83 | 84 | if __name__ == '__main__': 85 | tf.test.main() 86 | -------------------------------------------------------------------------------- /test/detext/layers/test_vocab_layer.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import shutil 3 | 4 | import tensorflow as tf 5 | import tensorflow_hub as hub 6 | 7 | from detext.layers import vocab_layer 8 | from detext.utils.layer_utils import get_sorted_dict 9 | from detext.utils.parsing_utils import InternalFtrType 10 | from detext.utils.testing.data_setup import DataSetup 11 | 12 | 13 | class TestVocabLayer(tf.test.TestCase, DataSetup): 14 | num_cls_sep = 1 15 | sentences = tf.constant(['hello sent1', 'build build build build sent2']) 16 | inputs = get_sorted_dict({InternalFtrType.SENTENCES: sentences, 17 | InternalFtrType.NUM_CLS: tf.constant(num_cls_sep, dtype=tf.dtypes.int32), 18 | InternalFtrType.NUM_SEP: tf.constant(num_cls_sep, dtype=tf.dtypes.int32), 19 | InternalFtrType.MIN_LEN: tf.constant(DataSetup.min_len, dtype=tf.dtypes.int32), 20 | InternalFtrType.MAX_LEN: tf.constant(DataSetup.max_len, dtype=tf.dtypes.int32)}) 21 | 22 | def testAddClsSep(self): 23 | vocab_layer_param = copy.copy(self.vocab_layer_param) 24 | inputs = copy.copy(self.inputs) 25 | inputs['min_len'] = 6 26 | inputs['max_len'] = 7 27 | inputs['num_cls'] = 2 28 | inputs['num_sep'] = 2 29 | 30 | layer = vocab_layer.create_vocab_layer(vocab_layer_param, '') 31 | outputs = layer(inputs) 32 | 33 | self.assertAllEqual(outputs[InternalFtrType.TOKENIZED_IDS][0], 34 | tf.constant([self.CLS_ID, self.CLS_ID, self.UNK_ID, self.UNK_ID, self.SEP_ID, self.SEP_ID, self.PAD_ID])) 35 | 36 | def testAdjustLen(self): 37 | vocab_layer_param = copy.copy(self.vocab_layer_param) 38 | inputs = copy.copy(self.inputs) 39 | inputs['min_len'] = 12 40 | inputs['max_len'] = 16 41 | 42 | layer = vocab_layer.create_vocab_layer(vocab_layer_param, '') 43 | outputs = layer(inputs) 44 | shape = tf.shape(outputs[InternalFtrType.TOKENIZED_IDS]) 45 | self.assertAllEqual(shape, tf.constant([2, 12])) 46 | 47 | inputs['min_len'] = 0 48 | inputs['max_len'] = 1 49 | outputs = layer(inputs) 50 | shape = tf.shape(outputs[InternalFtrType.TOKENIZED_IDS]) 51 | self.assertAllEqual(shape, tf.constant([2, 1])) 52 | 53 | def testLength(self): 54 | vocab_layer_param = copy.copy(self.vocab_layer_param) 55 | inputs = copy.copy(self.inputs) 56 | inputs['min_len'] = 1 57 | inputs['max_len'] = 16 58 | inputs['num_cls'] = 0 59 | inputs['num_sep'] = 0 60 | 61 | layer = vocab_layer.create_vocab_layer(vocab_layer_param, '') 62 | outputs = layer(inputs) 63 | self.assertAllEqual(outputs[InternalFtrType.LENGTH], tf.constant([2, 5])) 64 | 65 | inputs['num_cls'] = 1 66 | inputs['num_sep'] = 1 67 | layer = vocab_layer.create_vocab_layer(vocab_layer_param, '') 68 | outputs = layer(inputs) 69 | self.assertAllEqual(outputs[InternalFtrType.LENGTH], tf.constant([4, 7])) 70 | 71 | def testVocabLayerApi(self): 72 | """Checks whether a given layer conforms to the DeText vocab layer API""" 73 | layer = hub.load(self.vocab_hub_url) 74 | layer: vocab_layer.VocabLayerBase 75 | 76 | self.assertEqual(layer.vocab_size(), self.vocab_size) 77 | self.assertEqual(layer.pad_id(), self.PAD_ID) 78 | 79 | inputs = self.inputs 80 | outputs = layer(inputs) 81 | expected_outputs = {InternalFtrType.LENGTH: tf.constant([4, 7]), 82 | InternalFtrType.TOKENIZED_IDS: tf.constant([[1, 0, 0, 2, 3, 3, 3], 83 | [1, 4, 4, 4, 4, 0, 2]])} 84 | 85 | for k, v in outputs.items(): 86 | self.assertAllEqual(v, expected_outputs[k]) 87 | 88 | def testCreateVocabLayer(self): 89 | for vocab_hub_url in ['', self.vocab_hub_url]: 90 | self._testCreateVocabLayer(vocab_hub_url) 91 | 92 | def _testCreateVocabLayer(self, vocab_hub_url): 93 | layer = vocab_layer.create_vocab_layer(self.vocab_layer_param, vocab_hub_url) 94 | outputs = layer(self.inputs) 95 | tf.saved_model.save(layer, self.vocab_layer_dir) 96 | 97 | loaded_layer = vocab_layer.create_vocab_layer(None, self.vocab_layer_dir) 98 | loaded_layer_outputs = loaded_layer(self.inputs) 99 | 100 | for k, v in outputs.items(): 101 | self.assertAllEqual(v, loaded_layer_outputs[k]) 102 | shutil.rmtree(self.vocab_layer_dir) 103 | 104 | 105 | if __name__ == '__main__': 106 | tf.test.main() 107 | -------------------------------------------------------------------------------- /test/detext/resources/bert-hub/assets/uncased_vocab.txt: -------------------------------------------------------------------------------- 1 | [UNK] 2 | [CLS] 3 | [SEP] 4 | [PAD] 5 | build 6 | word 7 | function 8 | able 9 | test 10 | -------------------------------------------------------------------------------- /test/detext/resources/bert-hub/saved_model.pb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/detext/671d43c5ffc83cae635174ed15c58d0bc84b76ef/test/detext/resources/bert-hub/saved_model.pb -------------------------------------------------------------------------------- /test/detext/resources/bert-hub/variables/variables.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/detext/671d43c5ffc83cae635174ed15c58d0bc84b76ef/test/detext/resources/bert-hub/variables/variables.data-00000-of-00001 -------------------------------------------------------------------------------- /test/detext/resources/bert-hub/variables/variables.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/detext/671d43c5ffc83cae635174ed15c58d0bc84b76ef/test/detext/resources/bert-hub/variables/variables.index -------------------------------------------------------------------------------- /test/detext/resources/bert_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "hidden_dropout_prob": 0.1, 5 | "hidden_size": 64, 6 | "initializer_range": 0.02, 7 | "intermediate_size": 30, 8 | "max_position_embeddings": 64, 9 | "num_attention_heads": 2, 10 | "num_hidden_layers": 2, 11 | "type_vocab_size": 2, 12 | "vocab_size": 30 13 | } 14 | -------------------------------------------------------------------------------- /test/detext/resources/embedding_layer_hub/saved_model.pb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/detext/671d43c5ffc83cae635174ed15c58d0bc84b76ef/test/detext/resources/embedding_layer_hub/saved_model.pb -------------------------------------------------------------------------------- /test/detext/resources/embedding_layer_hub/variables/variables.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/detext/671d43c5ffc83cae635174ed15c58d0bc84b76ef/test/detext/resources/embedding_layer_hub/variables/variables.data-00000-of-00001 -------------------------------------------------------------------------------- /test/detext/resources/embedding_layer_hub/variables/variables.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/detext/671d43c5ffc83cae635174ed15c58d0bc84b76ef/test/detext/resources/embedding_layer_hub/variables/variables.index -------------------------------------------------------------------------------- /test/detext/resources/libert-sp-hub/assets/__tokenizer_type__: -------------------------------------------------------------------------------- 1 | sentencepiece -------------------------------------------------------------------------------- /test/detext/resources/libert-sp-hub/assets/spbpe.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/detext/671d43c5ffc83cae635174ed15c58d0bc84b76ef/test/detext/resources/libert-sp-hub/assets/spbpe.model -------------------------------------------------------------------------------- /test/detext/resources/libert-sp-hub/saved_model.pb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/detext/671d43c5ffc83cae635174ed15c58d0bc84b76ef/test/detext/resources/libert-sp-hub/saved_model.pb -------------------------------------------------------------------------------- /test/detext/resources/libert-sp-hub/variables/variables.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/detext/671d43c5ffc83cae635174ed15c58d0bc84b76ef/test/detext/resources/libert-sp-hub/variables/variables.data-00000-of-00001 -------------------------------------------------------------------------------- /test/detext/resources/libert-sp-hub/variables/variables.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/detext/671d43c5ffc83cae635174ed15c58d0bc84b76ef/test/detext/resources/libert-sp-hub/variables/variables.index -------------------------------------------------------------------------------- /test/detext/resources/libert-space-hub/assets/__tokenizer_type__: -------------------------------------------------------------------------------- 1 | space -------------------------------------------------------------------------------- /test/detext/resources/libert-space-hub/assets/vocab.txt: -------------------------------------------------------------------------------- 1 | [UNK] 2 | [CLS] 3 | [SEP] 4 | [PAD] 5 | build 6 | word 7 | function 8 | able 9 | test -------------------------------------------------------------------------------- /test/detext/resources/libert-space-hub/saved_model.pb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/detext/671d43c5ffc83cae635174ed15c58d0bc84b76ef/test/detext/resources/libert-space-hub/saved_model.pb -------------------------------------------------------------------------------- /test/detext/resources/libert-space-hub/variables/variables.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/detext/671d43c5ffc83cae635174ed15c58d0bc84b76ef/test/detext/resources/libert-space-hub/variables/variables.data-00000-of-00001 -------------------------------------------------------------------------------- /test/detext/resources/libert-space-hub/variables/variables.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/detext/671d43c5ffc83cae635174ed15c58d0bc84b76ef/test/detext/resources/libert-space-hub/variables/variables.index -------------------------------------------------------------------------------- /test/detext/resources/multilingual_vocab.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/detext/671d43c5ffc83cae635174ed15c58d0bc84b76ef/test/detext/resources/multilingual_vocab.txt.gz -------------------------------------------------------------------------------- /test/detext/resources/run_detext.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | PYTHONPATH=../../src python ../../src/detext/run_detext.py \ 3 | --ftr_ext cnn \ 4 | --label_column_name label \ 5 | --dense_ftrs_column_names wide_ftrs \ 6 | --doc_text_column_names doc_title \ 7 | --emb_sim_func inner concat diff \ 8 | --rescale_dense_ftrs True \ 9 | --ltr softmax \ 10 | --max_len 32 \ 11 | --min_len 3 \ 12 | --filter_window_sizes 1 2 3 \ 13 | --num_filters 50 \ 14 | --num_hidden 100 \ 15 | --num_train_steps 10 \ 16 | --num_units 60 \ 17 | --vocab_hub_url vocab_layer_hub \ 18 | --embedding_hub_url embedding_layer_hub \ 19 | --nums_dense_ftrs 2 \ 20 | --num_sparse_ftrs 100 \ 21 | --sparse_emb_size 10 \ 22 | --pmetric ndcg@10 \ 23 | --all_metrics ndcg@10 precision@1 \ 24 | --random_seed 123 \ 25 | --steps_per_stats 1 \ 26 | --steps_per_eval 2 \ 27 | --num_eval_rounds 10 \ 28 | --test_batch_size 2 \ 29 | --train_batch_size 2 \ 30 | --num_hidden 100 \ 31 | --explicit_allreduce True \ 32 | --learning_rate 0.01 \ 33 | --optimizer adamw \ 34 | --dev_file sample_data/hc_examples.tfrecord \ 35 | --test_file sample_data/hc_examples.tfrecord \ 36 | --train_file sample_data/hc_examples.tfrecord \ 37 | --vocab_file vocab.txt \ 38 | --distribution_strategy one_device \ 39 | --num_gpu 0 \ 40 | --out_dir /tmp/detext-output/hc_cnn_f50_u32_h100 41 | -------------------------------------------------------------------------------- /test/detext/resources/sample_data/hc_examples.tfrecord: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/detext/671d43c5ffc83cae635174ed15c58d0bc84b76ef/test/detext/resources/sample_data/hc_examples.tfrecord -------------------------------------------------------------------------------- /test/detext/resources/train/binary_classification/tfrecord/test.tfrecord: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/detext/671d43c5ffc83cae635174ed15c58d0bc84b76ef/test/detext/resources/train/binary_classification/tfrecord/test.tfrecord -------------------------------------------------------------------------------- /test/detext/resources/train/classification/tfrecord/test.tfrecord: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/detext/671d43c5ffc83cae635174ed15c58d0bc84b76ef/test/detext/resources/train/classification/tfrecord/test.tfrecord -------------------------------------------------------------------------------- /test/detext/resources/train/dataset/tfrecord/test.tfrecord: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/detext/671d43c5ffc83cae635174ed15c58d0bc84b76ef/test/detext/resources/train/dataset/tfrecord/test.tfrecord -------------------------------------------------------------------------------- /test/detext/resources/train/multitask/tfrecord/test.tfrecord: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/detext/671d43c5ffc83cae635174ed15c58d0bc84b76ef/test/detext/resources/train/multitask/tfrecord/test.tfrecord -------------------------------------------------------------------------------- /test/detext/resources/train/ranking/tfrecord/test.tfrecord: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/detext/671d43c5ffc83cae635174ed15c58d0bc84b76ef/test/detext/resources/train/ranking/tfrecord/test.tfrecord -------------------------------------------------------------------------------- /test/detext/resources/vocab.txt: -------------------------------------------------------------------------------- 1 | [UNK] 2 | [CLS] 3 | [SEP] 4 | [PAD] 5 | build 6 | word 7 | function 8 | able 9 | test 10 | -------------------------------------------------------------------------------- /test/detext/resources/vocab_layer_hub/saved_model.pb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/detext/671d43c5ffc83cae635174ed15c58d0bc84b76ef/test/detext/resources/vocab_layer_hub/saved_model.pb -------------------------------------------------------------------------------- /test/detext/resources/vocab_layer_hub/variables/variables.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/detext/671d43c5ffc83cae635174ed15c58d0bc84b76ef/test/detext/resources/vocab_layer_hub/variables/variables.data-00000-of-00001 -------------------------------------------------------------------------------- /test/detext/resources/vocab_layer_hub/variables/variables.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/detext/671d43c5ffc83cae635174ed15c58d0bc84b76ef/test/detext/resources/vocab_layer_hub/variables/variables.index -------------------------------------------------------------------------------- /test/detext/resources/we.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/detext/671d43c5ffc83cae635174ed15c58d0bc84b76ef/test/detext/resources/we.pkl -------------------------------------------------------------------------------- /test/detext/test_metaclass.py: -------------------------------------------------------------------------------- 1 | from detext.metaclass import SingletonMeta 2 | 3 | 4 | class Dummy(metaclass=SingletonMeta): 5 | pass 6 | 7 | 8 | def test_singleton(): 9 | p = Dummy() 10 | q = Dummy() 11 | assert id(p) == id(q) 12 | -------------------------------------------------------------------------------- /test/detext/test_tf2.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | class TestTF2(tf.test.TestCase): 5 | """Test whether TF2 is set up successfully""" 6 | 7 | def test_eager_execution(self): 8 | self.assertTrue(tf.executing_eagerly()) 9 | self.assertTrue(tf.__version__.startswith('2.')) 10 | 11 | 12 | if __name__ == "__main__": 13 | tf.test.main() 14 | -------------------------------------------------------------------------------- /test/detext/train/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/detext/671d43c5ffc83cae635174ed15c58d0bc84b76ef/test/detext/train/__init__.py -------------------------------------------------------------------------------- /test/detext/train/test_optimization.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from detext.train import optimization 4 | from detext.train.optimization import BERT_VAR_PREFIX 5 | 6 | 7 | class TestOptimization(tf.test.TestCase): 8 | @classmethod 9 | def _get_loss(cls, x, y_true, model, loss_obj): 10 | with tf.GradientTape() as tape: 11 | y_pred = model(x) 12 | loss = loss_obj(y_true, y_pred) 13 | return loss, tape 14 | 15 | @classmethod 16 | def _minimize(cls, x, y_true, model, loss_obj, optimizer): 17 | loss, tape = cls._get_loss(x, y_true, model, loss_obj) 18 | 19 | train_vars = model.trainable_variables 20 | grads = tape.gradient(loss, train_vars) 21 | grads_and_vars = zip(grads, train_vars) 22 | optimizer.apply_gradients(grads_and_vars) 23 | return loss 24 | 25 | def _train_linear_model(self, x, y_true, init_lr, num_train_steps, num_warmup_steps, process_grads_and_vars_fn): 26 | """Helper function to train a linear model""" 27 | optimizer = optimization.create_optimizer(init_lr=init_lr, 28 | num_train_steps=num_train_steps, 29 | num_warmup_steps=num_warmup_steps, 30 | optimizer='sgd', 31 | use_lr_schedule=True, 32 | use_bias_correction_for_adamw=False) 33 | 34 | model = tf.keras.Sequential(tf.keras.layers.Dense(1, use_bias=False, kernel_initializer=tf.keras.initializers.zeros())) 35 | loss_obj = tf.keras.losses.MeanSquaredError() 36 | 37 | for _ in range(2): 38 | loss, tape = self._get_loss(x, y_true, model, loss_obj) 39 | grads_and_vars = process_grads_and_vars_fn(tape, optimizer, loss, model.trainable_variables, []) 40 | optimizer.apply_gradients(grads_and_vars) 41 | return model 42 | 43 | def testProcessGradsAndVars(self): 44 | """Tests process_grads_and_vars with/without explicit allreduce""" 45 | init_lr = 0.05 46 | num_train_steps = 10 47 | num_warmup_steps = 3 48 | 49 | x = tf.constant([[0.1, 0.2], [0.3, 0.1]], dtype=tf.float32) 50 | y_true = x[:, 0] + x[:, 1] 51 | 52 | model_explicit_allreduce = self._train_linear_model(x, y_true, init_lr, num_train_steps, num_warmup_steps, 53 | optimization.process_grads_and_vars_using_explicit_allreduce) 54 | model_implicit_allreduce = self._train_linear_model(x, y_true, init_lr, num_train_steps, num_warmup_steps, 55 | optimization.process_grads_and_vars_without_explicit_allreduce) 56 | 57 | self.assertAllEqual([x.numpy() for x in model_explicit_allreduce.trainable_variables], 58 | [x.numpy() for x in model_implicit_allreduce.trainable_variables]) 59 | 60 | def testSplitBertGradsAndVars(self): 61 | """ Tests split_bert_grads_and_vars() """ 62 | grads = [tf.constant(3.0, dtype='float32'), tf.constant(4.0, dtype='float32')] 63 | variables = [tf.Variable(1.0, name='rep_model/' + BERT_VAR_PREFIX + '_model/some_var'), tf.Variable(1.0, name='some_var')] 64 | bert_grads_and_vars, non_bert_grads_and_vars = optimization.split_bert_grads_and_vars(zip(grads, variables)) 65 | 66 | self.assertAllEqual(bert_grads_and_vars, [(grads[0], variables[0])]) 67 | self.assertAllEqual(non_bert_grads_and_vars, [(grads[1], variables[1])]) 68 | 69 | def testClipByGlobalNorm(self): 70 | """ Tests clip_by_global_norm() """ 71 | grads = [tf.constant(3.0, dtype='float32'), tf.constant(4.0, dtype='float32')] 72 | variables = [tf.Variable(1.0), tf.Variable(1.0)] 73 | 74 | clip_norm_lst = [1.0, 50.0] 75 | expected_grads_lst = [[g / 5 for g in grads], [g for g in grads]] 76 | assert len(clip_norm_lst) == len(expected_grads_lst) 77 | 78 | for clip_norm, expected_grads in zip(clip_norm_lst, expected_grads_lst): 79 | self._testClipByGlobalNorm(grads, variables, clip_norm, expected_grads) 80 | 81 | def _testClipByGlobalNorm(self, grads, variables, clip_norm, expected_grads): 82 | """ Tests clip_by_global_norm() given clip norm""" 83 | grads_and_vars = optimization.clip_by_global_norm(zip(grads, variables), clip_norm) 84 | result_grad = [x[0] for x in grads_and_vars] 85 | self.assertAllEqual(result_grad, expected_grads) 86 | 87 | def testCreateOptimizer(self): 88 | """ Tests create_optimizer() """ 89 | init_lr = 0.05 90 | num_train_steps = 10 91 | num_warmup_steps = 3 92 | num_bp_steps = 5 93 | 94 | x = tf.constant([[0.1, 0.2], [0.3, 0.1]], dtype=tf.float32) 95 | y_true = x[:, 0] + x[:, 1] 96 | 97 | for optimizer_type in ['sgd', 'adam', 'adamw', 'lamb']: 98 | optimizer = optimization.create_optimizer(init_lr=init_lr, 99 | num_train_steps=num_train_steps, 100 | num_warmup_steps=num_warmup_steps, 101 | optimizer=optimizer_type, 102 | use_lr_schedule=True, 103 | use_bias_correction_for_adamw=False) 104 | 105 | model = tf.keras.Sequential(tf.keras.layers.Dense(1, use_bias=False, kernel_initializer=tf.keras.initializers.zeros())) 106 | loss_obj = tf.keras.losses.MeanSquaredError() 107 | 108 | prev_loss = self._minimize(x, y_true, model, loss_obj, optimizer).numpy() 109 | prev_lr = optimizer._decayed_lr('float32').numpy() 110 | for step in range(1, num_bp_steps): 111 | loss = self._minimize(x, y_true, model, loss_obj, optimizer).numpy() 112 | 113 | # When warm up steps > 0, lr will be 0 when calculating prev_loss and therefore no backprop will be executed 114 | # This will cause loss_at_step_0 = prev_loss 115 | if step > 1: 116 | self.assertLess(loss, prev_loss, f"Loss should be declining at each step. Step:{step}") 117 | 118 | # Learning rate check 119 | lr = optimizer._decayed_lr('float32').numpy() 120 | if step < num_warmup_steps: 121 | self.assertGreater(lr, prev_lr, f"Learning rate should be increasing during warm up. Step:{step}") 122 | else: 123 | self.assertLess(lr, prev_lr, f"Learning rate should be decreasing after warm up. Step:{step}") 124 | 125 | prev_loss = loss 126 | prev_lr = lr 127 | 128 | 129 | if __name__ == '__main__': 130 | tf.test.main() 131 | -------------------------------------------------------------------------------- /test/detext/train/test_train_flow_helper.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import tensorflow as tf 4 | from detext.train import data_fn, model, train_flow_helper 5 | from detext.train.constant import Constant 6 | from detext.utils.parsing_utils import HParams, InputFtrType, TaskType 7 | from detext.utils.testing.data_setup import DataSetup 8 | 9 | 10 | class TestTrainFlowHelper(tf.test.TestCase, DataSetup): 11 | """Unit test for train_flow_helper.py""" 12 | batch_size = 2 13 | nums_dense_ftrs = [3] 14 | task_type = TaskType.RANKING 15 | 16 | text_encoder_param = copy.copy(DataSetup.cnn_param) 17 | text_encoder_param.num_doc_fields = 1 18 | text_encoder_param.num_user_fields = 3 19 | 20 | id_encoder_param = copy.copy(DataSetup.id_encoder_param) 21 | id_encoder_param.num_id_fields = 2 22 | 23 | rep_layer_param = copy.copy(DataSetup.rep_layer_param) 24 | rep_layer_param.text_encoder_param = text_encoder_param 25 | rep_layer_param.id_encoder_param = id_encoder_param 26 | rep_layer_param.num_doc_fields = 1 27 | rep_layer_param.num_user_fields = 3 28 | 29 | rep_layer_param.num_doc_id_fields = 1 30 | rep_layer_param.num_user_id_fields = 1 31 | 32 | feature_type2name = {InputFtrType.QUERY_COLUMN_NAME: 'query', 33 | InputFtrType.DOC_TEXT_COLUMN_NAMES: ['doc_completedQuery'], 34 | InputFtrType.DOC_ID_COLUMN_NAMES: ['docId_completedQuery'], 35 | InputFtrType.USER_TEXT_COLUMN_NAMES: ['usr_headline', 'usr_skills', 'usr_currTitles'], 36 | InputFtrType.USER_ID_COLUMN_NAMES: ['usrId_currTitles'], 37 | InputFtrType.DENSE_FTRS_COLUMN_NAMES: ['wide_ftrs'], 38 | InputFtrType.LABEL_COLUMN_NAME: 'label', 39 | InputFtrType.WEIGHT_COLUMN_NAME: 'weight'} 40 | feature_name2num = {'wide_ftrs': nums_dense_ftrs[0]} 41 | 42 | deep_match_param = HParams(feature_name2num=feature_name2num, 43 | use_dense_ftrs=True, 44 | use_deep=True, 45 | has_query=True, 46 | use_sparse_ftrs=False, 47 | sparse_embedding_cross_ftr_combiner='concat', 48 | sparse_embedding_same_ftr_combiner='sum', 49 | sparse_embedding_size=10, 50 | emb_sim_func=['inner'], 51 | rep_layer_param=rep_layer_param, 52 | ftr_mean=None, ftr_std=None, 53 | num_hidden=[3], 54 | rescale_dense_ftrs=False, 55 | num_classes=1, 56 | task_ids=None) 57 | 58 | def testPredict(self): 59 | """Tests predict()""" 60 | dataset = data_fn.input_fn_tfrecord(input_pattern=self.data_dir, 61 | batch_size=self.batch_size, 62 | mode=tf.estimator.ModeKeys.EVAL, 63 | feature_type2name=self.feature_type2name, 64 | feature_name2num=self.feature_name2num, 65 | input_pipeline_context=None, 66 | ) 67 | 68 | detext_model = model.create_detext_model(self.feature_type2name, task_type=self.task_type, **self.deep_match_param) 69 | predicted_output = train_flow_helper.predict_with_additional_info(dataset, detext_model, self.feature_type2name) 70 | 71 | for output in predicted_output: 72 | for key in [train_flow_helper._SCORES, self.feature_type2name.get(InputFtrType.WEIGHT_COLUMN_NAME, Constant()._DEFAULT_WEIGHT_FTR_NAME), 73 | self.feature_type2name.get(InputFtrType.UID_COLUMN_NAME, Constant()._DEFAULT_UID_FTR_NAME), 74 | self.feature_type2name[InputFtrType.LABEL_COLUMN_NAME]]: 75 | self.assertIn(key, output) 76 | 77 | 78 | if __name__ == '__main__': 79 | tf.test.main() 80 | -------------------------------------------------------------------------------- /test/detext/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/detext/671d43c5ffc83cae635174ed15c58d0bc84b76ef/test/detext/utils/__init__.py -------------------------------------------------------------------------------- /test/detext/utils/test_parsing_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | import tensorflow as tf 5 | from detext.utils import parsing_utils 6 | from detext.utils.parsing_utils import InputFtrType 7 | 8 | root_dir = os.path.join(os.path.dirname(__file__), "../resources") 9 | test_tfrecord_path = os.path.join(root_dir, 'train', 'dataset', 'tfrecord', 'test.tfrecord') 10 | 11 | 12 | class TestParsingUtils(tf.test.TestCase): 13 | atol = 1e-3 14 | 15 | def testIterateItemsWithListVal(self): 16 | """Tests iterate_items_with_list_val""" 17 | dct_lst = [{'a': 'a'}, 18 | {'a': ['a']}] 19 | expected_result_lst = [[('a', ['a'])], 20 | [('a', ['a'])]] 21 | assert len(dct_lst) == len(expected_result_lst), 'Number of test data and result must match' 22 | 23 | for dct, expected_result in zip(dct_lst, expected_result_lst): 24 | self.assertCountEqual(expected_result, list(parsing_utils.iterate_items_with_list_val(dct))) 25 | 26 | def testGetFeatureNames(self): 27 | """Tests get_feature_names() """ 28 | self.assertCountEqual( 29 | [InputFtrType.QUERY_COLUMN_NAME, InputFtrType.LABEL_COLUMN_NAME, InputFtrType.WEIGHT_COLUMN_NAME, InputFtrType.TASK_ID_COLUMN_NAME, 30 | InputFtrType.UID_COLUMN_NAME, InputFtrType.DOC_TEXT_COLUMN_NAMES, 31 | InputFtrType.USER_TEXT_COLUMN_NAMES, InputFtrType.DOC_ID_COLUMN_NAMES, InputFtrType.USER_ID_COLUMN_NAMES, 32 | InputFtrType.SHALLOW_TOWER_SPARSE_FTRS_COLUMN_NAMES, 33 | InputFtrType.DENSE_FTRS_COLUMN_NAMES, InputFtrType.SPARSE_FTRS_COLUMN_NAMES], 34 | parsing_utils.get_feature_types()) 35 | 36 | def testHparamsLoadAndSave(self): 37 | """Tests loading and saving of hparams""" 38 | hparams = parsing_utils.HParams(a=1, b=2, c=[1, 2, 3]) 39 | parsing_utils.save_hparams(root_dir, hparams) 40 | loaded_hparams = parsing_utils.load_hparams(root_dir) 41 | self.assertEqual(hparams, loaded_hparams) 42 | os.remove(parsing_utils._get_hparam_path(root_dir)) 43 | 44 | def testComputeFtrMeanStd(self): 45 | """Tests compute_ftr_mean_std() """ 46 | true_mean = [3310.9523809523807, 0.05952380952380952, 5.440476190476191] 47 | true_std = [7346.840887180146, 0.23660246326609274, 3.1028066805800445] 48 | 49 | output_file = os.path.join(root_dir, 'tmp_ftr_mean_std_output') 50 | mean, std = parsing_utils.compute_mean_std(test_tfrecord_path, output_file, 3) 51 | os.remove(output_file) 52 | 53 | self.assertAllClose(mean, true_mean, atol=self.atol) 54 | self.assertAllClose(std, true_std, atol=self.atol) 55 | 56 | def testLoadFtrMeanStd(self): 57 | """Tests load_ftr_mean_std() """ 58 | true_mean = [1, 2, 3, 4] 59 | true_std = [3, 2, 1, 3] 60 | 61 | # Test file generated from spark 62 | separator = "," 63 | filepath_spark = os.path.join(root_dir, "tmp_ftr_mean_std.fromspark") 64 | with open(filepath_spark, 'w') as fout: 65 | fout.write("# Feature mean std file. Mean is first line, std is second line\n") 66 | fout.write(separator.join([str(x) for x in true_mean]) + "\n") 67 | fout.write(separator.join([str(x) for x in true_std]) + "\n") 68 | 69 | ftr_mean, ftr_std = parsing_utils.load_ftr_mean_std(filepath_spark) 70 | self.assertEqual(ftr_mean, true_mean) 71 | self.assertEqual(ftr_std, true_std) 72 | os.remove(filepath_spark) 73 | 74 | # Test pickle file 75 | filepath_pkl = os.path.join(root_dir, "tmp_ftr_mean_std.pkl") 76 | with tf.compat.v1.gfile.Open(filepath_pkl, 'wb') as fout: 77 | pickle.dump((true_mean, true_std), fout, protocol=2) 78 | ftr_mean, ftr_std = parsing_utils.load_ftr_mean_std(filepath_pkl) 79 | self.assertEqual(ftr_mean, true_mean) 80 | self.assertEqual(ftr_std, true_std) 81 | os.remove(filepath_pkl) 82 | 83 | def testEstimateStepsPerEpoch(self): 84 | """Tests estimate_steps_per_epoch() """ 85 | num_record = parsing_utils.estimate_steps_per_epoch(test_tfrecord_path, 1) 86 | self.assertEqual(num_record, 10) 87 | 88 | def testGetNumFields(self): 89 | """Tests get_num_fields() """ 90 | num_fields = parsing_utils.get_num_fields('doc_', ['doc_headline', 'docId_title']) 91 | self.assertEqual(num_fields, 1) 92 | 93 | 94 | if __name__ == '__main__': 95 | tf.test.main() 96 | -------------------------------------------------------------------------------- /test/smart_compose/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/detext/671d43c5ffc83cae635174ed15c58d0bc84b76ef/test/smart_compose/__init__.py -------------------------------------------------------------------------------- /test/smart_compose/layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/detext/671d43c5ffc83cae635174ed15c58d0bc84b76ef/test/smart_compose/layers/__init__.py -------------------------------------------------------------------------------- /test/smart_compose/layers/test_embedding_layer.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import shutil 3 | 4 | import tensorflow as tf 5 | import tensorflow_hub as hub 6 | 7 | from smart_compose.layers import embedding_layer 8 | from smart_compose.utils.layer_utils import get_sorted_dict 9 | from smart_compose.utils.parsing_utils import InternalFtrType 10 | from smart_compose.utils.testing.data_setup import DataSetup 11 | from smart_compose.utils.testing.test_case import TestCase 12 | 13 | 14 | class TestEmbeddingLayer(TestCase): 15 | """Tests embedding_layer.py""" 16 | num_cls = 0 17 | num_sep = 0 18 | 19 | min_len = 0 20 | max_len = 5 21 | 22 | sentences = tf.constant(['hello sent1', 'build sent2']) 23 | inputs = get_sorted_dict({'sentences': sentences, 24 | 'num_cls': tf.constant(num_cls, dtype=tf.dtypes.int32), 25 | 'num_sep': tf.constant(num_sep, dtype=tf.dtypes.int32), 26 | 'min_len': tf.constant(min_len, dtype=tf.dtypes.int32), 27 | 'max_len': tf.constant(max_len, dtype=tf.dtypes.int32)}) 28 | 29 | embedding_layer_param = {'vocab_layer_param': DataSetup.vocab_layer_param, 30 | 'vocab_hub_url': '', 31 | 'we_file': '', 32 | 'we_trainable': True, 33 | 'num_units': DataSetup.num_units} 34 | 35 | def testCreateEmbeddingLayer(self): 36 | """Tests create_embedding_layer() """ 37 | for vocab_hub_url in ['', self.vocab_hub_url]: 38 | embedding_layer_param = copy.copy(self.embedding_layer_param) 39 | embedding_layer_param['vocab_hub_url'] = vocab_hub_url 40 | self._testCreateEmbeddingLayer('', embedding_layer_param) 41 | 42 | embedding_layer_param = copy.copy(self.embedding_layer_param) 43 | embedding_layer_param['we_file'] = self.we_file 44 | self._testCreateEmbeddingLayer('', embedding_layer_param) 45 | 46 | embedding_layer_param = copy.copy(self.embedding_layer_param) 47 | self._testCreateEmbeddingLayer(self.embedding_hub_url, embedding_layer_param) 48 | 49 | def _testCreateEmbeddingLayer(self, embedding_hub_url, embedding_layer_param): 50 | layer = embedding_layer.create_embedding_layer(embedding_layer_param, embedding_hub_url) 51 | outputs = layer(self.inputs) 52 | 53 | tf.saved_model.save(layer, self.embedding_layer_dir) 54 | 55 | loaded_layer = embedding_layer.create_embedding_layer(embedding_layer_param, self.embedding_layer_dir) 56 | loaded_layer_outputs = loaded_layer(self.inputs) 57 | 58 | for k, v in outputs.items(): 59 | self.assertAllEqual(v, loaded_layer_outputs[k]) 60 | 61 | shutil.rmtree(self.embedding_layer_dir) 62 | 63 | def testEmbeddingLayerApi(self): 64 | """Checks whether a given layer conforms to the smart compose embedding layer api""" 65 | layer = hub.load(self.embedding_hub_url) 66 | layer: embedding_layer.EmbeddingLayerBase 67 | 68 | self.assertEqual(layer.num_units(), self.num_units) 69 | self.assertEqual(layer.vocab_size(), self.vocab_size) 70 | self.assertEqual(layer.sep_id(), self.SEP_ID) 71 | 72 | tokenized = layer.tokenize_to_indices(self.inputs) 73 | expected_tokenized = {InternalFtrType.LENGTH: tf.constant([2, 2]), 74 | InternalFtrType.TOKENIZED_IDS: tf.constant([[0, 0], 75 | [4, 0]])} 76 | for k, v in tokenized.items(): 77 | self.assertAllEqual(v, expected_tokenized[k]) 78 | 79 | tokenized_result = tf.constant([[1, 2], [0, 1]]) 80 | tokenized_result_shape = tf.shape(tokenized_result) 81 | embedding_lookup_result = layer.embedding_lookup(tokenized_result) 82 | self.assertAllEqual(tf.shape(embedding_lookup_result), [tokenized_result_shape[0], tokenized_result_shape[1], layer.num_units()]) 83 | 84 | outputs = layer(self.inputs) 85 | self.assertAllEqual(tf.shape(outputs[InternalFtrType.EMBEDDED]), [tokenized_result_shape[0], tokenized_result_shape[1], layer.num_units()]) 86 | self.assertAllEqual(outputs[InternalFtrType.LENGTH], tf.constant([2, 2])) 87 | 88 | 89 | if __name__ == '__main__': 90 | tf.test.main() 91 | -------------------------------------------------------------------------------- /test/smart_compose/layers/test_prefix_search.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from smart_compose.layers import prefix_search 4 | from smart_compose.layers import vocab_layer 5 | from smart_compose.utils.parsing_utils import InternalFtrType 6 | from smart_compose.utils.testing.test_case import TestCase 7 | 8 | 9 | class TestPrefixSearch(TestCase): 10 | """ Unit test for prefix_search.py """ 11 | min_len = 0 12 | max_len = 3 13 | num_cls = 1 14 | num_sep = 0 15 | searcher = prefix_search.PrefixSearcher(vocab_layer.create_vocab_layer(TestCase.vocab_layer_param, ''), 16 | min_len, max_len, num_cls, num_sep) 17 | 18 | def testPrefixSearch(self): 19 | prefix_list = [ 20 | tf.constant('b'), tf.constant('build s'), tf.constant('h'), tf.constant(''), tf.constant('b '), tf.constant(' ') 21 | ] 22 | exist_prefix_list = [ 23 | True, True, False, False, True, True 24 | ] 25 | vocab_mask_list = [ 26 | [False] * 4 + [True] + [False] * 11, 27 | [False] * 12 + [True, True] + [False] * 2, 28 | [False] * 16, 29 | [False] * 16, 30 | [True] * 16, 31 | [True] * 16 32 | ] 33 | length_list = [1, 2, 1, 0, 2, 1] 34 | 35 | assert len(prefix_list) == len(exist_prefix_list) == len(vocab_mask_list) == len(length_list), 'Test input list must have the same size' 36 | for prefix, exist_prefix, vocab_mask, length in zip(prefix_list, exist_prefix_list, vocab_mask_list, length_list): 37 | self._testPrefixSearch(prefix, exist_prefix, vocab_mask, length) 38 | 39 | def _testPrefixSearch(self, prefix, exist_prefix, vocab_mask, length): 40 | outputs = self.searcher(prefix) 41 | 42 | self.assertAllEqual(outputs[InternalFtrType.EXIST_PREFIX], exist_prefix) 43 | self.assertAllEqual(outputs[InternalFtrType.COMPLETION_VOCAB_MASK], vocab_mask) 44 | self.assertAllEqual(outputs[InternalFtrType.LENGTH], length) 45 | 46 | def testKeyValueArrayDict(self): 47 | keys_list = [[1, 2, 3], 48 | [1, 2, 3], 49 | ['10', '2', '3']] 50 | values_list = [[[2, 3, 4], [5, 0], [1]], 51 | [[2, 3, 4], [5, 0], [1]], 52 | [['2', '3', '4'], ['5', '0'], ['1']]] 53 | test_key_list = [2, -1, '2'] 54 | default_values = [-1, -1, ""] 55 | expected_value_list = [ 56 | tf.convert_to_tensor( 57 | [0, default_values[0], default_values[0], default_values[0], default_values[0], 5], 58 | dtype='int32' 59 | ), 60 | tf.convert_to_tensor( 61 | [0, default_values[1], default_values[1], default_values[1], default_values[1], 5], 62 | dtype='int32' 63 | ), 64 | tf.convert_to_tensor( 65 | ['0', default_values[2], default_values[2], default_values[2], default_values[2], '5'], 66 | dtype='string' 67 | ) 68 | ] 69 | key_type_list = ['int32', 'int32', 'string'] 70 | exist_prefix_list = [tf.convert_to_tensor(True), tf.convert_to_tensor(False), tf.convert_to_tensor(True)] 71 | for keys, values, test_key, expected_value, default_value, key_type, exist_prefix in zip( 72 | keys_list, values_list, test_key_list, expected_value_list, default_values, key_type_list, exist_prefix_list): 73 | self._testKeyValueArrayDict(keys, values, test_key, expected_value, default_value, key_type, exist_prefix) 74 | 75 | def _testKeyValueArrayDict(self, keys, values, test_key, expected_value, default_value, key_type, exist_prefix): 76 | table = prefix_search.KeyValueArrayDict(keys, values) 77 | outputs = table.lookup(tf.convert_to_tensor(test_key, dtype=key_type)) 78 | self.assertAllEqual(outputs[InternalFtrType.EXIST_KEY], exist_prefix) 79 | if exist_prefix.numpy(): 80 | self.assertAllEqual(tf.sparse.to_dense(outputs[InternalFtrType.COMPLETION_INDICES], default_value=default_value), expected_value) 81 | 82 | 83 | if __name__ == '__main__': 84 | tf.test.main() 85 | -------------------------------------------------------------------------------- /test/smart_compose/resources/embedding_layer_hub/saved_model.pb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/detext/671d43c5ffc83cae635174ed15c58d0bc84b76ef/test/smart_compose/resources/embedding_layer_hub/saved_model.pb -------------------------------------------------------------------------------- /test/smart_compose/resources/embedding_layer_hub/variables/variables.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/detext/671d43c5ffc83cae635174ed15c58d0bc84b76ef/test/smart_compose/resources/embedding_layer_hub/variables/variables.data-00000-of-00001 -------------------------------------------------------------------------------- /test/smart_compose/resources/embedding_layer_hub/variables/variables.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/detext/671d43c5ffc83cae635174ed15c58d0bc84b76ef/test/smart_compose/resources/embedding_layer_hub/variables/variables.index -------------------------------------------------------------------------------- /test/smart_compose/resources/train/dataset/tfrecord/test.tfrecord: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/detext/671d43c5ffc83cae635174ed15c58d0bc84b76ef/test/smart_compose/resources/train/dataset/tfrecord/test.tfrecord -------------------------------------------------------------------------------- /test/smart_compose/resources/vocab.txt: -------------------------------------------------------------------------------- 1 | [UNK] 2 | [CLS] 3 | [SEP] 4 | [PAD] 5 | build 6 | word 7 | function 8 | able 9 | test 10 | this 11 | is 12 | a 13 | source 14 | sentence 15 | target 16 | token 17 | -------------------------------------------------------------------------------- /test/smart_compose/resources/vocab_layer_hub/saved_model.pb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/detext/671d43c5ffc83cae635174ed15c58d0bc84b76ef/test/smart_compose/resources/vocab_layer_hub/saved_model.pb -------------------------------------------------------------------------------- /test/smart_compose/resources/vocab_layer_hub/variables/variables.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/detext/671d43c5ffc83cae635174ed15c58d0bc84b76ef/test/smart_compose/resources/vocab_layer_hub/variables/variables.data-00000-of-00001 -------------------------------------------------------------------------------- /test/smart_compose/resources/vocab_layer_hub/variables/variables.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/detext/671d43c5ffc83cae635174ed15c58d0bc84b76ef/test/smart_compose/resources/vocab_layer_hub/variables/variables.index -------------------------------------------------------------------------------- /test/smart_compose/test_run_smart_compose.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import tensorflow as tf 4 | 5 | from smart_compose.run_smart_compose import main 6 | from smart_compose.utils.parsing_utils import InputFtrType 7 | from smart_compose.utils.testing.test_case import TestCase 8 | 9 | 10 | class TestRun(TestCase): 11 | target_column_name = 'query' 12 | 13 | dataset_args = [ 14 | "--test_file", TestCase.data_dir, 15 | "--dev_file", TestCase.data_dir, 16 | "--train_file", TestCase.data_dir, 17 | "--out_dir", TestCase.out_dir, 18 | 19 | "--num_train_steps", "300", 20 | "--steps_per_stats", "100", 21 | "--steps_per_eval", "100", 22 | 23 | "--train_batch_size", "10", 24 | "--test_batch_size", "10", 25 | 26 | "--resume_training", "False", 27 | "--distribution_strategy", "one_device", 28 | "--num_gpu", "0", 29 | "--run_eagerly", "True", 30 | ] 31 | 32 | optimization_args = [ 33 | "--learning_rate", "0.01", 34 | "--optimizer", "adamw", 35 | "--pmetric", "perplexity" 36 | ] 37 | 38 | network_args = [ 39 | "--num_units", "30", 40 | ] 41 | 42 | feature_args = [ 43 | f"--{InputFtrType.TARGET_COLUMN_NAME}", target_column_name, 44 | "--vocab_file", TestCase.vocab_file, 45 | 46 | "--max_len", "16", 47 | "--min_len", "3", 48 | ] 49 | 50 | args = dataset_args + feature_args + network_args + optimization_args 51 | 52 | def testRunSmartCompose(self): 53 | sys.argv[1:] = self.args 54 | model = main(sys.argv) 55 | 56 | # {'exist_prefix': True, 57 | # 'predicted_scores': 58 | # [[-9.4395971e+00, -1.0921703e+01, -1.1221247e+01, -1.1409840e+01, 59 | # -1.3343969e+01, -1.7302521e+01, -1.0000000e+07, -1.0000000e+07, 60 | # -1.0000000e+07, -1.0000000e+07]], 61 | # 'predicted_texts': 62 | # [[b'this is a [SEP]', b'this is [UNK] [SEP]', 63 | # b'this is sentence [SEP]', b'this is [SEP] [PAD]', 64 | # b'this is function [SEP]', b'this is [PAD] [PAD]', 65 | # b'[PAD] [PAD] [PAD] [PAD]', b'[PAD] [PAD] [PAD] [PAD]', 66 | # b'[PAD] [PAD] [PAD] [PAD]', b'[PAD] [PAD] [PAD] [PAD]']], 67 | # } 68 | print(model({self.target_column_name: ["this is"]})) 69 | # {'exist_prefix': True, 70 | # 'predicted_scores': 71 | # [[-3.75747681e-04, -9.05515480e+00, -1.13139982e+01, 72 | # -1.16390209e+01, -1.19552555e+01, -1.38605614e+01, 73 | # -1.38952370e+01, -1.43704872e+01, -1.00000000e+07, 74 | # -1.00000000e+07]], 75 | # 'predicted_texts': 76 | # [[b'[CLS] word function [SEP]', b'[CLS] word [SEP] [PAD]', 77 | # b'[CLS] word [UNK] [SEP]', b'[CLS] word is [SEP]', 78 | # b'[CLS] word sentence [SEP]', b'[CLS] word source [SEP]', 79 | # b'[CLS] word build [SEP]', b'[CLS] word [PAD] [PAD]', 80 | # b'[PAD] [PAD] [PAD] [PAD]', b'[PAD] [PAD] [PAD] [PAD]']], 81 | # } 82 | print(model({self.target_column_name: ["word"]})) 83 | # {'exist_prefix': True, 84 | # 'predicted_scores': 85 | # [[-3.7574768e-04, -9.0551538e+00, -1.0219677e+01, -1.0755968e+01, 86 | # -1.0904562e+01, -1.1283651e+01, -1.1313998e+01, -1.1424127e+01, 87 | # -1.1639020e+01, -1.1955254e+01]], 88 | # 'predicted_texts': 89 | # [[b'word function [SEP] [PAD]', b'word [SEP] [PAD] [PAD]', 90 | # b'word function function [SEP]', b'word function sentence [SEP]', 91 | # b'word is function [SEP]', b'word word function [SEP]', 92 | # b'word [UNK] [SEP] [PAD]', b'word function [UNK] [SEP]', 93 | # b'word is [SEP] [PAD]', b'word sentence [SEP] [PAD]']], 94 | # } 95 | print(model({self.target_column_name: ["word "]})) 96 | # {'exist_prefix': True, 97 | # 'predicted_scores': 98 | # [[-1.3312817e-02, -6.3512154e+00, -6.8055477e+00, -7.4862165e+00, 99 | # -1.3018910e+01, -1.3128208e+01, -1.3179526e+01, -1.3493284e+01, 100 | # -1.3816385e+01, -1.3923851e+01]], 101 | # 'predicted_texts': 102 | # [[b'word build [SEP] [PAD]', b'word build [UNK] [SEP]', 103 | # b'word build function [SEP]', b'word build sentence [SEP]', 104 | # b'word build a [SEP]', b'word build test [SEP]', 105 | # b'word build word [SEP]', b'word build [PAD] [PAD]', 106 | # b'word build build [SEP]', b'word build is [SEP]']] 107 | # } 108 | print(model({self.target_column_name: ["word b"]})) 109 | 110 | self._cleanUp(TestCase.out_dir) 111 | 112 | 113 | if __name__ == '__main__': 114 | tf.test.main() 115 | -------------------------------------------------------------------------------- /test/smart_compose/train/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/detext/671d43c5ffc83cae635174ed15c58d0bc84b76ef/test/smart_compose/train/__init__.py -------------------------------------------------------------------------------- /test/smart_compose/train/test_data_fn.py: -------------------------------------------------------------------------------- 1 | from os.path import join as path_join 2 | 3 | import tensorflow as tf 4 | from official.utils.misc import distribution_utils 5 | 6 | from smart_compose.train import data_fn 7 | from smart_compose.utils import vocab_utils 8 | from smart_compose.utils.parsing_utils import InputFtrType, iterate_items_with_list_val 9 | from smart_compose.utils.testing.test_case import TestCase 10 | 11 | 12 | class TestData(TestCase): 13 | """Unit test for data_fn.py""" 14 | _, vocab_tf_table = vocab_utils.read_tf_vocab(TestCase.vocab_file, '[UNK]') 15 | vocab_table = TestCase.vocab_table_py 16 | 17 | CLS = '[CLS]' 18 | PAD = '[PAD]' 19 | SEP = '[SEP]' 20 | 21 | PAD_ID = vocab_table[PAD] 22 | SEP_ID = vocab_table[SEP] 23 | CLS_ID = vocab_table[CLS] 24 | 25 | target_column_name = 'query' 26 | 27 | def testInputFnBuilderTfrecord(self): 28 | """ Tests function input_fn_builder() """ 29 | one_device_strategy = distribution_utils.get_distribution_strategy('one_device', num_gpus=0) 30 | for strategy in [None, one_device_strategy]: 31 | self._testInputFnBuilderTfrecord(strategy) 32 | 33 | def _testInputFnBuilderTfrecord(self, strategy): 34 | """ Tests function input_fn_builder() for given strategy """ 35 | data_dir = path_join(self.data_dir) 36 | 37 | # Create a dataset 38 | # Read schema 39 | # Parse and process data in dataset 40 | feature_type_2_name = { 41 | InputFtrType.TARGET_COLUMN_NAME: self.target_column_name, 42 | } 43 | 44 | def _input_fn_tfrecord(ctx): 45 | return data_fn.input_fn_tfrecord(input_pattern=data_dir, 46 | batch_size=batch_size, 47 | mode=tf.estimator.ModeKeys.EVAL, 48 | feature_type_2_name=feature_type_2_name, 49 | input_pipeline_context=ctx) 50 | 51 | batch_size = 2 52 | if strategy is not None: 53 | dataset = strategy.experimental_distribute_datasets_from_function(_input_fn_tfrecord) 54 | else: 55 | dataset = _input_fn_tfrecord(None) 56 | 57 | # Make iterator 58 | for features, label in dataset: 59 | for ftr_type, ftr_name_lst in iterate_items_with_list_val(feature_type_2_name): 60 | if ftr_type in (InputFtrType.TARGET_COLUMN_NAME,): 61 | self.assertLen(ftr_name_lst, 1), f'Length for current ftr type ({ftr_type}) should be 1' 62 | ftr_name = ftr_name_lst[0] 63 | self.assertIn(ftr_name, label) 64 | continue 65 | 66 | # Check source and target text shape 67 | self.assertAllEqual(label[self.target_column_name].shape, [batch_size]) 68 | 69 | break 70 | 71 | 72 | if __name__ == "__main__": 73 | tf.test.main() 74 | -------------------------------------------------------------------------------- /test/smart_compose/train/test_losses.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from smart_compose.train.losses import compute_text_generation_loss, compute_regularization_penalty, compute_loss 4 | from smart_compose.utils.testing.test_case import TestCase 5 | 6 | 7 | class TestLoss(TestCase): 8 | """Unit test for losses.py""" 9 | places = 5 10 | atol = 10 ** (-places) 11 | 12 | def testComputeLoss(self): 13 | """Tests compute_loss() """ 14 | labels = tf.constant([[1, 2], 15 | [1, 2]], dtype=tf.dtypes.float32) 16 | logits = tf.constant([[[0, 1, 0], [0, 0, 1]], 17 | [[0, 1, 0], [0, 0, 1]]], dtype=tf.dtypes.float32) 18 | lengths = tf.constant([2, 2], dtype=tf.dtypes.int32) 19 | text_generation_loss = compute_text_generation_loss(logits=logits, labels=labels, lengths=lengths) 20 | 21 | l1_scale = l2_scale = 0.1 22 | var = tf.constant(3.0) 23 | regularization_penalty = compute_regularization_penalty(l1_scale, l2_scale, [var]) # Get elastic net loss 24 | 25 | self.assertEqual(text_generation_loss + regularization_penalty, compute_loss(l1=l1_scale, l2=l2_scale, 26 | logits=logits, labels=labels, lengths=lengths, 27 | trainable_vars=[var])) 28 | 29 | def testComputeTextGenerationLoss(self): 30 | """Tests compute_text_generation_loss()""" 31 | # Loss of correct prediction must be smaller than that of incorrect prediction 32 | labels = tf.constant([[1, 2], 33 | [1, 2]], dtype=tf.dtypes.float32) 34 | logits = tf.constant([[[0, 1, 0], [0, 0, 1]], 35 | [[0, 1, 0], [0, 0, 1]]], dtype=tf.dtypes.float32) 36 | lengths = tf.constant([2, 2], dtype=tf.dtypes.int32) 37 | loss_of_correct_prediction = compute_text_generation_loss(logits=logits, labels=labels, lengths=lengths) 38 | 39 | labels = tf.constant([[1, 2], 40 | [1, 2]], dtype=tf.dtypes.float32) 41 | logits = tf.constant([[[1, 0, 0], [1, 0, 0]], 42 | [[1, 0, 0], [1, 0, 0]]], dtype=tf.dtypes.float32) 43 | lengths = tf.constant([2, 2], dtype=tf.dtypes.int32) 44 | loss_of_incorrect_prediction = compute_text_generation_loss(logits=logits, labels=labels, lengths=lengths) 45 | self.assertAllGreater(loss_of_incorrect_prediction, loss_of_correct_prediction) 46 | 47 | # Verify effectiveness of length: when length = 0, loss must be 0 48 | labels = tf.constant([[1, 2], 49 | [1, 2]], dtype=tf.dtypes.float32) 50 | logits = tf.constant([[[0, 1, 0], [0, 0, 1]], 51 | [[0, 1, 0], [0, 0, 1]]], dtype=tf.dtypes.float32) 52 | lengths = tf.constant([0, 0], dtype=tf.dtypes.int32) 53 | zero_length_loss = compute_text_generation_loss(logits=logits, labels=labels, lengths=lengths) 54 | self.assertEqual(zero_length_loss, 0) 55 | 56 | def testRegularizationPenalty(self): 57 | """Tests correctness of regularization penalty """ 58 | # Test positive variable 59 | var = tf.constant(3.0) 60 | self._testRegularizationPenalty(var, 0.1, 0.2, 0.3, 1.8) 61 | 62 | # Test negative variable 63 | var = tf.constant(-3.0) 64 | self._testRegularizationPenalty(var, 0.1, 0.2, 0.3, 1.8) 65 | 66 | def _testRegularizationPenalty(self, var, l1_scale, l2_scale, l1_penalty_truth, l2_penalty_truth): 67 | """Tests regularization for a given variable value""" 68 | l1_penalty = compute_regularization_penalty(l1_scale, None, [var]) # Get L1 loss 69 | l2_penalty = compute_regularization_penalty(None, l2_scale, [var]) # Get L2 loss 70 | elastic_net_penalty = compute_regularization_penalty(l1_scale, l2_scale, [var]) # Get elastic net loss 71 | 72 | self.assertAlmostEqual(l1_penalty.numpy(), l1_penalty_truth, places=self.places) 73 | self.assertAlmostEqual(l2_penalty.numpy(), l2_penalty_truth, places=self.places) 74 | self.assertAlmostEqual(elastic_net_penalty.numpy(), l1_penalty_truth + l2_penalty_truth, places=self.places) 75 | 76 | 77 | if __name__ == '__main__': 78 | tf.test.main() 79 | -------------------------------------------------------------------------------- /test/smart_compose/train/test_metrics.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from smart_compose.train.metrics import NegativePerplexity 4 | from smart_compose.utils.testing.test_case import TestCase 5 | 6 | 7 | class TestMetrics(TestCase): 8 | """Unit test for metrics.py""" 9 | 10 | def testPerplexity(self): 11 | """Tests class Perplexity""" 12 | metric = NegativePerplexity() 13 | metric.reset_states() 14 | 15 | initial_value = metric.result() 16 | self.assertEqual(initial_value, -1) 17 | 18 | labels = tf.constant([[1, 2], 19 | [1, 2]], dtype=tf.dtypes.int32) 20 | logits = tf.constant([[[0, 1, 0], [0, 0, 1]], 21 | [[0, 1, 0], [0, 0, 1]]], dtype=tf.dtypes.float32) 22 | lengths = tf.constant([2, 2], dtype=tf.dtypes.int32) 23 | metric.update_state(labels=labels, logits=logits, lengths=lengths) 24 | updated_value = metric.result() 25 | perplexity_val = updated_value.numpy() 26 | 27 | metric.update_state(labels=labels, logits=logits, lengths=lengths) 28 | updated_value = metric.result() 29 | # Perplexity should be the same if inputs of the two updates are the same 30 | self.assertEqual(updated_value, perplexity_val) 31 | 32 | metric.update_state(labels=labels - 1, logits=logits, lengths=lengths) 33 | updated_value = metric.result() 34 | # Perplexity should change if once a different input is given 35 | self.assertNotEqual(updated_value, perplexity_val) 36 | 37 | 38 | if __name__ == '__main__': 39 | tf.test.main() 40 | -------------------------------------------------------------------------------- /test/smart_compose/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/detext/671d43c5ffc83cae635174ed15c58d0bc84b76ef/test/smart_compose/utils/__init__.py -------------------------------------------------------------------------------- /test/smart_compose/utils/test_layer_utils.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import tensorflow as tf 4 | 5 | from smart_compose.utils import layer_utils 6 | from smart_compose.utils.testing.test_case import TestCase 7 | 8 | 9 | class TestLayerUtils(TestCase): 10 | """Unit test for layer_utils.py""" 11 | 12 | def testExpandToSameRank(self): 13 | """Tests expand_to_same_rank() """ 14 | a = tf.ones([1, 2]) 15 | b = tf.ones([1, 2, 3, 4]) 16 | expanded_a = layer_utils.expand_to_same_rank(a, b) 17 | self.assertAllEqual(tf.shape(expanded_a), [1, 2, 1, 1]) 18 | 19 | def testGetLastValidElements(self): 20 | """Tests get_last_valid_elements() """ 21 | a = layer_utils.get_last_valid_elements(tf.constant([ 22 | ['a', 'b', 'c'], 23 | ['d', 'pad', 'pad'] 24 | ]), batch_size=2, seq_len=tf.constant([3, 1])) 25 | self.assertAllEqual(a, tf.constant(['c', 'd'])) 26 | 27 | def testTileBatch(self): 28 | """Tests tile_batch() """ 29 | a = layer_utils.tile_batch(tf.constant([['a'], ['b']]), multiplier=2) 30 | self.assertAllEqual( 31 | a, tf.constant([ 32 | ['a'], ['a'], ['b'], ['b'] 33 | ]) 34 | ) 35 | 36 | def test_get_abstract_tf_function(self): 37 | class A(ABC): 38 | @abstractmethod 39 | @tf.function 40 | def test_tf_function1(self): 41 | pass 42 | 43 | @abstractmethod 44 | def test_function(self): 45 | pass 46 | 47 | @tf.function 48 | @abstractmethod 49 | def test_tf_function2(self): 50 | pass 51 | 52 | class B(A): 53 | pass 54 | 55 | abstract_tf_funcs = layer_utils.get_tf_function_names(B) 56 | self.assertAllEqual( 57 | abstract_tf_funcs, 58 | ['test_tf_function1', 'test_tf_function2'] 59 | ) 60 | 61 | 62 | if __name__ == '__main__': 63 | tf.test.main() 64 | -------------------------------------------------------------------------------- /test/smart_compose/utils/test_parsing_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import tensorflow as tf 4 | 5 | from smart_compose.utils import parsing_utils 6 | from smart_compose.utils.parsing_utils import InputFtrType 7 | from smart_compose.utils.testing.test_case import TestCase 8 | 9 | 10 | class TestParsingUtils(TestCase): 11 | """Unit test for parsing_utils.py""" 12 | atol = 1e-3 13 | 14 | def testIterateItemsWithListVal(self): 15 | """Tests iterate_items_with_list_val""" 16 | dct_lst = [{'a': 'a'}, 17 | {'a': ['a']}] 18 | expected_result_lst = [[('a', ['a'])], 19 | [('a', ['a'])]] 20 | assert len(dct_lst) == len(expected_result_lst), 'Number of test data and result must match' 21 | 22 | for dct, expected_result in zip(dct_lst, expected_result_lst): 23 | self.assertCountEqual(expected_result, list(parsing_utils.iterate_items_with_list_val(dct))) 24 | 25 | def testGetFeatureTypes(self): 26 | """Tests get_feature_types() """ 27 | self.assertCountEqual( 28 | [InputFtrType.TARGET_COLUMN_NAME], 29 | parsing_utils.get_feature_types()) 30 | 31 | def testHparamsLoadAndSave(self): 32 | """Tests loading and saving of hparams""" 33 | hparams = parsing_utils.HParams(a=1, b=2, c=[1, 2, 3]) 34 | parsing_utils.save_hparams(self.resource_dir, hparams) 35 | loaded_hparams = parsing_utils.load_hparams(self.resource_dir) 36 | self.assertEqual(hparams, loaded_hparams) 37 | os.remove(parsing_utils._get_hparam_path(self.resource_dir)) 38 | 39 | def testEstimateStepsPerEpoch(self): 40 | """Tests estimate_steps_per_epoch() """ 41 | num_record = parsing_utils.estimate_steps_per_epoch(self.data_dir, 1) 42 | self.assertEqual(num_record, 40) 43 | 44 | 45 | if __name__ == '__main__': 46 | tf.test.main() 47 | -------------------------------------------------------------------------------- /test/smart_compose/utils/test_vocab_utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from smart_compose.utils import vocab_utils 4 | from smart_compose.utils.testing.test_case import TestCase 5 | 6 | 7 | class TestVocabUtils(TestCase): 8 | """Unit test for vocab_utils.py""" 9 | 10 | def testReadVocab(self): 11 | """Tests read_vocab()""" 12 | vocab = vocab_utils.read_vocab(self.vocab_file) 13 | self.assertEqual(len(vocab), 16) 14 | self.assertEqual(vocab[self.UNK], 0) 15 | 16 | def testReadTfVocab(self): 17 | """Tests read_tf_vocab()""" 18 | _, vocab = vocab_utils.read_tf_vocab(self.vocab_file, self.UNK) 19 | self.assertEqual(vocab.size(), 16) 20 | self.assertEqual(vocab.lookup(tf.constant(self.UNK)), 0) 21 | 22 | def testReadTfVocabInverse(self): 23 | """Tests read_tf_vocab_inverse()""" 24 | _, vocab = vocab_utils.read_tf_vocab_inverse(self.vocab_file, self.UNK) 25 | self.assertEqual(vocab.size(), 16) 26 | self.assertEqual(vocab.lookup(tf.constant(self.UNK_ID)), self.UNK) 27 | 28 | 29 | if __name__ == '__main__': 30 | tf.test.main() 31 | -------------------------------------------------------------------------------- /thumbnail_DeText.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/detext/671d43c5ffc83cae635174ed15c58d0bc84b76ef/thumbnail_DeText.png --------------------------------------------------------------------------------