├── .azure
├── docker-build.yml
├── gpu-coverage.yml
├── gpu-tests.yml
└── notebook-runs.yml
├── .codecov.yml
├── .git-blame-ignore-revs
├── .github
├── CODEOWNERS
├── ISSUE_TEMPLATE
│ ├── bug_report.md
│ ├── documentation.md
│ ├── feature_request.md
│ └── program_coverage.md
├── PULL_REQUEST_TEMPLATE.md
├── dependabot.yml
├── labeling-config.yml
├── lightning-probot.yml
├── load-jit-coverage-report.py
├── load-quickstart-report.py
├── run-benchmark-as-lit-jobs.py
├── run-jit-coverage-as-lit-studios.py
├── run-quickstart-as-lit-jobs.py
└── workflows
│ ├── auto-cc.yml
│ ├── ci-benchmark.yml
│ ├── ci-checks.yml
│ ├── ci-jit-coverage.yml
│ ├── ci-quickstart.yml
│ ├── ci-testing.yml
│ ├── docs-build.yml
│ ├── label-conflicts.yml
│ ├── labeler.yml
│ ├── release-nightly.yml
│ ├── release-pypi.yml
│ └── stale-branches.yaml
├── .gitignore
├── .lightning
└── workflows
│ ├── all-tests.yaml
│ ├── notebooks.yaml
│ └── transformer-engine.yaml
├── .pre-commit-config.yaml
├── LICENSE
├── MANIFEST.in
├── Makefile
├── README.md
├── dockers
├── README.md
├── ubuntu-cuda
│ └── Dockerfile
├── with-apex
│ └── Dockerfile
└── with-dev
│ └── Dockerfile
├── docs
├── .build_docs.sh
├── .readthedocs.yaml
├── Makefile
├── make.bat
└── source
│ ├── _static
│ ├── copybutton.js
│ └── images
│ │ ├── LightningThunderDarkModewByline.png
│ │ ├── LightningThunderLightModewByline.png
│ │ ├── how_it_works.png
│ │ ├── icon.svg
│ │ ├── lightning_thunder_lightmode_nobyline.png
│ │ ├── logo-large.svg
│ │ ├── logo-small.svg
│ │ ├── logo.png
│ │ ├── logo.svg
│ │ ├── normalized_training_throughput_zero2.png
│ │ ├── pretrain_perf.png
│ │ └── training_throughput_single.png
│ ├── _templates
│ └── theme_variables.jinja
│ ├── advanced
│ ├── contrib_thunder.rst
│ ├── extending.rst
│ └── inside_thunder.rst
│ ├── basic
│ ├── faq.rst
│ ├── inspecting_traces.rst
│ ├── mlp_mnist.rst
│ ├── overview.rst
│ └── sharp_edges.rst
│ ├── conf.py
│ ├── fundamentals
│ ├── examine.rst
│ ├── hello_world.rst
│ └── installation.rst
│ ├── index.rst
│ ├── intermediate
│ ├── additional_executors.rst
│ ├── benchmarking.rst
│ ├── ddp.rst
│ └── whats_next.rst
│ └── reference
│ ├── clang
│ └── index.rst
│ ├── common
│ └── index.rst
│ ├── core
│ ├── baseutils.rst
│ ├── codeutils.rst
│ ├── devices.rst
│ ├── dtypes.rst
│ ├── index.rst
│ ├── langctxs.rst
│ ├── prims.rst
│ ├── proxies.rst
│ ├── pytree.rst
│ ├── rematerialization.rst
│ ├── symbol.rst
│ ├── trace.rst
│ └── transforms.rst
│ ├── distributed
│ └── index.rst
│ ├── dynamo
│ └── index.rst
│ ├── examine
│ └── index.rst
│ ├── executors
│ ├── apexex.rst
│ ├── cudnnex.rst
│ ├── index.rst
│ ├── nvfuserex.rst
│ ├── passes.rst
│ ├── pythonex.rst
│ ├── torch_compile.rst
│ ├── torchex.rst
│ ├── triton_crossentropy.rst
│ └── utils.rst
│ ├── extend
│ └── index.rst
│ ├── plugins
│ └── index.rst
│ ├── recipes
│ └── index.rst
│ ├── thunder.rst
│ ├── torch
│ └── index.rst
│ └── transforms
│ └── index.rst
├── examples
├── coverage
│ ├── all.txt
│ ├── jit_coverage_hf.py
│ └── requirements.txt
└── quickstart
│ ├── hf_bert.py
│ ├── hf_llm.py
│ ├── mlp.py
│ ├── requirements.txt
│ ├── vit_hf.py
│ └── vit_tv.py
├── notebooks
├── .ignore.ci
├── adding_custom_operator.ipynb
├── adding_custom_operator_backward.ipynb
├── dev_tutorials
│ ├── extend.ipynb
│ ├── fsdp_tutorial.ipynb
│ └── thunder-add-vjp-rule.md
├── extend_thunder_with_cuda_python.ipynb
├── hello_world_thunderfx.ipynb
├── liger_kernel.ipynb
├── thunder_trace_intro.ipynb
├── writing_a_trace_transform_cpu_offloading.ipynb
└── zero_to_thunder.ipynb
├── pyproject.toml
├── requirements.txt
├── requirements
├── base.txt
├── coverage.txt
├── devel.txt
├── docs.txt
├── notebooks.txt
└── test.txt
├── scripts
├── bisect_nvfuser.py
├── build_from_source.sh
├── remove-torch-lines.sh
├── sanity-check.sh
└── validate_build.py
├── setup.py
└── thunder
├── __about__.py
├── __init__.py
├── benchmarks
├── __init__.py
├── benchmark_hf.py
├── benchmark_litgpt.py
├── benchmark_peft.py
├── conftest.py
├── distributed.py
├── einsum.py
├── targets.py
├── test_benchmark_litgpt.py
└── utils.py
├── clang
├── __init__.py
├── langctx.py
└── utils.py
├── common.py
├── core
├── __init__.py
├── baseutils.py
├── codeutils.py
├── compile_data.py
├── devices.py
├── dtypes.py
├── interpreter.py
├── jit_ext.py
├── langctxs.py
├── module.py
├── options.py
├── patterns.py
├── prims.py
├── profile.py
├── proxies.py
├── pytree.py
├── recipe.py
├── rematerialization.py
├── symbol.py
├── trace.py
├── trace_interpreter.py
├── transform_common.py
├── transforms.py
├── update_aliases.py
├── utils.py
└── vjp_utils.py
├── dev_utils
├── __init__.py
├── benchmark.py
├── check_trace.py
├── debug_transform.py
├── nvtx_profile_transform.py
└── utils.py
├── distributed
├── __init__.py
├── bucketing.py
├── checkpoint.py
├── prims.py
├── tensor_parallel
│ ├── __init__.py
│ ├── column_wise.py
│ ├── common.py
│ ├── optimize_comm.py
│ └── row_wise.py
├── transforms
│ ├── __init__.py
│ ├── ddp.py
│ ├── ddp_v2.py
│ ├── fsdp.py
│ └── fsdp_v2.py
└── utils.py
├── dynamo
├── __init__.py
├── benchmark_utils.py
├── compiler.py
├── compiler_graph_benchmark.py
├── report.py
├── repro_script_template.py
├── splitter.py
└── utils.py
├── examine
├── __init__.py
└── memory_calculation.py
├── executors
├── __init__.py
├── apex_entropyex_impl.py
├── apex_fused_rms_norm_impl.py
├── apexex.py
├── cudnn_layernormex.py
├── cudnn_sdpa.py
├── cudnnex.py
├── custom_op_ex.py
├── data_dependent_partition.py
├── fa3ex.py
├── nvfuserex.py
├── nvfuserex_impl.py
├── passes.py
├── pythonex.py
├── sdpaex.py
├── torch_autograd.py
├── torch_compile.py
├── torchex.py
├── transformer_engine_v1ex.py
├── transformer_engineex.py
├── transformer_engineex_impl.py
├── triton_crossentropy.py
├── triton_crossentropy_impl.py
├── triton_utils.py
└── utils.py
├── extend
└── __init__.py
├── numpy
├── __init__.py
└── langctx.py
├── plugins
├── __init__.py
├── distributed.py
├── fp8.py
├── quantization.py
└── reduce_overhead.py
├── py.typed
├── recipes
├── __init__.py
├── base.py
└── hf_transformers.py
├── tests
├── README.md
├── __init__.py
├── bf16.py
├── conftest.py
├── coverage_tests
│ ├── __init__.py
│ ├── test_coverage_hf_diffusers.py
│ └── test_coverage_trace.py
├── distributed
│ ├── __init__.py
│ ├── helper.py
│ ├── modules.py
│ ├── test_checkpoint.py
│ ├── test_ddp.py
│ ├── test_dtensor.py
│ ├── test_fsdp.py
│ ├── test_ops.py
│ └── test_tensor_parallel.py
├── framework.py
├── hf_bart_self_attn.py
├── litgpt_model.py
├── llama2_model.py
├── make_tensor.py
├── module_example.py
├── nanogpt_model.py
├── opinfos.py
├── test_apex_cross_entropy_executor.py
├── test_apex_fused_norms.py
├── test_auto_register_torchops.py
├── test_autocast.py
├── test_check_trace.py
├── test_core.py
├── test_cudnn_executor.py
├── test_dynamo.py
├── test_einops.py
├── test_elementwise.py
├── test_examine.py
├── test_examine_memory.py
├── test_extend.py
├── test_fa3_executor.py
├── test_grad.py
├── test_inplace_copy.py
├── test_interpreter.py
├── test_jit_general.py
├── test_networks.py
├── test_nvfuser.py
├── test_nvfuser_remat.py
├── test_ops.py
├── test_patterns.py
├── test_pythonex.py
├── test_randomness.py
├── test_recipes.py
├── test_reductions.py
├── test_sdpaex_executor.py
├── test_shape_ops.py
├── test_torch_compile_executor.py
├── test_torch_library_custom_op.py
├── test_torch_library_custom_op_with_lists.py
├── test_transformer_engine_executor.py
├── test_transformer_engine_v1_executor.py
├── test_transforms.py
├── test_triton_ce.py
├── test_update_aliases.py
└── utils.py
├── torch
├── __init__.py
├── custom_op.py
├── default_torch_ops.py
├── experimental
│ ├── __init__.py
│ ├── dtensor_codeutils.py
│ ├── dtensor_proxy.py
│ ├── dtensor_torch_and_prims.py
│ └── dtensor_utils.py
└── langctx.py
└── transforms
├── __init__.py
├── autocast.py
├── autodiff.py
├── constant_folding.py
├── cudagraph.py
├── extraction_only_prologue_transform.py
├── materialization.py
├── prune_prologue_checks.py
├── qlora.py
├── quantization.py
└── utils.py
/.azure/gpu-coverage.yml:
--------------------------------------------------------------------------------
1 | trigger:
2 | tags:
3 | include: ["*"]
4 | paths:
5 | include:
6 | - ".azure/gpu-coverage.yml"
7 | - "requirements/coverage.txt"
8 | - "thunder/tests/coverage/**"
9 | branches:
10 | include:
11 | - "main"
12 | - "release/*"
13 | - "refs/tags/*"
14 |
15 | pr:
16 | branches:
17 | include: ["*"]
18 |
19 | jobs:
20 | - job: coverage
21 | strategy:
22 | matrix:
23 | "w/ torch 2.7.1":
24 | docker-image: "ubuntu24.04-cuda12.6.3-cudnn-fe1.10.0-py3.10-pt_2.7.1-dev"
25 | # how much time to give 'run always even if cancelled tasks' before stopping them
26 | cancelTimeoutInMinutes: "2"
27 | pool: "lit-rtx-3090"
28 | variables:
29 | DEVICES: $( python -c 'name = "$(Agent.Name)" ; gpus = name.split("_")[-1] if "_" in name else "0"; print(gpus)' )
30 | TORCH_HOME: "/var/tmp/torch"
31 | PIP_CACHE_DIR: "/var/tmp/pip"
32 | PYTHONHASHSEED: "0"
33 | NCCL_DEBUG: "INFO"
34 | ALLOW_COVERAGE_TRACE: "1"
35 | container:
36 | image: "pytorchlightning/lightning-thunder:$(docker-image)"
37 | options: "--gpus=all --shm-size=16g -v /var/tmp:/var/tmp"
38 | workspace:
39 | clean: all
40 | steps:
41 | - bash: |
42 | echo $(DEVICES)
43 | lspci | egrep 'VGA|3D'
44 | dpkg-query -W -f='${Package} ${Version}\n' libnccl2 libnccl-dev
45 | whereis nvidia
46 | nvidia-smi
47 | which python && which pip
48 | python --version
49 | pip --version
50 | pip list
51 | echo "##vso[task.setvariable variable=CUDA_VISIBLE_DEVICES]$(DEVICES)"
52 | displayName: "Image info & NVIDIA"
53 |
54 | - bash: |
55 | set -ex
56 | # drop pt from requirements so not to interfere with the existing one
57 | bash scripts/remove-torch-lines.sh requirements/base.txt
58 | cat requirements/base.txt
59 |
60 | # double check on test requirements
61 | pip install -U -r requirements/base.txt -r requirements/coverage.txt
62 |
63 | # https://docs.codecov.com/docs/codecov-uploader
64 | curl -Os https://uploader.codecov.io/latest/linux/codecov
65 | chmod +x codecov
66 |
67 | # install this package
68 | python setup.py develop
69 | displayName: "Install package & ..."
70 |
71 | - bash: bash scripts/sanity-check.sh
72 | displayName: "Sanity check / details"
73 |
74 | - bash: |
75 | PYTHONPATH=$(pwd)/thunder/tests pytest thunder/tests/coverage_tests
76 | timeoutInMinutes: "45"
77 | displayName: "Testing: coverage_tests"
78 |
--------------------------------------------------------------------------------
/.azure/notebook-runs.yml:
--------------------------------------------------------------------------------
1 | trigger:
2 | tags:
3 | include: ["*"]
4 | branches:
5 | include:
6 | - "main"
7 | - "release/*"
8 | - "refs/tags/*"
9 |
10 | pr:
11 | branches:
12 | include: ["*"]
13 |
14 | jobs:
15 | - job: jupyter
16 | strategy:
17 | matrix:
18 | "notebooks w/ torch 2.8":
19 | docker-image: "ubuntu24.04-cuda12.6.3-cudnn-fe1.10.0-py3.10-pt_2.8.0-dev"
20 | "notebooks w/ torch-nightly":
21 | docker-image: "ubuntu24.04-cuda12.6.3-cudnn-fe1.10.0-py3.10-pt_main-dev"
22 | # how long to run the job before automatically cancelling
23 | timeoutInMinutes: "45"
24 | # how much time to give 'run always even if cancelled tasks' before stopping them
25 | cancelTimeoutInMinutes: "2"
26 | pool: "lit-rtx-3090"
27 | variables:
28 | DEVICES: $( python -c 'name = "$(Agent.Name)" ; gpus = name.split("_")[-1] if "_" in name else "0"; print(gpus)' )
29 | TORCH_HOME: "/var/tmp/torch"
30 | PIP_CACHE_DIR: "/var/tmp/pip"
31 | container:
32 | image: "pytorchlightning/lightning-thunder:$(docker-image)"
33 | options: "--gpus=all --shm-size=16g -v /var/tmp:/var/tmp"
34 | workspace:
35 | clean: all
36 | steps:
37 | - bash: |
38 | echo $(DEVICES)
39 | lspci | egrep 'VGA|3D'
40 | whereis nvidia
41 | nvidia-smi
42 | which python && which pip
43 | python --version
44 | pip --version
45 | pip list
46 | echo "##vso[task.setvariable variable=CUDA_VISIBLE_DEVICES]$(DEVICES)"
47 | displayName: "Image info & NVIDIA"
48 |
49 | - bash: |
50 | set -ex
51 | # drop pt from requirements so not to interfere with the existing one
52 | bash scripts/remove-torch-lines.sh requirements/base.txt
53 | cat requirements/base.txt
54 | pip install -U -r requirements/notebooks.txt
55 | # install this package
56 | python setup.py develop
57 | # double check on test requirements
58 | echo "Install special requirements for notebooks"
59 | displayName: "Install package & ..."
60 |
61 | - bash: |
62 | set -ex
63 | pip list
64 | bash scripts/sanity-check.sh
65 | displayName: "Sanity check / details"
66 |
67 | - bash: |
68 | set -ex
69 | # list all notebooks in this folder
70 | find . -name "*.ipynb" > all.txt
71 | # drop all "./" from beginning of each line
72 | sed -i 's/^\.\///' all.txt
73 | # filter out the ones that are listed in .ignore.ci
74 | grep -Fxv -f .ignore.ci all.txt > ci.txt
75 | # iterate over all listed notebooks and execute them with jupyter
76 | while read -r line; do
77 | echo "Processing $line"
78 | jupyter execute $line --timeout=300
79 | done <<< $(cat ci.txt)
80 | workingDirectory: "notebooks/"
81 | displayName: "Execute notebooks"
82 |
--------------------------------------------------------------------------------
/.codecov.yml:
--------------------------------------------------------------------------------
1 | # see https://docs.codecov.io/docs/codecov-yaml
2 | # Validation check:
3 | # $ curl --data-binary @.codecov.yml https://codecov.io/validate
4 |
5 | # https://docs.codecov.io/docs/codecovyml-reference
6 | codecov:
7 | bot: "codecov-io"
8 | strict_yaml_branch: "yaml-config"
9 | require_ci_to_pass: yes
10 | notify:
11 | # after_n_builds: 2
12 | wait_for_ci: yes
13 |
14 | coverage:
15 | precision: 0 # 2 = xx.xx%, 0 = xx%
16 | round: nearest # how coverage is rounded: down/up/nearest
17 | range: 40...100 # custom range of coverage colors from red -> yellow -> green
18 | status:
19 | # https://codecov.readme.io/v1.0/docs/commit-status
20 | project:
21 | default:
22 | informational: true
23 | target: 99% # specify the target coverage for each commit status
24 | threshold: 30% # allow this little decrease on project
25 | # https://github.com/codecov/support/wiki/Filtering-Branches
26 | # branches: main
27 | if_ci_failed: error
28 | # https://github.com/codecov/support/wiki/Patch-Status
29 | patch:
30 | default:
31 | informational: true
32 | target: 50% # specify the target "X%" coverage to hit
33 | # threshold: 50% # allow this much decrease on patch
34 | changes: false
35 |
36 | # https://docs.codecov.com/docs/github-checks#disabling-github-checks-patch-annotations
37 | github_checks:
38 | annotations: false
39 |
40 | parsers:
41 | gcov:
42 | branch_detection:
43 | conditional: true
44 | loop: true
45 | macro: false
46 | method: false
47 | javascript:
48 | enable_partials: false
49 |
50 | comment: false
51 |
--------------------------------------------------------------------------------
/.git-blame-ignore-revs:
--------------------------------------------------------------------------------
1 | # Move optimize_allreduce_in_ddp_backward to thunder.distributed.transforms namespace #2087
2 | 7a3d5d4e64da010eea42210a065960e3568870ac
3 | # Split `test_ddp` into `test_ddp`, `test_fsdp` and `test_ops` (#772)
4 | b63e8713141044ed8c184f2f4c8305048d477319
5 | # Enable `ruff-check` in pre-commit (#2192)
6 | ae51153db2e623e8b9fb3609de3a6ad976addd0a
7 | # Enable ruff format in pre-commit (#2142)
8 | 0c6f9a91f1ac955bd5c1087ae26d120d7ab184a3
9 |
--------------------------------------------------------------------------------
/.github/CODEOWNERS:
--------------------------------------------------------------------------------
1 | # Each line is a file pattern followed by one or more owners.
2 |
3 | # These owners will be the default owners for everything in the repo. Unless a later match takes precedence,
4 | # @global-owner1, @global-owner2, and @global-owner3 will be requested for review when someone opens a pull request.
5 |
6 | # Thank you, our previous code owners for their service:
7 | # @carmocca @borda
8 |
9 | * @mruberry @lantiga @t-vi @KaelanDt
10 |
11 | # CI/CD and configs
12 | /.azure/ @KaelanDt @lantiga @t-vi
13 | /.github/ @KaelanDt @lantiga @t-vi
14 | /.lightning/ @KaelanDt @lantiga @t-vi
15 | /dockers/ @KaelanDt @lantiga @t-vi
16 | Makefile @KaelanDt @lantiga @t-vi
17 | *.yml @KaelanDt @lantiga @t-vi
18 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug_report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug report
3 | about: Create a report to help us improve
4 | title: ''
5 | assignees: ''
6 | ---
7 |
8 | *Note*: If you have a model or program that is not supported yet but should be, please use the program coverage template.
9 |
10 | ## 🐛 Bug
11 |
12 |
13 |
14 | ### To Reproduce
15 |
16 | Steps to reproduce the behavior:
17 |
18 | 1. Go to '...'
19 | 1. Run '....'
20 | 1. Scroll down to '....'
21 | 1. See error
22 |
23 |
24 |
25 | #### Code sample
26 |
27 |
29 |
30 | ### Expected behavior
31 |
32 |
33 |
34 | ### Environment
35 |
36 | - PyTorch Version (e.g., 1.0):
37 | - OS (e.g., Linux):
38 | - How you installed PyTorch (`conda`, `pip`, source):
39 | - Build command you used (if compiling from source):
40 | - Python version:
41 | - CUDA/cuDNN version:
42 | - GPU models and configuration:
43 | - Any other relevant information:
44 |
45 | ### Additional context
46 |
47 |
48 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/documentation.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Typos and doc fixes
3 | about: Typos and doc fixes
4 | title: ''
5 | labels: documentation
6 | assignees: ''
7 | ---
8 |
9 | ## 📚 Documentation
10 |
11 | For typos and doc fixes, please go ahead and:
12 |
13 | 1. Create an issue.
14 | 1. Fix the typo.
15 | 1. Submit a PR.
16 |
17 | Thanks!
18 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Feature request
3 | about: Suggest an idea for this project
4 | title: ''
5 | labels: enhancement
6 | assignees: ''
7 | ---
8 |
9 | ## 🚀 Feature
10 |
11 |
12 |
13 | ### Motivation
14 |
15 |
16 |
17 | ### Pitch
18 |
19 |
20 |
21 | ### Alternatives
22 |
23 |
24 |
25 | ### Additional context
26 |
27 |
28 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/program_coverage.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Program Coverage
3 | about: Expand the programs / models Thunder can process
4 | title: ''
5 | labels: program-coverage
6 | assignees: ''
7 | ---
8 |
9 | ## 🚀 Model / language coverage
10 |
11 |
12 |
13 | ### Pitch
14 |
15 |
16 |
17 | ### Alternatives / Potential work-arounds
18 |
19 |
24 |
--------------------------------------------------------------------------------
/.github/PULL_REQUEST_TEMPLATE.md:
--------------------------------------------------------------------------------
1 |
2 | Before submitting
3 |
4 | - [ ] Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
5 | - [ ] Did you read the [contributor guideline](https://github.com/Lightning-AI/pytorch-lightning/blob/main/.github/CONTRIBUTING.md), Pull Request section?
6 | - [ ] Did you make sure to update the docs?
7 | - [ ] Did you write any new necessary tests?
8 |
9 |
10 |
11 | ## What does this PR do?
12 |
13 | Fixes # (issue).
14 |
15 | ## PR review
16 |
17 | Anyone in the community is free to review the PR once the tests have passed.
18 | If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
19 |
20 | ## Did you have fun?
21 |
22 | Make sure you had fun coding 🙃
23 |
--------------------------------------------------------------------------------
/.github/dependabot.yml:
--------------------------------------------------------------------------------
1 | # Basic dependabot.yml file with
2 | # minimum configuration for two package managers
3 |
4 | version: 2
5 | updates:
6 | # Enable version updates for python
7 | - package-ecosystem: "pip"
8 | # Look for a `requirements` in the `root` directory
9 | directory: "/"
10 | schedule:
11 | interval: "monthly"
12 | pull-request-branch-name:
13 | separator: "-"
14 | # Allow up to 5 open pull requests for pip dependencies
15 | open-pull-requests-limit: 5
16 |
17 | # Enable version updates for GitHub Actions
18 | - package-ecosystem: "github-actions"
19 | directory: "/"
20 | schedule:
21 | interval: "monthly"
22 | pull-request-branch-name:
23 | separator: "-"
24 | # Allow up to 5 open pull requests for GitHub Actions
25 | open-pull-requests-limit: 5
26 | groups:
27 | GHA-updates:
28 | patterns:
29 | - "*"
30 |
--------------------------------------------------------------------------------
/.github/labeling-config.yml:
--------------------------------------------------------------------------------
1 | documentation:
2 | - changed-files:
3 | - any-glob-to-any-file:
4 | - docs/**/*
5 | - dev_tutorials/*
6 |
7 | "ci":
8 | - changed-files:
9 | - any-glob-to-any-file:
10 | - .azure/*
11 | - .github/*
12 | - .github/workflows/*
13 | - .lightning/workflows/*
14 | - dockers/**/*
15 |
16 | "docker":
17 | - changed-files:
18 | - any-glob-to-any-file:
19 | - dockers/**/*
20 | - .azure/docker-build.yml
21 |
22 | "install":
23 | - changed-files:
24 | - any-glob-to-any-file:
25 | - setup.py
26 | - pyproject.toml
27 | - MANIFEST.in
28 |
29 | "dependencies":
30 | - changed-files:
31 | - any-glob-to-any-file:
32 | - requirements/*
33 | - requirements.txt
34 |
--------------------------------------------------------------------------------
/.github/lightning-probot.yml:
--------------------------------------------------------------------------------
1 | tracking_issue: 72
2 |
--------------------------------------------------------------------------------
/.github/load-jit-coverage-report.py:
--------------------------------------------------------------------------------
1 | import json
2 | import sys
3 |
4 |
5 | def main(report_path: str = "jit_coverage_report.json"):
6 | """Load and print the jit coverage report."""
7 | with open(report_path) as fp:
8 | report = json.load(fp)
9 |
10 | all_count = len(report)
11 | success = [el for el in report if el["status"] == "[SUCCESS]"]
12 | skipped = [el for el in report if el["status"] == "[SKIPPED]"]
13 | failure = [el for el in report if el["status"] == "[FAILURE]"]
14 |
15 | print("thunder.jit coverage report")
16 | print(f"🟢 {len(success)}/{all_count} [SUCCESS]")
17 | print(f"🟡 {len(skipped)}/{all_count} [SKIPPED]")
18 | print(f"⛔ {len(failure)}/{all_count} [FAILURE]")
19 |
20 |
21 | if __name__ == "__main__":
22 | # optional path to the report file
23 | system_args = sys.argv[1:]
24 | main_args = {"report_path": system_args[0]} if len(system_args) > 0 else {}
25 | main(**main_args)
26 |
--------------------------------------------------------------------------------
/.github/load-quickstart-report.py:
--------------------------------------------------------------------------------
1 | import json
2 | import sys
3 |
4 |
5 | def main(report_path: str = "quickstart_report.json"):
6 | """Load and print the quickstart report."""
7 | with open(report_path) as fp:
8 | report = json.load(fp)
9 |
10 | success_count = sum(status == "completed" for status in report.values())
11 | overall_status = "🟢" if success_count == len(report) else "⛔"
12 | print(f"Quickstart report {overall_status} with {success_count} out of {len(report)} successful:")
13 | # Sort so that entries with status "success" (or "completed") are last
14 | for name, status in sorted(report.items(), key=lambda x: x[1] == "completed"):
15 | status_icon = "✔️" if status == "completed" else "❌"
16 | print(f"{status_icon} {name}")
17 |
18 |
19 | if __name__ == "__main__":
20 | # optional path to the report file
21 | system_args = sys.argv[1:]
22 | main_args = {"report_path": system_args[0]} if len(system_args) > 0 else {}
23 | main(**main_args)
24 |
--------------------------------------------------------------------------------
/.github/run-benchmark-as-lit-jobs.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import json
3 | import sys
4 | import time
5 | from datetime import datetime
6 |
7 | from lightning_sdk import Studio, Job, Machine, Status
8 |
9 |
10 | def main(gh_run_id: str = ""):
11 | if not gh_run_id:
12 | gh_run_id = datetime.now().strftime("%Y-%m-%d|%H:%M:%S")
13 | print("Creating studio...")
14 | s = Studio(f"thunder-benchmark-run{gh_run_id}", "oss-thunder", org="lightning-ai", create_ok=True)
15 |
16 | print("Uploading package and benchmark script...")
17 | s.upload_folder("dist", remote_path="dist")
18 | pkg_path = glob.glob("dist/*.whl")[0]
19 | s.upload_file("thunder/benchmarks/benchmark_hf.py", remote_path="benchmarks/benchmark_hf.py")
20 |
21 | print("Starting studio...")
22 | s.start()
23 | print("Installing Thunder and dependencies...")
24 | s.run(
25 | f"pip install {pkg_path} -U transformers==4.52.4 nvidia-cudnn-frontend 'numpy<2.0' 'nvfuser_cu128_torch27==0.2.27.dev20250615'"
26 | )
27 |
28 | print("Running HF benchmark script...")
29 | job = Job.run(
30 | name=f"benchmark-run{gh_run_id}",
31 | command="pip list && python benchmarks/benchmark_hf.py",
32 | studio=s,
33 | machine=Machine.L40S,
34 | interruptible=True,
35 | )
36 |
37 | print("Stopping studio...")
38 | s.stop()
39 |
40 | print("Waiting for job to finish...")
41 | job.wait()
42 | status = str(job.status).lower()
43 | print(f"[{job.status}]\t {job.name}")
44 |
45 | report = {"benchmark_hf.py": status}
46 | with open("benchmark_hf_report.json", "w") as fp:
47 | json.dump(report, fp, indent=4)
48 |
49 | if job.status != Status.Completed:
50 | print("=" * 80)
51 | print(f"===== benchmark_hf.py -> {job.status} =====")
52 | print("=" * 80)
53 | print(job.logs)
54 | print("=" * 80)
55 | time.sleep(3)
56 | raise RuntimeError(f"Benchmark HF job {job.status}")
57 | # clean up
58 | job.delete()
59 | s.delete()
60 |
61 |
62 | if __name__ == "__main__":
63 | # parse command line arguments
64 | args = sys.argv[1:]
65 | main(*args)
66 |
--------------------------------------------------------------------------------
/.github/run-jit-coverage-as-lit-studios.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from datetime import datetime
3 | import glob
4 | import os
5 | import json
6 |
7 | from lightning_sdk import Studio, Machine
8 | from tqdm import tqdm
9 |
10 |
11 | def main(gh_run_id: str = ""):
12 | if not gh_run_id:
13 | gh_run_id = datetime.now().strftime("%Y-%m-%d|%H:%M:%S")
14 |
15 | batches = []
16 | chunk_size = 16
17 | with open("examples/coverage/all.txt") as f:
18 | lines = [el for el in f.readlines()]
19 | chunks = [lines[i : i + chunk_size] for i in range(0, len(lines), chunk_size)]
20 | for i, chunk_lines in enumerate(chunks):
21 | filename = f"{i:03d}.txt"
22 | with open(f"examples/coverage/{filename}", "w") as f:
23 | f.writelines([el + "\n" for el in chunk_lines])
24 | batches.append((filename, chunk_lines))
25 |
26 | print("Creating studio...")
27 | s = Studio(f"thunder-jit-coverage-run{gh_run_id}", "oss-thunder", org="lightning-ai", create_ok=True)
28 |
29 | print("Uploading package and scripts...")
30 | s.upload_folder("dist", remote_path="dist")
31 | pkg_path = glob.glob("dist/*.whl")[0]
32 | s.upload_folder("examples/coverage", remote_path="coverage")
33 |
34 | print("Starting studio...")
35 | s.start(machine=Machine.L40S, interruptible=False)
36 |
37 | print("Installing Thunder and other requirements...")
38 | s.run(f"pip install {pkg_path} -U -r coverage/requirements.txt")
39 |
40 | hf_token = os.environ["HF_TOKEN"]
41 | print("Running thunder.jit coverage...")
42 | for filename, models in tqdm(batches, unit_scale=chunk_size):
43 | s.run(
44 | f"HF_TOKEN={hf_token} python coverage/jit_coverage_hf.py --models-file coverage/{filename} --output-dir data"
45 | )
46 |
47 | print("Aggregating results...")
48 | s.run("python coverage/jit_coverage_hf.py --output-dir data --results-file data.json")
49 |
50 | data_json = s.run("cat data.json")
51 | data = json.loads(data_json)
52 | success = [el for el in data if el["status"] == "[SUCCESS]"]
53 | skipped = [el for el in data if el["status"] == "[SKIPPED]"]
54 | failure = [el for el in data if el["status"] == "[FAILURE]"]
55 |
56 | for el in success:
57 | print(f"🟢 [SUCCESS] {el['model']}")
58 |
59 | for el in skipped:
60 | print(f"🟡 [SKIPPED] {el['model']}")
61 | print(f" Error: {el['last']}")
62 |
63 | for el in failure:
64 | print(f"⛔️ [FAILURE] {el['model']}")
65 | print(f" Error: {el['last']}")
66 |
67 | with open("jit_coverage_report.json", "w") as f:
68 | f.write(data_json)
69 |
70 | print("Stopping studio...")
71 | s.stop()
72 | s.delete()
73 |
74 |
75 | if __name__ == "__main__":
76 | # parse command line arguments
77 | args = sys.argv[1:]
78 | main(*args)
79 |
--------------------------------------------------------------------------------
/.github/run-quickstart-as-lit-jobs.py:
--------------------------------------------------------------------------------
1 | import json
2 | import sys
3 | from datetime import datetime
4 | import glob
5 | import os.path
6 |
7 | from lightning_sdk import Studio, Job, Machine, Status
8 |
9 |
10 | def main(gh_run_id: str = ""):
11 | if not gh_run_id:
12 | gh_run_id = datetime.now().strftime("%Y-%m-%d|%H:%M:%S")
13 | print("Creating studio...")
14 | s = Studio(f"thunder-quickstarts-run{gh_run_id}", "oss-thunder", org="lightning-ai", create_ok=True)
15 | print("Uploading package and scripts...")
16 | s.upload_folder("dist", remote_path="dist")
17 | pkg_path = glob.glob("dist/*.whl")[0]
18 | s.upload_folder("examples/quickstart", remote_path="quickstart")
19 |
20 | print("Starting studio...")
21 | s.start()
22 | print("Installing Thunder and other requirements...")
23 | s.run(f"pip install {pkg_path} -U -r quickstart/requirements.txt")
24 |
25 | ls_quickstart = glob.glob("examples/quickstart/*.py")
26 | print("Found quickstart scripts:", ls_quickstart)
27 |
28 | print("Running quickstart scripts...")
29 | jobs = {
30 | os.path.basename(script): Job.run(
31 | name=f"ci-run{gh_run_id}_{script}",
32 | command=f"pip list && python quickstart/{os.path.basename(script)}",
33 | studio=s,
34 | machine=Machine.L40S,
35 | interruptible=True,
36 | )
37 | for script in ls_quickstart
38 | }
39 |
40 | print("Stopping studio...")
41 | s.stop()
42 |
43 | print("Waiting for jobs to finish...")
44 | report, failures = {}, {}
45 | for name, job in jobs.items():
46 | job.wait()
47 | print(f"[{job.status}]\t {job.name}")
48 | report[name] = str(job.status).lower()
49 | if job.status != Status.Completed:
50 | failures[name] = job.logs
51 | else: # clean up successful jobs
52 | job.delete()
53 |
54 | with open("quickstart_report.json", "w") as fp:
55 | json.dump(report, fp, indent=4)
56 |
57 | print("Showing logs of failed jobs...")
58 | separator = "=" * 80
59 | for name, logs in failures.items():
60 | offset = "=" * (80 - 5 - 2 - len(name))
61 | print(f"{separator}\n===== {name} {offset}\n{separator}")
62 | print(logs)
63 | print(separator + "\n" * 5)
64 |
65 | assert not failures
66 |
67 | print("Cleaning up...")
68 | s.delete()
69 |
70 |
71 | if __name__ == "__main__":
72 | # parse command line arguments
73 | args = sys.argv[1:]
74 | main(*args)
75 |
--------------------------------------------------------------------------------
/.github/workflows/auto-cc.yml:
--------------------------------------------------------------------------------
1 | name: Probot
2 |
3 | on:
4 | issues:
5 | types: [labeled]
6 | # should use `pull_request_target` but it's blocked by
7 | # https://github.com/probot/probot/issues/1635
8 | # so this job will not run on forks until the above is fixed
9 | pull_request:
10 | types: [labeled, ready_for_review]
11 |
12 | jobs:
13 | auto-cc:
14 | runs-on: ubuntu-latest
15 | if: github.event_name == 'issue' || github.event.pull_request.draft == false
16 | timeout-minutes: 5
17 | steps:
18 | - uses: Lightning-AI/probot@v5
19 | env:
20 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
21 | with:
22 | job: auto-cc
23 |
--------------------------------------------------------------------------------
/.github/workflows/ci-benchmark.yml:
--------------------------------------------------------------------------------
1 | name: Benchmark Models
2 |
3 | on:
4 | workflow_dispatch: {}
5 | pull_request:
6 | paths:
7 | - ".github/workflows/benchmark-hf.yml"
8 | - ".github/run-benchmark-as-lit-jobs.py"
9 | push: # enable tracing which commit to master broke it...
10 | branches: [main]
11 |
12 | concurrency:
13 | group: ${{ github.workflow }}-${{ github.ref }}
14 | cancel-in-progress: false
15 |
16 | defaults:
17 | run:
18 | shell: bash
19 |
20 | jobs:
21 | launcher-benchmark:
22 | runs-on: "ubuntu-22.04"
23 |
24 | steps:
25 | - uses: actions/checkout@v5
26 |
27 | - uses: actions/setup-python@v5
28 | with:
29 | python-version: "3.10"
30 |
31 | - name: Build Thunder Package
32 | run: |
33 | pip install -U build
34 | python -m build --sdist --wheel --outdir dist/
35 | ls -l dist/
36 |
37 | - name: Launch Benchmark Job in Lightning Studio
38 | env:
39 | LIGHTNING_USER_ID: ${{ secrets.LIGHTNING_USER_ID }}
40 | LIGHTNING_API_KEY: ${{ secrets.LIGHTNING_API_KEY }}
41 | run: |
42 | pip install lightning_sdk -U -q
43 | python .github/run-benchmark-as-lit-jobs.py ${{ github.run_id }}
44 |
45 | - name: Post Slack Notification
46 | if: always() && github.event_name != 'pull_request'
47 | uses: act10ns/slack@v2
48 | with:
49 | webhook-url: ${{ secrets.SLACK_WEBHOOK_URL }}
50 | status: ${{ job.status }}
51 | message: |
52 | *Benchmark Triggered Manually* - [${{ job.status }}]
53 | ref: ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}
54 |
--------------------------------------------------------------------------------
/.github/workflows/ci-checks.yml:
--------------------------------------------------------------------------------
1 | name: General checks
2 |
3 | on:
4 | push:
5 | branches: [main]
6 | pull_request: {}
7 |
8 | concurrency:
9 | group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref }}
10 | cancel-in-progress: ${{ ! (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/release/')) }}
11 |
12 | jobs:
13 | check-schema:
14 | uses: Lightning-AI/utilities/.github/workflows/check-schema.yml@main
15 | with:
16 | azure-dir: ".azure"
17 |
18 | check-package:
19 | uses: Lightning-AI/utilities/.github/workflows/check-package.yml@main
20 | with:
21 | actions-ref: main
22 | import-name: "thunder"
23 | artifact-name: dist-packages-${{ github.sha }}
24 | testing-matrix: |
25 | {
26 | "os": ["ubuntu-latest", "macOS-latest", "windows-latest"],
27 | "python-version": ["3.10", "3.11"]
28 | }
29 |
--------------------------------------------------------------------------------
/.github/workflows/ci-jit-coverage.yml:
--------------------------------------------------------------------------------
1 | name: CI jit coverage
2 |
3 | on:
4 | push:
5 | branches: [main]
6 | pull_request:
7 | paths:
8 | - ".github/workflows/ci-jit-coverage.yml"
9 | - ".github/run-jit-coverage-as-lit-studios.py"
10 | - ".github/load-jit-coverage-report.py"
11 | - "examples/coverage/*"
12 | workflow_dispatch: {}
13 | schedule:
14 | - cron: "0 0 * * *" # every midnight
15 |
16 | concurrency:
17 | group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref }}
18 | cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
19 |
20 | defaults:
21 | run:
22 | shell: bash
23 |
24 | jobs:
25 | launcher-jit-coverage:
26 | runs-on: "ubuntu-22.04"
27 | #timeout-minutes: 55
28 | steps:
29 | - uses: actions/checkout@v5
30 | - uses: actions/setup-python@v5
31 | with:
32 | python-version: "3.10"
33 |
34 | - name: Build package
35 | run: |
36 | pip install -U build
37 | python -m build --sdist --wheel --outdir dist/
38 | ls -l dist/
39 |
40 | - name: Run thunder.jit coverage
41 | env:
42 | LIGHTNING_USER_ID: ${{ secrets.LIGHTNING_USER_ID }}
43 | LIGHTNING_API_KEY: ${{ secrets.LIGHTNING_API_KEY }}
44 | HF_TOKEN: ${{ secrets.HF_TOKEN }}
45 | run: |
46 | pip install lightning_sdk tqdm -U -q
47 | python .github/run-jit-coverage-as-lit-studios.py ${{ github.run_id }}
48 |
49 | - name: Load report
50 | if: always()
51 | id: load-report
52 | run: |
53 | report=$(python .github/load-jit-coverage-report.py)
54 | echo "REPORT<> $GITHUB_ENV
55 | echo "$report" >> $GITHUB_ENV
56 | echo "EOF" >> $GITHUB_ENV
57 |
58 | - uses: act10ns/slack@v2
59 | # if: always() && github.event_name != 'pull_request'
60 | with:
61 | webhook-url: ${{ secrets.SLACK_WEBHOOK_URL }}
62 | status: ${{ job.status }}
63 | message: |
64 | *thunder.jit coverage CI* - [${{ job.status }}]
65 | ${{ env.REPORT }}
66 | ref: ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}
67 |
--------------------------------------------------------------------------------
/.github/workflows/ci-quickstart.yml:
--------------------------------------------------------------------------------
1 | name: CI quickstart
2 |
3 | on:
4 | push:
5 | branches: [main]
6 | pull_request:
7 | paths:
8 | - ".github/workflows/ci-quickstart.yml"
9 | - ".github/run-quickstart-as-lit-jobs.py"
10 | - ".github/load-quickstart-report.py"
11 | - "examples/quickstart/*"
12 | workflow_dispatch: {}
13 | schedule:
14 | - cron: "0 0 * * *" # every midnight
15 |
16 | concurrency:
17 | group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref }}
18 | cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
19 |
20 | defaults:
21 | run:
22 | shell: bash
23 |
24 | jobs:
25 | launcher-quickstart:
26 | runs-on: "ubuntu-22.04"
27 | #timeout-minutes: 55
28 | steps:
29 | - uses: actions/checkout@v5
30 | - uses: actions/setup-python@v5
31 | with:
32 | python-version: "3.10"
33 |
34 | - name: Build package
35 | run: |
36 | pip install -U build
37 | python -m build --sdist --wheel --outdir dist/
38 | ls -l dist/
39 |
40 | - name: Run scripts in jobs
41 | env:
42 | LIGHTNING_USER_ID: ${{ secrets.LIGHTNING_USER_ID }}
43 | LIGHTNING_API_KEY: ${{ secrets.LIGHTNING_API_KEY }}
44 | run: |
45 | pip install lightning_sdk -U -q
46 | python .github/run-quickstart-as-lit-jobs.py ${{ github.run_id }}
47 |
48 | - name: Load report
49 | if: always()
50 | id: load-report
51 | run: |
52 | report=$(python .github/load-quickstart-report.py)
53 | echo "REPORT<> $GITHUB_ENV
54 | echo "$report" >> $GITHUB_ENV
55 | echo "EOF" >> $GITHUB_ENV
56 |
57 | - uses: act10ns/slack@v2
58 | if: always() && github.event_name != 'pull_request' && steps.load-report.conclusion == 'success'
59 | with:
60 | webhook-url: ${{ secrets.SLACK_WEBHOOK_URL }}
61 | status: ${{ job.status }}
62 | message: |
63 | *Quickstart CI* - [${{ job.status }}]
64 | ${{ env.REPORT }}
65 | ref: ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}
66 |
--------------------------------------------------------------------------------
/.github/workflows/docs-build.yml:
--------------------------------------------------------------------------------
1 | name: "Build (& deploy) Docs"
2 | on:
3 | push:
4 | branches: [main]
5 | pull_request:
6 | branches: [main]
7 | types: [opened, reopened, ready_for_review, synchronize]
8 | workflow_dispatch: {}
9 |
10 | concurrency:
11 | group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref }}
12 | cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
13 |
14 | defaults:
15 | run:
16 | shell: bash
17 |
18 | jobs:
19 | docs-make:
20 | if: github.event.pull_request.draft == false
21 | runs-on: ubuntu-22.04
22 | strategy:
23 | fail-fast: false
24 | matrix:
25 | target: ["html", "doctest", "linkcheck"]
26 | env:
27 | ARTIFACT_DAYS: 0
28 | PYPI_LOCAL_DIR: "pypi_pkgs/"
29 | steps:
30 | - uses: actions/checkout@v5
31 | - uses: actions/setup-python@v5
32 | with:
33 | python-version: "3.10"
34 |
35 | - name: Pull sphinx template
36 | run: make get-sphinx-theme
37 | - name: Install pandoc
38 | timeout-minutes: 5
39 | run: sudo apt-get install -y pandoc
40 | - name: Install package & dependencies
41 | timeout-minutes: 20
42 | run: pip install . -U -r requirements/docs.txt
43 |
44 | - name: Make ${{ matrix.target }}
45 | working-directory: docs/
46 | # allow failing link check and doctest if you run with dispatch
47 | continue-on-error: ${{ matrix.target == 'doctest' || matrix.target == 'linkcheck' }}
48 | run: make ${{ matrix.target }} --debug --jobs $(nproc) SPHINXOPTS="-W --keep-going"
49 |
50 | - name: Keep artifact
51 | if: github.event_name == 'pull_request'
52 | run: echo "ARTIFACT_DAYS=7" >> $GITHUB_ENV
53 | - name: Upload built docs
54 | if: ${{ matrix.target == 'html' }}
55 | uses: actions/upload-artifact@v4
56 | with:
57 | name: docs-html-${{ github.sha }}
58 | path: docs/build/html/
59 | retention-days: ${{ env.ARTIFACT_DAYS }}
60 |
61 | deploy-docs:
62 | needs: docs-make
63 | if: github.repository_owner == 'Lightning-AI' && github.event_name == 'push'
64 | runs-on: ubuntu-latest
65 | env:
66 | GCP_TARGET: "gs://lightning-docs-thunder"
67 | steps:
68 | - uses: actions/download-artifact@v5
69 | with:
70 | name: docs-html-${{ github.sha }}
71 | path: docs/build/html/
72 |
73 | - name: Authenticate to Google Cloud
74 | uses: google-github-actions/auth@v3
75 | with:
76 | credentials_json: ${{ secrets.GCS_SA_KEY }}
77 |
78 | - name: Setup gcloud
79 | uses: google-github-actions/setup-gcloud@v3
80 | with:
81 | project_id: ${{ secrets.GCS_PROJECT }}
82 |
83 | # Uploading docs to GCS, so they can be served on lightning.ai
84 | #- name: Upload docs/thunder/stable to GCS 🪣
85 | # if: startsWith(github.ref, 'refs/heads/release/')
86 | # run: gsutil -m rsync -d -R docs/build/html/ ${GCP_TARGET}/stable
87 |
88 | # Uploading docs to GCS, so they can be served on lightning.ai
89 | - name: Upload docs/thunder/latest to GCS 🪣
90 | if: github.ref == 'refs/heads/main'
91 | run: gsutil -m rsync -d -R docs/build/html/ ${GCP_TARGET}/latest
92 |
93 | # Uploading docs to GCS, so they can be served on lightning.ai
94 | #- name: Upload docs/thunder/release to GCS 🪣
95 | # if: startsWith(github.ref, 'refs/tags/')
96 | # run: gsutil -m rsync -d -R docs/build/html/ ${GCP_TARGET}/${{ github.ref_name }}
97 |
98 | # Uploading docs as archive to GCS, so they can be as backup
99 | #- name: Upload docs as archive to GCS 🪣
100 | # if: startsWith(github.ref, 'refs/tags/')
101 | # working-directory: docs/build
102 | # run: |
103 | # zip ${{ github.ref_name }}.zip -r html/
104 | # gsutil cp ${{ github.ref_name }}.zip ${GCP_TARGET}
105 |
--------------------------------------------------------------------------------
/.github/workflows/label-conflicts.yml:
--------------------------------------------------------------------------------
1 | name: Label merge conflicts
2 |
3 | on:
4 | push:
5 | branches: ["main"]
6 | pull_request_target:
7 | types: ["synchronize", "reopened", "opened"]
8 |
9 | concurrency:
10 | group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref }}
11 | cancel-in-progress: true
12 |
13 | jobs:
14 | triage-conflicts:
15 | runs-on: ubuntu-latest
16 | steps:
17 | - uses: mschilde/auto-label-merge-conflicts@591722e97f3c4142df3eca156ed0dcf2bcd362bd
18 | with:
19 | CONFLICT_LABEL_NAME: "has conflicts"
20 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
21 | MAX_RETRIES: 3
22 | WAIT_MS: 5000
23 |
--------------------------------------------------------------------------------
/.github/workflows/labeler.yml:
--------------------------------------------------------------------------------
1 | name: "Pull Request Labeler"
2 | on: [pull_request_target]
3 |
4 | jobs:
5 | triage-prs:
6 | permissions:
7 | contents: read
8 | pull-requests: write
9 | runs-on: ubuntu-latest
10 | steps:
11 | # Uploads repository content to the runner
12 | - uses: actions/checkout@v5
13 | - uses: actions/labeler@v5
14 | with:
15 | # The path to the label configuration file.
16 | configuration-path: .github/labeling-config.yml
17 | # Whether removing labels when matching files are reverted or no longer changed by the PR
18 | sync-labels: true
19 |
--------------------------------------------------------------------------------
/.github/workflows/release-nightly.yml:
--------------------------------------------------------------------------------
1 | name: Nightly packages
2 |
3 | on:
4 | pull_request: # this shall test only the part of workflow before publishing
5 | branches: [main, "release/*"]
6 | types: [opened, reopened, ready_for_review, synchronize]
7 | paths:
8 | - ".github/workflows/release-nightly.yml"
9 | schedule:
10 | - cron: "0 0 * * 0" # on Sundays
11 | workflow_dispatch: {}
12 |
13 | defaults:
14 | run:
15 | shell: bash
16 |
17 | jobs:
18 | releasing-nightly:
19 | runs-on: ubuntu-22.04
20 | steps:
21 | - uses: actions/checkout@v5
22 | - uses: actions/setup-python@v5
23 | with:
24 | python-version: "3.10"
25 |
26 | - name: Install dependencies
27 | run: python -m pip install --user --upgrade setuptools wheel packaging
28 | - name: Build package
29 | env:
30 | CONVERT_VERSION2NIGHTLY: "1"
31 | run: python setup.py sdist bdist_wheel
32 |
33 | # We do this, since failures on test.pypi aren't that bad
34 | - name: Publish to Test PyPI
35 | if: github.event_name == 'schedule' || github.event_name == 'workflow_dispatch'
36 | uses: pypa/gh-action-pypi-publish@v1.12.4
37 | with:
38 | user: __token__
39 | password: ${{ secrets.test_pypi_password }}
40 | repository_url: https://test.pypi.org/legacy/
41 |
42 | - name: Publish distribution 📦 to PyPI
43 | if: github.event_name == 'schedule' || github.event_name == 'workflow_dispatch'
44 | uses: pypa/gh-action-pypi-publish@v1.12.4
45 | with:
46 | user: __token__
47 | password: ${{ secrets.pypi_password }}
48 |
--------------------------------------------------------------------------------
/.github/workflows/release-pypi.yml:
--------------------------------------------------------------------------------
1 | name: PyPI Release
2 |
3 | # https://help.github.com/en/actions/reference/events-that-trigger-workflows
4 | on: # Trigger the workflow on push or pull request, but only for the main branch
5 | push:
6 | branches: [main]
7 | release:
8 | types: [published]
9 |
10 | # based on https://github.com/pypa/gh-action-pypi-publish
11 |
12 | jobs:
13 | releasing-pypi:
14 | runs-on: ubuntu-22.04
15 | steps:
16 | - uses: actions/checkout@v5
17 | - uses: actions/setup-python@v5
18 | with:
19 | python-version: "3.10"
20 |
21 | - name: Install dependencies
22 | run: python -m pip install --user --upgrade setuptools wheel packaging
23 | - name: Build
24 | run: python setup.py sdist bdist_wheel
25 |
26 | # We do this, since failures on test.pypi aren't that bad
27 | - name: Publish to Test PyPI
28 | if: startsWith(github.event.ref, 'refs/tags') || github.event_name == 'release'
29 | uses: pypa/gh-action-pypi-publish@v1.12.4
30 | with:
31 | user: __token__
32 | password: ${{ secrets.test_pypi_password }}
33 | repository_url: https://test.pypi.org/legacy/
34 |
35 | - name: Publish distribution 📦 to PyPI
36 | if: startsWith(github.event.ref, 'refs/tags') || github.event_name == 'release'
37 | uses: pypa/gh-action-pypi-publish@v1.12.4
38 | with:
39 | user: __token__
40 | password: ${{ secrets.pypi_password }}
41 |
--------------------------------------------------------------------------------
/.github/workflows/stale-branches.yaml:
--------------------------------------------------------------------------------
1 | name: Delete abandoned branches
2 |
3 | on:
4 | pull_request:
5 | branches: ["main"]
6 | paths:
7 | - ".github/workflows/stale-branches.yaml"
8 | # Run daily at midnight
9 | schedule:
10 | - cron: "0 0 * * *"
11 | # Allow workflow to be manually run from the GitHub UI
12 | workflow_dispatch:
13 |
14 | concurrency:
15 | group: ${{ github.workflow }}-${{ github.ref }}
16 | cancel-in-progress: true
17 |
18 | jobs:
19 | cleanup-old-branches:
20 | runs-on: ubuntu-latest
21 | name: Satisfy my repo CDO
22 | env:
23 | DRY_RUN: no
24 | steps:
25 | - name: Set dry run
26 | if: ${{ github.event_name == 'pull_request' }}
27 | run: echo "DRY_RUN=yes" >> $GITHUB_ENV
28 | - name: Delete those pesky dead branches
29 | uses: phpdocker-io/github-actions-delete-abandoned-branches@v2
30 | id: delete_stuff
31 | with:
32 | github_token: ${{ github.token }}
33 | last_commit_age_days: 90
34 | ignore_branches: main
35 | # Disable dry run and actually get stuff deleted
36 | # For a PR, always perform dry run
37 | dry_run: ${{ env.DRY_RUN }}
38 |
39 | - name: Get output
40 | run: "echo 'Deleted branches: ${{ steps.delete_stuff.outputs.deleted_branches }}'"
41 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 |
29 | # Python venv
30 | bin/
31 | lib64
32 | pyvenv.cfg
33 | share/
34 | etc/
35 | include/
36 |
37 | # PyInstaller
38 | # Usually these files are written by a python script from a template
39 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
40 | *.manifest
41 | *.spec
42 |
43 | # Installer logs
44 | pip-log.txt
45 | pip-delete-this-directory.txt
46 |
47 | # Unit test / coverage reports
48 | htmlcov/
49 | .tox/
50 | .nox/
51 | .coverage
52 | .coverage.*
53 | .cache
54 | nosetests.xml
55 | coverage.xml
56 | *.cover
57 | *.py,cover
58 | .hypothesis/
59 | .pytest_cache/
60 |
61 | # Translations
62 | *.mo
63 | *.pot
64 |
65 | # Sphinx documentation
66 | docs/_build/
67 | docs/source/api/
68 | docs/source/*.md
69 | docs/source/notebooks/
70 |
71 | # PyBuilder
72 | target/
73 |
74 | # Jupyter Notebook
75 | .ipynb_checkpoints
76 |
77 | # VSCode
78 | .vscode/
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
88 | __pypackages__/
89 |
90 | # Celery stuff
91 | celerybeat-schedule
92 | celerybeat.pid
93 |
94 | # SageMath parsed files
95 | *.sage.py
96 |
97 | # Environments
98 | .env
99 | .venv
100 | env/
101 | venv/
102 | ENV/
103 | env.bak/
104 | venv.bak/
105 |
106 | # Spyder project settings
107 | .spyderproject
108 | .spyproject
109 |
110 | # Rope project settings
111 | .ropeproject
112 |
113 | # mkdocs documentation
114 | /site
115 |
116 | # mypy
117 | .mypy_cache/
118 | .dmypy.json
119 | dmypy.json
120 |
121 | # Pyre type checker
122 | .pyre/
123 |
124 | # PyCharm
125 | .idea/
126 |
127 | # Lightning logs
128 | lightning_logs
129 | *.gz
130 | .DS_Store
131 | .*_submit.py
132 |
133 | # Editor temporary file
134 | *.swn
135 | *.swo
136 | *.swp
137 | *.swm
138 | *~
139 |
140 | # Build artifacts
141 | scripts/build
142 |
143 | # Profiler traces
144 | benchmarks/traces
145 |
146 | # benchmark results
147 | .results
148 |
149 | .ruff_cache/
150 |
151 | # come CI artifacts
152 | notebooks/all.txt
153 | notebooks/ci.txt
154 |
155 | quickstart_report.json
156 |
--------------------------------------------------------------------------------
/.lightning/workflows/all-tests.yaml:
--------------------------------------------------------------------------------
1 | trigger:
2 | push:
3 | branches: ["main"]
4 | pull_request:
5 | branches: ["main"]
6 |
7 | timeout: "60" # minutes
8 | interruptible: False
9 | parametrize:
10 | matrix:
11 | image:
12 | - "pytorchlightning/lightning-thunder:ubuntu24.04-cuda12.6.3-cudnn-fe1.10.0-py3.10-pt_2.8.0-dev"
13 | - "pytorchlightning/lightning-thunder:ubuntu24.04-cuda12.6.3-cudnn-fe1.10.0-py3.10-pt_main-dev"
14 | testing: ["main", "ops", "grads"]
15 | machine: ["L4"]
16 | exclude: []
17 | include:
18 | - image: "pytorchlightning/lightning-thunder:ubuntu24.04-cuda12.6.3-cudnn-fe1.10.0-py3.10-pt_2.8.0-dev"
19 | testing: "distributed"
20 | machine: "L4_X_2"
21 | - image: "pytorchlightning/lightning-thunder:ubuntu24.04-cuda12.6.3-cudnn-fe1.10.0-py3.10-pt_main-dev"
22 | testing: "distributed"
23 | machine: "L4_X_2"
24 |
25 | env:
26 | CI: "true" # skip some tests with CI
27 | NCCL_DEBUG: "INFO"
28 | NCCL_IGNORE_DISABLED_P2P: "1"
29 | TORCH_VERSION: "2.7.1"
30 | CUDA_LAUNCH_BLOCKING: "1" # for debugging purposes, to get better stack traces
31 |
32 | run: |
33 | whereis nvidia
34 | nvidia-smi
35 | python --version
36 | pip --version
37 | pip list
38 | set -ex
39 |
40 | # drop pt from requirements so not to interfere with the existing one
41 | bash scripts/remove-torch-lines.sh requirements/base.txt
42 | cat requirements/base.txt
43 |
44 | # double check on test requirements
45 | pip install -U -r requirements/base.txt -r requirements/test.txt
46 |
47 | # https://docs.codecov.com/docs/codecov-uploader
48 | curl -Os https://uploader.codecov.io/latest/linux/codecov
49 | chmod +x codecov
50 |
51 | # install this package
52 | python setup.py develop
53 |
54 | bash scripts/sanity-check.sh
55 |
56 | if [ "${testing}" == "main" ]; then
57 | coverage run --source thunder -m \
58 | pytest thunder/tests/ \
59 | -m "not standalone" \
60 | -v --datefmt="%Y%m%d-%H:%M:%S.%f" \
61 | --random-order-seed=42 \
62 | --durations=250 \
63 | --timeout=360 \
64 | --numprocesses=9 \
65 | --ignore=thunder/tests/distributed --ignore=thunder/tests/test_networks.py \
66 | --ignore=thunder/tests/test_ops.py --ignore=thunder/tests/test_grad.py
67 | coverage run --source thunder -m \
68 | pytest \
69 | thunder/tests/test_networks.py \
70 | -m "not standalone" \
71 | -v --durations=0 \
72 | --random-order-seed=42 \
73 | --numprocesses=3
74 | elif [ "${testing}" == "ops" ]; then
75 | coverage run --source thunder -m \
76 | pytest thunder/tests/test_ops.py \
77 | -m "not standalone" \
78 | -v --datefmt="%Y%m%d-%H:%M:%S.%f" \
79 | --random-order-seed=42 \
80 | --durations=250 \
81 | --timeout=240 \
82 | --numprocesses=9
83 | elif [ "${testing}" == "grads" ]; then
84 | coverage run --source thunder -m \
85 | pytest thunder/tests/test_grad.py \
86 | -m "not standalone" \
87 | -v --datefmt="%Y%m%d-%H:%M:%S.%f" \
88 | --random-order-seed=42 \
89 | --durations=250 \
90 | --timeout=360 \
91 | --numprocesses=9
92 | elif [ "${testing}" == "distributed" ]; then
93 | pytest thunder/tests/distributed \
94 | -v --durations=0 \
95 | --random-order-seed=42
96 | else
97 | echo "Unknown testing type: ${testing}"
98 | exit 1
99 | fi
100 |
101 | # TODO: compile coverage results
102 | #python -m coverage report
103 | #python -m coverage xml
104 | # upload to codecov
105 | # TODO: add >> --token=$(CODECOV_TOKEN) --commit=$(Build.SourceVersion)
106 | #./codecov --flags=gpu,pytest,${testing} --name="GPU-coverage" --env=linux
107 |
--------------------------------------------------------------------------------
/.lightning/workflows/notebooks.yaml:
--------------------------------------------------------------------------------
1 | trigger:
2 | push:
3 | branches: ["main"]
4 | pull_request:
5 | branches: ["main"]
6 |
7 | timeout: "50" # minutes
8 | machine: "L4"
9 | interruptible: False
10 | parametrize:
11 | matrix:
12 | image:
13 | - "pytorchlightning/lightning-thunder:ubuntu24.04-cuda12.6.3-cudnn-fe1.10.0-py3.10-pt_2.8.0-dev"
14 | - "pytorchlightning/lightning-thunder:ubuntu24.04-cuda12.6.3-cudnn-fe1.10.0-py3.10-pt_main-dev"
15 | exclude: []
16 | include: []
17 |
18 | run: |
19 | whereis nvidia
20 | nvidia-smi
21 | python --version
22 | pip --version
23 | pip list
24 | set -ex
25 |
26 | # drop pt from requirements so not to interfere with the existing one
27 | bash scripts/remove-torch-lines.sh requirements/base.txt
28 | cat requirements/base.txt
29 | # double check on test requirements
30 | pip install -q -U -r requirements/base.txt -r requirements/notebooks.txt
31 | # install this package
32 | python setup.py develop
33 |
34 | bash scripts/sanity-check.sh
35 |
36 | # list all notebooks in this folder
37 | cd notebooks/
38 | find . -name "*.ipynb" > all.txt
39 | # drop all "./" from beginning of each line
40 | sed -i 's/^\.\///' all.txt
41 | # filter out the ones that are listed in .ignore.ci
42 | grep -Fxv -f .ignore.ci all.txt > ci.txt
43 | # iterate over all listed notebooks and execute them with jupyter
44 | while read -r line; do
45 | echo "Processing $line"
46 | jupyter execute $line --timeout=300
47 | done <<< $(cat ci.txt)
48 |
--------------------------------------------------------------------------------
/.lightning/workflows/transformer-engine.yaml:
--------------------------------------------------------------------------------
1 | trigger:
2 | push:
3 | branches: ["main"]
4 | pull_request:
5 | branches: ["main"]
6 |
7 | timeout: "30" # minutes
8 | machine: "L4"
9 | interruptible: False
10 | image: "pytorchlightning/lightning-thunder:ubuntu24.04-cuda12.6.3-cudnn-fe1.10.0-py3.10-pt_2.8.0-dev"
11 | parametrize:
12 | matrix:
13 | test_file:
14 | - test_transformer_engine_executor.py
15 | - test_transformer_engine_v1_executor.py
16 |
17 | run: |
18 | whereis nvidia
19 | nvidia-smi
20 | pip list
21 | set -ex
22 |
23 | # conda install -c conda-forge libstdcxx-ng
24 | # sudo apt install libstdc++6 libstdc++-*-dev
25 | pip install . -U -q -r requirements/test.txt
26 | # Need to explicitly point to cudnn.h as it is installed at a non-standard location
27 | # Ref: https://github.com/NVIDIA/TransformerEngine/issues/918#issuecomment-2187703769
28 | CPLUS_INCLUDE_PATH="/usr/local/lib/python3.10/dist-packages/nvidia/cudnn/include/" pip install --no-build-isolation 'transformer_engine[pytorch]'
29 | pip list # for debugging purposes
30 | pytest thunder/tests/${test_file} -v -rs
31 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | default_language_version:
2 | python: python3
3 |
4 | ci:
5 | autofix_prs: true
6 | autoupdate_commit_msg: "[pre-commit.ci] pre-commit suggestions"
7 | # submodules: true
8 |
9 | repos:
10 | - repo: https://github.com/pre-commit/pre-commit-hooks
11 | rev: v6.0.0
12 | hooks:
13 | - id: end-of-file-fixer
14 | - id: trailing-whitespace
15 | exclude: README.md
16 | - id: check-case-conflict
17 | - id: check-yaml
18 | - id: check-toml
19 | - id: check-json
20 | - id: check-added-large-files
21 | args: ["--maxkb=400", "--enforce-all"]
22 | exclude: notebooks
23 | - id: check-docstring-first
24 | - id: detect-private-key
25 |
26 | - repo: https://github.com/asottile/pyupgrade
27 | rev: v3.20.0
28 | hooks:
29 | - id: pyupgrade
30 | args: ["--py310-plus"]
31 | name: Upgrade code
32 | exclude: "examples|thunder/tests/test_interpreter.py|thunder/tests/test_jit_general.py"
33 |
34 | - repo: https://github.com/codespell-project/codespell
35 | rev: v2.4.1
36 | hooks:
37 | - id: codespell
38 | additional_dependencies: [tomli]
39 | #args: ["--write-changes"] # uncomment if you want to get automatic fixing
40 |
41 | - repo: https://github.com/astral-sh/ruff-pre-commit
42 | rev: v0.13.0
43 | hooks:
44 | - id: ruff-check
45 | args: ["--fix"]
46 | - id: ruff-format
47 | types_or: [python]
48 | exclude: "examples"
49 |
50 | - repo: https://github.com/executablebooks/mdformat
51 | rev: 0.7.22
52 | hooks:
53 | - id: mdformat
54 | additional_dependencies:
55 | - mdformat-gfm
56 | - mdformat-black
57 | - mdformat_frontmatter
58 | exclude: "examples"
59 |
60 | - repo: https://github.com/sphinx-contrib/sphinx-lint
61 | rev: v1.0.0
62 | hooks:
63 | - id: sphinx-lint
64 |
65 | - repo: https://github.com/JoC0de/pre-commit-prettier
66 | rev: b3e25fa39aa676c36bc18eb9eae6f26d9bb63f39 # v3.6.2 using SHA as tags are not persistent
67 | hooks:
68 | - id: prettier
69 | files: \.(json|yml|yaml|toml)
70 | # https://prettier.io/docs/en/options.html#print-width
71 | args: ["--print-width=120"]
72 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | recursive-exclude __pycache__ *.py[cod] *.orig
2 |
3 | # Include the README and CHANGELOG
4 | include *.md
5 | recursive-include pl_sandbox *.md
6 |
7 | # Include the license file
8 | include LICENSE
9 |
10 | # Exclude build configs
11 | exclude *.toml
12 | exclude *.svg
13 | exclude *.yml
14 | exclude *.yaml
15 |
16 | # exclude tests from package
17 | recursive-exclude thunder/tests *
18 | recursive-exclude site *
19 | exclude thunder/tests
20 |
21 | # Exclude the documentation files
22 | recursive-exclude docs *
23 | exclude docs
24 |
25 | # Include the Requirements
26 | recursive-include requirements *.txt
27 |
28 | # Exclude Makefile
29 | exclude Makefile
30 |
31 | prune .git
32 | prune .github
33 | prune notebook*
34 | prune temp*
35 | prune test*
36 | prune benchmark*
37 | prune examples*
38 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | .PHONY: test clean docs
2 |
3 | # assume you have installed need packages
4 | export SPHINX_MOCK_REQUIREMENTS=0
5 |
6 | test: clean
7 | pip install -q -r requirements.txt -r requirements/test.txt
8 |
9 | # use this to run tests
10 | python -m coverage run --source thunder -m pytest thunder tests -v
11 | python -m coverage report
12 |
13 | get-sphinx-theme:
14 | pip install -q awscli
15 | mkdir -p dist/
16 | aws s3 sync --no-sign-request s3://sphinx-packages/ dist/
17 | pip install lai-sphinx-theme -f dist/
18 |
19 | docs: clean get-sphinx-theme
20 | pip install -e . --quiet -r requirements/docs.txt -f https://download.pytorch.org/whl/cpu/torch_stable.html
21 | cd docs ; python -m sphinx -b html -W --keep-going source build
22 |
23 | clean:
24 | # clean all temp runs
25 | rm -rf .mypy_cache
26 | rm -rf .pytest_cache
27 | rm -rf ./docs/build
28 | rm -rf ./docs/source/**/generated
29 | rm -rf ./docs/source/api
30 | rm -rf _ckpt_*
31 |
--------------------------------------------------------------------------------
/dockers/README.md:
--------------------------------------------------------------------------------
1 | # Docker images
2 |
3 | ## Build images from Dockerfiles
4 |
5 | You can build it on your own, note it takes lots of time, be prepared.
6 |
7 | ```bash
8 | # build with specific arguments
9 | docker image build -t lightning:ubuntu-cuda-py3.10-cuda12.1.1 -f dockers/ubuntu-cuda/Dockerfile --build-arg "CUDA_VERSION=12.1.1" .
10 | ```
11 |
12 | To run your docker use
13 |
14 | ```bash
15 | docker image list
16 | docker run --rm -it pytorch-lightning:ubuntu-cuda-py3.10-cuda11.7.0 bash
17 | ```
18 |
19 | ## Run docker image with GPUs
20 |
21 | To run docker image with access to your GPUs, you need to install
22 |
23 | ```bash
24 | # Add the package repositories
25 | distribution=$(. /etc/os-release;echo $ID$VERSION_ID)
26 | curl -s -L https://nvidia.github.io/nvidia-docker/gpgkey | sudo apt-key add -
27 | curl -s -L https://nvidia.github.io/nvidia-docker/$distribution/nvidia-docker.list | sudo tee /etc/apt/sources.list.d/nvidia-docker.list
28 |
29 | sudo apt-get update && sudo apt-get install -y nvidia-container-toolkit
30 | sudo systemctl restart docker
31 | ```
32 |
33 | and later run the docker image with `--gpus=all`. For example,
34 |
35 | ```bash
36 | docker run --rm -it --gpus=all pytorchlightning/lightning:ubuntu-cuda-py3.10-cuda12.1.0
37 | ```
38 |
--------------------------------------------------------------------------------
/dockers/with-apex/Dockerfile:
--------------------------------------------------------------------------------
1 | # Copyright The Lightning AI team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | ARG BASE_IMAGE_TAG
16 |
17 | FROM pytorchlightning/lightning-thunder:${BASE_IMAGE_TAG}
18 |
19 | ARG APEX_CHECKOUT="master"
20 |
21 | SHELL ["/bin/bash", "-c"]
22 |
23 | RUN \
24 | # building Apex from source
25 | pip install "pip>=23.1" packaging && \
26 | git clone https://github.com/NVIDIA/apex && \
27 | cd apex && \
28 | git checkout ${APEX_CHECKOUT} && \
29 | # https://github.com/NVIDIA/apex#linux
30 | pip install -v \
31 | --disable-pip-version-check \
32 | --no-cache-dir \
33 | --no-build-isolation \
34 | --config-settings "--build-option=--xentropy" \
35 | . && \
36 | cd .. && \
37 | rm -rf apex
38 |
39 | RUN \
40 | # Show what we have
41 | pip --version && \
42 | pip list && \
43 | python -c "import apex"
44 |
--------------------------------------------------------------------------------
/dockers/with-dev/Dockerfile:
--------------------------------------------------------------------------------
1 | # Copyright The Lightning AI team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | ARG BASE_IMAGE_TAG
16 |
17 | FROM pytorchlightning/lightning-thunder:${BASE_IMAGE_TAG}
18 |
19 | SHELL ["/bin/bash", "-c"]
20 |
21 | COPY requirements/ requirements/
22 |
23 | RUN \
24 | ls -lh requirements/ && \
25 | CUDA_VERSION_MM=${CUDA_VERSION%.*} && \
26 | MAKEFLAGS="-j$(( $(nproc) / 4 ))" && \
27 | pip install -U -r requirements/test.txt && \
28 | rm -rf requirements/
29 |
30 | RUN \
31 | # Show what we have
32 | pip --version && \
33 | pip list
34 |
--------------------------------------------------------------------------------
/docs/.build_docs.sh:
--------------------------------------------------------------------------------
1 | make clean
2 | make html --debug --jobs $(nproc)
3 |
--------------------------------------------------------------------------------
/docs/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | # .readthedocs.yml
2 | # Read the Docs configuration file
3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
4 |
5 | # Required
6 | version: 2
7 |
8 | # Build documentation in the docs/ directory with Sphinx
9 | # reference: https://docs.readthedocs.io/en/stable/config-file/v2.html#sphinx
10 | sphinx:
11 | fail_on_warning: true
12 | configuration: docs/conf.py
13 |
14 | build:
15 | os: "ubuntu-22.04"
16 | tools:
17 | python: "3.10"
18 | commands:
19 | - printenv
20 | - pwd ; pip install -q py-tree ; py-tree .
21 | - make docs
22 | - mkdir -p _readthedocs ; mv docs/build _readthedocs/html
23 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line.
5 | SPHINXOPTS = -W
6 | SPHINXBUILD = python $(shell which sphinx-build)
7 | SOURCEDIR = source
8 | BUILDDIR = build
9 |
10 | # Put it first so that "make" without argument is like "make help".
11 | help:
12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
13 |
14 | .PHONY: help Makefile
15 |
16 | # Catch-all target: route all unknown targets to Sphinx using the new
17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
18 | %: Makefile
19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
20 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=source
11 | set BUILDDIR=build
12 |
13 | if "%1" == "" goto help
14 |
15 | %SPHINXBUILD% >NUL 2>NUL
16 | if errorlevel 9009 (
17 | echo.
18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
19 | echo.installed, then set the SPHINXBUILD environment variable to point
20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
21 | echo.may add the Sphinx directory to PATH.
22 | echo.
23 | echo.If you don't have Sphinx installed, grab it from
24 | echo.http://sphinx-doc.org/
25 | exit /b 1
26 | )
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/docs/source/_static/copybutton.js:
--------------------------------------------------------------------------------
1 | /* Copied from the official Python docs: https://docs.python.org/3/_static/copybutton.js */
2 | $(document).ready(function() {
3 | /* Add a [>>>] button on the top-right corner of code samples to hide
4 | * the >>> and ... prompts and the output and thus make the code
5 | * copyable. */
6 | var div = $('.highlight-python .highlight,' +
7 | '.highlight-python3 .highlight,' +
8 | '.highlight-pycon .highlight,' +
9 | '.highlight-default .highlight');
10 | var pre = div.find('pre');
11 |
12 | // get the styles from the current theme
13 | pre.parent().parent().css('position', 'relative');
14 | var hide_text = 'Hide the prompts and output';
15 | var show_text = 'Show the prompts and output';
16 | var border_width = pre.css('border-top-width');
17 | var border_style = pre.css('border-top-style');
18 | var border_color = pre.css('border-top-color');
19 | var button_styles = {
20 | 'cursor':'pointer', 'position': 'absolute', 'top': '0', 'right': '0',
21 | 'border-color': border_color, 'border-style': border_style,
22 | 'border-width': border_width, 'color': border_color, 'text-size': '75%',
23 | 'font-family': 'monospace', 'padding-left': '0.2em', 'padding-right': '0.2em',
24 | 'border-radius': '0 3px 0 0'
25 | }
26 |
27 | // create and add the button to all the code blocks that contain >>>
28 | div.each(function(index) {
29 | var jthis = $(this);
30 | if (jthis.find('.gp').length > 0) {
31 | var button = $('>>>');
32 | button.css(button_styles)
33 | button.attr('title', hide_text);
34 | button.data('hidden', 'false');
35 | jthis.prepend(button);
36 | }
37 | // tracebacks (.gt) contain bare text elements that need to be
38 | // wrapped in a span to work with .nextUntil() (see later)
39 | jthis.find('pre:has(.gt)').contents().filter(function() {
40 | return ((this.nodeType == 3) && (this.data.trim().length > 0));
41 | }).wrap('');
42 | });
43 |
44 | // define the behavior of the button when it's clicked
45 | $('.copybutton').click(function(e){
46 | e.preventDefault();
47 | var button = $(this);
48 | if (button.data('hidden') === 'false') {
49 | // hide the code output
50 | button.parent().find('.go, .gp, .gt').hide();
51 | button.next('pre').find('.gt').nextUntil('.gp, .go').css('visibility', 'hidden');
52 | button.css('text-decoration', 'line-through');
53 | button.attr('title', show_text);
54 | button.data('hidden', 'true');
55 | } else {
56 | // show the code output
57 | button.parent().find('.go, .gp, .gt').show();
58 | button.next('pre').find('.gt').nextUntil('.gp, .go').css('visibility', 'visible');
59 | button.css('text-decoration', 'none');
60 | button.attr('title', hide_text);
61 | button.data('hidden', 'false');
62 | }
63 | });
64 | });
65 |
--------------------------------------------------------------------------------
/docs/source/_static/images/LightningThunderDarkModewByline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-AI/lightning-thunder/6720e82b18884a26bbadaffff2f33d836682cc9d/docs/source/_static/images/LightningThunderDarkModewByline.png
--------------------------------------------------------------------------------
/docs/source/_static/images/LightningThunderLightModewByline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-AI/lightning-thunder/6720e82b18884a26bbadaffff2f33d836682cc9d/docs/source/_static/images/LightningThunderLightModewByline.png
--------------------------------------------------------------------------------
/docs/source/_static/images/how_it_works.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-AI/lightning-thunder/6720e82b18884a26bbadaffff2f33d836682cc9d/docs/source/_static/images/how_it_works.png
--------------------------------------------------------------------------------
/docs/source/_static/images/lightning_thunder_lightmode_nobyline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-AI/lightning-thunder/6720e82b18884a26bbadaffff2f33d836682cc9d/docs/source/_static/images/lightning_thunder_lightmode_nobyline.png
--------------------------------------------------------------------------------
/docs/source/_static/images/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-AI/lightning-thunder/6720e82b18884a26bbadaffff2f33d836682cc9d/docs/source/_static/images/logo.png
--------------------------------------------------------------------------------
/docs/source/_static/images/normalized_training_throughput_zero2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-AI/lightning-thunder/6720e82b18884a26bbadaffff2f33d836682cc9d/docs/source/_static/images/normalized_training_throughput_zero2.png
--------------------------------------------------------------------------------
/docs/source/_static/images/pretrain_perf.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-AI/lightning-thunder/6720e82b18884a26bbadaffff2f33d836682cc9d/docs/source/_static/images/pretrain_perf.png
--------------------------------------------------------------------------------
/docs/source/_static/images/training_throughput_single.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-AI/lightning-thunder/6720e82b18884a26bbadaffff2f33d836682cc9d/docs/source/_static/images/training_throughput_single.png
--------------------------------------------------------------------------------
/docs/source/_templates/theme_variables.jinja:
--------------------------------------------------------------------------------
1 | {%- set external_urls = {
2 | 'github': 'https://github.com/Lightning-AI/lightning-thunder',
3 | 'github_issues': 'https://github.com/Lightning-AI/lightning-thunder/issues',
4 | 'contributing': 'https://github.com/Lightning-AI/lightning-thunder/blob/master/CONTRIBUTING.md',
5 | 'governance': 'https://github.com/Lightning-AI/lightning-thunder/blob/master/governance.md',
6 | 'docs': 'https://lightning-thunder.rtfd.io/en/latest',
7 | 'twitter': 'https://twitter.com/LightningAI',
8 | 'discuss': 'https://pytorch-lightning.slack.com',
9 | 'tutorials': 'https://lightning-thunder.readthedocs.io/en/latest/#tutorials',
10 | 'previous_pytorch_versions': 'https://lightning-thunder.rtfd.io/en/latest/',
11 | 'home': 'https://lightning-thunder.rtfd.io/en/latest/',
12 | 'get_started': 'https://lightning-thunder.readthedocs.io/en/latest/introduction_guide.html',
13 | 'features': 'https://lightning-thunder.rtfd.io/en/latest/',
14 | 'blog': 'https://www.Lightning-AI.ai/blog',
15 | 'resources': 'https://lightning-thunder.readthedocs.io/en/latest/#community-examples',
16 | 'support': 'https://lightning-thunder.rtfd.io/en/latest/',
17 | }
18 | -%}
19 |
--------------------------------------------------------------------------------
/docs/source/advanced/extending.rst:
--------------------------------------------------------------------------------
1 | Extending Thunder
2 | #################
3 |
4 | ..
5 | TODO RC1: update using the extend API
6 |
7 | This section describes how to add an executor to Thunder for a PyTorch operation.
8 |
9 | First, define a Python function with the same signature as the targeted operation, and have it call your implementation. For example, the Apex executor for ``torch.nn.functional.cross_entropy`` might define its implementation like::
10 |
11 | import torch
12 | import xentropy_cuda
13 |
14 | def apex_xentropy(
15 | a: torch.Tensor, # a is an actual PyTorch tensor
16 | target,
17 | weight=None,
18 | size_average=None,
19 | ignore_index=-100,
20 | reduce=None,
21 | reduction="mean",
22 | label_smoothing=0.0,
23 | ):
24 | losses, max_log_sum_exp = xentropy_cuda.forward(a, target, label_smoothing, half_to_float)
25 |
26 | When this implementation is used it will be called with actual PyTorch tensors, and not with proxies.
27 |
28 | Next, define a “checker” function with the same signature as the targeted operation that returns True if your operation can execute the targeted operation and False otherwise. Checkers, unlike the implementations, are called with proxies, and not actual PyTorch tensors, because they're called at optimization time. The purpose of a checker function is to let executors target only specific inputs to an operation, and defer to another executor on other inputs.
29 |
30 | A checker function for the Apex executor might look like::
31 |
32 | from thunder.core.proxies import TensorProxy
33 |
34 | def apex_xentropy_checker(
35 | a: TensorProxy, # a is a proxy
36 | target,
37 | weight=None,
38 | size_average=None,
39 | ignore_index=-100,
40 | reduce=None,
41 | reduction="mean",
42 | label_smoothing=0.0,
43 | ):
44 | # Apex's xentropy only supports "sum", "mean" or "none" reductions
45 | if reduction not in ["sum", "mean", "none"]:
46 | return False
47 |
48 | return True
49 |
50 | Create a mapping from the name of the PyTorch operation to your replacement implementation's name, its checker, and its implementation::
51 |
52 | _op_to_xentropy = {
53 | "torch.nn.functional.cross_entropy": ("apex_xentropy", apex_xentropy_checker, apex_xentropy),
54 | }
55 |
56 | Then define a registration function that practitioners can call to access your executor::
57 |
58 | def register_apex_xentropyex(*, add_to_default_executors: bool = True) -> None:
59 | from thunder.executors import add_operator_executor
60 |
61 | return add_operator_executor("apex_xentropy", _op_to_xentropy, add_to_default_executors=add_to_default_executors)
62 |
63 | You can test your executor by registering it, compiling a function that calls the targeted operator, and then verifying that your operation is called (by inspecting the execution trace) and producing the correct output. A good example of this is the tests for the Apex executor.
64 |
--------------------------------------------------------------------------------
/docs/source/basic/sharp_edges.rst:
--------------------------------------------------------------------------------
1 | The sharp edges
2 | ###############
3 |
4 | This section describes features in the language that are not (yet) supported by Thunder, along with their workarounds. You might encounter these when compiling a module or function with Thunder. The examples should give you a good idea of how to change your program so that it works.
5 |
6 | Note that the fact that something is not supported today doesn't mean it won't be supported at some point in the future. Feel free to reach out to help us prioritize.
7 |
8 | Inplace operations
9 | ------------------
10 |
11 | Inplace PyTorch operations like `t.add_(1.0)` are not supported in Thunder yet. Support for inplace operations is coming soon.
12 |
13 |
14 | Tensor subclasses
15 | -----------------
16 |
17 | Thunder currently supports Python data types and PyTorch tensors as inputs of functions and models.
18 |
19 | Subclasses of these types, e.g. lazy tensors, nested tensors, or sparse tensors are not supported today.
20 |
21 |
22 | Tracing Python builtins, standard library operations and functions that call other languages
23 | --------------------------------------------------------------------------------------------
24 |
25 | Calling a Python builtin, standard library operation, or a function that calls into another language is safe to trace, so long as the following rules are observed:
26 |
27 | 1. The function should not have side effects. For example, calling ``print()`` will execute the ``print()`` function while tracing, but since it's not a Thunder operation it will not appear in a trace, and so future cached executions will not execute the ``print()`` statement.
28 | 2. The function must not manipulate tensor data or metadata. Since the operation won't appear in a trace, these manipulations won't be repeated by Thunder, and may even cause a crash while tracing. To implement such operations, see :doc:`Adding Custom Operators <../notebooks/adding_custom_operator>`
29 | 3. The function must not produce different results across invocations. Again, since the operation won't appear in traces, Thunder cannot replicate an operation that produces different results when it's invoked, like ``random.random()`` will.
30 |
31 | ..
32 | Certain op-level behavior
33 | -------------------------
34 | 1. Ops which have not yet been added to Thunder. Please let us know if there’s missing operator support you would like to see and we will be happy to help.
35 | 2. Data dependent control flow (e.g. ``if x.any()``). Since Thunder generates traces of programs ahead of the actual execution, control flow depending on the values of tensors as opposed to their metadata cannot be handled by Thunder.
36 |
37 |
38 | Using Thunder-optimized Modules
39 | -------------------------------
40 |
41 | Compiling a module produces a Thunder-optimized module”. A Thunder-optimized module is less dynamic than the original module, which facilitates tracing and optimization. It has a reference to the original module, and it shares its parameters with it.
42 |
43 | While modifying the original model's parameters will reflect in the Thunder-optimized module, other changes to the original module will not. In particular:
44 |
45 | - Whether model is in ``train`` or ``eval`` mode is captured at compilation time and constant
46 | - The structure of the module is captured at compilation time, and changing the original module's structure will likely break the Thunder-optimized module
47 | - Non-parameter attributes of the module may or may not be captured at compile time and treated as constants
48 |
49 | Not all features of PyTorch modules are currently supported, either. Module hooks are not supported, and adding new module attributes in a module's ``forward()`` method is only partially supported.
50 |
--------------------------------------------------------------------------------
/docs/source/fundamentals/examine.rst:
--------------------------------------------------------------------------------
1 | Using Examine
2 | #############
3 |
4 | We recommend using Thunder's ``examine()`` before compiling a function or a module.
5 |
6 | Thunder cannot run every PyTorch module, but you can quickly find out what is missing using ``examine()``.
7 |
8 | ``examine()`` not only determines if Thunder can compile the module, but provides a helpful report to use when filing an issue requesting support.
9 |
10 | You can run examine like this::
11 |
12 | from thunder.examine import examine
13 |
14 | model = MyModel(...)
15 | examine(model, *args, **kwargs)
16 |
17 | Where ``*args and **kwargs`` are valid inputs to the model. If examine determines that Thunder can run the module or function as expected, it will print::
18 |
19 | The function appears to be working as expected
20 |
21 | When ``examine`` encounters a module or function with one or more operators it doesn't support, it will specify the operators, like this::
22 |
23 | def foo(a):
24 | return torch.triu(a)
25 |
26 | import torch
27 | import thunder
28 | from thunder.examine import examine
29 |
30 | a = torch.full((2, 2), 1., device='cuda')
31 | examine(foo, a)
32 |
33 | Running the above will print::
34 |
35 | Found 1 distinct operations, of which 0 (0.0%) are supported
36 | Please file an issue requesting the following operators here: https://github.com/Lightning-AI/lightning-thunder/issues/new
37 | _VariableFunctionsClass.triu of torch
38 |
39 | To recap, ``examine()`` lets you know if Thunder can run a module, and if it can't it will provide a report to file an issue asking for support.
40 |
--------------------------------------------------------------------------------
/docs/source/fundamentals/hello_world.rst:
--------------------------------------------------------------------------------
1 | Hello World
2 | ###########
3 |
4 | Here is a simple example of how Thunder lets you compile and run PyTorch modules and functions::
5 |
6 | import torch
7 | import thunder
8 |
9 | def foo(a, b):
10 | return a + b
11 |
12 | jitted_foo = thunder.jit(foo)
13 |
14 | a = torch.full((2, 2), 1)
15 | b = torch.full((2, 2), 3)
16 |
17 | result = jitted_foo(a, b)
18 |
19 | print(result)
20 |
21 | # prints
22 | # tensor(
23 | # [[4, 4],
24 | # [4, 4]])
25 |
26 | The compiled function ``jitted_foo`` takes and returns PyTorch tensors, just like the original function, so modules and functions compiled by Thunder can be used as part of bigger PyTorch programs.
27 |
28 | Thunder currently understands a subset of PyTorch operations, and a subset of Python. It's adding support quickly, however. Reach out on the Thunder repo and open an issue there to easily get help if an operator is currently not supported.
29 |
--------------------------------------------------------------------------------
/docs/source/fundamentals/installation.rst:
--------------------------------------------------------------------------------
1 | Install Lightning Thunder
2 | #########################
3 |
4 | Minimal dependencies
5 | ====================
6 |
7 | Follow these instructions to install PyTorch, nvFuser, and finally Thunder.
8 |
9 | Install PyTorch and nvFuser with pip (command shown is for CUDA 12.1 and PyTorch 2.5.x)::
10 |
11 | pip install --pre nvfuser-cu121-torch25
12 |
13 | cu121 can be replaced with cu118 depending on your CUDA version. NVFuser builds typically support the latest point release of PyTorch stable versions.
14 | For torch 2.5, cu124 is also supported. For nightly versions and more detailed instructions, please see https://github.com/NVIDIA/Fuser/#installation
15 |
16 | You're all set with minimal dependencies, so you can follow `Install Thunder`_.
17 |
18 | Dependencies so far don't include additional optimized kernels for execution like OpenAI Triton, cuDNN Fusion, or Apex.
19 | These are described in the following section, `Optional dependencies`_.
20 |
21 | Optional dependencies
22 | =====================
23 |
24 | Install Apex
25 | ------------
26 |
27 | Thunder can use NVIDIA's Apex to accelerate some PyTorch operations. To install the Apex executor, first clone the Apex git repository and then execute the following command in the project's root directory::
28 |
29 | git clone https://github.com/NVIDIA/apex.git
30 | cd apex
31 |
32 | pip install -v --no-cache-dir --no-build-isolation --config-settings "--build-option=--xentropy" ./
33 |
34 | Install cuDNN
35 | -------------
36 |
37 | Thunder can use NVIDIA's cuDNN Python frontend bindings to accelerate some PyTorch operations. cuDNN backend is a pure-lib python package which needs to downloaded separately.
38 |
39 | pip install nvidia-cudnn-cu12
40 | pip install nvidia-cudnn-frontend
41 |
42 | You're all set, now follow `Install Thunder`_.
43 |
44 | Install OpenAI Triton
45 | ---------------------
46 |
47 | Thunder can easily integrate OpenAI Triton kernels. You can install Triton using::
48 |
49 | pip install triton
50 |
51 |
52 | Install Thunder
53 | ===============
54 |
55 | You can now install Thunder::
56 |
57 | pip install git+https://github.com/Lightning-AI/lightning-thunder.git
58 |
59 | Alternatively you can clone the Thunder repository and install locally::
60 |
61 | git clone https://github.com/Lightning-AI/lightning-thunder.git
62 | cd lightning-thunder
63 |
64 | pip install .
65 |
--------------------------------------------------------------------------------
/docs/source/intermediate/additional_executors.rst:
--------------------------------------------------------------------------------
1 | Additional executors
2 | ####################
3 |
4 | nvFuser and Pytorch are not the only executors available in Thunder today. Additional executors can be added to Thunder prior to compilation through a registration mechanism, which makes it easy to have specialized executors perform certain operations more efficiently.
5 |
6 | This section contains a list of all executors supported by PyTorch beyond nvFuser and PyTorch.
7 |
8 | Triton CrossEntropy Executor
9 | ============================
10 |
11 | The Triton CrossEntropy executor can execute ``torch.cross_entropy()`` using an optimized kernel written in OpenAI Triton (https://github.com/openai/triton). It can be used like in the following example::
12 |
13 | import torch
14 | import thunder
15 | from thunder.executors.triton_crossentropy import triton_ex as triton_cross_entropy_ex
16 |
17 | def xentropy(logits, labels, weight, reduction, ignore_index):
18 | return thunder.torch.cross_entropy(
19 | logits, labels, weight=weight, reduction=reduction, ignore_index=ignore_index
20 | )
21 |
22 | jitted_xentropy = thunder.jit(
23 | xentropy,
24 | executors=[triton_cross_entropy_ex,]
25 | )
26 |
27 | device = 'cuda'
28 | dtype = torch.float32
29 |
30 | logits = torch.randn([2048, 50257], device=device, dtype=dtype)
31 | labels = torch.randint(0, 50257, [2048], device=device)
32 | weight = torch.rand(50257, device=device, dtype=dtype, requires_grad=False)
33 | reduction = "sum"
34 | ignore_index = labels[5].item()
35 |
36 | jitted_xentropy(logits, labels, weight, reduction, ignore_index)
37 | traces = thunder.last_traces(jitted_xentropy)
38 | print(traces[-1])
39 |
40 | This prints::
41 |
42 | # Constructed by Delete Last Used (took 0 milliseconds)
43 | import torch
44 | from thunder.executors.torchex import no_autocast
45 |
46 | @torch.no_grad()
47 | @no_autocast
48 | def computation(logits, labels, weight):
49 | # logits: "cuda:0 f32[2048, 50257]"
50 | # labels: "cuda:0 i64[2048]"
51 | # weight: "cuda:0 f32[50257]"
52 | t23 = triton_crossentropy(logits, labels, weight, None, 45279, None, 'sum', 0.0) # t23: "cuda:0 f32[]"
53 | del logits, labels, weight
54 | return t23
55 |
56 | As shown in the above trace, ``triton_crossentropy()`` is the one running the operation.
57 |
58 | Apex CrossEntropy Executor
59 | ==========================
60 |
61 | The Apex CrossEntropy executor can execute ``torch.cross_entropy()`` through an optimized kernel, like this::
62 |
63 | import torch
64 | import thunder
65 | from thunder.executors.apexex import apex_ex
66 |
67 | def xentropy(logits, labels):
68 | return thunder.torch.cross_entropy(
69 | logits, labels, reduction='mean', ignore_index=-1
70 | )
71 |
72 | jitted_xentropy = thunder.jit(xentropy, executors=[apex_ex,])
73 |
74 | device = 'cuda'
75 | dtype = torch.float32
76 |
77 | logits = torch.randn([2048, 50257], device=device, dtype=dtype)
78 | labels = torch.randint(0, 50257, [2048], device=device)
79 |
80 | jitted_xentropy(logits, labels)
81 | traces = thunder.last_traces(jitted_xentropy)
82 | print(traces[-1])
83 |
84 | This prints::
85 |
86 | # Constructed by Delete Last Used (took 0 milliseconds)
87 | import torch
88 | from thunder.executors.torchex import no_autocast
89 |
90 | @torch.no_grad()
91 | @no_autocast
92 | def computation(logits, labels):
93 | # logits: "cuda:0 f32[2048, 50257]"
94 | # labels: "cuda:0 i64[2048]"
95 | (t18, _) = apex_cross_entropy(logits, labels, 'mean', 0.0)
96 | del logits, labels
97 | return t18
98 |
99 | showing that Apex is running the operation.
100 |
101 | cuDNN SDPA Executor
102 | ===================
103 |
104 | TODO RC1
105 |
106 | TransformerEngine Executor
107 | ==========================
108 |
109 | TODO RC1
110 |
--------------------------------------------------------------------------------
/docs/source/intermediate/ddp.rst:
--------------------------------------------------------------------------------
1 | Distributed Data Parallel (DDP)
2 | ###############################
3 |
4 | Thunder has its own Distributed Data Parallel (DDP) transform that we recommend using, although compiled modules also work with PyTorch's DDP transform.
5 |
6 | You can wrap a model in Thunder's ddp like this::
7 |
8 | from thunder.distributed import ddp
9 |
10 | model = MyModel()
11 | ddp_model = ddp(model)
12 | cmodel = thunder.jit(ddp_model)
13 |
14 | Specifying which rank to broadcast from is optional. ``ddp()`` will broadcast from the lowest rank in that group if ``broadcast_from`` is not specified.
15 |
16 | Thunder's ddp is compatible with PyTorch distributed runners like ``torchrun`` (https://pytorch.org/docs/stable/elastic/run.html).
17 |
18 | When using PyTorch's DDP, call DDP on the jitted module::
19 |
20 | from torch.nn.parallel import DistributedDataParallel as DDP
21 |
22 | model = MyModel()
23 | jitted_model = thunder.jit(model)
24 | ddp_model = DDP(jitted_model)
25 |
26 | The ability of Thunder to express distributed algorithms like DDP as a simple transform on the trace is one of Thunder's strengths and is being leveraged to quickly implement more elaborate distributed strategies, like Fully Sharded Data Parallel (FSDP).
27 |
--------------------------------------------------------------------------------
/docs/source/intermediate/whats_next.rst:
--------------------------------------------------------------------------------
1 | What's Next
2 | ###########
3 |
4 | Thunder is developing rapidly, and this section mentions some of what's happening. Please reach out (see Get Involved) if you're interested in one of these topics.
5 |
6 | Compiling the Training Loop
7 | ===========================
8 |
9 | Thunder currently supports compiling PyTorch modules - forward computation, loss calculation, backward computation -, but we plan to support compiling the entire training loop - forward computation, loss calculation, backward computation, and the optimizer step - for maximum performance.
10 |
11 | Dynamic Caching
12 | ===============
13 |
14 | Thunder currently supports either no caching or static caching, and static caching requires recompiling whenever a module is called with inputs with metadata different than past inputs. This can be overly strict. For example, adding two tensors with shape ``(5, 5)`` is essentially the same as adding two tensors with shape ``(10, 10)``. Dynamic caching will determine if the new metadata would result in a new trace or not, significantly reducing compilation time when training some models.
15 |
16 | Memory Layouts and Strides
17 | ==========================
18 |
19 | Thunder does not currently model any stride information on tensor proxies. In the future we will likely model some stride information, like memory layout (e.g. channels-last), to support integration with PyTorch programs that use memory layout, and to let executors use memory layout to inform kernel selection.
20 |
21 | Functional transforms: vmap and AMP
22 | ===================================
23 |
24 | Thunder already has early implementations of JAX's vmap transform and PyTorch's Automatic Mixed Precision (AMP) autocasting, and we're extending our support for these transforms so practitioners can easily apply a variety of composable transforms to PyTorch modules.
25 |
--------------------------------------------------------------------------------
/docs/source/reference/clang/index.rst:
--------------------------------------------------------------------------------
1 | .. module:: thunder.clang
2 |
3 | thunder.clang
4 | =============
5 |
6 | Thunder Core Language
7 |
8 |
9 | .. autosummary::
10 | :toctree: generated/
11 |
12 | maybe_convert_to_dtype
13 | device_put
14 | arange
15 | convolution
16 | full
17 | full_like
18 | uniform
19 | uniform_like
20 | diagonal
21 | expand
22 | flatten
23 | movedim
24 | reshape
25 | slice_in_dim
26 | squeeze
27 | transpose
28 | stride_order
29 | take
30 | index_add
31 | take_along_axis
32 | scatter_add
33 | unsqueeze
34 | cat
35 | stack
36 | compute_broadcast_shape
37 | matrix_transpose
38 | maybe_broadcast
39 |
40 | Unary
41 | ~~~~~
42 |
43 | .. autosummary::
44 | :toctree: generated/
45 |
46 | abs
47 | acos
48 | acosh
49 | asin
50 | asinh
51 | atan
52 | atanh
53 | bitwise_not
54 | ceil
55 | cos
56 | cosh
57 | erf
58 | erfc
59 | erfcinv
60 | erfinv
61 | exp
62 | exp2
63 | expm1
64 | floor
65 | isfinite
66 | lgamma
67 | log
68 | log10
69 | log1p
70 | log2
71 | ndtri
72 | neg
73 | reciprocal
74 | round
75 | rsqrt
76 | sigmoid
77 | sign
78 | signbit
79 | silu
80 | sin
81 | sinh
82 | sqrt
83 | tan
84 | tanh
85 | trunc
86 |
87 | Binary
88 | ~~~~~~
89 |
90 | .. autosummary::
91 | :toctree: generated/
92 |
93 | add
94 | atan2
95 | bitwise_and
96 | bitwise_or
97 | bitwise_xor
98 | copysign
99 | eq
100 | floor_divide
101 | fmod
102 | mod
103 | ge
104 | gt
105 | logical_and
106 | le
107 | lt
108 | mul
109 | ne
110 | nextafter
111 | pow
112 | remainder
113 | sub
114 | true_divide
115 |
116 | Conditional
117 | ~~~~~~~~~~~
118 |
119 | .. autosummary::
120 | :toctree: generated/
121 |
122 | where
123 |
--------------------------------------------------------------------------------
/docs/source/reference/common/index.rst:
--------------------------------------------------------------------------------
1 | .. module:: thunder.common
2 |
3 | thunder.common
4 | ==============
5 |
6 | Common functions and classes for Thunder.
7 |
8 | .. autosummary::
9 | :toctree: generated/
10 |
11 | CACHE_OPTIONS
12 | CompileData
13 | CompileStats
14 |
--------------------------------------------------------------------------------
/docs/source/reference/core/baseutils.rst:
--------------------------------------------------------------------------------
1 | .. module:: thunder.core.baseutils
2 | :noindex:
3 |
4 | Baseutils
5 | ---------
6 |
7 | .. currentmodule:: thunder.core.baseutils
8 |
9 | .. autosummary::
10 | :toctree: generated/
11 |
12 | check
13 | check_type
14 | ProxyInterface
15 | NumberProxyInterface
16 | TensorProxyInterface
17 | SymbolInterface
18 | BoundSymbolInterface
19 |
--------------------------------------------------------------------------------
/docs/source/reference/core/codeutils.rst:
--------------------------------------------------------------------------------
1 | .. module:: thunder.core.codeutils
2 |
3 | Codeutils
4 | ---------
5 |
6 | .. currentmodule:: thunder.core.codeutils
7 |
8 | .. autosummary::
9 | :toctree: generated/
10 |
11 | prettyprint
12 | SigInfo
13 | get_siginfo
14 |
--------------------------------------------------------------------------------
/docs/source/reference/core/devices.rst:
--------------------------------------------------------------------------------
1 | .. module:: thunder.core.devices
2 |
3 | Devices
4 | -------
5 |
6 | .. currentmodule:: thunder.core.devices
7 |
8 | .. autosummary::
9 | :toctree: generated/
10 |
11 | DeviceType
12 | devicetype_string
13 | Device
14 | device_from_string
15 | to_device
16 |
--------------------------------------------------------------------------------
/docs/source/reference/core/dtypes.rst:
--------------------------------------------------------------------------------
1 | .. module:: thunder.core.dtypes
2 |
3 | Dtypes
4 | ------
5 |
6 | ``dtype``\s such as ``float32`` are exposed to :mod:`thunder` like ``thunder.float32``.
7 |
8 | .. currentmodule:: thunder.core.dtypes
9 |
10 | .. autosummary::
11 | :toctree: generated/
12 | :nosignatures:
13 |
14 | dtype
15 | to_dtype
16 | float32
17 | float16
18 | bfloat16
19 | float64
20 | int8
21 | int16
22 | int32
23 | int64
24 | uint8
25 | bool8
26 | complex32
27 | complex64
28 | complex128
29 |
--------------------------------------------------------------------------------
/docs/source/reference/core/index.rst:
--------------------------------------------------------------------------------
1 | .. module:: thunder.core
2 |
3 | thunder.core
4 | ============
5 |
6 | .. toctree::
7 | :maxdepth: 1
8 |
9 | baseutils
10 | codeutils
11 | devices
12 | dtypes
13 | langctxs
14 | prims
15 | proxies
16 | pytree
17 | rematerialization
18 | symbol
19 | trace
20 | transforms
21 |
--------------------------------------------------------------------------------
/docs/source/reference/core/langctxs.rst:
--------------------------------------------------------------------------------
1 | .. module:: thunder.core.langctxs
2 |
3 | langctxs
4 | --------
5 |
6 | .. currentmodule:: thunder.core.langctxs
7 |
8 | .. autosummary::
9 | :toctree: generated/
10 |
11 | set_langctx
12 | get_langctx
13 | reset_langctx
14 | langctx
15 |
--------------------------------------------------------------------------------
/docs/source/reference/core/prims.rst:
--------------------------------------------------------------------------------
1 | .. module:: thunder.core.prims
2 |
3 | .. todo Add other prims after resolving "Cannot resolve forward reference in type annotations of "thunder.core.prims.all_reduce": name 'Callable' is not defined"
4 |
5 | Prims
6 | -----
7 |
8 | .. currentmodule:: thunder.core.prims
9 |
10 | .. autosummary::
11 | :toctree: generated/
12 |
13 | PrimIDs
14 |
--------------------------------------------------------------------------------
/docs/source/reference/core/proxies.rst:
--------------------------------------------------------------------------------
1 | .. module:: thunder.core.proxies
2 |
3 | Proxies
4 | -------
5 |
6 | .. currentmodule:: thunder.core.proxies
7 |
8 | .. autosummary::
9 | :toctree: generated/
10 |
11 | Variable
12 | variableify
13 | unvariableify
14 | CollectionProxy
15 | pyval
16 | pytype
17 | ComplexProxy
18 | IntegerProxy
19 | FloatProxy
20 | FutureTensorProxy
21 | TensorProxy
22 | is_proxyable
23 | proxy
24 |
--------------------------------------------------------------------------------
/docs/source/reference/core/pytree.rst:
--------------------------------------------------------------------------------
1 | .. module:: thunder.core.pytree
2 |
3 | PyTree
4 | ------
5 |
6 | .. list-table:: pytree API map
7 | :widths: 50 50
8 | :header-rows: 1
9 |
10 | * - thunder.core.pytree
11 | - pytorch pytree
12 | * - :func:`tree_map`
13 | - See https://optree.readthedocs.io/en/latest/ops.html#optree.tree_map
14 | * - :func:`tree_flatten`
15 | - See https://optree.readthedocs.io/en/latest/ops.html#optree.tree_flatten
16 | * - :func:`tree_unflatten`
17 | - See https://optree.readthedocs.io/en/latest/ops.html#optree.tree_unflatten
18 |
--------------------------------------------------------------------------------
/docs/source/reference/core/rematerialization.rst:
--------------------------------------------------------------------------------
1 | .. module:: thundre.core.rematerialization
2 |
3 | Rematerialization
4 | -----------------
5 |
6 | .. currentmodule:: thunder.core.rematerialization
7 |
8 | .. autosummary::
9 | :toctree: generated/
10 |
11 | rematerialize
12 |
--------------------------------------------------------------------------------
/docs/source/reference/core/symbol.rst:
--------------------------------------------------------------------------------
1 | .. module:: thunder.core.symbol
2 |
3 | Symbol
4 | ------
5 |
6 | .. currentmodule:: thunder.core.symbol
7 |
8 | .. autosummary::
9 | :toctree: generated/
10 |
11 | Symbol
12 | BoundSymbol
13 | BoundSymbolRHS
14 |
--------------------------------------------------------------------------------
/docs/source/reference/core/trace.rst:
--------------------------------------------------------------------------------
1 | .. module:: thunder.core.trace
2 |
3 | Trace
4 | -----
5 |
6 | .. currentmodule:: thunder.core.trace
7 |
8 | .. autosummary::
9 | :toctree: generated/
10 |
11 | TraceCtx
12 | from_trace
13 | set_tracectx
14 | get_tracectx
15 | reset_tracectx
16 |
--------------------------------------------------------------------------------
/docs/source/reference/core/transforms.rst:
--------------------------------------------------------------------------------
1 | .. module:: thunder.core.transforms
2 |
3 | Transforms
4 | ----------
5 |
6 | .. currentmodule:: thunder.core.transforms
7 |
8 | .. autosummary::
9 | :toctree: generated/
10 |
11 | Node
12 | bsym_list_to_dag
13 | toposort_bsym_dag
14 | insert_inplace
15 | replace_inplace
16 | VISIT_TYPE
17 | visitor_transform
18 |
--------------------------------------------------------------------------------
/docs/source/reference/distributed/index.rst:
--------------------------------------------------------------------------------
1 | .. module:: thunder.distributed
2 |
3 | thunder.distributed
4 | ===================
5 |
6 | .. autosummary::
7 | :toctree: generated/
8 |
9 | ddp
10 | fsdp
11 | FSDPType
12 | FSDPBucketingStrategy
13 | set_skip_data_parallel_grad_sync
14 | reset_skip_data_parallel_grad_sync
15 | get_skip_data_parallel_grad_sync
16 | skip_data_parallel_grad_sync
17 | column_parallel
18 | row_parallel
19 |
--------------------------------------------------------------------------------
/docs/source/reference/dynamo/index.rst:
--------------------------------------------------------------------------------
1 | .. module:: thunder.dynamo
2 |
3 | thunder.dynamo
4 | ==============
5 |
6 | .. autosummary::
7 | :toctree:
8 |
9 | ThunderCompiler
10 | thunderfx
11 | ThunderFXCompiledObject
12 |
--------------------------------------------------------------------------------
/docs/source/reference/examine/index.rst:
--------------------------------------------------------------------------------
1 | .. module:: thunder.examine
2 |
3 | thunder.examine
4 | ===============
5 |
6 | .. autosummary::
7 | :toctree: generated/
8 |
9 | examine
10 |
--------------------------------------------------------------------------------
/docs/source/reference/executors/apexex.rst:
--------------------------------------------------------------------------------
1 | .. module:: thunder.executors.apexex
2 |
3 | APEX Executor
4 | -------------
5 |
6 | Executor of `NVIDIA/apex `_.
7 |
8 | .. currentmodule:: thunder.executors.apexex
9 |
10 | .. autosummary::
11 | :toctree: generated/
12 |
13 | apex_entropy_available
14 |
--------------------------------------------------------------------------------
/docs/source/reference/executors/cudnnex.rst:
--------------------------------------------------------------------------------
1 | .. module:: thunder.executors.cudnnex
2 |
3 | cuDNN Executor
4 | --------------
5 |
6 | Executor of `NVIDIA/cudnn-frontend `_.
7 |
8 | .. currentmodule:: thunder.executors.cudnnex
9 |
10 | .. autosummary::
11 | :toctree: generated/
12 |
13 | cudnn_ex
14 |
--------------------------------------------------------------------------------
/docs/source/reference/executors/index.rst:
--------------------------------------------------------------------------------
1 | .. module:: thunder.executors
2 |
3 | thunder.executors
4 | =================
5 |
6 | .. currentmodule:: thunder.executors
7 |
8 | .. autosummary::
9 | :toctree: generated/
10 |
11 | get_nvfuser_executor
12 | get_torch_executor
13 | nvfuser_available
14 |
15 | .. toctree::
16 | :maxdepth: 1
17 |
18 | apexex
19 | cudnnex
20 | nvfuserex
21 | passes
22 | pythonex
23 | torch_compile
24 | torchex
25 | triton_crossentropy
26 | utils
27 |
--------------------------------------------------------------------------------
/docs/source/reference/executors/nvfuserex.rst:
--------------------------------------------------------------------------------
1 | .. module:: thunder.executors.nvfuserex
2 |
3 | nvFuser Executor
4 | ----------------
5 |
6 | An executor dispatches operators to `NVIDIA/Fuser `_.
7 |
8 | .. currentmodule:: thunder.executors.nvfuserex
9 |
10 | .. autosummary::
11 | :toctree: generated/
12 |
13 | .. fixme(crcrpar) add methods and classes
14 |
--------------------------------------------------------------------------------
/docs/source/reference/executors/passes.rst:
--------------------------------------------------------------------------------
1 | .. module:: thunder.executors.passes
2 |
3 | Optimization Passes
4 | -------------------
5 |
6 | .. currentmodule:: thunder.executors.passes
7 |
8 | .. autosummary::
9 | :toctree: generated/
10 |
11 | transform_for_execution
12 | update_fusion_call_ctx
13 | del_last_used
14 |
--------------------------------------------------------------------------------
/docs/source/reference/executors/pythonex.rst:
--------------------------------------------------------------------------------
1 | .. module:: thunder.executors.pythonex
2 |
3 |
4 | Python Executor
5 | ---------------
6 |
7 | .. currentmodule:: thunder.executors.pythonex
8 |
9 | .. autosummary::
10 | :toctree: generated/
11 |
--------------------------------------------------------------------------------
/docs/source/reference/executors/torch_compile.rst:
--------------------------------------------------------------------------------
1 | .. module:: thunder.executors.torch_compile
2 |
3 | Torch Compile Executor
4 | ----------------------
5 |
6 | .. currentmodule:: thunder.executors.torch_compile
7 |
8 | .. autosummary::
9 | :toctree: generated/
10 |
11 | make_compiled
12 | TorchCompileExecutor
13 |
--------------------------------------------------------------------------------
/docs/source/reference/executors/torchex.rst:
--------------------------------------------------------------------------------
1 | .. module:: thunder.executors.torchex
2 |
3 |
4 | PyTorch Executor
5 | ----------------
6 |
7 | .. currentmodule:: thunder.executors.torch_autograd
8 |
9 | .. autosummary::
10 | :toctree: generated/
11 |
--------------------------------------------------------------------------------
/docs/source/reference/executors/triton_crossentropy.rst:
--------------------------------------------------------------------------------
1 | .. module:: thunder.executors.triton_crossentropy
2 |
3 | Triton Executor
4 | ---------------
5 |
6 | Executor of `openai/triton `_.
7 |
8 | .. currentmodule:: thunder.executors.triton_crossentropy
9 |
10 | .. autosummary::
11 | :toctree: generated/
12 |
--------------------------------------------------------------------------------
/docs/source/reference/executors/utils.rst:
--------------------------------------------------------------------------------
1 | .. module:: thunder.executors.utils
2 |
3 | Utils
4 | -----
5 |
6 | .. currentmodule:: thunder.executors.utils
7 |
8 | .. autosummary::
9 | :toctree: generated/
10 |
11 | Region
12 |
--------------------------------------------------------------------------------
/docs/source/reference/extend/index.rst:
--------------------------------------------------------------------------------
1 | .. module:: thunder.extend
2 |
3 | thunder.extend
4 | ==============
5 |
6 | .. autosummary::
7 | :toctree:
8 |
9 | register_executor
10 | deregister_executor
11 | get_all_executors
12 | get_default_executors
13 | get_always_executors
14 | get_executor
15 | set_default_executors
16 | set_always_executors
17 | add_default_executor
18 | add_always_executor
19 | remove_default_executor
20 | remove_always_executor
21 |
22 | Executor
23 |
--------------------------------------------------------------------------------
/docs/source/reference/plugins/index.rst:
--------------------------------------------------------------------------------
1 | .. module:: thunder.plugins
2 |
3 | thunder.plugins
4 | ==================
5 |
6 | .. autosummary::
7 | :toctree: generated/
8 |
9 | DDP
10 | FSDP
11 | QuantizeInt4
12 | FP8
13 | ReduceOverhead
14 |
--------------------------------------------------------------------------------
/docs/source/reference/recipes/index.rst:
--------------------------------------------------------------------------------
1 | .. module:: thunder.recipes
2 |
3 | thunder.recipes
4 | ==================
5 |
6 | .. autosummary::
7 | :toctree: generated/
8 |
9 | BaseRecipe
10 | HFTransformers
11 |
--------------------------------------------------------------------------------
/docs/source/reference/thunder.rst:
--------------------------------------------------------------------------------
1 | .. module:: thunder
2 |
3 | thunder
4 | =======
5 |
6 |
7 | Compiling functions and modules
8 | -------------------------------
9 |
10 |
11 | .. autosummary::
12 | :toctree: generated/
13 |
14 | jit
15 |
16 |
17 | Querying information on compiled functions and modules
18 | ------------------------------------------------------
19 |
20 |
21 | .. autosummary::
22 | :toctree: generated/
23 |
24 | DebugOptions
25 | compile_data
26 | compile_stats
27 | last_traces
28 | last_backward_traces
29 | last_prologue_traces
30 | cache_option
31 | cache_hits
32 | cache_misses
33 | list_transforms
34 | last_interpreted_instructions
35 | last_interpreter_log
36 | last_compile_options
37 | ..
38 | compile
39 | grad
40 |
41 | JITed Model wrapper
42 | -------------------
43 |
44 | .. autoclass:: ThunderModule
45 | :members: no_sync
46 | :exclude-members: forward
47 |
--------------------------------------------------------------------------------
/docs/source/reference/torch/index.rst:
--------------------------------------------------------------------------------
1 | .. module:: thunder.torch
2 |
3 | thunder.torch
4 | -------------
5 |
6 | A PyTorch dialect in thunder.
7 |
8 | .. currentmodule:: thunder.torch
9 |
10 | .. autosummary::
11 | :toctree: generated/
12 |
13 | torchsymbol
14 | size
15 | to_float
16 | _register_custom_op
17 |
18 | Under Construction
19 |
20 | .. fixme "Cannot resolve forward reference in type annotations of "thunder.torch.abs": name 'Callable' is not defined"
21 |
22 | Operators
23 | ~~~~~~~~~
24 |
25 | Unary
26 | ~~~~~
27 |
28 | Binary
29 | ~~~~~~
30 |
31 | Conditional
32 | ~~~~~~~~~~~
33 |
34 | Tensor Creation
35 | ~~~~~~~~~~~~~~~
36 |
37 | Shape Operation
38 | ~~~~~~~~~~~~~~~
39 |
--------------------------------------------------------------------------------
/docs/source/reference/transforms/index.rst:
--------------------------------------------------------------------------------
1 | .. module:: thunder.transforms
2 |
3 | thunder.transforms
4 | ==================
5 |
6 | .. autosummary::
7 | :toctree: generated/
8 |
9 | MaterializationTransform
10 | ConstantFolding
11 |
--------------------------------------------------------------------------------
/examples/coverage/requirements.txt:
--------------------------------------------------------------------------------
1 | transformers==4.52.4
2 | accelerate
3 | nvfuser-cu128-torch27
4 | nvidia-cudnn-frontend
5 | einops
6 | tiktoken
7 |
--------------------------------------------------------------------------------
/examples/quickstart/hf_bert.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import transformers
3 |
4 | import thunder
5 |
6 | from thunder.dev_utils.benchmark import benchmark
7 |
8 |
9 | def main():
10 | # model_name = "bert-large-uncased"
11 | model_name = "bert-base-uncased"
12 |
13 | device = "cuda:0" if torch.cuda.is_available() else "cpu"
14 |
15 | tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
16 |
17 | with torch.device(device):
18 | model = transformers.AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
19 | model.requires_grad_(False)
20 | model.eval()
21 | # apparently, Transformers 4.51.3 does not instantiate models on the default device
22 | model.to(device)
23 |
24 | inp = tokenizer(["Hello world!"], return_tensors="pt")
25 |
26 | print(f"Eager: {benchmark(model, **inp):.2f}ms")
27 |
28 | thunder_model = thunder.compile(
29 | model,
30 | recipe="hf-transformers",
31 | )
32 |
33 | print(f"Thunder: {benchmark(thunder_model, **inp):.2f}ms")
34 |
35 | if torch.cuda.is_available():
36 | thunder_model = thunder.compile(model, plugins="reduce-overhead")
37 |
38 | print(f"Thunder with 'reduce-overhead': {benchmark(thunder_model, **inp):.2f}ms")
39 |
40 |
41 | if __name__ == "__main__":
42 | torch.backends.cuda.matmul.allow_tf32 = True
43 |
44 | main()
45 |
--------------------------------------------------------------------------------
/examples/quickstart/hf_llm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import transformers
3 |
4 | import thunder
5 |
6 | from thunder.dev_utils.benchmark import benchmark_n
7 |
8 |
9 | def main():
10 | # model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
11 | # model_name = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
12 | # model_name = "meta-llama/Llama-3.1-8B"
13 | model_name = "meta-llama/Llama-3.2-1B"
14 | # model_name = "Qwen/Qwen2.5-7B-Instruct"
15 | # model_name = "microsoft/phi-4"
16 |
17 | device = "cuda:0" if torch.cuda.is_available() else "cpu"
18 |
19 | tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
20 |
21 | with torch.device(device):
22 | model = transformers.AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
23 | model.requires_grad_(False)
24 | model.eval()
25 | # apparently, Transformers 4.51.3 does not instantiate models on the default device
26 | model.to(device)
27 |
28 | inp = tokenizer(["Hello world! Here's a long story"], return_tensors="pt")
29 |
30 | def generate(model, inp, cache=None):
31 | out = model.generate(**inp, do_sample=False, cache_implementation=cache, max_new_tokens=100)
32 | print(tokenizer.decode(out[0].tolist()))
33 |
34 | print("\nGenerating with PyTorch eager:")
35 | eager_time = benchmark_n(2, generate, model, inp)
36 |
37 | thunder_model = thunder.compile(
38 | model,
39 | recipe="hf-transformers",
40 | )
41 |
42 | print("\nGenerating with Thunder:")
43 | thunder_time = benchmark_n(2, generate, thunder_model, inp, cache="static")
44 |
45 | print(f"\nEager: {eager_time:.2f}ms")
46 | print(f"Thunder: {thunder_time:.2f}ms")
47 |
48 |
49 | if __name__ == "__main__":
50 | torch.backends.cuda.matmul.allow_tf32 = True
51 |
52 | main()
53 |
--------------------------------------------------------------------------------
/examples/quickstart/mlp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | import thunder
5 |
6 |
7 | def main():
8 | model = nn.Sequential(
9 | nn.Linear(2048, 4096),
10 | nn.ReLU(),
11 | nn.Linear(4096, 64)
12 | )
13 |
14 | thunder_model = thunder.compile(model)
15 | x = torch.randn(64, 2048)
16 | y = thunder_model(x)
17 |
18 | print(thunder_model)
19 |
20 | print(thunder.last_traces(thunder_model)[-1])
21 |
22 |
23 | if __name__ == "__main__":
24 | main()
25 |
--------------------------------------------------------------------------------
/examples/quickstart/requirements.txt:
--------------------------------------------------------------------------------
1 | transformers==4.55.4
2 | accelerate
3 | nvfuser-cu128-torch27
4 | nvidia-cudnn-frontend
5 |
--------------------------------------------------------------------------------
/examples/quickstart/vit_hf.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from transformers import ViTForImageClassification
3 |
4 | import thunder
5 |
6 | from thunder.dev_utils.benchmark import benchmark
7 |
8 |
9 | def main():
10 | device = "cuda:0" if torch.cuda.is_available() else "cpu"
11 |
12 | with torch.device(device):
13 | model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224", attn_implementation="sdpa", torch_dtype=torch.float32)
14 | model.requires_grad_(False)
15 | model.eval()
16 | # apparently, Transformers 4.51.3 does not instantiate models on the default device
17 | model.to(device)
18 |
19 | inp = torch.randn(128, 3, 224, 224)
20 |
21 | out = model(inp)
22 |
23 | thunder_model = thunder.compile(model, plugins="reduce-overhead" if torch.cuda.is_available() else None)
24 |
25 | thunder_out = thunder_model(inp)
26 |
27 | torch.testing.assert_close(out, thunder_out, atol=1e-2, rtol=1e-2)
28 |
29 | print(f"Eager: {benchmark(model, inp):.2f}ms")
30 | print(f"Thunder: {benchmark(thunder_model, inp):.2f}ms")
31 |
32 |
33 | if __name__ == "__main__":
34 | torch.set_float32_matmul_precision('high')
35 |
36 | main()
37 |
--------------------------------------------------------------------------------
/examples/quickstart/vit_tv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision.models as models
3 |
4 | import thunder
5 |
6 | from thunder.dev_utils.benchmark import benchmark
7 |
8 |
9 | def main():
10 | device = "cuda:0" if torch.cuda.is_available() else "cpu"
11 |
12 | with torch.device(device):
13 | model = models.vit_b_16()
14 | model.requires_grad_(False)
15 | model.eval()
16 |
17 | inp = torch.randn(128, 3, 224, 224)
18 |
19 | out = model(inp)
20 |
21 | thunder_model = thunder.compile(model, plugins="reduce-overhead" if torch.cuda.is_available() else None)
22 |
23 | thunder_out = thunder_model(inp)
24 |
25 | # print(thunder.last_traces(thunder_model)[-1])
26 |
27 | torch.testing.assert_close(out, thunder_out)
28 |
29 | print(f"Eager: {benchmark(model, inp):.2f}ms")
30 | print(f"Thunder: {benchmark(thunder_model, inp):.2f}ms")
31 |
32 |
33 | if __name__ == "__main__":
34 | torch.set_float32_matmul_precision('high')
35 |
36 | main()
37 |
--------------------------------------------------------------------------------
/notebooks/.ignore.ci:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-AI/lightning-thunder/6720e82b18884a26bbadaffff2f33d836682cc9d/notebooks/.ignore.ci
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | -r requirements/base.txt
2 |
--------------------------------------------------------------------------------
/requirements/base.txt:
--------------------------------------------------------------------------------
1 | torch >=2.7.1
2 | looseversion ==1.3.0
3 | lightning-utilities >0.7.0
4 | numpy
5 | networkx >=3.3
6 | optree >=0.12.1
7 | opt_einsum >=3.3.0
8 | # todo: teporarl pin for `NameError: name '_C' is not defined`
9 | mpmath <1.4.0
10 | # Support for 3.12
11 | dill >=0.3.8
12 |
--------------------------------------------------------------------------------
/requirements/coverage.txt:
--------------------------------------------------------------------------------
1 | coverage ~=7.9.1
2 | pytest ==8.4.2
3 | pytest-cov ==6.2.1
4 | pytest-benchmark ==5.1.0
5 | transformers ==4.52.4
6 | lightning_sdk
7 | diffusers==0.35.1
8 | accelerate
9 | bitsandbytes==0.48.0
10 |
--------------------------------------------------------------------------------
/requirements/devel.txt:
--------------------------------------------------------------------------------
1 | packaging
2 | -r base.txt
3 | -r test.txt
4 |
5 | datasets>=3.5.0
6 | peft>=0.15.2
7 |
8 | torchvision
9 | torchaudio
10 |
--------------------------------------------------------------------------------
/requirements/docs.txt:
--------------------------------------------------------------------------------
1 | sphinx ==5.3.0
2 | myst-parser ==1.0.0
3 | nbsphinx ~=0.9.7
4 | ipython[all] ~=8.37.0
5 | pandoc ==2.4
6 | docutils >=0.16
7 | sphinxcontrib-fulltoc @ git+https://github.com/t-vi/sphinx-contrib-fulltoc@master
8 | sphinxcontrib-mockautodoc
9 |
10 | sphinx-autodoc-typehints ==1.23.0
11 | sphinx-paramlinks ==0.6.0
12 | sphinx-togglebutton ==0.3.2
13 | sphinx-copybutton ==0.5.2
14 | snowballstemmer < 4
15 |
16 | # installed from S3 location and fetched in advance
17 | lai-sphinx-theme
18 | # alternative /back-up (old) theme
19 | pt-lightning-sphinx-theme @ https://github.com/PyTorchLightning/lightning_sphinx_theme/archive/master.zip
20 |
--------------------------------------------------------------------------------
/requirements/notebooks.txt:
--------------------------------------------------------------------------------
1 | # just reuse the base requirements
2 | -r base.txt
3 |
4 | ipython[all] ~=8.37.0
5 | numpy
6 | liger-kernel == 0.4.0
7 | cuda-python >=12.6.1, <14.0.0
8 | litgpt == 0.5.1
9 |
--------------------------------------------------------------------------------
/requirements/test.txt:
--------------------------------------------------------------------------------
1 | coverage ~=7.9.1
2 | pytest ==8.4.2
3 | pytest-benchmark ==5.1.0
4 | pytest-timeout ==2.4.0
5 | pytest-cov ==6.2.1
6 | pytest-xdist ==3.8.0
7 | pytest-random-order ==1.2.0
8 | pytest-timestamper ==0.0.10
9 | graphviz ==0.21
10 | fdm ==0.5.0
11 | expecttest ==0.3.0 # for test_ddp.py
12 | hypothesis ~=6.136.6 # for test_ddp.py
13 | numpy
14 | einops # for test_einops.py
15 | litgpt==0.5.0 # for the model definition in tests and benchmarks # todo: need update to latest
16 | absl-py # thunder/benchmarks/test_benchmark_litgpt.py
17 | pandas # thunder/benchmarks/test_benchmark_litgpt.py
18 | xlsxwriter # thunder/benchmarks/test_benchmark_litgpt.py
19 | jsonargparse # thunder/benchmarks/benchmark_litgpt.py
20 | bitsandbytes==0.48.0; 'arm' not in platform_machine and 'aarch' not in platform_machine
21 | bitsandbytes>=0.42,<0.43; 'arm' in platform_machine or 'aarch' in platform_machine
22 | transformers==4.52.4 # for test_networks.py
23 | diffusers==0.35.1 # for test_networks.py
24 | accelerate # for test_networks.py
25 |
26 | asvdb @ git+https://github.com/rapidsai/asvdb.git
27 | asv >=0.6.4
28 |
--------------------------------------------------------------------------------
/scripts/bisect_nvfuser.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import subprocess
4 |
5 |
6 | class Runner:
7 | def __init__(self, command_and_args: list[str], verbose: bool):
8 | self._command_and_args = command_and_args
9 | self._verbose = verbose
10 |
11 | def env_var_for_fuel(self) -> str:
12 | return "NVFUSER_OPTIMIZATION_FUEL"
13 |
14 | def run(self, fuel: int) -> int:
15 | os.environ[self.env_var_for_fuel()] = str(fuel)
16 | print(f"Running {self._command_and_args} with {fuel} units of fuel...")
17 | result = subprocess.run(
18 | self._command_and_args,
19 | check=False,
20 | stdout=(None if self._verbose else subprocess.DEVNULL),
21 | stderr=(None if self._verbose else subprocess.DEVNULL),
22 | )
23 | print(f"Command returned {result.returncode} with {fuel} units of fuel.")
24 | return result.returncode
25 |
26 |
27 | if __name__ == "__main__":
28 | parser = argparse.ArgumentParser(
29 | description="""
30 | A script that bisects nvFuser's optimization fuel to isolate a compiler bug.
31 |
32 | See `thunder.extend.FusionExecutor.get_fuel/set_fuel` for what optimization fuel is and how it is implemented.
33 |
34 | This script finds the minimum unit of optimization fuel between `lower_bound` and `upper_bound` that causes `command_and_args` to fail (i.e. exit with non-zero). The user will then run `NVFUSER_OPTIMIZATION_FUEL= ` as a minimal reproducer to identify the nvFusion that triggers the failure. Likely, the last nvFusion generated is the culprit.
35 | """
36 | )
37 | parser.add_argument("lower_bound", type=int, help="the lower bound for bisecting")
38 | parser.add_argument("upper_bound", type=int, help="the upper bound for bisecting")
39 | parser.add_argument("command_and_args", type=str, nargs=argparse.REMAINDER, help="the command and args to run")
40 | parser.add_argument("--verbose", action="store_true", help="whether to show stdout/stderr of the subprocess")
41 | args = parser.parse_args()
42 |
43 | runner = Runner(args.command_and_args, args.verbose)
44 |
45 | # The cornercases are not needed for the correctness of binary search. They are for catching errors earlier when the user specified a wrong lower/upper bound.
46 | if (exitcode := runner.run(args.lower_bound)) != 0:
47 | print("No need to bisect. Command failed with the lower bound.")
48 | exit(0)
49 |
50 | if (exitcode := runner.run(args.upper_bound)) == 0:
51 | print("Bisecting failed. Command passed with the upper bound.")
52 | exit(1)
53 |
54 | # Find the smallest fuel that fails `command_and_args`.
55 | low = args.lower_bound + 1 # +1 because we know `lower_bound` passed.
56 | high = args.upper_bound
57 | while low < high:
58 | mid = (low + high) // 2
59 | exitcode = runner.run(mid)
60 | if exitcode == 0:
61 | low = mid + 1
62 | else:
63 | high = mid
64 | assert low == high
65 |
66 | print("Bisecting succeeded. Run the following command as a minimal reproducer:")
67 | print(f" {runner.env_var_for_fuel()}={low} {' '.join(args.command_and_args)}")
68 | print("The last nvFusion likely triggered the failure.")
69 |
--------------------------------------------------------------------------------
/scripts/remove-torch-lines.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Get the name of the file to process
4 | filename=$1
5 |
6 | # Create a temporary file to store the filtered lines
7 | tempfile=$(mktemp)
8 |
9 | # Loop through the original file and remove lines including the word "torch"
10 | while read -r line; do
11 | if [[ ! "$line" =~ "torch" ]]; then
12 | echo "$line" >> "$tempfile"
13 | fi
14 | done < "$filename"
15 |
16 | # Move the temporary file to the original file
17 | mv "$tempfile" "$filename"
18 |
--------------------------------------------------------------------------------
/scripts/sanity-check.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -ex
4 | pip list
5 | python -c "import torch ; assert torch.cuda.is_available(), 'missing GPU support!'"
6 | python -c "import torch ; v = torch.__version__ ; assert str(v).startswith('2'), v"
7 | python -c "from thunder.executors import nvfuser_available ; assert nvfuser_available(), 'nvFuser is missing!'"
8 | python -c "from thunder.executors.triton_utils import triton_version ; assert triton_version() is not None, 'triton is missing!'"
9 |
--------------------------------------------------------------------------------
/scripts/validate_build.py:
--------------------------------------------------------------------------------
1 | """Check that `scripts/build_from_source.sh` has produced a correct build."""
2 |
3 | import os
4 |
5 | BUILD_DIR = os.path.abspath(os.path.join(os.path.split(__file__)[0], "build"))
6 |
7 |
8 | def main():
9 | import torch
10 |
11 | expected = os.path.join(BUILD_DIR, "pytorch", "torch", "__init__.py")
12 | actual = os.path.abspath(torch.__file__)
13 | assert expected == actual, f"{expected} vs. {actual}"
14 |
15 | import nvfuser
16 | from looseversion import LooseVersion
17 |
18 | assert LooseVersion(nvfuser.version()) >= LooseVersion("0.0.6"), nvfuser.version()
19 |
20 | import thunder
21 |
22 | expected = os.path.abspath(os.path.join(BUILD_DIR, "..", "..", "thunder", "__init__.py"))
23 | assert expected == thunder.__file__, f"{expected} vs. {thunder.__file__}"
24 |
25 | print("Build looks good!")
26 |
27 |
28 | if __name__ == "__main__":
29 | main()
30 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | from importlib.util import module_from_spec, spec_from_file_location
4 |
5 | from packaging.requirements import Requirement as parse_requirements
6 | from setuptools import find_packages, setup
7 |
8 |
9 | _PATH_ROOT = os.path.dirname(__file__)
10 | _PATH_REQUIRES = os.path.join(_PATH_ROOT, "requirements")
11 | # check if os env. variable is set to convert version to nightly
12 | _CONVERT_VERSION = int(os.environ.get("CONVERT_VERSION2NIGHTLY", 0))
13 |
14 |
15 | def _load_py_module(fname, pkg="thunder"):
16 | spec = spec_from_file_location(os.path.join(pkg, fname), os.path.join(_PATH_ROOT, pkg, fname))
17 | py = module_from_spec(spec)
18 | spec.loader.exec_module(py)
19 | return py
20 |
21 |
22 | def _load_requirements(path_dir: str, file_name: str = "requirements.txt") -> list[str]:
23 | """Parse requirements file -> list of strings"""
24 | reqs: list[str] = []
25 | with open(os.path.join(path_dir, file_name)) as f:
26 | for line in f:
27 | if line and not line.startswith("#"):
28 | reqs.append(line.strip())
29 | # Filter out requirements referring to local paths or specific URLs (if any)
30 | reqs = [str(parse_requirements(r)) for r in reqs if "@" not in r and "://" not in r]
31 | return reqs
32 |
33 |
34 | def convert_version2nightly(about_file: str = "thunder/__about__.py") -> None:
35 | """Load the actual version and convert it to the nightly version."""
36 | from datetime import datetime
37 |
38 | # load the about file
39 | with open(about_file) as fo:
40 | lines = fo.readlines()
41 | idx = None
42 | # find the line with version
43 | for i, ln in enumerate(lines):
44 | if ln.startswith("__version__"):
45 | idx = i
46 | break
47 | if idx is None:
48 | raise ValueError("The version is not found in the `__about__.py` file.")
49 | # parse the version from variable assignment
50 | version = lines[idx].split("=")[1].strip().strip('"')
51 | # parse X.Y.Z version and prune any suffix
52 | vers = re.match(r"(\d+)\.(\d+)\.(\d+).*", version)
53 | # create timestamp YYYYMMDD
54 | timestamp = datetime.now().strftime("%Y%m%d")
55 | version = f"{'.'.join(vers.groups())}.dev{timestamp}"
56 | # print the new version
57 | lines[idx] = f'__version__ = "{version}"\n'
58 | # dump updated lines
59 | with open(about_file, "w") as fo:
60 | fo.writelines(lines)
61 |
62 |
63 | def _load_readme_description(path_dir: str, homepage: str, version: str) -> str:
64 | """Load readme as decribtion."""
65 | path_readme = os.path.join(path_dir, "README.md")
66 | with open(path_readme, encoding="utf-8") as fp:
67 | text = fp.read()
68 | # https://github.com/Lightning-AI/lightning-thunder/raw/master/docs/source/_static/images/lightning_module/pt_to_pl.png
69 | github_source_url = os.path.join(homepage, "raw", version)
70 | # replace relative repository path to absolute link to the release
71 | # do not replace all "docs" as in the readme we replace some other sources with particular path to docs
72 | text = text.replace("docs/source/_static/", f"{os.path.join(github_source_url, 'docs/source/_static/')}")
73 | return text
74 |
75 |
76 | if _CONVERT_VERSION:
77 | convert_version2nightly()
78 |
79 | about = _load_py_module("__about__.py")
80 |
81 | setup(
82 | version=about.__version__,
83 | packages=find_packages(exclude=["thunder/tests", "docs"]),
84 | long_description=_load_readme_description(
85 | path_dir=_PATH_ROOT,
86 | homepage=about.__homepage__,
87 | version=about.__version__,
88 | ),
89 | long_description_content_type="text/markdown",
90 | install_requires=_load_requirements(_PATH_REQUIRES, file_name="base.txt"),
91 | include_package_data=True,
92 | zip_safe=False,
93 | )
94 |
--------------------------------------------------------------------------------
/thunder/__about__.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.2.6.dev0"
2 | __author__ = "Lightning-AI et al"
3 | __author_email__ = "community@lightning.ai"
4 | __copyright__ = f"2024, 2025 {__author__}"
5 | __homepage__ = "https://github.com/Lightning-AI/lightning-thunder"
6 | __docs__ = "Lightning Thunder project."
7 |
8 |
9 | __all__ = [
10 | "__author__",
11 | "__author_email__",
12 | "__copyright__",
13 | "__docs__",
14 | "__homepage__",
15 | "__version__",
16 | ]
17 |
--------------------------------------------------------------------------------
/thunder/clang/langctx.py:
--------------------------------------------------------------------------------
1 | from collections.abc import Callable
2 |
3 | from thunder.core.langctxs import LanguageContext, register_langctx, Languages, resolve_language
4 | from thunder.core.pytree import tree_flatten
5 | from thunder.core.proxies import TensorProxy, NumberProxy
6 |
7 | #
8 | # Creates and registers the torch language context
9 | #
10 | # NOTE That this is done separately from the definition of thunder.torch operations, because the
11 | # language context must be available before those operations are defined
12 |
13 | _method_name_to_fn_map: dict[str, Callable] = {}
14 |
15 |
16 | # Creates and registers the core language language context
17 | class ClangCtx(LanguageContext):
18 | def __init__(self):
19 | super().__init__("core")
20 |
21 | def has_method(self, id: str) -> bool:
22 | return id in _method_name_to_fn_map
23 |
24 | def get_method(self, id: str, *args, **kwargs) -> Callable:
25 | # Note: concrete implmenetations should only raise AttributeError or
26 | # return None for "missing" methods as the proxies will
27 | # route __getattr__ to here and hasattr relies on __getattr__
28 | # throwing AttributeError (only) when the attribute does
29 | # not exist.
30 | inps, _ = tree_flatten((args, kwargs))
31 |
32 | has_proxy_input: bool = False
33 | for x in inps:
34 | if isinstance(x, TensorProxy) or isinstance(x, NumberProxy):
35 | has_proxy_input = True
36 | break
37 |
38 | if has_proxy_input:
39 | method: None | Callable = _method_name_to_fn_map.get(id, None)
40 | if method is None:
41 | raise AttributeError(f"The {self.name} language context has no method {id}")
42 | return method
43 |
44 | # has_proxy_input is False
45 | # Defers to the primitive language context when there are no tensor inputs=
46 | # (the primitive language context handles operations on numbers)
47 | primsctx: LanguageContext = resolve_language(Languages.PRIMS)
48 | if not primsctx.has_method(id):
49 | raise AttributeError(
50 | f"Attempting to call method {id} in the core language context, but it has no tensor inputs and the primitive language context (which handles numbers) doesn't have the method"
51 | )
52 | prim_method: Callable = primsctx.get_method(id, *args, **kwargs)
53 | return prim_method
54 |
55 |
56 | clangctx = ClangCtx()
57 | register_langctx(Languages.CLANG, clangctx)
58 |
59 |
60 | # Registers a method with the torch language context
61 | def register_method(method_name: str, method: Callable, /) -> None:
62 | _method_name_to_fn_map[method_name] = method
63 |
--------------------------------------------------------------------------------
/thunder/clang/utils.py:
--------------------------------------------------------------------------------
1 | from numbers import Number
2 | from collections.abc import Sequence
3 | from collections.abc import Callable
4 |
5 | from thunder.core import utils
6 | import thunder.core.dtypes as dtypes
7 | from thunder.core.symbol import Symbol
8 |
9 | from thunder.core.proxies import (
10 | NumberProxy,
11 | TensorProxy,
12 | )
13 |
14 |
15 | def create_maybe_convert_to_dtype_with_prim(conversion_prim: Symbol):
16 | assert isinstance(conversion_prim, Symbol)
17 |
18 | def maybe_convert_to_dtype(a, dtype, *, enforce_safe_casting=False):
19 | """If a has the same dtype as the given dtype, returns a unmodified.
20 |
21 | Otherwise returns a converted to the given dtype.
22 | """
23 |
24 | utils.check(utils.is_dtype(dtype), lambda: f"Unknown dtype {dtype}!")
25 |
26 | if isinstance(a, Sequence):
27 | return tuple(maybe_convert_to_dtype(x, dtype) for x in a)
28 | if isinstance(a, TensorProxy):
29 | # Translates numbertypes to dtypes
30 | if dtypes.is_numbertype(dtype):
31 | dtype = dtypes.numbertype_to_dtype(dtype)
32 | elif isinstance(a, (Number, NumberProxy)):
33 | # NOTE This allows conversions like (5, float32) -> 5., which is a little odd
34 | dtype = utils.dtype_to_numbertype(dtype)
35 | else:
36 | raise ValueError(
37 | f"Trying to convert the type of the data of an unknown object {a} of {type(a)} that is neither a tensor, number, or sequence!"
38 | )
39 |
40 | if not utils.are_same_dtypes(a, dtype):
41 | if enforce_safe_casting:
42 | utils.check(
43 | utils.can_safe_cast_to(cast_from=utils.to_dtype(a), cast_to=dtype),
44 | lambda: f"Can't safe case from a={a} with dtype {utils.to_dtype(a)} to {dtype}!",
45 | )
46 |
47 | return conversion_prim(a, dtype)
48 |
49 | return a
50 |
51 | return maybe_convert_to_dtype
52 |
53 |
54 | # TODO Add supported dtypes
55 | def _elementwise_unary_wrapper(
56 | a,
57 | *,
58 | prim,
59 | type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
60 | dtype_conversion_fn: Callable[[TensorProxy | NumberProxy, dtypes.dtype], TensorProxy | NumberProxy],
61 | ):
62 | computation_dtype, result_dtype = utils.elementwise_type_promotion(a, type_promotion_kind=type_promotion_kind)
63 |
64 | a = dtype_conversion_fn(a, computation_dtype)
65 | result = prim(a)
66 | result = dtype_conversion_fn(result, result_dtype)
67 |
68 | return result
69 |
--------------------------------------------------------------------------------
/thunder/core/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-AI/lightning-thunder/6720e82b18884a26bbadaffff2f33d836682cc9d/thunder/core/__init__.py
--------------------------------------------------------------------------------
/thunder/core/compile_data.py:
--------------------------------------------------------------------------------
1 | from contextlib import contextmanager
2 | from contextvars import ContextVar
3 | from typing import Any
4 |
5 | from thunder.core.options import CACHE_OPTIONS
6 |
7 | #
8 | # Setting and querying the "compile_data_and_stats" context variable, which contains
9 | # a tuple of (CompileData, CompileStats) objects for the current trace.
10 | #
11 |
12 | _compile_data = ContextVar("compile_data", default=(None, None))
13 |
14 |
15 | #
16 | # Setting, getting, and resetting the context variable
17 | #
18 |
19 |
20 | # NOTE This just acquires the compile data part of the context var's tuple
21 | def get_compile_data() -> None | Any:
22 | """Returns the current compile data."""
23 | cd, cs = _compile_data.get()
24 |
25 | return cd
26 |
27 |
28 | def set_compile_data_and_stats(cd, cs, /):
29 | """Sets the current compile data.
30 |
31 | This is used to pass compile data to functions that are called during compilation.
32 | """
33 | token = _compile_data.set((cd, cs))
34 | return token
35 |
36 |
37 | def reset_compile_data_and_stats(token, /):
38 | """Resets the compile data."""
39 | _compile_data.reset(token)
40 |
41 |
42 | @contextmanager
43 | def compile_data_and_stats(cd, cs, /):
44 | """Sets the current compile data for the duration of the context."""
45 | token = set_compile_data_and_stats(cd, cs)
46 | try:
47 | yield
48 | finally:
49 | reset_compile_data_and_stats(token)
50 |
51 |
52 | #
53 | # Query helpers
54 | #
55 |
56 |
57 | def get_compile_option(option: str, description: str, /) -> None | Any:
58 | cd, cs = _compile_data.get()
59 |
60 | if cd is None or cs is None:
61 | return None
62 |
63 | # See NOTE Different categories of compile options in thunder/__init__.py
64 | cs.last_compile_reasons[option].append(description)
65 | return cd.compile_options.get(option, None)
66 |
67 |
68 | # Whether or not the caching option uses symbolic values
69 | def get_cache_option() -> CACHE_OPTIONS:
70 | cd = get_compile_data()
71 | if cd is None:
72 | return CACHE_OPTIONS.CONSTANT_VALUES
73 | return cd.cache_option
74 |
75 |
76 | # TODO RC1 Remove the try (hack for when operating outside of this contextvar being set)
77 | def using_symbolic_values() -> bool:
78 | try:
79 | return get_cache_option() is CACHE_OPTIONS.SYMBOLIC_VALUES
80 | except:
81 | return False
82 |
83 |
84 | def using_jit() -> bool:
85 | try:
86 | cd, cs = _compile_data.get()
87 | return cd.using_jit
88 | except:
89 | return False
90 |
--------------------------------------------------------------------------------
/thunder/core/profile.py:
--------------------------------------------------------------------------------
1 | from collections.abc import Callable
2 | import contextlib
3 | import functools
4 | import os
5 | import warnings
6 |
7 | import torch
8 |
9 |
10 | _ENABLED = os.getenv("THUNDER_ANNOTATE_TRACES") in ("1", "y", "Y")
11 |
12 | # However, nvtx is incredibly cheap so we no longer bother requiring the
13 | # environment variable.
14 | try:
15 | import nvtx
16 | except ImportError:
17 | if _ENABLED:
18 | msg = "Requested nvtx but the package is not available."
19 | msg += "\nUse `pip install -m pip install nvtx`."
20 | warnings.warn(msg)
21 | _ENABLED = False
22 | raise
23 |
24 |
25 | def profiling_enabled() -> bool:
26 | return _ENABLED
27 |
28 |
29 | @contextlib.contextmanager
30 | def add_markers(msg: str) -> None:
31 | if not profiling_enabled():
32 | yield
33 | return
34 |
35 | assert "\n" not in msg, msg # Both NVTX and JSON forbid newlines
36 | assert '"' not in msg, msg # The PyTorch profiler does not properly escape quotations
37 |
38 | with torch.profiler.record_function(msg):
39 | torch.cuda.nvtx.range_push(msg)
40 | try:
41 | yield
42 |
43 | finally:
44 | torch.cuda.nvtx.range_pop()
45 |
46 |
47 | # The main interface to profiling something. Generally used as a decorator:
48 | # @thunder.core.profile.annotate_for_profile("foo")
49 | # def foo(...): ...
50 | # but alternatively as a `with` context:
51 | # with thunder.core.profile.annotate_for_profile("name for a block of code"):
52 | # # ... code ...
53 | annotate_for_profile: Callable[[str], None] = None
54 |
55 |
56 | if _ENABLED:
57 | annotate_for_profile = functools.partial(nvtx.annotate, domain="thunder")
58 | else:
59 |
60 | class _no_annotate(contextlib.nullcontext):
61 | """
62 | A profiling decorator that does nothing.
63 | """
64 |
65 | def __init__(self, *args, **kwargs):
66 | super().__init__()
67 |
68 | def __call__(self, fqn):
69 | return fqn
70 |
71 | annotate_for_profile = _no_annotate
72 |
--------------------------------------------------------------------------------
/thunder/dev_utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-AI/lightning-thunder/6720e82b18884a26bbadaffff2f33d836682cc9d/thunder/dev_utils/__init__.py
--------------------------------------------------------------------------------
/thunder/dev_utils/benchmark.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 |
3 | import torch
4 |
5 |
6 | def benchmark_n(n, model_or_fn, /, *args, **kwargs):
7 | for _ in range(n):
8 | _ = model_or_fn(*args, **kwargs)
9 | start_event = torch.cuda.Event(enable_timing=True)
10 | end_event = torch.cuda.Event(enable_timing=True)
11 | torch.cuda.synchronize()
12 | start_event.record()
13 | for _ in range(n):
14 | _ = model_or_fn(*args, **kwargs)
15 | end_event.record()
16 | torch.cuda.synchronize()
17 | return start_event.elapsed_time(end_event) / n
18 |
19 |
20 | benchmark = partial(benchmark_n, 10)
21 |
--------------------------------------------------------------------------------
/thunder/dev_utils/check_trace.py:
--------------------------------------------------------------------------------
1 | import thunder
2 |
3 | CHECK_VERSION = 3
4 |
5 |
6 | def check_subsymbols(parent_bsym):
7 | if parent_bsym.sym.is_prim:
8 | # assert that there are no subsymbols?
9 | return
10 | known_proxies = {a.name: a for a in parent_bsym.flat_proxy_args}
11 | for bsym in parent_bsym.subsymbols:
12 | for a in bsym.flat_proxy_args:
13 | assert a.name in known_proxies, f"unknown proxy {a.name} is used in {bsym} args"
14 | assert known_proxies[a.name] is a, f"proxy name collision {a.name} in {bsym} args"
15 | for o in bsym.flat_proxy_outs:
16 | assert known_proxies.get(o.name, o) is o, f"known proxy or proxy name collision {o.name} in {bsym} outputs"
17 | known_proxies[o.name] = o
18 | check_subsymbols(bsym)
19 | for o in parent_bsym.flat_proxy_outs:
20 | assert known_proxies.get(o.name, o) is o, f"known proxy or proxy name collision {o.name} in {parent_bsym}"
21 |
22 |
23 | def check_trace(trace, *, version=CHECK_VERSION):
24 | """checks a trace for consistency
25 |
26 | The check is versioned for the benefit of CI and other automated testing.
27 |
28 | As a user, don't pass a version to get all implemented checks.
29 |
30 | If you add new checks and do not fix all newly detected inconsistencies,
31 | bump the CHECK_VERSION and make your tests only apply to this latest version.
32 |
33 | Please do file issues for things that fail with the latest versions so we can
34 | catch up.
35 | """
36 | # TODO:
37 | # - args vs. unpack trivial
38 | # - args vs. flat_args in return
39 | known_proxies = {}
40 | for bsym in trace.bound_symbols:
41 | if (version >= 3) and bsym.sym == thunder.core.prims.unpack_sequence:
42 | coll = bsym.args[0].collection()
43 | assert len(coll) == len(bsym.output), f"unpack collection length mismatch {bsym}"
44 | for c, o in zip(coll, bsym.output):
45 | if o is None: # unused outputs
46 | continue
47 | if isinstance(c, thunder.Proxy):
48 | assert c is o, f"mismatch in unpack collection: {c} {o} {bsym}"
49 |
50 | for a in bsym.flat_proxy_args:
51 | assert a.name in known_proxies, f"unknown proxy {a.name} is used in {bsym} args"
52 | assert known_proxies[a.name] is a, f"proxy name collision {a.name} in {bsym} args"
53 | for o in bsym.flat_proxy_outs:
54 | assert known_proxies.get(o.name, o) is o, f"known proxy or proxy name collision {o.name} in {bsym} outputs"
55 | known_proxies[o.name] = o
56 | check_subsymbols(bsym)
57 |
58 | tr = thunder.core.trace.from_trace(trace)
59 | with thunder.core.trace.tracectx(tr):
60 | for bsym in trace.bound_symbols:
61 | if bsym.sym.name.startswith("unpack") or bsym.sym.name in {"pack_buffer"}:
62 | continue
63 | res = bsym.sym(*bsym.args, **bsym.kwargs)
64 |
65 | def check_shape(x, y):
66 | if isinstance(x, thunder.TensorProxy) and y is not None: # formal output can be none if unused
67 | assert x.shape == y.shape, f"shape of proxy {y.name} recomputes to {x.shape} incorrectly in {bsym}"
68 | return x
69 |
70 | thunder.core.utils.safe_map_flat(check_shape, res, bsym.output)
71 |
72 |
73 | class CheckedListOfTraces(list):
74 | def __init__(self, *args, **kwargs):
75 | super().__init__(*args, **kwargs)
76 | cd = thunder.core.compile_data.get_compile_data()
77 | if cd.debug_options.check_traces is True:
78 | self._check_version = CHECK_VERSION
79 | elif not cd.debug_options.check_traces:
80 | self._check_version = 0
81 | else:
82 | self._check_version = cd.debug_options.check_traces
83 |
84 | def append(self, trace):
85 | check_trace(trace, version=self._check_version)
86 | super().append(trace)
87 |
88 | def extend(self, traces):
89 | for tr in traces:
90 | check_trace(tr, version=self._check_version)
91 | super().extend(traces)
92 |
--------------------------------------------------------------------------------
/thunder/dev_utils/nvtx_profile_transform.py:
--------------------------------------------------------------------------------
1 | from thunder.core.trace import TraceCtx as Trace, from_trace, TraceProvenance
2 | from thunder.dev_utils.utils import NON_COMPUTATION_PRIMS
3 | from thunder.extend import OperatorExecutor
4 | import time
5 | import torch
6 | import thunder
7 |
8 |
9 | class Timer:
10 | def __init__(self):
11 | self.start_time_ns = None
12 | self.end_time_ns = None
13 |
14 | def __enter__(self):
15 | self.start_time_ns = time.perf_counter_ns()
16 | return self
17 |
18 | def __exit__(self, *args):
19 | self.end_time_ns = time.perf_counter_ns()
20 |
21 | def get_elapsed_time_in_ms(self):
22 | elapsed_time_ns = self.end_time_ns - self.start_time_ns
23 | return elapsed_time_ns // int(1e6)
24 |
25 |
26 | nvtx_profiler_ex = OperatorExecutor("nvtx_profiler_ex")
27 |
28 |
29 | def nvtx_push_impl(msg):
30 | torch.cuda.nvtx.range_push(msg)
31 |
32 |
33 | def nvtx_pop_impl():
34 | torch.cuda.nvtx.range_pop()
35 |
36 |
37 | # Symbols for profiling.
38 | nvtx_push = nvtx_profiler_ex.register_operator("nvtx_range_push", meta=lambda msg: None, fn=nvtx_push_impl)
39 | nvtx_pop = nvtx_profiler_ex.register_operator("nvtx_range_pop", meta=lambda: None, fn=nvtx_pop_impl)
40 |
41 |
42 | class NvtxProfileTransform(thunder.core.transforms.Transform):
43 | def transform_trace_post_optimization(self, trace: Trace, **kwargs) -> Trace:
44 | with Timer() as timer:
45 | profile_trace = from_trace(trace)
46 |
47 | for bound_symbol in trace.bound_symbols:
48 | if bound_symbol.sym.id in NON_COMPUTATION_PRIMS:
49 | profile_trace.bound_symbols.append(bound_symbol)
50 | continue
51 |
52 | # Add nvtx range for the symbol.
53 | profile_trace.bound_symbols.append(
54 | nvtx_push.bind(f"{''.join(bound_symbol.python(indent=0))}", output=None)
55 | )
56 | profile_trace.bound_symbols.append(bound_symbol)
57 | profile_trace.bound_symbols.append(nvtx_pop.bind(output=None))
58 |
59 | profile_trace.set_provenance(
60 | TraceProvenance(f"NVTX Profile Transform (took {timer.get_elapsed_time_in_ms()} milliseconds)")
61 | )
62 | return profile_trace
63 |
--------------------------------------------------------------------------------
/thunder/distributed/tensor_parallel/__init__.py:
--------------------------------------------------------------------------------
1 | import torch.distributed
2 |
3 | from thunder.distributed.tensor_parallel.common import TensorParallelLayerType
4 |
5 | if torch.distributed.is_available():
6 | from thunder.distributed.tensor_parallel.column_wise import column_parallel
7 | from thunder.distributed.tensor_parallel.row_wise import row_parallel
8 | else:
9 | column_parallel = None
10 | row_parallel = None
11 |
12 |
13 | __all__ = [
14 | "TensorParallelLayerType",
15 | "column_parallel",
16 | "row_parallel",
17 | ]
18 |
--------------------------------------------------------------------------------
/thunder/distributed/transforms/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | if torch.distributed.is_available():
4 | from .ddp import apply_bucketing_to_grad_allreduce
5 | from .fsdp import FSDPCommBucketing
6 | else:
7 | apply_bucketing_to_grad_allreduce = None
8 | FSDPCommBucketing = None
9 |
10 | __all__ = [
11 | "apply_bucketing_to_grad_allreduce",
12 | "FSDPCommBucketing",
13 | ]
14 |
--------------------------------------------------------------------------------
/thunder/dynamo/__init__.py:
--------------------------------------------------------------------------------
1 | from thunder.dynamo.compiler import (
2 | ThunderCompiler,
3 | ThunderFXCompiledObject,
4 | thunder_optimize,
5 | thunder_profile,
6 | thunderfx,
7 | )
8 |
9 |
10 | __all__ = [
11 | "ThunderCompiler",
12 | "ThunderFXCompiledObject",
13 | "thunder_optimize",
14 | "thunder_profile",
15 | "thunderfx",
16 | ]
17 |
--------------------------------------------------------------------------------
/thunder/dynamo/repro_script_template.py:
--------------------------------------------------------------------------------
1 | FXGRAPH_CLASS_NAME = "DynamoModule"
2 | INPUTS_NAME = "inputs"
3 | CALLABLE_NAME = "model"
4 | COMPILED_CALLABLE_NAME = "compiled_model"
5 | THUNDER_IMPORT_STRS = """
6 | from thunder.dev_utils import nvtx_profile_transform
7 | """
8 |
9 | pytest_benchmark_multi_exe_code_template = f'''
10 | # NOTE: This script requires `pytest-benchmark==4.0.0` to be installed.
11 | # To execute the script, run `pytest {{graph_name}}_benchmark.py --benchmark-timer=torch.utils.benchmark.utils.timer.timer --benchmark-warmup=on --benchmark-group-by=param:compute_type`
12 | # To check the peak allocated CUDA memory, use --benchmark-json=json_file_name and look at the "max_allocated_memory_MB" field in the json file
13 | # To run tests for a specific compute_type, use the pytest `-k` option.
14 | # For example:
15 | # - `-k "forward"` will run only the forward pass.
16 | #
17 | # Available options:
18 | # - compute_type: "forward", "backward"
19 |
20 | import pytest
21 | from thunder.benchmarks.targets import parametrize_compute_type_only_training, benchmark_for_compute_type, ComputeType
22 | {{torch_import_str}}
23 | {{import_str}}
24 | {THUNDER_IMPORT_STRS}
25 |
26 | # NOTE: The reproducer function has already been processed by TorchDynamo.
27 | # If we let it go through TorchDynamo again, it could be segmented further.
28 | # To avoid this, we directly use Inductor here.
29 | # See issue https://github.com/Lightning-AI/lightning-thunder/issues/1521
30 | def torch_inductor(fn, inputs):
31 | from torch._inductor import compile as inductor_compile
32 | from torch.fx import symbolic_trace
33 |
34 | fx_graph = symbolic_trace(fn)
35 | return inductor_compile(fx_graph, inputs)
36 |
37 | {{executors}}
38 | {{executor_names}}
39 |
40 | {{dynamo_module}}
41 |
42 | @pytest.mark.parametrize(
43 | "executor",
44 | executors,
45 | ids=executor_names,
46 | )
47 | {{compute_type_decorator}}
48 | def test_{{graph_name}}(benchmark, executor, compute_type):
49 | {{inputs}}
50 |
51 | model = DynamoModule()
52 | if executor is None:
53 | compiled_model = model
54 | elif executor == torch_inductor:
55 | compiled_model = executor(model, inputs)
56 | else:
57 | compiled_model = executor(model)
58 | {{call_benchmark}}
59 |
60 | """
61 | Environment information get from `torch.utils.collect_env.get_pretty_env_info()`:
62 | {{torch_env}}
63 |
64 | Versions of Thunder related libraries:
65 | {{thunder_pkgs}}
66 |
67 | {{extra_comment_str}}
68 | """
69 | '''
70 |
71 |
72 | bsym_torch_compile_repro_template = '''
73 | """
74 | {extra_comment_str}
75 | """
76 | {python_func}
77 |
78 | from thunder.executors.torch_compile import make_compiled as make_torch_compile_callable
79 | import thunder.examine
80 |
81 | inputs = {inputs}
82 |
83 | jfn = thunder.jit({func_name})
84 | jfn(*inputs)
85 |
86 | trc = thunder.last_traces(jfn)[-1]
87 | fusion_symbols = thunder.examine.get_fusion_symbols(trc)
88 | assert len(fusion_symbols) == 1
89 | bsym = fusion_symbols[0]
90 |
91 | # NOTE: The nvFusion function cannot be compiled directly using `torch.compile`.
92 | # It must first be processed by Thunder into BoundSymbols and compiled with `make_torch_compile_callable`.
93 | # Additionally, it's recommended to visually verify that `bsym` matches the
94 | # `nvFusion` function above by printing it using `print(bsym)`.
95 | torch_compiled_callable = make_torch_compile_callable(bsym.subsymbols, bsym.flat_args, bsym.flat_outs)
96 | '''
97 |
98 | repro_bench_code_template = f"""
99 | {{import_str}}
100 | {THUNDER_IMPORT_STRS}
101 |
102 | {{dynamo_module}}
103 | def test_{{graph_name}}():
104 | {{inputs}}
105 |
106 | model = {FXGRAPH_CLASS_NAME}()
107 | """
108 |
109 | main_code = """
110 | if __name__ == "__main__":
111 | test_{graph_name}()
112 | """
113 |
114 | comment_str_template = '''
115 | """
116 | Environment information get from `torch.utils.collect_env.get_pretty_env_info()`:
117 | {torch_env}
118 |
119 | Versions of Thunder related libraries:
120 | {thunder_pkgs}
121 |
122 | {extra_comment_str}
123 | """
124 | '''
125 |
--------------------------------------------------------------------------------
/thunder/executors/__init__.py:
--------------------------------------------------------------------------------
1 | import thunder.executors.passes as passes
2 |
3 | import thunder.extend as extend
4 |
5 |
6 | # NOTE The executors submodule depends on the extend submodule
7 |
8 | __all__ = [
9 | "passes",
10 | "get_torch_executor",
11 | "get_nvfuser_executor",
12 | "nvfuser_available",
13 | ]
14 |
15 |
16 | def get_nvfuser_executor() -> None | extend.Executor:
17 | return extend.get_executor("nvfuser")
18 |
19 |
20 | def get_torch_executor() -> extend.Executor:
21 | return extend.get_executor("torch")
22 |
23 |
24 | def nvfuser_available() -> bool:
25 | return get_nvfuser_executor() is not None
26 |
--------------------------------------------------------------------------------
/thunder/executors/apex_fused_rms_norm_impl.py:
--------------------------------------------------------------------------------
1 | from collections.abc import Sequence
2 | import math
3 |
4 |
5 | from thunder.core.proxies import TensorProxy
6 | from thunder.torch import TensorLike
7 | from thunder.executors.apexex import apex_ex
8 |
9 |
10 | APEX_FUSED_NORMS_AVAILABLE = True
11 | try:
12 | # Fused layer norm is only importable if torch.distributed is available
13 | # https://github.com/NVIDIA/apex/issues/1853
14 | from torch.distributed import is_available
15 |
16 | if not is_available():
17 | raise ImportError
18 | import fused_layer_norm_cuda
19 | from apex.normalization.fused_layer_norm import FusedRMSNormAffineMixedDtypesFunction # noqa: F401
20 | except ImportError:
21 | APEX_FUSED_NORMS_AVAILABLE = False
22 |
23 |
24 | def apex_fused_norms_available() -> bool:
25 | return APEX_FUSED_NORMS_AVAILABLE
26 |
27 |
28 | def apex_fused_rms_norm_forward_affine_meta(
29 | input: TensorLike, normalized_shape: Sequence[int], weight: TensorLike, eps: float
30 | ):
31 | unnormalized_dims = len(input.shape) - len(normalized_shape)
32 | invvar = TensorProxy(like=input, shape=(math.prod(input.shape[:unnormalized_dims]),))
33 | return TensorProxy(like=input), invvar
34 |
35 |
36 | def apex_fused_rms_norm_backward_affine_meta(
37 | grad_output: TensorLike,
38 | invvar: TensorLike,
39 | input_or_output: TensorLike,
40 | normalized_shape: Sequence[int],
41 | weight_,
42 | eps: float,
43 | memory_efficient: bool,
44 | ):
45 | return TensorProxy(like=grad_output), TensorProxy(like=weight_)
46 |
47 |
48 | # Create a new symbol and register lookaside only if import is available.
49 | if apex_fused_norms_available():
50 | apex_fused_rms_norm_forward_affine = apex_ex.register_operator(
51 | "apex_fused_rms_norm_forward_affine",
52 | meta=apex_fused_rms_norm_forward_affine_meta,
53 | fn=fused_layer_norm_cuda.rms_forward_affine,
54 | replaces=fused_layer_norm_cuda.rms_forward_affine,
55 | )
56 |
57 | apex_fused_rms_norm_forward_affine_mixed_dtypes = apex_ex.register_operator(
58 | "apex_fused_rms_norm_forward_affine_mixed_dtypes",
59 | meta=apex_fused_rms_norm_forward_affine_meta,
60 | fn=fused_layer_norm_cuda.rms_forward_affine_mixed_dtypes,
61 | replaces=fused_layer_norm_cuda.rms_forward_affine_mixed_dtypes,
62 | )
63 |
64 | apex_fused_rms_norm_backward_affine = apex_ex.register_operator(
65 | "apex_fused_rms_norm_backward_affine",
66 | meta=apex_fused_rms_norm_backward_affine_meta,
67 | fn=fused_layer_norm_cuda.rms_backward_affine,
68 | replaces=fused_layer_norm_cuda.rms_backward_affine,
69 | )
70 |
--------------------------------------------------------------------------------
/thunder/executors/apexex.py:
--------------------------------------------------------------------------------
1 | from thunder.extend import OperatorExecutor, register_executor
2 |
3 |
4 | # TODO Does apex have a version this should use?
5 | apex_ex = OperatorExecutor("apex", version="0.1")
6 | register_executor(apex_ex)
7 |
8 |
9 | from thunder.executors.apex_entropyex_impl import apex_entropy_available
10 | from thunder.executors.apex_fused_rms_norm_impl import apex_fused_norms_available
11 |
12 | __all__ = ["apex_ex", "apex_entropy_available", "apex_fused_norms_available"]
13 |
--------------------------------------------------------------------------------
/thunder/executors/cudnnex.py:
--------------------------------------------------------------------------------
1 | from collections.abc import Callable
2 |
3 | from looseversion import LooseVersion
4 | from thunder.extend import OperatorExecutor
5 |
6 | __all__ = ["cudnn_version", "required_cudnn_version", "cudnn_available", "cudnn_ex", "torch_to_cudnn_dtype"]
7 |
8 |
9 | #
10 | # Functions for detecting cudnn and its version
11 | #
12 | def cudnn_version() -> LooseVersion | None:
13 | try:
14 | import cudnn
15 |
16 | if hasattr(cudnn, "__version__"):
17 | return LooseVersion(cudnn.__version__)
18 |
19 | # NOTE: This import of cudnn may or may not have version info
20 | return LooseVersion("0.0.0")
21 | except ImportError:
22 | pass
23 |
24 | # NOTE This occurs when cudnn couldn't be imported
25 | return None
26 |
27 |
28 | def required_cudnn_version() -> LooseVersion:
29 | # History of versions:
30 | # Using 1.3.0+ because it works better with other libraries (e.g. torch) that also build on top of cudnn
31 | # Using 1.5.0+ because it handles exception with unsupported graphs better
32 | # Using 1.5.1 because of a compatibility fix
33 | # Using 1.5.2 to allow stride 0
34 | return LooseVersion("1.5.2")
35 |
36 |
37 | def cudnn_available() -> bool:
38 | v = cudnn_version()
39 | if v is None:
40 | return False
41 | required = required_cudnn_version()
42 | if v < required:
43 | import warnings
44 |
45 | msg = f"Your cuDNN installation is out of date. Thunder requires version {required}, but found version {v}."
46 | warnings.warn(msg)
47 | return False
48 | return True
49 |
50 |
51 | cudnn_ex: None | OperatorExecutor = None
52 | torch_to_cudnn_dtype: None | Callable = None
53 | cudnn = None
54 |
55 | if cudnn_available():
56 | import thunder.executors.cudnn_sdpa as sdpa_impl
57 |
58 | torch_to_cudnn_dtype = sdpa_impl.torch_to_cudnn_dtype
59 | cudnn_ex = sdpa_impl.cudnn_ex
60 |
--------------------------------------------------------------------------------
/thunder/executors/custom_op_ex.py:
--------------------------------------------------------------------------------
1 | """Executor for `torch.library.custom_op` operators"""
2 |
3 | from thunder.extend import OperatorExecutor
4 |
5 |
6 | __all__ = [
7 | "custom_op_ex",
8 | ]
9 |
10 |
11 | custom_op_ex = OperatorExecutor("custom_op")
12 |
--------------------------------------------------------------------------------
/thunder/executors/nvfuserex.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from looseversion import LooseVersion
3 |
4 | from thunder.extend import FusionExecutor
5 |
6 | __all__ = ["nvfuser_version", "required_nvfuser_version", "nvfuser_available", "nvfuserex"]
7 |
8 |
9 | #
10 | # Functions for detecting nvFuser and its version
11 | #
12 | def nvfuser_version() -> LooseVersion | None:
13 | # Short-circuits if CUDA isn't available
14 | if not torch.cuda.is_available():
15 | return None
16 |
17 | try:
18 | import nvfuser
19 |
20 | if hasattr(nvfuser, "version"):
21 | return LooseVersion(nvfuser.version())
22 |
23 | # NOTE: This import of nvFuser may or may not have version info
24 | return LooseVersion("0.0.0")
25 | except ImportError:
26 | pass
27 |
28 | # NOTE This occurs when nvFuser couldn't be imported
29 | return None
30 |
31 |
32 | def required_nvfuser_version() -> LooseVersion:
33 | return LooseVersion("0.2.8")
34 |
35 |
36 | def nvfuser_available() -> bool:
37 | v = nvfuser_version()
38 | if v is None:
39 | return False
40 |
41 | required = required_nvfuser_version()
42 | if v < required:
43 | import warnings
44 |
45 | msg = f"Your nvfuser installation is out of date. Thunder requires version {required}, but found version {v}."
46 | warnings.warn(msg)
47 | return False
48 | return True
49 |
50 |
51 | nvfuserex: None | FusionExecutor = None
52 | if nvfuser_available():
53 | import thunder.executors.nvfuserex_impl as impl
54 |
55 | nvfuserex = impl.ex
56 |
--------------------------------------------------------------------------------
/thunder/executors/transformer_engineex.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | from collections.abc import Callable
4 |
5 | from lightning_utilities.core.imports import package_available
6 |
7 | from thunder import Transform
8 | from thunder.extend import StatefulExecutor
9 | from thunder.core.trace import TraceCtx
10 |
11 | import torch
12 |
13 | __all__ = ["transformer_engine_ex", "TransformerEngineTransform", "_te_activation_checkpointing_transform"]
14 |
15 | transformer_engine_ex: None | StatefulExecutor = None
16 | TransformerEngineTransform: None | Transform = None
17 | _te_activation_checkpointing_transform: None | Callable[[TraceCtx], TraceCtx] = None
18 |
19 | if torch.cuda.is_available():
20 | if package_available("transformer_engine"):
21 | import thunder.executors.transformer_engineex_impl as impl
22 |
23 | transformer_engine_ex = impl.transformer_engine_ex
24 | TransformerEngineTransform = impl.TransformerEngineTransform
25 | _te_activation_checkpointing_transform = impl._te_activation_checkpointing_transform
26 |
27 | else:
28 | warnings.warn("transformer_engine module not found!")
29 |
--------------------------------------------------------------------------------
/thunder/executors/triton_crossentropy.py:
--------------------------------------------------------------------------------
1 | from thunder.executors import triton_utils
2 | from thunder.extend import OperatorExecutor
3 |
4 | import torch
5 |
6 | triton_version: None | str = triton_utils.triton_version()
7 |
8 | triton_ex: None | OperatorExecutor = None
9 |
10 | if torch.cuda.is_available():
11 | if triton_version is not None:
12 | try:
13 | from thunder.executors.triton_crossentropy_impl import triton_ex as impl_ex
14 |
15 | triton_ex = impl_ex
16 | except Exception:
17 | import warnings
18 |
19 | warnings.warn("triton is present but cannot be initialized")
20 | triton_version = None
21 |
--------------------------------------------------------------------------------
/thunder/executors/triton_utils.py:
--------------------------------------------------------------------------------
1 | from looseversion import LooseVersion
2 | from lightning_utilities.core.imports import package_available
3 |
4 |
5 | def triton_version() -> None | str:
6 | if not package_available("triton"):
7 | return None
8 |
9 | import triton
10 |
11 | return triton.__version__
12 |
13 |
14 | def is_triton_version_at_least(minimum_version: str) -> bool:
15 | version = triton_version()
16 |
17 | if version is None:
18 | return False
19 |
20 | return LooseVersion(version) >= minimum_version
21 |
--------------------------------------------------------------------------------
/thunder/numpy/__init__.py:
--------------------------------------------------------------------------------
1 | from numbers import Number
2 | from collections.abc import Callable
3 |
4 | from thunder.core.langctx import langctx, Languages
5 | from thunder.numpy.langctx import register_method
6 |
7 | from thunder.core.proxies import TensorProxy
8 | from thunder.core.symbol import Symbol
9 | import thunder.clang as clang
10 |
11 |
12 | #
13 | # NumPy operator definitions
14 | #
15 | # NOTE NumPy language support is demonstrative. PRs extending it are welcome!
16 |
17 |
18 | # Decorator that sets the language context and constructs a Symbol for each function
19 | class npsymbol:
20 | def __init__(self, *, method_name: None | str = None):
21 | self.method_name: None | str = method_name
22 |
23 | def __call__(self, fn: Callable) -> Symbol:
24 | _fn = langctx(Languages.NUMPY)(fn)
25 | # TODO: register _fn as opaque with the interpreter or do this in jit_ext?
26 | sym = Symbol(name=fn.__name__, meta=_fn)
27 |
28 | if self.method_name is not None:
29 | register_method(self.method_name, _fn)
30 |
31 | return sym
32 |
33 |
34 | #
35 | # Tensor properties
36 | #
37 |
38 |
39 | # NOTE Named `compute_len` so that it doesn't conflict with built-in `len`
40 | def compute_len(a: TensorProxy, /) -> int:
41 | return a.shape[0]
42 |
43 |
44 | register_method("len", compute_len)
45 |
46 |
47 | def size(a: TensorProxy, /) -> int:
48 | return a.numel
49 |
50 |
51 | register_method("size", size)
52 |
53 |
54 | #
55 | # Elementwise binary operators
56 | #
57 |
58 |
59 | # TODO Create a factory that adds ufunc support to elementwise operations
60 | npsymbol(method_name="add")
61 |
62 |
63 | def add(a: Number | TensorProxy, b: Number | TensorProxy, /, *, where: None | Number | TensorProxy = None):
64 | result = clang.add(a, b)
65 | if where is not None:
66 | return clang.where(where, result, a)
67 | return result
68 |
--------------------------------------------------------------------------------
/thunder/numpy/langctx.py:
--------------------------------------------------------------------------------
1 | from collections.abc import Callable
2 |
3 | from thunder.core.langctxs import LanguageContext, register_langctx, Languages, resolve_language
4 | from thunder.core.pytree import tree_flatten
5 | from thunder.core.proxies import TensorProxy
6 |
7 | #
8 | # Creates and registers the torch language context
9 | #
10 | # NOTE That this is done separately from the definition of thunder.torch operations, because the
11 | # language context must be available before those operations are defined
12 |
13 | _method_name_to_fn_map: dict[str, Callable] = {}
14 |
15 |
16 | # Creates and registers the core language language context
17 | class NumPyCtx(LanguageContext):
18 | def __init__(self):
19 | super().__init__("numpy")
20 |
21 | def has_method(self, id: str) -> bool:
22 | return id in _method_name_to_fn_map
23 |
24 | def get_method(self, id: str, *args, **kwargs) -> Callable:
25 | # Note: concrete implementations should only raise AttributeError or
26 | # return None for "missing" methods as the proxies will
27 | # route __getattr__ to here and hasattr relies on __getattr__
28 | # throwing AttributeError (only) when the attribute does
29 | # not exist.
30 | inps, _ = tree_flatten((args, kwargs))
31 |
32 | has_tensor_input: bool = False
33 | for x in inps:
34 | if isinstance(x, TensorProxy):
35 | has_tensor_input = True
36 | break
37 |
38 | if has_tensor_input:
39 | method: None | Callable = _method_name_to_fn_map.get(id, None)
40 | if method is None:
41 | raise AttributeError(f"The {self.name} language context has no method {id}")
42 | return method
43 |
44 | # has_tensor_input is False
45 | # Defers to the primitive language context when there are no tensor inputs=
46 | # (the primitive language context handles operations on numbers)
47 | primsctx: LanguageContext = resolve_language(Languages.PRIMS)
48 | if not primsctx.has_method(id):
49 | raise AttributeError(
50 | f"Attempting to call method {id} in the numpy language context, but it has no tensor inputs and the primitive language context (which handles numbers) doesn't have the method"
51 | )
52 | prim_method: Callable = primsctx.get_method(id, *args, **kwargs)
53 | return prim_method
54 |
55 |
56 | numpyctx = NumPyCtx()
57 | register_langctx(Languages.NUMPY, numpyctx)
58 | register_langctx("numpy", numpyctx)
59 |
60 |
61 | # Registers a method with the torch language context
62 | def register_method(method_name: str, method: Callable, /) -> None:
63 | _method_name_to_fn_map[method_name] = method
64 |
--------------------------------------------------------------------------------
/thunder/plugins/__init__.py:
--------------------------------------------------------------------------------
1 | from thunder.plugins.distributed import DDP, FSDP
2 | from thunder.plugins.quantization import QuantizeInt4
3 | from thunder.plugins.fp8 import FP8
4 | from thunder.plugins.reduce_overhead import ReduceOverhead
5 |
6 | names_to_plugins = {
7 | "ddp": DDP,
8 | "fsdp": FSDP,
9 | "quantize-int4": QuantizeInt4,
10 | "fp8": FP8,
11 | "reduce-overhead": ReduceOverhead,
12 | }
13 |
14 |
15 | def get_plugin(name):
16 | return names_to_plugins.get(name)
17 |
18 |
19 | def get_plugin_names():
20 | return list(names_to_plugins.keys())
21 |
22 |
23 | def register_plugin(name, cls):
24 | names_to_plugins[name] = cls
25 |
--------------------------------------------------------------------------------
/thunder/plugins/fp8.py:
--------------------------------------------------------------------------------
1 | from thunder import Plugin
2 |
3 |
4 | class FP8(Plugin):
5 | """
6 | Plugin for enabling FP8 precision via NVIDIA Transformer Engine, enabling higher throughput of matrix operations in FP8.
7 |
8 | See `lightning-thunder/thunder/executors/transformer_engine_v1ex.py` for implementation details.
9 | """
10 |
11 | def setup_executors(self):
12 | """
13 | Imports the Transformer Engine executor.
14 |
15 | Returns:
16 | list[Executor]: A list containing the Transformer Engine executor.
17 |
18 | """
19 | from thunder.executors.transformer_engine_v1ex import transformer_engine_v1_ex
20 |
21 | return [transformer_engine_v1_ex]
22 |
--------------------------------------------------------------------------------
/thunder/plugins/quantization.py:
--------------------------------------------------------------------------------
1 | from thunder import Plugin
2 |
3 |
4 | class QuantizeInt4(Plugin):
5 | """
6 | Plugin for 4-bit integer quantization using BitsAndBytes.
7 |
8 | This plugin applies a 4-bit linear quantization transform to
9 | model weights, reducing memory footprint and improving
10 | throughput for both training and inference.
11 |
12 | See https://github.com/bitsandbytes-foundation/bitsandbytes/blob/main/bitsandbytes/functional.py#L889 for more details.
13 | """
14 |
15 | def setup_transforms(self):
16 | """
17 | Fetches the BitsAndBytes quantization transform.
18 |
19 | Returns:
20 | list[Transform]: A list containing the Transformer Engine executor.
21 | """
22 |
23 | from thunder.transforms.quantization import BitsAndBytesLinearQuant4bit
24 |
25 | return [BitsAndBytesLinearQuant4bit()]
26 |
27 | def setup_executors(self):
28 | """
29 | Fetches the BitsAndBytes quantization executor.
30 |
31 | Returns:
32 | list[Executor]: A list containing the Transformer Engine executor.
33 |
34 | """
35 | from thunder.transforms.quantization import get_bitsandbytes_executor
36 |
37 | return [get_bitsandbytes_executor()]
38 |
--------------------------------------------------------------------------------
/thunder/plugins/reduce_overhead.py:
--------------------------------------------------------------------------------
1 | from thunder.core.recipe import Plugin, PluginPolicy
2 | from thunder.transforms.cudagraph import CUDAGraphTransform
3 |
4 |
5 | class ReduceOverhead(Plugin):
6 | """
7 | Plugin to enable CUDA Graphs and reduce CPU overhead.
8 | """
9 |
10 | policy = PluginPolicy.POST
11 |
12 | def setup_transforms(self):
13 | """
14 | Fetches the CUDAGraph transform.
15 |
16 | Returns:
17 | list[Transform]: A list containing the CUDAGraph transform.
18 | """
19 | return [CUDAGraphTransform()]
20 |
--------------------------------------------------------------------------------
/thunder/py.typed:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-AI/lightning-thunder/6720e82b18884a26bbadaffff2f33d836682cc9d/thunder/py.typed
--------------------------------------------------------------------------------
/thunder/recipes/__init__.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 |
3 | from thunder.recipes.base import BaseRecipe
4 | from thunder.recipes.hf_transformers import HFTransformers
5 |
6 |
7 | names_to_recipes: dict[str : type[Any]] = {
8 | "default": BaseRecipe,
9 | "hf-transformers": HFTransformers,
10 | }
11 |
12 |
13 | def get_recipe_class(name: str) -> type[Any]:
14 | return names_to_recipes.get(name)
15 |
16 |
17 | def get_recipes() -> list[str]:
18 | return list(names_to_recipes.keys())
19 |
20 |
21 | def register_recipe(name: str, cls: type[Any]):
22 | if name == "auto":
23 | raise ValueError("Recipe name 'auto' is reserved.")
24 | names_to_recipes[name] = cls
25 |
--------------------------------------------------------------------------------
/thunder/tests/README.md:
--------------------------------------------------------------------------------
1 | Testing is important!
2 |
3 | Most Thunder operators are tested by automatically generated tests that use data from the operator's "OpInfo", defined
4 | in `opinfos.py`. These tests typically compare Thunder's operator behavior with another framework, using the
5 | "sample inputs" defined by its OpInfo. Each OpInfo should define a sample input generator and “reference” implementation in PyTorch (and in the future, NumPy).
6 |
7 | When an operator is very similar across the three language levels (user, core, and primitive), like cos, only one variation is typically tested (the core language variant is preferred).
8 | However, there are also cases where testing a torch language operation or primitive operation directly makes sense.
9 |
10 | Operator tests are autogenerated in `test_ops.py`. To run the tests for a particular operator, use pytest’s -k option.
11 | For example, to run the tests for `cos`, for example, the command would be
12 |
13 | ```bash
14 | pytest test_ops.py -k cos -v
15 | ```
16 |
17 | This will run tests for Thunder’s different executors, supported dtypes, and supported devicetypes.
18 |
--------------------------------------------------------------------------------
/thunder/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-AI/lightning-thunder/6720e82b18884a26bbadaffff2f33d836682cc9d/thunder/tests/__init__.py
--------------------------------------------------------------------------------
/thunder/tests/bf16.py:
--------------------------------------------------------------------------------
1 | import thunder.core
2 | import torch
3 |
4 |
5 | def device_supports_bf16(device: None | str | torch.device | thunder.core.devices.Device, /) -> bool:
6 | """Torch has its own `torch.cuda.is_bf16_supported()`, but the contract of
7 | this API changed in December 2023 to be "can bf16 be represented?", which
8 | is true even for devices that existed before bf16 was invented.
9 |
10 | The contract for this API is that the device implements bf16 operations."""
11 | if not torch.cuda.is_available():
12 | return False
13 |
14 | dev: torch.device = thunder.core.devices.to_torch_device(device)
15 |
16 | if dev.type != "cuda":
17 | return False
18 |
19 | cuda_major: int
20 | cuda_minor: int
21 | cuda_major, cuda_minor = torch.cuda.get_device_capability(dev)
22 | return cuda_major >= 8
23 |
--------------------------------------------------------------------------------
/thunder/tests/conftest.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import pytest_benchmark
3 | from thunder.dynamo.compiler_graph_benchmark import GRAPH_BY_GRAPH_BENCHMARK_PARAMS_KEYS
4 |
5 | import torch
6 |
7 | try:
8 | import nvfuser
9 | except ImportError:
10 | nvfuser = None
11 |
12 |
13 | @pytest.fixture(autouse=True)
14 | def gpu_memory(request):
15 | if torch.cuda.is_available():
16 | torch.cuda.empty_cache()
17 | torch.cuda.reset_peak_memory_stats()
18 | gpu_mem_before = torch.cuda.max_memory_allocated() / 2**30
19 | yield
20 | if torch.cuda.is_available():
21 | gpu_mem = torch.cuda.max_memory_allocated() / 2**30
22 | gpu_mem_limit = request.config.getoption("--gpu-mem-limit")
23 | gpu_mem_use = gpu_mem - gpu_mem_before
24 | if gpu_mem_limit:
25 | assert gpu_mem_use <= gpu_mem_limit, (
26 | f"test needs {gpu_mem - gpu_mem_before:.2f}GB VRAM, only {gpu_mem_limit:.2f}GB allowed"
27 | )
28 | torch.cuda.empty_cache()
29 | torch.cuda.reset_peak_memory_stats()
30 |
31 |
32 | @pytest.fixture(autouse=True)
33 | def test_cleanup(request):
34 | yield
35 | if nvfuser is not None:
36 | nvfuser.FusionCache.reset()
37 |
38 |
39 | @pytest.hookimpl(hookwrapper=True)
40 | def pytest_benchmark_group_stats(config, benchmarks, group_by):
41 | """
42 | The function customize the behavior for ThunderCompilerGraphBenchmarking.
43 | The custom grouping function is only invoked when the `--benchmark-group-by`
44 | option is set to 'graph-by-graph:param:GraphID,param:SplitModuleName'.
45 | For an example, refer to the comment section in `ThunderCompilerGraphBenchmarking`.
46 |
47 | Reference: https://pytest-benchmark.readthedocs.io/en/latest/hooks.html#pytest_benchmark.hookspec.pytest_benchmark_group_stats
48 | """
49 | prefix = "graph-by-graph:"
50 | outcome = yield
51 | if group_by.startswith(prefix):
52 | group_by = group_by[len(prefix) :]
53 | for bench in benchmarks:
54 | if bench["params"] is None:
55 | bench["params"] = {}
56 | # The benchs with the same `params`` share the same dict
57 | # We need to create a deepcopy of the original dictionary to add parameters specific to each graph.
58 | else:
59 | bench["params"] = bench["params"].copy()
60 | if bench["param"] is None:
61 | bench["param"] = ""
62 |
63 | name = bench["name"]
64 | gid, module_name, ex = name.split("-")[-3:]
65 | # Add the "GraphID", "SplitModuleName","executor" as params in benchmark
66 | gid_key, module_name_key, ex_key = GRAPH_BY_GRAPH_BENCHMARK_PARAMS_KEYS
67 | bench["params"].update({gid_key: gid, module_name_key: module_name, ex_key: ex})
68 | bench["param"] += f"-{gid}-{module_name}-{ex}"
69 |
70 | result = pytest_benchmark.plugin.pytest_benchmark_group_stats(config, benchmarks, group_by)
71 | outcome.force_result(result)
72 |
73 |
74 | def pytest_collection_modifyitems(items):
75 | items.sort(key=lambda item: item.name)
76 |
77 |
78 | def pytest_addoption(parser):
79 | parser.addoption("--gpu-mem-limit", type=float)
80 |
--------------------------------------------------------------------------------
/thunder/tests/coverage_tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-AI/lightning-thunder/6720e82b18884a26bbadaffff2f33d836682cc9d/thunder/tests/coverage_tests/__init__.py
--------------------------------------------------------------------------------
/thunder/tests/coverage_tests/test_coverage_hf_diffusers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import pytest
4 |
5 |
6 | if os.getenv("ALLOW_COVERAGE_TRACE") != "1":
7 | pytest.skip("Skipping test_coverage_hf_diffusers.py in regular CI", allow_module_level=True)
8 |
9 | hf_diffusers_unet2d_condition_model_ids = [
10 | "runwayml/stable-diffusion-v1-5",
11 | "CompVis/stable-diffusion-v1-4",
12 | "ionet-official/bc8-alpha",
13 | "stabilityai/sd-turbo",
14 | "runwayml/stable-diffusion-inpainting",
15 | "stabilityai/stable-diffusion-xl-base-1.0",
16 | "stabilityai/stable-diffusion-xl-refiner-1.0",
17 | "diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
18 | ]
19 |
20 | from thunder.tests.framework import requiresCUDA
21 |
22 |
23 | @requiresCUDA
24 | @pytest.mark.parametrize("model_id", hf_diffusers_unet2d_condition_model_ids)
25 | def test_hf_diffusers(model_id):
26 | from thunder.dynamo import thunderfx
27 | from diffusers import UNet2DConditionModel
28 |
29 | unet_config = UNet2DConditionModel.load_config(model_id, subfolder="unet", torch_dtype=torch.bfloat16)
30 | unet = UNet2DConditionModel(unet_config)
31 | in_channels = unet.config.in_channels
32 | cross_attention_dim = unet.config.cross_attention_dim
33 | addition_embed_type = unet.config.addition_embed_type
34 |
35 | sample_size = 4
36 | batch_size = 1
37 | seq_length = 4
38 |
39 | if "xl" in model_id:
40 | time_ids_dim = 6
41 | text_embeds_dim = 4
42 | if "refiner" in model_id:
43 | time_ids_dim = 2
44 | text_embeds_dim = 4
45 | else:
46 | time_ids_dim = None
47 | text_embeds_dim = None
48 |
49 | input_shape = (batch_size, in_channels, sample_size, sample_size)
50 | hidden_states_shape = (batch_size, seq_length, cross_attention_dim)
51 |
52 | unet = unet.to("cuda", dtype=torch.bfloat16).requires_grad_(True)
53 | compiled_model = thunderfx(unet)
54 |
55 | def make_inputs(dtype=torch.bfloat16):
56 | added_cond_kwargs = {}
57 | with torch.device("cuda"):
58 | input = torch.randn(input_shape, dtype=dtype)
59 | hidden_states = torch.randn(hidden_states_shape, dtype=dtype)
60 | timestep = torch.ones(batch_size, dtype=torch.long)
61 | if addition_embed_type is not None:
62 | assert text_embeds_dim is not None and time_ids_dim is not None
63 | time_ids_shape = (batch_size, time_ids_dim)
64 | text_embeds_shape = (batch_size, text_embeds_dim)
65 | added_cond_kwargs["time_ids"] = torch.randn(time_ids_shape, device="cuda", dtype=dtype)
66 | added_cond_kwargs["text_embeds"] = torch.randn(text_embeds_shape, device="cuda", dtype=dtype)
67 | return (input, timestep, hidden_states), {"added_cond_kwargs": added_cond_kwargs}
68 |
69 | compiled_args, compiled_kwargs = make_inputs(torch.bfloat16)
70 | compiled_output = compiled_model(*compiled_args, **compiled_kwargs)
71 |
72 | ref_output = unet(*compiled_args, **compiled_kwargs)
73 |
74 | ref_output = ref_output.sample
75 | compiled_output = compiled_output.sample
76 |
77 | torch.testing.assert_close(compiled_output, ref_output, rtol=1e-2, atol=2e-1)
78 |
79 | # TODO: Currently fails, needs investigation https://github.com/Lightning-AI/lightning-thunder/issues/2153
80 | # loss_grad = torch.randn_like(compiled_output)
81 | # grads_ref = torch.autograd.grad(ref_output, unet.parameters(), grad_outputs=loss_grad)
82 | # grads_compiled = torch.autograd.grad(compiled_output, unet.parameters(), grad_outputs=loss_grad)
83 | # torch.testing.assert_close(grads_ref, grads_compiled, rtol=1e-1, atol=1e-1)
84 |
--------------------------------------------------------------------------------
/thunder/tests/coverage_tests/test_coverage_trace.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import thunder
3 | import os
4 | import pytest
5 |
6 |
7 | if os.getenv("ALLOW_COVERAGE_TRACE") != "1":
8 | pytest.skip("Skipping test_coverage_trace.py in regular CI", allow_module_level=True)
9 |
10 |
11 | from transformers import (
12 | AutoConfig,
13 | AutoModel,
14 | AutoModelForCausalLM,
15 | AutoModelForSeq2SeqLM,
16 | AutoModelForImageClassification,
17 | )
18 | from thunder.tests.test_core import run_prologue
19 |
20 |
21 | MODEL_LIST = [
22 | "gpt2",
23 | "bert-base-uncased",
24 | "google/reformer-enwik8",
25 | "facebook/bart-large",
26 | "t5-base",
27 | "xlnet-base-cased",
28 | "facebook/dinov2-small",
29 | "albert-base-v2",
30 | "google/electra-base-discriminator",
31 | "facebook/opt-1.3b",
32 | "google/vit-base-patch16-224",
33 | ]
34 |
35 | MODEL_TASKS = {"openai/clip-vit-large-patch14": "clip", "facebook/dinov2-small": "vision"}
36 |
37 |
38 | def get_model_class(model_name, config):
39 | task = MODEL_TASKS.get(model_name, "text")
40 | if task == "vision":
41 | return AutoModelForImageClassification
42 | elif config.architectures and "CausalLM" in config.architectures[0]:
43 | return AutoModelForCausalLM
44 | elif config.architectures and "Seq2SeqLM" in config.architectures[0]:
45 | return AutoModelForSeq2SeqLM
46 | else:
47 | return AutoModel
48 |
49 |
50 | # custom input depending on model task
51 | def get_dummy_input(model_name, config):
52 | task = MODEL_TASKS.get(model_name, "text")
53 |
54 | if task == "vision":
55 | return {"pixel_values": torch.randn(1, 3, 224, 224, device="cpu", dtype=torch.float32)}
56 | else:
57 | return {"input_ids": torch.randint(0, 1000, (1, 16), device="cpu")}
58 |
59 |
60 | @pytest.mark.skip(reason="https://github.com/Lightning-AI/lightning-thunder/issues/2436")
61 | @pytest.mark.parametrize("model_name", MODEL_LIST)
62 | def test_model_trace(model_name):
63 | print(f"\n=== Testing {model_name} ===")
64 |
65 | config = AutoConfig.from_pretrained(model_name)
66 | model_class = get_model_class(model_name, config)
67 | model = model_class.from_config(config).to("meta")
68 | input_sample = get_dummy_input(model_name, config)
69 |
70 | jmodel = thunder.jit(model)
71 | ce, pro_to_comp, pro_to_epi = run_prologue(jmodel, **input_sample)
72 |
73 | print(f"[SUCCESS] {model_name} Trace acquired!")
74 |
--------------------------------------------------------------------------------
/thunder/tests/distributed/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-AI/lightning-thunder/6720e82b18884a26bbadaffff2f33d836682cc9d/thunder/tests/distributed/__init__.py
--------------------------------------------------------------------------------
/thunder/tests/distributed/modules.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import ClassVar, TYPE_CHECKING
3 |
4 | import torch.nn as nn
5 |
6 | from thunder.core import utils
7 |
8 | if TYPE_CHECKING:
9 | import torch
10 |
11 |
12 | __all__ = [
13 | "ParallelMLP",
14 | ]
15 |
16 |
17 | class ParallelMLP(nn.Module):
18 | """Simplified version of Megatron/NeMo's ParallelMLP.
19 |
20 | Ref: https://github.com/NVIDIA/NeMo/blob/95ca2f4/nemo/collections/nlp/modules/common/megatron/mlp.py#L61
21 | """
22 |
23 | COLUMN_WISE: ClassVar[tuple[str]] = ("dense_h_to_4h",)
24 | ROW_WISE: ClassVar[tuple[str]] = ("dense_4h_to_h",)
25 |
26 | SUPPORTED_GELU_APPROX: ClassVar[tuple[str, str]] = ("none", "tanh")
27 |
28 | def __init__(
29 | self,
30 | hidden_size: int,
31 | ffn_hidden_size: int | None = None,
32 | bias: bool = True,
33 | gelu_approximate: str = "none",
34 | ) -> None:
35 | utils.check(
36 | gelu_approximate in ParallelMLP.SUPPORTED_GELU_APPROX,
37 | lambda: f"Invalid {gelu_approximate}, supported are {ParallelMLP.SUPPORTED_GELU_APPROX}",
38 | )
39 | if ffn_hidden_size is None:
40 | ffn_hidden_size = 4 * hidden_size
41 |
42 | super().__init__()
43 | self.dense_h_to_4h = nn.Linear(hidden_size, ffn_hidden_size, bias=bias)
44 | self.dense_4h_to_h = nn.Linear(ffn_hidden_size, hidden_size, bias=bias)
45 | self.gelu = nn.GELU(approximate=gelu_approximate)
46 |
47 | def forward(self, x: torch.Tensor) -> torch.Tensor:
48 | four_h = self.gelu(self.dense_h_to_4h(x))
49 | h = self.dense_4h_to_h(four_h)
50 | return h
51 |
--------------------------------------------------------------------------------
/thunder/tests/hf_bart_self_attn.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # This is an abbreviated version of the HuggingFace Bart Encoder Self-Attention
16 | # module. The block was editted to remove non-essential code for
17 | # self-attention.
18 |
19 | import torch
20 | from torch import nn
21 |
22 |
23 | class BartAttention(nn.Module):
24 | """Multi-headed attention from 'Attention Is All You Need' paper"""
25 |
26 | def __init__(
27 | self,
28 | embed_dim: int,
29 | num_heads: int,
30 | dropout: float = 0.0,
31 | bias: bool = True,
32 | ):
33 | super().__init__()
34 | self.embed_dim = embed_dim
35 | self.num_heads = num_heads
36 | self.dropout = dropout
37 | self.head_dim = embed_dim // num_heads
38 |
39 | if (self.head_dim * num_heads) != self.embed_dim:
40 | raise ValueError(
41 | f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
42 | f" and `num_heads`: {num_heads})."
43 | )
44 | self.scaling = self.head_dim**-0.5
45 |
46 | self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
47 | self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
48 | self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
49 | self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
50 |
51 | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
52 | return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
53 |
54 | def forward(
55 | self,
56 | hidden_states: torch.Tensor,
57 | attention_mask: torch.Tensor,
58 | ) -> torch.Tensor:
59 | """Input shape: Batch x Time x Channel"""
60 | bsz, tgt_len, _ = hidden_states.size()
61 |
62 | # get query proj
63 | query_states = self.q_proj(hidden_states) * self.scaling
64 | key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
65 | value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
66 |
67 | proj_shape = (bsz * self.num_heads, -1, self.head_dim)
68 | query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
69 | key_states = key_states.view(*proj_shape)
70 | value_states = value_states.view(*proj_shape)
71 |
72 | src_len = key_states.size(1)
73 | attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
74 |
75 | if attention_mask is not None:
76 | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
77 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
78 |
79 | attn_weights = nn.functional.softmax(attn_weights, dim=-1)
80 |
81 | attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
82 |
83 | attn_output = torch.bmm(attn_probs, value_states)
84 |
85 | attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
86 | attn_output = attn_output.transpose(1, 2)
87 |
88 | # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
89 | # partitioned aross GPUs when using tensor-parallelism.
90 | attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
91 |
92 | attn_output = self.out_proj(attn_output)
93 |
94 | return attn_output
95 |
--------------------------------------------------------------------------------
/thunder/tests/litgpt_model.py:
--------------------------------------------------------------------------------
1 | """Taken from https://github.com/Lightning-AI/litgpt/blob/main/litgpt/model.py"""
2 |
3 | configs = [
4 | # diverse sample of configs FOR TESTING that cover all major checkpoints variants architecturally but with reduced
5 | # size
6 | dict(name="gpt-neox-like", block_size=128, n_layer=2, n_embd=64, n_head=4, padding_multiple=8),
7 | dict(
8 | name="llama1-like",
9 | block_size=128,
10 | vocab_size=320,
11 | padding_multiple=64,
12 | n_layer=2,
13 | n_head=4,
14 | n_embd=64,
15 | rotary_percentage=1.0,
16 | parallel_residual=False,
17 | bias=False,
18 | norm_class_name="RMSNorm",
19 | norm_eps=1e-6,
20 | mlp_class_name="LLaMAMLP",
21 | intermediate_size=1376,
22 | ),
23 | dict(
24 | name="long-context-like",
25 | block_size=512,
26 | vocab_size=320,
27 | padding_multiple=64,
28 | n_layer=2,
29 | n_head=4,
30 | n_embd=64,
31 | rotary_percentage=1.0,
32 | parallel_residual=False,
33 | bias=False,
34 | norm_class_name="RMSNorm",
35 | mlp_class_name="LLaMAMLP",
36 | intermediate_size=11008,
37 | rope_condense_ratio=4,
38 | ),
39 | dict(
40 | name="llama2-like",
41 | vocab_size=320,
42 | padding_multiple=64,
43 | n_layer=2,
44 | n_head=4,
45 | n_embd=64,
46 | rotary_percentage=1.0,
47 | parallel_residual=False,
48 | bias=False,
49 | norm_class_name="RMSNorm",
50 | mlp_class_name="LLaMAMLP",
51 | intermediate_size=1376,
52 | ),
53 | dict(
54 | name="falcon-7b-like",
55 | block_size=128,
56 | padded_vocab_size=254,
57 | n_layer=2,
58 | n_head=7,
59 | n_embd=448,
60 | rotary_percentage=1.0,
61 | n_query_groups=1,
62 | bias=False,
63 | shared_attention_norm=True,
64 | ),
65 | dict(
66 | name="falcon-40b-like",
67 | block_size=128,
68 | padded_vocab_size=508,
69 | n_layer=2,
70 | n_head=64,
71 | n_embd=256,
72 | rotary_percentage=1.0,
73 | n_query_groups=4,
74 | bias=False,
75 | ),
76 | dict(
77 | name="codellama2-like",
78 | block_size=1024,
79 | vocab_size=2001,
80 | padding_multiple=16,
81 | n_layer=2,
82 | n_head=4,
83 | n_embd=64,
84 | rotary_percentage=1.0,
85 | parallel_residual=False,
86 | bias=False,
87 | norm_class_name="RMSNorm",
88 | norm_eps=1e-05,
89 | mlp_class_name="LLaMAMLP",
90 | intermediate_size=1376,
91 | rope_base=1000000,
92 | ),
93 | dict(
94 | name="mixtral-like",
95 | block_size=512,
96 | padded_vocab_size=500,
97 | n_layer=2,
98 | n_head=64,
99 | n_embd=256,
100 | rotary_percentage=1.0,
101 | n_query_groups=8,
102 | parallel_residual=False,
103 | bias=False,
104 | norm_class_name="RMSNorm",
105 | norm_eps=1e-05,
106 | mlp_class_name="LLaMAMoE",
107 | intermediate_size=224,
108 | rope_base=1000000,
109 | n_expert=8,
110 | n_expert_per_token=2,
111 | ),
112 | ]
113 |
114 | name_to_config = {config["name"]: config for config in configs}
115 |
116 |
117 | import litgpt
118 |
119 | # add the testing configurations
120 | litgpt.config.name_to_config.update(name_to_config)
121 | name_to_config.update(litgpt.config.name_to_config)
122 |
123 | # manually expose for backwards compatibility
124 | Config = litgpt.Config
125 | GPT = litgpt.GPT
126 | RMSNorm = litgpt.model.RMSNorm
127 | CausalSelfAttention = litgpt.model.CausalSelfAttention
128 | LLaMAMLP = litgpt.model.LLaMAMLP
129 | build_rope_cache = litgpt.model.build_rope_cache
130 | apply_rope = litgpt.model.apply_rope
131 | Block = litgpt.model.Block
132 |
--------------------------------------------------------------------------------
/thunder/tests/module_example.py:
--------------------------------------------------------------------------------
1 | def returns_two():
2 | return 2
3 |
4 |
5 | def _returns_three():
6 | return 3
7 |
8 |
9 | def returns_five():
10 | return 5
11 |
--------------------------------------------------------------------------------
/thunder/tests/test_apex_fused_norms.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | from torch.testing import assert_close
4 |
5 | fused_layer_norm_cuda = pytest.importorskip("fused_layer_norm_cuda")
6 |
7 | from torch.distributed import is_available
8 | from thunder.executors.apexex import apex_ex
9 | import thunder
10 |
11 |
12 | # See https://github.com/NVIDIA/apex/issues/1853
13 | @pytest.mark.skipif(not is_available(), reason="torch.distributed is not available")
14 | @pytest.mark.parametrize("requires_grad", [True, False])
15 | @pytest.mark.parametrize("memory_efficient", [True, False])
16 | def test_apex_fused_rms_norm(requires_grad, memory_efficient):
17 | from apex.normalization.fused_layer_norm import FusedRMSNormAffineMixedDtypesFunction
18 |
19 | def fn(x, weight, normalized_shape, eps):
20 | return FusedRMSNormAffineMixedDtypesFunction.apply(x, weight, normalized_shape, eps, memory_efficient)
21 |
22 | device = "cuda"
23 | normalized_shape = (3, 2)
24 | x = torch.randn(4, 5, *normalized_shape, device=device, requires_grad=requires_grad)
25 | weight = torch.randn(*normalized_shape, device=device, requires_grad=requires_grad)
26 | eps = 1e-5
27 |
28 | expected = fn(x, weight, normalized_shape, eps)
29 | jfn = thunder.jit(fn, executors=[apex_ex])
30 | actual = jfn(x, weight, normalized_shape, eps)
31 |
32 | assert_close(actual, expected)
33 |
34 | if requires_grad:
35 | grad_output = torch.rand_like(actual)
36 | actual_grad = torch.autograd.grad(actual, [x, weight], grad_output)
37 | expected_grad = torch.autograd.grad(expected, [x, weight], grad_output)
38 |
39 | assert_close(actual_grad, expected_grad)
40 |
41 |
42 | # See https://github.com/NVIDIA/apex/issues/1853
43 | @pytest.mark.skipif(not is_available(), reason="torch.distributed is not available")
44 | @pytest.mark.parametrize("requires_grad", [True, False])
45 | @pytest.mark.parametrize("memory_efficient", [True, False])
46 | def test_apex_fused_rms_norm_autoregister(requires_grad, memory_efficient):
47 | from apex.normalization.fused_layer_norm import FusedRMSNormAffineMixedDtypesFunction
48 |
49 | def fn(x, weight, normalized_shape, eps):
50 | return FusedRMSNormAffineMixedDtypesFunction.apply(x, weight, normalized_shape, eps, memory_efficient)
51 |
52 | device = "cuda"
53 | normalized_shape = (3, 2)
54 | x = torch.randn(4, 5, *normalized_shape, device=device, requires_grad=requires_grad)
55 | weight = torch.randn(*normalized_shape, device=device, requires_grad=requires_grad)
56 | eps = 1e-5
57 |
58 | expected = fn(x, weight, normalized_shape, eps)
59 | jfn = thunder.jit(fn, executors=())
60 | actual = jfn(x, weight, normalized_shape, eps)
61 |
62 | assert_close(actual, expected)
63 |
64 | if requires_grad:
65 | grad_output = torch.rand_like(actual)
66 | actual_grad = torch.autograd.grad(actual, [x, weight], grad_output)
67 | expected_grad = torch.autograd.grad(expected, [x, weight], grad_output)
68 |
69 | assert_close(actual_grad, expected_grad)
70 |
--------------------------------------------------------------------------------
/thunder/tests/test_check_trace.py:
--------------------------------------------------------------------------------
1 | import thunder
2 | from thunder.dev_utils.check_trace import check_trace
3 | import torch
4 | import pytest
5 |
6 |
7 | def test_missing_symbol():
8 | def fn(a, b):
9 | c = a + b
10 | d = 2 * c
11 | return d
12 |
13 | jfn = thunder.jit(fn)
14 |
15 | a = torch.randn(2, 2)
16 | b = torch.randn(2, 2)
17 |
18 | jfn(a, b)
19 |
20 | tr = thunder.last_traces(jfn)[-1]
21 | check_trace(tr)
22 |
23 | del tr.bound_symbols[-3]
24 |
25 | with pytest.raises(AssertionError, match="unknown proxy"):
26 | check_trace(tr)
27 |
28 |
29 | def test_debug_option():
30 | class BrokenTransform(thunder.core.transform_common.Transform):
31 | def transform_traces_pre_prologue(self, pro, comp, epi, **kwargs):
32 | new_comp = thunder.core.trace.from_trace(comp)
33 | new_comp.bound_symbols = comp.bound_symbols[:]
34 | del new_comp.bound_symbols[2]
35 | return pro, new_comp, epi
36 |
37 | def fn(a, b):
38 | c = a + b
39 | d = 2 * c
40 | return d
41 |
42 | a = torch.randn(2, 2)
43 | b = torch.randn(2, 2)
44 |
45 | # works
46 | jfn = thunder.jit(fn, debug_options=thunder.DebugOptions(check_traces=True))
47 | jfn(a, b)
48 |
49 | # broken with nice error
50 | jfn = thunder.jit(fn, transforms=(BrokenTransform(),), debug_options=thunder.DebugOptions(check_traces=True))
51 |
52 | with pytest.raises(AssertionError, match="unknown proxy"):
53 | jfn(a, b)
54 |
55 | # broken with less nice error
56 | jfn = thunder.jit(fn, transforms=(BrokenTransform(),), executors=())
57 | with pytest.raises(UnboundLocalError, match="cannot access local|referenced before assignment"):
58 | jfn(a, b)
59 |
--------------------------------------------------------------------------------
/thunder/tests/test_examine.py:
--------------------------------------------------------------------------------
1 | import thunder.examine
2 | import torch
3 |
4 |
5 | def test_examine_fn():
6 | def foo(x):
7 | x[0] = 5 * x[1]
8 |
9 | x = torch.ones(2, 2)
10 | thunder.examine.examine(foo, x)
11 |
12 |
13 | def test_examine_jfn():
14 | def foo(x):
15 | x[0] = 5 * x[1]
16 |
17 | jfoo = thunder.jit(foo)
18 | x = torch.ones(2, 2)
19 | thunder.examine.examine(jfoo, x)
20 |
21 |
22 | def test_examine_noncallable(capsys):
23 | x = torch.ones(2, 2)
24 | y = torch.ones(2, 2)
25 | thunder.examine.examine(x, y)
26 | captured = capsys.readouterr()
27 | assert "expected `fn` to be a callable" in captured.out
28 |
--------------------------------------------------------------------------------
/thunder/tests/test_pythonex.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import thunder
4 |
5 |
6 | def _run_cache_symbolic_values(fn, ref_fn, *args):
7 | jit_fn = thunder.jit(fn, cache="symbolic values")
8 | out = jit_fn(*args)
9 |
10 | out_ref = ref_fn(*args)
11 | assert out == out_ref
12 |
13 |
14 | def test_fmod():
15 | def foo(a, b):
16 | return a % b
17 |
18 | _run_cache_symbolic_values(foo, foo, 2.0, 1.3)
19 |
20 |
21 | def test_bitwise_or():
22 | def foo(a, b):
23 | return a | b
24 |
25 | _run_cache_symbolic_values(foo, foo, 3, 5)
26 |
27 |
28 | def test_bitwise_and():
29 | def foo(a, b):
30 | return a & b
31 |
32 | _run_cache_symbolic_values(foo, foo, 3, 5)
33 |
34 |
35 | def test_bitwise_xor():
36 | def foo(a, b):
37 | return a ^ b
38 |
39 | _run_cache_symbolic_values(foo, foo, 3, 5)
40 |
41 |
42 | def test_math_atan2():
43 | def foo(a, b):
44 | # TODO: calling through math.atan2 bakes in constant, this needs to be investigated.
45 | return thunder.clang.atan2(a, b)
46 |
47 | # NOTE: we have thunder.clang in foo, which cannot be run with non-proxy
48 | _run_cache_symbolic_values(foo, math.atan2, 2.0, 1.3)
49 |
50 |
51 | def test_math_fmod():
52 | def foo(a, b):
53 | return thunder.clang.fmod(a, b)
54 |
55 | _run_cache_symbolic_values(foo, math.fmod, 2.0, 1.3)
56 |
--------------------------------------------------------------------------------
/thunder/tests/test_reductions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.testing import assert_close
3 |
4 | import thunder
5 | import thunder.torch as ttorch
6 | from thunder.tests.framework import instantiate
7 |
8 |
9 | # TODO: convert these tests to OpInfo generated tests
10 |
11 |
12 | @instantiate(dtypes=(thunder.float32,))
13 | def test_torch_var(executor, device, dtype):
14 | # Tests passing all arguments as function inputs
15 | def foo(a, dim, *, keepdim=False, correction=1):
16 | return ttorch.var(a, dim, keepdim=keepdim, correction=correction)
17 |
18 | traced_foo = executor.make_callable(foo)
19 |
20 | tdtype = ttorch.to_torch_dtype(dtype)
21 | a = torch.testing.make_tensor((4, 4), device=device, dtype=tdtype)
22 |
23 | # Full reduction
24 | thunder_result = traced_foo(a, [0, 1])
25 | torch_result = torch.var(a, [0, 1])
26 | assert_close(thunder_result, torch_result)
27 |
28 | # Reduce along dim 1
29 | thunder_result = traced_foo(a, [1])
30 | torch_result = torch.var(a, [1])
31 | assert_close(thunder_result, torch_result)
32 |
33 | # Specifying the correction
34 | thunder_result = traced_foo(a, [1], correction=2)
35 | torch_result = torch.var(a, [1], correction=2)
36 | assert_close(thunder_result, torch_result)
37 |
38 | # Specifying keepdim
39 | thunder_result = traced_foo(a, [1], keepdim=True, correction=2)
40 | torch_result = torch.var(a, [1], keepdim=True, correction=2)
41 | assert_close(thunder_result, torch_result)
42 |
43 | # Tests passing arguments as constants
44 | def foo(a):
45 | return ttorch.var(a, [0, 1], keepdim=True, correction=2)
46 |
47 | traced_foo = executor.make_callable(foo)
48 |
49 | a = torch.testing.make_tensor((4, 4), device=device, dtype=tdtype)
50 |
51 | thunder_result = traced_foo(a)
52 | torch_result = torch.var(a, [0, 1], keepdim=True, correction=2)
53 | assert_close(thunder_result, torch_result)
54 |
55 |
56 | @instantiate(dtypes=(thunder.float32,))
57 | def test_torch_mean(executor, device, dtype):
58 | def foo(a, dim=None, keepdim=False, *, dtype=None):
59 | return ttorch.mean(a, dim, keepdim, dtype=dtype)
60 |
61 | traced_foo = executor.make_callable(foo)
62 |
63 | tdtype = ttorch.to_torch_dtype(dtype)
64 | a = torch.testing.make_tensor((4, 4), device=device, dtype=tdtype)
65 |
66 | # Full reduction
67 | thunder_result = traced_foo(a, [0, 1])
68 | torch_result = torch.mean(a, [0, 1])
69 | assert_close(thunder_result, torch_result)
70 |
71 | # Reduce along dim 1
72 | thunder_result = traced_foo(a, [1])
73 | torch_result = torch.mean(a, [1])
74 | assert_close(thunder_result, torch_result)
75 |
76 | # Reduce with () dims
77 | thunder_result = traced_foo(a, ())
78 | torch_result = torch.mean(a, ())
79 | assert_close(thunder_result, torch_result)
80 |
81 |
82 | @instantiate(dtypes=(thunder.float32,))
83 | def test_var_mean(executor, device, dtype):
84 | def foo(a, dim=None, keepdim=False, *, correction=1):
85 | return ttorch.var_mean(a, dim, keepdim=keepdim, correction=correction)
86 |
87 | traced_foo = executor.make_callable(foo)
88 |
89 | tdtype = ttorch.to_torch_dtype(dtype)
90 | a = torch.testing.make_tensor((4, 4), device=device, dtype=tdtype)
91 |
92 | # Full reduction
93 | thunder_result = traced_foo(a, [0, 1])
94 | torch_result = torch.var_mean(a, [0, 1])
95 | assert_close(thunder_result, torch_result)
96 |
97 | # Reduce along dim 1
98 | thunder_result = traced_foo(a, [1])
99 | torch_result = torch.var_mean(a, [1])
100 | assert_close(thunder_result, torch_result)
101 |
102 | # Tests passing arguments as constants
103 | def foo(a):
104 | return ttorch.var_mean(a, [0, 1], keepdim=True, correction=2)
105 |
106 | traced_foo = executor.make_callable(foo)
107 |
108 | a = torch.testing.make_tensor((4, 4), device=device, dtype=tdtype)
109 |
110 | thunder_result = traced_foo(a)
111 | torch_result = torch.var_mean(a, [0, 1], keepdim=True, correction=2)
112 | assert_close(thunder_result, torch_result)
113 |
--------------------------------------------------------------------------------
/thunder/tests/test_shape_ops.py:
--------------------------------------------------------------------------------
1 | import thunder
2 | import torch
3 |
4 |
5 | def test_pad_cast_value_itof():
6 | """
7 | Pad should cast the given value to the type of tensor and pad that value.
8 | """
9 |
10 | def fqn():
11 | x = torch.tensor([2, 3], dtype=torch.int32)
12 | y = torch.nn.functional.pad(x, pad=(1, 2), value=6.4)
13 | return y
14 |
15 | th_fqn = thunder.jit(fqn)
16 | v = th_fqn()
17 | assert v[0] == 6
18 | assert v[1] == 2
19 | assert v[2] == 3
20 | assert v[3] == 6
21 | assert v[4] == 6
22 |
23 |
24 | def test_pad_cast_value_ftoi():
25 | """
26 | Pad should cast the given value to the type of tensor and pad that value.
27 | """
28 |
29 | def fqn():
30 | x = torch.tensor([2.4, 3.8])
31 | y = torch.nn.functional.pad(x, pad=(1, 2), value=1)
32 | return y
33 |
34 | th_fqn = thunder.jit(fqn)
35 | v = th_fqn()
36 | assert v[0] == 1.0
37 | assert v[1] == 2.4
38 | assert v[2] == 3.8
39 | assert v[3] == 1.0
40 | assert v[4] == 1.0
41 |
--------------------------------------------------------------------------------
/thunder/tests/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pytest
3 | import functools
4 |
5 |
6 | def is_output_differentiable(x):
7 | # grad_fn is set only if one of the input `requires_grad=True`
8 | # and the op is differentiable.
9 | # Example:
10 | # >>> x = torch.ones(3, requires_grad=True)
11 | # >>> y = torch.ones(3, requires_grad=False)
12 | # >>> (x + x).grad_fn #
13 | # >>> (y + y).grad_fn # None
14 | # >>> (y + x).grad_fn #
15 | # >>> (x < 1).grad_fn # None (non-differentiable op)
16 | # Op with differentiable and non-differentiable outputs.
17 | # >>> torch.topk(x, k=2)
18 | # torch.return_types.topk(
19 | # values=tensor([1., 1.], grad_fn=),
20 | # indices=tensor([0, 1]))
21 | # >>> torch.topk(torch.ones(3, requires_grad=False), k=2)
22 | # torch.return_types.topk(
23 | # values=tensor([1., 1.]),
24 | # indices=tensor([0, 1]))
25 | return x.grad_fn is not None or is_returning_self(x)
26 |
27 |
28 | def is_returning_self(x):
29 | if x.is_leaf and x.requires_grad:
30 | return True
31 | return False
32 |
33 |
34 | def filter_differentiable_outputs(outputs):
35 | if isinstance(outputs, torch.Tensor):
36 | # Otherwise `filter` below will
37 | # iterate over the Tensor data.
38 | outputs = [outputs]
39 |
40 | return list(filter(is_output_differentiable, outputs))
41 |
42 |
43 | def is_sm120_orsm121():
44 | return torch.cuda.get_device_capability() in ((12, 1), (12, 0))
45 |
46 |
47 | def skip_on_sm120_and_sm121(fn):
48 | @functools.wraps(fn)
49 | def wrapped_fn(*args, **kwargs):
50 | if is_sm120_orsm121():
51 | pytest.skip("Skipped on SM120/121")
52 | else:
53 | fn(*args, **kwargs)
54 |
55 | return wrapped_fn
56 |
--------------------------------------------------------------------------------
/thunder/torch/experimental/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-AI/lightning-thunder/6720e82b18884a26bbadaffff2f33d836682cc9d/thunder/torch/experimental/__init__.py
--------------------------------------------------------------------------------
/thunder/torch/experimental/dtensor_codeutils.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 | import torch
3 | from thunder.torch.experimental.dtensor_utils import run_only_if_distributed_is_available
4 |
5 | if torch.distributed.is_available():
6 | from torch.distributed.tensor._dtensor_spec import DTensorSpec
7 |
8 |
9 | @run_only_if_distributed_is_available
10 | def is_dtensor_spec(x: Any) -> bool:
11 | return isinstance(x, DTensorSpec)
12 |
--------------------------------------------------------------------------------
/thunder/torch/langctx.py:
--------------------------------------------------------------------------------
1 | from collections.abc import Callable
2 |
3 | from thunder.core.langctxs import LanguageContext, register_langctx, Languages, resolve_language
4 | from thunder.core.pytree import tree_flatten
5 | from thunder.core.proxies import TensorProxy
6 |
7 | #
8 | # Creates and registers the torch language context
9 | #
10 | # NOTE That this is done separately from the definition of thunder.torch operations, because the
11 | # language context must be available before those operations are defined
12 |
13 | _method_name_to_fn_map: dict[str, Callable] = {}
14 | _property_name_to_fn_map: dict[str, Callable] = {}
15 |
16 |
17 | class TorchCtx(LanguageContext):
18 | def __init__(self):
19 | super().__init__("torch")
20 |
21 | def has_method(self, id: str) -> bool:
22 | return id in _method_name_to_fn_map or id in _property_name_to_fn_map
23 |
24 | def get_method(self, id: str, *args, **kwargs) -> Callable:
25 | # Note: concrete implmenetations should only raise AttributeError or
26 | # return None for "missing" methods as the proxies will
27 | # route __getattr__ to here and hasattr relies on __getattr__
28 | # throwing AttributeError (only) when the attribute does
29 | # not exist.
30 | inps, _ = tree_flatten((args, kwargs))
31 |
32 | has_tensor_input: bool = False
33 | for x in inps:
34 | if isinstance(x, TensorProxy):
35 | has_tensor_input = True
36 | break
37 | if has_tensor_input:
38 | method: None | Callable = _method_name_to_fn_map.get(id, None)
39 | prop: None | Callable = _property_name_to_fn_map.get(id, None)
40 | if method is None and prop is None:
41 | raise AttributeError(f"The {self.name} language context has no method or attribute {id}")
42 | if method:
43 | return method
44 | else:
45 | return prop(inps[0])
46 |
47 | # has_tensor_input is False
48 | # Defers to the CLANG language context when there are no tensor inputs=
49 | # (the clang language context handles operations on NumberProxies and Numbers)
50 | primsctx: LanguageContext = resolve_language(Languages.CLANG)
51 | if not primsctx.has_method(id):
52 | raise AttributeError(
53 | f"Attempting to call method {id} in the torch language context, but it has no tensor inputs and the primitive language context (which handles numbers) doesn't have the method"
54 | )
55 | prim_method: Callable = primsctx.get_method(id, *args, **kwargs)
56 | return prim_method
57 |
58 |
59 | torchctx = TorchCtx()
60 | register_langctx(Languages.TORCH, torchctx)
61 | register_langctx("torch", torchctx)
62 |
63 |
64 | # Registers a method with the torch language context
65 | def register_method(method_name: str, method: Callable, /) -> None:
66 | _method_name_to_fn_map[method_name] = method
67 |
68 |
69 | def register_property(property_name, property) -> None:
70 | _property_name_to_fn_map[property_name] = property
71 |
--------------------------------------------------------------------------------
/thunder/transforms/__init__.py:
--------------------------------------------------------------------------------
1 | from .constant_folding import ConstantFolding
2 | from .materialization import MaterializationTransform
3 | from .qlora import LORATransform
4 | from .prune_prologue_checks import PrunePrologueChecks
5 | from .extraction_only_prologue_transform import ExtractionOnlyPrologueTransform
6 |
7 |
8 | __all__ = [
9 | "ConstantFolding",
10 | "LORATransform",
11 | "MaterializationTransform",
12 | "PrunePrologueChecks",
13 | "ExtractionOnlyPrologueTransform",
14 | ]
15 |
--------------------------------------------------------------------------------
/thunder/transforms/extraction_only_prologue_transform.py:
--------------------------------------------------------------------------------
1 | from thunder.core.prims import PrimIDs
2 | from thunder.core.proxies import ProxyTag
3 | from thunder.core.trace import from_trace
4 | from thunder.core.transform_common import Transform
5 |
6 |
7 | __all__ = [
8 | "ExtractionOnlyPrologueTransform",
9 | ]
10 |
11 |
12 | class ExtractionOnlyPrologueTransform(Transform):
13 | """Exclude :func:`~thunder.core.prims.check_tensor_shape_and_metadata` from prologue trace of ThunderCompiler.
14 |
15 | This transform is mainly used by :class:`~thunder.dynamo.ThunderCompiler` to remove the check of input tensors
16 | when either they are :class:`torch.nn.Parameter` or all of them don't have any symbolic shape.
17 |
18 | Args:
19 | skip_check_on_input_tensors: If :obj:`True`, remove all the check from the prologue trace as TorchDynamo caching would do enough.
20 | Otherwise, remove the checks of tensor proxies with ``ProxyTag.STATIC_MEMORY_LOCATION``.
21 | """
22 |
23 | def __init__(self, skip_check_on_input_tensors: bool = False):
24 | self.skip_check_on_input_tensors = skip_check_on_input_tensors
25 |
26 | def transform_traces_pre_prologue(self, prologue_trace, computation_trace, epilogue_trace, **kwargs):
27 | new_prologue_trace = from_trace(prologue_trace)
28 | new_bsyms = []
29 |
30 | for bsym in prologue_trace.bound_symbols:
31 | # NOTE - We assume TensorProxy's tagged with `STATIC_MEMORY_LOCATION` to
32 | # be Parameters or Buffer. It should be safe to disable check for
33 | # tensors we deem to be static.
34 | if bsym.sym.id == PrimIDs.CHECK_TENSOR_SHAPE_AND_METADATA and (
35 | self.skip_check_on_input_tensors or ProxyTag.STATIC_MEMORY_LOCATION in bsym.args[0].tags
36 | ):
37 | continue
38 |
39 | new_bsyms.append(bsym)
40 |
41 | new_prologue_trace.bound_symbols = new_bsyms
42 |
43 | new_prologue_trace.set_provenance("Extraction only prologue pass")
44 | return new_prologue_trace, computation_trace, epilogue_trace
45 |
--------------------------------------------------------------------------------
/thunder/transforms/prune_prologue_checks.py:
--------------------------------------------------------------------------------
1 | import thunder
2 | from thunder.core import prims as prims
3 |
4 |
5 | class PrunePrologueChecks(thunder.core.transform_common.Transform):
6 | """A Transform to prune Prologue checks
7 |
8 | This transform removes prologue checks and can be applied when the user
9 | controls the environment enough to ensure that these checks would always
10 | succeed. By default, we remove all checks that use the module.
11 | """
12 |
13 | def __init__(self, prune_all_checks=False):
14 | self.prune_all_checks = prune_all_checks
15 |
16 | def transform_traces_pre_prologue(self, prologue_trace, compute_trace, epilogue_trace, **kwargs):
17 | def is_check(bsym):
18 | return bsym.sym in {
19 | prims.check_tensor_shape_and_metadata,
20 | prims.check_string_value,
21 | prims.check_number_type_and_value,
22 | prims.check_len,
23 | prims.check_literal_like,
24 | }
25 |
26 | if not self.prune_all_checks:
27 | bsyms_to_skip = set()
28 | module_member_names = set()
29 |
30 | for bsym in prologue_trace.bound_symbols:
31 | if bsym.sym in {prims.unpack_trivial, prims.unpack_cache_info, prims.python_return}:
32 | # These don't have inputs but need to be skipped to not trigger false positives
33 | # python_return may have no inputs
34 | continue
35 | if bsym.sym is prims.unpack_function_obj:
36 | for o in bsym.flat_proxy_outs:
37 | module_member_names.add(o.name)
38 | continue
39 |
40 | input_names = {i.name in module_member_names for i in bsym.flat_proxy_args}
41 | if all(input_names):
42 | # This has the special case of no proxy inputs, which is the case for unpack_function_obj,
43 | # the root of module_member_names
44 | assert input_names, f"unexpected symbol {bsym.sym.name} without inputs"
45 | for o in bsym.flat_proxy_outs:
46 | module_member_names.add(o.name)
47 | if is_check(bsym):
48 | bsyms_to_skip.add(bsym)
49 |
50 | def should_skip_bsym(bsym):
51 | return bsym in bsyms_to_skip
52 |
53 | else:
54 | should_skip_bsym = is_check
55 |
56 | new_prologue_trace = thunder.core.trace.from_trace(prologue_trace)
57 | for bsym in prologue_trace.bound_symbols:
58 | if not should_skip_bsym(bsym):
59 | new_prologue_trace.bound_symbols.append(bsym.from_bsym())
60 |
61 | new_prologue_trace = thunder.core.transform_common.dce(new_prologue_trace)
62 | new_prologue_trace.set_provenance(thunder.core.trace.TraceProvenance(f"{self.__class__.__name__}"))
63 |
64 | return new_prologue_trace, compute_trace, epilogue_trace
65 |
--------------------------------------------------------------------------------
/thunder/transforms/utils.py:
--------------------------------------------------------------------------------
1 | import thunder
2 | from thunder.core import utils
3 | from thunder.core import prims
4 | from thunder.core.trace import TraceCtx
5 | from thunder.core.pytree import tree_map
6 |
7 |
8 | def get_checks(prologue_trace):
9 | # returns a dictionary mapping model param names to (check bsym, get param bsym
10 | check_dict = {}
11 | prologue_producers, prologue_consumers = utils.producers_and_consumers(prologue_trace)
12 | for bsym in prologue_trace.bound_symbols:
13 | if bsym.sym == prims.unpack_parameter or bsym.sym == prims.unpack_buffer:
14 | param_thunder_module, param_name = bsym.args
15 | checks = [
16 | bsym2 for bsym2 in prologue_consumers[bsym.output] if bsym2.sym == prims.check_tensor_shape_and_metadata
17 | ]
18 | assert len(checks) == 1, (
19 | f"expected each parameter and buffer to have exactly one checker, but {bsym.output} has {len(checks)}"
20 | )
21 | assert isinstance(param_name, str)
22 | check_dict[param_name] = (checks[0], bsym)
23 | return check_dict
24 |
25 |
26 | def add_trace_output(trace, output, subindex: int | None = None):
27 | ret_node = trace.bound_symbols[-1]
28 | assert ret_node.sym == prims.python_return
29 | assert len(ret_node.args) == 1
30 | (ret_args,) = ret_node.args
31 |
32 | if subindex is None:
33 | ret_args = (*ret_args, output)
34 | else:
35 | assert isinstance(ret_args[subindex], tuple)
36 | ret_args = (*ret_args[:subindex], (*ret_args[subindex], output), *ret_args[subindex + 1 :])
37 |
38 | ret_node.args = (ret_args,)
39 |
40 |
41 | def trace_with_replaced_proxy_metadata(trace: TraceCtx, proxy_replacement_metadata) -> TraceCtx:
42 | t = TraceCtx(trace.fn)
43 |
44 | proxymap: dict[str, thunder.Proxy] = {}
45 |
46 | def map_proxy(p):
47 | if isinstance(p, thunder.Proxy):
48 | return proxymap[p.name]
49 | return p
50 |
51 | def create_proxy(p):
52 | if isinstance(p, thunder.Proxy):
53 | if p.name in proxymap: # happens with subsymbols
54 | return proxymap[p.name]
55 | with thunder.core.trace.tracectx(t):
56 | np = p.replace(**proxy_replacement_metadata.get(p.name, {}))
57 | proxymap[p.name] = np
58 | return np
59 | return p
60 |
61 | def process_bound_symbols(src_bound_symbols, target_bound_symbols):
62 | for bsym in src_bound_symbols:
63 | new_args = tree_map(map_proxy, bsym.args)
64 | new_kwargs = tree_map(map_proxy, bsym.kwargs)
65 | new_output = tree_map(create_proxy, bsym.output)
66 | new_bsym = bsym.from_bsym(output=new_output, args=new_args, kwargs=new_kwargs, subsymbols=[])
67 | target_bound_symbols.append(new_bsym)
68 | if len(bsym.subsymbols) > 0:
69 | process_bound_symbols(bsym.subsymbols, new_bsym.subsymbols)
70 |
71 | process_bound_symbols(trace.bound_symbols, t.bound_symbols)
72 |
73 | t.args = tree_map(map_proxy, trace.args)
74 | t.kwargs = tree_map(map_proxy, trace.kwargs)
75 | t._siginfo = trace._siginfo
76 | return t
77 |
--------------------------------------------------------------------------------