├── .gitattributes
├── .github
├── actionlint-matcher.json
├── actionlint.yaml
├── copy-pr-bot.yaml
├── release-drafter.yml
└── workflows
│ ├── ISSUE_TEMPLATE
│ ├── bug-report.md
│ ├── documentation-request.md
│ ├── feature-request.md
│ ├── submit-question.md
│ └── task.md
│ ├── check-base-branch.yaml
│ ├── cpu-ci.yml
│ ├── docs-preview-pr.yaml
│ ├── docs-remove-stale-reviews.yaml
│ ├── docs-sched-rebuild.yaml
│ ├── gpu-ci.yml
│ ├── lint.yaml
│ ├── merlin.yml
│ ├── packages.yaml
│ ├── release-drafter.yaml
│ ├── require-label.yaml
│ ├── set-stable-branch.yaml
│ ├── tox.yml
│ └── triage.yml
├── .gitignore
├── .pre-commit-config.yaml
├── .prettierignore
├── .pylintrc
├── CLA.md
├── CONTRIBUTING.md
├── LICENSE
├── MANIFEST.in
├── README.md
├── ci
├── ignore_codespell_words.txt
├── pr.gpu.Jenkinsfile
└── test_unit.sh
├── conda
└── recipe
│ └── meta.yaml
├── docs
├── Makefile
├── README.md
├── make.bat
└── source
│ ├── _static
│ ├── .gitkeep
│ ├── NVIDIA-LogoBlack.svg
│ ├── NVIDIA-LogoWhite.svg
│ ├── css
│ │ ├── custom.css
│ │ └── versions.css
│ ├── favicon.png
│ └── js
│ │ └── rtd-version-switcher.js
│ ├── _templates
│ ├── footer.html
│ ├── layout.html
│ ├── merlin-ecosystem.html
│ └── versions.html
│ ├── api
│ ├── index.rst
│ ├── merlin.dag.rst
│ ├── merlin.io.rst
│ └── merlin.schema.rst
│ ├── conf.py
│ ├── index.rst
│ └── toc.yaml
├── merlin
├── config
│ └── __init__.py
├── core
│ ├── __init__.py
│ ├── _version.py
│ ├── compat
│ │ ├── __init__.py
│ │ ├── tensorflow.py
│ │ └── torch.py
│ ├── dispatch.py
│ ├── has_gpu.py
│ ├── protocols.py
│ └── utils.py
├── dag
│ ├── __init__.py
│ ├── executors.py
│ ├── graph.py
│ ├── node.py
│ ├── operator.py
│ ├── ops
│ │ ├── __init__.py
│ │ ├── add_metadata.py
│ │ ├── concat_columns.py
│ │ ├── grouping.py
│ │ ├── rename.py
│ │ ├── selection.py
│ │ ├── stat_operator.py
│ │ ├── subgraph.py
│ │ ├── subset_columns.py
│ │ ├── subtraction.py
│ │ └── udf.py
│ ├── runtime.py
│ ├── selector.py
│ └── utils.py
├── dispatch
│ └── lazy.py
├── dtypes
│ ├── __init__.py
│ ├── aliases.py
│ ├── base.py
│ ├── mapping.py
│ ├── mappings
│ │ ├── __init__.py
│ │ ├── cudf.py
│ │ ├── merlin.py
│ │ ├── numpy.py
│ │ ├── pandas.py
│ │ ├── python.py
│ │ ├── tf.py
│ │ ├── torch.py
│ │ └── triton.py
│ ├── registry.py
│ └── shape.py
├── io
│ ├── __init__.py
│ ├── avro.py
│ ├── csv.py
│ ├── dask.py
│ ├── dataframe_engine.py
│ ├── dataframe_iter.py
│ ├── dataset.py
│ ├── dataset_engine.py
│ ├── fsspec_utils.py
│ ├── hugectr.py
│ ├── parquet.py
│ ├── shuffle.py
│ ├── worker.py
│ ├── writer.py
│ └── writer_factory.py
├── schema
│ ├── __init__.py
│ ├── io
│ │ ├── __init__.py
│ │ ├── proto_utils.py
│ │ ├── schema_bp.py
│ │ └── tensorflow_metadata.py
│ ├── schema.py
│ └── tags.py
├── table
│ ├── __init__.py
│ ├── conversions.py
│ ├── cupy_column.py
│ ├── numpy_column.py
│ ├── tensor_column.py
│ ├── tensor_table.py
│ ├── tensorflow_column.py
│ └── torch_column.py
└── testing
│ ├── __init__.py
│ └── assert_equals.py
├── mypy.ini
├── pyproject.toml
├── requirements-dev.txt
├── requirements-docs.txt
├── requirements-gpu.txt
├── requirements-test-cpu.txt
├── requirements-test-gpu.txt
├── requirements-test.txt
├── requirements.txt
├── setup.cfg
├── setup.py
├── tests
├── __init__.py
├── conftest.py
└── unit
│ ├── core
│ ├── test_dispatch.py
│ ├── test_protocols.py
│ └── test_version.py
│ ├── dag
│ ├── ops
│ │ ├── test_addmetadata.py
│ │ ├── test_rename.py
│ │ ├── test_selection.py
│ │ ├── test_stat_op.py
│ │ ├── test_subgraph.py
│ │ └── test_udf.py
│ ├── test_base_operator.py
│ ├── test_column_selector.py
│ ├── test_dag_utils.py
│ ├── test_executors.py
│ └── test_graph.py
│ ├── dispatch
│ └── test_lazy_dispatch.py
│ ├── dtypes
│ ├── test_cudf.py
│ ├── test_module.py
│ └── test_shape.py
│ ├── io
│ ├── test_avro.py
│ ├── test_dataset.py
│ ├── test_io.py
│ └── test_worker.py
│ ├── schema
│ ├── test_column_schemas.py
│ ├── test_schema.py
│ ├── test_schema_io.py
│ └── test_tags.py
│ ├── table
│ ├── test_convert_column.py
│ ├── test_tensor_column.py
│ └── test_tensor_table.py
│ └── utils
│ └── test_utils.py
├── tox.ini
└── versioneer.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | merlin/core/_version.py export-subst
2 |
--------------------------------------------------------------------------------
/.github/actionlint-matcher.json:
--------------------------------------------------------------------------------
1 | {
2 | "problemMatcher": [
3 | {
4 | "owner": "actionlint",
5 | "pattern": [
6 | {
7 | "regexp": "^(?:\\x1b\\[\\d+m)?(.+?)(?:\\x1b\\[\\d+m)*:(?:\\x1b\\[\\d+m)*(\\d+)(?:\\x1b\\[\\d+m)*:(?:\\x1b\\[\\d+m)*(\\d+)(?:\\x1b\\[\\d+m)*: (?:\\x1b\\[\\d+m)*(.+?)(?:\\x1b\\[\\d+m)* \\[(.+?)\\]$",
8 | "file": 1,
9 | "line": 2,
10 | "column": 3,
11 | "message": 4,
12 | "code": 5
13 | }
14 | ]
15 | }
16 | ]
17 | }
18 |
--------------------------------------------------------------------------------
/.github/actionlint.yaml:
--------------------------------------------------------------------------------
1 | self-hosted-runner:
2 | # Labels of self-hosted runner in array of string
3 | labels:
4 | - 1GPU
5 | - 2GPU
6 | - linux-amd64-gpu-p100-latest-1
7 |
--------------------------------------------------------------------------------
/.github/copy-pr-bot.yaml:
--------------------------------------------------------------------------------
1 | # Configuration file for `copy-pr-bot` GitHub App
2 | # https://docs.gha-runners.nvidia.com/apps/copy-pr-bot/
3 |
4 | enabled: true
5 |
--------------------------------------------------------------------------------
/.github/release-drafter.yml:
--------------------------------------------------------------------------------
1 | categories:
2 | - title: "⚠ Breaking Changes"
3 | labels:
4 | - "breaking"
5 | - title: "🐜 Bug Fixes"
6 | labels:
7 | - "bug"
8 | - title: "🚀 Features"
9 | labels:
10 | - "feature"
11 | - "enhancement"
12 | - title: "📄 Documentation"
13 | labels:
14 | - "documentation"
15 | - "examples"
16 | - title: "🔧 Maintenance"
17 | labels:
18 | - "build"
19 | - "dependencies"
20 | - "chore"
21 | - "ci"
22 | change-template: "- $TITLE @$AUTHOR (#$NUMBER)"
23 | exclude-labels:
24 | - "skip-changelog"
25 | template: |
26 | ## What’s Changed
27 |
28 | $CHANGES
29 |
--------------------------------------------------------------------------------
/.github/workflows/ISSUE_TEMPLATE/bug-report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: "\U0001F41B Bug Report"
3 | about: Submit a bug report to help us improve Merlin Core
4 | title: "[BUG]"
5 | labels: "status/needs-triage, bug"
6 | assignees: ""
7 | ---
8 |
9 | ### Bug description
10 |
11 |
12 |
13 | ### Steps/Code to reproduce bug
14 |
15 |
16 |
17 | 1.
18 | 2.
19 | 3.
20 |
21 | ### Expected behavior
22 |
23 |
24 |
25 | ### Environment details
26 |
27 | - Merlin version:
28 | - Platform:
29 | - Python version:
30 | - PyTorch version (GPU?):
31 | - Tensorflow version (GPU?):
32 |
33 | ### Additional context
34 |
35 |
36 |
--------------------------------------------------------------------------------
/.github/workflows/ISSUE_TEMPLATE/documentation-request.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Documentation request
3 | about: Report incorrect or needed documentation
4 | title: "[DOC]"
5 | labels: "status/needs-triage, area/documentation"
6 | assignees: ""
7 | ---
8 |
9 | ## Report incorrect documentation
10 |
11 | ### Location of incorrect documentation
12 |
13 |
14 |
15 | ### Describe the problems or issues found in the documentation
16 |
17 |
18 |
19 | ### Steps taken to verify documentation is incorrect
20 |
21 |
22 |
23 | ### Suggested fix for documentation
24 |
25 |
26 |
27 | ---
28 |
29 | ## Report needed documentation
30 |
31 | ### Report needed documentation
32 |
33 |
34 |
35 | ### Describe the documentation you'd like
36 |
37 |
38 |
39 | ### Steps taken to search for needed documentation\*\*
40 |
41 |
42 |
--------------------------------------------------------------------------------
/.github/workflows/ISSUE_TEMPLATE/feature-request.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: "\U0001F680 Feature request"
3 | about: Submit a proposal/request for a new Merlin Core feature
4 | title: "[FEA]"
5 | labels: "status/needs-triage, kind/feature-request"
6 | assignees: ""
7 | ---
8 |
9 | # 🚀 Feature request
10 |
11 |
13 |
14 | ## Motivation
15 |
16 |
19 |
20 | ## Your contribution
21 |
22 |
25 |
--------------------------------------------------------------------------------
/.github/workflows/ISSUE_TEMPLATE/submit-question.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: "❓ Questions & Help"
3 | about: Ask a general question about Merlin Core
4 | title: "[QST]"
5 | labels: "status/needs-triage, kind/question"
6 | assignees: ""
7 | ---
8 |
9 | # ❓ Questions & Help
10 |
11 | ## Details
12 |
13 |
14 |
--------------------------------------------------------------------------------
/.github/workflows/ISSUE_TEMPLATE/task.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Task
3 | about: A general task that we're tracking in Github
4 | title: "[Task]"
5 | labels: ""
6 | assignees: ""
7 | ---
8 |
9 | ### Description
10 |
11 | **Additional context**
12 | Add any other context, code examples, or references to existing implementations about the task here.
13 |
--------------------------------------------------------------------------------
/.github/workflows/check-base-branch.yaml:
--------------------------------------------------------------------------------
1 | name: Require Development Base Branch
2 |
3 | on:
4 | pull_request:
5 | types: [synchronize, opened, reopened, labeled, unlabeled]
6 |
7 | jobs:
8 | check:
9 | uses: NVIDIA-Merlin/.github/.github/workflows/check-base-branch.yml@main
10 |
--------------------------------------------------------------------------------
/.github/workflows/cpu-ci.yml:
--------------------------------------------------------------------------------
1 | name: Core Tests (CPU)
2 |
3 | on:
4 | workflow_dispatch:
5 | push:
6 | branches: [main]
7 | tags:
8 | - "v[0-9]+.[0-9]+.[0-9]+"
9 | pull_request:
10 | branches: [main]
11 |
12 | jobs:
13 | build:
14 | uses: ./.github/workflows/tox.yml
15 | with:
16 | env: test-cpu
17 |
18 | store-pr-information:
19 | needs: [build]
20 | runs-on: ubuntu-latest
21 | steps:
22 | - name: Store PR information
23 | run: |
24 | mkdir ./pr
25 | echo ${{ github.event.number }} > ./pr/pr.txt
26 | echo ${{ github.event.pull_request.merged }} > ./pr/merged.txt
27 | echo ${{ github.event.action }} > ./pr/action.txt
28 | - name: Upload PR information
29 | uses: actions/upload-artifact@v2
30 | with:
31 | name: pr
32 | path: pr/
33 |
--------------------------------------------------------------------------------
/.github/workflows/docs-preview-pr.yaml:
--------------------------------------------------------------------------------
1 | name: docs-preview-pr
2 |
3 | on:
4 | workflow_run:
5 | workflows: ["Packages"]
6 | types: [completed]
7 |
8 | env:
9 | WF_ID: ${{ github.event.workflow_run.id }}
10 |
11 | jobs:
12 | preview:
13 | uses: nvidia-merlin/.github/.github/workflows/docs-preview-pr-common.yaml@main
14 |
--------------------------------------------------------------------------------
/.github/workflows/docs-remove-stale-reviews.yaml:
--------------------------------------------------------------------------------
1 | name: docs-remove-stale-reviews
2 |
3 | on:
4 | schedule:
5 | # 42 minutes after 0:00 UTC on Sundays
6 | - cron: "42 0 * * 0"
7 | workflow_dispatch:
8 |
9 | jobs:
10 | remove:
11 | uses: nvidia-merlin/.github/.github/workflows/docs-remove-stale-reviews-common.yaml@main
12 |
--------------------------------------------------------------------------------
/.github/workflows/docs-sched-rebuild.yaml:
--------------------------------------------------------------------------------
1 | name: docs-sched-rebuild
2 |
3 | on:
4 | push:
5 | branches: [main]
6 | tags:
7 | - "v[0-9]+.[0-9]+.[0-9]+"
8 | workflow_dispatch:
9 |
10 | jobs:
11 | build:
12 | runs-on: [ubuntu-latest]
13 |
14 | steps:
15 | - uses: actions/checkout@v3
16 | with:
17 | fetch-depth: 0
18 | ref: main
19 | - name: Set up Python 3.9
20 | uses: actions/setup-python@v4
21 | with:
22 | python-version: 3.9
23 | - name: Install Ubuntu packages
24 | run: |
25 | sudo apt-get update -y
26 | sudo apt-get install -y protobuf-compiler
27 | - name: Install dependencies
28 | run: |
29 | python -m pip install --upgrade pip setuptools==59.4.0 wheel tox
30 | - name: Building docs (multiversion)
31 | run: |
32 | # setup local branches that we'd like to build docs for
33 | # required for sphinx-multiversion to find these
34 | git branch --track stable origin/stable || true
35 | tox -e docs-multi
36 | - name: Delete unnecessary files
37 | run: |
38 | find docs/build -name .doctrees -prune -exec rm -rf {} \;
39 | find docs/build -name .buildinfo -exec rm {} \;
40 | - name: Upload HTML
41 | uses: actions/upload-artifact@v3
42 | with:
43 | name: html-build-artifact
44 | path: docs/build/html
45 | if-no-files-found: error
46 | retention-days: 1
47 |
48 | # Identify the dir for the HTML.
49 | store-html:
50 | needs: [build]
51 | runs-on: ubuntu-latest
52 | steps:
53 | - uses: actions/checkout@v3
54 | with:
55 | ref: "gh-pages"
56 | - name: Initialize Git configuration
57 | run: |
58 | git config user.name docs-sched-rebuild
59 | git config user.email do-not-send-@github.com
60 | - name: Download artifacts
61 | uses: actions/download-artifact@v3
62 | with:
63 | name: html-build-artifact
64 | - name: Copy HTML directories
65 | run: |
66 | ls -asl
67 | for i in */
68 | do
69 | echo "Git adding ${i}"
70 | git add "${i}"
71 | done
72 | - name: Check or create dot-no-jekyll file
73 | run: |
74 | if [ -f ".nojekyll" ]; then
75 | echo "The dot-no-jekyll file already exists."
76 | exit 0
77 | fi
78 | touch .nojekyll
79 | git add .nojekyll
80 | - name: Check or create redirect page
81 | env:
82 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
83 | run: |
84 | resp=$(grep 'http-equiv="refresh"' index.html 2>/dev/null) || true
85 | if [ -n "${resp}" ]; then
86 | echo "The redirect file already exists."
87 | exit 0
88 | fi
89 |
90 | # If any of these commands fail, fail the build.
91 | def_branch="stable"
92 | html_url=$(gh api "repos/${GITHUB_REPOSITORY}/pages" --jq ".html_url")
93 |
94 | cat > index.html << EOF
95 |
96 |
97 |
98 | Redirect to documentation
99 |
100 |
102 |
107 |
108 |
109 | Please follow the link to the
110 | ${def_branch}' branch documentation.
111 |
112 |
113 | EOF
114 |
115 | git add index.html
116 | - name: Commit changes to the GitHub Pages branch
117 | run: |
118 | git status
119 | if git commit -m 'Pushing changes to GitHub Pages.'; then
120 | git push -f
121 | else
122 | echo "Nothing changed."
123 | fi
124 |
--------------------------------------------------------------------------------
/.github/workflows/gpu-ci.yml:
--------------------------------------------------------------------------------
1 | name: GPU CI
2 |
3 | on:
4 | workflow_dispatch:
5 | push:
6 | branches:
7 | - main
8 | - "pull-request/[0-9]+"
9 | tags:
10 | - "v[0-9]+.[0-9]+.[0-9]+"
11 |
12 | jobs:
13 | gpu-ci:
14 | runs-on: linux-amd64-gpu-p100-latest-1
15 | container:
16 | image: nvcr.io/nvstaging/merlin/merlin-ci-runner:latest
17 | env:
18 | NVIDIA_VISIBLE_DEVICES: ${{ env.NVIDIA_VISIBLE_DEVICES }}
19 | options: --shm-size=1G
20 | credentials:
21 | username: $oauthtoken
22 | password: ${{ secrets.NGC_TOKEN }}
23 |
24 | steps:
25 | - uses: actions/checkout@v3
26 | with:
27 | fetch-depth: 0
28 | - name: Run tests
29 | run: |
30 | ref_type=${{ github.ref_type }}
31 | branch=main
32 | if [[ $ref_type == "tag"* ]]
33 | then
34 | raw=$(git branch -r --contains ${{ github.ref_name }})
35 | branch=${raw/origin\/}
36 | fi
37 | tox -e test-gpu -- "$branch"
38 |
39 | gpu-ci-not-visible:
40 | runs-on: linux-amd64-gpu-p100-latest-1
41 | container:
42 | image: nvcr.io/nvstaging/merlin/merlin-ci-runner:latest
43 | env:
44 | NVIDIA_VISIBLE_DEVICES: ${{ env.NVIDIA_VISIBLE_DEVICES }}
45 | options: --shm-size=1G
46 | credentials:
47 | username: $oauthtoken
48 | password: ${{ secrets.NGC_TOKEN }}
49 |
50 | steps:
51 | - uses: actions/checkout@v3
52 | with:
53 | fetch-depth: 0
54 | - name: Run tests
55 | run: |
56 | ref_type=${{ github.ref_type }}
57 | branch=main
58 | if [[ $ref_type == "tag"* ]]
59 | then
60 | raw=$(git branch -r --contains ${{ github.ref_name }})
61 | branch=${raw/origin\/}
62 | fi
63 | tox -e test-gpu-not-visible -- "$branch"
64 |
--------------------------------------------------------------------------------
/.github/workflows/lint.yaml:
--------------------------------------------------------------------------------
1 | name: lint
2 |
3 | on:
4 | pull_request:
5 | push:
6 | branches: [main]
7 | tags:
8 | - "v[0-9]+.[0-9]+.[0-9]+"
9 |
10 | jobs:
11 | pre-commit:
12 | runs-on: ubuntu-latest
13 | steps:
14 | - uses: actions/checkout@v3
15 | - uses: actions/setup-python@v4
16 | - uses: pre-commit/action@v3.0.0
17 |
18 | actionlint:
19 | runs-on: ubuntu-latest
20 | steps:
21 | - uses: actions/checkout@v3
22 | - name: Check workflow files
23 | run: |
24 | echo "::add-matcher::.github/actionlint-matcher.json"
25 | bash <(curl https://raw.githubusercontent.com/rhysd/actionlint/fd7ba3c382e13dcc0248e425b4cbc3f1185fa3ee/scripts/download-actionlint.bash)
26 | ./actionlint
27 | shell: bash
28 |
--------------------------------------------------------------------------------
/.github/workflows/merlin.yml:
--------------------------------------------------------------------------------
1 | name: "Test Merlin"
2 |
3 | on:
4 | workflow_dispatch:
5 | push:
6 | branches: [main]
7 | tags:
8 | - "v[0-9]+.[0-9]+.[0-9]+"
9 | pull_request:
10 | branches: [main]
11 |
12 | jobs:
13 | dataloader:
14 | name: "Dataloader (CPU)"
15 | uses: ./.github/workflows/tox.yml
16 | with:
17 | env: test-dataloader-cpu
18 |
19 | systems:
20 | name: "Systems (CPU)"
21 | uses: ./.github/workflows/tox.yml
22 | with:
23 | env: test-systems-cpu
24 |
25 | models:
26 | name: "Models (CPU)"
27 | uses: ./.github/workflows/tox.yml
28 | with:
29 | env: test-models-cpu
30 |
31 | nvtabular:
32 | name: "NVTabular (CPU)"
33 | uses: ./.github/workflows/tox.yml
34 | with:
35 | env: test-nvtabular-cpu
36 |
--------------------------------------------------------------------------------
/.github/workflows/release-drafter.yaml:
--------------------------------------------------------------------------------
1 | name: release-drafter
2 |
3 | on:
4 | push:
5 | # trigger on tags only
6 | tags:
7 | - "v[0-9]+.[0-9]+.[0-9]+"
8 |
9 | workflow_dispatch:
10 |
11 | jobs:
12 | update_release_draft:
13 | uses: nvidia-merlin/.github/.github/workflows/release-drafter-common.yaml@main
14 |
--------------------------------------------------------------------------------
/.github/workflows/require-label.yaml:
--------------------------------------------------------------------------------
1 | name: Require PR Labels
2 |
3 | on:
4 | pull_request:
5 | types: [synchronize, opened, reopened, labeled, unlabeled]
6 |
7 | jobs:
8 | check-labels:
9 | uses: nvidia-merlin/.github/.github/workflows/require-label.yaml@main
10 |
--------------------------------------------------------------------------------
/.github/workflows/set-stable-branch.yaml:
--------------------------------------------------------------------------------
1 | name: Set Stable Branch
2 |
3 | on:
4 | workflow_dispatch:
5 | release:
6 | types: [published, deleted]
7 |
8 | jobs:
9 | set-stable-branch:
10 | uses: NVIDIA-Merlin/.github/.github/workflows/set-stable-branch.yaml@main
11 |
--------------------------------------------------------------------------------
/.github/workflows/tox.yml:
--------------------------------------------------------------------------------
1 | name: "Run Tox Env"
2 |
3 | on:
4 | workflow_call:
5 | inputs:
6 | env:
7 | description: "The name of the tox environment to run"
8 | required: true
9 | type: string
10 |
11 | jobs:
12 | check:
13 | runs-on: ${{ matrix.os }}
14 | strategy:
15 | matrix:
16 | python-version: [3.8]
17 | os: [ubuntu-latest]
18 |
19 | steps:
20 | - uses: actions/checkout@v3
21 | - name: Set up Python ${{ matrix.python-version }}
22 | uses: actions/setup-python@v4
23 | with:
24 | python-version: ${{ matrix.python-version }}
25 | - name: Install Ubuntu packages
26 | run: |
27 | sudo apt-get update -y
28 | sudo apt-get install -y protobuf-compiler
29 | - name: Install and upgrade python packages
30 | run: |
31 | python -m pip install --upgrade pip setuptools==59.4.0 wheel tox
32 | - name: Get Branch name
33 | id: get-branch-name
34 | uses: NVIDIA-Merlin/.github/actions/branch-name@main
35 | - name: Run tests
36 | run: |
37 | branch="${{ steps.get-branch-name.outputs.branch }}"
38 | GIT_COMMIT=$(git rev-parse HEAD) tox -e ${{ inputs.env }} -- $branch
39 |
--------------------------------------------------------------------------------
/.github/workflows/triage.yml:
--------------------------------------------------------------------------------
1 | name: triage_issues
2 | on:
3 | issues:
4 | types: [opened, reopened]
5 |
6 | jobs:
7 | triage_issue:
8 | uses: nvidia-merlin/.github/.github/workflows/triage.yaml@main
9 | secrets:
10 | TRIAGE_APP_ID: ${{ secrets.TRIAGE_APP_ID }}
11 | TRIAGE_APP_PEM: ${{ secrets.TRIAGE_APP_PEM }}
12 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | **/.ipynb_checkpoints/*
2 | /.*_checkpoints/
3 | .ipynb_checkpoints/*
4 | */.jupyter/*
5 | */.local/*
6 |
7 | cache/*
8 | data*/*
9 | *.parquet
10 | *.orc
11 | *.csv
12 |
13 |
14 | # Byte-compiled / optimized / DLL files
15 | __pycache__/
16 | *.py[cod]
17 | *$py.class
18 |
19 | # C extensions
20 | *.so
21 |
22 | # Distribution / packaging
23 | .Python
24 | build/
25 | develop-eggs/
26 | dist/
27 | downloads/
28 | eggs/
29 | .eggs/
30 | lib/
31 | lib64/
32 | parts/
33 | sdist/
34 | var/
35 | wheels/
36 | share/python-wheels/
37 | *.egg-info/
38 | .installed.cfg
39 | *.egg
40 | MANIFEST
41 |
42 | # Installer logs
43 | pip-log.txt
44 | pip-delete-this-directory.txt
45 |
46 | # Unit test / coverage reports
47 | htmlcov/
48 | .tox/
49 | .nox/
50 | .coverage
51 | .coverage.*
52 | .cache
53 | nosetests.xml
54 | coverage.xml
55 | *.cover
56 | *.py,cover
57 | .hypothesis/
58 | .pytest_cache/
59 | cover/
60 |
61 | # Sphinx documentation
62 | docs/build/
63 | docs/source/README.md
64 | docs/source/LICENSE
65 |
66 | # IPython
67 | profile_default/
68 | ipython_config.py
69 |
70 | # mypy
71 | .mypy_cache/
72 | .dmypy.json
73 | dmypy.json
74 |
75 | # PyCharm
76 | .idea
77 |
78 | .vscode/
79 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | # imports
3 | - repo: https://github.com/MarcoGorelli/absolufy-imports
4 | rev: v0.3.1
5 | hooks:
6 | - id: absolufy-imports
7 | - repo: https://github.com/timothycrosley/isort
8 | rev: 5.12.0
9 | hooks:
10 | - id: isort
11 | additional_dependencies: [toml]
12 | exclude: ^examples/
13 | # types
14 | - repo: https://github.com/pre-commit/mirrors-mypy
15 | rev: "v0.991"
16 | hooks:
17 | - id: mypy
18 | language_version: python3
19 | args:
20 | [
21 | --non-interactive,
22 | --install-types,
23 | --namespace-packages,
24 | --explicit-package-bases,
25 | ]
26 | exclude: ^docs/
27 | # code style
28 | - repo: https://github.com/python/black
29 | rev: 23.1.0
30 | hooks:
31 | - id: black
32 | - repo: https://github.com/pycqa/pylint
33 | rev: v2.16.0
34 | hooks:
35 | - id: pylint
36 | - repo: https://github.com/pycqa/flake8
37 | rev: 6.0.0
38 | hooks:
39 | - id: flake8
40 | - repo: https://github.com/pre-commit/mirrors-prettier
41 | rev: v2.7.1
42 | hooks:
43 | - id: prettier
44 | types_or: [yaml, markdown]
45 | # notebooks
46 | - repo: https://github.com/s-weigand/flake8-nb
47 | rev: v0.5.2
48 | hooks:
49 | - id: flake8-nb
50 | files: \.ipynb$
51 | # documentation
52 | - repo: https://github.com/econchick/interrogate
53 | rev: 1.5.0
54 | hooks:
55 | - id: interrogate
56 | exclude: ^(build|docs|merlin/io|tests|setup.py|versioneer.py)
57 | args: [--config=pyproject.toml]
58 | - repo: https://github.com/codespell-project/codespell
59 | rev: v2.2.2
60 | hooks:
61 | - id: codespell
62 | # security
63 | - repo: https://github.com/PyCQA/bandit
64 | rev: 1.7.4
65 | hooks:
66 | - id: bandit
67 | args: [--verbose, -ll, -x, tests, examples, bench]
68 |
--------------------------------------------------------------------------------
/.prettierignore:
--------------------------------------------------------------------------------
1 | build
2 | conda
3 |
--------------------------------------------------------------------------------
/.pylintrc:
--------------------------------------------------------------------------------
1 | [MASTER]
2 |
3 | ignore-patterns=_version.py,model_config_pb2.py
4 |
5 | extension-pkg-allow-list=hugectr,nvtabular_cpp
6 |
7 | [MESSAGES CONTROL]
8 | disable=fixme,
9 | # docstrings aren't required (yet).
10 | missing-function-docstring,
11 | missing-module-docstring,
12 | missing-class-docstring,
13 |
14 | # formatting checks that we're handling with black/isort
15 | wrong-import-order,
16 | wrong-import-position,
17 | ungrouped-imports,
18 | line-too-long,
19 | superfluous-parens,
20 | trailing-whitespace,
21 |
22 | # we'll probably never enable these checks
23 | invalid-name,
24 | import-error,
25 |
26 | # disable code-complexity checks for now
27 | # TODO: should we configure the thresholds for these rather than just disable?
28 | too-many-function-args,
29 | too-many-instance-attributes,
30 | too-many-locals,
31 | too-many-branches,
32 | too-many-nested-blocks,
33 | too-many-statements,
34 | too-many-arguments,
35 | too-many-return-statements,
36 | too-many-lines,
37 | too-few-public-methods,
38 |
39 | # many of these checks would be great to include at some point, but would
40 | # require some changes to our codebase
41 | useless-return,
42 | protected-access,
43 | arguments-differ,
44 | unused-argument,
45 | unused-variable,
46 | abstract-method,
47 | no-name-in-module,
48 | attribute-defined-outside-init,
49 | redefined-outer-name,
50 | import-outside-toplevel,
51 | no-else-continue,
52 | no-else-return,
53 | no-else-raise,
54 | no-member,
55 | super-with-arguments,
56 | unsupported-assignment-operation,
57 | inconsistent-return-statements,
58 | duplicate-string-formatting-argument,
59 | len-as-condition,
60 | cyclic-import,
61 |
62 | # producing false positives
63 | unexpected-keyword-arg,
64 | not-an-iterable,
65 | unsubscriptable-object
66 |
67 | [SIMILARITIES]
68 | min-similarity-lines=20
69 | ignore-comments=yes
70 | ignore-docstrings=yes
71 | ignore-imports=yes
72 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing
2 |
3 | If you are interested in contributing to the project, your contributions will fall
4 | into three categories:
5 |
6 | 1. You want to report a bug, feature request, or documentation issue
7 | - File an [issue](https://github.com/NVIDIA-Merlin/core/issues/new/choose)
8 | describing what you encountered or what you want to see changed.
9 | - The maintainers evaluate the issues and triage them, scheduling
10 | them for a release. If you believe the issue needs priority attention
11 | comment on the issue to notify the team.
12 | 2. You want to propose a new feature and implement it
13 | - Post about your intended feature, and we shall discuss the design and
14 | implementation.
15 | - Once we agree that the plan looks good, go ahead and implement it, using
16 | the [code contributions](#code-contributions) guide below.
17 | 3. You want to implement a feature or bug-fix for an outstanding issue
18 | - Follow the [code contributions](#code-contributions) guide below.
19 | - If you need more context on a particular issue, please ask and we shall
20 | provide.
21 |
22 | ## Code contributions
23 |
24 | ### Your first issue
25 |
26 | 1. Read the project's [README.md](https://github.com/NVIDIA-Merlin/core/blob/main/README.md)
27 | to learn how to setup the development environment.
28 | 2. Find an issue to work on. The best way is to look for the
29 | [good first issue](https://github.com/NVIDIA-Merlin/core/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22)
30 | or [help wanted](https://github.com/NVIDIA-Merlin/core/issues?q=is%3Aissue+is%3Aopen+label%3A%22help+wanted%22) labels.
31 | 3. Comment on the issue saying you are going to work on it.
32 | 4. Code! Make sure to update unit tests!
33 | 5. When done, [create your pull request](https://github.com/NVIDIA-Merlin/core/compare).
34 | 6. Verify that CI passes all [status checks](https://help.github.com/articles/about-status-checks/).
35 | Fix if needed.
36 | 7. Wait for other developers to review your code and update code as needed.
37 | 8. After your pull request reviewed and approved, a maintainer will merge your
38 | pull request.
39 |
40 | Remember, if you are unsure about anything, don't hesitate to comment on issues
41 | and ask for clarifications!
42 |
43 | ## Label your PRs
44 |
45 | This repository uses the release-drafter action to draft and create our change log.
46 |
47 | Please add one of the following labels to your PR to specify the type of contribution
48 | and help categorize the PR in our change log:
49 |
50 | - `breaking` -- The PR creates a breaking change to the API.
51 | - `bug` -- The PR fixes a problem with the code.
52 | - `feature` or `enhancement` -- The PR introduces a backward-compatible feature.
53 | - `documentation` or `examples` -- The PR is an addition or update to documentation.
54 | - `build`, `dependencies`, `chore`, or `ci` -- The PR is related to maintaining the
55 | repository or the project.
56 |
57 | By default, an unlabeled PR is listed at the top of the change log and is not
58 | grouped under a heading like _Features_ that groups similar PRs.
59 | Labeling the PRs so we can categorize them is preferred.
60 |
61 | If, for some reason, you do not believe your PR should be included in the change
62 | log, you can add the `skip-changelog` label.
63 | This label excludes the PR from the change log.
64 |
65 | For more information, see `.github/release-drafter.yml` in the repository
66 | or go to .
67 |
68 | ## Attribution
69 |
70 | Portions adopted from .
71 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include versioneer.py
2 | recursive-include merlin *.py *.proto
3 | include requirements.txt
4 | include requirements-dev.txt
5 |
6 | include merlin/core/_version.py
7 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # [Merlin Core](https://github.com/NVIDIA-Merlin/core)
2 |
3 | [](https://pypi.python.org/pypi/merlin-core/)
4 | [](LICENSE)
5 | [](https://nvidia-merlin.github.io/core)
6 |
7 | The Merlin Core library provides the core utilities for [NVIDIA Merlin](https://github.com/NVIDIA-Merlin) libraries
8 | like [NVTabular](https://github.com/NVIDIA-Merlin/NVTabular), [Transformers4Rec](https://github.com/NVIDIA-Merlin/Transformers4Rec)
9 | and [Merlin Models](https://github.com/NVIDIA-Merlin/models).
10 | For example, the [merlin.io.Dataset](https://nvidia-merlin.github.io/core/stable/api/merlin.io.html#merlin.io.Dataset) and [merin.schema.Schema](https://nvidia-merlin.github.io/core/stable/api/merlin.schema.html#merlin.schema.Schema) classes are fundamental for working with data and building recommender systems with Merlin.
11 |
12 | ## Installation
13 |
14 | ### Installing Merlin Core Using Pip
15 |
16 | ```shell
17 | pip install merlin-core
18 | ```
19 |
20 | ### Installing Merlin Core Using Conda
21 |
22 | ```shell
23 | conda install -c nvidia -c rapidsai -c numba -c conda-forge merlin-core python=3.7 cudatoolkit=11.2
24 | ```
25 |
26 | ### Running Merlin Core with Docker
27 |
28 | As a fundamental library for Merlin, Merlin Core is included in the Merlin Containers.
29 |
30 | Refer to the [Merlin Containers](https://nvidia-merlin.github.io/Merlin/main/containers.html) documentation page for information about the Merlin container names, URLs to the container images on the NVIDIA GPU Cloud catalog, and key Merlin components.
31 |
32 | ## Feedback and Support
33 |
34 | To report bugs or get help, please open an issue on the [GitHub repo](https://github.com/NVIDIA-Merlin/core/issues).
35 |
--------------------------------------------------------------------------------
/ci/ignore_codespell_words.txt:
--------------------------------------------------------------------------------
1 | te
2 | coo
3 | ser
4 | fo
5 | ot
6 | lik
7 | usera
8 |
--------------------------------------------------------------------------------
/ci/pr.gpu.Jenkinsfile:
--------------------------------------------------------------------------------
1 | pipeline {
2 | agent {
3 | docker {
4 | image 'nvcr.io/nvstaging/merlin/merlin-ci-runner-wrapper'
5 | label 'merlin_gpu_gcp || merlin_gpu'
6 | registryCredentialsId 'jawe-nvcr-io'
7 | registryUrl 'https://nvcr.io'
8 | args "--runtime=nvidia -e NVIDIA_VISIBLE_DEVICES=all"
9 | }
10 | }
11 |
12 | options {
13 | buildDiscarder(logRotator(numToKeepStr: '10'))
14 | ansiColor('xterm')
15 | disableConcurrentBuilds(abortPrevious: true)
16 | }
17 |
18 | stages {
19 | stage("test-gpu") {
20 | options {
21 | timeout(time: 60, unit: 'MINUTES', activity: true)
22 | }
23 | steps {
24 | sh """#!/bin/bash
25 | #set -e
26 | printenv
27 |
28 | rm -rf $HOME/.cudf/
29 | export TF_MEMORY_ALLOCATION="0.1"
30 | export CUDA_VISIBLE_DEVICES=0,1
31 | export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION='python'
32 | export MKL_SERVICE_FORCE_INTEL=1
33 | export NR_USER=true
34 |
35 | tox -re test-gpu
36 |
37 | rm -rf "nvtabular-$GIT_COMMIT"
38 | rm -rf "models-$GIT_COMMIT"
39 | rm -rf "systems-$GIT_COMMIT"
40 | """
41 | }
42 | }
43 | }
44 | }
--------------------------------------------------------------------------------
/ci/test_unit.sh:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2022, NVIDIA CORPORATION.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | #!/bin/bash
18 | set -e
19 |
20 | cd /core/
21 |
22 | # Run tests
23 | pytest -rxs /core/tests/unit
24 |
25 |
--------------------------------------------------------------------------------
/conda/recipe/meta.yaml:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION.
2 |
3 | # Usage:
4 | # conda build . -c defaults -c conda-forge -c numba -c rapidsai
5 |
6 |
7 | {% set version = environ.get('GIT_DESCRIBE_TAG', '0.1').lstrip('v') + environ.get('VERSION_SUFFIX', '') %}
8 | {% set git_revision_count=environ.get('GIT_DESCRIBE_NUMBER', 0) %}
9 | {% set setup_py_data = load_setup_py_data() %}
10 |
11 |
12 | package:
13 | name: merlin-core
14 | version: {{ version }}
15 |
16 | source:
17 | path: ../../
18 |
19 | build:
20 | number: {{ git_revision_count }}
21 | noarch: python
22 | script: python -m pip install . -vvv
23 |
24 | requirements:
25 | build:
26 | - python
27 | - setuptools
28 | run:
29 | - python
30 | {% for req in setup_py_data.get('install_requires', []) %}
31 | - {{ req }}
32 | {% endfor %}
33 | - cupy>=10
34 | - nvtx>=0.2.1
35 |
36 | about:
37 | home: https://github.com/NVIDIA-Merlin/core
38 | license_file: LICENSE
39 | summary: Core Utilities for NVIDIA Merlin
40 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = source
9 | BUILDDIR = build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/README.md:
--------------------------------------------------------------------------------
1 | # Documentation
2 |
3 | This folder contains the scripts necessary to build documentation for the
4 | merlin-core library. You can find the [generated documentation
5 | here](https://nvidia-merlin.github.io/core).
6 |
7 | ## Contributing to Docs
8 |
9 | You build the documentation with the `tox` command and specify the `docs` environment.
10 | The following steps are one way of many to build the documentation before opening a merge request.
11 |
12 | 1. Create a virtual environment:
13 |
14 | ```shell
15 | python -m venv .venv
16 | ```
17 |
18 | 1. Activate the virtual environment:
19 |
20 | ```shell
21 | source .venv/bin/activate
22 | ```
23 |
24 | 1. Install tox in the virtual environment:
25 |
26 | ```shell
27 | python -m pip install --upgrade pip
28 | python -m pip install tox
29 | ```
30 |
31 | 1. Build the documentation with tox:
32 |
33 | ```shell
34 | tox -e docs
35 | ```
36 |
37 | These steps run Sphinx in your shell and create HTML in the `docs/build/html/`
38 | directory.
39 |
40 | ## Preview the Changes
41 |
42 | View the docs web page by opening the HTML in your browser. First, navigate to
43 | the `build/html/` directory and then run the following command:
44 |
45 | ```shell
46 | python -m http.server
47 | ```
48 |
49 | Afterward, open a web browser and access .
50 |
51 | Check that yours edits formatted correctly and read well.
52 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=source
11 | set BUILDDIR=build
12 |
13 | if "%1" == "" goto help
14 |
15 | %SPHINXBUILD% >NUL 2>NUL
16 | if errorlevel 9009 (
17 | echo.
18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
19 | echo.installed, then set the SPHINXBUILD environment variable to point
20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
21 | echo.may add the Sphinx directory to PATH.
22 | echo.
23 | echo.If you don't have Sphinx installed, grab it from
24 | echo.https://www.sphinx-doc.org/
25 | exit /b 1
26 | )
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/docs/source/_static/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVIDIA-Merlin/core/6d396aa1d5e6ffe60bb133dcfe93d869b8548ba2/docs/source/_static/.gitkeep
--------------------------------------------------------------------------------
/docs/source/_static/NVIDIA-LogoBlack.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/docs/source/_static/NVIDIA-LogoWhite.svg:
--------------------------------------------------------------------------------
1 |
2 |
59 |
--------------------------------------------------------------------------------
/docs/source/_static/css/versions.css:
--------------------------------------------------------------------------------
1 | /* Version Switcher */
2 |
3 | .rst-versions {
4 | flex-align: bottom;
5 | bottom: 0;
6 | left: 0;
7 | z-index: 400
8 | }
9 |
10 | .rst-versions a {
11 | color: var(--nv-green);
12 | text-decoration: none
13 | }
14 |
15 | .rst-versions .rst-badge-small {
16 | display: none
17 | }
18 |
19 | .rst-versions .rst-current-version {
20 | padding: 12px;
21 | display: block;
22 | text-align: right;
23 | font-size: 90%;
24 | cursor: pointer;
25 | border-top: 1px solid rgba(0,0,0,.1);
26 | *zoom:1
27 | }
28 |
29 | .rst-versions .rst-current-version:before,.rst-versions .rst-current-version:after {
30 | display: table;
31 | content: ""
32 | }
33 |
34 | .rst-versions .rst-current-version:after {
35 | clear: both
36 | }
37 |
38 | .rst-versions .rst-current-version .fa-book {
39 | float: left
40 | }
41 |
42 | .rst-versions .rst-current-version .icon-book {
43 | float: left
44 | }
45 |
46 | .rst-versions .rst-current-version.rst-out-of-date {
47 | background-color: #E74C3C;
48 | color: #fff
49 | }
50 |
51 | .rst-versions .rst-current-version.rst-active-old-version {
52 | background-color: #F1C40F;
53 | color: #000
54 | }
55 |
56 | .rst-versions.shift-up {
57 | height: auto;
58 | max-height: 100%
59 | }
60 |
61 | .rst-versions.shift-up .rst-other-versions {
62 | display: block
63 | }
64 |
65 | .rst-versions .rst-other-versions {
66 | font-size: 90%;
67 | padding: 12px;
68 | color: gray;
69 | display: none
70 | }
71 |
72 | .rst-versions .rst-other-versions hr {
73 | display: block;
74 | height: 1px;
75 | border: 0;
76 | margin: 20px 0;
77 | padding: 0;
78 | border-top: solid 1px #413d3d
79 | }
80 |
81 | .rst-versions .rst-other-versions dd {
82 | display: inline-block;
83 | margin: 0
84 | }
85 |
86 | .rst-versions .rst-other-versions dd a {
87 | display: inline-block;
88 | padding: 6px;
89 | color: var(--nv-green);
90 | font-weight: 500;
91 | }
92 |
93 | .rst-versions.rst-badge {
94 | width: auto;
95 | bottom: 20px;
96 | right: 20px;
97 | left: auto;
98 | border: none;
99 | max-width: 300px
100 | }
101 |
102 | .rst-versions.rst-badge .icon-book {
103 | float: none
104 | }
105 |
106 | .rst-versions.rst-badge .fa-book {
107 | float: none
108 | }
109 |
110 | .rst-versions.rst-badge.shift-up .rst-current-version {
111 | text-align: right
112 | }
113 |
114 | .rst-versions.rst-badge.shift-up .rst-current-version .fa-book {
115 | float: left
116 | }
117 |
118 | .rst-versions.rst-badge.shift-up .rst-current-version .icon-book {
119 | float: left
120 | }
121 |
122 | .rst-versions.rst-badge .rst-current-version {
123 | width: auto;
124 | height: 30px;
125 | line-height: 30px;
126 | padding: 0 6px;
127 | display: block;
128 | text-align: center
129 | }
130 |
131 | @media screen and (max-width: 768px) {
132 | .rst-versions {
133 | width:85%;
134 | display: none
135 | }
136 |
137 | .rst-versions.shift {
138 | display: block
139 | }
140 | }
141 |
--------------------------------------------------------------------------------
/docs/source/_static/favicon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVIDIA-Merlin/core/6d396aa1d5e6ffe60bb133dcfe93d869b8548ba2/docs/source/_static/favicon.png
--------------------------------------------------------------------------------
/docs/source/_static/js/rtd-version-switcher.js:
--------------------------------------------------------------------------------
1 | var jQuery = (typeof(window) != 'undefined') ? window.jQuery : require('jquery');
2 | var doc = $(document);
3 | doc.on('click', "[data-toggle='rst-current-version']", function() {
4 | $("[data-toggle='rst-versions']").toggleClass("shift-up");
5 | });
6 |
--------------------------------------------------------------------------------
/docs/source/_templates/footer.html:
--------------------------------------------------------------------------------
1 |
2 | Privacy Policy |
3 | Manage My Privacy |
4 | Do Not Sell or Share My Data |
5 | Terms of Service |
6 | Accessibility |
7 | Corporate Policies |
8 | Product Security |
9 | Contact
10 |
--------------------------------------------------------------------------------
/docs/source/_templates/layout.html:
--------------------------------------------------------------------------------
1 | {%- extends "!layout.html" %}
2 |
3 | {%- block extrahead %}
4 | {%- if analytics_id %}
5 |
6 |
7 |
16 | {% endif %}
17 |
18 |
19 |
20 |
21 | {%- endblock %}
22 |
--------------------------------------------------------------------------------
/docs/source/_templates/merlin-ecosystem.html:
--------------------------------------------------------------------------------
1 |
14 |
--------------------------------------------------------------------------------
/docs/source/_templates/versions.html:
--------------------------------------------------------------------------------
1 | {%- if current_version %}
2 |
3 |
4 |
5 | v: {{ current_version.name }}
6 |
7 |
8 |
9 | {%- if versions.tags %}
10 |
11 | - Tags
12 | {%- for item in versions.tags %}
13 | - {{ item.name }}
14 | {%- endfor %}
15 |
16 | {%- endif %}
17 | {%- if versions.branches %}
18 |
19 | - Branches
20 | {%- for item in versions.branches %}
21 | - {{ item.name }}
22 | {%- endfor %}
23 |
24 | {%- endif %}
25 |
26 |
27 | {%- endif %}
28 |
--------------------------------------------------------------------------------
/docs/source/api/index.rst:
--------------------------------------------------------------------------------
1 | merlin namespace
2 | ================
3 |
4 | .. py:module:: merlin
5 |
6 | Subpackages
7 | -----------
8 |
9 | .. toctree::
10 | :maxdepth: 4
11 |
12 | merlin.dag
13 | merlin.io
14 | merlin.schema
15 |
--------------------------------------------------------------------------------
/docs/source/api/merlin.dag.rst:
--------------------------------------------------------------------------------
1 | Merlin DAG
2 | ------------------
3 |
4 | .. autosummary::
5 | :toctree: generated
6 |
7 | merlin.dag.Operator
8 |
9 | merlin.dag.Graph
10 | merlin.dag.Node
11 | merlin.dag.ColumnSelector
12 |
--------------------------------------------------------------------------------
/docs/source/api/merlin.io.rst:
--------------------------------------------------------------------------------
1 | Merlin IO
2 | ------------------
3 |
4 | .. autosummary::
5 | :toctree: generated
6 |
7 | merlin.io.Dataset
8 |
--------------------------------------------------------------------------------
/docs/source/api/merlin.schema.rst:
--------------------------------------------------------------------------------
1 | Merlin Schema
2 | ------------------
3 |
4 | .. autosummary::
5 | :toctree: generated
6 |
7 | merlin.schema.Schema
8 | merlin.schema.ColumnSchema
9 | merlin.schema.Tags
10 |
--------------------------------------------------------------------------------
/docs/source/index.rst:
--------------------------------------------------------------------------------
1 | Merlin Core
2 | ===========
3 |
4 | The Merlin Core library provides the core utilities and classes for the NVIDIA Merlin project.
5 |
6 | To learn more, start with the `Introduction `_.
7 |
8 | Related Resources
9 | -----------------
10 |
11 | Merlin Core GitHub Repository
12 | ``_
13 |
14 | About Merlin
15 | Merlin is the overarching project that brings together the Merlin projects.
16 | See the `documentation `_
17 | or the `repository `_ on GitHub.
18 |
19 | Developer website for Merlin
20 | More information about Merlin is available at our developer website:
21 | ``_.
22 |
23 | Index
24 | -----
25 |
26 | * :ref:`genindex`
27 |
--------------------------------------------------------------------------------
/docs/source/toc.yaml:
--------------------------------------------------------------------------------
1 | root: index
2 | subtrees:
3 | - caption: Contents
4 | entries:
5 | - file: README.md
6 | title: Introduction
7 | - file: api/index.rst
8 | title: API Documentation
9 |
--------------------------------------------------------------------------------
/merlin/config/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2024, NVIDIA CORPORATION.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | _DASK_QUERY_PLANNING_ENABLED = False
18 | try:
19 | # Disable query-planning and string conversion
20 | import dask
21 |
22 | dask.config.set(
23 | {
24 | "dataframe.query-planning": False,
25 | "dataframe.convert-string": False,
26 | }
27 | )
28 | except ImportError:
29 | dask = None
30 | else:
31 | import sys
32 |
33 | import dask.dataframe as dd
34 | from packaging.version import parse
35 |
36 | if parse(dask.__version__) > parse("2024.6.0"):
37 | # For newer versions of dask, we can just check
38 | # the official DASK_EXPR_ENABLED constant
39 | _DASK_QUERY_PLANNING_ENABLED = dd.DASK_EXPR_ENABLED
40 | else:
41 | # For older versions of dask, we must assume query
42 | # planning is enabled if dask_expr was imported
43 | # (because we can't know for sure)
44 | _DASK_QUERY_PLANNING_ENABLED = "dask_expr" in sys.modules
45 |
46 |
47 | def validate_dask_configs():
48 | """Central check for problematic config options in Dask"""
49 | if _DASK_QUERY_PLANNING_ENABLED:
50 | raise NotImplementedError(
51 | "Merlin does not support the query-planning API in "
52 | "Dask Dataframe yet. Please make sure query-planning is "
53 | "disabled before dask.dataframe is imported.\n\ne.g."
54 | "dask.config.set({'dataframe.query-planning': False})"
55 | "\n\nOr set the environment variable: "
56 | "export DASK_DATAFRAME__QUERY_PLANNING=False"
57 | )
58 |
59 | if dask is not None and dask.config.get("dataframe.convert-string"):
60 | raise NotImplementedError(
61 | "Merlin does not support automatic string conversion in "
62 | "Dask Dataframe yet. Please make sure this option is "
63 | "disabled.\n\ne.g."
64 | "dask.config.set({'dataframe.convert-string': False})"
65 | "\n\nOr set the environment variable: "
66 | "export DASK_DATAFRAME__CONVERT_STRING=False"
67 | )
68 |
--------------------------------------------------------------------------------
/merlin/core/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2022-2024, NVIDIA CORPORATION.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | from merlin.config import validate_dask_configs
18 | from merlin.core import _version
19 |
20 | __version__ = _version.get_versions()["version"]
21 | validate_dask_configs()
22 |
--------------------------------------------------------------------------------
/merlin/core/compat/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2022, NVIDIA CORPORATION.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | # pylint: disable=unused-import
18 | import warnings
19 |
20 | from numba import cuda
21 |
22 | from merlin.core.has_gpu import HAS_GPU # noqa pylint: disable=unused-import
23 |
24 | if not cuda.is_available():
25 | cuda = None
26 |
27 | try:
28 | import psutil
29 | except ImportError:
30 | psutil = None
31 |
32 |
33 | def pynvml_mem_size(kind="total", index=0):
34 | """Get Memory Info for device.
35 |
36 | Parameters
37 | ----------
38 | kind : str, optional
39 | Either "free" or "total", by default "total"
40 | index : int, optional
41 | Device Index, by default 0
42 |
43 | Returns
44 | -------
45 | int
46 | Either free or total memory on device depending on the kind parameter.
47 |
48 | Raises
49 | ------
50 | ValueError
51 | When kind is not one of {"free", "total"}
52 | """
53 | import pynvml
54 |
55 | pynvml.nvmlInit()
56 | size = None
57 | if kind == "free":
58 | size = int(pynvml.nvmlDeviceGetMemoryInfo(pynvml.nvmlDeviceGetHandleByIndex(index)).free)
59 | elif kind == "total":
60 | size = int(pynvml.nvmlDeviceGetMemoryInfo(pynvml.nvmlDeviceGetHandleByIndex(index)).total)
61 | else:
62 | raise ValueError(f"{kind} not a supported option for device_mem_size.")
63 | pynvml.nvmlShutdown()
64 | return size
65 |
66 |
67 | def device_mem_size(kind="total", cpu=False):
68 | """Get Memory Info for either CPU or GPU.
69 |
70 | Parameters
71 | ----------
72 | kind : str, optional
73 | Either "total" or "free", by default "total"
74 | cpu : bool, optional
75 | Specifies whether to check memory for CPU or GPU, by default False
76 |
77 | Returns
78 | -------
79 | int
80 | Free or total memory on device
81 |
82 | Raises
83 | ------
84 | ValueError
85 | When kind is provided with an unsupported value.
86 | """
87 | # Use psutil (if available) for cpu mode
88 | cpu = cpu or not cuda
89 | if cpu and psutil:
90 | if kind == "total":
91 | return psutil.virtual_memory().total
92 | elif kind == "free":
93 | return psutil.virtual_memory().free
94 | elif cpu:
95 | warnings.warn("Please install psutil for full cpu=True support.")
96 | # Assume 1GB of memory
97 | return int(1e9)
98 |
99 | if kind not in ["free", "total"]:
100 | raise ValueError(f"{kind} not a supported option for device_mem_size.")
101 | try:
102 | if kind == "free":
103 | return int(cuda.current_context().get_memory_info()[0])
104 | else:
105 | return int(cuda.current_context().get_memory_info()[1])
106 | except NotImplementedError:
107 | if kind == "free":
108 | # Not using NVML "free" memory, because it will not include RMM-managed memory
109 | warnings.warn("get_memory_info is not supported. Using total device memory from NVML.")
110 | size = pynvml_mem_size(kind="total", index=0)
111 | return size
112 |
113 |
114 | try:
115 | import numpy
116 | except ImportError:
117 | numpy = None
118 |
119 | try:
120 | import pandas
121 | except ImportError:
122 | pandas = None
123 |
124 | if HAS_GPU:
125 | try:
126 | import cupy
127 | except ImportError:
128 | cupy = None
129 |
130 | try:
131 | import cudf
132 | except ImportError:
133 | cudf = None
134 |
135 | try:
136 | import dask_cudf
137 | except ImportError:
138 | dask_cudf = None
139 | else:
140 | # Without a GPU available none of these packages should be used
141 | cupy = None
142 | cudf = None
143 | dask_cudf = None
144 |
--------------------------------------------------------------------------------
/merlin/core/compat/tensorflow.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2023, NVIDIA CORPORATION.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | # pylint: disable=unused-import
18 | import os
19 | import warnings
20 |
21 | from packaging import version
22 |
23 | from merlin.core.compat import HAS_GPU, device_mem_size
24 |
25 | try:
26 | import tensorflow
27 |
28 | def configure_tensorflow(memory_allocation=None, device=None):
29 | """Utility to help configure tensorflow to not use 100% of gpu memory as buffer"""
30 | tf = tensorflow
31 | total_gpu_mem_mb = device_mem_size(kind="total", cpu=(not HAS_GPU)) / (1024**2)
32 |
33 | if memory_allocation is None:
34 | memory_allocation = os.environ.get("TF_MEMORY_ALLOCATION", 0.5)
35 |
36 | if float(memory_allocation) < 1:
37 | memory_allocation = total_gpu_mem_mb * float(memory_allocation)
38 | memory_allocation = int(memory_allocation)
39 | assert memory_allocation < total_gpu_mem_mb
40 |
41 | if HAS_GPU:
42 | tf_devices = tf.config.list_physical_devices("GPU")
43 |
44 | if len(tf_devices) == 0:
45 | raise ImportError("TensorFlow is not configured for GPU")
46 |
47 | for tf_device in tf_devices:
48 | try:
49 | tf.config.set_logical_device_configuration(
50 | tf_device,
51 | [tf.config.LogicalDeviceConfiguration(memory_limit=memory_allocation)],
52 | )
53 | except RuntimeError:
54 | warnings.warn(
55 | "TensorFlow runtime already initialized, may not be enough memory for cudf"
56 | )
57 | try:
58 | tf.config.experimental.set_virtual_device_configuration(
59 | tf_device,
60 | [
61 | tf.config.experimental.VirtualDeviceConfiguration(
62 | memory_limit=memory_allocation
63 | )
64 | ],
65 | )
66 | except RuntimeError as e:
67 | # Virtual devices must be set before GPUs have been initialized
68 | warnings.warn(str(e))
69 |
70 | # versions using TF earlier than 2.3.0 need to use extension
71 | # library for dlpack support to avoid memory leak issue
72 | __TF_DLPACK_STABLE_VERSION = "2.3.0"
73 | if version.parse(tf.__version__) < version.parse(__TF_DLPACK_STABLE_VERSION):
74 | try:
75 | from tfdlpack import from_dlpack
76 | except ModuleNotFoundError as e:
77 | message = (
78 | "If using TensorFlow < 2.3.0, you must install tfdlpack-gpu extension library"
79 | )
80 | raise ModuleNotFoundError(message) from e
81 |
82 | else:
83 | from tensorflow.experimental.dlpack import from_dlpack
84 |
85 | return from_dlpack
86 |
87 | configure_tensorflow()
88 |
89 | from tensorflow.python.framework import ops as tf_ops
90 | except ImportError:
91 | tensorflow = None
92 | tf_ops = None
93 |
--------------------------------------------------------------------------------
/merlin/core/compat/torch.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2023, NVIDIA CORPORATION.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | # pylint: disable=unused-import
18 |
19 | try:
20 | import torch
21 | except ImportError:
22 | torch = None
23 |
--------------------------------------------------------------------------------
/merlin/core/has_gpu.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2023, NVIDIA CORPORATION.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | # pylint: disable=unused-import
18 | import os
19 |
20 | from dask.distributed.diagnostics import nvml
21 |
22 |
23 | def _get_gpu_count():
24 | """Get Number of GPU devices accounting for CUDA_VISIBLE_DEVICES environment variable"""
25 | # Using the `dask.distributed.diagnostics.nvml.device_get_count`
26 | # helper function from dask to check device counts with NVML
27 | # since this handles some complexity of checking NVML state for us.
28 |
29 | # Note: We can't use `numba.cuda.gpus`, since this has some side effects
30 | # that are incompatible with Dask-CUDA. If CUDA runtime functions are
31 | # called before Dask-CUDA can spawn worker processes
32 | # then Dask-CUDA it will not work correctly (raises an exception)
33 | nvml_device_count = nvml.device_get_count()
34 | if nvml_device_count == 0:
35 | return 0
36 | try:
37 | cuda_visible_devices = os.environ["CUDA_VISIBLE_DEVICES"]
38 | if cuda_visible_devices:
39 | return len(cuda_visible_devices.split(","))
40 | else:
41 | return 0
42 | except KeyError:
43 | return nvml_device_count
44 |
45 |
46 | HAS_GPU = _get_gpu_count() > 0
47 |
--------------------------------------------------------------------------------
/merlin/dag/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2022-2024, NVIDIA CORPORATION.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 | # flake8: noqa
17 |
18 | from merlin.config import validate_dask_configs
19 |
20 | validate_dask_configs()
21 |
22 | from merlin.dag.graph import Graph
23 | from merlin.dag.node import Node, iter_nodes, postorder_iter_nodes, preorder_iter_nodes
24 | from merlin.dag.operator import DataFormats, Operator, Supports
25 | from merlin.dag.selector import ColumnSelector
26 | from merlin.dag.utils import group_values_offsets, ungroup_values_offsets
27 |
28 | BaseOperator = Operator
29 |
--------------------------------------------------------------------------------
/merlin/dag/ops/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2022, NVIDIA CORPORATION.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 | # alias submodules here to avoid breaking everything with moving to submodules
17 | # flake8: noqa
18 | from merlin.dag.ops.add_metadata import (
19 | AddMetadata,
20 | AddProperties,
21 | AddTags,
22 | TagAsItemFeatures,
23 | TagAsItemID,
24 | TagAsUserFeatures,
25 | TagAsUserID,
26 | )
27 | from merlin.dag.ops.concat_columns import ConcatColumns
28 | from merlin.dag.ops.grouping import GroupingOp
29 | from merlin.dag.ops.rename import Rename
30 | from merlin.dag.ops.selection import SelectionOp
31 | from merlin.dag.ops.subset_columns import SubsetColumns
32 | from merlin.dag.ops.subtraction import SubtractionOp
33 | from merlin.dag.ops.udf import UDF
34 |
--------------------------------------------------------------------------------
/merlin/dag/ops/add_metadata.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2023, NVIDIA CORPORATION.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 | from merlin.dag.operator import Operator
17 | from merlin.schema.tags import Tags
18 |
19 |
20 | class AddMetadata(Operator):
21 | """
22 | This operator will add user defined tags and properties
23 | to a Schema.
24 | """
25 |
26 | def __init__(self, tags=None, properties=None):
27 | super().__init__()
28 | self.tags = tags or []
29 | self.properties = properties or {}
30 |
31 | @property
32 | def output_tags(self):
33 | return self.tags
34 |
35 | @property
36 | def output_properties(self):
37 | return self.properties
38 |
39 |
40 | class AddTags(AddMetadata):
41 | def __init__(self, tags=None):
42 | super().__init__(tags=tags)
43 |
44 |
45 | class AddProperties(AddMetadata):
46 | def __init__(self, properties=None):
47 | super().__init__(properties=properties)
48 |
49 |
50 | # Wrappers for common features
51 | class TagAsUserID(AddTags):
52 | def __init__(self, tags=None):
53 | super().__init__(tags=[Tags.ID, Tags.USER])
54 |
55 |
56 | class TagAsItemID(AddTags):
57 | def __init__(self, tags=None):
58 | super().__init__(tags=[Tags.ID, Tags.ITEM])
59 |
60 |
61 | class TagAsUserFeatures(AddTags):
62 | def __init__(self, tags=None):
63 | super().__init__(tags=[Tags.USER])
64 |
65 |
66 | class TagAsItemFeatures(AddTags):
67 | def __init__(self, tags=None):
68 | super().__init__(tags=[Tags.ITEM])
69 |
--------------------------------------------------------------------------------
/merlin/dag/ops/concat_columns.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2022, NVIDIA CORPORATION.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | from merlin.core.protocols import Transformable
18 | from merlin.dag.operator import Operator
19 | from merlin.dag.selector import ColumnSelector
20 | from merlin.schema import Schema
21 |
22 |
23 | class ConcatColumns(Operator):
24 | """
25 | This operator class provides an implementation for the `+` operator used in constructing graphs.
26 | """
27 |
28 | def __init__(self, label=None):
29 | self._label = label or self.__class__.__name__
30 | super().__init__()
31 |
32 | def compute_selector(
33 | self,
34 | input_schema: Schema,
35 | selector: ColumnSelector,
36 | parents_selector: ColumnSelector = None,
37 | dependencies_selector: ColumnSelector = None,
38 | ) -> ColumnSelector:
39 | """
40 | Combine selectors from the nodes being added
41 |
42 | Parameters
43 | ----------
44 | input_schema : Schema
45 | Combined schema of the columns coming from upstream nodes
46 | selector : ColumnSelector
47 | Existing column selector for this node in the graph (often None)
48 | parents_selector : ColumnSelector
49 | Combined column selectors of parent nodes
50 | dependencies_selector : ColumnSelector
51 | Combined column selectors of dependency nodes
52 |
53 | Returns
54 | -------
55 | ColumnSelector
56 | Combined column selectors of parent and dependency nodes
57 | """
58 | upstream_selector = parents_selector + dependencies_selector
59 | if upstream_selector.subgroups:
60 | selector = super().compute_selector(
61 | input_schema,
62 | upstream_selector,
63 | )
64 | else:
65 | selector = ColumnSelector(input_schema.column_names)
66 | return selector
67 |
68 | def compute_input_schema(
69 | self,
70 | root_schema: Schema,
71 | parents_schema: Schema,
72 | deps_schema: Schema,
73 | selector: ColumnSelector,
74 | ) -> Schema:
75 | """
76 | Combine schemas from the nodes being added
77 |
78 | Parameters
79 | ----------
80 | root_schema : Schema
81 | Schema of the columns from the input dataset
82 | parents_schema : Schema
83 | Schema of the columns from the parent nodes
84 | deps_schema : Schema
85 | Schema of the columns from the dependency nodes
86 | selector : ColumnSelector
87 | Existing column selector for this node in the graph (often None)
88 |
89 | Returns
90 | -------
91 | Schema
92 | Combined schema of columns from parents and dependencies
93 | """
94 | return parents_schema + deps_schema
95 |
96 | def transform(
97 | self, col_selector: ColumnSelector, transformable: Transformable
98 | ) -> Transformable:
99 | """Simply returns the selected output columns from the input dataframe
100 |
101 | The main functionality of this operator has to do with computing the schemas
102 | for `+` nodes in the Workflow graph, so very little has to happen in the
103 | `transform` method.
104 |
105 | Parameters
106 | -----------
107 | columns: list of str or list of list of str
108 | The columns to apply this operator to
109 | transformable: Transformable
110 | A pandas or cudf dataframe that this operator will work on
111 |
112 | Returns
113 | -------
114 | DataFrame
115 | Returns a transformed dataframe for this operator
116 | """
117 | return super()._get_columns(transformable, col_selector)
118 |
119 | @property
120 | def label(self) -> str:
121 | """
122 | Display name of this operator
123 | """
124 | return self._label
125 |
--------------------------------------------------------------------------------
/merlin/dag/ops/grouping.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | from merlin.dag.ops.selection import SelectionOp
4 | from merlin.dag.selector import ColumnSelector
5 | from merlin.schema import Schema
6 |
7 |
8 | class GroupingOp(SelectionOp):
9 | def compute_selector(
10 | self,
11 | input_schema: Schema,
12 | selector: ColumnSelector,
13 | parents_selector: Optional[ColumnSelector] = None,
14 | dependencies_selector: Optional[ColumnSelector] = None,
15 | ) -> ColumnSelector:
16 | upstream_selector = parents_selector + dependencies_selector
17 | new_selector = ColumnSelector(subgroups=upstream_selector)
18 | selector = super().compute_selector(
19 | input_schema,
20 | new_selector,
21 | )
22 | return selector
23 |
--------------------------------------------------------------------------------
/merlin/dag/ops/rename.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2023, NVIDIA CORPORATION.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 | from merlin.core.protocols import Transformable
17 | from merlin.dag.operator import Operator
18 | from merlin.dag.selector import ColumnSelector
19 |
20 |
21 | class Rename(Operator):
22 | """This operation renames columns by one of several methods:
23 |
24 | - using a user defined lambda function to transform column names
25 | - appending a postfix string to every column name
26 | - renaming a single column to a single fixed string
27 |
28 | Example usage::
29 |
30 | # Rename columns after LogOp
31 | cont_features = cont_names >> nvt.ops.LogOp() >> Rename(postfix='_log')
32 | processor = nvt.Workflow(cont_features)
33 |
34 | Parameters
35 | ----------
36 | f : callable, optional
37 | Function that takes a column name and returns a new column name
38 | postfix : str, optional
39 | If set each column name in the output will have this string appended to it
40 | name : str, optional
41 | If set, a single input column will be renamed to this string
42 | """
43 |
44 | def __init__(self, f=None, postfix=None, name=None):
45 | if not f and postfix is None and name is None:
46 | raise ValueError("must specify name, f, or postfix, for Rename op")
47 |
48 | self.f = f
49 | self.postfix = postfix
50 | self.name = name
51 | super().__init__()
52 |
53 | def transform(
54 | self, col_selector: ColumnSelector, transformable: Transformable
55 | ) -> Transformable:
56 | transformable = transformable[col_selector.names]
57 | transformable.columns = list( # type: ignore[assignment]
58 | self.column_mapping(col_selector).keys()
59 | )
60 | return transformable
61 |
62 | transform.__doc__ = Operator.transform.__doc__
63 |
64 | def column_mapping(self, col_selector):
65 | column_mapping = {}
66 | for col_name in col_selector.names:
67 | if self.f:
68 | new_col_name = self.f(col_name)
69 | elif self.postfix:
70 | new_col_name = col_name + self.postfix
71 | elif self.name:
72 | if len(col_selector.names) == 1:
73 | new_col_name = self.name
74 | else:
75 | raise RuntimeError("Single column name provided for renaming multiple columns")
76 | else:
77 | raise RuntimeError(
78 | "The Rename op requires one of f, postfix, or name to be provided"
79 | )
80 |
81 | column_mapping[new_col_name] = [col_name]
82 |
83 | return column_mapping
84 |
--------------------------------------------------------------------------------
/merlin/dag/ops/selection.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2022, NVIDIA CORPORATION.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 | from __future__ import annotations
17 |
18 | import logging
19 |
20 | from merlin.core.protocols import Transformable
21 | from merlin.dag.operator import Operator
22 | from merlin.dag.selector import ColumnSelector
23 | from merlin.schema import Schema
24 |
25 | LOG = logging.getLogger("SelectionOp")
26 |
27 |
28 | class SelectionOp(Operator):
29 | """
30 | This operator class provides an implementation of the behavior of selection (e.g. input) nodes.
31 | """
32 |
33 | def __init__(self, selector=None):
34 | self.selector = selector
35 | super().__init__()
36 |
37 | def transform(
38 | self, col_selector: ColumnSelector, transformable: Transformable
39 | ) -> Transformable:
40 | """Simply returns the selected output columns from the input dataframe
41 |
42 | The main functionality of this operator has to do with computing the schemas
43 | for selection nodes in the Workflow graph, so very little has to happen in the
44 | `transform` method.
45 |
46 | Parameters
47 | -----------
48 | columns: list of str or list of list of str
49 | The columns to apply this operator to
50 | transformable: Transformable
51 | A pandas or cudf dataframe that this operator will work on
52 |
53 | Returns
54 | -------
55 | DataFrame
56 | Returns a transformed dataframe for this operator
57 | """
58 | selector = col_selector or self.selector
59 | return super()._get_columns(transformable, selector)
60 |
61 | def compute_input_schema(
62 | self,
63 | root_schema: Schema,
64 | parents_schema: Schema,
65 | deps_schema: Schema,
66 | selector: ColumnSelector,
67 | ) -> Schema:
68 | """
69 | Return the schemas of columns
70 |
71 | Parameters
72 | ----------
73 | root_schema : Schema
74 | Schema of the columns from the input dataset
75 | parents_schema : Schema
76 | Schema of the columns from the parent nodes
77 | deps_schema : Schema
78 | Schema of the columns from the dependency nodes
79 | selector : ColumnSelector
80 | Existing column selector for this node in the graph (often None)
81 |
82 | Returns
83 | -------
84 | Schema
85 | Schema of selected columns from input, parents, and dependencies
86 | """
87 | upstream_schema = root_schema + parents_schema + deps_schema
88 | return upstream_schema.select(self.selector)
89 |
90 | def compute_output_schema(
91 | self,
92 | input_schema: Schema,
93 | col_selector: ColumnSelector,
94 | prev_output_schema: Schema = None,
95 | ) -> Schema:
96 | """Given a set of schemas and a column selector for the input columns,
97 | returns a set of schemas for the transformed columns this operator will produce
98 |
99 | Parameters
100 | -----------
101 | input_schema: Schema
102 | The schemas of the columns to apply this operator to
103 | col_selector: ColumnSelector
104 | The column selector to apply to the input schema
105 | Returns
106 | -------
107 | Schema
108 | The schemas of the columns produced by this operator
109 | """
110 | selector = col_selector or self.selector
111 | if selector.all:
112 | selector = ColumnSelector(input_schema.column_names)
113 | return super().compute_output_schema(input_schema, selector, prev_output_schema)
114 |
--------------------------------------------------------------------------------
/merlin/dag/ops/stat_operator.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2023, NVIDIA CORPORATION.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 | from typing import Any
17 |
18 | import dask.dataframe as dd
19 |
20 | from merlin.dag.operator import Operator
21 | from merlin.dag.selector import ColumnSelector
22 |
23 |
24 | class StatOperator(Operator):
25 | """
26 | Base class for statistical operator classes. This adds a 'fit' and 'finalize' method
27 | on top of the Operator class.
28 | """
29 |
30 | fitted = False
31 |
32 | def fit(self, col_selector: ColumnSelector, ddf: dd.DataFrame) -> Any:
33 | """Calculate statistics for this operator, and return a dask future
34 | to these statistics, which will be computed by the workflow."""
35 |
36 | raise NotImplementedError(
37 | """The dask operations needed to return a dictionary of uncomputed statistics."""
38 | )
39 |
40 | def fit_finalize(self, dask_stats):
41 | """Finalize statistics calculation - the workflow calls this function with
42 | the computed statistics from the 'fit' object'"""
43 |
44 | raise NotImplementedError(
45 | """Follow-up operations to convert dask statistics in to member variables"""
46 | )
47 |
48 | def clear(self):
49 | """zero and reinitialize all relevant statistical properties"""
50 | raise NotImplementedError("clear isn't implemented for this op!")
51 |
52 | def set_storage_path(self, new_path, copy=False):
53 | """Certain stat operators need external storage - for instance Categorify writes out
54 | parquet files containing the categorical mapping. When we save the operator, we
55 | also want to save these files as part of the bundle. Implementing this method
56 | lets statoperators bundle their dependent files into the new path that we're writing
57 | out (note that this could happen after the operator is created)
58 | """
59 |
--------------------------------------------------------------------------------
/merlin/dag/ops/subset_columns.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2022, NVIDIA CORPORATION.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | from merlin.core.protocols import Transformable
18 | from merlin.dag.operator import Operator
19 | from merlin.dag.selector import ColumnSelector
20 |
21 |
22 | class SubsetColumns(Operator):
23 | """
24 | This operator class provides an implementation for the `[]` operator
25 | used in constructing graphs.
26 | """
27 |
28 | def __init__(self, label=None):
29 | self._label = label or self.__class__.__name__
30 | super().__init__()
31 |
32 | def transform(
33 | self, col_selector: ColumnSelector, transformable: Transformable
34 | ) -> Transformable:
35 | """Simply returns the selected output columns from the input dataframe
36 |
37 | The main functionality of this operator has to do with computing the schemas
38 | for `-` nodes in the Workflow graph, so very little has to happen in the
39 | `transform` method.
40 |
41 | Parameters
42 | -----------
43 | columns: list of str or list of list of str
44 | The columns to apply this operator to
45 | transformable: Transformable
46 | A pandas or cudf dataframe that this operator will work on
47 |
48 | Returns
49 | -------
50 | DataFrame
51 | Returns a transformed dataframe for this operator
52 | """
53 | return super()._get_columns(transformable, col_selector)
54 |
55 | @property
56 | def label(self) -> str:
57 | return self._label
58 |
--------------------------------------------------------------------------------
/merlin/dag/ops/subtraction.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2022, NVIDIA CORPORATION.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 | from __future__ import annotations
17 |
18 | from merlin.core.protocols import Transformable
19 | from merlin.dag.operator import Operator
20 | from merlin.dag.selector import ColumnSelector
21 | from merlin.schema import Schema
22 |
23 |
24 | class SubtractionOp(Operator):
25 | """
26 | This operator class provides an implementation for the `-` operator used in constructing graphs.
27 | """
28 |
29 | def __init__(self, selector=None):
30 | self.selector = selector
31 | super().__init__()
32 |
33 | def compute_selector(
34 | self,
35 | input_schema: Schema,
36 | selector: ColumnSelector,
37 | parents_selector: ColumnSelector = None,
38 | dependencies_selector: ColumnSelector = None,
39 | ) -> ColumnSelector:
40 | """
41 | Creates selector of all columns from the input schema
42 |
43 | Parameters
44 | ----------
45 | input_schema : Schema
46 | Combined schema of the columns coming from upstream nodes
47 | selector : ColumnSelector
48 | Existing column selector for this node in the graph (often None)
49 | parents_selector : ColumnSelector
50 | Combined column selectors of parent nodes
51 | dependencies_selector : ColumnSelector
52 | Combined column selectors of dependency nodes
53 |
54 | Returns
55 | -------
56 | ColumnSelector
57 | Selector of all columns from the input schema
58 | """
59 | return super().compute_selector(
60 | input_schema,
61 | ColumnSelector("*"),
62 | )
63 |
64 | def compute_input_schema(
65 | self,
66 | root_schema: Schema,
67 | parents_schema: Schema,
68 | deps_schema: Schema,
69 | selector: ColumnSelector,
70 | ) -> Schema:
71 | """
72 | Return remaining schemas of columns after removing dependencies
73 |
74 | Parameters
75 | ----------
76 | root_schema : Schema
77 | Schema of the columns from the input dataset
78 | parents_schema : Schema
79 | Schema of the columns from the parent nodes
80 | deps_schema : Schema
81 | Schema of the columns from the dependency nodes
82 | selector : ColumnSelector
83 | Existing column selector for this node in the graph (often None)
84 |
85 | Returns
86 | -------
87 | Schema
88 | Remaining schema of columns from parents after removing dependencies
89 | """
90 | result = None
91 | if deps_schema.column_schemas:
92 | result = parents_schema - deps_schema
93 | else:
94 | subtraction_selector = self.selector or selector
95 | result = parents_schema.excluding(subtraction_selector)
96 | return result
97 |
98 | def transform(
99 | self, col_selector: ColumnSelector, transformable: Transformable
100 | ) -> Transformable:
101 | """Simply returns the selected output columns from the input dataframe
102 |
103 | The main functionality of this operator has to do with computing the schemas
104 | for `-` nodes in the Workflow graph, so very little has to happen in the
105 | `transform` method.
106 |
107 | Parameters
108 | -----------
109 | columns: list of str or list of list of str
110 | The columns to apply this operator to
111 | transformable: Transformable
112 | A pandas or cudf dataframe that this operator will work on
113 |
114 | Returns
115 | -------
116 | DataFrame
117 | Returns a transformed dataframe for this operator
118 | """
119 | selector = self.selector or col_selector
120 | return super()._get_columns(transformable, selector)
121 |
--------------------------------------------------------------------------------
/merlin/dag/runtime.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2023, NVIDIA CORPORATION.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 | from merlin.core.protocols import Transformable
17 | from merlin.dag.executors import LocalExecutor
18 | from merlin.dag.graph import Graph
19 |
20 |
21 | class Runtime:
22 | """A Graph Runtime.
23 |
24 | This class can be used as a base class for custom runtimes.
25 | """
26 |
27 | def __init__(self, executor=None):
28 | """Construct a Runtime.
29 |
30 | Parameters
31 | ----------
32 | executor : Executor, optional
33 | The Graph Executor to use to use for the transform, by default None
34 | """
35 | self.executor = executor or LocalExecutor()
36 | self.op_table = {}
37 |
38 | def transform(self, graph: Graph, transformable: Transformable):
39 | """Run the graph with the input data.
40 |
41 | Parameters
42 | ----------
43 | graph : Graph
44 | Graph of nodes container operator chains for data manipulation.
45 | transformable : Transformable
46 | Input data to transform in graph.
47 |
48 | Returns
49 | -------
50 | Transformable
51 | Input data after it has been transformed via graph.
52 | """
53 | return self.executor.transform(transformable, [graph.output_node])
54 |
55 | def export(self):
56 | """Optional method.
57 | Implemented for runtimes that require an exported artifact to transform.
58 | """
59 | raise NotImplementedError
60 |
--------------------------------------------------------------------------------
/merlin/dag/utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2023, NVIDIA CORPORATION.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 |
18 | def ungroup_values_offsets(grouped_cols: dict) -> dict:
19 | """
20 | Flatten columns with values/offsets tuples in a dictionary to separate keys
21 |
22 | Parameters
23 | ----------
24 | grouped_cols : dict
25 | A dictionary of column arrays including values/offsets tuples
26 |
27 | Returns
28 | -------
29 | dict
30 | A dictionary of column arrays with separate keys for values and offsets
31 | """
32 | flat_cols = {}
33 |
34 | for key, value in grouped_cols.items():
35 | if isinstance(value, tuple):
36 | flat_cols[f"{key}__values"] = value[0]
37 | flat_cols[f"{key}__offsets"] = value[1]
38 | else:
39 | flat_cols[key] = value
40 |
41 | return flat_cols
42 |
43 |
44 | def group_values_offsets(flat_cols: dict) -> dict:
45 | """
46 | Convert separate values/offsets keys for columns into tuples w/ a single key
47 |
48 | Parameters
49 | ----------
50 | flat_cols : dict
51 | A dictionary of column arrays with separate keys for values and offsets
52 |
53 | Returns
54 | -------
55 | dict
56 | A dictionary of column arrays including values/offsets tuples
57 | """
58 | grouped_cols = {}
59 |
60 | for key, value in flat_cols.items():
61 | if key.endswith("__values"):
62 | col_name = key.replace("__values", "")
63 | grouped_cols[col_name] = (flat_cols[key], flat_cols[f"{col_name}__offsets"])
64 | elif key.endswith("__offsets"):
65 | pass
66 | else:
67 | grouped_cols[key] = value
68 |
69 | return grouped_cols
70 |
--------------------------------------------------------------------------------
/merlin/dtypes/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2022, NVIDIA CORPORATION.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | # flake8: noqa
18 | from merlin.dtypes import aliases, mappings
19 | from merlin.dtypes.aliases import *
20 | from merlin.dtypes.base import DType
21 | from merlin.dtypes.registry import _dtype_registry
22 | from merlin.dtypes.shape import Dimension, Shape
23 |
24 | # Convenience alias for registering dtypes
25 | register = _dtype_registry.register
26 |
27 |
28 | def dtype(external_dtype):
29 | # If the supplied dtype is None, then there's not a default dtype we can
30 | # universally translate to across frameworks, so raise an error and help
31 | # the downstream developer figure out how to handle that case explicitly
32 | if external_dtype is None:
33 | raise TypeError(
34 | "Merlin doesn't provide a default dtype mapping for `None`. "
35 | "This differs from the Numpy behavior you may be expecting, "
36 | "which treats `None` as an alias for `np.float64`. If you're "
37 | "expecting this dtype to be non-`None`, there may be an issue "
38 | "in upstream code. If you'd like to allow this dtype to be `None`, "
39 | "you can use a `try/except` to catch this error."
40 | )
41 |
42 | # If the supplied dtype is already a Merlin dtype, then there's
43 | # nothing for us to do and we can exit early
44 | if isinstance(external_dtype, DType):
45 | return external_dtype
46 |
47 | # If not, attempt to apply all the registered Merlin dtype mappings.
48 | # If we don't find a match with those, fall back on converting to
49 | # a numpy dtype and trying to match that instead.
50 | base_exc = None
51 | merlin_dtype = None
52 |
53 | try:
54 | merlin_dtype = _dtype_registry.to_merlin(external_dtype)
55 | except (TypeError, KeyError, AttributeError) as exc:
56 | base_exc = exc
57 |
58 | if base_exc or merlin_dtype == aliases.unknown:
59 | try:
60 | merlin_dtype = _dtype_registry.to_merlin_via_numpy(external_dtype)
61 | except TypeError as numpy_exc:
62 | # If we fail to find a match even after we try converting to
63 | # numpy, re-raise the original exception because it has more
64 | # information about the original external dtype that's causing
65 | # the problem. (We want to highlight that dtype, not whatever
66 | # numpy dtype it was converted to in the interim.)
67 | if base_exc:
68 | raise base_exc from numpy_exc
69 | else:
70 | raise numpy_exc
71 |
72 | return merlin_dtype
73 |
--------------------------------------------------------------------------------
/merlin/dtypes/aliases.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2022, NVIDIA CORPORATION.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | from merlin.dtypes.base import DType, ElementType, ElementUnit
18 |
19 | # Unsigned Integer
20 | uint8 = DType("uint8", ElementType.UInt, 8)
21 | uint16 = DType("uint16", ElementType.UInt, 16)
22 | uint32 = DType("uint32", ElementType.UInt, 32)
23 | uint64 = DType("uint64", ElementType.UInt, 64)
24 |
25 | # Signed Integer
26 | int8 = DType("int8", ElementType.Int, 8, signed=True)
27 | int16 = DType("int16", ElementType.Int, 16, signed=True)
28 | int32 = DType("int32", ElementType.Int, 32, signed=True)
29 | int64 = DType("int64", ElementType.Int, 64, signed=True)
30 |
31 | # Float
32 | float16 = DType("float16", ElementType.Float, 16, signed=True)
33 | float32 = DType("float32", ElementType.Float, 32, signed=True)
34 | float64 = DType("float64", ElementType.Float, 64, signed=True)
35 |
36 | # Date/Time
37 | datetime64 = DType("datetime64", ElementType.DateTime, 64)
38 | datetime64Y = DType("datetime64[Y]", ElementType.DateTime, 64, ElementUnit.Year)
39 | datetime64M = DType("datetime64[M]", ElementType.DateTime, 64, ElementUnit.Month)
40 | datetime64D = DType("datetime64[D]", ElementType.DateTime, 64, ElementUnit.Day)
41 | datetime64h = DType("datetime64[h]", ElementType.DateTime, 64, ElementUnit.Hour)
42 | datetime64m = DType("datetime64[m]", ElementType.DateTime, 64, ElementUnit.Minute)
43 | datetime64s = DType("datetime64[s]", ElementType.DateTime, 64, ElementUnit.Second)
44 | datetime64ms = DType("datetime64[ms]", ElementType.DateTime, 64, ElementUnit.Millisecond)
45 | datetime64us = DType("datetime64[us]", ElementType.DateTime, 64, ElementUnit.Microsecond)
46 | datetime64ns = DType("datetime64[ns]", ElementType.DateTime, 64, ElementUnit.Nanosecond)
47 |
48 | # Miscellaneous
49 | string = DType("str", ElementType.String)
50 | boolean = DType("bool", ElementType.Bool)
51 | object_ = DType("object", ElementType.Object)
52 | struct = DType("struct", ElementType.Struct)
53 | unknown = DType("unknown", ElementType.Unknown)
54 |
--------------------------------------------------------------------------------
/merlin/dtypes/mappings/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2022, NVIDIA CORPORATION.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | # flake8: noqa
18 | from merlin.dtypes.mappings import cudf, merlin, numpy, pandas, python, tf, torch, triton
19 |
--------------------------------------------------------------------------------
/merlin/dtypes/mappings/cudf.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2022, NVIDIA CORPORATION.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 | import numpy as np
17 |
18 | import merlin.dtypes.aliases as mn
19 | from merlin.core.compat import cudf
20 | from merlin.core.dispatch import is_string_dtype
21 | from merlin.dtypes.mapping import DTypeMapping, NumpyPreprocessor
22 | from merlin.dtypes.registry import _dtype_registry
23 |
24 |
25 | def cudf_translator(raw_dtype) -> np.dtype:
26 | """
27 | Translate cudf dtypes to Numpy dtypes
28 |
29 | Parameters
30 | ----------
31 | raw_dtype : cudf dtype
32 | The dtype to be translated
33 |
34 | Returns
35 | -------
36 | np.dtype
37 | The result of translating raw_dtype to Numpy
38 | """
39 | category_type = raw_dtype._categories.dtype
40 | if is_string_dtype(category_type):
41 | return np.dtype("str")
42 | else:
43 | return category_type
44 |
45 |
46 | if cudf:
47 | try:
48 | # We only want to register this mapping if cudf is available, even though
49 | # the mapping itself doesn't use cudf (yet?)
50 |
51 | cudf_dtypes = DTypeMapping(
52 | {
53 | mn.struct: [cudf.StructDtype],
54 | },
55 | translator=NumpyPreprocessor("cudf", cudf_translator, attrs=["_categories"]),
56 | )
57 | _dtype_registry.register("cudf", cudf_dtypes)
58 | except ImportError as exc:
59 | from warnings import warn
60 |
61 | warn(f"cuDF dtype mappings did not load successfully due to an error: {exc.msg}")
62 |
--------------------------------------------------------------------------------
/merlin/dtypes/mappings/merlin.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2022, NVIDIA CORPORATION.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 | import merlin.dtypes.aliases as mn
17 | from merlin.dtypes.registry import _dtype_registry
18 |
19 | merlin_dtypes = {
20 | # Unsigned Integer
21 | mn.uint8: ["uint8"],
22 | mn.uint16: ["uint16"],
23 | mn.uint32: ["uint32"],
24 | mn.uint64: ["uint64"],
25 | # Signed integer
26 | mn.int8: ["int8"],
27 | mn.int16: ["int16"],
28 | mn.int32: ["int32"],
29 | mn.int64: ["int64"],
30 | # Floating Point
31 | mn.float16: ["float16"],
32 | mn.float32: ["float32"],
33 | mn.float64: ["float64"],
34 | # Date/Time
35 | mn.datetime64: ["datetime64"],
36 | mn.datetime64Y: ["datetime64[Y]"],
37 | mn.datetime64M: ["datetime64[M]"],
38 | mn.datetime64D: ["datetime64[D]"],
39 | mn.datetime64h: ["datetime64[h]"],
40 | mn.datetime64m: ["datetime64[m]"],
41 | mn.datetime64s: ["datetime64[s]"],
42 | mn.datetime64ms: ["datetime64[ms]"],
43 | mn.datetime64us: ["datetime64[us]"],
44 | mn.datetime64ns: ["datetime64[ns]"],
45 | # Miscellaneous
46 | mn.string: ["str", "string"],
47 | mn.object_: ["object"],
48 | mn.struct: ["struct"],
49 | mn.boolean: ["bool", "boolean"],
50 | }
51 | _dtype_registry.register("merlin", merlin_dtypes)
52 |
--------------------------------------------------------------------------------
/merlin/dtypes/mappings/numpy.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2022, NVIDIA CORPORATION.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 | import numpy as np
17 |
18 | import merlin.dtypes.aliases as mn
19 | from merlin.dtypes.registry import _dtype_registry
20 |
21 | numpy_dtypes = {
22 | # Unsigned Integer
23 | mn.uint8: [np.dtype("uint8"), np.uint8],
24 | mn.uint16: [np.dtype("uint16"), np.uint16],
25 | mn.uint32: [np.dtype("uint32"), np.uint32],
26 | mn.uint64: [np.dtype("uint64"), np.uint64],
27 | # Signed integer
28 | mn.int8: [np.dtype("int8"), np.int8],
29 | mn.int16: [np.dtype("int16"), np.int16],
30 | mn.int32: [np.dtype("int32"), np.int32],
31 | mn.int64: [np.dtype("int64"), np.int64],
32 | # Floating Point
33 | mn.float16: [np.dtype("float16"), np.float16],
34 | mn.float32: [np.dtype("float32"), np.float32],
35 | mn.float64: [np.dtype("float64"), np.float64],
36 | # Date/Time
37 | mn.datetime64: [np.dtype("datetime64"), np.datetime64],
38 | mn.datetime64Y: [np.dtype("datetime64[Y]")],
39 | mn.datetime64M: [np.dtype("datetime64[M]")],
40 | mn.datetime64D: [np.dtype("datetime64[D]")],
41 | mn.datetime64h: [np.dtype("datetime64[h]")],
42 | mn.datetime64m: [np.dtype("datetime64[m]")],
43 | mn.datetime64s: [np.dtype("datetime64[s]")],
44 | mn.datetime64ms: [np.dtype("datetime64[ms]")],
45 | mn.datetime64us: [np.dtype("datetime64[us]")],
46 | mn.datetime64ns: [np.dtype("datetime64[ns]")],
47 | # Miscellaneous
48 | mn.string: [np.dtype("str"), str],
49 | mn.object_: [np.dtype("O"), object],
50 | mn.boolean: [np.dtype("bool"), bool],
51 | }
52 | _dtype_registry.register("numpy", numpy_dtypes)
53 |
--------------------------------------------------------------------------------
/merlin/dtypes/mappings/pandas.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2022, NVIDIA CORPORATION.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 | from typing import Callable
17 |
18 | import merlin.dtypes.aliases as mn
19 | from merlin.dtypes.mapping import DTypeMapping, NumpyPreprocessor
20 | from merlin.dtypes.registry import _dtype_registry
21 |
22 | try:
23 | import pandas as pd
24 | from pandas.core.dtypes.base import ExtensionDtype
25 |
26 | def _translate_to_numpy(raw_dtype):
27 | if isinstance(raw_dtype, ExtensionDtype) and isinstance(raw_dtype, Callable):
28 | raw_dtype = raw_dtype()
29 |
30 | return raw_dtype.numpy_dtype
31 |
32 | pandas_dtypes = DTypeMapping(
33 | {
34 | mn.string: [pd.StringDtype(), pd.StringDtype],
35 | mn.boolean: [pd.BooleanDtype(), pd.BooleanDtype],
36 | },
37 | translator=NumpyPreprocessor("pandas", _translate_to_numpy, attrs=["numpy_dtype"]),
38 | )
39 | _dtype_registry.register("pandas", pandas_dtypes)
40 | except ImportError as exc:
41 | from warnings import warn
42 |
43 | warn(f"Pandas dtype mappings did not load successfully due to an error: {exc.msg}")
44 |
--------------------------------------------------------------------------------
/merlin/dtypes/mappings/python.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2022, NVIDIA CORPORATION.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 | import merlin.dtypes.aliases as mn
17 | from merlin.dtypes.mapping import DTypeMapping
18 | from merlin.dtypes.registry import _dtype_registry
19 |
20 | python_dtypes = DTypeMapping(
21 | {
22 | mn.boolean: bool,
23 | mn.int64: int,
24 | mn.float64: float,
25 | mn.string: str,
26 | }
27 | )
28 | _dtype_registry.register("python", python_dtypes)
29 |
--------------------------------------------------------------------------------
/merlin/dtypes/mappings/tf.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2022, NVIDIA CORPORATION.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 | import merlin.dtypes.aliases as mn
17 | from merlin.dtypes.mapping import DTypeMapping, NumpyPreprocessor
18 | from merlin.dtypes.registry import _dtype_registry
19 |
20 | try:
21 | from tensorflow import dtypes as tf_dtypes
22 |
23 | tf_dtypes = DTypeMapping(
24 | {
25 | # Unsigned Integer
26 | mn.uint8: [tf_dtypes.uint8],
27 | mn.uint16: [tf_dtypes.uint16],
28 | mn.uint32: [tf_dtypes.uint32],
29 | mn.uint64: [tf_dtypes.uint64],
30 | # Signed integer
31 | mn.int8: [tf_dtypes.int8],
32 | mn.int16: [tf_dtypes.int16],
33 | mn.int32: [tf_dtypes.int32],
34 | mn.int64: [tf_dtypes.int64],
35 | # Floating Point
36 | mn.float16: [tf_dtypes.float16],
37 | mn.float32: [tf_dtypes.float32],
38 | mn.float64: [tf_dtypes.float64],
39 | # Miscellaneous
40 | mn.boolean: [tf_dtypes.bool],
41 | },
42 | base_class=tf_dtypes.DType,
43 | translator=NumpyPreprocessor(
44 | "tf", lambda raw: raw.as_numpy_dtype, attrs=["as_numpy_dtype"]
45 | ),
46 | )
47 | _dtype_registry.register("tf", tf_dtypes)
48 | _dtype_registry.register("tensorflow", tf_dtypes)
49 | except ImportError as exc:
50 | from warnings import warn
51 |
52 | warn(f"Tensorflow dtype mappings did not load successfully due to an error: {exc.msg}")
53 |
--------------------------------------------------------------------------------
/merlin/dtypes/mappings/torch.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2022, NVIDIA CORPORATION.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 | import merlin.dtypes.aliases as mn
17 | from merlin.dtypes.registry import _dtype_registry
18 |
19 | try:
20 | from torch import bool as bool_
21 | from torch import float16, float32, float64, int8, int16, int32, int64, uint8
22 |
23 | torch_dtypes = {
24 | # Unsigned Integer
25 | mn.uint8: [uint8],
26 | # Signed integer
27 | mn.int8: [int8],
28 | mn.int16: [int16],
29 | mn.int32: [int32],
30 | mn.int64: [int64],
31 | # Floating Point
32 | mn.float16: [float16],
33 | mn.float32: [float32],
34 | mn.float64: [float64],
35 | # Miscellaneous
36 | mn.boolean: [bool_],
37 | }
38 | _dtype_registry.register("torch", torch_dtypes)
39 | _dtype_registry.register("pytorch", torch_dtypes)
40 | except ImportError as exc:
41 | from warnings import warn
42 |
43 | warn(f"PyTorch dtype mappings did not load successfully due to an error: {exc.msg}")
44 |
--------------------------------------------------------------------------------
/merlin/dtypes/mappings/triton.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2022, NVIDIA CORPORATION.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 | import merlin.dtypes.aliases as mn
17 | from merlin.dtypes.registry import _dtype_registry
18 |
19 | # Only define a Triton dtype mapping if `tritonclient` is available
20 | try:
21 | import tritonclient.grpc.model_config_pb2 as model_config
22 |
23 | # The type constants on `model_config` are literally integers,
24 | # so this only works if we don't have any other dtypes that
25 | # either are or are equivalent to integers. We work around
26 | # this by checking base classes for dtypes that have them (e.g.
27 | # Tensorflow.)
28 | triton_dtypes = {
29 | # Unsigned Integer
30 | mn.uint8: [model_config.TYPE_UINT8],
31 | mn.uint16: [model_config.TYPE_UINT16],
32 | mn.uint32: [model_config.TYPE_UINT32],
33 | mn.uint64: [model_config.TYPE_UINT64],
34 | # Signed integer
35 | mn.int8: [model_config.TYPE_INT8],
36 | mn.int16: [model_config.TYPE_INT16],
37 | mn.int32: [model_config.TYPE_INT32],
38 | mn.int64: [model_config.TYPE_INT64],
39 | # Floating Point
40 | mn.float16: [model_config.TYPE_FP16],
41 | mn.float32: [
42 | model_config.TYPE_FP32,
43 | ],
44 | mn.float64: [model_config.TYPE_FP64],
45 | # Miscellaneous
46 | mn.string: [model_config.TYPE_STRING],
47 | mn.boolean: [model_config.TYPE_BOOL],
48 | }
49 | _dtype_registry.register("triton", triton_dtypes)
50 | except ImportError as exc:
51 | from warnings import warn
52 |
53 | warn(f"Triton dtype mappings did not load successfully due to an error: {exc.msg}")
54 |
--------------------------------------------------------------------------------
/merlin/dtypes/registry.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2022, NVIDIA CORPORATION.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 | from typing import Dict, Union
17 |
18 | import merlin.dtypes.aliases as mn
19 | from merlin.dtypes.mapping import DTypeMapping
20 |
21 |
22 | class DTypeMappingRegistry:
23 | """
24 | A registry of mappings between Merlin dtypes and the dtypes of many external frameworks
25 | """
26 |
27 | def __init__(self):
28 | self.mappings = {}
29 |
30 | def __iter__(self):
31 | return iter(self.mappings)
32 |
33 | def register(self, name: str, mapping: Union[Dict, DTypeMapping]):
34 | """
35 | Register a mapping between Merlin and external dtypes by name
36 |
37 | Parameters
38 | ----------
39 | name : str
40 | Name of the new mapping to register
41 | mapping : Union[Dict, DTypeMapping]
42 | Mapping between Merlin and external dtypes
43 | """
44 | if not isinstance(mapping, DTypeMapping):
45 | mapping = DTypeMapping(mapping)
46 |
47 | self.mappings[name] = mapping
48 |
49 | def from_merlin(self, merlin_dtype, mapping_name):
50 | """
51 | Map a Merlin dtype to an external dtype
52 |
53 | Parameters
54 | ----------
55 | merlin_dtype : DType
56 | A Merlin dtype object
57 | mapping_name : str
58 | The name of the external framework mapping to apply
59 |
60 | Returns
61 | -------
62 | Any
63 | An external framework dtype object
64 |
65 | Raises
66 | ------
67 | TypeError
68 | If the Merlin dtype can't be mapped to an external dtype from the requested framework
69 | """
70 | mapping = self.mappings[mapping_name]
71 | if mapping.matches_merlin(merlin_dtype):
72 | return mapping.to_merlin(merlin_dtype)
73 |
74 | return mn.unknown
75 |
76 | def to_merlin(self, external_dtype):
77 | """
78 | Map an external dtype to a Merlin dtype
79 |
80 | Parameters
81 | ----------
82 | external_dtype : Any
83 | A dtype object from an external framework
84 |
85 | Returns
86 | -------
87 | DType
88 | A Merlin DType object
89 |
90 | Raises
91 | ------
92 | TypeError
93 | If the external dtype can't be mapped to a Merlin dtype
94 | """
95 | for framework, mapping in self.mappings.items():
96 | if mapping.matches_external(external_dtype):
97 | return mapping.to_merlin(external_dtype)
98 |
99 | return mn.unknown
100 |
101 | def to_merlin_via_numpy(self, external_dtype):
102 | """
103 | Map an external dtype to a Merlin dtype by converting the external type to Numpy first
104 |
105 | This is sometimes useful for external framework dtypes that don't have a clear
106 | one-to-one mapping with a Merlin dtype, like cuDF's CategoricalDtype. We can often do
107 | some additional preprocessing on the external framework's dtype to determine the
108 | Numpy dtype of the elements, and then use that as an intermediary to find the
109 | corresponding Merlin dtype.
110 |
111 | Parameters
112 | ----------
113 | external_dtype : Any
114 | A dtype object from an external framework
115 |
116 | Returns
117 | -------
118 | DType
119 | A Merlin DType object
120 |
121 | Raises
122 | ------
123 | TypeError
124 | If the external dtype can't be mapped to a Merlin dtype
125 | """
126 | numpy_dtype = None
127 |
128 | for mapping in self.mappings.values():
129 | if mapping.translator and mapping.translator.matches(external_dtype):
130 | numpy_dtype = mapping.translator.to_numpy(external_dtype)
131 | break
132 |
133 | return self.to_merlin(numpy_dtype)
134 |
135 |
136 | _dtype_registry = DTypeMappingRegistry()
137 |
--------------------------------------------------------------------------------
/merlin/io/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2022-2024, NVIDIA CORPORATION.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 | # flake8: noqa
17 |
18 | from merlin.config import validate_dask_configs
19 |
20 | validate_dask_configs()
21 |
22 | from merlin.io import dataframe_iter, dataset, shuffle
23 | from merlin.io.dataframe_iter import DataFrameIter
24 | from merlin.io.dataset import MERLIN_METADATA_DIR_NAME, Dataset
25 | from merlin.io.shuffle import Shuffle, shuffle_df
26 |
--------------------------------------------------------------------------------
/merlin/io/csv.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2022, NVIDIA CORPORATION.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 | import functools
17 |
18 | import dask.dataframe as dd
19 | from dask.bytes import read_bytes
20 | from dask.utils import parse_bytes
21 | from fsspec.core import get_fs_token_paths
22 | from fsspec.utils import infer_compression
23 |
24 | from merlin.core.compat import dask_cudf
25 | from merlin.core.compat import numpy as np
26 | from merlin.io.dataset_engine import DatasetEngine
27 |
28 |
29 | class CSVDatasetEngine(DatasetEngine):
30 | """CSVDatasetEngine
31 |
32 | Thin wrapper around dask_cudf.read_csv.
33 | """
34 |
35 | def __init__(self, paths, part_size, storage_options=None, cpu=False, **kwargs):
36 | # pylint: disable=access-member-before-definition
37 | super().__init__(paths, part_size, cpu=cpu, storage_options=storage_options)
38 | self._meta = {}
39 | self.csv_kwargs = kwargs
40 | self.csv_kwargs["storage_options"] = storage_options
41 |
42 | # CSV reader needs a list of files
43 | # (Assume flat directory structure if this is a dir)
44 | if len(self.paths) == 1 and self.fs.isdir(self.paths[0]):
45 | self.paths = self.fs.glob(self.fs.sep.join([self.paths[0], "*"]))
46 |
47 | def to_ddf(self, columns=None, cpu=None):
48 | # Check if we are using cpu
49 | cpu = self.cpu if cpu is None else cpu
50 | if cpu:
51 | ddf = dd.read_csv(self.paths, blocksize=self.part_size, **self.csv_kwargs)
52 | else:
53 | ddf = dask_cudf.read_csv(self.paths, chunksize=self.part_size, **self.csv_kwargs)
54 | if columns:
55 | ddf = ddf[columns]
56 | return ddf
57 |
58 | @property # type: ignore
59 | @functools.lru_cache(1)
60 | def _file_partition_map(self):
61 | ind = 0
62 | _pp_map = {}
63 | for path, blocks in zip(
64 | *_byte_block_counts(
65 | self.paths,
66 | self.part_size,
67 | **self.csv_kwargs,
68 | )
69 | ):
70 | _pp_map[path.split(self.fs.sep)[-1]] = np.arange(ind, ind + blocks)
71 | ind += blocks
72 | return _pp_map
73 |
74 | def to_cpu(self):
75 | self.cpu = True
76 |
77 | def to_gpu(self):
78 | self.cpu = False
79 |
80 |
81 | def _byte_block_counts(
82 | urlpath,
83 | blocksize,
84 | lineterminator=None,
85 | compression="infer",
86 | storage_options=None,
87 | **kwargs,
88 | ):
89 | """Return a list of paths and block counts.
90 |
91 | Logic copied from dask.bytes.read_bytes
92 | """
93 |
94 | if lineterminator is not None and len(lineterminator) == 1:
95 | kwargs["lineterminator"] = lineterminator
96 | else:
97 | lineterminator = "\n"
98 |
99 | if compression == "infer":
100 | paths = get_fs_token_paths(urlpath, mode="rb", storage_options=storage_options)[2]
101 | compression = infer_compression(paths[0])
102 |
103 | if isinstance(blocksize, str):
104 | blocksize = parse_bytes(blocksize)
105 | if blocksize and compression:
106 | blocksize = None
107 |
108 | b_out = read_bytes(
109 | urlpath,
110 | delimiter=lineterminator.encode(),
111 | blocksize=blocksize,
112 | sample=False,
113 | compression=compression,
114 | include_path=True,
115 | **(storage_options or {}),
116 | )
117 | _, values, paths = b_out
118 |
119 | if not isinstance(values[0], (tuple, list)):
120 | values = [values]
121 |
122 | return paths, [len(v) for v in values]
123 |
--------------------------------------------------------------------------------
/merlin/io/dataframe_iter.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2022, NVIDIA CORPORATION.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 |
18 | class DataFrameIter:
19 | def __init__(self, ddf, columns=None, indices=None, partition_lens=None, epochs=1):
20 | self.indices = indices if isinstance(indices, list) else range(ddf.npartitions)
21 | self._ddf = ddf
22 | self.columns = columns
23 | self.partition_lens = partition_lens
24 | self.epochs = epochs
25 |
26 | def __len__(self):
27 | if self.partition_lens:
28 | # Use metadata-based partition-size information
29 | # if/when it is available. Note that this metadata
30 | # will not be correct if rows where added or dropped
31 | # after IO (within Ops).
32 | return sum(self.partition_lens[i] for i in self.indices) * self.epochs
33 | if len(self.indices) < self._ddf.npartitions:
34 | return len(self._ddf.partitions[self.indices]) * self.epochs
35 | return len(self._ddf) * self.epochs
36 |
37 | def __iter__(self):
38 | for epoch in range(self.epochs):
39 | for i in self.indices:
40 | part = self._ddf.get_partition(i)
41 | if self.columns:
42 | yield part[self.columns].compute(scheduler="synchronous")
43 | else:
44 | yield part.compute(scheduler="synchronous")
45 | part = None
46 |
--------------------------------------------------------------------------------
/merlin/io/dataset_engine.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2022, NVIDIA CORPORATION.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | from dask.utils import natural_sort_key
18 | from fsspec.core import get_fs_token_paths
19 |
20 |
21 | class DatasetEngine:
22 | """Base class for Dask-powered IO engines. Engines must provide a ``to_ddf`` method."""
23 |
24 | def __init__(self, paths, part_size, cpu=False, storage_options=None):
25 | paths = sorted(paths, key=natural_sort_key)
26 | self.paths = paths
27 | self.part_size = part_size
28 | self.storage_options = storage_options
29 | fs, fs_token, paths2 = get_fs_token_paths(
30 | paths, mode="rb", storage_options=self.storage_options
31 | )
32 | self.stripped_paths = paths2
33 | self.fs = fs
34 | self.fs_token = fs_token
35 | self.cpu = cpu
36 |
37 | def to_ddf(self, columns=None, cpu=None):
38 | raise NotImplementedError(""" Return a dask.dataframe.DataFrame or dask_cudf.DataFrame""")
39 |
40 | def to_cpu(self):
41 | raise NotImplementedError(""" Move data to CPU memory """)
42 |
43 | def to_gpu(self):
44 | raise NotImplementedError(""" Move data to GPU memory """)
45 |
46 | @property
47 | def _path_partition_map(self):
48 | return None
49 |
50 | @property
51 | def num_rows(self):
52 | raise NotImplementedError(""" Returns the number of rows in the dataset """)
53 |
54 | def sample_data(self, n=1):
55 | """Return a sample of real data from the dataset
56 |
57 | Sample the partitions of the underlying Dask collection
58 | until a non-empty partition is found. Then, use the first
59 | ``n`` rows of that partition to infer dtype info. If no
60 | non-empty partitions are found, use the Dask metadata.
61 | """
62 | _ddf = self.to_ddf()
63 | for partition_index in range(_ddf.npartitions):
64 | _head = _ddf.partitions[partition_index].head(n)
65 | if len(_head):
66 | return _head
67 | return _ddf._meta
68 |
--------------------------------------------------------------------------------
/merlin/io/hugectr.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2022, NVIDIA CORPORATION.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 | import os
17 | from uuid import uuid4
18 |
19 | import numpy as np
20 |
21 | from merlin.io.writer import ThreadedWriter
22 |
23 |
24 | class HugeCTRWriter(ThreadedWriter):
25 | def __init__(self, out_dir, suffix=".data", **kwargs):
26 | super().__init__(out_dir, **kwargs)
27 | self.suffix = suffix
28 | if self.use_guid:
29 | self.data_paths = [
30 | os.path.join(out_dir, f"{i}i.{uuid4().hex}{self.suffix}")
31 | for i in range(self.num_out_files)
32 | ]
33 | else:
34 | self.data_paths = [
35 | os.path.join(out_dir, f"{i}{self.suffix}") for i in range(self.num_out_files)
36 | ]
37 | self.data_writers = [open(f, "wb") for f in self.data_paths]
38 | # Reserve 64 bytes for header
39 | header = np.array([0, 0, 0, 0, 0, 0, 0, 0], dtype=np.longlong)
40 | for i, writer in enumerate(self.data_writers):
41 | writer.write(header.tobytes())
42 |
43 | def _write_table(self, idx, data):
44 | # Prepare data format
45 | np_label = data[self.labels].to_pandas().astype(np.single).to_numpy()
46 | np_conts = data[self.conts].to_pandas().astype(np.single).to_numpy()
47 | nnz = np.intc(1)
48 | np_cats = data[self.cats].to_pandas().astype(np.uintc).to_numpy()
49 | # Write all the data samples
50 | for i, label in enumerate(np_label):
51 | # Write Label
52 | self.data_writers[idx].write(label.tobytes())
53 | # Write conts (HugeCTR: dense)
54 | self.data_writers[idx].write(np_conts[i].tobytes())
55 | # Write cats (HugeCTR: Slots)
56 | for j, _ in enumerate(np_cats[i]):
57 | self.data_writers[idx].write(nnz.tobytes())
58 | self.data_writers[idx].write(np_cats[i][j].tobytes())
59 |
60 | def _write_thread(self):
61 | while True:
62 | item = self.queue.get()
63 | try:
64 | if item is self._eod:
65 | break
66 | idx, data = item
67 | with self.write_locks[idx]:
68 | self._write_table(idx, data)
69 | finally:
70 | self.queue.task_done()
71 |
72 | def _close_writers(self):
73 | for i, writer in enumerate(self.data_writers):
74 | if self.cats:
75 | # Write HugeCTR Metadata
76 | writer.seek(0)
77 | # error_check (0: no error check; 1: check_num)
78 | # num of samples in this file
79 | # Dimension of the labels
80 | # Dimension of the features
81 | # slot_num for each embedding
82 | # reserved for future use
83 | header = np.array(
84 | [
85 | 0,
86 | self.num_samples[i],
87 | len(self.labels),
88 | len(self.conts),
89 | len(self.cats),
90 | 0,
91 | 0,
92 | 0,
93 | ],
94 | dtype=np.longlong,
95 | )
96 | writer.write(header.tobytes())
97 | writer.close()
98 | return None
99 |
100 | def _bytesio_to_disk(self):
101 | raise ValueError("hugectr binary format doesn't support PER_WORKER shuffle yet")
102 |
--------------------------------------------------------------------------------
/merlin/io/shuffle.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2022, NVIDIA CORPORATION.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 | import enum
17 | import warnings
18 |
19 | import pandas as pd
20 | from packaging.version import Version
21 |
22 | from merlin.core.compat import cudf
23 |
24 | _IGNORE_INDEX_SUPPORTED = Version(pd.__version__) >= Version("1.3.0")
25 |
26 | if cudf:
27 | try:
28 | _CUDF_IGNORE_INDEX_SUPPORTED = Version(cudf.__version__) >= Version("22.04.0")
29 | except ImportError:
30 | _CUDF_IGNORE_INDEX_SUPPORTED = None
31 |
32 |
33 | class Shuffle(enum.Enum):
34 | PER_PARTITION = 0
35 | PER_WORKER = 1
36 | FULL = 2
37 |
38 |
39 | #
40 | # Helper Function definitions
41 | #
42 |
43 |
44 | def _check_shuffle_arg(shuffle):
45 | if shuffle is None:
46 | return shuffle
47 |
48 | if isinstance(shuffle, Shuffle):
49 | if shuffle == Shuffle.FULL:
50 | raise ValueError('`shuffle="full"` is not yet supported.')
51 | elif shuffle is True:
52 | shuffle = Shuffle.PER_WORKER
53 | warnings.warn("`shuffle=True` is deprecated. Using `PER_WORKER`.", DeprecationWarning)
54 | elif shuffle is False:
55 | shuffle = None
56 | else:
57 | raise ValueError(f"`shuffle={shuffle}` not recognized.")
58 | return shuffle
59 |
60 |
61 | def shuffle_df(df, size=None, keep_index=False):
62 | """Shuffles a DataFrame, returning a new dataframe with randomly
63 | ordered rows"""
64 | size = size or len(df)
65 | if isinstance(df, pd.DataFrame):
66 | if _IGNORE_INDEX_SUPPORTED:
67 | return df.sample(n=size, ignore_index=not keep_index)
68 | else:
69 | # Pandas<1.3.0
70 | if keep_index:
71 | return df.sample(n=size)
72 | return df.sample(n=size).reset_index(drop=True)
73 | else:
74 | if _CUDF_IGNORE_INDEX_SUPPORTED:
75 | return df.sample(n=size, ignore_index=not keep_index)
76 | else:
77 | return df.sample(n=size, keep_index=keep_index)
78 |
--------------------------------------------------------------------------------
/merlin/io/writer_factory.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2022, NVIDIA CORPORATION.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 | from fsspec.core import get_fs_token_paths
17 |
18 | from merlin.io.hugectr import HugeCTRWriter
19 | from merlin.io.parquet import CPUParquetWriter, GPUParquetWriter
20 |
21 |
22 | def writer_factory(
23 | output_format,
24 | output_path,
25 | out_files_per_proc,
26 | shuffle,
27 | use_guid=False,
28 | bytes_io=False,
29 | num_threads=0,
30 | cpu=False,
31 | fns=None,
32 | suffix=None,
33 | fs=None,
34 | **kwargs, # Format-specific arguments
35 | ):
36 | if output_format is None:
37 | return None
38 |
39 | writer_cls, fs = _writer_cls_factory(output_format, output_path, cpu=cpu, fs=fs)
40 | return writer_cls(
41 | output_path,
42 | num_out_files=out_files_per_proc,
43 | shuffle=shuffle,
44 | fs=fs,
45 | use_guid=use_guid,
46 | bytes_io=bytes_io,
47 | num_threads=num_threads,
48 | cpu=cpu,
49 | fns=fns,
50 | suffix=suffix,
51 | **kwargs, # Format-specific arguments
52 | )
53 |
54 |
55 | def _writer_cls_factory(output_format, output_path, cpu=None, fs=None):
56 | if output_format == "parquet" and cpu:
57 | writer_cls = CPUParquetWriter
58 | elif output_format == "parquet":
59 | writer_cls = GPUParquetWriter
60 | elif output_format == "hugectr":
61 | writer_cls = HugeCTRWriter
62 | else:
63 | raise ValueError("Output format not yet supported.")
64 |
65 | if fs is None:
66 | fs = get_fs_token_paths(output_path)[0]
67 | return writer_cls, fs
68 |
--------------------------------------------------------------------------------
/merlin/schema/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2022, NVIDIA CORPORATION.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | # flake8: noqa
18 | from merlin.schema.schema import ColumnSchema, Schema
19 | from merlin.schema.tags import Tags, TagSet, TagsType
20 |
--------------------------------------------------------------------------------
/merlin/schema/io/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2022, NVIDIA CORPORATION.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
--------------------------------------------------------------------------------
/merlin/schema/io/proto_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | from typing import TypeVar
16 |
17 | import betterproto
18 | from betterproto import Message as BetterProtoMessage
19 | from google.protobuf import json_format, text_format
20 | from google.protobuf.message import Message as ProtoMessage
21 |
22 | ProtoMessageType = TypeVar("ProtoMessageType", bound=BetterProtoMessage)
23 |
24 |
25 | def has_field(message: ProtoMessageType, field_name: str) -> bool:
26 | """Check if a Protobuf message has a particular field
27 |
28 | Parameters
29 | ----------
30 | message : ProtoMessageType
31 | Protobuf message object
32 | field_name : str
33 | Name of the field to look for
34 | message: ProtoMessageType :
35 |
36 | Returns
37 | -------
38 | bool
39 | `True` if the named field exists on the message object
40 |
41 | """
42 | return betterproto.serialized_on_wire(getattr(message, field_name))
43 |
44 |
45 | def copy_better_proto_message(better_proto_message: ProtoMessageType, **kwargs) -> ProtoMessageType:
46 | """Create a copy of a Protobuf message
47 |
48 | Parameters
49 | ----------
50 | better_proto_message : ProtoMessageType
51 | The message to copy
52 |
53 | Returns
54 | -------
55 | ProtoMessageType
56 | Copy of better_proto_message
57 |
58 | """
59 | output = better_proto_message.__class__().parse(bytes(better_proto_message))
60 | for key, val in kwargs.items():
61 | setattr(output, key, val)
62 |
63 | return output
64 |
65 |
66 | def better_proto_to_proto_text(
67 | better_proto_message: BetterProtoMessage, message: ProtoMessage
68 | ) -> str:
69 | """Convert a BetterProto message object to Protobuf text format
70 |
71 | Parameters
72 | ----------
73 | better_proto_message : BetterProtoMessage
74 | The message to convert
75 | message : ProtoMessage
76 | A blank (raw) Protobuf message object to parse into
77 |
78 | Returns
79 | -------
80 | str
81 | Protobuf text representation of better_proto_message
82 |
83 | """
84 | message.ParseFromString(bytes(better_proto_message))
85 |
86 | return text_format.MessageToString(message)
87 |
88 |
89 | def proto_text_to_better_proto(
90 | better_proto_message: ProtoMessageType, proto_text: str, message: ProtoMessage
91 | ) -> ProtoMessageType:
92 | """Convert a Protobuf text format message into a BetterProto message object
93 |
94 | Parameters
95 | ----------
96 | better_proto_message : ProtoMessageType
97 | A BetterProto message object of the desired type
98 | proto_text : str
99 | The Protobuf text format message to convert
100 | message : ProtoMessage
101 | A blank (raw) Protobuf message object to parse proto_text into
102 |
103 | Returns
104 | -------
105 | ProtoMessageType
106 | A BetterProto message object containing attributes parsed from proto_text
107 |
108 | """
109 | proto = text_format.Parse(proto_text, message)
110 | json_str = json_format.MessageToJson(proto)
111 |
112 | return better_proto_message.__class__().from_json(json_str)
113 |
--------------------------------------------------------------------------------
/merlin/table/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2022, NVIDIA CORPORATION.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | # flake8: noqa
18 | from merlin.table.conversions import df_from_tensor_table, tensor_table_from_df
19 | from merlin.table.cupy_column import CupyColumn
20 | from merlin.table.numpy_column import NumpyColumn
21 | from merlin.table.tensor_column import Device, TensorColumn
22 | from merlin.table.tensor_table import TensorTable
23 | from merlin.table.tensorflow_column import TensorflowColumn
24 | from merlin.table.torch_column import TorchColumn
25 |
--------------------------------------------------------------------------------
/merlin/table/cupy_column.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2023, NVIDIA CORPORATION.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 | from typing import Callable, Optional, Type
17 |
18 | from merlin.core.compat import cupy as cp
19 | from merlin.table.conversions import _from_dlpack_gpu, _to_dlpack
20 | from merlin.table.tensor_column import Device, TensorColumn
21 |
22 |
23 | class CupyColumn(TensorColumn):
24 | """
25 | A SeriesLike column backed by CuPy arrays
26 | """
27 |
28 | @classmethod
29 | def array_type(cls) -> Optional[Type]:
30 | """
31 | The type of the arrays backing this column
32 | """
33 | return cp.ndarray if cp else None
34 |
35 | @classmethod
36 | def array_constructor(cls) -> Callable:
37 | return cp.asarray
38 |
39 | @classmethod
40 | def supported_devices(cls):
41 | """
42 | List of device types supported by this column type
43 | """
44 | return [Device.GPU]
45 |
46 | def __init__(
47 | self,
48 | values: "cp.ndarray",
49 | offsets: "cp.ndarray" = None,
50 | dtype=None,
51 | _ref=None,
52 | _unsafe=False,
53 | ):
54 | super().__init__(values, offsets, dtype, _ref=_ref, _device=Device.GPU, _unsafe=_unsafe)
55 |
56 | def cpu(self):
57 | """
58 | Move this column's data to host (i.e. CPU) memory
59 |
60 | Returns
61 | -------
62 | NumpyColumn
63 | A copy of this column backed by NumPy arrays
64 | """
65 | from merlin.table import NumpyColumn
66 |
67 | values = cp.asnumpy(self.values)
68 | offsets = cp.asnumpy(self.offsets) if self.offsets is not None else None
69 |
70 | return NumpyColumn(values, offsets)
71 |
72 | def gpu(self):
73 | """
74 | Move this column's data to device (i.e. GPU) memory
75 |
76 | Returns
77 | -------
78 | CupyColumn
79 | This column, unchanged and backed by CuPy arrays
80 | """
81 | return self
82 |
83 | @property
84 | def _flatten_values(self):
85 | return self.values.flatten()
86 |
87 | def _reshape_values(self, values, shape):
88 | return cp.reshape(values, shape)
89 |
90 |
91 | @_to_dlpack.register_lazy("cupy")
92 | def _register_to_dlpack_from_cupy():
93 | import cupy as cp
94 |
95 | @_to_dlpack.register(cp.ndarray)
96 | def _to_dlpack_from_cp_tensor(tensor):
97 | if tensor.dtype == cp.dtype("bool"):
98 | tensor = tensor.astype(cp.dtype("int8"))
99 | return tensor
100 |
101 |
102 | @_from_dlpack_gpu.register_lazy("cupy")
103 | def _register_from_dlpack_gpu_to_cupy():
104 | import cupy as cp
105 |
106 | @_from_dlpack_gpu.register(cp.ndarray)
107 | def _from_dlpack_gpu_to_cupy(to, array) -> cp.ndarray:
108 | return cp.from_dlpack(array)
109 |
--------------------------------------------------------------------------------
/merlin/table/numpy_column.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2023, NVIDIA CORPORATION.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 | from typing import Callable, Type
17 |
18 | from merlin.core.compat import cupy as cp
19 | from merlin.core.compat import numpy as np
20 | from merlin.table.conversions import _from_dlpack_cpu, _to_dlpack
21 | from merlin.table.tensor_column import Device, TensorColumn
22 |
23 |
24 | class NumpyColumn(TensorColumn):
25 | """
26 | A SeriesLike column backed by NumPy arrays
27 | """
28 |
29 | @classmethod
30 | def array_type(cls) -> Type:
31 | """
32 | The type of the arrays backing this column
33 | """
34 | return np.ndarray
35 |
36 | @classmethod
37 | def array_constructor(cls) -> Callable:
38 | return np.array
39 |
40 | @classmethod
41 | def supported_devices(cls):
42 | """
43 | List of device types supported by this column type
44 | """
45 | return [Device.CPU]
46 |
47 | def __init__(
48 | self,
49 | values: "np.ndarray",
50 | offsets: "np.ndarray" = None,
51 | dtype=None,
52 | _ref=None,
53 | _unsafe=False,
54 | ):
55 | super().__init__(values, offsets, dtype, _ref=_ref, _device=Device.CPU, _unsafe=_unsafe)
56 |
57 | def cpu(self):
58 | """
59 | Move this column's data to host (i.e. CPU) memory
60 |
61 | Returns
62 | -------
63 | NumpyColumn
64 | This column, unchanged and backed by NumPy arrays
65 | """
66 | return self
67 |
68 | def gpu(self):
69 | """
70 | Move this column's data to device (i.e. GPU) memory
71 |
72 | Returns
73 | -------
74 | CupyColumn
75 | A copy of this column backed by CuPy arrays
76 | """
77 |
78 | from merlin.table import CupyColumn
79 |
80 | values = cp.asarray(self.values)
81 | offsets = cp.asarray(self.offsets) if self.offsets is not None else None
82 |
83 | return CupyColumn(values, offsets)
84 |
85 | @property
86 | def _flatten_values(self):
87 | return self.values.flatten()
88 |
89 | def _reshape_values(self, values, shape):
90 | return np.reshape(values, shape)
91 |
92 |
93 | @_from_dlpack_cpu.register_lazy("numpy")
94 | def _register_from_dlpack_cpu_to_numpy():
95 | import numpy as np
96 |
97 | @_from_dlpack_cpu.register(np.ndarray)
98 | def _from_dlpack_cpu_to_numpy(to, array):
99 | try:
100 | # private `_from_dlpack` method added in 1.22.0
101 | return np._from_dlpack(array)
102 | except AttributeError:
103 | pass
104 | try:
105 | # public `from_dlpack` method added in 1.23.0
106 | return np.from_dlpack(array)
107 | except AttributeError as exc:
108 | raise NotImplementedError(
109 | "NumPy does not implement the DLPack Standard until version 1.22.0, "
110 | f"currently running {np.__version__}"
111 | ) from exc
112 |
113 |
114 | @_to_dlpack.register_lazy("numpy")
115 | def _register_from_numpy_to_dlpack_cpu():
116 | import numpy as np
117 |
118 | @_to_dlpack.register(np.ndarray)
119 | def _to_dlpack_cpu_from_numpy(array):
120 | if array.dtype == np.dtype("bool"):
121 | array = array.astype(np.dtype("int8"))
122 | return array
123 |
--------------------------------------------------------------------------------
/merlin/table/torch_column.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2023, NVIDIA CORPORATION.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 | from typing import Callable, Optional, Type
17 |
18 | from merlin.core.compat.torch import torch as th
19 | from merlin.table.conversions import _from_dlpack_cpu, _from_dlpack_gpu, _to_dlpack
20 | from merlin.table.tensor_column import Device, TensorColumn
21 |
22 |
23 | class TorchColumn(TensorColumn):
24 | """
25 | A SeriesLike column backed by Torch tensors
26 | """
27 |
28 | framework_name = "torch"
29 |
30 | @classmethod
31 | def array_type(cls) -> Optional[Type]:
32 | """
33 | The type of the arrays backing this column
34 | """
35 | return th.Tensor if th else None
36 |
37 | @classmethod
38 | def array_constructor(cls) -> Callable:
39 | return th.tensor
40 |
41 | @classmethod
42 | def supported_devices(cls):
43 | """
44 | List of device types supported by this column type
45 | """
46 | return [Device.CPU, Device.GPU]
47 |
48 | def __init__(
49 | self, values: "th.Tensor", offsets: "th.Tensor" = None, dtype=None, _ref=None, _unsafe=False
50 | ):
51 | values_device = self._th_device(values)
52 | if offsets is not None:
53 | offsets_device = self._th_device(offsets)
54 | if values_device != offsets_device:
55 | raise ValueError(
56 | f"Values and offsets were detected on different devices: "
57 | f"values ({values_device}) and offsets ({offsets_device})."
58 | )
59 |
60 | super().__init__(values, offsets, dtype, _device=values_device, _ref=_ref, _unsafe=_unsafe)
61 |
62 | def cpu(self):
63 | """
64 | Move this column's data to host (i.e. CPU) memory
65 |
66 | Returns
67 | -------
68 | TorchColumn
69 | A copy of this column backed by Torch CPU tensors
70 | """
71 | if self.device is Device.CPU:
72 | return self
73 |
74 | values = self.values.cpu()
75 | offsets = self.offsets.cpu() if self.offsets is not None else None
76 |
77 | return TorchColumn(values, offsets)
78 |
79 | def gpu(self):
80 | """
81 | Move this column's data to device (i.e. GPU) memory
82 |
83 | Returns
84 | -------
85 | TorchColumn
86 | A copy of this column backed by Torch GPU tensors
87 | """
88 | if self.device is Device.GPU:
89 | return self
90 |
91 | values = self.values.cuda()
92 | offsets = self.offsets.cuda() if self.offsets is not None else None
93 |
94 | return TorchColumn(values, offsets)
95 |
96 | @property
97 | def device(self) -> Device:
98 | return self._th_device(self.values)
99 |
100 | @property
101 | def _flatten_values(self):
102 | return th.flatten(self.values)
103 |
104 | def _reshape_values(self, values, shape):
105 | return th.reshape(values, shape)
106 |
107 | def _th_device(self, tensor):
108 | return Device.GPU if tensor.is_cuda else Device.CPU
109 |
110 |
111 | @_to_dlpack.register_lazy("torch")
112 | def _register_to_dlpack_from_torch():
113 | import torch as th
114 |
115 | @_to_dlpack.register(th.Tensor)
116 | def _to_dlpack_from_torch_tensor(tensor):
117 | return tensor
118 |
119 |
120 | @_from_dlpack_cpu.register_lazy("torch")
121 | def _register_from_dlpack_cpu_to_torch():
122 | import torch as th
123 |
124 | @_from_dlpack_cpu.register(th.Tensor)
125 | def _from_dlpack_cpu_to_torch(target_type, array):
126 | return th.utils.dlpack.from_dlpack(array)
127 |
128 |
129 | @_from_dlpack_gpu.register_lazy("torch")
130 | def _register_from_dlpack_gpu_to_torch():
131 | import torch as th
132 |
133 | @_from_dlpack_gpu.register(th.Tensor)
134 | def _from_dlpack_gpu_to_torch(target_type, array):
135 | return th.utils.dlpack.from_dlpack(array.__dlpack__())
136 |
--------------------------------------------------------------------------------
/merlin/testing/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2023, NVIDIA CORPORATION.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | # flake8: noqa
18 | from merlin.testing.assert_equals import assert_transformable_equal
19 |
--------------------------------------------------------------------------------
/merlin/testing/assert_equals.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2023, NVIDIA CORPORATION.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | from merlin.core.compat import pandas as pd
18 | from merlin.dispatch.lazy import lazy_singledispatch
19 | from merlin.table import TensorTable
20 |
21 |
22 | def assert_table_equal(left: TensorTable, right: TensorTable):
23 | pd.testing.assert_frame_equal(left.cpu().to_df(), right.cpu().to_df())
24 |
25 |
26 | @lazy_singledispatch
27 | def assert_transformable_equal(left, right):
28 | raise NotImplementedError
29 |
30 |
31 | @assert_transformable_equal.register(TensorTable)
32 | def _assert_equal_table(left, right):
33 | assert_table_equal(left, right)
34 |
35 |
36 | @assert_transformable_equal.register_lazy("cudf")
37 | def _register_assert_equal_df_cudf():
38 | import cudf
39 |
40 | @assert_transformable_equal.register(cudf.DataFrame)
41 | def _assert_equal_df_cudf(left, right):
42 | cudf.testing.assert_frame_equal(left, right)
43 |
44 |
45 | @assert_transformable_equal.register_lazy("pandas")
46 | def _register_assert_equal_pandas():
47 | import pandas
48 |
49 | @assert_transformable_equal.register(pandas.DataFrame)
50 | def _assert_equal_pandas(left, right):
51 | pandas.testing.assert_frame_equal(left, right)
52 |
--------------------------------------------------------------------------------
/mypy.ini:
--------------------------------------------------------------------------------
1 | [mypy]
2 | python_version = 3.8
3 | warn_unused_configs = True
4 | exclude = versioneer.py
5 | ignore_missing_imports = True
6 | show_traceback = True
7 | strict_optional = False
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = [
3 | "setuptools>=42",
4 | "wheel",
5 | ]
6 | build-backend = "setuptools.build_meta"
7 |
8 | [tool.black]
9 | line-length = 100
10 |
11 | [tool.isort]
12 | use_parentheses = true
13 | multi_line_output = 3
14 | include_trailing_comma = true
15 | force_grid_wrap = 0
16 | ensure_newline_before_comments = true
17 | line_length = 100
18 | balanced_wrapping = true
19 | indent = " "
20 | known_third_party = ["cudf", "cupy", "dask", "dask_cuda", "dask_cudf", "numba", "numpy", "pytest", "torch", "rmm", "tensorflow"]
21 | skip = ["build", ".eggs"]
22 |
23 | [tool.interrogate]
24 | ignore-init-method = true
25 | ignore-init-module = true
26 | ignore-magic = true
27 | ignore-module = true
28 | ignore-private = true
29 | ignore-property-decorators = true
30 | ignore-nested-classes = true
31 | ignore-nested-functions = true
32 | ignore-semiprivate = true
33 | ignore-setters = true
34 | fail-under = 70
35 | exclude = ["build", "docs", "merlin/core", "merlin/io", "tests", "setup.py", "versioneer.py"]
36 | verbose = 1
37 | omit-covered-files = true
38 | quiet = false
39 | whitelist-regex = []
40 | ignore-regex = []
41 | color = true
42 |
43 | [tool.pytest.ini_options]
44 | filterwarnings = [
45 | 'ignore:`np.*` is a deprecated alias:DeprecationWarning',
46 | 'ignore:The default dtype for empty Series:DeprecationWarning',
47 | 'ignore:General-metadata information not detected:UserWarning',
48 | 'ignore:Changing an NVTabular Dataset to CPU mode:UserWarning',
49 | 'ignore:Initializing an NVTabular Dataset in CPU mode:UserWarning',
50 | 'ignore:Performing a hash-based transformation:UserWarning',
51 | 'ignore:WARNING..cuDF.to_dlpack',
52 | 'ignore:::numba.cuda.envvar:',
53 | 'ignore:Call to deprecated create function:DeprecationWarning',
54 | 'ignore:Only created .* files did not have enough partitions to create .* file:UserWarning',
55 | 'ignore:distutils Version classes are deprecated. Use packaging.version instead:DeprecationWarning',
56 | ]
57 |
--------------------------------------------------------------------------------
/requirements-dev.txt:
--------------------------------------------------------------------------------
1 | # packages necessary to run tests and push PRs
2 | # assumes requirements for merlin-core are already installed
3 |
4 | pytest>=5
5 | pytest-cov>=2
6 | pytest-xdist
7 |
--------------------------------------------------------------------------------
/requirements-docs.txt:
--------------------------------------------------------------------------------
1 | -r requirements.txt
2 | -r requirements-dev.txt
3 |
4 | sphinxcontrib-applehelp==1.0.4
5 | sphinxcontrib-devhelp==1.0.2
6 | sphinxcontrib-htmlhelp==2.0.1
7 | sphinxcontrib-qthelp==1.0.3
8 | sphinxcontrib-serializinghtml==1.1.5
9 | sphinx-multiversion@git+https://github.com/mikemckiernan/sphinx-multiversion.git
10 | sphinxcontrib-copydirs@git+https://github.com/mikemckiernan/sphinxcontrib-copydirs.git
11 | recommonmark~=0.7.1
12 | Jinja2<3.1
13 | natsort~=8.4.0
14 | myst-nb~=1.1.0
15 | linkify-it-py~=2.0.3
16 | sphinx-external-toc~=1.0.1
17 | attrs~=23.2.0
18 | sphinx-book-theme~=1.1.2
19 | sphinx_design~=0.5.0
20 |
--------------------------------------------------------------------------------
/requirements-gpu.txt:
--------------------------------------------------------------------------------
1 | -r requirements.txt
2 |
3 | # cudf>=21.12
4 | # dask-cudf>=21.12
5 | # dask-cuda>=21.12
6 |
--------------------------------------------------------------------------------
/requirements-test-cpu.txt:
--------------------------------------------------------------------------------
1 | -r requirements-test.txt
2 | -r requirements.txt
3 |
--------------------------------------------------------------------------------
/requirements-test-gpu.txt:
--------------------------------------------------------------------------------
1 | -r requirements-test.txt
2 | -r requirements-gpu.txt
3 |
--------------------------------------------------------------------------------
/requirements-test.txt:
--------------------------------------------------------------------------------
1 | -r requirements.txt
2 | -r requirements-dev.txt
3 |
4 | # This contains common libraries for testing.
5 |
6 | # NOTE: You should pip install requirements-test-[cpu|gpu].txt for device-specific test
7 | # requirements, which will include the dependencies defined in this file.
8 |
9 | pytest>=5
10 | pytest-cov>=2
11 | testbook==0.4.2
12 |
13 | # needed to make test_s3 work
14 | # moto>=2
15 | # boto3==1.17
16 | # s3fs>=2021.4
17 | # aiobotocore>=1.3.3
18 | # flask
19 | # flask-cors
20 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | dask>=2022.11.1
2 | dask-cuda>=22.12.0
3 | distributed>=2022.11.1
4 | fsspec>=2022.7.1
5 | numpy>=1.22.0
6 | pandas>=1.2.0,<1.6.0dev0
7 | numba>=0.54
8 | pyarrow>=5.0.0
9 | protobuf>=3.0.0
10 | tqdm>=4.0
11 | tensorflow-metadata>=1.2.0
12 | betterproto<2.0.0
13 | packaging
14 | npy-append-array
15 |
16 | # pynvml==11.5.0 is incompatible with distributed<2023.2.1
17 | pynvml>=11.0.0,<11.5
18 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [flake8]
2 | max-line-length = 100
3 | exclude = build,.eggs,*_pb2.py
4 | ignore = E203,W503
5 | per-file-ignores =
6 | examples/criteo_benchmark.py:E402
7 | examples/dataloader_bench.py:E402
8 |
9 | [flake8_nb]
10 | max-line-length = 120
11 | ignore = E203,E402,W503
12 |
13 | [pydocstyle]
14 | ignore = D100,D102,D103,D104,D105,D107,D203,D205,D211,D212,D213,D400,D401,D413,D415
15 |
16 | [codespell]
17 | skip = .*pb2.py,./.git,./.github,./bench,./dist,./docs/build,.*egg-info.*,versioneer.py,*.csv,*.parquet
18 | ignore-words = ./ci/ignore_codespell_words.txt
19 | count =
20 | quiet-level = 3
21 |
22 | # See the docstring in versioneer.py for instructions. Note that you must
23 | # re-run 'versioneer.py setup' after changing this section, and commit the
24 | # resulting files.
25 |
26 | [versioneer]
27 | VCS = git
28 | style = pep440
29 | versionfile_source = merlin/core/_version.py
30 | versionfile_build = merlin/core/_version.py
31 | tag_prefix = v
32 | parentdir_prefix = merlin-core-
33 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2022, NVIDIA CORPORATION.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 | import os
17 | import sys
18 |
19 | from setuptools import find_namespace_packages, setup
20 |
21 | try:
22 | import versioneer
23 | except ImportError:
24 | # we have a versioneer.py file living in the same directory as this file, but
25 | # if we're using pep 517/518 to build from pyproject.toml its not going to find it
26 | # https://github.com/python-versioneer/python-versioneer/issues/193#issue-408237852
27 | # make this work by adding this directory to the python path
28 | sys.path.append(os.path.dirname(os.path.realpath(__file__)))
29 | import versioneer
30 |
31 |
32 | def parse_requirements(filename):
33 | """load requirements from a pip requirements file"""
34 | lineiter = (line.strip() for line in open(filename))
35 | return [line for line in lineiter if line and not line.startswith("#")]
36 |
37 |
38 | install_reqs = parse_requirements("./requirements.txt")
39 |
40 | setup(
41 | name="merlin-core",
42 | version=versioneer.get_version(),
43 | packages=find_namespace_packages(include=["merlin*"]),
44 | url="https://github.com/NVIDIA-Merlin/core",
45 | author="NVIDIA Corporation",
46 | license="Apache 2.0",
47 | long_description=open("README.md", encoding="utf8").read(),
48 | long_description_content_type="text/markdown",
49 | classifiers=[
50 | "Development Status :: 4 - Beta",
51 | "Programming Language :: Python :: 3",
52 | "Intended Audience :: Developers",
53 | "License :: OSI Approved :: Apache Software License",
54 | "Topic :: Software Development :: Libraries",
55 | "Topic :: Scientific/Engineering",
56 | ],
57 | zip_safe=False,
58 | python_requires=">=3.8",
59 | install_requires=install_reqs,
60 | cmdclass=versioneer.get_cmdclass(),
61 | )
62 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVIDIA-Merlin/core/6d396aa1d5e6ffe60bb133dcfe93d869b8548ba2/tests/__init__.py
--------------------------------------------------------------------------------
/tests/unit/core/test_dispatch.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2022, NVIDIA CORPORATION.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 | import numpy as np
17 | import pandas as pd
18 | import pytest
19 |
20 | from merlin.core.compat import HAS_GPU
21 | from merlin.core.compat import cupy as cp
22 | from merlin.core.dispatch import (
23 | concat_columns,
24 | is_list_dtype,
25 | list_val_dtype,
26 | make_df,
27 | nullable_series,
28 | )
29 |
30 | if HAS_GPU:
31 | _DEVICES = ["cpu", "gpu"]
32 | else:
33 | _DEVICES = ["cpu"]
34 |
35 |
36 | @pytest.mark.parametrize("device", _DEVICES)
37 | def test_list_dtypes(tmpdir, device):
38 | df = make_df(device=device)
39 | df["vals"] = [
40 | [[0, 1, 2], [3, 4], [5]],
41 | ]
42 | # Check that the index can be arbitrary
43 | df.set_index(np.array([2]), drop=True, inplace=True)
44 |
45 | assert is_list_dtype(df["vals"])
46 | assert list_val_dtype(df["vals"]) == np.dtype(np.int64)
47 |
48 |
49 | @pytest.mark.parametrize("device", _DEVICES)
50 | def test_concat_columns(device):
51 | df1 = make_df({"a": [1, 2], "b": [[3], [4, 5]]}, device=device)
52 | df2 = make_df({"c": [3, 4, 5]}, device=device)
53 | data_frames = [df1, df2]
54 | res = concat_columns(data_frames)
55 | assert res.columns.to_list() == ["a", "b", "c"]
56 |
57 |
58 | @pytest.mark.skipif(not (cp and HAS_GPU), reason="Cupy not available")
59 | def test_pandas_cupy_combo():
60 | rand_cp_nd_arr = cp.random.uniform(0.0, 1.0, size=100)
61 | with pytest.raises(TypeError) as exc_info:
62 | pd.DataFrame(rand_cp_nd_arr)
63 |
64 | assert "Implicit conversion to a NumPy array is not allowed" in str(exc_info)
65 | pd_df = pd.DataFrame(rand_cp_nd_arr.get())[0]
66 | mk_df = make_df(rand_cp_nd_arr)[0]
67 | assert all(pd_df.to_numpy() == mk_df.to_numpy())
68 |
69 |
70 | @pytest.mark.parametrize(
71 | ["data", "dtype", "expected_series"],
72 | [
73 | [[None], np.dtype("int8"), pd.Series([pd.NA], dtype="Int8")],
74 | [[None], np.dtype("int16"), pd.Series([pd.NA], dtype="Int16")],
75 | [[None], np.dtype("int32"), pd.Series([pd.NA], dtype="Int32")],
76 | [[None], np.dtype("int64"), pd.Series([pd.NA], dtype="Int64")],
77 | [[None], np.dtype("uint8"), pd.Series([pd.NA], dtype="UInt8")],
78 | [[None], np.dtype("uint16"), pd.Series([pd.NA], dtype="UInt16")],
79 | [[None], np.dtype("uint32"), pd.Series([pd.NA], dtype="UInt32")],
80 | [[None], np.dtype("uint64"), pd.Series([pd.NA], dtype="UInt64")],
81 | [[None], np.dtype("float32"), pd.Series([pd.NA], dtype="Float32")],
82 | [[None], np.dtype("float64"), pd.Series([pd.NA], dtype="Float64")],
83 | ],
84 | )
85 | def test_nullable_series(data, dtype, expected_series):
86 | series = nullable_series(data, pd.DataFrame(), dtype)
87 | pd.testing.assert_series_equal(series, expected_series)
88 |
--------------------------------------------------------------------------------
/tests/unit/core/test_protocols.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2022, NVIDIA CORPORATION.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 | import pytest
17 |
18 | from merlin.core.compat import HAS_GPU, cudf
19 | from merlin.core.dispatch import make_df, make_series
20 | from merlin.core.protocols import DataFrameLike, DictLike, SeriesLike, Transformable
21 |
22 | if HAS_GPU and cudf:
23 | _DEVICES = ["cpu", None]
24 | else:
25 | _DEVICES = ["cpu"]
26 |
27 |
28 | @pytest.mark.parametrize("protocol", [DictLike])
29 | def test_dictionary_is_dictlike(protocol):
30 | obj = {}
31 |
32 | assert isinstance(obj, protocol)
33 |
34 |
35 | @pytest.mark.parametrize("device", _DEVICES)
36 | @pytest.mark.parametrize("protocol", [DictLike, DataFrameLike, Transformable])
37 | def test_dataframes_match_protocols(protocol, device):
38 | obj = make_df({}, device=device)
39 |
40 | assert isinstance(obj, protocol)
41 |
42 |
43 | @pytest.mark.parametrize("device", _DEVICES)
44 | def test_series_are_serieslike(device):
45 | obj = make_series([0], device=device)
46 |
47 | assert isinstance(obj, SeriesLike)
48 |
--------------------------------------------------------------------------------
/tests/unit/core/test_version.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2022, NVIDIA CORPORATION.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 | import pytest
17 | from packaging.version import Version
18 |
19 | import merlin.core
20 |
21 |
22 | @pytest.mark.skip(reason="requires full clone of repo to ensure versions are detected.")
23 | def test_version():
24 | assert Version(merlin.core.__version__) >= Version("0.6.0")
25 |
--------------------------------------------------------------------------------
/tests/unit/dag/ops/test_addmetadata.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 |
4 | import merlin.dag.ops.add_metadata as ops
5 | from merlin.dag import ColumnSelector
6 | from merlin.schema import ColumnSchema, Schema
7 |
8 |
9 | @pytest.mark.parametrize("properties", [{}, {"p1": "1"}])
10 | @pytest.mark.parametrize("tags", [[], ["TAG1", "TAG2"]])
11 | @pytest.mark.parametrize(
12 | "op",
13 | [
14 | ops.AddMetadata(tags=["excellent"], properties={"domain": {"min": 0, "max": 20}}),
15 | ops.AddTags(tags=["excellent"]),
16 | ops.AddProperties(properties={"domain": {"min": 0, "max": 20}}),
17 | ops.TagAsUserID(),
18 | ops.TagAsItemID(),
19 | ops.TagAsUserFeatures(),
20 | ops.TagAsItemFeatures(),
21 | ],
22 | )
23 | @pytest.mark.parametrize("selection", [["1"], ["2", "3"], ["1", "2", "3", "4"]])
24 | def test_schema_out(tags, properties, selection, op):
25 | # Create columnSchemas
26 | column_schemas = []
27 | all_cols = []
28 | for x in range(5):
29 | all_cols.append(str(x))
30 | column_schemas.append(
31 | ColumnSchema(str(x), dtype=np.int32, tags=tags, properties=properties)
32 | )
33 |
34 | # Turn to Schema
35 | input_schema = Schema(column_schemas)
36 |
37 | # run schema through op
38 | selector = ColumnSelector(selection)
39 | output_schema = op.compute_output_schema(input_schema, selector)
40 |
41 | # should have dtype float
42 | for input_col_name in selector.names:
43 | output_col_names = [name for name in output_schema.column_schemas if input_col_name in name]
44 | if output_col_names:
45 | for output_col_name in output_col_names:
46 | result_schema = output_schema.column_schemas[output_col_name]
47 |
48 | expected_dtype = op._compute_dtype(
49 | ColumnSchema(output_col_name),
50 | Schema([input_schema.column_schemas[input_col_name]]),
51 | ).dtype
52 |
53 | expected_tags = op._compute_tags(
54 | ColumnSchema(output_col_name),
55 | Schema([input_schema.column_schemas[input_col_name]]),
56 | ).tags
57 |
58 | expected_properties = op._compute_properties(
59 | ColumnSchema(output_col_name),
60 | Schema([input_schema.column_schemas[input_col_name]]),
61 | ).properties
62 |
63 | assert result_schema.dtype == expected_dtype
64 | if output_col_name in selector.names:
65 | assert result_schema.properties == expected_properties
66 |
67 | assert len(result_schema.tags) == len(expected_tags)
68 | else:
69 | assert set(expected_tags).issubset(result_schema.tags)
70 |
71 | not_used = [col for col in all_cols if col not in selector.names]
72 | for input_col_name in not_used:
73 | assert input_col_name not in output_schema.column_schemas
74 |
--------------------------------------------------------------------------------
/tests/unit/dag/ops/test_rename.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2023, NVIDIA CORPORATION.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 | import numpy as np
17 | import pandas as pd
18 | import pytest
19 |
20 | from merlin.core.compat import cudf
21 | from merlin.dag import ColumnSelector
22 | from merlin.dag.ops.rename import Rename
23 | from merlin.table import TensorTable
24 | from merlin.testing import assert_transformable_equal
25 |
26 | transformables = [pd.DataFrame, TensorTable]
27 | if cudf:
28 | transformables.append(cudf.DataFrame)
29 |
30 |
31 | @pytest.mark.parametrize("transformable", transformables)
32 | def test_rename(transformable):
33 | df = transformable({"x": np.array([1, 2, 3, 4, 5]), "y": np.array([6, 7, 8, 9, 10])})
34 |
35 | selector = ColumnSelector(["x", "y"])
36 |
37 | op = Rename(f=lambda name: name.upper())
38 | transformed = op.transform(selector, df)
39 | expected = transformable({"X": np.array([1, 2, 3, 4, 5]), "Y": np.array([6, 7, 8, 9, 10])})
40 | assert_transformable_equal(transformed, expected)
41 |
42 | op = Rename(postfix="_lower")
43 | transformed = op.transform(selector, df)
44 | expected = transformable(
45 | {
46 | "x_lower": np.array([1, 2, 3, 4, 5]),
47 | "y_lower": np.array([6, 7, 8, 9, 10]),
48 | }
49 | )
50 | assert_transformable_equal(transformed, expected)
51 |
52 | selector = ColumnSelector(["x"])
53 |
54 | op = Rename(name="z")
55 | transformed = op.transform(selector, df)
56 | expected = transformable({"z": np.array([1, 2, 3, 4, 5])})
57 | assert_transformable_equal(transformed, expected)
58 |
59 | op = Rename(f=lambda name: name.upper())
60 | transformed = op.transform(selector, df)
61 | expected = transformable({"X": np.array([1, 2, 3, 4, 5])})
62 | assert_transformable_equal(transformed, expected)
63 |
--------------------------------------------------------------------------------
/tests/unit/dag/ops/test_selection.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2022, NVIDIA CORPORATION.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 | import pytest
17 |
18 | from merlin.dag import ColumnSelector
19 | from merlin.dag.ops.selection import SelectionOp
20 | from merlin.schema import ColumnSchema, Schema
21 |
22 |
23 | @pytest.mark.parametrize("engine", ["parquet"])
24 | def test_selection_transform(df):
25 | selector = ColumnSelector(["x", "y"])
26 | op = SelectionOp(selector)
27 |
28 | result_df = op.transform(ColumnSelector(), df)
29 |
30 | assert (result_df.columns == ["x", "y"]).all()
31 |
32 |
33 | @pytest.mark.parametrize("engine", ["parquet"])
34 | def test_selection_output_column_names(df):
35 | selector = ColumnSelector(["x", "y"])
36 |
37 | op = SelectionOp(selector)
38 | result_selector = op.output_column_names(selector)
39 |
40 | assert result_selector.names == ["x", "y"]
41 |
42 |
43 | @pytest.mark.parametrize("engine", ["parquet"])
44 | def test_selection_output_schema(df):
45 | selector = ColumnSelector(["x", "y"])
46 | schema = Schema([ColumnSchema(col) for col in df.columns])
47 | op = SelectionOp(selector)
48 |
49 | result_schema = op.compute_output_schema(schema, ColumnSelector())
50 |
51 | assert result_schema.column_names == ["x", "y"]
52 |
53 |
54 | @pytest.mark.parametrize("engine", ["parquet"])
55 | def test_selection_wildcard_output_schema(df):
56 | selector = ColumnSelector("*")
57 | schema = Schema([ColumnSchema(col) for col in df.columns])
58 | op = SelectionOp(selector)
59 |
60 | result_schema = op.compute_output_schema(schema, ColumnSelector())
61 |
62 | assert result_schema.column_names == schema.column_names
63 |
--------------------------------------------------------------------------------
/tests/unit/dag/ops/test_stat_op.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2023, NVIDIA CORPORATION.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 | from typing import Dict
17 |
18 | import dask.dataframe as dd
19 | import numpy as np
20 | import pandas as pd
21 | import pytest
22 |
23 | from merlin.core.compat import cudf
24 | from merlin.dag import ColumnSelector, Graph
25 | from merlin.dag.executors import DaskExecutor
26 | from merlin.dag.ops.stat_operator import StatOperator
27 | from merlin.io.dataset import Dataset
28 |
29 | transformables = [pd.DataFrame]
30 | if cudf:
31 | transformables.append(cudf.DataFrame)
32 |
33 |
34 | class FitOp(StatOperator):
35 | def fit(self, col_selector: ColumnSelector, ddf: dd.DataFrame):
36 | fit_exactly_once = "fit_exactly_once" not in self.stats
37 | self.stats: Dict[str, bool] = {"fit_exactly_once": fit_exactly_once}
38 |
39 | def fit_finalize(self, dask_stats):
40 | return dask_stats
41 |
42 | def clear(self):
43 | self.stats = {}
44 |
45 |
46 | @pytest.mark.parametrize("transformable", transformables)
47 | @pytest.mark.parametrize("engine", ["parquet"])
48 | def test_fitted_stat_op(transformable, engine):
49 | df = transformable({"x": np.array([1, 2, 3, 4, 5]), "y": np.array([6, 7, 8, 9, 10])})
50 |
51 | op = FitOp()
52 | graph = ["x", "y"] >> op
53 | graph = Graph(graph)
54 | executor = DaskExecutor()
55 |
56 | executor.fit(Dataset(df), graph)
57 | assert op.stats == {"fit_exactly_once": True}
58 |
59 | executor.fit(Dataset(df), graph, refit=False)
60 | assert op.stats == {"fit_exactly_once": True}
61 |
62 |
63 | @pytest.mark.parametrize("transformable", transformables)
64 | @pytest.mark.parametrize("engine", ["parquet"])
65 | def test_fit_op_before_transfrom(transformable, engine):
66 | df = transformable({"x": np.array([1, 2, 3, 4, 5]), "y": np.array([6, 7, 8, 9, 10])})
67 |
68 | op = FitOp()
69 | graph = ["x", "y"] >> op
70 | graph = Graph(graph)
71 | executor = DaskExecutor()
72 | graph.construct_schema(Dataset(df).schema)
73 | with pytest.raises(RuntimeError) as exc:
74 | executor.transform(Dataset(df).to_ddf(), graph)
75 | assert "attempting to use them to transform data" in str(exc.value)
76 |
77 | executor.fit(Dataset(df), graph)
78 | executor.transform(Dataset(df).to_ddf(), graph)
79 | assert op.stats == {"fit_exactly_once": True}
80 |
--------------------------------------------------------------------------------
/tests/unit/dag/ops/test_subgraph.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2022, NVIDIA CORPORATION.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | import pytest
18 |
19 | from merlin.core.protocols import Transformable
20 | from merlin.dag.executors import DaskExecutor, LocalExecutor
21 | from merlin.dag.graph import Graph
22 | from merlin.dag.operator import Operator
23 | from merlin.dag.ops.stat_operator import StatOperator
24 | from merlin.dag.ops.subgraph import Subgraph
25 | from merlin.dag.selector import ColumnSelector
26 | from merlin.io import Dataset
27 | from merlin.schema import Schema
28 |
29 |
30 | @pytest.mark.parametrize("engine", ["parquet"])
31 | def test_subgraph(df):
32 | ops = ["x"] >> Operator() >> Operator()
33 | subgraph_op = Subgraph("subgraph", ops)
34 | main_graph_ops = ["x", "y"] >> Operator() >> subgraph_op >> Operator()
35 |
36 | main_graph = Graph(main_graph_ops)
37 |
38 | main_graph.construct_schema(Schema(list(df.columns)))
39 |
40 | result_df = LocalExecutor().transform(df, main_graph)
41 | assert (result_df == df[["x"]]).all()[0]
42 |
43 | assert main_graph.subgraph("subgraph") == subgraph_op.graph
44 |
45 |
46 | @pytest.mark.parametrize("engine", ["parquet"])
47 | def test_subgraph_fit(dataset):
48 | class FitTestOp(StatOperator):
49 | def fit(self, col_selector: ColumnSelector, ddf):
50 | self.stats = {"fit": True}
51 |
52 | def clear(self):
53 | self.stats = {}
54 |
55 | def fit_finalize(self, dask_stats):
56 | return self.stats
57 |
58 | fit_test_op = FitTestOp()
59 | subgraph_op = Subgraph("subgraph", ["x"] >> fit_test_op)
60 | main_graph_ops = ["x", "y"] >> Operator() >> subgraph_op >> Operator()
61 |
62 | main_graph = Graph(main_graph_ops)
63 | main_graph.construct_schema(dataset.schema)
64 |
65 | executor = DaskExecutor()
66 | executor.fit(dataset, main_graph)
67 | result_df = executor.transform(dataset.to_ddf(), main_graph)
68 |
69 | assert (result_df.compute() == dataset.to_ddf().compute()[["x"]]).all()[0]
70 | assert main_graph.subgraph("subgraph").output_node.op.stats["fit"] is True
71 |
72 |
73 | @pytest.mark.parametrize("engine", ["parquet"])
74 | def test_subgraph_looping(dataset):
75 | class LoopingTestOp(Operator):
76 | def transform(
77 | self, col_selector: ColumnSelector, transformable: Transformable
78 | ) -> Transformable:
79 | return transformable[col_selector.names] + 1.0
80 |
81 | subgraph = ["x"] >> LoopingTestOp()
82 | subgraph_op = Subgraph(
83 | "subgraph",
84 | subgraph,
85 | loop_until=lambda transformable: (transformable["x"] > 5.0).all(),
86 | )
87 | main_graph_ops = ["x", "y"] >> Operator() >> subgraph_op >> Operator()
88 |
89 | main_graph = Graph(main_graph_ops)
90 | main_graph.construct_schema(dataset.schema)
91 |
92 | df = dataset.to_ddf().compute()
93 | df["x"] = df["x"] * 0.0
94 | dataset = Dataset(df)
95 |
96 | executor = DaskExecutor()
97 | executor.fit(dataset, main_graph)
98 | result_df = executor.transform(dataset.to_ddf(), main_graph)
99 |
100 | assert (result_df.compute()[["x"]] > 5.0).all()[0]
101 |
--------------------------------------------------------------------------------
/tests/unit/dag/test_base_operator.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2022, NVIDIA CORPORATION.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 | import pytest
17 |
18 | from merlin.dag.graph import Graph
19 | from merlin.dag.operator import Operator
20 | from merlin.dag.selector import ColumnSelector
21 | from merlin.schema import Schema
22 |
23 |
24 | @pytest.mark.parametrize("engine", ["parquet"])
25 | def test_graph_validates_schemas(dataset, engine):
26 | ops = ["a", "b", "c"] >> Operator()
27 | graph = Graph(ops)
28 |
29 | with pytest.raises(ValueError) as exc_info:
30 | graph.construct_schema(dataset.schema)
31 |
32 | assert "Missing column" in str(exc_info.value)
33 |
34 |
35 | @pytest.mark.parametrize("engine", ["parquet"])
36 | def test_compute_selector_validates_schemas(dataset, engine):
37 | op = Operator()
38 | schema = Schema(["a", "b"])
39 | selector = ColumnSelector(["c"])
40 |
41 | with pytest.raises(ValueError) as exc_info:
42 | op.compute_selector(schema, selector, ColumnSelector(), ColumnSelector())
43 |
44 | assert "Missing column" in str(exc_info.value)
45 |
46 |
47 | @pytest.mark.parametrize("engine", ["parquet"])
48 | def test_compute_input_schema_validates_schemas(dataset, engine):
49 | op = Operator()
50 | schema = Schema(["a", "b"])
51 | selector = ColumnSelector(["c"])
52 |
53 | with pytest.raises(ValueError) as exc_info:
54 | op.compute_input_schema(schema, Schema(), Schema(), selector)
55 |
56 | assert "Missing column" in str(exc_info.value)
57 |
58 | with pytest.raises(ValueError) as exc_info:
59 | op.compute_input_schema(Schema(), schema, Schema(), selector)
60 |
61 | assert "Missing column" in str(exc_info.value)
62 |
63 | with pytest.raises(ValueError) as exc_info:
64 | op.compute_input_schema(Schema(), Schema(), schema, selector)
65 |
66 | assert "Missing column" in str(exc_info.value)
67 |
68 |
69 | @pytest.mark.parametrize("engine", ["parquet"])
70 | def test_compute_output_schema_validates_schemas(dataset, engine):
71 | op = Operator()
72 | schema = Schema(["a", "b"])
73 | selector = ColumnSelector(["c"])
74 |
75 | with pytest.raises(ValueError) as exc_info:
76 | op.compute_output_schema(schema, selector)
77 |
78 | assert "Missing column" in str(exc_info.value)
79 |
--------------------------------------------------------------------------------
/tests/unit/dag/test_dag_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2023, NVIDIA CORPORATION.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 | import numpy as np
17 |
18 | from merlin.dag import group_values_offsets, ungroup_values_offsets
19 |
20 |
21 | def test_flat_dict_to_tuple_dict():
22 | col1 = np.array([1, 2, 3, 4, 5])
23 | col2_values = np.array([6, 7, 8, 9, 10])
24 | col2_offsets = np.array([0, 2, 5])
25 |
26 | flat_dict = {"col1": col1, "col2__values": col2_values, "col2__offsets": col2_offsets}
27 |
28 | tuple_dict = {"col1": col1, "col2": (col2_values, col2_offsets)}
29 |
30 | assert ungroup_values_offsets(tuple_dict) == flat_dict
31 | assert group_values_offsets(flat_dict) == tuple_dict
32 |
--------------------------------------------------------------------------------
/tests/unit/dag/test_executors.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2022, NVIDIA CORPORATION.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | import numpy as np
18 | import pandas as pd
19 |
20 | from merlin.core.dispatch import make_df
21 | from merlin.dag import Graph
22 | from merlin.dag.executors import LocalExecutor
23 | from merlin.dag.operator import Operator
24 | from merlin.schema.schema import ColumnSchema, Schema
25 | from merlin.table import TensorTable
26 |
27 |
28 | def test_local_executor_with_dataframe():
29 | df = make_df({"a": [1, 2, 3], "b": [4, 5, 6]})
30 | schema = Schema([ColumnSchema("a", dtype=np.int64), ColumnSchema("b", dtype=np.int64)])
31 | operator = ["a"] >> Operator()
32 | graph = Graph(operator)
33 | graph.construct_schema(schema)
34 |
35 | executor = LocalExecutor()
36 | result = executor.transform(df, [graph.output_node])
37 |
38 | result_a = result["a"].to_pandas() if not isinstance(result["a"], pd.Series) else result["a"]
39 | df_a = df["a"].to_pandas() if not isinstance(df["a"], pd.Series) else result["a"]
40 |
41 | assert all(result_a == df_a)
42 | assert "b" not in result.columns
43 |
44 |
45 | def test_local_executor_with_dataframe_like():
46 | df = TensorTable(
47 | {"a": np.array([1, 2, 3], dtype=np.int64), "b": np.array([4, 5, 6], dtype=np.int64)}
48 | )
49 | schema = Schema([ColumnSchema("a", dtype=np.int64), ColumnSchema("b", dtype=np.int64)])
50 | operator = ["a"] >> Operator()
51 | graph = Graph(operator)
52 | graph.construct_schema(schema)
53 |
54 | executor = LocalExecutor()
55 | result = executor.transform(df, [graph.output_node])
56 |
57 | assert result["a"] == df["a"]
58 | assert "b" not in result.columns
59 |
--------------------------------------------------------------------------------
/tests/unit/dispatch/test_lazy_dispatch.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2023, NVIDIA CORPORATION.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 | import pytest
17 |
18 | from merlin.dispatch.lazy import lazy_singledispatch
19 |
20 | try:
21 | import tensorflow as tf
22 | except ImportError:
23 | tf = None
24 |
25 |
26 | return_type_name = lazy_singledispatch("return_type_name")
27 |
28 |
29 | @return_type_name.register
30 | def return_int(arg: int):
31 | return "int"
32 |
33 |
34 | @return_type_name.register
35 | def return_float(arg: float):
36 | return "float"
37 |
38 |
39 | @return_type_name.register_lazy("tensorflow")
40 | def register_tf_to_array():
41 | import tensorflow as tf # pylint:disable=reimported
42 |
43 | @return_type_name.register(tf.Tensor)
44 | def return_tensor(arg: tf.Tensor):
45 | return "tensor"
46 |
47 |
48 | @pytest.mark.skipif(not tf, reason="requires tensorflow")
49 | def test_lazy_dispatch():
50 | result = return_type_name(5)
51 | assert result == "int"
52 |
53 | result = return_type_name(5.0)
54 | assert result == "float"
55 |
56 | result = return_type_name(tf.constant([1, 2, 3, 4]))
57 | assert result == "tensor"
58 |
59 | with pytest.raises(NotImplementedError) as exc:
60 | result = return_type_name("abc")
61 | assert "doesn't have a registered implementation" in str(exc.value)
62 |
--------------------------------------------------------------------------------
/tests/unit/dtypes/test_cudf.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2023, NVIDIA CORPORATION.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 | import pytest
17 |
18 | import merlin.dtypes as md
19 | from merlin.core.compat import cudf
20 |
21 |
22 | @pytest.mark.skipif(not cudf, reason="CUDF is required to test its dtypes")
23 | def test_cudf_struct_dtype():
24 | struct_dtype = cudf.StructDtype({"a": "int64", "b": "string"})
25 | merlin_dtype = md.dtype(struct_dtype)
26 | assert merlin_dtype == md.struct
27 |
28 | merlin_dtype = md.struct
29 | cudf_dtype = merlin_dtype.to("cudf")
30 | assert cudf_dtype == cudf.StructDtype
31 |
--------------------------------------------------------------------------------
/tests/unit/dtypes/test_module.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2022, NVIDIA CORPORATION.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 | import numpy
17 | import pytest
18 |
19 | import merlin.dtypes as md
20 |
21 |
22 | @pytest.mark.parametrize("python_type, merlin_type", [(int, md.int64)])
23 | def test_python_types_convert_correctly(python_type, merlin_type):
24 | assert md.dtype(python_type) == merlin_type
25 |
26 |
27 | @pytest.mark.parametrize("numpy_type, merlin_type", [(numpy.int64, md.int64)])
28 | def test_numpy_types_convert_correctly(numpy_type, merlin_type):
29 | assert md.dtype(numpy_type) == merlin_type
30 |
31 |
32 | @pytest.mark.parametrize(
33 | "dtype_name, dtype",
34 | [
35 | ("int32", md.int32),
36 | ("float64", md.float64),
37 | ("string", md.string),
38 | ("struct", md.struct),
39 | ],
40 | )
41 | def test_string_aliases_can_be_used(dtype_name, dtype):
42 | assert md.dtype(dtype_name) == dtype
43 |
44 |
45 | def test_type_mappings_can_be_registered():
46 | class TestType:
47 | pass
48 |
49 | test_type = md.DType("test", md.ElementType.Int, 4096, signed=True)
50 |
51 | md.register("test", {test_type: TestType})
52 | merlin_dtype = md.dtype(TestType)
53 | assert merlin_dtype == test_type
54 |
55 |
56 | def test_unknown_types_return_unknown():
57 | class UnknownType:
58 | pass
59 |
60 | dtype = md.dtype(UnknownType)
61 | assert dtype == md.unknown
62 |
--------------------------------------------------------------------------------
/tests/unit/io/test_avro.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2022, NVIDIA CORPORATION.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 | import os
17 |
18 | import numpy as np
19 | import pandas as pd
20 | import pytest
21 | from dask.dataframe import assert_eq
22 | from dask.dataframe.io.demo import names as name_list
23 |
24 | import merlin.io
25 | from merlin.core.compat import cudf
26 |
27 | if cudf:
28 | dask_cudf = pytest.importorskip("dask_cudf")
29 | else:
30 | pytest.mark.skip(reason="cudf did not import successfully")
31 | # Require uavro and fastavro library.
32 | # Note that fastavro is only required to write
33 | # avro files for testing, while uavro is actually
34 | # used by AvroDatasetEngine.
35 | fa = pytest.importorskip("fastavro")
36 | pytest.importorskip("uavro")
37 |
38 |
39 | @pytest.mark.parametrize("part_size", [None, "1KB"])
40 | @pytest.mark.parametrize("size", [100, 5000])
41 | @pytest.mark.parametrize("nfiles", [1, 2])
42 | def test_avro_basic(tmpdir, part_size, size, nfiles):
43 | # Define avro schema
44 | schema = fa.parse_schema(
45 | {
46 | "name": "avro.example.User",
47 | "type": "record",
48 | "fields": [
49 | {"name": "name", "type": "string"},
50 | {"name": "age", "type": "int"},
51 | ],
52 | }
53 | )
54 |
55 | # Write avro dataset with two files.
56 | # Collect block and record (row) count while writing.
57 | nblocks = 0
58 | nrecords = 0
59 | paths = [os.path.join(str(tmpdir), f"test.{i}.avro") for i in range(nfiles)]
60 | records = []
61 | for path in paths:
62 | names = np.random.choice(name_list, size)
63 | ages = np.random.randint(18, 100, size)
64 | data = [{"name": names[i], "age": ages[i]} for i in range(size)]
65 | with open(path, "wb") as f:
66 | fa.writer(f, schema, data)
67 | with open(path, "rb") as fo:
68 | avro_reader = fa.block_reader(fo)
69 | for block in avro_reader:
70 | nrecords += block.num_records
71 | nblocks += 1
72 | records += list(block)
73 | if nfiles == 1:
74 | paths = paths[0]
75 |
76 | # Read back with dask.dataframe
77 | df = merlin.io.Dataset(paths, part_size=part_size, engine="avro").to_ddf()
78 |
79 | # Check basic length and partition count
80 | if part_size == "1KB":
81 | assert df.npartitions == nblocks
82 | assert len(df) == nrecords
83 |
84 | # Full comparison
85 | expect = pd.DataFrame.from_records(records)
86 | expect["age"] = expect["age"].astype("int32")
87 | assert_eq(df.compute().reset_index(drop=True), expect)
88 |
--------------------------------------------------------------------------------
/tests/unit/schema/test_tags.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2022, NVIDIA CORPORATION.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 | import pytest
17 |
18 | from merlin.schema.tags import COMPOUND_TAGS, Tags, TagSet
19 |
20 |
21 | def test_tagset_init_normalizes_tags_to_enum():
22 | origin_tags = ["continuous", "list", "custom_tag"]
23 | tag_set = TagSet(origin_tags)
24 | assert Tags.CONTINUOUS in tag_set._tags
25 | assert Tags.LIST in tag_set._tags
26 |
27 |
28 | def test_tagset_init_collision_error():
29 | with pytest.raises(ValueError) as err:
30 | tag_set = TagSet(["continuous", "categorical"]) # noqa
31 |
32 | assert "continuous" in str(err.value)
33 | assert "categorical" in str(err.value)
34 | assert "incompatible" in str(err.value)
35 |
36 |
37 | def test_tagset_is_iterable():
38 | origin_tags = ["continuous", "list"]
39 | tag_set = TagSet(origin_tags)
40 | for tag in tag_set:
41 | assert tag.value in origin_tags
42 | assert len(tag_set) == len(origin_tags)
43 |
44 |
45 | def test_tagset_add():
46 | origin_tags = [Tags.CONTINUOUS, Tags.LIST, "custom_tag"]
47 | tag_set = TagSet(origin_tags)
48 |
49 | new_tags = "custom_tag2"
50 | new_tag_set = tag_set + new_tags
51 | assert len(new_tag_set) == 4
52 | assert new_tags in new_tag_set
53 | assert all(origin_tag in new_tag_set for origin_tag in origin_tags)
54 |
55 | new_tags = ["custom_tag2"]
56 | new_tag_set = tag_set + new_tags
57 |
58 | assert len(new_tag_set) == 4
59 | assert all(new_tag in new_tag_set for new_tag in new_tags)
60 | assert all(origin_tag in new_tag_set for origin_tag in origin_tags)
61 |
62 | new_tags = TagSet(["custom_tag2"])
63 | new_tag_set = tag_set + new_tags
64 |
65 | assert len(new_tag_set) == 4
66 | assert all(tag in new_tag_set for tag in new_tags._tags)
67 | assert all(origin_tag in new_tag_set for origin_tag in origin_tags)
68 |
69 |
70 | def test_tagset_sub():
71 | origin_tags = [Tags.CONTINUOUS, Tags.LIST]
72 | tag_set = TagSet(origin_tags + ["custom_tag"])
73 | assert len(tag_set) == 3
74 |
75 | new_tags = "custom_tag"
76 | new_tag_set = tag_set - new_tags
77 | assert len(new_tag_set) == 2
78 | assert new_tags not in new_tag_set
79 | assert all(origin_tag in new_tag_set for origin_tag in origin_tags)
80 |
81 | new_tags = ["custom_tag"]
82 | new_tag_set = tag_set - new_tags
83 |
84 | assert len(new_tag_set) == 2
85 | assert all(new_tag not in new_tag_set for new_tag in new_tags)
86 | assert all(origin_tag in new_tag_set for origin_tag in origin_tags)
87 |
88 | new_tags = TagSet(["custom_tag"])
89 | new_tag_set = tag_set - new_tags
90 |
91 | assert len(new_tag_set) == 2
92 | assert all(new_tag not in new_tag_set for new_tag in new_tags)
93 | assert all(origin_tag in new_tag_set for origin_tag in origin_tags)
94 |
95 |
96 | def test_tagset_add_collision_error():
97 | origin_tags = ["continuous", "list", "custom_tag"]
98 | tag_set = TagSet(origin_tags)
99 |
100 | new_tags = "categorical"
101 |
102 | with pytest.raises(ValueError) as err:
103 | tag_set + new_tags # pylint: disable=W0104
104 |
105 | assert "continuous" in str(err.value)
106 | assert "categorical" in str(err.value)
107 | assert "incompatible" in str(err.value)
108 |
109 | new_tags = ["categorical"]
110 |
111 | with pytest.raises(ValueError) as err:
112 | tag_set + new_tags # pylint: disable=W0104
113 |
114 | assert "continuous" in str(err.value)
115 | assert "categorical" in str(err.value)
116 | assert "incompatible" in str(err.value)
117 |
118 | new_tags = TagSet(["categorical"])
119 |
120 | with pytest.raises(ValueError) as err:
121 | tag_set + new_tags # pylint: disable=W0104
122 |
123 | assert "continuous" in str(err.value)
124 | assert "categorical" in str(err.value)
125 | assert "incompatible" in str(err.value)
126 |
127 |
128 | def test_tagset_atomizes_compound_tags():
129 | for tag, atomic_tags in COMPOUND_TAGS.items():
130 | tag_set = TagSet([tag])
131 | assert tag not in tag_set
132 | for atomic_tag in atomic_tags:
133 | assert atomic_tag in tag_set
134 |
--------------------------------------------------------------------------------
/tests/unit/utils/test_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2022, NVIDIA CORPORATION.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | import pytest
18 |
19 | from merlin.core.compat import HAS_GPU
20 | from merlin.core.utils import Distributed, Serial, global_dask_client, set_dask_client
21 |
22 | if HAS_GPU:
23 | _CPU = [True, False]
24 | else:
25 | _CPU = [True]
26 |
27 |
28 | @pytest.mark.parametrize("cpu", _CPU)
29 | def test_serial_context(client, cpu):
30 | # Set distributed client
31 | set_dask_client(client=client)
32 | assert global_dask_client() == client
33 |
34 | # Check that the global dask client
35 | # becomes None in a `with Serial()` block
36 | with Serial():
37 | assert global_dask_client() is None
38 |
39 | # Global client should revert outside
40 | # the `with Serial()` block
41 | assert global_dask_client() == client
42 |
43 |
44 | @pytest.mark.parametrize("cpu", _CPU)
45 | @pytest.mark.parametrize("nested_serial", [True, False])
46 | def test_nvt_distributed(cpu, nested_serial):
47 | if cpu:
48 | distributed = pytest.importorskip("distributed")
49 | cluster_type = "cpu"
50 | cluster_cls = distributed.LocalCluster
51 | else:
52 | dask_cuda = pytest.importorskip("dask_cuda")
53 | cluster_type = "cuda"
54 | cluster_cls = dask_cuda.LocalCUDACluster
55 |
56 | # Set the global client to None
57 | set_dask_client(client=None)
58 | assert global_dask_client() is None
59 |
60 | # Check that a new local cluster is deployed within
61 | # a `with Distributed()` block
62 | with Distributed(cluster_type=cluster_type, n_workers=1, force_new=True) as dist:
63 | assert dist.client is not None
64 | assert global_dask_client() == dist.client
65 | assert len(dist.cluster.workers) == 1
66 | assert isinstance(dist.cluster, cluster_cls)
67 |
68 | # Check that we can nest a `with Serial()` block
69 | # inside a `with Distributed()` block
70 | if nested_serial:
71 | with Serial():
72 | assert global_dask_client() is None
73 | assert global_dask_client() == dist.client
74 |
75 | # Global client should revert to None outside
76 | # the `with Distributed()` block
77 | assert global_dask_client() is None
78 |
79 |
80 | @pytest.mark.parametrize("cpu", _CPU)
81 | def test_nvt_distributed_force(client, cpu):
82 | if cpu:
83 | distributed = pytest.importorskip("distributed")
84 | cluster_type = "cpu"
85 | cluster_cls = distributed.LocalCluster
86 | else:
87 | dask_cuda = pytest.importorskip("dask_cuda")
88 | cluster_type = "cuda"
89 | cluster_cls = dask_cuda.LocalCUDACluster
90 |
91 | # Set distributed client
92 | set_dask_client(client=client)
93 | assert global_dask_client() == client
94 |
95 | # Check that a new local cluster is deployed within
96 | # a `with Distributed()` block. Since we are using
97 | # `force_new=True`, the new cluster should NOT be
98 | # the same as the original `client`.
99 | with Distributed(cluster_type=cluster_type, force_new=True, n_workers=1) as dist:
100 | assert dist.client != client
101 | assert global_dask_client() == dist.client
102 | assert len(dist.cluster.workers) == 1
103 | assert isinstance(dist.cluster, cluster_cls)
104 |
105 | # We should revert to the original client
106 | # outside the `with Distributed()` block
107 | assert global_dask_client() == client
108 |
109 | # Check that the default behavior is to avoid
110 | # deploying a new cluster (and warning the user)
111 | # if an existing client is detected
112 | with pytest.warns(UserWarning):
113 | with Distributed(cluster_type=cluster_type, n_workers=1) as dist:
114 | assert dist.client == client
115 | assert global_dask_client() == dist.client
116 |
117 | # We should revert to the original client
118 | # outside the `with Distributed()` block
119 | assert global_dask_client() == client
120 |
--------------------------------------------------------------------------------