├── .gitattributes ├── .github ├── actionlint-matcher.json ├── actionlint.yaml ├── copy-pr-bot.yaml ├── release-drafter.yml └── workflows │ ├── ISSUE_TEMPLATE │ ├── bug-report.md │ ├── documentation-request.md │ ├── feature-request.md │ ├── submit-question.md │ └── task.md │ ├── check-base-branch.yaml │ ├── cpu-ci.yml │ ├── docs-preview-pr.yaml │ ├── docs-remove-stale-reviews.yaml │ ├── docs-sched-rebuild.yaml │ ├── gpu-ci.yml │ ├── lint.yaml │ ├── merlin.yml │ ├── packages.yaml │ ├── release-drafter.yaml │ ├── require-label.yaml │ ├── set-stable-branch.yaml │ ├── tox.yml │ └── triage.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .prettierignore ├── .pylintrc ├── CLA.md ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── README.md ├── ci ├── ignore_codespell_words.txt ├── pr.gpu.Jenkinsfile └── test_unit.sh ├── conda └── recipe │ └── meta.yaml ├── docs ├── Makefile ├── README.md ├── make.bat └── source │ ├── _static │ ├── .gitkeep │ ├── NVIDIA-LogoBlack.svg │ ├── NVIDIA-LogoWhite.svg │ ├── css │ │ ├── custom.css │ │ └── versions.css │ ├── favicon.png │ └── js │ │ └── rtd-version-switcher.js │ ├── _templates │ ├── footer.html │ ├── layout.html │ ├── merlin-ecosystem.html │ └── versions.html │ ├── api │ ├── index.rst │ ├── merlin.dag.rst │ ├── merlin.io.rst │ └── merlin.schema.rst │ ├── conf.py │ ├── index.rst │ └── toc.yaml ├── merlin ├── config │ └── __init__.py ├── core │ ├── __init__.py │ ├── _version.py │ ├── compat │ │ ├── __init__.py │ │ ├── tensorflow.py │ │ └── torch.py │ ├── dispatch.py │ ├── has_gpu.py │ ├── protocols.py │ └── utils.py ├── dag │ ├── __init__.py │ ├── executors.py │ ├── graph.py │ ├── node.py │ ├── operator.py │ ├── ops │ │ ├── __init__.py │ │ ├── add_metadata.py │ │ ├── concat_columns.py │ │ ├── grouping.py │ │ ├── rename.py │ │ ├── selection.py │ │ ├── stat_operator.py │ │ ├── subgraph.py │ │ ├── subset_columns.py │ │ ├── subtraction.py │ │ └── udf.py │ ├── runtime.py │ ├── selector.py │ └── utils.py ├── dispatch │ └── lazy.py ├── dtypes │ ├── __init__.py │ ├── aliases.py │ ├── base.py │ ├── mapping.py │ ├── mappings │ │ ├── __init__.py │ │ ├── cudf.py │ │ ├── merlin.py │ │ ├── numpy.py │ │ ├── pandas.py │ │ ├── python.py │ │ ├── tf.py │ │ ├── torch.py │ │ └── triton.py │ ├── registry.py │ └── shape.py ├── io │ ├── __init__.py │ ├── avro.py │ ├── csv.py │ ├── dask.py │ ├── dataframe_engine.py │ ├── dataframe_iter.py │ ├── dataset.py │ ├── dataset_engine.py │ ├── fsspec_utils.py │ ├── hugectr.py │ ├── parquet.py │ ├── shuffle.py │ ├── worker.py │ ├── writer.py │ └── writer_factory.py ├── schema │ ├── __init__.py │ ├── io │ │ ├── __init__.py │ │ ├── proto_utils.py │ │ ├── schema_bp.py │ │ └── tensorflow_metadata.py │ ├── schema.py │ └── tags.py ├── table │ ├── __init__.py │ ├── conversions.py │ ├── cupy_column.py │ ├── numpy_column.py │ ├── tensor_column.py │ ├── tensor_table.py │ ├── tensorflow_column.py │ └── torch_column.py └── testing │ ├── __init__.py │ └── assert_equals.py ├── mypy.ini ├── pyproject.toml ├── requirements-dev.txt ├── requirements-docs.txt ├── requirements-gpu.txt ├── requirements-test-cpu.txt ├── requirements-test-gpu.txt ├── requirements-test.txt ├── requirements.txt ├── setup.cfg ├── setup.py ├── tests ├── __init__.py ├── conftest.py └── unit │ ├── core │ ├── test_dispatch.py │ ├── test_protocols.py │ └── test_version.py │ ├── dag │ ├── ops │ │ ├── test_addmetadata.py │ │ ├── test_rename.py │ │ ├── test_selection.py │ │ ├── test_stat_op.py │ │ ├── test_subgraph.py │ │ └── test_udf.py │ ├── test_base_operator.py │ ├── test_column_selector.py │ ├── test_dag_utils.py │ ├── test_executors.py │ └── test_graph.py │ ├── dispatch │ └── test_lazy_dispatch.py │ ├── dtypes │ ├── test_cudf.py │ ├── test_module.py │ └── test_shape.py │ ├── io │ ├── test_avro.py │ ├── test_dataset.py │ ├── test_io.py │ └── test_worker.py │ ├── schema │ ├── test_column_schemas.py │ ├── test_schema.py │ ├── test_schema_io.py │ └── test_tags.py │ ├── table │ ├── test_convert_column.py │ ├── test_tensor_column.py │ └── test_tensor_table.py │ └── utils │ └── test_utils.py ├── tox.ini └── versioneer.py /.gitattributes: -------------------------------------------------------------------------------- 1 | merlin/core/_version.py export-subst 2 | -------------------------------------------------------------------------------- /.github/actionlint-matcher.json: -------------------------------------------------------------------------------- 1 | { 2 | "problemMatcher": [ 3 | { 4 | "owner": "actionlint", 5 | "pattern": [ 6 | { 7 | "regexp": "^(?:\\x1b\\[\\d+m)?(.+?)(?:\\x1b\\[\\d+m)*:(?:\\x1b\\[\\d+m)*(\\d+)(?:\\x1b\\[\\d+m)*:(?:\\x1b\\[\\d+m)*(\\d+)(?:\\x1b\\[\\d+m)*: (?:\\x1b\\[\\d+m)*(.+?)(?:\\x1b\\[\\d+m)* \\[(.+?)\\]$", 8 | "file": 1, 9 | "line": 2, 10 | "column": 3, 11 | "message": 4, 12 | "code": 5 13 | } 14 | ] 15 | } 16 | ] 17 | } 18 | -------------------------------------------------------------------------------- /.github/actionlint.yaml: -------------------------------------------------------------------------------- 1 | self-hosted-runner: 2 | # Labels of self-hosted runner in array of string 3 | labels: 4 | - 1GPU 5 | - 2GPU 6 | - linux-amd64-gpu-p100-latest-1 7 | -------------------------------------------------------------------------------- /.github/copy-pr-bot.yaml: -------------------------------------------------------------------------------- 1 | # Configuration file for `copy-pr-bot` GitHub App 2 | # https://docs.gha-runners.nvidia.com/apps/copy-pr-bot/ 3 | 4 | enabled: true 5 | -------------------------------------------------------------------------------- /.github/release-drafter.yml: -------------------------------------------------------------------------------- 1 | categories: 2 | - title: "⚠ Breaking Changes" 3 | labels: 4 | - "breaking" 5 | - title: "🐜 Bug Fixes" 6 | labels: 7 | - "bug" 8 | - title: "🚀 Features" 9 | labels: 10 | - "feature" 11 | - "enhancement" 12 | - title: "📄 Documentation" 13 | labels: 14 | - "documentation" 15 | - "examples" 16 | - title: "🔧 Maintenance" 17 | labels: 18 | - "build" 19 | - "dependencies" 20 | - "chore" 21 | - "ci" 22 | change-template: "- $TITLE @$AUTHOR (#$NUMBER)" 23 | exclude-labels: 24 | - "skip-changelog" 25 | template: | 26 | ## What’s Changed 27 | 28 | $CHANGES 29 | -------------------------------------------------------------------------------- /.github/workflows/ISSUE_TEMPLATE/bug-report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "\U0001F41B Bug Report" 3 | about: Submit a bug report to help us improve Merlin Core 4 | title: "[BUG]" 5 | labels: "status/needs-triage, bug" 6 | assignees: "" 7 | --- 8 | 9 | ### Bug description 10 | 11 | 12 | 13 | ### Steps/Code to reproduce bug 14 | 15 | 16 | 17 | 1. 18 | 2. 19 | 3. 20 | 21 | ### Expected behavior 22 | 23 | 24 | 25 | ### Environment details 26 | 27 | - Merlin version: 28 | - Platform: 29 | - Python version: 30 | - PyTorch version (GPU?): 31 | - Tensorflow version (GPU?): 32 | 33 | ### Additional context 34 | 35 | 36 | -------------------------------------------------------------------------------- /.github/workflows/ISSUE_TEMPLATE/documentation-request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Documentation request 3 | about: Report incorrect or needed documentation 4 | title: "[DOC]" 5 | labels: "status/needs-triage, area/documentation" 6 | assignees: "" 7 | --- 8 | 9 | ## Report incorrect documentation 10 | 11 | ### Location of incorrect documentation 12 | 13 | 14 | 15 | ### Describe the problems or issues found in the documentation 16 | 17 | 18 | 19 | ### Steps taken to verify documentation is incorrect 20 | 21 | 22 | 23 | ### Suggested fix for documentation 24 | 25 | 26 | 27 | --- 28 | 29 | ## Report needed documentation 30 | 31 | ### Report needed documentation 32 | 33 | 34 | 35 | ### Describe the documentation you'd like 36 | 37 | 38 | 39 | ### Steps taken to search for needed documentation\*\* 40 | 41 | 42 | -------------------------------------------------------------------------------- /.github/workflows/ISSUE_TEMPLATE/feature-request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "\U0001F680 Feature request" 3 | about: Submit a proposal/request for a new Merlin Core feature 4 | title: "[FEA]" 5 | labels: "status/needs-triage, kind/feature-request" 6 | assignees: "" 7 | --- 8 | 9 | # 🚀 Feature request 10 | 11 | 13 | 14 | ## Motivation 15 | 16 | 19 | 20 | ## Your contribution 21 | 22 | 25 | -------------------------------------------------------------------------------- /.github/workflows/ISSUE_TEMPLATE/submit-question.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "❓ Questions & Help" 3 | about: Ask a general question about Merlin Core 4 | title: "[QST]" 5 | labels: "status/needs-triage, kind/question" 6 | assignees: "" 7 | --- 8 | 9 | # ❓ Questions & Help 10 | 11 | ## Details 12 | 13 | 14 | -------------------------------------------------------------------------------- /.github/workflows/ISSUE_TEMPLATE/task.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Task 3 | about: A general task that we're tracking in Github 4 | title: "[Task]" 5 | labels: "" 6 | assignees: "" 7 | --- 8 | 9 | ### Description 10 | 11 | **Additional context** 12 | Add any other context, code examples, or references to existing implementations about the task here. 13 | -------------------------------------------------------------------------------- /.github/workflows/check-base-branch.yaml: -------------------------------------------------------------------------------- 1 | name: Require Development Base Branch 2 | 3 | on: 4 | pull_request: 5 | types: [synchronize, opened, reopened, labeled, unlabeled] 6 | 7 | jobs: 8 | check: 9 | uses: NVIDIA-Merlin/.github/.github/workflows/check-base-branch.yml@main 10 | -------------------------------------------------------------------------------- /.github/workflows/cpu-ci.yml: -------------------------------------------------------------------------------- 1 | name: Core Tests (CPU) 2 | 3 | on: 4 | workflow_dispatch: 5 | push: 6 | branches: [main] 7 | tags: 8 | - "v[0-9]+.[0-9]+.[0-9]+" 9 | pull_request: 10 | branches: [main] 11 | 12 | jobs: 13 | build: 14 | uses: ./.github/workflows/tox.yml 15 | with: 16 | env: test-cpu 17 | 18 | store-pr-information: 19 | needs: [build] 20 | runs-on: ubuntu-latest 21 | steps: 22 | - name: Store PR information 23 | run: | 24 | mkdir ./pr 25 | echo ${{ github.event.number }} > ./pr/pr.txt 26 | echo ${{ github.event.pull_request.merged }} > ./pr/merged.txt 27 | echo ${{ github.event.action }} > ./pr/action.txt 28 | - name: Upload PR information 29 | uses: actions/upload-artifact@v2 30 | with: 31 | name: pr 32 | path: pr/ 33 | -------------------------------------------------------------------------------- /.github/workflows/docs-preview-pr.yaml: -------------------------------------------------------------------------------- 1 | name: docs-preview-pr 2 | 3 | on: 4 | workflow_run: 5 | workflows: ["Packages"] 6 | types: [completed] 7 | 8 | env: 9 | WF_ID: ${{ github.event.workflow_run.id }} 10 | 11 | jobs: 12 | preview: 13 | uses: nvidia-merlin/.github/.github/workflows/docs-preview-pr-common.yaml@main 14 | -------------------------------------------------------------------------------- /.github/workflows/docs-remove-stale-reviews.yaml: -------------------------------------------------------------------------------- 1 | name: docs-remove-stale-reviews 2 | 3 | on: 4 | schedule: 5 | # 42 minutes after 0:00 UTC on Sundays 6 | - cron: "42 0 * * 0" 7 | workflow_dispatch: 8 | 9 | jobs: 10 | remove: 11 | uses: nvidia-merlin/.github/.github/workflows/docs-remove-stale-reviews-common.yaml@main 12 | -------------------------------------------------------------------------------- /.github/workflows/docs-sched-rebuild.yaml: -------------------------------------------------------------------------------- 1 | name: docs-sched-rebuild 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | tags: 7 | - "v[0-9]+.[0-9]+.[0-9]+" 8 | workflow_dispatch: 9 | 10 | jobs: 11 | build: 12 | runs-on: [ubuntu-latest] 13 | 14 | steps: 15 | - uses: actions/checkout@v3 16 | with: 17 | fetch-depth: 0 18 | ref: main 19 | - name: Set up Python 3.9 20 | uses: actions/setup-python@v4 21 | with: 22 | python-version: 3.9 23 | - name: Install Ubuntu packages 24 | run: | 25 | sudo apt-get update -y 26 | sudo apt-get install -y protobuf-compiler 27 | - name: Install dependencies 28 | run: | 29 | python -m pip install --upgrade pip setuptools==59.4.0 wheel tox 30 | - name: Building docs (multiversion) 31 | run: | 32 | # setup local branches that we'd like to build docs for 33 | # required for sphinx-multiversion to find these 34 | git branch --track stable origin/stable || true 35 | tox -e docs-multi 36 | - name: Delete unnecessary files 37 | run: | 38 | find docs/build -name .doctrees -prune -exec rm -rf {} \; 39 | find docs/build -name .buildinfo -exec rm {} \; 40 | - name: Upload HTML 41 | uses: actions/upload-artifact@v3 42 | with: 43 | name: html-build-artifact 44 | path: docs/build/html 45 | if-no-files-found: error 46 | retention-days: 1 47 | 48 | # Identify the dir for the HTML. 49 | store-html: 50 | needs: [build] 51 | runs-on: ubuntu-latest 52 | steps: 53 | - uses: actions/checkout@v3 54 | with: 55 | ref: "gh-pages" 56 | - name: Initialize Git configuration 57 | run: | 58 | git config user.name docs-sched-rebuild 59 | git config user.email do-not-send-@github.com 60 | - name: Download artifacts 61 | uses: actions/download-artifact@v3 62 | with: 63 | name: html-build-artifact 64 | - name: Copy HTML directories 65 | run: | 66 | ls -asl 67 | for i in */ 68 | do 69 | echo "Git adding ${i}" 70 | git add "${i}" 71 | done 72 | - name: Check or create dot-no-jekyll file 73 | run: | 74 | if [ -f ".nojekyll" ]; then 75 | echo "The dot-no-jekyll file already exists." 76 | exit 0 77 | fi 78 | touch .nojekyll 79 | git add .nojekyll 80 | - name: Check or create redirect page 81 | env: 82 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 83 | run: | 84 | resp=$(grep 'http-equiv="refresh"' index.html 2>/dev/null) || true 85 | if [ -n "${resp}" ]; then 86 | echo "The redirect file already exists." 87 | exit 0 88 | fi 89 | 90 | # If any of these commands fail, fail the build. 91 | def_branch="stable" 92 | html_url=$(gh api "repos/${GITHUB_REPOSITORY}/pages" --jq ".html_url") 93 | 94 | cat > index.html << EOF 95 | 96 | 97 | 98 | Redirect to documentation 99 | 100 | 102 | 107 | 108 | 109 |

Please follow the link to the 110 | ${def_branch}' branch documentation.

111 | 112 | 113 | EOF 114 | 115 | git add index.html 116 | - name: Commit changes to the GitHub Pages branch 117 | run: | 118 | git status 119 | if git commit -m 'Pushing changes to GitHub Pages.'; then 120 | git push -f 121 | else 122 | echo "Nothing changed." 123 | fi 124 | -------------------------------------------------------------------------------- /.github/workflows/gpu-ci.yml: -------------------------------------------------------------------------------- 1 | name: GPU CI 2 | 3 | on: 4 | workflow_dispatch: 5 | push: 6 | branches: 7 | - main 8 | - "pull-request/[0-9]+" 9 | tags: 10 | - "v[0-9]+.[0-9]+.[0-9]+" 11 | 12 | jobs: 13 | gpu-ci: 14 | runs-on: linux-amd64-gpu-p100-latest-1 15 | container: 16 | image: nvcr.io/nvstaging/merlin/merlin-ci-runner:latest 17 | env: 18 | NVIDIA_VISIBLE_DEVICES: ${{ env.NVIDIA_VISIBLE_DEVICES }} 19 | options: --shm-size=1G 20 | credentials: 21 | username: $oauthtoken 22 | password: ${{ secrets.NGC_TOKEN }} 23 | 24 | steps: 25 | - uses: actions/checkout@v3 26 | with: 27 | fetch-depth: 0 28 | - name: Run tests 29 | run: | 30 | ref_type=${{ github.ref_type }} 31 | branch=main 32 | if [[ $ref_type == "tag"* ]] 33 | then 34 | raw=$(git branch -r --contains ${{ github.ref_name }}) 35 | branch=${raw/origin\/} 36 | fi 37 | tox -e test-gpu -- "$branch" 38 | 39 | gpu-ci-not-visible: 40 | runs-on: linux-amd64-gpu-p100-latest-1 41 | container: 42 | image: nvcr.io/nvstaging/merlin/merlin-ci-runner:latest 43 | env: 44 | NVIDIA_VISIBLE_DEVICES: ${{ env.NVIDIA_VISIBLE_DEVICES }} 45 | options: --shm-size=1G 46 | credentials: 47 | username: $oauthtoken 48 | password: ${{ secrets.NGC_TOKEN }} 49 | 50 | steps: 51 | - uses: actions/checkout@v3 52 | with: 53 | fetch-depth: 0 54 | - name: Run tests 55 | run: | 56 | ref_type=${{ github.ref_type }} 57 | branch=main 58 | if [[ $ref_type == "tag"* ]] 59 | then 60 | raw=$(git branch -r --contains ${{ github.ref_name }}) 61 | branch=${raw/origin\/} 62 | fi 63 | tox -e test-gpu-not-visible -- "$branch" 64 | -------------------------------------------------------------------------------- /.github/workflows/lint.yaml: -------------------------------------------------------------------------------- 1 | name: lint 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: [main] 7 | tags: 8 | - "v[0-9]+.[0-9]+.[0-9]+" 9 | 10 | jobs: 11 | pre-commit: 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: actions/checkout@v3 15 | - uses: actions/setup-python@v4 16 | - uses: pre-commit/action@v3.0.0 17 | 18 | actionlint: 19 | runs-on: ubuntu-latest 20 | steps: 21 | - uses: actions/checkout@v3 22 | - name: Check workflow files 23 | run: | 24 | echo "::add-matcher::.github/actionlint-matcher.json" 25 | bash <(curl https://raw.githubusercontent.com/rhysd/actionlint/fd7ba3c382e13dcc0248e425b4cbc3f1185fa3ee/scripts/download-actionlint.bash) 26 | ./actionlint 27 | shell: bash 28 | -------------------------------------------------------------------------------- /.github/workflows/merlin.yml: -------------------------------------------------------------------------------- 1 | name: "Test Merlin" 2 | 3 | on: 4 | workflow_dispatch: 5 | push: 6 | branches: [main] 7 | tags: 8 | - "v[0-9]+.[0-9]+.[0-9]+" 9 | pull_request: 10 | branches: [main] 11 | 12 | jobs: 13 | dataloader: 14 | name: "Dataloader (CPU)" 15 | uses: ./.github/workflows/tox.yml 16 | with: 17 | env: test-dataloader-cpu 18 | 19 | systems: 20 | name: "Systems (CPU)" 21 | uses: ./.github/workflows/tox.yml 22 | with: 23 | env: test-systems-cpu 24 | 25 | models: 26 | name: "Models (CPU)" 27 | uses: ./.github/workflows/tox.yml 28 | with: 29 | env: test-models-cpu 30 | 31 | nvtabular: 32 | name: "NVTabular (CPU)" 33 | uses: ./.github/workflows/tox.yml 34 | with: 35 | env: test-nvtabular-cpu 36 | -------------------------------------------------------------------------------- /.github/workflows/release-drafter.yaml: -------------------------------------------------------------------------------- 1 | name: release-drafter 2 | 3 | on: 4 | push: 5 | # trigger on tags only 6 | tags: 7 | - "v[0-9]+.[0-9]+.[0-9]+" 8 | 9 | workflow_dispatch: 10 | 11 | jobs: 12 | update_release_draft: 13 | uses: nvidia-merlin/.github/.github/workflows/release-drafter-common.yaml@main 14 | -------------------------------------------------------------------------------- /.github/workflows/require-label.yaml: -------------------------------------------------------------------------------- 1 | name: Require PR Labels 2 | 3 | on: 4 | pull_request: 5 | types: [synchronize, opened, reopened, labeled, unlabeled] 6 | 7 | jobs: 8 | check-labels: 9 | uses: nvidia-merlin/.github/.github/workflows/require-label.yaml@main 10 | -------------------------------------------------------------------------------- /.github/workflows/set-stable-branch.yaml: -------------------------------------------------------------------------------- 1 | name: Set Stable Branch 2 | 3 | on: 4 | workflow_dispatch: 5 | release: 6 | types: [published, deleted] 7 | 8 | jobs: 9 | set-stable-branch: 10 | uses: NVIDIA-Merlin/.github/.github/workflows/set-stable-branch.yaml@main 11 | -------------------------------------------------------------------------------- /.github/workflows/tox.yml: -------------------------------------------------------------------------------- 1 | name: "Run Tox Env" 2 | 3 | on: 4 | workflow_call: 5 | inputs: 6 | env: 7 | description: "The name of the tox environment to run" 8 | required: true 9 | type: string 10 | 11 | jobs: 12 | check: 13 | runs-on: ${{ matrix.os }} 14 | strategy: 15 | matrix: 16 | python-version: [3.8] 17 | os: [ubuntu-latest] 18 | 19 | steps: 20 | - uses: actions/checkout@v3 21 | - name: Set up Python ${{ matrix.python-version }} 22 | uses: actions/setup-python@v4 23 | with: 24 | python-version: ${{ matrix.python-version }} 25 | - name: Install Ubuntu packages 26 | run: | 27 | sudo apt-get update -y 28 | sudo apt-get install -y protobuf-compiler 29 | - name: Install and upgrade python packages 30 | run: | 31 | python -m pip install --upgrade pip setuptools==59.4.0 wheel tox 32 | - name: Get Branch name 33 | id: get-branch-name 34 | uses: NVIDIA-Merlin/.github/actions/branch-name@main 35 | - name: Run tests 36 | run: | 37 | branch="${{ steps.get-branch-name.outputs.branch }}" 38 | GIT_COMMIT=$(git rev-parse HEAD) tox -e ${{ inputs.env }} -- $branch 39 | -------------------------------------------------------------------------------- /.github/workflows/triage.yml: -------------------------------------------------------------------------------- 1 | name: triage_issues 2 | on: 3 | issues: 4 | types: [opened, reopened] 5 | 6 | jobs: 7 | triage_issue: 8 | uses: nvidia-merlin/.github/.github/workflows/triage.yaml@main 9 | secrets: 10 | TRIAGE_APP_ID: ${{ secrets.TRIAGE_APP_ID }} 11 | TRIAGE_APP_PEM: ${{ secrets.TRIAGE_APP_PEM }} 12 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **/.ipynb_checkpoints/* 2 | /.*_checkpoints/ 3 | .ipynb_checkpoints/* 4 | */.jupyter/* 5 | */.local/* 6 | 7 | cache/* 8 | data*/* 9 | *.parquet 10 | *.orc 11 | *.csv 12 | 13 | 14 | # Byte-compiled / optimized / DLL files 15 | __pycache__/ 16 | *.py[cod] 17 | *$py.class 18 | 19 | # C extensions 20 | *.so 21 | 22 | # Distribution / packaging 23 | .Python 24 | build/ 25 | develop-eggs/ 26 | dist/ 27 | downloads/ 28 | eggs/ 29 | .eggs/ 30 | lib/ 31 | lib64/ 32 | parts/ 33 | sdist/ 34 | var/ 35 | wheels/ 36 | share/python-wheels/ 37 | *.egg-info/ 38 | .installed.cfg 39 | *.egg 40 | MANIFEST 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | *.py,cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | cover/ 60 | 61 | # Sphinx documentation 62 | docs/build/ 63 | docs/source/README.md 64 | docs/source/LICENSE 65 | 66 | # IPython 67 | profile_default/ 68 | ipython_config.py 69 | 70 | # mypy 71 | .mypy_cache/ 72 | .dmypy.json 73 | dmypy.json 74 | 75 | # PyCharm 76 | .idea 77 | 78 | .vscode/ 79 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | # imports 3 | - repo: https://github.com/MarcoGorelli/absolufy-imports 4 | rev: v0.3.1 5 | hooks: 6 | - id: absolufy-imports 7 | - repo: https://github.com/timothycrosley/isort 8 | rev: 5.12.0 9 | hooks: 10 | - id: isort 11 | additional_dependencies: [toml] 12 | exclude: ^examples/ 13 | # types 14 | - repo: https://github.com/pre-commit/mirrors-mypy 15 | rev: "v0.991" 16 | hooks: 17 | - id: mypy 18 | language_version: python3 19 | args: 20 | [ 21 | --non-interactive, 22 | --install-types, 23 | --namespace-packages, 24 | --explicit-package-bases, 25 | ] 26 | exclude: ^docs/ 27 | # code style 28 | - repo: https://github.com/python/black 29 | rev: 23.1.0 30 | hooks: 31 | - id: black 32 | - repo: https://github.com/pycqa/pylint 33 | rev: v2.16.0 34 | hooks: 35 | - id: pylint 36 | - repo: https://github.com/pycqa/flake8 37 | rev: 6.0.0 38 | hooks: 39 | - id: flake8 40 | - repo: https://github.com/pre-commit/mirrors-prettier 41 | rev: v2.7.1 42 | hooks: 43 | - id: prettier 44 | types_or: [yaml, markdown] 45 | # notebooks 46 | - repo: https://github.com/s-weigand/flake8-nb 47 | rev: v0.5.2 48 | hooks: 49 | - id: flake8-nb 50 | files: \.ipynb$ 51 | # documentation 52 | - repo: https://github.com/econchick/interrogate 53 | rev: 1.5.0 54 | hooks: 55 | - id: interrogate 56 | exclude: ^(build|docs|merlin/io|tests|setup.py|versioneer.py) 57 | args: [--config=pyproject.toml] 58 | - repo: https://github.com/codespell-project/codespell 59 | rev: v2.2.2 60 | hooks: 61 | - id: codespell 62 | # security 63 | - repo: https://github.com/PyCQA/bandit 64 | rev: 1.7.4 65 | hooks: 66 | - id: bandit 67 | args: [--verbose, -ll, -x, tests, examples, bench] 68 | -------------------------------------------------------------------------------- /.prettierignore: -------------------------------------------------------------------------------- 1 | build 2 | conda 3 | -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | [MASTER] 2 | 3 | ignore-patterns=_version.py,model_config_pb2.py 4 | 5 | extension-pkg-allow-list=hugectr,nvtabular_cpp 6 | 7 | [MESSAGES CONTROL] 8 | disable=fixme, 9 | # docstrings aren't required (yet). 10 | missing-function-docstring, 11 | missing-module-docstring, 12 | missing-class-docstring, 13 | 14 | # formatting checks that we're handling with black/isort 15 | wrong-import-order, 16 | wrong-import-position, 17 | ungrouped-imports, 18 | line-too-long, 19 | superfluous-parens, 20 | trailing-whitespace, 21 | 22 | # we'll probably never enable these checks 23 | invalid-name, 24 | import-error, 25 | 26 | # disable code-complexity checks for now 27 | # TODO: should we configure the thresholds for these rather than just disable? 28 | too-many-function-args, 29 | too-many-instance-attributes, 30 | too-many-locals, 31 | too-many-branches, 32 | too-many-nested-blocks, 33 | too-many-statements, 34 | too-many-arguments, 35 | too-many-return-statements, 36 | too-many-lines, 37 | too-few-public-methods, 38 | 39 | # many of these checks would be great to include at some point, but would 40 | # require some changes to our codebase 41 | useless-return, 42 | protected-access, 43 | arguments-differ, 44 | unused-argument, 45 | unused-variable, 46 | abstract-method, 47 | no-name-in-module, 48 | attribute-defined-outside-init, 49 | redefined-outer-name, 50 | import-outside-toplevel, 51 | no-else-continue, 52 | no-else-return, 53 | no-else-raise, 54 | no-member, 55 | super-with-arguments, 56 | unsupported-assignment-operation, 57 | inconsistent-return-statements, 58 | duplicate-string-formatting-argument, 59 | len-as-condition, 60 | cyclic-import, 61 | 62 | # producing false positives 63 | unexpected-keyword-arg, 64 | not-an-iterable, 65 | unsubscriptable-object 66 | 67 | [SIMILARITIES] 68 | min-similarity-lines=20 69 | ignore-comments=yes 70 | ignore-docstrings=yes 71 | ignore-imports=yes 72 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | If you are interested in contributing to the project, your contributions will fall 4 | into three categories: 5 | 6 | 1. You want to report a bug, feature request, or documentation issue 7 | - File an [issue](https://github.com/NVIDIA-Merlin/core/issues/new/choose) 8 | describing what you encountered or what you want to see changed. 9 | - The maintainers evaluate the issues and triage them, scheduling 10 | them for a release. If you believe the issue needs priority attention 11 | comment on the issue to notify the team. 12 | 2. You want to propose a new feature and implement it 13 | - Post about your intended feature, and we shall discuss the design and 14 | implementation. 15 | - Once we agree that the plan looks good, go ahead and implement it, using 16 | the [code contributions](#code-contributions) guide below. 17 | 3. You want to implement a feature or bug-fix for an outstanding issue 18 | - Follow the [code contributions](#code-contributions) guide below. 19 | - If you need more context on a particular issue, please ask and we shall 20 | provide. 21 | 22 | ## Code contributions 23 | 24 | ### Your first issue 25 | 26 | 1. Read the project's [README.md](https://github.com/NVIDIA-Merlin/core/blob/main/README.md) 27 | to learn how to setup the development environment. 28 | 2. Find an issue to work on. The best way is to look for the 29 | [good first issue](https://github.com/NVIDIA-Merlin/core/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22) 30 | or [help wanted](https://github.com/NVIDIA-Merlin/core/issues?q=is%3Aissue+is%3Aopen+label%3A%22help+wanted%22) labels. 31 | 3. Comment on the issue saying you are going to work on it. 32 | 4. Code! Make sure to update unit tests! 33 | 5. When done, [create your pull request](https://github.com/NVIDIA-Merlin/core/compare). 34 | 6. Verify that CI passes all [status checks](https://help.github.com/articles/about-status-checks/). 35 | Fix if needed. 36 | 7. Wait for other developers to review your code and update code as needed. 37 | 8. After your pull request reviewed and approved, a maintainer will merge your 38 | pull request. 39 | 40 | Remember, if you are unsure about anything, don't hesitate to comment on issues 41 | and ask for clarifications! 42 | 43 | ## Label your PRs 44 | 45 | This repository uses the release-drafter action to draft and create our change log. 46 | 47 | Please add one of the following labels to your PR to specify the type of contribution 48 | and help categorize the PR in our change log: 49 | 50 | - `breaking` -- The PR creates a breaking change to the API. 51 | - `bug` -- The PR fixes a problem with the code. 52 | - `feature` or `enhancement` -- The PR introduces a backward-compatible feature. 53 | - `documentation` or `examples` -- The PR is an addition or update to documentation. 54 | - `build`, `dependencies`, `chore`, or `ci` -- The PR is related to maintaining the 55 | repository or the project. 56 | 57 | By default, an unlabeled PR is listed at the top of the change log and is not 58 | grouped under a heading like _Features_ that groups similar PRs. 59 | Labeling the PRs so we can categorize them is preferred. 60 | 61 | If, for some reason, you do not believe your PR should be included in the change 62 | log, you can add the `skip-changelog` label. 63 | This label excludes the PR from the change log. 64 | 65 | For more information, see `.github/release-drafter.yml` in the repository 66 | or go to . 67 | 68 | ## Attribution 69 | 70 | Portions adopted from . 71 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include versioneer.py 2 | recursive-include merlin *.py *.proto 3 | include requirements.txt 4 | include requirements-dev.txt 5 | 6 | include merlin/core/_version.py 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [Merlin Core](https://github.com/NVIDIA-Merlin/core) 2 | 3 | [![PyPI](https://img.shields.io/pypi/v/merlin-core?color=orange&label=version)](https://pypi.python.org/pypi/merlin-core/) 4 | [![LICENSE](https://img.shields.io/github/license/NVIDIA-Merlin/core)](LICENSE) 5 | [![Documentation](https://img.shields.io/badge/documentation-blue.svg)](https://nvidia-merlin.github.io/core) 6 | 7 | The Merlin Core library provides the core utilities for [NVIDIA Merlin](https://github.com/NVIDIA-Merlin) libraries 8 | like [NVTabular](https://github.com/NVIDIA-Merlin/NVTabular), [Transformers4Rec](https://github.com/NVIDIA-Merlin/Transformers4Rec) 9 | and [Merlin Models](https://github.com/NVIDIA-Merlin/models). 10 | For example, the [merlin.io.Dataset](https://nvidia-merlin.github.io/core/stable/api/merlin.io.html#merlin.io.Dataset) and [merin.schema.Schema](https://nvidia-merlin.github.io/core/stable/api/merlin.schema.html#merlin.schema.Schema) classes are fundamental for working with data and building recommender systems with Merlin. 11 | 12 | ## Installation 13 | 14 | ### Installing Merlin Core Using Pip 15 | 16 | ```shell 17 | pip install merlin-core 18 | ``` 19 | 20 | ### Installing Merlin Core Using Conda 21 | 22 | ```shell 23 | conda install -c nvidia -c rapidsai -c numba -c conda-forge merlin-core python=3.7 cudatoolkit=11.2 24 | ``` 25 | 26 | ### Running Merlin Core with Docker 27 | 28 | As a fundamental library for Merlin, Merlin Core is included in the Merlin Containers. 29 | 30 | Refer to the [Merlin Containers](https://nvidia-merlin.github.io/Merlin/main/containers.html) documentation page for information about the Merlin container names, URLs to the container images on the NVIDIA GPU Cloud catalog, and key Merlin components. 31 | 32 | ## Feedback and Support 33 | 34 | To report bugs or get help, please open an issue on the [GitHub repo](https://github.com/NVIDIA-Merlin/core/issues). 35 | -------------------------------------------------------------------------------- /ci/ignore_codespell_words.txt: -------------------------------------------------------------------------------- 1 | te 2 | coo 3 | ser 4 | fo 5 | ot 6 | lik 7 | usera 8 | -------------------------------------------------------------------------------- /ci/pr.gpu.Jenkinsfile: -------------------------------------------------------------------------------- 1 | pipeline { 2 | agent { 3 | docker { 4 | image 'nvcr.io/nvstaging/merlin/merlin-ci-runner-wrapper' 5 | label 'merlin_gpu_gcp || merlin_gpu' 6 | registryCredentialsId 'jawe-nvcr-io' 7 | registryUrl 'https://nvcr.io' 8 | args "--runtime=nvidia -e NVIDIA_VISIBLE_DEVICES=all" 9 | } 10 | } 11 | 12 | options { 13 | buildDiscarder(logRotator(numToKeepStr: '10')) 14 | ansiColor('xterm') 15 | disableConcurrentBuilds(abortPrevious: true) 16 | } 17 | 18 | stages { 19 | stage("test-gpu") { 20 | options { 21 | timeout(time: 60, unit: 'MINUTES', activity: true) 22 | } 23 | steps { 24 | sh """#!/bin/bash 25 | #set -e 26 | printenv 27 | 28 | rm -rf $HOME/.cudf/ 29 | export TF_MEMORY_ALLOCATION="0.1" 30 | export CUDA_VISIBLE_DEVICES=0,1 31 | export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION='python' 32 | export MKL_SERVICE_FORCE_INTEL=1 33 | export NR_USER=true 34 | 35 | tox -re test-gpu 36 | 37 | rm -rf "nvtabular-$GIT_COMMIT" 38 | rm -rf "models-$GIT_COMMIT" 39 | rm -rf "systems-$GIT_COMMIT" 40 | """ 41 | } 42 | } 43 | } 44 | } -------------------------------------------------------------------------------- /ci/test_unit.sh: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2022, NVIDIA CORPORATION. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | #!/bin/bash 18 | set -e 19 | 20 | cd /core/ 21 | 22 | # Run tests 23 | pytest -rxs /core/tests/unit 24 | 25 | -------------------------------------------------------------------------------- /conda/recipe/meta.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. 2 | 3 | # Usage: 4 | # conda build . -c defaults -c conda-forge -c numba -c rapidsai 5 | 6 | 7 | {% set version = environ.get('GIT_DESCRIBE_TAG', '0.1').lstrip('v') + environ.get('VERSION_SUFFIX', '') %} 8 | {% set git_revision_count=environ.get('GIT_DESCRIBE_NUMBER', 0) %} 9 | {% set setup_py_data = load_setup_py_data() %} 10 | 11 | 12 | package: 13 | name: merlin-core 14 | version: {{ version }} 15 | 16 | source: 17 | path: ../../ 18 | 19 | build: 20 | number: {{ git_revision_count }} 21 | noarch: python 22 | script: python -m pip install . -vvv 23 | 24 | requirements: 25 | build: 26 | - python 27 | - setuptools 28 | run: 29 | - python 30 | {% for req in setup_py_data.get('install_requires', []) %} 31 | - {{ req }} 32 | {% endfor %} 33 | - cupy>=10 34 | - nvtx>=0.2.1 35 | 36 | about: 37 | home: https://github.com/NVIDIA-Merlin/core 38 | license_file: LICENSE 39 | summary: Core Utilities for NVIDIA Merlin 40 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # Documentation 2 | 3 | This folder contains the scripts necessary to build documentation for the 4 | merlin-core library. You can find the [generated documentation 5 | here](https://nvidia-merlin.github.io/core). 6 | 7 | ## Contributing to Docs 8 | 9 | You build the documentation with the `tox` command and specify the `docs` environment. 10 | The following steps are one way of many to build the documentation before opening a merge request. 11 | 12 | 1. Create a virtual environment: 13 | 14 | ```shell 15 | python -m venv .venv 16 | ``` 17 | 18 | 1. Activate the virtual environment: 19 | 20 | ```shell 21 | source .venv/bin/activate 22 | ``` 23 | 24 | 1. Install tox in the virtual environment: 25 | 26 | ```shell 27 | python -m pip install --upgrade pip 28 | python -m pip install tox 29 | ``` 30 | 31 | 1. Build the documentation with tox: 32 | 33 | ```shell 34 | tox -e docs 35 | ``` 36 | 37 | These steps run Sphinx in your shell and create HTML in the `docs/build/html/` 38 | directory. 39 | 40 | ## Preview the Changes 41 | 42 | View the docs web page by opening the HTML in your browser. First, navigate to 43 | the `build/html/` directory and then run the following command: 44 | 45 | ```shell 46 | python -m http.server 47 | ``` 48 | 49 | Afterward, open a web browser and access . 50 | 51 | Check that yours edits formatted correctly and read well. 52 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.https://www.sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/source/_static/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA-Merlin/core/6d396aa1d5e6ffe60bb133dcfe93d869b8548ba2/docs/source/_static/.gitkeep -------------------------------------------------------------------------------- /docs/source/_static/NVIDIA-LogoBlack.svg: -------------------------------------------------------------------------------- 1 | NVIDIA-LogoBlack -------------------------------------------------------------------------------- /docs/source/_static/NVIDIA-LogoWhite.svg: -------------------------------------------------------------------------------- 1 | 2 | 13 | 32 | 34 | 36 | 37 | NVIDIA-LogoBlack 39 | 43 | 48 | 53 | 58 | 59 | -------------------------------------------------------------------------------- /docs/source/_static/css/versions.css: -------------------------------------------------------------------------------- 1 | /* Version Switcher */ 2 | 3 | .rst-versions { 4 | flex-align: bottom; 5 | bottom: 0; 6 | left: 0; 7 | z-index: 400 8 | } 9 | 10 | .rst-versions a { 11 | color: var(--nv-green); 12 | text-decoration: none 13 | } 14 | 15 | .rst-versions .rst-badge-small { 16 | display: none 17 | } 18 | 19 | .rst-versions .rst-current-version { 20 | padding: 12px; 21 | display: block; 22 | text-align: right; 23 | font-size: 90%; 24 | cursor: pointer; 25 | border-top: 1px solid rgba(0,0,0,.1); 26 | *zoom:1 27 | } 28 | 29 | .rst-versions .rst-current-version:before,.rst-versions .rst-current-version:after { 30 | display: table; 31 | content: "" 32 | } 33 | 34 | .rst-versions .rst-current-version:after { 35 | clear: both 36 | } 37 | 38 | .rst-versions .rst-current-version .fa-book { 39 | float: left 40 | } 41 | 42 | .rst-versions .rst-current-version .icon-book { 43 | float: left 44 | } 45 | 46 | .rst-versions .rst-current-version.rst-out-of-date { 47 | background-color: #E74C3C; 48 | color: #fff 49 | } 50 | 51 | .rst-versions .rst-current-version.rst-active-old-version { 52 | background-color: #F1C40F; 53 | color: #000 54 | } 55 | 56 | .rst-versions.shift-up { 57 | height: auto; 58 | max-height: 100% 59 | } 60 | 61 | .rst-versions.shift-up .rst-other-versions { 62 | display: block 63 | } 64 | 65 | .rst-versions .rst-other-versions { 66 | font-size: 90%; 67 | padding: 12px; 68 | color: gray; 69 | display: none 70 | } 71 | 72 | .rst-versions .rst-other-versions hr { 73 | display: block; 74 | height: 1px; 75 | border: 0; 76 | margin: 20px 0; 77 | padding: 0; 78 | border-top: solid 1px #413d3d 79 | } 80 | 81 | .rst-versions .rst-other-versions dd { 82 | display: inline-block; 83 | margin: 0 84 | } 85 | 86 | .rst-versions .rst-other-versions dd a { 87 | display: inline-block; 88 | padding: 6px; 89 | color: var(--nv-green); 90 | font-weight: 500; 91 | } 92 | 93 | .rst-versions.rst-badge { 94 | width: auto; 95 | bottom: 20px; 96 | right: 20px; 97 | left: auto; 98 | border: none; 99 | max-width: 300px 100 | } 101 | 102 | .rst-versions.rst-badge .icon-book { 103 | float: none 104 | } 105 | 106 | .rst-versions.rst-badge .fa-book { 107 | float: none 108 | } 109 | 110 | .rst-versions.rst-badge.shift-up .rst-current-version { 111 | text-align: right 112 | } 113 | 114 | .rst-versions.rst-badge.shift-up .rst-current-version .fa-book { 115 | float: left 116 | } 117 | 118 | .rst-versions.rst-badge.shift-up .rst-current-version .icon-book { 119 | float: left 120 | } 121 | 122 | .rst-versions.rst-badge .rst-current-version { 123 | width: auto; 124 | height: 30px; 125 | line-height: 30px; 126 | padding: 0 6px; 127 | display: block; 128 | text-align: center 129 | } 130 | 131 | @media screen and (max-width: 768px) { 132 | .rst-versions { 133 | width:85%; 134 | display: none 135 | } 136 | 137 | .rst-versions.shift { 138 | display: block 139 | } 140 | } 141 | -------------------------------------------------------------------------------- /docs/source/_static/favicon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA-Merlin/core/6d396aa1d5e6ffe60bb133dcfe93d869b8548ba2/docs/source/_static/favicon.png -------------------------------------------------------------------------------- /docs/source/_static/js/rtd-version-switcher.js: -------------------------------------------------------------------------------- 1 | var jQuery = (typeof(window) != 'undefined') ? window.jQuery : require('jquery'); 2 | var doc = $(document); 3 | doc.on('click', "[data-toggle='rst-current-version']", function() { 4 | $("[data-toggle='rst-versions']").toggleClass("shift-up"); 5 | }); 6 | -------------------------------------------------------------------------------- /docs/source/_templates/footer.html: -------------------------------------------------------------------------------- 1 |

2 | Privacy Policy | 3 | Manage My Privacy | 4 | Do Not Sell or Share My Data | 5 | Terms of Service | 6 | Accessibility | 7 | Corporate Policies | 8 | Product Security | 9 | Contact 10 |

-------------------------------------------------------------------------------- /docs/source/_templates/layout.html: -------------------------------------------------------------------------------- 1 | {%- extends "!layout.html" %} 2 | 3 | {%- block extrahead %} 4 | {%- if analytics_id %} 5 | 6 | 7 | 16 | {% endif %} 17 | 18 | 19 | 20 | 21 | {%- endblock %} 22 | -------------------------------------------------------------------------------- /docs/source/_templates/merlin-ecosystem.html: -------------------------------------------------------------------------------- 1 | 14 | -------------------------------------------------------------------------------- /docs/source/_templates/versions.html: -------------------------------------------------------------------------------- 1 | {%- if current_version %} 2 |
3 | 4 | 5 | v: {{ current_version.name }} 6 | 7 | 8 |
9 | {%- if versions.tags %} 10 |
11 |
Tags
12 | {%- for item in versions.tags %} 13 |
{{ item.name }}
14 | {%- endfor %} 15 |
16 | {%- endif %} 17 | {%- if versions.branches %} 18 |
19 |
Branches
20 | {%- for item in versions.branches %} 21 |
{{ item.name }}
22 | {%- endfor %} 23 |
24 | {%- endif %} 25 |
26 |
27 | {%- endif %} 28 | -------------------------------------------------------------------------------- /docs/source/api/index.rst: -------------------------------------------------------------------------------- 1 | merlin namespace 2 | ================ 3 | 4 | .. py:module:: merlin 5 | 6 | Subpackages 7 | ----------- 8 | 9 | .. toctree:: 10 | :maxdepth: 4 11 | 12 | merlin.dag 13 | merlin.io 14 | merlin.schema 15 | -------------------------------------------------------------------------------- /docs/source/api/merlin.dag.rst: -------------------------------------------------------------------------------- 1 | Merlin DAG 2 | ------------------ 3 | 4 | .. autosummary:: 5 | :toctree: generated 6 | 7 | merlin.dag.Operator 8 | 9 | merlin.dag.Graph 10 | merlin.dag.Node 11 | merlin.dag.ColumnSelector 12 | -------------------------------------------------------------------------------- /docs/source/api/merlin.io.rst: -------------------------------------------------------------------------------- 1 | Merlin IO 2 | ------------------ 3 | 4 | .. autosummary:: 5 | :toctree: generated 6 | 7 | merlin.io.Dataset 8 | -------------------------------------------------------------------------------- /docs/source/api/merlin.schema.rst: -------------------------------------------------------------------------------- 1 | Merlin Schema 2 | ------------------ 3 | 4 | .. autosummary:: 5 | :toctree: generated 6 | 7 | merlin.schema.Schema 8 | merlin.schema.ColumnSchema 9 | merlin.schema.Tags 10 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | Merlin Core 2 | =========== 3 | 4 | The Merlin Core library provides the core utilities and classes for the NVIDIA Merlin project. 5 | 6 | To learn more, start with the `Introduction `_. 7 | 8 | Related Resources 9 | ----------------- 10 | 11 | Merlin Core GitHub Repository 12 | ``_ 13 | 14 | About Merlin 15 | Merlin is the overarching project that brings together the Merlin projects. 16 | See the `documentation `_ 17 | or the `repository `_ on GitHub. 18 | 19 | Developer website for Merlin 20 | More information about Merlin is available at our developer website: 21 | ``_. 22 | 23 | Index 24 | ----- 25 | 26 | * :ref:`genindex` 27 | -------------------------------------------------------------------------------- /docs/source/toc.yaml: -------------------------------------------------------------------------------- 1 | root: index 2 | subtrees: 3 | - caption: Contents 4 | entries: 5 | - file: README.md 6 | title: Introduction 7 | - file: api/index.rst 8 | title: API Documentation 9 | -------------------------------------------------------------------------------- /merlin/config/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2024, NVIDIA CORPORATION. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | _DASK_QUERY_PLANNING_ENABLED = False 18 | try: 19 | # Disable query-planning and string conversion 20 | import dask 21 | 22 | dask.config.set( 23 | { 24 | "dataframe.query-planning": False, 25 | "dataframe.convert-string": False, 26 | } 27 | ) 28 | except ImportError: 29 | dask = None 30 | else: 31 | import sys 32 | 33 | import dask.dataframe as dd 34 | from packaging.version import parse 35 | 36 | if parse(dask.__version__) > parse("2024.6.0"): 37 | # For newer versions of dask, we can just check 38 | # the official DASK_EXPR_ENABLED constant 39 | _DASK_QUERY_PLANNING_ENABLED = dd.DASK_EXPR_ENABLED 40 | else: 41 | # For older versions of dask, we must assume query 42 | # planning is enabled if dask_expr was imported 43 | # (because we can't know for sure) 44 | _DASK_QUERY_PLANNING_ENABLED = "dask_expr" in sys.modules 45 | 46 | 47 | def validate_dask_configs(): 48 | """Central check for problematic config options in Dask""" 49 | if _DASK_QUERY_PLANNING_ENABLED: 50 | raise NotImplementedError( 51 | "Merlin does not support the query-planning API in " 52 | "Dask Dataframe yet. Please make sure query-planning is " 53 | "disabled before dask.dataframe is imported.\n\ne.g." 54 | "dask.config.set({'dataframe.query-planning': False})" 55 | "\n\nOr set the environment variable: " 56 | "export DASK_DATAFRAME__QUERY_PLANNING=False" 57 | ) 58 | 59 | if dask is not None and dask.config.get("dataframe.convert-string"): 60 | raise NotImplementedError( 61 | "Merlin does not support automatic string conversion in " 62 | "Dask Dataframe yet. Please make sure this option is " 63 | "disabled.\n\ne.g." 64 | "dask.config.set({'dataframe.convert-string': False})" 65 | "\n\nOr set the environment variable: " 66 | "export DASK_DATAFRAME__CONVERT_STRING=False" 67 | ) 68 | -------------------------------------------------------------------------------- /merlin/core/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2022-2024, NVIDIA CORPORATION. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | from merlin.config import validate_dask_configs 18 | from merlin.core import _version 19 | 20 | __version__ = _version.get_versions()["version"] 21 | validate_dask_configs() 22 | -------------------------------------------------------------------------------- /merlin/core/compat/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2022, NVIDIA CORPORATION. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | # pylint: disable=unused-import 18 | import warnings 19 | 20 | from numba import cuda 21 | 22 | from merlin.core.has_gpu import HAS_GPU # noqa pylint: disable=unused-import 23 | 24 | if not cuda.is_available(): 25 | cuda = None 26 | 27 | try: 28 | import psutil 29 | except ImportError: 30 | psutil = None 31 | 32 | 33 | def pynvml_mem_size(kind="total", index=0): 34 | """Get Memory Info for device. 35 | 36 | Parameters 37 | ---------- 38 | kind : str, optional 39 | Either "free" or "total", by default "total" 40 | index : int, optional 41 | Device Index, by default 0 42 | 43 | Returns 44 | ------- 45 | int 46 | Either free or total memory on device depending on the kind parameter. 47 | 48 | Raises 49 | ------ 50 | ValueError 51 | When kind is not one of {"free", "total"} 52 | """ 53 | import pynvml 54 | 55 | pynvml.nvmlInit() 56 | size = None 57 | if kind == "free": 58 | size = int(pynvml.nvmlDeviceGetMemoryInfo(pynvml.nvmlDeviceGetHandleByIndex(index)).free) 59 | elif kind == "total": 60 | size = int(pynvml.nvmlDeviceGetMemoryInfo(pynvml.nvmlDeviceGetHandleByIndex(index)).total) 61 | else: 62 | raise ValueError(f"{kind} not a supported option for device_mem_size.") 63 | pynvml.nvmlShutdown() 64 | return size 65 | 66 | 67 | def device_mem_size(kind="total", cpu=False): 68 | """Get Memory Info for either CPU or GPU. 69 | 70 | Parameters 71 | ---------- 72 | kind : str, optional 73 | Either "total" or "free", by default "total" 74 | cpu : bool, optional 75 | Specifies whether to check memory for CPU or GPU, by default False 76 | 77 | Returns 78 | ------- 79 | int 80 | Free or total memory on device 81 | 82 | Raises 83 | ------ 84 | ValueError 85 | When kind is provided with an unsupported value. 86 | """ 87 | # Use psutil (if available) for cpu mode 88 | cpu = cpu or not cuda 89 | if cpu and psutil: 90 | if kind == "total": 91 | return psutil.virtual_memory().total 92 | elif kind == "free": 93 | return psutil.virtual_memory().free 94 | elif cpu: 95 | warnings.warn("Please install psutil for full cpu=True support.") 96 | # Assume 1GB of memory 97 | return int(1e9) 98 | 99 | if kind not in ["free", "total"]: 100 | raise ValueError(f"{kind} not a supported option for device_mem_size.") 101 | try: 102 | if kind == "free": 103 | return int(cuda.current_context().get_memory_info()[0]) 104 | else: 105 | return int(cuda.current_context().get_memory_info()[1]) 106 | except NotImplementedError: 107 | if kind == "free": 108 | # Not using NVML "free" memory, because it will not include RMM-managed memory 109 | warnings.warn("get_memory_info is not supported. Using total device memory from NVML.") 110 | size = pynvml_mem_size(kind="total", index=0) 111 | return size 112 | 113 | 114 | try: 115 | import numpy 116 | except ImportError: 117 | numpy = None 118 | 119 | try: 120 | import pandas 121 | except ImportError: 122 | pandas = None 123 | 124 | if HAS_GPU: 125 | try: 126 | import cupy 127 | except ImportError: 128 | cupy = None 129 | 130 | try: 131 | import cudf 132 | except ImportError: 133 | cudf = None 134 | 135 | try: 136 | import dask_cudf 137 | except ImportError: 138 | dask_cudf = None 139 | else: 140 | # Without a GPU available none of these packages should be used 141 | cupy = None 142 | cudf = None 143 | dask_cudf = None 144 | -------------------------------------------------------------------------------- /merlin/core/compat/tensorflow.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023, NVIDIA CORPORATION. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | # pylint: disable=unused-import 18 | import os 19 | import warnings 20 | 21 | from packaging import version 22 | 23 | from merlin.core.compat import HAS_GPU, device_mem_size 24 | 25 | try: 26 | import tensorflow 27 | 28 | def configure_tensorflow(memory_allocation=None, device=None): 29 | """Utility to help configure tensorflow to not use 100% of gpu memory as buffer""" 30 | tf = tensorflow 31 | total_gpu_mem_mb = device_mem_size(kind="total", cpu=(not HAS_GPU)) / (1024**2) 32 | 33 | if memory_allocation is None: 34 | memory_allocation = os.environ.get("TF_MEMORY_ALLOCATION", 0.5) 35 | 36 | if float(memory_allocation) < 1: 37 | memory_allocation = total_gpu_mem_mb * float(memory_allocation) 38 | memory_allocation = int(memory_allocation) 39 | assert memory_allocation < total_gpu_mem_mb 40 | 41 | if HAS_GPU: 42 | tf_devices = tf.config.list_physical_devices("GPU") 43 | 44 | if len(tf_devices) == 0: 45 | raise ImportError("TensorFlow is not configured for GPU") 46 | 47 | for tf_device in tf_devices: 48 | try: 49 | tf.config.set_logical_device_configuration( 50 | tf_device, 51 | [tf.config.LogicalDeviceConfiguration(memory_limit=memory_allocation)], 52 | ) 53 | except RuntimeError: 54 | warnings.warn( 55 | "TensorFlow runtime already initialized, may not be enough memory for cudf" 56 | ) 57 | try: 58 | tf.config.experimental.set_virtual_device_configuration( 59 | tf_device, 60 | [ 61 | tf.config.experimental.VirtualDeviceConfiguration( 62 | memory_limit=memory_allocation 63 | ) 64 | ], 65 | ) 66 | except RuntimeError as e: 67 | # Virtual devices must be set before GPUs have been initialized 68 | warnings.warn(str(e)) 69 | 70 | # versions using TF earlier than 2.3.0 need to use extension 71 | # library for dlpack support to avoid memory leak issue 72 | __TF_DLPACK_STABLE_VERSION = "2.3.0" 73 | if version.parse(tf.__version__) < version.parse(__TF_DLPACK_STABLE_VERSION): 74 | try: 75 | from tfdlpack import from_dlpack 76 | except ModuleNotFoundError as e: 77 | message = ( 78 | "If using TensorFlow < 2.3.0, you must install tfdlpack-gpu extension library" 79 | ) 80 | raise ModuleNotFoundError(message) from e 81 | 82 | else: 83 | from tensorflow.experimental.dlpack import from_dlpack 84 | 85 | return from_dlpack 86 | 87 | configure_tensorflow() 88 | 89 | from tensorflow.python.framework import ops as tf_ops 90 | except ImportError: 91 | tensorflow = None 92 | tf_ops = None 93 | -------------------------------------------------------------------------------- /merlin/core/compat/torch.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023, NVIDIA CORPORATION. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | # pylint: disable=unused-import 18 | 19 | try: 20 | import torch 21 | except ImportError: 22 | torch = None 23 | -------------------------------------------------------------------------------- /merlin/core/has_gpu.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023, NVIDIA CORPORATION. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | # pylint: disable=unused-import 18 | import os 19 | 20 | from dask.distributed.diagnostics import nvml 21 | 22 | 23 | def _get_gpu_count(): 24 | """Get Number of GPU devices accounting for CUDA_VISIBLE_DEVICES environment variable""" 25 | # Using the `dask.distributed.diagnostics.nvml.device_get_count` 26 | # helper function from dask to check device counts with NVML 27 | # since this handles some complexity of checking NVML state for us. 28 | 29 | # Note: We can't use `numba.cuda.gpus`, since this has some side effects 30 | # that are incompatible with Dask-CUDA. If CUDA runtime functions are 31 | # called before Dask-CUDA can spawn worker processes 32 | # then Dask-CUDA it will not work correctly (raises an exception) 33 | nvml_device_count = nvml.device_get_count() 34 | if nvml_device_count == 0: 35 | return 0 36 | try: 37 | cuda_visible_devices = os.environ["CUDA_VISIBLE_DEVICES"] 38 | if cuda_visible_devices: 39 | return len(cuda_visible_devices.split(",")) 40 | else: 41 | return 0 42 | except KeyError: 43 | return nvml_device_count 44 | 45 | 46 | HAS_GPU = _get_gpu_count() > 0 47 | -------------------------------------------------------------------------------- /merlin/dag/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2022-2024, NVIDIA CORPORATION. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | # flake8: noqa 17 | 18 | from merlin.config import validate_dask_configs 19 | 20 | validate_dask_configs() 21 | 22 | from merlin.dag.graph import Graph 23 | from merlin.dag.node import Node, iter_nodes, postorder_iter_nodes, preorder_iter_nodes 24 | from merlin.dag.operator import DataFormats, Operator, Supports 25 | from merlin.dag.selector import ColumnSelector 26 | from merlin.dag.utils import group_values_offsets, ungroup_values_offsets 27 | 28 | BaseOperator = Operator 29 | -------------------------------------------------------------------------------- /merlin/dag/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2022, NVIDIA CORPORATION. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | # alias submodules here to avoid breaking everything with moving to submodules 17 | # flake8: noqa 18 | from merlin.dag.ops.add_metadata import ( 19 | AddMetadata, 20 | AddProperties, 21 | AddTags, 22 | TagAsItemFeatures, 23 | TagAsItemID, 24 | TagAsUserFeatures, 25 | TagAsUserID, 26 | ) 27 | from merlin.dag.ops.concat_columns import ConcatColumns 28 | from merlin.dag.ops.grouping import GroupingOp 29 | from merlin.dag.ops.rename import Rename 30 | from merlin.dag.ops.selection import SelectionOp 31 | from merlin.dag.ops.subset_columns import SubsetColumns 32 | from merlin.dag.ops.subtraction import SubtractionOp 33 | from merlin.dag.ops.udf import UDF 34 | -------------------------------------------------------------------------------- /merlin/dag/ops/add_metadata.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023, NVIDIA CORPORATION. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | from merlin.dag.operator import Operator 17 | from merlin.schema.tags import Tags 18 | 19 | 20 | class AddMetadata(Operator): 21 | """ 22 | This operator will add user defined tags and properties 23 | to a Schema. 24 | """ 25 | 26 | def __init__(self, tags=None, properties=None): 27 | super().__init__() 28 | self.tags = tags or [] 29 | self.properties = properties or {} 30 | 31 | @property 32 | def output_tags(self): 33 | return self.tags 34 | 35 | @property 36 | def output_properties(self): 37 | return self.properties 38 | 39 | 40 | class AddTags(AddMetadata): 41 | def __init__(self, tags=None): 42 | super().__init__(tags=tags) 43 | 44 | 45 | class AddProperties(AddMetadata): 46 | def __init__(self, properties=None): 47 | super().__init__(properties=properties) 48 | 49 | 50 | # Wrappers for common features 51 | class TagAsUserID(AddTags): 52 | def __init__(self, tags=None): 53 | super().__init__(tags=[Tags.ID, Tags.USER]) 54 | 55 | 56 | class TagAsItemID(AddTags): 57 | def __init__(self, tags=None): 58 | super().__init__(tags=[Tags.ID, Tags.ITEM]) 59 | 60 | 61 | class TagAsUserFeatures(AddTags): 62 | def __init__(self, tags=None): 63 | super().__init__(tags=[Tags.USER]) 64 | 65 | 66 | class TagAsItemFeatures(AddTags): 67 | def __init__(self, tags=None): 68 | super().__init__(tags=[Tags.ITEM]) 69 | -------------------------------------------------------------------------------- /merlin/dag/ops/concat_columns.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2022, NVIDIA CORPORATION. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | from merlin.core.protocols import Transformable 18 | from merlin.dag.operator import Operator 19 | from merlin.dag.selector import ColumnSelector 20 | from merlin.schema import Schema 21 | 22 | 23 | class ConcatColumns(Operator): 24 | """ 25 | This operator class provides an implementation for the `+` operator used in constructing graphs. 26 | """ 27 | 28 | def __init__(self, label=None): 29 | self._label = label or self.__class__.__name__ 30 | super().__init__() 31 | 32 | def compute_selector( 33 | self, 34 | input_schema: Schema, 35 | selector: ColumnSelector, 36 | parents_selector: ColumnSelector = None, 37 | dependencies_selector: ColumnSelector = None, 38 | ) -> ColumnSelector: 39 | """ 40 | Combine selectors from the nodes being added 41 | 42 | Parameters 43 | ---------- 44 | input_schema : Schema 45 | Combined schema of the columns coming from upstream nodes 46 | selector : ColumnSelector 47 | Existing column selector for this node in the graph (often None) 48 | parents_selector : ColumnSelector 49 | Combined column selectors of parent nodes 50 | dependencies_selector : ColumnSelector 51 | Combined column selectors of dependency nodes 52 | 53 | Returns 54 | ------- 55 | ColumnSelector 56 | Combined column selectors of parent and dependency nodes 57 | """ 58 | upstream_selector = parents_selector + dependencies_selector 59 | if upstream_selector.subgroups: 60 | selector = super().compute_selector( 61 | input_schema, 62 | upstream_selector, 63 | ) 64 | else: 65 | selector = ColumnSelector(input_schema.column_names) 66 | return selector 67 | 68 | def compute_input_schema( 69 | self, 70 | root_schema: Schema, 71 | parents_schema: Schema, 72 | deps_schema: Schema, 73 | selector: ColumnSelector, 74 | ) -> Schema: 75 | """ 76 | Combine schemas from the nodes being added 77 | 78 | Parameters 79 | ---------- 80 | root_schema : Schema 81 | Schema of the columns from the input dataset 82 | parents_schema : Schema 83 | Schema of the columns from the parent nodes 84 | deps_schema : Schema 85 | Schema of the columns from the dependency nodes 86 | selector : ColumnSelector 87 | Existing column selector for this node in the graph (often None) 88 | 89 | Returns 90 | ------- 91 | Schema 92 | Combined schema of columns from parents and dependencies 93 | """ 94 | return parents_schema + deps_schema 95 | 96 | def transform( 97 | self, col_selector: ColumnSelector, transformable: Transformable 98 | ) -> Transformable: 99 | """Simply returns the selected output columns from the input dataframe 100 | 101 | The main functionality of this operator has to do with computing the schemas 102 | for `+` nodes in the Workflow graph, so very little has to happen in the 103 | `transform` method. 104 | 105 | Parameters 106 | ----------- 107 | columns: list of str or list of list of str 108 | The columns to apply this operator to 109 | transformable: Transformable 110 | A pandas or cudf dataframe that this operator will work on 111 | 112 | Returns 113 | ------- 114 | DataFrame 115 | Returns a transformed dataframe for this operator 116 | """ 117 | return super()._get_columns(transformable, col_selector) 118 | 119 | @property 120 | def label(self) -> str: 121 | """ 122 | Display name of this operator 123 | """ 124 | return self._label 125 | -------------------------------------------------------------------------------- /merlin/dag/ops/grouping.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from merlin.dag.ops.selection import SelectionOp 4 | from merlin.dag.selector import ColumnSelector 5 | from merlin.schema import Schema 6 | 7 | 8 | class GroupingOp(SelectionOp): 9 | def compute_selector( 10 | self, 11 | input_schema: Schema, 12 | selector: ColumnSelector, 13 | parents_selector: Optional[ColumnSelector] = None, 14 | dependencies_selector: Optional[ColumnSelector] = None, 15 | ) -> ColumnSelector: 16 | upstream_selector = parents_selector + dependencies_selector 17 | new_selector = ColumnSelector(subgroups=upstream_selector) 18 | selector = super().compute_selector( 19 | input_schema, 20 | new_selector, 21 | ) 22 | return selector 23 | -------------------------------------------------------------------------------- /merlin/dag/ops/rename.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023, NVIDIA CORPORATION. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | from merlin.core.protocols import Transformable 17 | from merlin.dag.operator import Operator 18 | from merlin.dag.selector import ColumnSelector 19 | 20 | 21 | class Rename(Operator): 22 | """This operation renames columns by one of several methods: 23 | 24 | - using a user defined lambda function to transform column names 25 | - appending a postfix string to every column name 26 | - renaming a single column to a single fixed string 27 | 28 | Example usage:: 29 | 30 | # Rename columns after LogOp 31 | cont_features = cont_names >> nvt.ops.LogOp() >> Rename(postfix='_log') 32 | processor = nvt.Workflow(cont_features) 33 | 34 | Parameters 35 | ---------- 36 | f : callable, optional 37 | Function that takes a column name and returns a new column name 38 | postfix : str, optional 39 | If set each column name in the output will have this string appended to it 40 | name : str, optional 41 | If set, a single input column will be renamed to this string 42 | """ 43 | 44 | def __init__(self, f=None, postfix=None, name=None): 45 | if not f and postfix is None and name is None: 46 | raise ValueError("must specify name, f, or postfix, for Rename op") 47 | 48 | self.f = f 49 | self.postfix = postfix 50 | self.name = name 51 | super().__init__() 52 | 53 | def transform( 54 | self, col_selector: ColumnSelector, transformable: Transformable 55 | ) -> Transformable: 56 | transformable = transformable[col_selector.names] 57 | transformable.columns = list( # type: ignore[assignment] 58 | self.column_mapping(col_selector).keys() 59 | ) 60 | return transformable 61 | 62 | transform.__doc__ = Operator.transform.__doc__ 63 | 64 | def column_mapping(self, col_selector): 65 | column_mapping = {} 66 | for col_name in col_selector.names: 67 | if self.f: 68 | new_col_name = self.f(col_name) 69 | elif self.postfix: 70 | new_col_name = col_name + self.postfix 71 | elif self.name: 72 | if len(col_selector.names) == 1: 73 | new_col_name = self.name 74 | else: 75 | raise RuntimeError("Single column name provided for renaming multiple columns") 76 | else: 77 | raise RuntimeError( 78 | "The Rename op requires one of f, postfix, or name to be provided" 79 | ) 80 | 81 | column_mapping[new_col_name] = [col_name] 82 | 83 | return column_mapping 84 | -------------------------------------------------------------------------------- /merlin/dag/ops/selection.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2022, NVIDIA CORPORATION. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | from __future__ import annotations 17 | 18 | import logging 19 | 20 | from merlin.core.protocols import Transformable 21 | from merlin.dag.operator import Operator 22 | from merlin.dag.selector import ColumnSelector 23 | from merlin.schema import Schema 24 | 25 | LOG = logging.getLogger("SelectionOp") 26 | 27 | 28 | class SelectionOp(Operator): 29 | """ 30 | This operator class provides an implementation of the behavior of selection (e.g. input) nodes. 31 | """ 32 | 33 | def __init__(self, selector=None): 34 | self.selector = selector 35 | super().__init__() 36 | 37 | def transform( 38 | self, col_selector: ColumnSelector, transformable: Transformable 39 | ) -> Transformable: 40 | """Simply returns the selected output columns from the input dataframe 41 | 42 | The main functionality of this operator has to do with computing the schemas 43 | for selection nodes in the Workflow graph, so very little has to happen in the 44 | `transform` method. 45 | 46 | Parameters 47 | ----------- 48 | columns: list of str or list of list of str 49 | The columns to apply this operator to 50 | transformable: Transformable 51 | A pandas or cudf dataframe that this operator will work on 52 | 53 | Returns 54 | ------- 55 | DataFrame 56 | Returns a transformed dataframe for this operator 57 | """ 58 | selector = col_selector or self.selector 59 | return super()._get_columns(transformable, selector) 60 | 61 | def compute_input_schema( 62 | self, 63 | root_schema: Schema, 64 | parents_schema: Schema, 65 | deps_schema: Schema, 66 | selector: ColumnSelector, 67 | ) -> Schema: 68 | """ 69 | Return the schemas of columns 70 | 71 | Parameters 72 | ---------- 73 | root_schema : Schema 74 | Schema of the columns from the input dataset 75 | parents_schema : Schema 76 | Schema of the columns from the parent nodes 77 | deps_schema : Schema 78 | Schema of the columns from the dependency nodes 79 | selector : ColumnSelector 80 | Existing column selector for this node in the graph (often None) 81 | 82 | Returns 83 | ------- 84 | Schema 85 | Schema of selected columns from input, parents, and dependencies 86 | """ 87 | upstream_schema = root_schema + parents_schema + deps_schema 88 | return upstream_schema.select(self.selector) 89 | 90 | def compute_output_schema( 91 | self, 92 | input_schema: Schema, 93 | col_selector: ColumnSelector, 94 | prev_output_schema: Schema = None, 95 | ) -> Schema: 96 | """Given a set of schemas and a column selector for the input columns, 97 | returns a set of schemas for the transformed columns this operator will produce 98 | 99 | Parameters 100 | ----------- 101 | input_schema: Schema 102 | The schemas of the columns to apply this operator to 103 | col_selector: ColumnSelector 104 | The column selector to apply to the input schema 105 | Returns 106 | ------- 107 | Schema 108 | The schemas of the columns produced by this operator 109 | """ 110 | selector = col_selector or self.selector 111 | if selector.all: 112 | selector = ColumnSelector(input_schema.column_names) 113 | return super().compute_output_schema(input_schema, selector, prev_output_schema) 114 | -------------------------------------------------------------------------------- /merlin/dag/ops/stat_operator.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023, NVIDIA CORPORATION. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | from typing import Any 17 | 18 | import dask.dataframe as dd 19 | 20 | from merlin.dag.operator import Operator 21 | from merlin.dag.selector import ColumnSelector 22 | 23 | 24 | class StatOperator(Operator): 25 | """ 26 | Base class for statistical operator classes. This adds a 'fit' and 'finalize' method 27 | on top of the Operator class. 28 | """ 29 | 30 | fitted = False 31 | 32 | def fit(self, col_selector: ColumnSelector, ddf: dd.DataFrame) -> Any: 33 | """Calculate statistics for this operator, and return a dask future 34 | to these statistics, which will be computed by the workflow.""" 35 | 36 | raise NotImplementedError( 37 | """The dask operations needed to return a dictionary of uncomputed statistics.""" 38 | ) 39 | 40 | def fit_finalize(self, dask_stats): 41 | """Finalize statistics calculation - the workflow calls this function with 42 | the computed statistics from the 'fit' object'""" 43 | 44 | raise NotImplementedError( 45 | """Follow-up operations to convert dask statistics in to member variables""" 46 | ) 47 | 48 | def clear(self): 49 | """zero and reinitialize all relevant statistical properties""" 50 | raise NotImplementedError("clear isn't implemented for this op!") 51 | 52 | def set_storage_path(self, new_path, copy=False): 53 | """Certain stat operators need external storage - for instance Categorify writes out 54 | parquet files containing the categorical mapping. When we save the operator, we 55 | also want to save these files as part of the bundle. Implementing this method 56 | lets statoperators bundle their dependent files into the new path that we're writing 57 | out (note that this could happen after the operator is created) 58 | """ 59 | -------------------------------------------------------------------------------- /merlin/dag/ops/subset_columns.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2022, NVIDIA CORPORATION. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | from merlin.core.protocols import Transformable 18 | from merlin.dag.operator import Operator 19 | from merlin.dag.selector import ColumnSelector 20 | 21 | 22 | class SubsetColumns(Operator): 23 | """ 24 | This operator class provides an implementation for the `[]` operator 25 | used in constructing graphs. 26 | """ 27 | 28 | def __init__(self, label=None): 29 | self._label = label or self.__class__.__name__ 30 | super().__init__() 31 | 32 | def transform( 33 | self, col_selector: ColumnSelector, transformable: Transformable 34 | ) -> Transformable: 35 | """Simply returns the selected output columns from the input dataframe 36 | 37 | The main functionality of this operator has to do with computing the schemas 38 | for `-` nodes in the Workflow graph, so very little has to happen in the 39 | `transform` method. 40 | 41 | Parameters 42 | ----------- 43 | columns: list of str or list of list of str 44 | The columns to apply this operator to 45 | transformable: Transformable 46 | A pandas or cudf dataframe that this operator will work on 47 | 48 | Returns 49 | ------- 50 | DataFrame 51 | Returns a transformed dataframe for this operator 52 | """ 53 | return super()._get_columns(transformable, col_selector) 54 | 55 | @property 56 | def label(self) -> str: 57 | return self._label 58 | -------------------------------------------------------------------------------- /merlin/dag/ops/subtraction.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2022, NVIDIA CORPORATION. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | from __future__ import annotations 17 | 18 | from merlin.core.protocols import Transformable 19 | from merlin.dag.operator import Operator 20 | from merlin.dag.selector import ColumnSelector 21 | from merlin.schema import Schema 22 | 23 | 24 | class SubtractionOp(Operator): 25 | """ 26 | This operator class provides an implementation for the `-` operator used in constructing graphs. 27 | """ 28 | 29 | def __init__(self, selector=None): 30 | self.selector = selector 31 | super().__init__() 32 | 33 | def compute_selector( 34 | self, 35 | input_schema: Schema, 36 | selector: ColumnSelector, 37 | parents_selector: ColumnSelector = None, 38 | dependencies_selector: ColumnSelector = None, 39 | ) -> ColumnSelector: 40 | """ 41 | Creates selector of all columns from the input schema 42 | 43 | Parameters 44 | ---------- 45 | input_schema : Schema 46 | Combined schema of the columns coming from upstream nodes 47 | selector : ColumnSelector 48 | Existing column selector for this node in the graph (often None) 49 | parents_selector : ColumnSelector 50 | Combined column selectors of parent nodes 51 | dependencies_selector : ColumnSelector 52 | Combined column selectors of dependency nodes 53 | 54 | Returns 55 | ------- 56 | ColumnSelector 57 | Selector of all columns from the input schema 58 | """ 59 | return super().compute_selector( 60 | input_schema, 61 | ColumnSelector("*"), 62 | ) 63 | 64 | def compute_input_schema( 65 | self, 66 | root_schema: Schema, 67 | parents_schema: Schema, 68 | deps_schema: Schema, 69 | selector: ColumnSelector, 70 | ) -> Schema: 71 | """ 72 | Return remaining schemas of columns after removing dependencies 73 | 74 | Parameters 75 | ---------- 76 | root_schema : Schema 77 | Schema of the columns from the input dataset 78 | parents_schema : Schema 79 | Schema of the columns from the parent nodes 80 | deps_schema : Schema 81 | Schema of the columns from the dependency nodes 82 | selector : ColumnSelector 83 | Existing column selector for this node in the graph (often None) 84 | 85 | Returns 86 | ------- 87 | Schema 88 | Remaining schema of columns from parents after removing dependencies 89 | """ 90 | result = None 91 | if deps_schema.column_schemas: 92 | result = parents_schema - deps_schema 93 | else: 94 | subtraction_selector = self.selector or selector 95 | result = parents_schema.excluding(subtraction_selector) 96 | return result 97 | 98 | def transform( 99 | self, col_selector: ColumnSelector, transformable: Transformable 100 | ) -> Transformable: 101 | """Simply returns the selected output columns from the input dataframe 102 | 103 | The main functionality of this operator has to do with computing the schemas 104 | for `-` nodes in the Workflow graph, so very little has to happen in the 105 | `transform` method. 106 | 107 | Parameters 108 | ----------- 109 | columns: list of str or list of list of str 110 | The columns to apply this operator to 111 | transformable: Transformable 112 | A pandas or cudf dataframe that this operator will work on 113 | 114 | Returns 115 | ------- 116 | DataFrame 117 | Returns a transformed dataframe for this operator 118 | """ 119 | selector = self.selector or col_selector 120 | return super()._get_columns(transformable, selector) 121 | -------------------------------------------------------------------------------- /merlin/dag/runtime.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023, NVIDIA CORPORATION. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | from merlin.core.protocols import Transformable 17 | from merlin.dag.executors import LocalExecutor 18 | from merlin.dag.graph import Graph 19 | 20 | 21 | class Runtime: 22 | """A Graph Runtime. 23 | 24 | This class can be used as a base class for custom runtimes. 25 | """ 26 | 27 | def __init__(self, executor=None): 28 | """Construct a Runtime. 29 | 30 | Parameters 31 | ---------- 32 | executor : Executor, optional 33 | The Graph Executor to use to use for the transform, by default None 34 | """ 35 | self.executor = executor or LocalExecutor() 36 | self.op_table = {} 37 | 38 | def transform(self, graph: Graph, transformable: Transformable): 39 | """Run the graph with the input data. 40 | 41 | Parameters 42 | ---------- 43 | graph : Graph 44 | Graph of nodes container operator chains for data manipulation. 45 | transformable : Transformable 46 | Input data to transform in graph. 47 | 48 | Returns 49 | ------- 50 | Transformable 51 | Input data after it has been transformed via graph. 52 | """ 53 | return self.executor.transform(transformable, [graph.output_node]) 54 | 55 | def export(self): 56 | """Optional method. 57 | Implemented for runtimes that require an exported artifact to transform. 58 | """ 59 | raise NotImplementedError 60 | -------------------------------------------------------------------------------- /merlin/dag/utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023, NVIDIA CORPORATION. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | 18 | def ungroup_values_offsets(grouped_cols: dict) -> dict: 19 | """ 20 | Flatten columns with values/offsets tuples in a dictionary to separate keys 21 | 22 | Parameters 23 | ---------- 24 | grouped_cols : dict 25 | A dictionary of column arrays including values/offsets tuples 26 | 27 | Returns 28 | ------- 29 | dict 30 | A dictionary of column arrays with separate keys for values and offsets 31 | """ 32 | flat_cols = {} 33 | 34 | for key, value in grouped_cols.items(): 35 | if isinstance(value, tuple): 36 | flat_cols[f"{key}__values"] = value[0] 37 | flat_cols[f"{key}__offsets"] = value[1] 38 | else: 39 | flat_cols[key] = value 40 | 41 | return flat_cols 42 | 43 | 44 | def group_values_offsets(flat_cols: dict) -> dict: 45 | """ 46 | Convert separate values/offsets keys for columns into tuples w/ a single key 47 | 48 | Parameters 49 | ---------- 50 | flat_cols : dict 51 | A dictionary of column arrays with separate keys for values and offsets 52 | 53 | Returns 54 | ------- 55 | dict 56 | A dictionary of column arrays including values/offsets tuples 57 | """ 58 | grouped_cols = {} 59 | 60 | for key, value in flat_cols.items(): 61 | if key.endswith("__values"): 62 | col_name = key.replace("__values", "") 63 | grouped_cols[col_name] = (flat_cols[key], flat_cols[f"{col_name}__offsets"]) 64 | elif key.endswith("__offsets"): 65 | pass 66 | else: 67 | grouped_cols[key] = value 68 | 69 | return grouped_cols 70 | -------------------------------------------------------------------------------- /merlin/dtypes/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2022, NVIDIA CORPORATION. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | # flake8: noqa 18 | from merlin.dtypes import aliases, mappings 19 | from merlin.dtypes.aliases import * 20 | from merlin.dtypes.base import DType 21 | from merlin.dtypes.registry import _dtype_registry 22 | from merlin.dtypes.shape import Dimension, Shape 23 | 24 | # Convenience alias for registering dtypes 25 | register = _dtype_registry.register 26 | 27 | 28 | def dtype(external_dtype): 29 | # If the supplied dtype is None, then there's not a default dtype we can 30 | # universally translate to across frameworks, so raise an error and help 31 | # the downstream developer figure out how to handle that case explicitly 32 | if external_dtype is None: 33 | raise TypeError( 34 | "Merlin doesn't provide a default dtype mapping for `None`. " 35 | "This differs from the Numpy behavior you may be expecting, " 36 | "which treats `None` as an alias for `np.float64`. If you're " 37 | "expecting this dtype to be non-`None`, there may be an issue " 38 | "in upstream code. If you'd like to allow this dtype to be `None`, " 39 | "you can use a `try/except` to catch this error." 40 | ) 41 | 42 | # If the supplied dtype is already a Merlin dtype, then there's 43 | # nothing for us to do and we can exit early 44 | if isinstance(external_dtype, DType): 45 | return external_dtype 46 | 47 | # If not, attempt to apply all the registered Merlin dtype mappings. 48 | # If we don't find a match with those, fall back on converting to 49 | # a numpy dtype and trying to match that instead. 50 | base_exc = None 51 | merlin_dtype = None 52 | 53 | try: 54 | merlin_dtype = _dtype_registry.to_merlin(external_dtype) 55 | except (TypeError, KeyError, AttributeError) as exc: 56 | base_exc = exc 57 | 58 | if base_exc or merlin_dtype == aliases.unknown: 59 | try: 60 | merlin_dtype = _dtype_registry.to_merlin_via_numpy(external_dtype) 61 | except TypeError as numpy_exc: 62 | # If we fail to find a match even after we try converting to 63 | # numpy, re-raise the original exception because it has more 64 | # information about the original external dtype that's causing 65 | # the problem. (We want to highlight that dtype, not whatever 66 | # numpy dtype it was converted to in the interim.) 67 | if base_exc: 68 | raise base_exc from numpy_exc 69 | else: 70 | raise numpy_exc 71 | 72 | return merlin_dtype 73 | -------------------------------------------------------------------------------- /merlin/dtypes/aliases.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2022, NVIDIA CORPORATION. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | from merlin.dtypes.base import DType, ElementType, ElementUnit 18 | 19 | # Unsigned Integer 20 | uint8 = DType("uint8", ElementType.UInt, 8) 21 | uint16 = DType("uint16", ElementType.UInt, 16) 22 | uint32 = DType("uint32", ElementType.UInt, 32) 23 | uint64 = DType("uint64", ElementType.UInt, 64) 24 | 25 | # Signed Integer 26 | int8 = DType("int8", ElementType.Int, 8, signed=True) 27 | int16 = DType("int16", ElementType.Int, 16, signed=True) 28 | int32 = DType("int32", ElementType.Int, 32, signed=True) 29 | int64 = DType("int64", ElementType.Int, 64, signed=True) 30 | 31 | # Float 32 | float16 = DType("float16", ElementType.Float, 16, signed=True) 33 | float32 = DType("float32", ElementType.Float, 32, signed=True) 34 | float64 = DType("float64", ElementType.Float, 64, signed=True) 35 | 36 | # Date/Time 37 | datetime64 = DType("datetime64", ElementType.DateTime, 64) 38 | datetime64Y = DType("datetime64[Y]", ElementType.DateTime, 64, ElementUnit.Year) 39 | datetime64M = DType("datetime64[M]", ElementType.DateTime, 64, ElementUnit.Month) 40 | datetime64D = DType("datetime64[D]", ElementType.DateTime, 64, ElementUnit.Day) 41 | datetime64h = DType("datetime64[h]", ElementType.DateTime, 64, ElementUnit.Hour) 42 | datetime64m = DType("datetime64[m]", ElementType.DateTime, 64, ElementUnit.Minute) 43 | datetime64s = DType("datetime64[s]", ElementType.DateTime, 64, ElementUnit.Second) 44 | datetime64ms = DType("datetime64[ms]", ElementType.DateTime, 64, ElementUnit.Millisecond) 45 | datetime64us = DType("datetime64[us]", ElementType.DateTime, 64, ElementUnit.Microsecond) 46 | datetime64ns = DType("datetime64[ns]", ElementType.DateTime, 64, ElementUnit.Nanosecond) 47 | 48 | # Miscellaneous 49 | string = DType("str", ElementType.String) 50 | boolean = DType("bool", ElementType.Bool) 51 | object_ = DType("object", ElementType.Object) 52 | struct = DType("struct", ElementType.Struct) 53 | unknown = DType("unknown", ElementType.Unknown) 54 | -------------------------------------------------------------------------------- /merlin/dtypes/mappings/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2022, NVIDIA CORPORATION. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | # flake8: noqa 18 | from merlin.dtypes.mappings import cudf, merlin, numpy, pandas, python, tf, torch, triton 19 | -------------------------------------------------------------------------------- /merlin/dtypes/mappings/cudf.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2022, NVIDIA CORPORATION. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | import numpy as np 17 | 18 | import merlin.dtypes.aliases as mn 19 | from merlin.core.compat import cudf 20 | from merlin.core.dispatch import is_string_dtype 21 | from merlin.dtypes.mapping import DTypeMapping, NumpyPreprocessor 22 | from merlin.dtypes.registry import _dtype_registry 23 | 24 | 25 | def cudf_translator(raw_dtype) -> np.dtype: 26 | """ 27 | Translate cudf dtypes to Numpy dtypes 28 | 29 | Parameters 30 | ---------- 31 | raw_dtype : cudf dtype 32 | The dtype to be translated 33 | 34 | Returns 35 | ------- 36 | np.dtype 37 | The result of translating raw_dtype to Numpy 38 | """ 39 | category_type = raw_dtype._categories.dtype 40 | if is_string_dtype(category_type): 41 | return np.dtype("str") 42 | else: 43 | return category_type 44 | 45 | 46 | if cudf: 47 | try: 48 | # We only want to register this mapping if cudf is available, even though 49 | # the mapping itself doesn't use cudf (yet?) 50 | 51 | cudf_dtypes = DTypeMapping( 52 | { 53 | mn.struct: [cudf.StructDtype], 54 | }, 55 | translator=NumpyPreprocessor("cudf", cudf_translator, attrs=["_categories"]), 56 | ) 57 | _dtype_registry.register("cudf", cudf_dtypes) 58 | except ImportError as exc: 59 | from warnings import warn 60 | 61 | warn(f"cuDF dtype mappings did not load successfully due to an error: {exc.msg}") 62 | -------------------------------------------------------------------------------- /merlin/dtypes/mappings/merlin.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2022, NVIDIA CORPORATION. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | import merlin.dtypes.aliases as mn 17 | from merlin.dtypes.registry import _dtype_registry 18 | 19 | merlin_dtypes = { 20 | # Unsigned Integer 21 | mn.uint8: ["uint8"], 22 | mn.uint16: ["uint16"], 23 | mn.uint32: ["uint32"], 24 | mn.uint64: ["uint64"], 25 | # Signed integer 26 | mn.int8: ["int8"], 27 | mn.int16: ["int16"], 28 | mn.int32: ["int32"], 29 | mn.int64: ["int64"], 30 | # Floating Point 31 | mn.float16: ["float16"], 32 | mn.float32: ["float32"], 33 | mn.float64: ["float64"], 34 | # Date/Time 35 | mn.datetime64: ["datetime64"], 36 | mn.datetime64Y: ["datetime64[Y]"], 37 | mn.datetime64M: ["datetime64[M]"], 38 | mn.datetime64D: ["datetime64[D]"], 39 | mn.datetime64h: ["datetime64[h]"], 40 | mn.datetime64m: ["datetime64[m]"], 41 | mn.datetime64s: ["datetime64[s]"], 42 | mn.datetime64ms: ["datetime64[ms]"], 43 | mn.datetime64us: ["datetime64[us]"], 44 | mn.datetime64ns: ["datetime64[ns]"], 45 | # Miscellaneous 46 | mn.string: ["str", "string"], 47 | mn.object_: ["object"], 48 | mn.struct: ["struct"], 49 | mn.boolean: ["bool", "boolean"], 50 | } 51 | _dtype_registry.register("merlin", merlin_dtypes) 52 | -------------------------------------------------------------------------------- /merlin/dtypes/mappings/numpy.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2022, NVIDIA CORPORATION. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | import numpy as np 17 | 18 | import merlin.dtypes.aliases as mn 19 | from merlin.dtypes.registry import _dtype_registry 20 | 21 | numpy_dtypes = { 22 | # Unsigned Integer 23 | mn.uint8: [np.dtype("uint8"), np.uint8], 24 | mn.uint16: [np.dtype("uint16"), np.uint16], 25 | mn.uint32: [np.dtype("uint32"), np.uint32], 26 | mn.uint64: [np.dtype("uint64"), np.uint64], 27 | # Signed integer 28 | mn.int8: [np.dtype("int8"), np.int8], 29 | mn.int16: [np.dtype("int16"), np.int16], 30 | mn.int32: [np.dtype("int32"), np.int32], 31 | mn.int64: [np.dtype("int64"), np.int64], 32 | # Floating Point 33 | mn.float16: [np.dtype("float16"), np.float16], 34 | mn.float32: [np.dtype("float32"), np.float32], 35 | mn.float64: [np.dtype("float64"), np.float64], 36 | # Date/Time 37 | mn.datetime64: [np.dtype("datetime64"), np.datetime64], 38 | mn.datetime64Y: [np.dtype("datetime64[Y]")], 39 | mn.datetime64M: [np.dtype("datetime64[M]")], 40 | mn.datetime64D: [np.dtype("datetime64[D]")], 41 | mn.datetime64h: [np.dtype("datetime64[h]")], 42 | mn.datetime64m: [np.dtype("datetime64[m]")], 43 | mn.datetime64s: [np.dtype("datetime64[s]")], 44 | mn.datetime64ms: [np.dtype("datetime64[ms]")], 45 | mn.datetime64us: [np.dtype("datetime64[us]")], 46 | mn.datetime64ns: [np.dtype("datetime64[ns]")], 47 | # Miscellaneous 48 | mn.string: [np.dtype("str"), str], 49 | mn.object_: [np.dtype("O"), object], 50 | mn.boolean: [np.dtype("bool"), bool], 51 | } 52 | _dtype_registry.register("numpy", numpy_dtypes) 53 | -------------------------------------------------------------------------------- /merlin/dtypes/mappings/pandas.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2022, NVIDIA CORPORATION. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | from typing import Callable 17 | 18 | import merlin.dtypes.aliases as mn 19 | from merlin.dtypes.mapping import DTypeMapping, NumpyPreprocessor 20 | from merlin.dtypes.registry import _dtype_registry 21 | 22 | try: 23 | import pandas as pd 24 | from pandas.core.dtypes.base import ExtensionDtype 25 | 26 | def _translate_to_numpy(raw_dtype): 27 | if isinstance(raw_dtype, ExtensionDtype) and isinstance(raw_dtype, Callable): 28 | raw_dtype = raw_dtype() 29 | 30 | return raw_dtype.numpy_dtype 31 | 32 | pandas_dtypes = DTypeMapping( 33 | { 34 | mn.string: [pd.StringDtype(), pd.StringDtype], 35 | mn.boolean: [pd.BooleanDtype(), pd.BooleanDtype], 36 | }, 37 | translator=NumpyPreprocessor("pandas", _translate_to_numpy, attrs=["numpy_dtype"]), 38 | ) 39 | _dtype_registry.register("pandas", pandas_dtypes) 40 | except ImportError as exc: 41 | from warnings import warn 42 | 43 | warn(f"Pandas dtype mappings did not load successfully due to an error: {exc.msg}") 44 | -------------------------------------------------------------------------------- /merlin/dtypes/mappings/python.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2022, NVIDIA CORPORATION. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | import merlin.dtypes.aliases as mn 17 | from merlin.dtypes.mapping import DTypeMapping 18 | from merlin.dtypes.registry import _dtype_registry 19 | 20 | python_dtypes = DTypeMapping( 21 | { 22 | mn.boolean: bool, 23 | mn.int64: int, 24 | mn.float64: float, 25 | mn.string: str, 26 | } 27 | ) 28 | _dtype_registry.register("python", python_dtypes) 29 | -------------------------------------------------------------------------------- /merlin/dtypes/mappings/tf.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2022, NVIDIA CORPORATION. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | import merlin.dtypes.aliases as mn 17 | from merlin.dtypes.mapping import DTypeMapping, NumpyPreprocessor 18 | from merlin.dtypes.registry import _dtype_registry 19 | 20 | try: 21 | from tensorflow import dtypes as tf_dtypes 22 | 23 | tf_dtypes = DTypeMapping( 24 | { 25 | # Unsigned Integer 26 | mn.uint8: [tf_dtypes.uint8], 27 | mn.uint16: [tf_dtypes.uint16], 28 | mn.uint32: [tf_dtypes.uint32], 29 | mn.uint64: [tf_dtypes.uint64], 30 | # Signed integer 31 | mn.int8: [tf_dtypes.int8], 32 | mn.int16: [tf_dtypes.int16], 33 | mn.int32: [tf_dtypes.int32], 34 | mn.int64: [tf_dtypes.int64], 35 | # Floating Point 36 | mn.float16: [tf_dtypes.float16], 37 | mn.float32: [tf_dtypes.float32], 38 | mn.float64: [tf_dtypes.float64], 39 | # Miscellaneous 40 | mn.boolean: [tf_dtypes.bool], 41 | }, 42 | base_class=tf_dtypes.DType, 43 | translator=NumpyPreprocessor( 44 | "tf", lambda raw: raw.as_numpy_dtype, attrs=["as_numpy_dtype"] 45 | ), 46 | ) 47 | _dtype_registry.register("tf", tf_dtypes) 48 | _dtype_registry.register("tensorflow", tf_dtypes) 49 | except ImportError as exc: 50 | from warnings import warn 51 | 52 | warn(f"Tensorflow dtype mappings did not load successfully due to an error: {exc.msg}") 53 | -------------------------------------------------------------------------------- /merlin/dtypes/mappings/torch.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2022, NVIDIA CORPORATION. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | import merlin.dtypes.aliases as mn 17 | from merlin.dtypes.registry import _dtype_registry 18 | 19 | try: 20 | from torch import bool as bool_ 21 | from torch import float16, float32, float64, int8, int16, int32, int64, uint8 22 | 23 | torch_dtypes = { 24 | # Unsigned Integer 25 | mn.uint8: [uint8], 26 | # Signed integer 27 | mn.int8: [int8], 28 | mn.int16: [int16], 29 | mn.int32: [int32], 30 | mn.int64: [int64], 31 | # Floating Point 32 | mn.float16: [float16], 33 | mn.float32: [float32], 34 | mn.float64: [float64], 35 | # Miscellaneous 36 | mn.boolean: [bool_], 37 | } 38 | _dtype_registry.register("torch", torch_dtypes) 39 | _dtype_registry.register("pytorch", torch_dtypes) 40 | except ImportError as exc: 41 | from warnings import warn 42 | 43 | warn(f"PyTorch dtype mappings did not load successfully due to an error: {exc.msg}") 44 | -------------------------------------------------------------------------------- /merlin/dtypes/mappings/triton.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2022, NVIDIA CORPORATION. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | import merlin.dtypes.aliases as mn 17 | from merlin.dtypes.registry import _dtype_registry 18 | 19 | # Only define a Triton dtype mapping if `tritonclient` is available 20 | try: 21 | import tritonclient.grpc.model_config_pb2 as model_config 22 | 23 | # The type constants on `model_config` are literally integers, 24 | # so this only works if we don't have any other dtypes that 25 | # either are or are equivalent to integers. We work around 26 | # this by checking base classes for dtypes that have them (e.g. 27 | # Tensorflow.) 28 | triton_dtypes = { 29 | # Unsigned Integer 30 | mn.uint8: [model_config.TYPE_UINT8], 31 | mn.uint16: [model_config.TYPE_UINT16], 32 | mn.uint32: [model_config.TYPE_UINT32], 33 | mn.uint64: [model_config.TYPE_UINT64], 34 | # Signed integer 35 | mn.int8: [model_config.TYPE_INT8], 36 | mn.int16: [model_config.TYPE_INT16], 37 | mn.int32: [model_config.TYPE_INT32], 38 | mn.int64: [model_config.TYPE_INT64], 39 | # Floating Point 40 | mn.float16: [model_config.TYPE_FP16], 41 | mn.float32: [ 42 | model_config.TYPE_FP32, 43 | ], 44 | mn.float64: [model_config.TYPE_FP64], 45 | # Miscellaneous 46 | mn.string: [model_config.TYPE_STRING], 47 | mn.boolean: [model_config.TYPE_BOOL], 48 | } 49 | _dtype_registry.register("triton", triton_dtypes) 50 | except ImportError as exc: 51 | from warnings import warn 52 | 53 | warn(f"Triton dtype mappings did not load successfully due to an error: {exc.msg}") 54 | -------------------------------------------------------------------------------- /merlin/dtypes/registry.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2022, NVIDIA CORPORATION. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | from typing import Dict, Union 17 | 18 | import merlin.dtypes.aliases as mn 19 | from merlin.dtypes.mapping import DTypeMapping 20 | 21 | 22 | class DTypeMappingRegistry: 23 | """ 24 | A registry of mappings between Merlin dtypes and the dtypes of many external frameworks 25 | """ 26 | 27 | def __init__(self): 28 | self.mappings = {} 29 | 30 | def __iter__(self): 31 | return iter(self.mappings) 32 | 33 | def register(self, name: str, mapping: Union[Dict, DTypeMapping]): 34 | """ 35 | Register a mapping between Merlin and external dtypes by name 36 | 37 | Parameters 38 | ---------- 39 | name : str 40 | Name of the new mapping to register 41 | mapping : Union[Dict, DTypeMapping] 42 | Mapping between Merlin and external dtypes 43 | """ 44 | if not isinstance(mapping, DTypeMapping): 45 | mapping = DTypeMapping(mapping) 46 | 47 | self.mappings[name] = mapping 48 | 49 | def from_merlin(self, merlin_dtype, mapping_name): 50 | """ 51 | Map a Merlin dtype to an external dtype 52 | 53 | Parameters 54 | ---------- 55 | merlin_dtype : DType 56 | A Merlin dtype object 57 | mapping_name : str 58 | The name of the external framework mapping to apply 59 | 60 | Returns 61 | ------- 62 | Any 63 | An external framework dtype object 64 | 65 | Raises 66 | ------ 67 | TypeError 68 | If the Merlin dtype can't be mapped to an external dtype from the requested framework 69 | """ 70 | mapping = self.mappings[mapping_name] 71 | if mapping.matches_merlin(merlin_dtype): 72 | return mapping.to_merlin(merlin_dtype) 73 | 74 | return mn.unknown 75 | 76 | def to_merlin(self, external_dtype): 77 | """ 78 | Map an external dtype to a Merlin dtype 79 | 80 | Parameters 81 | ---------- 82 | external_dtype : Any 83 | A dtype object from an external framework 84 | 85 | Returns 86 | ------- 87 | DType 88 | A Merlin DType object 89 | 90 | Raises 91 | ------ 92 | TypeError 93 | If the external dtype can't be mapped to a Merlin dtype 94 | """ 95 | for framework, mapping in self.mappings.items(): 96 | if mapping.matches_external(external_dtype): 97 | return mapping.to_merlin(external_dtype) 98 | 99 | return mn.unknown 100 | 101 | def to_merlin_via_numpy(self, external_dtype): 102 | """ 103 | Map an external dtype to a Merlin dtype by converting the external type to Numpy first 104 | 105 | This is sometimes useful for external framework dtypes that don't have a clear 106 | one-to-one mapping with a Merlin dtype, like cuDF's CategoricalDtype. We can often do 107 | some additional preprocessing on the external framework's dtype to determine the 108 | Numpy dtype of the elements, and then use that as an intermediary to find the 109 | corresponding Merlin dtype. 110 | 111 | Parameters 112 | ---------- 113 | external_dtype : Any 114 | A dtype object from an external framework 115 | 116 | Returns 117 | ------- 118 | DType 119 | A Merlin DType object 120 | 121 | Raises 122 | ------ 123 | TypeError 124 | If the external dtype can't be mapped to a Merlin dtype 125 | """ 126 | numpy_dtype = None 127 | 128 | for mapping in self.mappings.values(): 129 | if mapping.translator and mapping.translator.matches(external_dtype): 130 | numpy_dtype = mapping.translator.to_numpy(external_dtype) 131 | break 132 | 133 | return self.to_merlin(numpy_dtype) 134 | 135 | 136 | _dtype_registry = DTypeMappingRegistry() 137 | -------------------------------------------------------------------------------- /merlin/io/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2022-2024, NVIDIA CORPORATION. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | # flake8: noqa 17 | 18 | from merlin.config import validate_dask_configs 19 | 20 | validate_dask_configs() 21 | 22 | from merlin.io import dataframe_iter, dataset, shuffle 23 | from merlin.io.dataframe_iter import DataFrameIter 24 | from merlin.io.dataset import MERLIN_METADATA_DIR_NAME, Dataset 25 | from merlin.io.shuffle import Shuffle, shuffle_df 26 | -------------------------------------------------------------------------------- /merlin/io/csv.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2022, NVIDIA CORPORATION. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | import functools 17 | 18 | import dask.dataframe as dd 19 | from dask.bytes import read_bytes 20 | from dask.utils import parse_bytes 21 | from fsspec.core import get_fs_token_paths 22 | from fsspec.utils import infer_compression 23 | 24 | from merlin.core.compat import dask_cudf 25 | from merlin.core.compat import numpy as np 26 | from merlin.io.dataset_engine import DatasetEngine 27 | 28 | 29 | class CSVDatasetEngine(DatasetEngine): 30 | """CSVDatasetEngine 31 | 32 | Thin wrapper around dask_cudf.read_csv. 33 | """ 34 | 35 | def __init__(self, paths, part_size, storage_options=None, cpu=False, **kwargs): 36 | # pylint: disable=access-member-before-definition 37 | super().__init__(paths, part_size, cpu=cpu, storage_options=storage_options) 38 | self._meta = {} 39 | self.csv_kwargs = kwargs 40 | self.csv_kwargs["storage_options"] = storage_options 41 | 42 | # CSV reader needs a list of files 43 | # (Assume flat directory structure if this is a dir) 44 | if len(self.paths) == 1 and self.fs.isdir(self.paths[0]): 45 | self.paths = self.fs.glob(self.fs.sep.join([self.paths[0], "*"])) 46 | 47 | def to_ddf(self, columns=None, cpu=None): 48 | # Check if we are using cpu 49 | cpu = self.cpu if cpu is None else cpu 50 | if cpu: 51 | ddf = dd.read_csv(self.paths, blocksize=self.part_size, **self.csv_kwargs) 52 | else: 53 | ddf = dask_cudf.read_csv(self.paths, chunksize=self.part_size, **self.csv_kwargs) 54 | if columns: 55 | ddf = ddf[columns] 56 | return ddf 57 | 58 | @property # type: ignore 59 | @functools.lru_cache(1) 60 | def _file_partition_map(self): 61 | ind = 0 62 | _pp_map = {} 63 | for path, blocks in zip( 64 | *_byte_block_counts( 65 | self.paths, 66 | self.part_size, 67 | **self.csv_kwargs, 68 | ) 69 | ): 70 | _pp_map[path.split(self.fs.sep)[-1]] = np.arange(ind, ind + blocks) 71 | ind += blocks 72 | return _pp_map 73 | 74 | def to_cpu(self): 75 | self.cpu = True 76 | 77 | def to_gpu(self): 78 | self.cpu = False 79 | 80 | 81 | def _byte_block_counts( 82 | urlpath, 83 | blocksize, 84 | lineterminator=None, 85 | compression="infer", 86 | storage_options=None, 87 | **kwargs, 88 | ): 89 | """Return a list of paths and block counts. 90 | 91 | Logic copied from dask.bytes.read_bytes 92 | """ 93 | 94 | if lineterminator is not None and len(lineterminator) == 1: 95 | kwargs["lineterminator"] = lineterminator 96 | else: 97 | lineterminator = "\n" 98 | 99 | if compression == "infer": 100 | paths = get_fs_token_paths(urlpath, mode="rb", storage_options=storage_options)[2] 101 | compression = infer_compression(paths[0]) 102 | 103 | if isinstance(blocksize, str): 104 | blocksize = parse_bytes(blocksize) 105 | if blocksize and compression: 106 | blocksize = None 107 | 108 | b_out = read_bytes( 109 | urlpath, 110 | delimiter=lineterminator.encode(), 111 | blocksize=blocksize, 112 | sample=False, 113 | compression=compression, 114 | include_path=True, 115 | **(storage_options or {}), 116 | ) 117 | _, values, paths = b_out 118 | 119 | if not isinstance(values[0], (tuple, list)): 120 | values = [values] 121 | 122 | return paths, [len(v) for v in values] 123 | -------------------------------------------------------------------------------- /merlin/io/dataframe_iter.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2022, NVIDIA CORPORATION. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | 18 | class DataFrameIter: 19 | def __init__(self, ddf, columns=None, indices=None, partition_lens=None, epochs=1): 20 | self.indices = indices if isinstance(indices, list) else range(ddf.npartitions) 21 | self._ddf = ddf 22 | self.columns = columns 23 | self.partition_lens = partition_lens 24 | self.epochs = epochs 25 | 26 | def __len__(self): 27 | if self.partition_lens: 28 | # Use metadata-based partition-size information 29 | # if/when it is available. Note that this metadata 30 | # will not be correct if rows where added or dropped 31 | # after IO (within Ops). 32 | return sum(self.partition_lens[i] for i in self.indices) * self.epochs 33 | if len(self.indices) < self._ddf.npartitions: 34 | return len(self._ddf.partitions[self.indices]) * self.epochs 35 | return len(self._ddf) * self.epochs 36 | 37 | def __iter__(self): 38 | for epoch in range(self.epochs): 39 | for i in self.indices: 40 | part = self._ddf.get_partition(i) 41 | if self.columns: 42 | yield part[self.columns].compute(scheduler="synchronous") 43 | else: 44 | yield part.compute(scheduler="synchronous") 45 | part = None 46 | -------------------------------------------------------------------------------- /merlin/io/dataset_engine.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2022, NVIDIA CORPORATION. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | from dask.utils import natural_sort_key 18 | from fsspec.core import get_fs_token_paths 19 | 20 | 21 | class DatasetEngine: 22 | """Base class for Dask-powered IO engines. Engines must provide a ``to_ddf`` method.""" 23 | 24 | def __init__(self, paths, part_size, cpu=False, storage_options=None): 25 | paths = sorted(paths, key=natural_sort_key) 26 | self.paths = paths 27 | self.part_size = part_size 28 | self.storage_options = storage_options 29 | fs, fs_token, paths2 = get_fs_token_paths( 30 | paths, mode="rb", storage_options=self.storage_options 31 | ) 32 | self.stripped_paths = paths2 33 | self.fs = fs 34 | self.fs_token = fs_token 35 | self.cpu = cpu 36 | 37 | def to_ddf(self, columns=None, cpu=None): 38 | raise NotImplementedError(""" Return a dask.dataframe.DataFrame or dask_cudf.DataFrame""") 39 | 40 | def to_cpu(self): 41 | raise NotImplementedError(""" Move data to CPU memory """) 42 | 43 | def to_gpu(self): 44 | raise NotImplementedError(""" Move data to GPU memory """) 45 | 46 | @property 47 | def _path_partition_map(self): 48 | return None 49 | 50 | @property 51 | def num_rows(self): 52 | raise NotImplementedError(""" Returns the number of rows in the dataset """) 53 | 54 | def sample_data(self, n=1): 55 | """Return a sample of real data from the dataset 56 | 57 | Sample the partitions of the underlying Dask collection 58 | until a non-empty partition is found. Then, use the first 59 | ``n`` rows of that partition to infer dtype info. If no 60 | non-empty partitions are found, use the Dask metadata. 61 | """ 62 | _ddf = self.to_ddf() 63 | for partition_index in range(_ddf.npartitions): 64 | _head = _ddf.partitions[partition_index].head(n) 65 | if len(_head): 66 | return _head 67 | return _ddf._meta 68 | -------------------------------------------------------------------------------- /merlin/io/hugectr.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2022, NVIDIA CORPORATION. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | import os 17 | from uuid import uuid4 18 | 19 | import numpy as np 20 | 21 | from merlin.io.writer import ThreadedWriter 22 | 23 | 24 | class HugeCTRWriter(ThreadedWriter): 25 | def __init__(self, out_dir, suffix=".data", **kwargs): 26 | super().__init__(out_dir, **kwargs) 27 | self.suffix = suffix 28 | if self.use_guid: 29 | self.data_paths = [ 30 | os.path.join(out_dir, f"{i}i.{uuid4().hex}{self.suffix}") 31 | for i in range(self.num_out_files) 32 | ] 33 | else: 34 | self.data_paths = [ 35 | os.path.join(out_dir, f"{i}{self.suffix}") for i in range(self.num_out_files) 36 | ] 37 | self.data_writers = [open(f, "wb") for f in self.data_paths] 38 | # Reserve 64 bytes for header 39 | header = np.array([0, 0, 0, 0, 0, 0, 0, 0], dtype=np.longlong) 40 | for i, writer in enumerate(self.data_writers): 41 | writer.write(header.tobytes()) 42 | 43 | def _write_table(self, idx, data): 44 | # Prepare data format 45 | np_label = data[self.labels].to_pandas().astype(np.single).to_numpy() 46 | np_conts = data[self.conts].to_pandas().astype(np.single).to_numpy() 47 | nnz = np.intc(1) 48 | np_cats = data[self.cats].to_pandas().astype(np.uintc).to_numpy() 49 | # Write all the data samples 50 | for i, label in enumerate(np_label): 51 | # Write Label 52 | self.data_writers[idx].write(label.tobytes()) 53 | # Write conts (HugeCTR: dense) 54 | self.data_writers[idx].write(np_conts[i].tobytes()) 55 | # Write cats (HugeCTR: Slots) 56 | for j, _ in enumerate(np_cats[i]): 57 | self.data_writers[idx].write(nnz.tobytes()) 58 | self.data_writers[idx].write(np_cats[i][j].tobytes()) 59 | 60 | def _write_thread(self): 61 | while True: 62 | item = self.queue.get() 63 | try: 64 | if item is self._eod: 65 | break 66 | idx, data = item 67 | with self.write_locks[idx]: 68 | self._write_table(idx, data) 69 | finally: 70 | self.queue.task_done() 71 | 72 | def _close_writers(self): 73 | for i, writer in enumerate(self.data_writers): 74 | if self.cats: 75 | # Write HugeCTR Metadata 76 | writer.seek(0) 77 | # error_check (0: no error check; 1: check_num) 78 | # num of samples in this file 79 | # Dimension of the labels 80 | # Dimension of the features 81 | # slot_num for each embedding 82 | # reserved for future use 83 | header = np.array( 84 | [ 85 | 0, 86 | self.num_samples[i], 87 | len(self.labels), 88 | len(self.conts), 89 | len(self.cats), 90 | 0, 91 | 0, 92 | 0, 93 | ], 94 | dtype=np.longlong, 95 | ) 96 | writer.write(header.tobytes()) 97 | writer.close() 98 | return None 99 | 100 | def _bytesio_to_disk(self): 101 | raise ValueError("hugectr binary format doesn't support PER_WORKER shuffle yet") 102 | -------------------------------------------------------------------------------- /merlin/io/shuffle.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2022, NVIDIA CORPORATION. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | import enum 17 | import warnings 18 | 19 | import pandas as pd 20 | from packaging.version import Version 21 | 22 | from merlin.core.compat import cudf 23 | 24 | _IGNORE_INDEX_SUPPORTED = Version(pd.__version__) >= Version("1.3.0") 25 | 26 | if cudf: 27 | try: 28 | _CUDF_IGNORE_INDEX_SUPPORTED = Version(cudf.__version__) >= Version("22.04.0") 29 | except ImportError: 30 | _CUDF_IGNORE_INDEX_SUPPORTED = None 31 | 32 | 33 | class Shuffle(enum.Enum): 34 | PER_PARTITION = 0 35 | PER_WORKER = 1 36 | FULL = 2 37 | 38 | 39 | # 40 | # Helper Function definitions 41 | # 42 | 43 | 44 | def _check_shuffle_arg(shuffle): 45 | if shuffle is None: 46 | return shuffle 47 | 48 | if isinstance(shuffle, Shuffle): 49 | if shuffle == Shuffle.FULL: 50 | raise ValueError('`shuffle="full"` is not yet supported.') 51 | elif shuffle is True: 52 | shuffle = Shuffle.PER_WORKER 53 | warnings.warn("`shuffle=True` is deprecated. Using `PER_WORKER`.", DeprecationWarning) 54 | elif shuffle is False: 55 | shuffle = None 56 | else: 57 | raise ValueError(f"`shuffle={shuffle}` not recognized.") 58 | return shuffle 59 | 60 | 61 | def shuffle_df(df, size=None, keep_index=False): 62 | """Shuffles a DataFrame, returning a new dataframe with randomly 63 | ordered rows""" 64 | size = size or len(df) 65 | if isinstance(df, pd.DataFrame): 66 | if _IGNORE_INDEX_SUPPORTED: 67 | return df.sample(n=size, ignore_index=not keep_index) 68 | else: 69 | # Pandas<1.3.0 70 | if keep_index: 71 | return df.sample(n=size) 72 | return df.sample(n=size).reset_index(drop=True) 73 | else: 74 | if _CUDF_IGNORE_INDEX_SUPPORTED: 75 | return df.sample(n=size, ignore_index=not keep_index) 76 | else: 77 | return df.sample(n=size, keep_index=keep_index) 78 | -------------------------------------------------------------------------------- /merlin/io/writer_factory.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2022, NVIDIA CORPORATION. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | from fsspec.core import get_fs_token_paths 17 | 18 | from merlin.io.hugectr import HugeCTRWriter 19 | from merlin.io.parquet import CPUParquetWriter, GPUParquetWriter 20 | 21 | 22 | def writer_factory( 23 | output_format, 24 | output_path, 25 | out_files_per_proc, 26 | shuffle, 27 | use_guid=False, 28 | bytes_io=False, 29 | num_threads=0, 30 | cpu=False, 31 | fns=None, 32 | suffix=None, 33 | fs=None, 34 | **kwargs, # Format-specific arguments 35 | ): 36 | if output_format is None: 37 | return None 38 | 39 | writer_cls, fs = _writer_cls_factory(output_format, output_path, cpu=cpu, fs=fs) 40 | return writer_cls( 41 | output_path, 42 | num_out_files=out_files_per_proc, 43 | shuffle=shuffle, 44 | fs=fs, 45 | use_guid=use_guid, 46 | bytes_io=bytes_io, 47 | num_threads=num_threads, 48 | cpu=cpu, 49 | fns=fns, 50 | suffix=suffix, 51 | **kwargs, # Format-specific arguments 52 | ) 53 | 54 | 55 | def _writer_cls_factory(output_format, output_path, cpu=None, fs=None): 56 | if output_format == "parquet" and cpu: 57 | writer_cls = CPUParquetWriter 58 | elif output_format == "parquet": 59 | writer_cls = GPUParquetWriter 60 | elif output_format == "hugectr": 61 | writer_cls = HugeCTRWriter 62 | else: 63 | raise ValueError("Output format not yet supported.") 64 | 65 | if fs is None: 66 | fs = get_fs_token_paths(output_path)[0] 67 | return writer_cls, fs 68 | -------------------------------------------------------------------------------- /merlin/schema/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2022, NVIDIA CORPORATION. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | # flake8: noqa 18 | from merlin.schema.schema import ColumnSchema, Schema 19 | from merlin.schema.tags import Tags, TagSet, TagsType 20 | -------------------------------------------------------------------------------- /merlin/schema/io/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2022, NVIDIA CORPORATION. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | -------------------------------------------------------------------------------- /merlin/schema/io/proto_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | from typing import TypeVar 16 | 17 | import betterproto 18 | from betterproto import Message as BetterProtoMessage 19 | from google.protobuf import json_format, text_format 20 | from google.protobuf.message import Message as ProtoMessage 21 | 22 | ProtoMessageType = TypeVar("ProtoMessageType", bound=BetterProtoMessage) 23 | 24 | 25 | def has_field(message: ProtoMessageType, field_name: str) -> bool: 26 | """Check if a Protobuf message has a particular field 27 | 28 | Parameters 29 | ---------- 30 | message : ProtoMessageType 31 | Protobuf message object 32 | field_name : str 33 | Name of the field to look for 34 | message: ProtoMessageType : 35 | 36 | Returns 37 | ------- 38 | bool 39 | `True` if the named field exists on the message object 40 | 41 | """ 42 | return betterproto.serialized_on_wire(getattr(message, field_name)) 43 | 44 | 45 | def copy_better_proto_message(better_proto_message: ProtoMessageType, **kwargs) -> ProtoMessageType: 46 | """Create a copy of a Protobuf message 47 | 48 | Parameters 49 | ---------- 50 | better_proto_message : ProtoMessageType 51 | The message to copy 52 | 53 | Returns 54 | ------- 55 | ProtoMessageType 56 | Copy of better_proto_message 57 | 58 | """ 59 | output = better_proto_message.__class__().parse(bytes(better_proto_message)) 60 | for key, val in kwargs.items(): 61 | setattr(output, key, val) 62 | 63 | return output 64 | 65 | 66 | def better_proto_to_proto_text( 67 | better_proto_message: BetterProtoMessage, message: ProtoMessage 68 | ) -> str: 69 | """Convert a BetterProto message object to Protobuf text format 70 | 71 | Parameters 72 | ---------- 73 | better_proto_message : BetterProtoMessage 74 | The message to convert 75 | message : ProtoMessage 76 | A blank (raw) Protobuf message object to parse into 77 | 78 | Returns 79 | ------- 80 | str 81 | Protobuf text representation of better_proto_message 82 | 83 | """ 84 | message.ParseFromString(bytes(better_proto_message)) 85 | 86 | return text_format.MessageToString(message) 87 | 88 | 89 | def proto_text_to_better_proto( 90 | better_proto_message: ProtoMessageType, proto_text: str, message: ProtoMessage 91 | ) -> ProtoMessageType: 92 | """Convert a Protobuf text format message into a BetterProto message object 93 | 94 | Parameters 95 | ---------- 96 | better_proto_message : ProtoMessageType 97 | A BetterProto message object of the desired type 98 | proto_text : str 99 | The Protobuf text format message to convert 100 | message : ProtoMessage 101 | A blank (raw) Protobuf message object to parse proto_text into 102 | 103 | Returns 104 | ------- 105 | ProtoMessageType 106 | A BetterProto message object containing attributes parsed from proto_text 107 | 108 | """ 109 | proto = text_format.Parse(proto_text, message) 110 | json_str = json_format.MessageToJson(proto) 111 | 112 | return better_proto_message.__class__().from_json(json_str) 113 | -------------------------------------------------------------------------------- /merlin/table/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2022, NVIDIA CORPORATION. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | # flake8: noqa 18 | from merlin.table.conversions import df_from_tensor_table, tensor_table_from_df 19 | from merlin.table.cupy_column import CupyColumn 20 | from merlin.table.numpy_column import NumpyColumn 21 | from merlin.table.tensor_column import Device, TensorColumn 22 | from merlin.table.tensor_table import TensorTable 23 | from merlin.table.tensorflow_column import TensorflowColumn 24 | from merlin.table.torch_column import TorchColumn 25 | -------------------------------------------------------------------------------- /merlin/table/cupy_column.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023, NVIDIA CORPORATION. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | from typing import Callable, Optional, Type 17 | 18 | from merlin.core.compat import cupy as cp 19 | from merlin.table.conversions import _from_dlpack_gpu, _to_dlpack 20 | from merlin.table.tensor_column import Device, TensorColumn 21 | 22 | 23 | class CupyColumn(TensorColumn): 24 | """ 25 | A SeriesLike column backed by CuPy arrays 26 | """ 27 | 28 | @classmethod 29 | def array_type(cls) -> Optional[Type]: 30 | """ 31 | The type of the arrays backing this column 32 | """ 33 | return cp.ndarray if cp else None 34 | 35 | @classmethod 36 | def array_constructor(cls) -> Callable: 37 | return cp.asarray 38 | 39 | @classmethod 40 | def supported_devices(cls): 41 | """ 42 | List of device types supported by this column type 43 | """ 44 | return [Device.GPU] 45 | 46 | def __init__( 47 | self, 48 | values: "cp.ndarray", 49 | offsets: "cp.ndarray" = None, 50 | dtype=None, 51 | _ref=None, 52 | _unsafe=False, 53 | ): 54 | super().__init__(values, offsets, dtype, _ref=_ref, _device=Device.GPU, _unsafe=_unsafe) 55 | 56 | def cpu(self): 57 | """ 58 | Move this column's data to host (i.e. CPU) memory 59 | 60 | Returns 61 | ------- 62 | NumpyColumn 63 | A copy of this column backed by NumPy arrays 64 | """ 65 | from merlin.table import NumpyColumn 66 | 67 | values = cp.asnumpy(self.values) 68 | offsets = cp.asnumpy(self.offsets) if self.offsets is not None else None 69 | 70 | return NumpyColumn(values, offsets) 71 | 72 | def gpu(self): 73 | """ 74 | Move this column's data to device (i.e. GPU) memory 75 | 76 | Returns 77 | ------- 78 | CupyColumn 79 | This column, unchanged and backed by CuPy arrays 80 | """ 81 | return self 82 | 83 | @property 84 | def _flatten_values(self): 85 | return self.values.flatten() 86 | 87 | def _reshape_values(self, values, shape): 88 | return cp.reshape(values, shape) 89 | 90 | 91 | @_to_dlpack.register_lazy("cupy") 92 | def _register_to_dlpack_from_cupy(): 93 | import cupy as cp 94 | 95 | @_to_dlpack.register(cp.ndarray) 96 | def _to_dlpack_from_cp_tensor(tensor): 97 | if tensor.dtype == cp.dtype("bool"): 98 | tensor = tensor.astype(cp.dtype("int8")) 99 | return tensor 100 | 101 | 102 | @_from_dlpack_gpu.register_lazy("cupy") 103 | def _register_from_dlpack_gpu_to_cupy(): 104 | import cupy as cp 105 | 106 | @_from_dlpack_gpu.register(cp.ndarray) 107 | def _from_dlpack_gpu_to_cupy(to, array) -> cp.ndarray: 108 | return cp.from_dlpack(array) 109 | -------------------------------------------------------------------------------- /merlin/table/numpy_column.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023, NVIDIA CORPORATION. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | from typing import Callable, Type 17 | 18 | from merlin.core.compat import cupy as cp 19 | from merlin.core.compat import numpy as np 20 | from merlin.table.conversions import _from_dlpack_cpu, _to_dlpack 21 | from merlin.table.tensor_column import Device, TensorColumn 22 | 23 | 24 | class NumpyColumn(TensorColumn): 25 | """ 26 | A SeriesLike column backed by NumPy arrays 27 | """ 28 | 29 | @classmethod 30 | def array_type(cls) -> Type: 31 | """ 32 | The type of the arrays backing this column 33 | """ 34 | return np.ndarray 35 | 36 | @classmethod 37 | def array_constructor(cls) -> Callable: 38 | return np.array 39 | 40 | @classmethod 41 | def supported_devices(cls): 42 | """ 43 | List of device types supported by this column type 44 | """ 45 | return [Device.CPU] 46 | 47 | def __init__( 48 | self, 49 | values: "np.ndarray", 50 | offsets: "np.ndarray" = None, 51 | dtype=None, 52 | _ref=None, 53 | _unsafe=False, 54 | ): 55 | super().__init__(values, offsets, dtype, _ref=_ref, _device=Device.CPU, _unsafe=_unsafe) 56 | 57 | def cpu(self): 58 | """ 59 | Move this column's data to host (i.e. CPU) memory 60 | 61 | Returns 62 | ------- 63 | NumpyColumn 64 | This column, unchanged and backed by NumPy arrays 65 | """ 66 | return self 67 | 68 | def gpu(self): 69 | """ 70 | Move this column's data to device (i.e. GPU) memory 71 | 72 | Returns 73 | ------- 74 | CupyColumn 75 | A copy of this column backed by CuPy arrays 76 | """ 77 | 78 | from merlin.table import CupyColumn 79 | 80 | values = cp.asarray(self.values) 81 | offsets = cp.asarray(self.offsets) if self.offsets is not None else None 82 | 83 | return CupyColumn(values, offsets) 84 | 85 | @property 86 | def _flatten_values(self): 87 | return self.values.flatten() 88 | 89 | def _reshape_values(self, values, shape): 90 | return np.reshape(values, shape) 91 | 92 | 93 | @_from_dlpack_cpu.register_lazy("numpy") 94 | def _register_from_dlpack_cpu_to_numpy(): 95 | import numpy as np 96 | 97 | @_from_dlpack_cpu.register(np.ndarray) 98 | def _from_dlpack_cpu_to_numpy(to, array): 99 | try: 100 | # private `_from_dlpack` method added in 1.22.0 101 | return np._from_dlpack(array) 102 | except AttributeError: 103 | pass 104 | try: 105 | # public `from_dlpack` method added in 1.23.0 106 | return np.from_dlpack(array) 107 | except AttributeError as exc: 108 | raise NotImplementedError( 109 | "NumPy does not implement the DLPack Standard until version 1.22.0, " 110 | f"currently running {np.__version__}" 111 | ) from exc 112 | 113 | 114 | @_to_dlpack.register_lazy("numpy") 115 | def _register_from_numpy_to_dlpack_cpu(): 116 | import numpy as np 117 | 118 | @_to_dlpack.register(np.ndarray) 119 | def _to_dlpack_cpu_from_numpy(array): 120 | if array.dtype == np.dtype("bool"): 121 | array = array.astype(np.dtype("int8")) 122 | return array 123 | -------------------------------------------------------------------------------- /merlin/table/torch_column.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023, NVIDIA CORPORATION. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | from typing import Callable, Optional, Type 17 | 18 | from merlin.core.compat.torch import torch as th 19 | from merlin.table.conversions import _from_dlpack_cpu, _from_dlpack_gpu, _to_dlpack 20 | from merlin.table.tensor_column import Device, TensorColumn 21 | 22 | 23 | class TorchColumn(TensorColumn): 24 | """ 25 | A SeriesLike column backed by Torch tensors 26 | """ 27 | 28 | framework_name = "torch" 29 | 30 | @classmethod 31 | def array_type(cls) -> Optional[Type]: 32 | """ 33 | The type of the arrays backing this column 34 | """ 35 | return th.Tensor if th else None 36 | 37 | @classmethod 38 | def array_constructor(cls) -> Callable: 39 | return th.tensor 40 | 41 | @classmethod 42 | def supported_devices(cls): 43 | """ 44 | List of device types supported by this column type 45 | """ 46 | return [Device.CPU, Device.GPU] 47 | 48 | def __init__( 49 | self, values: "th.Tensor", offsets: "th.Tensor" = None, dtype=None, _ref=None, _unsafe=False 50 | ): 51 | values_device = self._th_device(values) 52 | if offsets is not None: 53 | offsets_device = self._th_device(offsets) 54 | if values_device != offsets_device: 55 | raise ValueError( 56 | f"Values and offsets were detected on different devices: " 57 | f"values ({values_device}) and offsets ({offsets_device})." 58 | ) 59 | 60 | super().__init__(values, offsets, dtype, _device=values_device, _ref=_ref, _unsafe=_unsafe) 61 | 62 | def cpu(self): 63 | """ 64 | Move this column's data to host (i.e. CPU) memory 65 | 66 | Returns 67 | ------- 68 | TorchColumn 69 | A copy of this column backed by Torch CPU tensors 70 | """ 71 | if self.device is Device.CPU: 72 | return self 73 | 74 | values = self.values.cpu() 75 | offsets = self.offsets.cpu() if self.offsets is not None else None 76 | 77 | return TorchColumn(values, offsets) 78 | 79 | def gpu(self): 80 | """ 81 | Move this column's data to device (i.e. GPU) memory 82 | 83 | Returns 84 | ------- 85 | TorchColumn 86 | A copy of this column backed by Torch GPU tensors 87 | """ 88 | if self.device is Device.GPU: 89 | return self 90 | 91 | values = self.values.cuda() 92 | offsets = self.offsets.cuda() if self.offsets is not None else None 93 | 94 | return TorchColumn(values, offsets) 95 | 96 | @property 97 | def device(self) -> Device: 98 | return self._th_device(self.values) 99 | 100 | @property 101 | def _flatten_values(self): 102 | return th.flatten(self.values) 103 | 104 | def _reshape_values(self, values, shape): 105 | return th.reshape(values, shape) 106 | 107 | def _th_device(self, tensor): 108 | return Device.GPU if tensor.is_cuda else Device.CPU 109 | 110 | 111 | @_to_dlpack.register_lazy("torch") 112 | def _register_to_dlpack_from_torch(): 113 | import torch as th 114 | 115 | @_to_dlpack.register(th.Tensor) 116 | def _to_dlpack_from_torch_tensor(tensor): 117 | return tensor 118 | 119 | 120 | @_from_dlpack_cpu.register_lazy("torch") 121 | def _register_from_dlpack_cpu_to_torch(): 122 | import torch as th 123 | 124 | @_from_dlpack_cpu.register(th.Tensor) 125 | def _from_dlpack_cpu_to_torch(target_type, array): 126 | return th.utils.dlpack.from_dlpack(array) 127 | 128 | 129 | @_from_dlpack_gpu.register_lazy("torch") 130 | def _register_from_dlpack_gpu_to_torch(): 131 | import torch as th 132 | 133 | @_from_dlpack_gpu.register(th.Tensor) 134 | def _from_dlpack_gpu_to_torch(target_type, array): 135 | return th.utils.dlpack.from_dlpack(array.__dlpack__()) 136 | -------------------------------------------------------------------------------- /merlin/testing/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023, NVIDIA CORPORATION. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | # flake8: noqa 18 | from merlin.testing.assert_equals import assert_transformable_equal 19 | -------------------------------------------------------------------------------- /merlin/testing/assert_equals.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023, NVIDIA CORPORATION. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | from merlin.core.compat import pandas as pd 18 | from merlin.dispatch.lazy import lazy_singledispatch 19 | from merlin.table import TensorTable 20 | 21 | 22 | def assert_table_equal(left: TensorTable, right: TensorTable): 23 | pd.testing.assert_frame_equal(left.cpu().to_df(), right.cpu().to_df()) 24 | 25 | 26 | @lazy_singledispatch 27 | def assert_transformable_equal(left, right): 28 | raise NotImplementedError 29 | 30 | 31 | @assert_transformable_equal.register(TensorTable) 32 | def _assert_equal_table(left, right): 33 | assert_table_equal(left, right) 34 | 35 | 36 | @assert_transformable_equal.register_lazy("cudf") 37 | def _register_assert_equal_df_cudf(): 38 | import cudf 39 | 40 | @assert_transformable_equal.register(cudf.DataFrame) 41 | def _assert_equal_df_cudf(left, right): 42 | cudf.testing.assert_frame_equal(left, right) 43 | 44 | 45 | @assert_transformable_equal.register_lazy("pandas") 46 | def _register_assert_equal_pandas(): 47 | import pandas 48 | 49 | @assert_transformable_equal.register(pandas.DataFrame) 50 | def _assert_equal_pandas(left, right): 51 | pandas.testing.assert_frame_equal(left, right) 52 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | python_version = 3.8 3 | warn_unused_configs = True 4 | exclude = versioneer.py 5 | ignore_missing_imports = True 6 | show_traceback = True 7 | strict_optional = False -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools>=42", 4 | "wheel", 5 | ] 6 | build-backend = "setuptools.build_meta" 7 | 8 | [tool.black] 9 | line-length = 100 10 | 11 | [tool.isort] 12 | use_parentheses = true 13 | multi_line_output = 3 14 | include_trailing_comma = true 15 | force_grid_wrap = 0 16 | ensure_newline_before_comments = true 17 | line_length = 100 18 | balanced_wrapping = true 19 | indent = " " 20 | known_third_party = ["cudf", "cupy", "dask", "dask_cuda", "dask_cudf", "numba", "numpy", "pytest", "torch", "rmm", "tensorflow"] 21 | skip = ["build", ".eggs"] 22 | 23 | [tool.interrogate] 24 | ignore-init-method = true 25 | ignore-init-module = true 26 | ignore-magic = true 27 | ignore-module = true 28 | ignore-private = true 29 | ignore-property-decorators = true 30 | ignore-nested-classes = true 31 | ignore-nested-functions = true 32 | ignore-semiprivate = true 33 | ignore-setters = true 34 | fail-under = 70 35 | exclude = ["build", "docs", "merlin/core", "merlin/io", "tests", "setup.py", "versioneer.py"] 36 | verbose = 1 37 | omit-covered-files = true 38 | quiet = false 39 | whitelist-regex = [] 40 | ignore-regex = [] 41 | color = true 42 | 43 | [tool.pytest.ini_options] 44 | filterwarnings = [ 45 | 'ignore:`np.*` is a deprecated alias:DeprecationWarning', 46 | 'ignore:The default dtype for empty Series:DeprecationWarning', 47 | 'ignore:General-metadata information not detected:UserWarning', 48 | 'ignore:Changing an NVTabular Dataset to CPU mode:UserWarning', 49 | 'ignore:Initializing an NVTabular Dataset in CPU mode:UserWarning', 50 | 'ignore:Performing a hash-based transformation:UserWarning', 51 | 'ignore:WARNING..cuDF.to_dlpack', 52 | 'ignore:::numba.cuda.envvar:', 53 | 'ignore:Call to deprecated create function:DeprecationWarning', 54 | 'ignore:Only created .* files did not have enough partitions to create .* file:UserWarning', 55 | 'ignore:distutils Version classes are deprecated. Use packaging.version instead:DeprecationWarning', 56 | ] 57 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | # packages necessary to run tests and push PRs 2 | # assumes requirements for merlin-core are already installed 3 | 4 | pytest>=5 5 | pytest-cov>=2 6 | pytest-xdist 7 | -------------------------------------------------------------------------------- /requirements-docs.txt: -------------------------------------------------------------------------------- 1 | -r requirements.txt 2 | -r requirements-dev.txt 3 | 4 | sphinxcontrib-applehelp==1.0.4 5 | sphinxcontrib-devhelp==1.0.2 6 | sphinxcontrib-htmlhelp==2.0.1 7 | sphinxcontrib-qthelp==1.0.3 8 | sphinxcontrib-serializinghtml==1.1.5 9 | sphinx-multiversion@git+https://github.com/mikemckiernan/sphinx-multiversion.git 10 | sphinxcontrib-copydirs@git+https://github.com/mikemckiernan/sphinxcontrib-copydirs.git 11 | recommonmark~=0.7.1 12 | Jinja2<3.1 13 | natsort~=8.4.0 14 | myst-nb~=1.1.0 15 | linkify-it-py~=2.0.3 16 | sphinx-external-toc~=1.0.1 17 | attrs~=23.2.0 18 | sphinx-book-theme~=1.1.2 19 | sphinx_design~=0.5.0 20 | -------------------------------------------------------------------------------- /requirements-gpu.txt: -------------------------------------------------------------------------------- 1 | -r requirements.txt 2 | 3 | # cudf>=21.12 4 | # dask-cudf>=21.12 5 | # dask-cuda>=21.12 6 | -------------------------------------------------------------------------------- /requirements-test-cpu.txt: -------------------------------------------------------------------------------- 1 | -r requirements-test.txt 2 | -r requirements.txt 3 | -------------------------------------------------------------------------------- /requirements-test-gpu.txt: -------------------------------------------------------------------------------- 1 | -r requirements-test.txt 2 | -r requirements-gpu.txt 3 | -------------------------------------------------------------------------------- /requirements-test.txt: -------------------------------------------------------------------------------- 1 | -r requirements.txt 2 | -r requirements-dev.txt 3 | 4 | # This contains common libraries for testing. 5 | 6 | # NOTE: You should pip install requirements-test-[cpu|gpu].txt for device-specific test 7 | # requirements, which will include the dependencies defined in this file. 8 | 9 | pytest>=5 10 | pytest-cov>=2 11 | testbook==0.4.2 12 | 13 | # needed to make test_s3 work 14 | # moto>=2 15 | # boto3==1.17 16 | # s3fs>=2021.4 17 | # aiobotocore>=1.3.3 18 | # flask 19 | # flask-cors 20 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | dask>=2022.11.1 2 | dask-cuda>=22.12.0 3 | distributed>=2022.11.1 4 | fsspec>=2022.7.1 5 | numpy>=1.22.0 6 | pandas>=1.2.0,<1.6.0dev0 7 | numba>=0.54 8 | pyarrow>=5.0.0 9 | protobuf>=3.0.0 10 | tqdm>=4.0 11 | tensorflow-metadata>=1.2.0 12 | betterproto<2.0.0 13 | packaging 14 | npy-append-array 15 | 16 | # pynvml==11.5.0 is incompatible with distributed<2023.2.1 17 | pynvml>=11.0.0,<11.5 18 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 100 3 | exclude = build,.eggs,*_pb2.py 4 | ignore = E203,W503 5 | per-file-ignores = 6 | examples/criteo_benchmark.py:E402 7 | examples/dataloader_bench.py:E402 8 | 9 | [flake8_nb] 10 | max-line-length = 120 11 | ignore = E203,E402,W503 12 | 13 | [pydocstyle] 14 | ignore = D100,D102,D103,D104,D105,D107,D203,D205,D211,D212,D213,D400,D401,D413,D415 15 | 16 | [codespell] 17 | skip = .*pb2.py,./.git,./.github,./bench,./dist,./docs/build,.*egg-info.*,versioneer.py,*.csv,*.parquet 18 | ignore-words = ./ci/ignore_codespell_words.txt 19 | count = 20 | quiet-level = 3 21 | 22 | # See the docstring in versioneer.py for instructions. Note that you must 23 | # re-run 'versioneer.py setup' after changing this section, and commit the 24 | # resulting files. 25 | 26 | [versioneer] 27 | VCS = git 28 | style = pep440 29 | versionfile_source = merlin/core/_version.py 30 | versionfile_build = merlin/core/_version.py 31 | tag_prefix = v 32 | parentdir_prefix = merlin-core- 33 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2022, NVIDIA CORPORATION. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | import os 17 | import sys 18 | 19 | from setuptools import find_namespace_packages, setup 20 | 21 | try: 22 | import versioneer 23 | except ImportError: 24 | # we have a versioneer.py file living in the same directory as this file, but 25 | # if we're using pep 517/518 to build from pyproject.toml its not going to find it 26 | # https://github.com/python-versioneer/python-versioneer/issues/193#issue-408237852 27 | # make this work by adding this directory to the python path 28 | sys.path.append(os.path.dirname(os.path.realpath(__file__))) 29 | import versioneer 30 | 31 | 32 | def parse_requirements(filename): 33 | """load requirements from a pip requirements file""" 34 | lineiter = (line.strip() for line in open(filename)) 35 | return [line for line in lineiter if line and not line.startswith("#")] 36 | 37 | 38 | install_reqs = parse_requirements("./requirements.txt") 39 | 40 | setup( 41 | name="merlin-core", 42 | version=versioneer.get_version(), 43 | packages=find_namespace_packages(include=["merlin*"]), 44 | url="https://github.com/NVIDIA-Merlin/core", 45 | author="NVIDIA Corporation", 46 | license="Apache 2.0", 47 | long_description=open("README.md", encoding="utf8").read(), 48 | long_description_content_type="text/markdown", 49 | classifiers=[ 50 | "Development Status :: 4 - Beta", 51 | "Programming Language :: Python :: 3", 52 | "Intended Audience :: Developers", 53 | "License :: OSI Approved :: Apache Software License", 54 | "Topic :: Software Development :: Libraries", 55 | "Topic :: Scientific/Engineering", 56 | ], 57 | zip_safe=False, 58 | python_requires=">=3.8", 59 | install_requires=install_reqs, 60 | cmdclass=versioneer.get_cmdclass(), 61 | ) 62 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA-Merlin/core/6d396aa1d5e6ffe60bb133dcfe93d869b8548ba2/tests/__init__.py -------------------------------------------------------------------------------- /tests/unit/core/test_dispatch.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2022, NVIDIA CORPORATION. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | import numpy as np 17 | import pandas as pd 18 | import pytest 19 | 20 | from merlin.core.compat import HAS_GPU 21 | from merlin.core.compat import cupy as cp 22 | from merlin.core.dispatch import ( 23 | concat_columns, 24 | is_list_dtype, 25 | list_val_dtype, 26 | make_df, 27 | nullable_series, 28 | ) 29 | 30 | if HAS_GPU: 31 | _DEVICES = ["cpu", "gpu"] 32 | else: 33 | _DEVICES = ["cpu"] 34 | 35 | 36 | @pytest.mark.parametrize("device", _DEVICES) 37 | def test_list_dtypes(tmpdir, device): 38 | df = make_df(device=device) 39 | df["vals"] = [ 40 | [[0, 1, 2], [3, 4], [5]], 41 | ] 42 | # Check that the index can be arbitrary 43 | df.set_index(np.array([2]), drop=True, inplace=True) 44 | 45 | assert is_list_dtype(df["vals"]) 46 | assert list_val_dtype(df["vals"]) == np.dtype(np.int64) 47 | 48 | 49 | @pytest.mark.parametrize("device", _DEVICES) 50 | def test_concat_columns(device): 51 | df1 = make_df({"a": [1, 2], "b": [[3], [4, 5]]}, device=device) 52 | df2 = make_df({"c": [3, 4, 5]}, device=device) 53 | data_frames = [df1, df2] 54 | res = concat_columns(data_frames) 55 | assert res.columns.to_list() == ["a", "b", "c"] 56 | 57 | 58 | @pytest.mark.skipif(not (cp and HAS_GPU), reason="Cupy not available") 59 | def test_pandas_cupy_combo(): 60 | rand_cp_nd_arr = cp.random.uniform(0.0, 1.0, size=100) 61 | with pytest.raises(TypeError) as exc_info: 62 | pd.DataFrame(rand_cp_nd_arr) 63 | 64 | assert "Implicit conversion to a NumPy array is not allowed" in str(exc_info) 65 | pd_df = pd.DataFrame(rand_cp_nd_arr.get())[0] 66 | mk_df = make_df(rand_cp_nd_arr)[0] 67 | assert all(pd_df.to_numpy() == mk_df.to_numpy()) 68 | 69 | 70 | @pytest.mark.parametrize( 71 | ["data", "dtype", "expected_series"], 72 | [ 73 | [[None], np.dtype("int8"), pd.Series([pd.NA], dtype="Int8")], 74 | [[None], np.dtype("int16"), pd.Series([pd.NA], dtype="Int16")], 75 | [[None], np.dtype("int32"), pd.Series([pd.NA], dtype="Int32")], 76 | [[None], np.dtype("int64"), pd.Series([pd.NA], dtype="Int64")], 77 | [[None], np.dtype("uint8"), pd.Series([pd.NA], dtype="UInt8")], 78 | [[None], np.dtype("uint16"), pd.Series([pd.NA], dtype="UInt16")], 79 | [[None], np.dtype("uint32"), pd.Series([pd.NA], dtype="UInt32")], 80 | [[None], np.dtype("uint64"), pd.Series([pd.NA], dtype="UInt64")], 81 | [[None], np.dtype("float32"), pd.Series([pd.NA], dtype="Float32")], 82 | [[None], np.dtype("float64"), pd.Series([pd.NA], dtype="Float64")], 83 | ], 84 | ) 85 | def test_nullable_series(data, dtype, expected_series): 86 | series = nullable_series(data, pd.DataFrame(), dtype) 87 | pd.testing.assert_series_equal(series, expected_series) 88 | -------------------------------------------------------------------------------- /tests/unit/core/test_protocols.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2022, NVIDIA CORPORATION. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | import pytest 17 | 18 | from merlin.core.compat import HAS_GPU, cudf 19 | from merlin.core.dispatch import make_df, make_series 20 | from merlin.core.protocols import DataFrameLike, DictLike, SeriesLike, Transformable 21 | 22 | if HAS_GPU and cudf: 23 | _DEVICES = ["cpu", None] 24 | else: 25 | _DEVICES = ["cpu"] 26 | 27 | 28 | @pytest.mark.parametrize("protocol", [DictLike]) 29 | def test_dictionary_is_dictlike(protocol): 30 | obj = {} 31 | 32 | assert isinstance(obj, protocol) 33 | 34 | 35 | @pytest.mark.parametrize("device", _DEVICES) 36 | @pytest.mark.parametrize("protocol", [DictLike, DataFrameLike, Transformable]) 37 | def test_dataframes_match_protocols(protocol, device): 38 | obj = make_df({}, device=device) 39 | 40 | assert isinstance(obj, protocol) 41 | 42 | 43 | @pytest.mark.parametrize("device", _DEVICES) 44 | def test_series_are_serieslike(device): 45 | obj = make_series([0], device=device) 46 | 47 | assert isinstance(obj, SeriesLike) 48 | -------------------------------------------------------------------------------- /tests/unit/core/test_version.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2022, NVIDIA CORPORATION. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | import pytest 17 | from packaging.version import Version 18 | 19 | import merlin.core 20 | 21 | 22 | @pytest.mark.skip(reason="requires full clone of repo to ensure versions are detected.") 23 | def test_version(): 24 | assert Version(merlin.core.__version__) >= Version("0.6.0") 25 | -------------------------------------------------------------------------------- /tests/unit/dag/ops/test_addmetadata.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | import merlin.dag.ops.add_metadata as ops 5 | from merlin.dag import ColumnSelector 6 | from merlin.schema import ColumnSchema, Schema 7 | 8 | 9 | @pytest.mark.parametrize("properties", [{}, {"p1": "1"}]) 10 | @pytest.mark.parametrize("tags", [[], ["TAG1", "TAG2"]]) 11 | @pytest.mark.parametrize( 12 | "op", 13 | [ 14 | ops.AddMetadata(tags=["excellent"], properties={"domain": {"min": 0, "max": 20}}), 15 | ops.AddTags(tags=["excellent"]), 16 | ops.AddProperties(properties={"domain": {"min": 0, "max": 20}}), 17 | ops.TagAsUserID(), 18 | ops.TagAsItemID(), 19 | ops.TagAsUserFeatures(), 20 | ops.TagAsItemFeatures(), 21 | ], 22 | ) 23 | @pytest.mark.parametrize("selection", [["1"], ["2", "3"], ["1", "2", "3", "4"]]) 24 | def test_schema_out(tags, properties, selection, op): 25 | # Create columnSchemas 26 | column_schemas = [] 27 | all_cols = [] 28 | for x in range(5): 29 | all_cols.append(str(x)) 30 | column_schemas.append( 31 | ColumnSchema(str(x), dtype=np.int32, tags=tags, properties=properties) 32 | ) 33 | 34 | # Turn to Schema 35 | input_schema = Schema(column_schemas) 36 | 37 | # run schema through op 38 | selector = ColumnSelector(selection) 39 | output_schema = op.compute_output_schema(input_schema, selector) 40 | 41 | # should have dtype float 42 | for input_col_name in selector.names: 43 | output_col_names = [name for name in output_schema.column_schemas if input_col_name in name] 44 | if output_col_names: 45 | for output_col_name in output_col_names: 46 | result_schema = output_schema.column_schemas[output_col_name] 47 | 48 | expected_dtype = op._compute_dtype( 49 | ColumnSchema(output_col_name), 50 | Schema([input_schema.column_schemas[input_col_name]]), 51 | ).dtype 52 | 53 | expected_tags = op._compute_tags( 54 | ColumnSchema(output_col_name), 55 | Schema([input_schema.column_schemas[input_col_name]]), 56 | ).tags 57 | 58 | expected_properties = op._compute_properties( 59 | ColumnSchema(output_col_name), 60 | Schema([input_schema.column_schemas[input_col_name]]), 61 | ).properties 62 | 63 | assert result_schema.dtype == expected_dtype 64 | if output_col_name in selector.names: 65 | assert result_schema.properties == expected_properties 66 | 67 | assert len(result_schema.tags) == len(expected_tags) 68 | else: 69 | assert set(expected_tags).issubset(result_schema.tags) 70 | 71 | not_used = [col for col in all_cols if col not in selector.names] 72 | for input_col_name in not_used: 73 | assert input_col_name not in output_schema.column_schemas 74 | -------------------------------------------------------------------------------- /tests/unit/dag/ops/test_rename.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023, NVIDIA CORPORATION. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | import numpy as np 17 | import pandas as pd 18 | import pytest 19 | 20 | from merlin.core.compat import cudf 21 | from merlin.dag import ColumnSelector 22 | from merlin.dag.ops.rename import Rename 23 | from merlin.table import TensorTable 24 | from merlin.testing import assert_transformable_equal 25 | 26 | transformables = [pd.DataFrame, TensorTable] 27 | if cudf: 28 | transformables.append(cudf.DataFrame) 29 | 30 | 31 | @pytest.mark.parametrize("transformable", transformables) 32 | def test_rename(transformable): 33 | df = transformable({"x": np.array([1, 2, 3, 4, 5]), "y": np.array([6, 7, 8, 9, 10])}) 34 | 35 | selector = ColumnSelector(["x", "y"]) 36 | 37 | op = Rename(f=lambda name: name.upper()) 38 | transformed = op.transform(selector, df) 39 | expected = transformable({"X": np.array([1, 2, 3, 4, 5]), "Y": np.array([6, 7, 8, 9, 10])}) 40 | assert_transformable_equal(transformed, expected) 41 | 42 | op = Rename(postfix="_lower") 43 | transformed = op.transform(selector, df) 44 | expected = transformable( 45 | { 46 | "x_lower": np.array([1, 2, 3, 4, 5]), 47 | "y_lower": np.array([6, 7, 8, 9, 10]), 48 | } 49 | ) 50 | assert_transformable_equal(transformed, expected) 51 | 52 | selector = ColumnSelector(["x"]) 53 | 54 | op = Rename(name="z") 55 | transformed = op.transform(selector, df) 56 | expected = transformable({"z": np.array([1, 2, 3, 4, 5])}) 57 | assert_transformable_equal(transformed, expected) 58 | 59 | op = Rename(f=lambda name: name.upper()) 60 | transformed = op.transform(selector, df) 61 | expected = transformable({"X": np.array([1, 2, 3, 4, 5])}) 62 | assert_transformable_equal(transformed, expected) 63 | -------------------------------------------------------------------------------- /tests/unit/dag/ops/test_selection.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2022, NVIDIA CORPORATION. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | import pytest 17 | 18 | from merlin.dag import ColumnSelector 19 | from merlin.dag.ops.selection import SelectionOp 20 | from merlin.schema import ColumnSchema, Schema 21 | 22 | 23 | @pytest.mark.parametrize("engine", ["parquet"]) 24 | def test_selection_transform(df): 25 | selector = ColumnSelector(["x", "y"]) 26 | op = SelectionOp(selector) 27 | 28 | result_df = op.transform(ColumnSelector(), df) 29 | 30 | assert (result_df.columns == ["x", "y"]).all() 31 | 32 | 33 | @pytest.mark.parametrize("engine", ["parquet"]) 34 | def test_selection_output_column_names(df): 35 | selector = ColumnSelector(["x", "y"]) 36 | 37 | op = SelectionOp(selector) 38 | result_selector = op.output_column_names(selector) 39 | 40 | assert result_selector.names == ["x", "y"] 41 | 42 | 43 | @pytest.mark.parametrize("engine", ["parquet"]) 44 | def test_selection_output_schema(df): 45 | selector = ColumnSelector(["x", "y"]) 46 | schema = Schema([ColumnSchema(col) for col in df.columns]) 47 | op = SelectionOp(selector) 48 | 49 | result_schema = op.compute_output_schema(schema, ColumnSelector()) 50 | 51 | assert result_schema.column_names == ["x", "y"] 52 | 53 | 54 | @pytest.mark.parametrize("engine", ["parquet"]) 55 | def test_selection_wildcard_output_schema(df): 56 | selector = ColumnSelector("*") 57 | schema = Schema([ColumnSchema(col) for col in df.columns]) 58 | op = SelectionOp(selector) 59 | 60 | result_schema = op.compute_output_schema(schema, ColumnSelector()) 61 | 62 | assert result_schema.column_names == schema.column_names 63 | -------------------------------------------------------------------------------- /tests/unit/dag/ops/test_stat_op.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023, NVIDIA CORPORATION. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | from typing import Dict 17 | 18 | import dask.dataframe as dd 19 | import numpy as np 20 | import pandas as pd 21 | import pytest 22 | 23 | from merlin.core.compat import cudf 24 | from merlin.dag import ColumnSelector, Graph 25 | from merlin.dag.executors import DaskExecutor 26 | from merlin.dag.ops.stat_operator import StatOperator 27 | from merlin.io.dataset import Dataset 28 | 29 | transformables = [pd.DataFrame] 30 | if cudf: 31 | transformables.append(cudf.DataFrame) 32 | 33 | 34 | class FitOp(StatOperator): 35 | def fit(self, col_selector: ColumnSelector, ddf: dd.DataFrame): 36 | fit_exactly_once = "fit_exactly_once" not in self.stats 37 | self.stats: Dict[str, bool] = {"fit_exactly_once": fit_exactly_once} 38 | 39 | def fit_finalize(self, dask_stats): 40 | return dask_stats 41 | 42 | def clear(self): 43 | self.stats = {} 44 | 45 | 46 | @pytest.mark.parametrize("transformable", transformables) 47 | @pytest.mark.parametrize("engine", ["parquet"]) 48 | def test_fitted_stat_op(transformable, engine): 49 | df = transformable({"x": np.array([1, 2, 3, 4, 5]), "y": np.array([6, 7, 8, 9, 10])}) 50 | 51 | op = FitOp() 52 | graph = ["x", "y"] >> op 53 | graph = Graph(graph) 54 | executor = DaskExecutor() 55 | 56 | executor.fit(Dataset(df), graph) 57 | assert op.stats == {"fit_exactly_once": True} 58 | 59 | executor.fit(Dataset(df), graph, refit=False) 60 | assert op.stats == {"fit_exactly_once": True} 61 | 62 | 63 | @pytest.mark.parametrize("transformable", transformables) 64 | @pytest.mark.parametrize("engine", ["parquet"]) 65 | def test_fit_op_before_transfrom(transformable, engine): 66 | df = transformable({"x": np.array([1, 2, 3, 4, 5]), "y": np.array([6, 7, 8, 9, 10])}) 67 | 68 | op = FitOp() 69 | graph = ["x", "y"] >> op 70 | graph = Graph(graph) 71 | executor = DaskExecutor() 72 | graph.construct_schema(Dataset(df).schema) 73 | with pytest.raises(RuntimeError) as exc: 74 | executor.transform(Dataset(df).to_ddf(), graph) 75 | assert "attempting to use them to transform data" in str(exc.value) 76 | 77 | executor.fit(Dataset(df), graph) 78 | executor.transform(Dataset(df).to_ddf(), graph) 79 | assert op.stats == {"fit_exactly_once": True} 80 | -------------------------------------------------------------------------------- /tests/unit/dag/ops/test_subgraph.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2022, NVIDIA CORPORATION. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | import pytest 18 | 19 | from merlin.core.protocols import Transformable 20 | from merlin.dag.executors import DaskExecutor, LocalExecutor 21 | from merlin.dag.graph import Graph 22 | from merlin.dag.operator import Operator 23 | from merlin.dag.ops.stat_operator import StatOperator 24 | from merlin.dag.ops.subgraph import Subgraph 25 | from merlin.dag.selector import ColumnSelector 26 | from merlin.io import Dataset 27 | from merlin.schema import Schema 28 | 29 | 30 | @pytest.mark.parametrize("engine", ["parquet"]) 31 | def test_subgraph(df): 32 | ops = ["x"] >> Operator() >> Operator() 33 | subgraph_op = Subgraph("subgraph", ops) 34 | main_graph_ops = ["x", "y"] >> Operator() >> subgraph_op >> Operator() 35 | 36 | main_graph = Graph(main_graph_ops) 37 | 38 | main_graph.construct_schema(Schema(list(df.columns))) 39 | 40 | result_df = LocalExecutor().transform(df, main_graph) 41 | assert (result_df == df[["x"]]).all()[0] 42 | 43 | assert main_graph.subgraph("subgraph") == subgraph_op.graph 44 | 45 | 46 | @pytest.mark.parametrize("engine", ["parquet"]) 47 | def test_subgraph_fit(dataset): 48 | class FitTestOp(StatOperator): 49 | def fit(self, col_selector: ColumnSelector, ddf): 50 | self.stats = {"fit": True} 51 | 52 | def clear(self): 53 | self.stats = {} 54 | 55 | def fit_finalize(self, dask_stats): 56 | return self.stats 57 | 58 | fit_test_op = FitTestOp() 59 | subgraph_op = Subgraph("subgraph", ["x"] >> fit_test_op) 60 | main_graph_ops = ["x", "y"] >> Operator() >> subgraph_op >> Operator() 61 | 62 | main_graph = Graph(main_graph_ops) 63 | main_graph.construct_schema(dataset.schema) 64 | 65 | executor = DaskExecutor() 66 | executor.fit(dataset, main_graph) 67 | result_df = executor.transform(dataset.to_ddf(), main_graph) 68 | 69 | assert (result_df.compute() == dataset.to_ddf().compute()[["x"]]).all()[0] 70 | assert main_graph.subgraph("subgraph").output_node.op.stats["fit"] is True 71 | 72 | 73 | @pytest.mark.parametrize("engine", ["parquet"]) 74 | def test_subgraph_looping(dataset): 75 | class LoopingTestOp(Operator): 76 | def transform( 77 | self, col_selector: ColumnSelector, transformable: Transformable 78 | ) -> Transformable: 79 | return transformable[col_selector.names] + 1.0 80 | 81 | subgraph = ["x"] >> LoopingTestOp() 82 | subgraph_op = Subgraph( 83 | "subgraph", 84 | subgraph, 85 | loop_until=lambda transformable: (transformable["x"] > 5.0).all(), 86 | ) 87 | main_graph_ops = ["x", "y"] >> Operator() >> subgraph_op >> Operator() 88 | 89 | main_graph = Graph(main_graph_ops) 90 | main_graph.construct_schema(dataset.schema) 91 | 92 | df = dataset.to_ddf().compute() 93 | df["x"] = df["x"] * 0.0 94 | dataset = Dataset(df) 95 | 96 | executor = DaskExecutor() 97 | executor.fit(dataset, main_graph) 98 | result_df = executor.transform(dataset.to_ddf(), main_graph) 99 | 100 | assert (result_df.compute()[["x"]] > 5.0).all()[0] 101 | -------------------------------------------------------------------------------- /tests/unit/dag/test_base_operator.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2022, NVIDIA CORPORATION. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | import pytest 17 | 18 | from merlin.dag.graph import Graph 19 | from merlin.dag.operator import Operator 20 | from merlin.dag.selector import ColumnSelector 21 | from merlin.schema import Schema 22 | 23 | 24 | @pytest.mark.parametrize("engine", ["parquet"]) 25 | def test_graph_validates_schemas(dataset, engine): 26 | ops = ["a", "b", "c"] >> Operator() 27 | graph = Graph(ops) 28 | 29 | with pytest.raises(ValueError) as exc_info: 30 | graph.construct_schema(dataset.schema) 31 | 32 | assert "Missing column" in str(exc_info.value) 33 | 34 | 35 | @pytest.mark.parametrize("engine", ["parquet"]) 36 | def test_compute_selector_validates_schemas(dataset, engine): 37 | op = Operator() 38 | schema = Schema(["a", "b"]) 39 | selector = ColumnSelector(["c"]) 40 | 41 | with pytest.raises(ValueError) as exc_info: 42 | op.compute_selector(schema, selector, ColumnSelector(), ColumnSelector()) 43 | 44 | assert "Missing column" in str(exc_info.value) 45 | 46 | 47 | @pytest.mark.parametrize("engine", ["parquet"]) 48 | def test_compute_input_schema_validates_schemas(dataset, engine): 49 | op = Operator() 50 | schema = Schema(["a", "b"]) 51 | selector = ColumnSelector(["c"]) 52 | 53 | with pytest.raises(ValueError) as exc_info: 54 | op.compute_input_schema(schema, Schema(), Schema(), selector) 55 | 56 | assert "Missing column" in str(exc_info.value) 57 | 58 | with pytest.raises(ValueError) as exc_info: 59 | op.compute_input_schema(Schema(), schema, Schema(), selector) 60 | 61 | assert "Missing column" in str(exc_info.value) 62 | 63 | with pytest.raises(ValueError) as exc_info: 64 | op.compute_input_schema(Schema(), Schema(), schema, selector) 65 | 66 | assert "Missing column" in str(exc_info.value) 67 | 68 | 69 | @pytest.mark.parametrize("engine", ["parquet"]) 70 | def test_compute_output_schema_validates_schemas(dataset, engine): 71 | op = Operator() 72 | schema = Schema(["a", "b"]) 73 | selector = ColumnSelector(["c"]) 74 | 75 | with pytest.raises(ValueError) as exc_info: 76 | op.compute_output_schema(schema, selector) 77 | 78 | assert "Missing column" in str(exc_info.value) 79 | -------------------------------------------------------------------------------- /tests/unit/dag/test_dag_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023, NVIDIA CORPORATION. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | import numpy as np 17 | 18 | from merlin.dag import group_values_offsets, ungroup_values_offsets 19 | 20 | 21 | def test_flat_dict_to_tuple_dict(): 22 | col1 = np.array([1, 2, 3, 4, 5]) 23 | col2_values = np.array([6, 7, 8, 9, 10]) 24 | col2_offsets = np.array([0, 2, 5]) 25 | 26 | flat_dict = {"col1": col1, "col2__values": col2_values, "col2__offsets": col2_offsets} 27 | 28 | tuple_dict = {"col1": col1, "col2": (col2_values, col2_offsets)} 29 | 30 | assert ungroup_values_offsets(tuple_dict) == flat_dict 31 | assert group_values_offsets(flat_dict) == tuple_dict 32 | -------------------------------------------------------------------------------- /tests/unit/dag/test_executors.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2022, NVIDIA CORPORATION. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | import numpy as np 18 | import pandas as pd 19 | 20 | from merlin.core.dispatch import make_df 21 | from merlin.dag import Graph 22 | from merlin.dag.executors import LocalExecutor 23 | from merlin.dag.operator import Operator 24 | from merlin.schema.schema import ColumnSchema, Schema 25 | from merlin.table import TensorTable 26 | 27 | 28 | def test_local_executor_with_dataframe(): 29 | df = make_df({"a": [1, 2, 3], "b": [4, 5, 6]}) 30 | schema = Schema([ColumnSchema("a", dtype=np.int64), ColumnSchema("b", dtype=np.int64)]) 31 | operator = ["a"] >> Operator() 32 | graph = Graph(operator) 33 | graph.construct_schema(schema) 34 | 35 | executor = LocalExecutor() 36 | result = executor.transform(df, [graph.output_node]) 37 | 38 | result_a = result["a"].to_pandas() if not isinstance(result["a"], pd.Series) else result["a"] 39 | df_a = df["a"].to_pandas() if not isinstance(df["a"], pd.Series) else result["a"] 40 | 41 | assert all(result_a == df_a) 42 | assert "b" not in result.columns 43 | 44 | 45 | def test_local_executor_with_dataframe_like(): 46 | df = TensorTable( 47 | {"a": np.array([1, 2, 3], dtype=np.int64), "b": np.array([4, 5, 6], dtype=np.int64)} 48 | ) 49 | schema = Schema([ColumnSchema("a", dtype=np.int64), ColumnSchema("b", dtype=np.int64)]) 50 | operator = ["a"] >> Operator() 51 | graph = Graph(operator) 52 | graph.construct_schema(schema) 53 | 54 | executor = LocalExecutor() 55 | result = executor.transform(df, [graph.output_node]) 56 | 57 | assert result["a"] == df["a"] 58 | assert "b" not in result.columns 59 | -------------------------------------------------------------------------------- /tests/unit/dispatch/test_lazy_dispatch.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023, NVIDIA CORPORATION. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | import pytest 17 | 18 | from merlin.dispatch.lazy import lazy_singledispatch 19 | 20 | try: 21 | import tensorflow as tf 22 | except ImportError: 23 | tf = None 24 | 25 | 26 | return_type_name = lazy_singledispatch("return_type_name") 27 | 28 | 29 | @return_type_name.register 30 | def return_int(arg: int): 31 | return "int" 32 | 33 | 34 | @return_type_name.register 35 | def return_float(arg: float): 36 | return "float" 37 | 38 | 39 | @return_type_name.register_lazy("tensorflow") 40 | def register_tf_to_array(): 41 | import tensorflow as tf # pylint:disable=reimported 42 | 43 | @return_type_name.register(tf.Tensor) 44 | def return_tensor(arg: tf.Tensor): 45 | return "tensor" 46 | 47 | 48 | @pytest.mark.skipif(not tf, reason="requires tensorflow") 49 | def test_lazy_dispatch(): 50 | result = return_type_name(5) 51 | assert result == "int" 52 | 53 | result = return_type_name(5.0) 54 | assert result == "float" 55 | 56 | result = return_type_name(tf.constant([1, 2, 3, 4])) 57 | assert result == "tensor" 58 | 59 | with pytest.raises(NotImplementedError) as exc: 60 | result = return_type_name("abc") 61 | assert "doesn't have a registered implementation" in str(exc.value) 62 | -------------------------------------------------------------------------------- /tests/unit/dtypes/test_cudf.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2023, NVIDIA CORPORATION. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | import pytest 17 | 18 | import merlin.dtypes as md 19 | from merlin.core.compat import cudf 20 | 21 | 22 | @pytest.mark.skipif(not cudf, reason="CUDF is required to test its dtypes") 23 | def test_cudf_struct_dtype(): 24 | struct_dtype = cudf.StructDtype({"a": "int64", "b": "string"}) 25 | merlin_dtype = md.dtype(struct_dtype) 26 | assert merlin_dtype == md.struct 27 | 28 | merlin_dtype = md.struct 29 | cudf_dtype = merlin_dtype.to("cudf") 30 | assert cudf_dtype == cudf.StructDtype 31 | -------------------------------------------------------------------------------- /tests/unit/dtypes/test_module.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2022, NVIDIA CORPORATION. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | import numpy 17 | import pytest 18 | 19 | import merlin.dtypes as md 20 | 21 | 22 | @pytest.mark.parametrize("python_type, merlin_type", [(int, md.int64)]) 23 | def test_python_types_convert_correctly(python_type, merlin_type): 24 | assert md.dtype(python_type) == merlin_type 25 | 26 | 27 | @pytest.mark.parametrize("numpy_type, merlin_type", [(numpy.int64, md.int64)]) 28 | def test_numpy_types_convert_correctly(numpy_type, merlin_type): 29 | assert md.dtype(numpy_type) == merlin_type 30 | 31 | 32 | @pytest.mark.parametrize( 33 | "dtype_name, dtype", 34 | [ 35 | ("int32", md.int32), 36 | ("float64", md.float64), 37 | ("string", md.string), 38 | ("struct", md.struct), 39 | ], 40 | ) 41 | def test_string_aliases_can_be_used(dtype_name, dtype): 42 | assert md.dtype(dtype_name) == dtype 43 | 44 | 45 | def test_type_mappings_can_be_registered(): 46 | class TestType: 47 | pass 48 | 49 | test_type = md.DType("test", md.ElementType.Int, 4096, signed=True) 50 | 51 | md.register("test", {test_type: TestType}) 52 | merlin_dtype = md.dtype(TestType) 53 | assert merlin_dtype == test_type 54 | 55 | 56 | def test_unknown_types_return_unknown(): 57 | class UnknownType: 58 | pass 59 | 60 | dtype = md.dtype(UnknownType) 61 | assert dtype == md.unknown 62 | -------------------------------------------------------------------------------- /tests/unit/io/test_avro.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2022, NVIDIA CORPORATION. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | import os 17 | 18 | import numpy as np 19 | import pandas as pd 20 | import pytest 21 | from dask.dataframe import assert_eq 22 | from dask.dataframe.io.demo import names as name_list 23 | 24 | import merlin.io 25 | from merlin.core.compat import cudf 26 | 27 | if cudf: 28 | dask_cudf = pytest.importorskip("dask_cudf") 29 | else: 30 | pytest.mark.skip(reason="cudf did not import successfully") 31 | # Require uavro and fastavro library. 32 | # Note that fastavro is only required to write 33 | # avro files for testing, while uavro is actually 34 | # used by AvroDatasetEngine. 35 | fa = pytest.importorskip("fastavro") 36 | pytest.importorskip("uavro") 37 | 38 | 39 | @pytest.mark.parametrize("part_size", [None, "1KB"]) 40 | @pytest.mark.parametrize("size", [100, 5000]) 41 | @pytest.mark.parametrize("nfiles", [1, 2]) 42 | def test_avro_basic(tmpdir, part_size, size, nfiles): 43 | # Define avro schema 44 | schema = fa.parse_schema( 45 | { 46 | "name": "avro.example.User", 47 | "type": "record", 48 | "fields": [ 49 | {"name": "name", "type": "string"}, 50 | {"name": "age", "type": "int"}, 51 | ], 52 | } 53 | ) 54 | 55 | # Write avro dataset with two files. 56 | # Collect block and record (row) count while writing. 57 | nblocks = 0 58 | nrecords = 0 59 | paths = [os.path.join(str(tmpdir), f"test.{i}.avro") for i in range(nfiles)] 60 | records = [] 61 | for path in paths: 62 | names = np.random.choice(name_list, size) 63 | ages = np.random.randint(18, 100, size) 64 | data = [{"name": names[i], "age": ages[i]} for i in range(size)] 65 | with open(path, "wb") as f: 66 | fa.writer(f, schema, data) 67 | with open(path, "rb") as fo: 68 | avro_reader = fa.block_reader(fo) 69 | for block in avro_reader: 70 | nrecords += block.num_records 71 | nblocks += 1 72 | records += list(block) 73 | if nfiles == 1: 74 | paths = paths[0] 75 | 76 | # Read back with dask.dataframe 77 | df = merlin.io.Dataset(paths, part_size=part_size, engine="avro").to_ddf() 78 | 79 | # Check basic length and partition count 80 | if part_size == "1KB": 81 | assert df.npartitions == nblocks 82 | assert len(df) == nrecords 83 | 84 | # Full comparison 85 | expect = pd.DataFrame.from_records(records) 86 | expect["age"] = expect["age"].astype("int32") 87 | assert_eq(df.compute().reset_index(drop=True), expect) 88 | -------------------------------------------------------------------------------- /tests/unit/schema/test_tags.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2022, NVIDIA CORPORATION. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | import pytest 17 | 18 | from merlin.schema.tags import COMPOUND_TAGS, Tags, TagSet 19 | 20 | 21 | def test_tagset_init_normalizes_tags_to_enum(): 22 | origin_tags = ["continuous", "list", "custom_tag"] 23 | tag_set = TagSet(origin_tags) 24 | assert Tags.CONTINUOUS in tag_set._tags 25 | assert Tags.LIST in tag_set._tags 26 | 27 | 28 | def test_tagset_init_collision_error(): 29 | with pytest.raises(ValueError) as err: 30 | tag_set = TagSet(["continuous", "categorical"]) # noqa 31 | 32 | assert "continuous" in str(err.value) 33 | assert "categorical" in str(err.value) 34 | assert "incompatible" in str(err.value) 35 | 36 | 37 | def test_tagset_is_iterable(): 38 | origin_tags = ["continuous", "list"] 39 | tag_set = TagSet(origin_tags) 40 | for tag in tag_set: 41 | assert tag.value in origin_tags 42 | assert len(tag_set) == len(origin_tags) 43 | 44 | 45 | def test_tagset_add(): 46 | origin_tags = [Tags.CONTINUOUS, Tags.LIST, "custom_tag"] 47 | tag_set = TagSet(origin_tags) 48 | 49 | new_tags = "custom_tag2" 50 | new_tag_set = tag_set + new_tags 51 | assert len(new_tag_set) == 4 52 | assert new_tags in new_tag_set 53 | assert all(origin_tag in new_tag_set for origin_tag in origin_tags) 54 | 55 | new_tags = ["custom_tag2"] 56 | new_tag_set = tag_set + new_tags 57 | 58 | assert len(new_tag_set) == 4 59 | assert all(new_tag in new_tag_set for new_tag in new_tags) 60 | assert all(origin_tag in new_tag_set for origin_tag in origin_tags) 61 | 62 | new_tags = TagSet(["custom_tag2"]) 63 | new_tag_set = tag_set + new_tags 64 | 65 | assert len(new_tag_set) == 4 66 | assert all(tag in new_tag_set for tag in new_tags._tags) 67 | assert all(origin_tag in new_tag_set for origin_tag in origin_tags) 68 | 69 | 70 | def test_tagset_sub(): 71 | origin_tags = [Tags.CONTINUOUS, Tags.LIST] 72 | tag_set = TagSet(origin_tags + ["custom_tag"]) 73 | assert len(tag_set) == 3 74 | 75 | new_tags = "custom_tag" 76 | new_tag_set = tag_set - new_tags 77 | assert len(new_tag_set) == 2 78 | assert new_tags not in new_tag_set 79 | assert all(origin_tag in new_tag_set for origin_tag in origin_tags) 80 | 81 | new_tags = ["custom_tag"] 82 | new_tag_set = tag_set - new_tags 83 | 84 | assert len(new_tag_set) == 2 85 | assert all(new_tag not in new_tag_set for new_tag in new_tags) 86 | assert all(origin_tag in new_tag_set for origin_tag in origin_tags) 87 | 88 | new_tags = TagSet(["custom_tag"]) 89 | new_tag_set = tag_set - new_tags 90 | 91 | assert len(new_tag_set) == 2 92 | assert all(new_tag not in new_tag_set for new_tag in new_tags) 93 | assert all(origin_tag in new_tag_set for origin_tag in origin_tags) 94 | 95 | 96 | def test_tagset_add_collision_error(): 97 | origin_tags = ["continuous", "list", "custom_tag"] 98 | tag_set = TagSet(origin_tags) 99 | 100 | new_tags = "categorical" 101 | 102 | with pytest.raises(ValueError) as err: 103 | tag_set + new_tags # pylint: disable=W0104 104 | 105 | assert "continuous" in str(err.value) 106 | assert "categorical" in str(err.value) 107 | assert "incompatible" in str(err.value) 108 | 109 | new_tags = ["categorical"] 110 | 111 | with pytest.raises(ValueError) as err: 112 | tag_set + new_tags # pylint: disable=W0104 113 | 114 | assert "continuous" in str(err.value) 115 | assert "categorical" in str(err.value) 116 | assert "incompatible" in str(err.value) 117 | 118 | new_tags = TagSet(["categorical"]) 119 | 120 | with pytest.raises(ValueError) as err: 121 | tag_set + new_tags # pylint: disable=W0104 122 | 123 | assert "continuous" in str(err.value) 124 | assert "categorical" in str(err.value) 125 | assert "incompatible" in str(err.value) 126 | 127 | 128 | def test_tagset_atomizes_compound_tags(): 129 | for tag, atomic_tags in COMPOUND_TAGS.items(): 130 | tag_set = TagSet([tag]) 131 | assert tag not in tag_set 132 | for atomic_tag in atomic_tags: 133 | assert atomic_tag in tag_set 134 | -------------------------------------------------------------------------------- /tests/unit/utils/test_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2022, NVIDIA CORPORATION. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | import pytest 18 | 19 | from merlin.core.compat import HAS_GPU 20 | from merlin.core.utils import Distributed, Serial, global_dask_client, set_dask_client 21 | 22 | if HAS_GPU: 23 | _CPU = [True, False] 24 | else: 25 | _CPU = [True] 26 | 27 | 28 | @pytest.mark.parametrize("cpu", _CPU) 29 | def test_serial_context(client, cpu): 30 | # Set distributed client 31 | set_dask_client(client=client) 32 | assert global_dask_client() == client 33 | 34 | # Check that the global dask client 35 | # becomes None in a `with Serial()` block 36 | with Serial(): 37 | assert global_dask_client() is None 38 | 39 | # Global client should revert outside 40 | # the `with Serial()` block 41 | assert global_dask_client() == client 42 | 43 | 44 | @pytest.mark.parametrize("cpu", _CPU) 45 | @pytest.mark.parametrize("nested_serial", [True, False]) 46 | def test_nvt_distributed(cpu, nested_serial): 47 | if cpu: 48 | distributed = pytest.importorskip("distributed") 49 | cluster_type = "cpu" 50 | cluster_cls = distributed.LocalCluster 51 | else: 52 | dask_cuda = pytest.importorskip("dask_cuda") 53 | cluster_type = "cuda" 54 | cluster_cls = dask_cuda.LocalCUDACluster 55 | 56 | # Set the global client to None 57 | set_dask_client(client=None) 58 | assert global_dask_client() is None 59 | 60 | # Check that a new local cluster is deployed within 61 | # a `with Distributed()` block 62 | with Distributed(cluster_type=cluster_type, n_workers=1, force_new=True) as dist: 63 | assert dist.client is not None 64 | assert global_dask_client() == dist.client 65 | assert len(dist.cluster.workers) == 1 66 | assert isinstance(dist.cluster, cluster_cls) 67 | 68 | # Check that we can nest a `with Serial()` block 69 | # inside a `with Distributed()` block 70 | if nested_serial: 71 | with Serial(): 72 | assert global_dask_client() is None 73 | assert global_dask_client() == dist.client 74 | 75 | # Global client should revert to None outside 76 | # the `with Distributed()` block 77 | assert global_dask_client() is None 78 | 79 | 80 | @pytest.mark.parametrize("cpu", _CPU) 81 | def test_nvt_distributed_force(client, cpu): 82 | if cpu: 83 | distributed = pytest.importorskip("distributed") 84 | cluster_type = "cpu" 85 | cluster_cls = distributed.LocalCluster 86 | else: 87 | dask_cuda = pytest.importorskip("dask_cuda") 88 | cluster_type = "cuda" 89 | cluster_cls = dask_cuda.LocalCUDACluster 90 | 91 | # Set distributed client 92 | set_dask_client(client=client) 93 | assert global_dask_client() == client 94 | 95 | # Check that a new local cluster is deployed within 96 | # a `with Distributed()` block. Since we are using 97 | # `force_new=True`, the new cluster should NOT be 98 | # the same as the original `client`. 99 | with Distributed(cluster_type=cluster_type, force_new=True, n_workers=1) as dist: 100 | assert dist.client != client 101 | assert global_dask_client() == dist.client 102 | assert len(dist.cluster.workers) == 1 103 | assert isinstance(dist.cluster, cluster_cls) 104 | 105 | # We should revert to the original client 106 | # outside the `with Distributed()` block 107 | assert global_dask_client() == client 108 | 109 | # Check that the default behavior is to avoid 110 | # deploying a new cluster (and warning the user) 111 | # if an existing client is detected 112 | with pytest.warns(UserWarning): 113 | with Distributed(cluster_type=cluster_type, n_workers=1) as dist: 114 | assert dist.client == client 115 | assert global_dask_client() == dist.client 116 | 117 | # We should revert to the original client 118 | # outside the `with Distributed()` block 119 | assert global_dask_client() == client 120 | --------------------------------------------------------------------------------