├── .azure
├── docker-build.yml
├── gpu-tests.yml
├── notebook-runs.yml
├── remove-torch-lines.sh
└── sanity-check.sh
├── .codecov.yml
├── .git-blame-ignore-revs
├── .github
├── CODEOWNERS
├── ISSUE_TEMPLATE
│ ├── bug_report.md
│ ├── documentation.md
│ ├── feature_request.md
│ └── program_coverage.md
├── PULL_REQUEST_TEMPLATE.md
├── dependabot.yml
├── labeling-config.yml
├── lightning-probot.yml
├── load-quickstart-report.py
├── run-benchmark-as-lit-jobs.py
├── run-quickstart-as-lit-jobs.py
└── workflows
│ ├── auto-cc.yml
│ ├── ci-benchmark.yml
│ ├── ci-checks.yml
│ ├── ci-quickstart.yml
│ ├── ci-testing.yml
│ ├── docs-build.yml
│ ├── label-conflicts.yml
│ ├── labeler.yml
│ ├── release-nightly.yml
│ ├── release-pypi.yml
│ └── stale-branches.yaml
├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE
├── MANIFEST.in
├── Makefile
├── README.md
├── dockers
├── README.md
├── ubuntu-cuda
│ └── Dockerfile
├── with-apex
│ └── Dockerfile
└── with-dev
│ └── Dockerfile
├── docs
├── .build_docs.sh
├── .readthedocs.yaml
├── Makefile
├── make.bat
└── source
│ ├── _static
│ ├── copybutton.js
│ └── images
│ │ ├── LightningThunderDarkModewByline.png
│ │ ├── LightningThunderLightModewByline.png
│ │ ├── how_it_works.png
│ │ ├── icon.svg
│ │ ├── lightning_thunder_lightmode_nobyline.png
│ │ ├── logo-large.svg
│ │ ├── logo-small.svg
│ │ ├── logo.png
│ │ ├── logo.svg
│ │ ├── normalized_training_throughput_zero2.png
│ │ ├── pretrain_perf.png
│ │ └── training_throughput_single.png
│ ├── _templates
│ └── theme_variables.jinja
│ ├── advanced
│ ├── contrib_thunder.rst
│ ├── extending.rst
│ └── inside_thunder.rst
│ ├── basic
│ ├── faq.rst
│ ├── inspecting_traces.rst
│ ├── mlp_mnist.rst
│ ├── overview.rst
│ └── sharp_edges.rst
│ ├── conf.py
│ ├── fundamentals
│ ├── examine.rst
│ ├── hello_world.rst
│ └── installation.rst
│ ├── index.rst
│ ├── intermediate
│ ├── additional_executors.rst
│ ├── benchmarking.rst
│ ├── ddp.rst
│ └── whats_next.rst
│ └── reference
│ ├── clang
│ └── index.rst
│ ├── common
│ └── index.rst
│ ├── core
│ ├── baseutils.rst
│ ├── codeutils.rst
│ ├── devices.rst
│ ├── dtypes.rst
│ ├── functionalization.rst
│ ├── index.rst
│ ├── langctxs.rst
│ ├── prims.rst
│ ├── proxies.rst
│ ├── pytree.rst
│ ├── rematerialization.rst
│ ├── symbol.rst
│ ├── trace.rst
│ └── transforms.rst
│ ├── distributed
│ └── index.rst
│ ├── dynamo
│ └── index.rst
│ ├── examine
│ └── index.rst
│ ├── executors
│ ├── apexex.rst
│ ├── cudnnex.rst
│ ├── index.rst
│ ├── nvfuserex.rst
│ ├── passes.rst
│ ├── pythonex.rst
│ ├── torch_compile.rst
│ ├── torchex.rst
│ ├── triton_crossentropy.rst
│ └── utils.rst
│ ├── extend
│ └── index.rst
│ ├── recipes
│ └── index.rst
│ ├── thunder.rst
│ ├── torch
│ └── index.rst
│ └── transforms
│ └── index.rst
├── examples
└── quickstart
│ ├── hf_bert.py
│ ├── hf_llm.py
│ ├── mlp.py
│ ├── requirements.txt
│ ├── vit_hf.py
│ └── vit_tv.py
├── notebooks
├── .ignore.ci
├── adding_custom_operator.ipynb
├── adding_custom_operator_backward.ipynb
├── dev_tutorials
│ ├── extend.ipynb
│ ├── fsdp_tutorial.ipynb
│ └── thunder-add-vjp-rule.md
├── extend_thunder_with_cuda_python.ipynb
├── hello_world_thunderfx.ipynb
├── liger_kernel.ipynb
├── thunder_trace_intro.ipynb
├── writing_a_trace_transform_cpu_offloading.ipynb
└── zero_to_thunder.ipynb
├── pyproject.toml
├── requirements.txt
├── requirements
├── base.txt
├── devel.txt
├── docs.txt
├── notebooks.txt
└── test.txt
├── scripts
├── bisect_nvfuser.py
├── build_from_source.sh
└── validate_build.py
├── setup.py
└── thunder
├── __about__.py
├── __init__.py
├── benchmarks
├── __init__.py
├── benchmark_hf.py
├── benchmark_litgpt.py
├── conftest.py
├── distributed.py
├── einsum.py
├── targets.py
├── test_benchmark_litgpt.py
└── utils.py
├── clang
├── __init__.py
└── langctx.py
├── common.py
├── core
├── __init__.py
├── baseutils.py
├── codeutils.py
├── compile_data.py
├── devices.py
├── dtypes.py
├── functionalization.py
├── interpreter.py
├── jit_ext.py
├── langctxs.py
├── module.py
├── options.py
├── patterns.py
├── prims.py
├── profile.py
├── proxies.py
├── pytree.py
├── recipe.py
├── rematerialization.py
├── symbol.py
├── trace.py
├── trace_interpreter.py
├── transform_common.py
├── transforms.py
├── update_aliases.py
├── utils.py
└── vjp_utils.py
├── dev_utils
├── __init__.py
├── benchmark.py
├── check_trace.py
├── debug_transform.py
├── nvtx_profile_transform.py
└── utils.py
├── distributed
├── __init__.py
├── bucketing.py
├── checkpoint.py
├── prims.py
├── tensor_parallel
│ ├── __init__.py
│ ├── column_wise.py
│ ├── common.py
│ ├── optimize_comm.py
│ └── row_wise.py
├── transforms
│ ├── __init__.py
│ ├── ddp.py
│ ├── ddp_v2.py
│ ├── fsdp.py
│ └── fsdp_v2.py
└── utils.py
├── dynamo
├── __init__.py
├── benchmark_utils.py
├── compiler.py
├── compiler_graph_benchmark.py
├── report.py
├── repro_script_template.py
├── splitter.py
└── utils.py
├── examine
├── __init__.py
└── memory_calculation.py
├── executors
├── __init__.py
├── apex_entropyex_impl.py
├── apex_fused_rms_norm_impl.py
├── apexex.py
├── cudnn_layernormex.py
├── cudnnex.py
├── data_dependent_partition.py
├── fa3ex.py
├── nvfuserex.py
├── nvfuserex_impl.py
├── passes.py
├── pythonex.py
├── sdpaex.py
├── torch_autograd.py
├── torch_compile.py
├── torchex.py
├── transformer_engine_v2ex.py
├── transformer_engine_v2ex_impl.py
├── transformer_engineex.py
├── triton_crossentropy.py
├── triton_crossentropy_impl.py
├── triton_utils.py
└── utils.py
├── extend
└── __init__.py
├── numpy
├── __init__.py
└── langctx.py
├── plugins
├── __init__.py
├── distributed.py
├── fp8.py
├── quantization.py
└── reduce_overhead.py
├── py.typed
├── recipes
├── __init__.py
├── base.py
└── hf_transformers.py
├── tests
├── README.md
├── __init__.py
├── bf16.py
├── conftest.py
├── distributed
│ ├── __init__.py
│ ├── helper.py
│ ├── modules.py
│ ├── test_checkpoint.py
│ ├── test_ddp.py
│ ├── test_fsdp.py
│ ├── test_ops.py
│ └── test_tensor_parallel.py
├── framework.py
├── hf_bart_self_attn.py
├── litgpt_model.py
├── llama2_model.py
├── make_tensor.py
├── module_example.py
├── nanogpt_model.py
├── opinfos.py
├── test_apex_cross_entropy_executor.py
├── test_apex_fused_norms.py
├── test_auto_register_torchops.py
├── test_autocast.py
├── test_check_trace.py
├── test_core.py
├── test_cudnn_executor.py
├── test_dynamo.py
├── test_einops.py
├── test_elementwise.py
├── test_examine.py
├── test_examine_memory.py
├── test_extend.py
├── test_fa3_executor.py
├── test_grad.py
├── test_inplace_copy.py
├── test_inplace_functionalization.py
├── test_interpreter.py
├── test_jit_general.py
├── test_networks.py
├── test_nvfuser.py
├── test_nvfuser_remat.py
├── test_ops.py
├── test_patterns.py
├── test_pythonex.py
├── test_randomness.py
├── test_recipes.py
├── test_reductions.py
├── test_sdpaex_executor.py
├── test_shape_ops.py
├── test_torch_compile_executor.py
├── test_transformer_engine_executor.py
├── test_transformer_engine_v2_executor.py
├── test_transforms.py
├── test_triton_ce.py
└── test_update_aliases.py
├── torch
├── __init__.py
├── default_torch_ops.py
└── langctx.py
└── transforms
├── __init__.py
├── autocast.py
├── autodiff.py
├── constant_folding.py
├── cudagraph.py
├── extraction_only_prologue_transform.py
├── materialization.py
├── prune_prologue_checks.py
├── qlora.py
├── quantization.py
└── utils.py
/.azure/docker-build.yml:
--------------------------------------------------------------------------------
1 | trigger:
2 | tags:
3 | include: ["*"]
4 | branches:
5 | include: ["main"]
6 | paths:
7 | include:
8 | - ".azure/docker-build.yml"
9 | - "dockers/**"
10 | - "requirements.txt"
11 | - "requirements/*.txt"
12 | - "setup.py"
13 | exclude:
14 | - "*.md"
15 | - "**/*.md"
16 |
17 | pr:
18 | branches:
19 | include: ["*"]
20 | paths:
21 | include:
22 | - ".azure/docker-build.yml"
23 | - "dockers/**"
24 | - "requirements.txt"
25 | - "requirements/*.txt"
26 | - "setup.py"
27 | exclude:
28 | - "*.md"
29 | - "**/*.md"
30 |
31 | schedules:
32 | - cron: "0 */2 * * *"
33 | displayName: rebuild dockers for CI every 2 hours
34 | branches:
35 | include: ["main"]
36 |
37 | jobs:
38 | - job: build_push
39 | strategy:
40 | matrix:
41 | "cuda 12.6 | torch 2.7.1 | cudnn FE v1.10.0":
42 | { CUDA_VERSION: "12.6.3", TORCH_VERSION: "2.7.1", TRITON_VERSION: "3.3.1", CUDNN_FRONTEND_VERSION: "1.10.0" }
43 | "cuda 12.6 | torch nightly | cudnn FE v1.10.0":
44 | { CUDA_VERSION: "12.6.3", TORCH_VERSION: "main", TORCH_INSTALL: "source", CUDNN_FRONTEND_VERSION: "1.10.0" }
45 | #'cuda 12.1': # this version - '8.9.5.29-1+cuda12.1' for 'libcudnn8' was not found
46 | # how much time to give 'run always even if cancelled tasks' before stopping them
47 | cancelTimeoutInMinutes: "2"
48 | timeoutInMinutes: "95"
49 | variables:
50 | UBUNTU_VERSION: "24.04"
51 | PYTHON_VERSION: "3.10"
52 | imageRepository: "pytorchlightning/lightning-thunder"
53 | imageTag: "ubuntu$(UBUNTU_VERSION)-cuda$(CUDA_VERSION)-cudnn-fe$(CUDNN_FRONTEND_VERSION)-py$(PYTHON_VERSION)-pt_${TORCH_VERSION/v/}"
54 | pool: "lit-rtx-3090"
55 | workspace:
56 | clean: all
57 | steps:
58 | - bash: |
59 | set -e
60 | echo $imageTag
61 | nvidia-smi
62 | docker image build \
63 | -t $(imageRepository):$(imageTag) \
64 | -f "dockers/ubuntu-cuda/Dockerfile" \
65 | --build-arg UBUNTU_VERSION="$(UBUNTU_VERSION)" \
66 | --build-arg CUDA_VERSION="$(CUDA_VERSION)" \
67 | --build-arg CUDNN_FRONTEND_VERSION="v$(CUDNN_FRONTEND_VERSION)" \
68 | --build-arg PYTHON_VERSION="$(PYTHON_VERSION)" \
69 | --build-arg TORCH_VERSION="$(TORCH_VERSION)" \
70 | --build-arg TORCH_INSTALL="$(TORCH_INSTALL)" \
71 | --build-arg TRITON_VERSION="$(TRITON_VERSION)" \
72 | . --no-cache
73 | timeoutInMinutes: "95"
74 | displayName: "Build base image"
75 |
76 | - bash: |
77 | docker image build \
78 | -t $(imageRepository):$(imageTag)-apex \
79 | -f "dockers/with-apex/Dockerfile" \
80 | --build-arg BASE_IMAGE_TAG="$(imageTag)" \
81 | . --no-cache
82 | timeoutInMinutes: "25"
83 | displayName: "Build Apex image"
84 |
85 | - bash: |
86 | docker image build \
87 | -t $(imageRepository):$(imageTag)-dev \
88 | -f "dockers/with-dev/Dockerfile" \
89 | --build-arg BASE_IMAGE_TAG="$(imageTag)-apex" \
90 | . --no-cache
91 | timeoutInMinutes: "25"
92 | displayName: "Build Dev image"
93 |
94 | - bash: |
95 | docker image ls | grep $(imageRepository)
96 | # drop pt from requirements so not to interfere with the existing one
97 | bash .azure/remove-torch-lines.sh requirements/base.txt
98 | mv .azure azure
99 | docker run --rm --gpus=all -v .:/workspace $(imageRepository):$(imageTag)-dev \
100 | bash -c "cd /workspace && ls -lh . && \
101 | pip install -q . && \
102 | bash azure/sanity-check.sh"
103 | timeoutInMinutes: "5"
104 | displayName: "Sanity check"
105 |
106 | - bash: |
107 | set -e
108 | echo $(imageRepository):$(imageTag)
109 | echo $(DOCKERHUB_PAT) | docker login --username $(DOCKERHUB_USER) --password-stdin
110 | docker push $(imageRepository):$(imageTag)-dev
111 | condition: ne(variables['Build.Reason'], 'PullRequest')
112 | timeoutInMinutes: "35"
113 | displayName: "Push base image"
114 |
--------------------------------------------------------------------------------
/.azure/notebook-runs.yml:
--------------------------------------------------------------------------------
1 | trigger:
2 | tags:
3 | include: ["*"]
4 | branches:
5 | include:
6 | - "main"
7 | - "release/*"
8 | - "refs/tags/*"
9 |
10 | pr:
11 | branches:
12 | include: ["*"]
13 |
14 | jobs:
15 | - job: jupyter
16 | strategy:
17 | matrix:
18 | "notebooks w/ torch 2.7":
19 | docker-image: "ubuntu24.04-cuda12.6.3-cudnn-fe1.10.0-py3.10-pt_2.7.1-dev"
20 | "notebooks w/ torch-nightly":
21 | docker-image: "ubuntu24.04-cuda12.6.3-cudnn-fe1.10.0-py3.10-pt_main-dev"
22 | # how long to run the job before automatically cancelling
23 | timeoutInMinutes: "45"
24 | # how much time to give 'run always even if cancelled tasks' before stopping them
25 | cancelTimeoutInMinutes: "2"
26 | pool: "lit-rtx-3090"
27 | variables:
28 | DEVICES: $( python -c 'name = "$(Agent.Name)" ; gpus = name.split("_")[-1] if "_" in name else "0"; print(gpus)' )
29 | TORCH_HOME: "/var/tmp/torch"
30 | PIP_CACHE_DIR: "/var/tmp/pip"
31 | container:
32 | image: "pytorchlightning/lightning-thunder:$(docker-image)"
33 | options: "--gpus=all --shm-size=16g -v /var/tmp:/var/tmp"
34 | workspace:
35 | clean: all
36 | steps:
37 | - bash: |
38 | echo $(DEVICES)
39 | lspci | egrep 'VGA|3D'
40 | whereis nvidia
41 | nvidia-smi
42 | which python && which pip
43 | python --version
44 | pip --version
45 | pip list
46 | echo "##vso[task.setvariable variable=CUDA_VISIBLE_DEVICES]$(DEVICES)"
47 | displayName: "Image info & NVIDIA"
48 |
49 | - bash: |
50 | # drop pt from requirements so not to interfere with the existing one
51 | bash .azure/remove-torch-lines.sh requirements/base.txt
52 | cat requirements/base.txt
53 | # double check on test requirements
54 | pip install -r requirements/notebooks.txt
55 | # install this package
56 | python setup.py develop
57 | displayName: "Install package & ..."
58 |
59 | - bash: |
60 | set -ex
61 | bash .azure/sanity-check.sh
62 | displayName: "Sanity check / details"
63 |
64 | - bash: |
65 | set -ex
66 | # list all notebooks in this folder
67 | find . -name "*.ipynb" > all.txt
68 | # drop all "./" from beginning of each line
69 | sed -i 's/^\.\///' all.txt
70 | # filter out the ones that are listed in .ignore.ci
71 | grep -Fxv -f .ignore.ci all.txt > ci.txt
72 | # iterate over all listed notebooks and execute them with jupyter
73 | while read -r line; do
74 | echo "Processing $line"
75 | jupyter execute $line --timeout=300
76 | done <<< $(cat ci.txt)
77 | workingDirectory: "notebooks/"
78 | displayName: "Execute notebooks"
79 |
--------------------------------------------------------------------------------
/.azure/remove-torch-lines.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Get the name of the file to process
4 | filename=$1
5 |
6 | # Create a temporary file to store the filtered lines
7 | tempfile=$(mktemp)
8 |
9 | # Loop through the original file and remove lines including the word "torch"
10 | while read -r line; do
11 | if [[ ! "$line" =~ "torch" ]]; then
12 | echo "$line" >> "$tempfile"
13 | fi
14 | done < "$filename"
15 |
16 | # Move the temporary file to the original file
17 | mv "$tempfile" "$filename"
18 |
--------------------------------------------------------------------------------
/.azure/sanity-check.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -ex
4 | pip list
5 | python -c "import torch ; assert torch.cuda.is_available(), 'missing GPU support!'"
6 | python -c "import torch ; v = torch.__version__ ; assert str(v).startswith('2'), v"
7 | python -c "from thunder.executors import nvfuser_available ; assert nvfuser_available(), 'nvFuser is missing!'"
8 | python -c "from thunder.executors.triton_utils import triton_version ; assert triton_version() is not None, 'triton is missing!'"
9 |
--------------------------------------------------------------------------------
/.codecov.yml:
--------------------------------------------------------------------------------
1 | # see https://docs.codecov.io/docs/codecov-yaml
2 | # Validation check:
3 | # $ curl --data-binary @.codecov.yml https://codecov.io/validate
4 |
5 | # https://docs.codecov.io/docs/codecovyml-reference
6 | codecov:
7 | bot: "codecov-io"
8 | strict_yaml_branch: "yaml-config"
9 | require_ci_to_pass: yes
10 | notify:
11 | # after_n_builds: 2
12 | wait_for_ci: yes
13 |
14 | coverage:
15 | precision: 0 # 2 = xx.xx%, 0 = xx%
16 | round: nearest # how coverage is rounded: down/up/nearest
17 | range: 40...100 # custom range of coverage colors from red -> yellow -> green
18 | status:
19 | # https://codecov.readme.io/v1.0/docs/commit-status
20 | project:
21 | default:
22 | informational: true
23 | target: 99% # specify the target coverage for each commit status
24 | threshold: 30% # allow this little decrease on project
25 | # https://github.com/codecov/support/wiki/Filtering-Branches
26 | # branches: main
27 | if_ci_failed: error
28 | # https://github.com/codecov/support/wiki/Patch-Status
29 | patch:
30 | default:
31 | informational: true
32 | target: 50% # specify the target "X%" coverage to hit
33 | # threshold: 50% # allow this much decrease on patch
34 | changes: false
35 |
36 | # https://docs.codecov.com/docs/github-checks#disabling-github-checks-patch-annotations
37 | github_checks:
38 | annotations: false
39 |
40 | parsers:
41 | gcov:
42 | branch_detection:
43 | conditional: true
44 | loop: true
45 | macro: false
46 | method: false
47 | javascript:
48 | enable_partials: false
49 |
50 | comment: false
51 |
--------------------------------------------------------------------------------
/.git-blame-ignore-revs:
--------------------------------------------------------------------------------
1 | # Move optimize_allreduce_in_ddp_backward to thunder.distributed.transforms namespace #2087
2 | 7a3d5d4e64da010eea42210a065960e3568870ac
3 | # Split `test_ddp` into `test_ddp`, `test_fsdp` and `test_ops` (#772)
4 | b63e8713141044ed8c184f2f4c8305048d477319
5 |
--------------------------------------------------------------------------------
/.github/CODEOWNERS:
--------------------------------------------------------------------------------
1 | # Each line is a file pattern followed by one or more owners.
2 |
3 | # These owners will be the default owners for everything in the repo. Unless a later match takes precedence,
4 | # @global-owner1, @global-owner2, and @global-owner3 will be requested for review when someone opens a pull request.
5 |
6 | # Thank you, our previous code owners for their service:
7 | # @carmocca
8 |
9 | * @mruberry @lantiga @t-vi
10 |
11 | # CI/CD and configs
12 | /.azure/ @borda @lantiga @t-vi
13 | /.github/ @borda @lantiga @t-vi
14 | /dockers/ @borda @lantiga @t-vi
15 | Makefile @borda @lantiga @t-vi
16 | *.yml @borda @lantiga @t-vi
17 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug_report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug report
3 | about: Create a report to help us improve
4 | title: ''
5 | assignees: ''
6 | ---
7 |
8 | *Note*: If you have a model or program that is not supported yet but should be, please use the program coverage template.
9 |
10 | ## 🐛 Bug
11 |
12 |
13 |
14 | ### To Reproduce
15 |
16 | Steps to reproduce the behavior:
17 |
18 | 1. Go to '...'
19 | 1. Run '....'
20 | 1. Scroll down to '....'
21 | 1. See error
22 |
23 |
24 |
25 | #### Code sample
26 |
27 |
29 |
30 | ### Expected behavior
31 |
32 |
33 |
34 | ### Environment
35 |
36 | - PyTorch Version (e.g., 1.0):
37 | - OS (e.g., Linux):
38 | - How you installed PyTorch (`conda`, `pip`, source):
39 | - Build command you used (if compiling from source):
40 | - Python version:
41 | - CUDA/cuDNN version:
42 | - GPU models and configuration:
43 | - Any other relevant information:
44 |
45 | ### Additional context
46 |
47 |
48 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/documentation.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Typos and doc fixes
3 | about: Typos and doc fixes
4 | title: ''
5 | labels: documentation
6 | assignees: ''
7 | ---
8 |
9 | ## 📚 Documentation
10 |
11 | For typos and doc fixes, please go ahead and:
12 |
13 | 1. Create an issue.
14 | 1. Fix the typo.
15 | 1. Submit a PR.
16 |
17 | Thanks!
18 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Feature request
3 | about: Suggest an idea for this project
4 | title: ''
5 | labels: enhancement
6 | assignees: ''
7 | ---
8 |
9 | ## 🚀 Feature
10 |
11 |
12 |
13 | ### Motivation
14 |
15 |
16 |
17 | ### Pitch
18 |
19 |
20 |
21 | ### Alternatives
22 |
23 |
24 |
25 | ### Additional context
26 |
27 |
28 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/program_coverage.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Program Coverage
3 | about: Expand the programs / models Thunder can process
4 | title: ''
5 | labels: program-coverage
6 | assignees: ''
7 | ---
8 |
9 | ## 🚀 Model / language coverage
10 |
11 |
12 |
13 | ### Pitch
14 |
15 |
16 |
17 | ### Alternatives / Potential work-arounds
18 |
19 |
24 |
--------------------------------------------------------------------------------
/.github/PULL_REQUEST_TEMPLATE.md:
--------------------------------------------------------------------------------
1 |
2 | Before submitting
3 |
4 | - [ ] Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
5 | - [ ] Did you read the [contributor guideline](https://github.com/Lightning-AI/pytorch-lightning/blob/main/.github/CONTRIBUTING.md), Pull Request section?
6 | - [ ] Did you make sure to update the docs?
7 | - [ ] Did you write any new necessary tests?
8 |
9 |
10 |
11 | ## What does this PR do?
12 |
13 | Fixes # (issue).
14 |
15 | ## PR review
16 |
17 | Anyone in the community is free to review the PR once the tests have passed.
18 | If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
19 |
20 | ## Did you have fun?
21 |
22 | Make sure you had fun coding 🙃
23 |
--------------------------------------------------------------------------------
/.github/dependabot.yml:
--------------------------------------------------------------------------------
1 | # Basic dependabot.yml file with
2 | # minimum configuration for two package managers
3 |
4 | version: 2
5 | updates:
6 | # Enable version updates for python
7 | - package-ecosystem: "pip"
8 | # Look for a `requirements` in the `root` directory
9 | directory: "/"
10 | schedule:
11 | interval: "monthly"
12 | pull-request-branch-name:
13 | separator: "-"
14 | # Allow up to 5 open pull requests for pip dependencies
15 | open-pull-requests-limit: 5
16 |
17 | # Enable version updates for GitHub Actions
18 | - package-ecosystem: "github-actions"
19 | directory: "/"
20 | schedule:
21 | interval: "monthly"
22 | pull-request-branch-name:
23 | separator: "-"
24 | # Allow up to 5 open pull requests for GitHub Actions
25 | open-pull-requests-limit: 5
26 | groups:
27 | GHA-updates:
28 | patterns:
29 | - "*"
30 |
--------------------------------------------------------------------------------
/.github/labeling-config.yml:
--------------------------------------------------------------------------------
1 | documentation:
2 | - changed-files:
3 | - any-glob-to-any-file:
4 | - docs/**/*
5 | - dev_tutorials/*
6 |
7 | "ci":
8 | - changed-files:
9 | - any-glob-to-any-file:
10 | - .azure/*
11 | - .github/*
12 | - .github/workflows/*
13 | - dockers/**/*
14 |
15 | "docker":
16 | - changed-files:
17 | - any-glob-to-any-file:
18 | - dockers/**/*
19 | - .azure/docker-build.yml
20 |
21 | "install":
22 | - changed-files:
23 | - any-glob-to-any-file:
24 | - setup.py
25 | - pyproject.toml
26 | - MANIFEST.in
27 |
28 | "dependencies":
29 | - changed-files:
30 | - any-glob-to-any-file:
31 | - requirements/*
32 | - requirements.txt
33 |
--------------------------------------------------------------------------------
/.github/lightning-probot.yml:
--------------------------------------------------------------------------------
1 | tracking_issue: 72
2 |
--------------------------------------------------------------------------------
/.github/load-quickstart-report.py:
--------------------------------------------------------------------------------
1 | import json
2 | import sys
3 |
4 |
5 | def main(report_path: str = "quickstart_report.json"):
6 | """Load and print the quickstart report."""
7 | with open(report_path) as fp:
8 | report = json.load(fp)
9 |
10 | success_count = sum(status == "completed" for status in report.values())
11 | overall_status = "🟢" if success_count == len(report) else "⛔"
12 | print(f"Quickstart report {overall_status} with {success_count} out of {len(report)} successful:")
13 | # Sort so that entries with status "success" (or "completed") are last
14 | for name, status in sorted(report.items(), key=lambda x: x[1] == "completed"):
15 | status_icon = "✔️" if status == "completed" else "❌"
16 | print(f"{status_icon} {name}")
17 |
18 |
19 | if __name__ == "__main__":
20 | # optional path to the report file
21 | system_args = sys.argv[1:]
22 | main_args = {"report_path": system_args[0]} if len(system_args) > 0 else {}
23 | main(**main_args)
24 |
--------------------------------------------------------------------------------
/.github/run-benchmark-as-lit-jobs.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import json
3 | import sys
4 | import time
5 | from datetime import datetime
6 |
7 | from lightning_sdk import Studio, Job, Machine, Status
8 |
9 |
10 | def main(gh_run_id: str = ""):
11 | if not gh_run_id:
12 | gh_run_id = datetime.now().strftime("%Y-%m-%d|%H:%M:%S")
13 | print("Creating studio...")
14 | s = Studio(f"thunder-benchmark-run{gh_run_id}", "oss-thunder", org="lightning-ai", create_ok=True)
15 |
16 | print("Uploading package and benchmark script...")
17 | s.upload_folder("dist", remote_path="dist")
18 | pkg_path = glob.glob("dist/*.whl")[0]
19 | s.upload_file("thunder/benchmarks/benchmark_hf.py", remote_path="benchmarks/benchmark_hf.py")
20 |
21 | print("Starting studio...")
22 | s.start()
23 | print("Installing Thunder and dependencies...")
24 | s.run(f"pip install {pkg_path} -U transformers 'numpy<2.0' 'nvfuser_cu128_torch27==0.2.27.dev20250501'")
25 |
26 | print("Running HF benchmark script...")
27 | job = Job.run(
28 | name=f"benchmark-run{gh_run_id}",
29 | command="pip list && python benchmarks/benchmark_hf.py",
30 | studio=s,
31 | machine=Machine.L40S,
32 | interruptible=True,
33 | )
34 |
35 | print("Stopping studio...")
36 | s.stop()
37 |
38 | print("Waiting for job to finish...")
39 | job.wait()
40 | status = str(job.status).lower()
41 | print(f"[{job.status}]\t {job.name}")
42 |
43 | report = {"benchmark_hf.py": status}
44 | with open("benchmark_hf_report.json", "w") as fp:
45 | json.dump(report, fp, indent=4)
46 |
47 | if job.status != Status.Completed:
48 | print("=" * 80)
49 | print(f"===== benchmark_hf.py -> {job.status} =====")
50 | print("=" * 80)
51 | print(job.logs)
52 | print("=" * 80)
53 | time.sleep(3)
54 | raise RuntimeError(f"Benchmark HF job {job.status}")
55 | # clean up
56 | job.delete()
57 | s.delete()
58 |
59 |
60 | if __name__ == "__main__":
61 | # parse command line arguments
62 | args = sys.argv[1:]
63 | main(*args)
64 |
--------------------------------------------------------------------------------
/.github/run-quickstart-as-lit-jobs.py:
--------------------------------------------------------------------------------
1 | import json
2 | import sys
3 | from datetime import datetime
4 | import glob
5 | import os.path
6 |
7 | from lightning_sdk import Studio, Job, Machine, Status
8 |
9 |
10 | def main(gh_run_id: str = ""):
11 | if not gh_run_id:
12 | gh_run_id = datetime.now().strftime("%Y-%m-%d|%H:%M:%S")
13 | print("Creating studio...")
14 | s = Studio(f"thunder-quickstarts-run{gh_run_id}", "oss-thunder", org="lightning-ai", create_ok=True)
15 | print("Uploading package and scripts...")
16 | s.upload_folder("dist", remote_path="dist")
17 | pkg_path = glob.glob("dist/*.whl")[0]
18 | s.upload_folder("examples/quickstart", remote_path="quickstart")
19 |
20 | print("Starting studio...")
21 | s.start()
22 | print("Installing Thunder and other requirements...")
23 | s.run(f"pip install {pkg_path} -U -r quickstart/requirements.txt")
24 |
25 | ls_quickstart = glob.glob("examples/quickstart/*.py")
26 | print("Found quickstart scripts:", ls_quickstart)
27 |
28 | print("Running quickstart scripts...")
29 | jobs = {
30 | os.path.basename(script): Job.run(
31 | name=f"ci-run{gh_run_id}_{script}",
32 | command=f"pip list && python quickstart/{os.path.basename(script)}",
33 | studio=s,
34 | machine=Machine.L40S,
35 | interruptible=True,
36 | )
37 | for script in ls_quickstart
38 | }
39 |
40 | print("Stopping studio...")
41 | s.stop()
42 |
43 | print("Waiting for jobs to finish...")
44 | report, failures = {}, {}
45 | for name, job in jobs.items():
46 | job.wait()
47 | print(f"[{job.status}]\t {job.name}")
48 | report[name] = str(job.status).lower()
49 | if job.status != Status.Completed:
50 | failures[name] = job.logs
51 | else: # clean up successful jobs
52 | job.delete()
53 |
54 | with open("quickstart_report.json", "w") as fp:
55 | json.dump(report, fp, indent=4)
56 |
57 | print("Showing logs of failed jobs...")
58 | separator = "=" * 80
59 | for name, logs in failures.items():
60 | offset = "=" * (80 - 5 - 2 - len(name))
61 | print(f"{separator}\n===== {name} {offset}\n{separator}")
62 | print(logs)
63 | print(separator + "\n" * 5)
64 |
65 | assert not failures
66 |
67 | print("Cleaning up...")
68 | s.delete()
69 |
70 |
71 | if __name__ == "__main__":
72 | # parse command line arguments
73 | args = sys.argv[1:]
74 | main(*args)
75 |
--------------------------------------------------------------------------------
/.github/workflows/auto-cc.yml:
--------------------------------------------------------------------------------
1 | name: Probot
2 |
3 | on:
4 | issues:
5 | types: [labeled]
6 | # should use `pull_request_target` but it's blocked by
7 | # https://github.com/probot/probot/issues/1635
8 | # so this job will not run on forks until the above is fixed
9 | pull_request:
10 | types: [labeled, ready_for_review]
11 |
12 | jobs:
13 | auto-cc:
14 | runs-on: ubuntu-latest
15 | if: github.event_name == 'issue' || github.event.pull_request.draft == false
16 | timeout-minutes: 5
17 | steps:
18 | - uses: Lightning-AI/probot@v5
19 | env:
20 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
21 | with:
22 | job: auto-cc
23 |
--------------------------------------------------------------------------------
/.github/workflows/ci-benchmark.yml:
--------------------------------------------------------------------------------
1 | name: Benchmark Models
2 |
3 | on:
4 | workflow_dispatch: {}
5 | pull_request:
6 | paths:
7 | - ".github/workflows/benchmark-hf.yml"
8 | - ".github/run-benchmark-as-lit-jobs.py"
9 | push: # enable tracing which commit to master broke it...
10 | branches: [main]
11 |
12 | concurrency:
13 | group: ${{ github.workflow }}-${{ github.ref }}
14 | cancel-in-progress: false
15 |
16 | defaults:
17 | run:
18 | shell: bash
19 |
20 | jobs:
21 | launcher-benchmark:
22 | runs-on: "ubuntu-22.04"
23 |
24 | steps:
25 | - uses: actions/checkout@v4
26 |
27 | - uses: actions/setup-python@v5
28 | with:
29 | python-version: "3.10"
30 |
31 | - name: Build Thunder Package
32 | run: |
33 | pip install -U build
34 | python -m build --sdist --wheel --outdir dist/
35 | ls -l dist/
36 |
37 | - name: Launch Benchmark Job in Lightning Studio
38 | env:
39 | LIGHTNING_USER_ID: ${{ secrets.LIGHTNING_USER_ID }}
40 | LIGHTNING_API_KEY: ${{ secrets.LIGHTNING_API_KEY }}
41 | run: |
42 | pip install lightning_sdk -U -q
43 | python .github/run-benchmark-as-lit-jobs.py ${{ github.run_id }}
44 |
45 | - name: Post Slack Notification
46 | if: always() && github.event_name != 'pull_request'
47 | uses: act10ns/slack@v2
48 | with:
49 | webhook-url: ${{ secrets.SLACK_WEBHOOK_URL }}
50 | status: ${{ job.status }}
51 | message: |
52 | *Benchmark Triggered Manually* - [${{ job.status }}]
53 | ref: ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}
54 |
--------------------------------------------------------------------------------
/.github/workflows/ci-checks.yml:
--------------------------------------------------------------------------------
1 | name: General checks
2 |
3 | on:
4 | push:
5 | branches: [main]
6 | pull_request: {}
7 |
8 | concurrency:
9 | group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref }}
10 | cancel-in-progress: ${{ ! (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/release/')) }}
11 |
12 | jobs:
13 | check-schema:
14 | uses: Lightning-AI/utilities/.github/workflows/check-schema.yml@main
15 | with:
16 | azure-dir: ".azure"
17 |
18 | check-package:
19 | uses: Lightning-AI/utilities/.github/workflows/check-package.yml@main
20 | with:
21 | actions-ref: main
22 | import-name: "thunder"
23 | artifact-name: dist-packages-${{ github.sha }}
24 | testing-matrix: |
25 | {
26 | "os": ["ubuntu-latest", "macOS-latest", "windows-latest"],
27 | "python-version": ["3.10", "3.11"]
28 | }
29 |
--------------------------------------------------------------------------------
/.github/workflows/ci-quickstart.yml:
--------------------------------------------------------------------------------
1 | name: CI quickstart
2 |
3 | on:
4 | push:
5 | branches: [main]
6 | pull_request:
7 | paths:
8 | - ".github/workflows/ci-quickstart.yml"
9 | - ".github/run-quickstart-as-lit-jobs.py"
10 | - ".github/load-quickstart-report.py"
11 | - "examples/quickstart/*"
12 | workflow_dispatch: {}
13 | schedule:
14 | - cron: "0 0 * * *" # every midnight
15 |
16 | concurrency:
17 | group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref }}
18 | cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
19 |
20 | defaults:
21 | run:
22 | shell: bash
23 |
24 | jobs:
25 | launcher-quickstart:
26 | runs-on: "ubuntu-22.04"
27 | #timeout-minutes: 55
28 | steps:
29 | - uses: actions/checkout@v4
30 | - uses: actions/setup-python@v5
31 | with:
32 | python-version: "3.10"
33 |
34 | - name: Build package
35 | run: |
36 | pip install -U build
37 | python -m build --sdist --wheel --outdir dist/
38 | ls -l dist/
39 |
40 | - name: Run scripts in jobs
41 | env:
42 | LIGHTNING_USER_ID: ${{ secrets.LIGHTNING_USER_ID }}
43 | LIGHTNING_API_KEY: ${{ secrets.LIGHTNING_API_KEY }}
44 | run: |
45 | pip install lightning_sdk -U -q
46 | python .github/run-quickstart-as-lit-jobs.py ${{ github.run_id }}
47 |
48 | - name: Load report
49 | if: always()
50 | id: load-report
51 | run: |
52 | report=$(python .github/load-quickstart-report.py)
53 | echo "REPORT<> $GITHUB_ENV
54 | echo "$report" >> $GITHUB_ENV
55 | echo "EOF" >> $GITHUB_ENV
56 |
57 | - uses: act10ns/slack@v2
58 | if: always() && github.event_name != 'pull_request' && steps.load-report.conclusion == 'success'
59 | with:
60 | webhook-url: ${{ secrets.SLACK_WEBHOOK_URL }}
61 | status: ${{ job.status }}
62 | message: |
63 | *Quickstart CI* - [${{ job.status }}]
64 | ${{ env.REPORT }}
65 | ref: ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}
66 |
--------------------------------------------------------------------------------
/.github/workflows/docs-build.yml:
--------------------------------------------------------------------------------
1 | name: "Build (& deploy) Docs"
2 | on:
3 | push:
4 | branches: [main]
5 | pull_request:
6 | branches: [main]
7 | types: [opened, reopened, ready_for_review, synchronize]
8 | workflow_dispatch: {}
9 |
10 | concurrency:
11 | group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref }}
12 | cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
13 |
14 | defaults:
15 | run:
16 | shell: bash
17 |
18 | jobs:
19 | docs-make:
20 | if: github.event.pull_request.draft == false
21 | runs-on: ubuntu-22.04
22 | strategy:
23 | fail-fast: false
24 | matrix:
25 | target: ["html", "doctest", "linkcheck"]
26 | env:
27 | ARTIFACT_DAYS: 0
28 | PYPI_LOCAL_DIR: "pypi_pkgs/"
29 | steps:
30 | - uses: actions/checkout@v4
31 | - uses: actions/setup-python@v5
32 | with:
33 | python-version: "3.10"
34 |
35 | - name: Pull sphinx template
36 | run: make get-sphinx-theme
37 | - name: Install pandoc
38 | timeout-minutes: 5
39 | run: sudo apt-get install -y pandoc
40 | - name: Install package & dependencies
41 | timeout-minutes: 20
42 | run: pip install . -U -r requirements/docs.txt
43 |
44 | - name: Make ${{ matrix.target }}
45 | working-directory: docs/
46 | # allow failing link check and doctest if you run with dispatch
47 | continue-on-error: ${{ matrix.target == 'doctest' || matrix.target == 'linkcheck' }}
48 | run: make ${{ matrix.target }} --debug --jobs $(nproc) SPHINXOPTS="-W --keep-going"
49 |
50 | - name: Keep artifact
51 | if: github.event_name == 'pull_request'
52 | run: echo "ARTIFACT_DAYS=7" >> $GITHUB_ENV
53 | - name: Upload built docs
54 | if: ${{ matrix.target == 'html' }}
55 | uses: actions/upload-artifact@v4
56 | with:
57 | name: docs-html-${{ github.sha }}
58 | path: docs/build/html/
59 | retention-days: ${{ env.ARTIFACT_DAYS }}
60 |
61 | deploy-docs:
62 | needs: docs-make
63 | if: github.repository_owner == 'Lightning-AI' && github.event_name == 'push'
64 | runs-on: ubuntu-latest
65 | env:
66 | GCP_TARGET: "gs://lightning-docs-thunder"
67 | steps:
68 | - uses: actions/download-artifact@v4
69 | with:
70 | name: docs-html-${{ github.sha }}
71 | path: docs/build/html/
72 |
73 | - name: Authenticate to Google Cloud
74 | uses: google-github-actions/auth@v2
75 | with:
76 | credentials_json: ${{ secrets.GCS_SA_KEY }}
77 |
78 | - name: Setup gcloud
79 | uses: google-github-actions/setup-gcloud@v2
80 | with:
81 | project_id: ${{ secrets.GCS_PROJECT }}
82 |
83 | # Uploading docs to GCS, so they can be served on lightning.ai
84 | #- name: Upload docs/thunder/stable to GCS 🪣
85 | # if: startsWith(github.ref, 'refs/heads/release/')
86 | # run: gsutil -m rsync -d -R docs/build/html/ ${GCP_TARGET}/stable
87 |
88 | # Uploading docs to GCS, so they can be served on lightning.ai
89 | - name: Upload docs/thunder/latest to GCS 🪣
90 | if: github.ref == 'refs/heads/main'
91 | run: gsutil -m rsync -d -R docs/build/html/ ${GCP_TARGET}/latest
92 |
93 | # Uploading docs to GCS, so they can be served on lightning.ai
94 | #- name: Upload docs/thunder/release to GCS 🪣
95 | # if: startsWith(github.ref, 'refs/tags/')
96 | # run: gsutil -m rsync -d -R docs/build/html/ ${GCP_TARGET}/${{ github.ref_name }}
97 |
98 | # Uploading docs as archive to GCS, so they can be as backup
99 | #- name: Upload docs as archive to GCS 🪣
100 | # if: startsWith(github.ref, 'refs/tags/')
101 | # working-directory: docs/build
102 | # run: |
103 | # zip ${{ github.ref_name }}.zip -r html/
104 | # gsutil cp ${{ github.ref_name }}.zip ${GCP_TARGET}
105 |
--------------------------------------------------------------------------------
/.github/workflows/label-conflicts.yml:
--------------------------------------------------------------------------------
1 | name: Label merge conflicts
2 |
3 | on:
4 | push:
5 | branches: ["main"]
6 | pull_request_target:
7 | types: ["synchronize", "reopened", "opened"]
8 |
9 | concurrency:
10 | group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref }}
11 | cancel-in-progress: true
12 |
13 | jobs:
14 | triage-conflicts:
15 | runs-on: ubuntu-latest
16 | steps:
17 | - uses: mschilde/auto-label-merge-conflicts@591722e97f3c4142df3eca156ed0dcf2bcd362bd
18 | with:
19 | CONFLICT_LABEL_NAME: "has conflicts"
20 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
21 | MAX_RETRIES: 3
22 | WAIT_MS: 5000
23 |
--------------------------------------------------------------------------------
/.github/workflows/labeler.yml:
--------------------------------------------------------------------------------
1 | name: "Pull Request Labeler"
2 | on: [pull_request_target]
3 |
4 | jobs:
5 | triage-prs:
6 | permissions:
7 | contents: read
8 | pull-requests: write
9 | runs-on: ubuntu-latest
10 | steps:
11 | # Uploads repository content to the runner
12 | - uses: actions/checkout@v4
13 | - uses: actions/labeler@v5
14 | with:
15 | # The path to the label configuration file.
16 | configuration-path: .github/labeling-config.yml
17 | # Whether removing labels when matching files are reverted or no longer changed by the PR
18 | sync-labels: true
19 |
--------------------------------------------------------------------------------
/.github/workflows/release-nightly.yml:
--------------------------------------------------------------------------------
1 | name: Nightly packages
2 |
3 | on:
4 | pull_request: # this shall test only the part of workflow before publishing
5 | branches: [main, "release/*"]
6 | types: [opened, reopened, ready_for_review, synchronize]
7 | paths:
8 | - ".github/workflows/release-nightly.yml"
9 | schedule:
10 | - cron: "0 0 * * 0" # on Sundays
11 | workflow_dispatch: {}
12 |
13 | defaults:
14 | run:
15 | shell: bash
16 |
17 | jobs:
18 | releasing-nightly:
19 | runs-on: ubuntu-22.04
20 | steps:
21 | - uses: actions/checkout@v4
22 | - uses: actions/setup-python@v5
23 | with:
24 | python-version: "3.10"
25 |
26 | - name: Install dependencies
27 | run: python -m pip install --user --upgrade setuptools wheel packaging
28 | - name: Build package
29 | env:
30 | CONVERT_VERSION2NIGHTLY: "1"
31 | run: python setup.py sdist bdist_wheel
32 |
33 | # We do this, since failures on test.pypi aren't that bad
34 | - name: Publish to Test PyPI
35 | if: github.event_name == 'schedule' || github.event_name == 'workflow_dispatch'
36 | uses: pypa/gh-action-pypi-publish@v1.12.4
37 | with:
38 | user: __token__
39 | password: ${{ secrets.test_pypi_password }}
40 | repository_url: https://test.pypi.org/legacy/
41 |
42 | - name: Publish distribution 📦 to PyPI
43 | if: github.event_name == 'schedule' || github.event_name == 'workflow_dispatch'
44 | uses: pypa/gh-action-pypi-publish@v1.12.4
45 | with:
46 | user: __token__
47 | password: ${{ secrets.pypi_password }}
48 |
--------------------------------------------------------------------------------
/.github/workflows/release-pypi.yml:
--------------------------------------------------------------------------------
1 | name: PyPI Release
2 |
3 | # https://help.github.com/en/actions/reference/events-that-trigger-workflows
4 | on: # Trigger the workflow on push or pull request, but only for the main branch
5 | push:
6 | branches: [main]
7 | release:
8 | types: [published]
9 |
10 | # based on https://github.com/pypa/gh-action-pypi-publish
11 |
12 | jobs:
13 | releasing-pypi:
14 | runs-on: ubuntu-22.04
15 | steps:
16 | - uses: actions/checkout@v4
17 | - uses: actions/setup-python@v5
18 | with:
19 | python-version: "3.10"
20 |
21 | - name: Install dependencies
22 | run: python -m pip install --user --upgrade setuptools wheel packaging
23 | - name: Build
24 | run: python setup.py sdist bdist_wheel
25 |
26 | # We do this, since failures on test.pypi aren't that bad
27 | - name: Publish to Test PyPI
28 | if: startsWith(github.event.ref, 'refs/tags') || github.event_name == 'release'
29 | uses: pypa/gh-action-pypi-publish@v1.12.4
30 | with:
31 | user: __token__
32 | password: ${{ secrets.test_pypi_password }}
33 | repository_url: https://test.pypi.org/legacy/
34 |
35 | - name: Publish distribution 📦 to PyPI
36 | if: startsWith(github.event.ref, 'refs/tags') || github.event_name == 'release'
37 | uses: pypa/gh-action-pypi-publish@v1.12.4
38 | with:
39 | user: __token__
40 | password: ${{ secrets.pypi_password }}
41 |
--------------------------------------------------------------------------------
/.github/workflows/stale-branches.yaml:
--------------------------------------------------------------------------------
1 | name: Delete abandoned branches
2 |
3 | on:
4 | pull_request:
5 | branches: ["main"]
6 | paths:
7 | - ".github/workflows/stale-branches.yaml"
8 | # Run daily at midnight
9 | schedule:
10 | - cron: "0 0 * * *"
11 | # Allow workflow to be manually run from the GitHub UI
12 | workflow_dispatch:
13 |
14 | concurrency:
15 | group: ${{ github.workflow }}-${{ github.ref }}
16 | cancel-in-progress: true
17 |
18 | jobs:
19 | cleanup-old-branches:
20 | runs-on: ubuntu-latest
21 | name: Satisfy my repo CDO
22 | env:
23 | DRY_RUN: no
24 | steps:
25 | - name: Set dry run
26 | if: ${{ github.event_name == 'pull_request' }}
27 | run: echo "DRY_RUN=yes" >> $GITHUB_ENV
28 | - name: Delete those pesky dead branches
29 | uses: phpdocker-io/github-actions-delete-abandoned-branches@v2
30 | id: delete_stuff
31 | with:
32 | github_token: ${{ github.token }}
33 | last_commit_age_days: 90
34 | ignore_branches: main
35 | # Disable dry run and actually get stuff deleted
36 | # For a PR, always perform dry run
37 | dry_run: ${{ env.DRY_RUN }}
38 |
39 | - name: Get output
40 | run: "echo 'Deleted branches: ${{ steps.delete_stuff.outputs.deleted_branches }}'"
41 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 |
29 | # Python venv
30 | bin/
31 | lib64
32 | pyvenv.cfg
33 | share/
34 | etc/
35 | include/
36 |
37 | # PyInstaller
38 | # Usually these files are written by a python script from a template
39 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
40 | *.manifest
41 | *.spec
42 |
43 | # Installer logs
44 | pip-log.txt
45 | pip-delete-this-directory.txt
46 |
47 | # Unit test / coverage reports
48 | htmlcov/
49 | .tox/
50 | .nox/
51 | .coverage
52 | .coverage.*
53 | .cache
54 | nosetests.xml
55 | coverage.xml
56 | *.cover
57 | *.py,cover
58 | .hypothesis/
59 | .pytest_cache/
60 |
61 | # Translations
62 | *.mo
63 | *.pot
64 |
65 | # Sphinx documentation
66 | docs/_build/
67 | docs/source/api/
68 | docs/source/*.md
69 | docs/source/notebooks/
70 |
71 | # PyBuilder
72 | target/
73 |
74 | # Jupyter Notebook
75 | .ipynb_checkpoints
76 |
77 | # VSCode
78 | .vscode/
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
88 | __pypackages__/
89 |
90 | # Celery stuff
91 | celerybeat-schedule
92 | celerybeat.pid
93 |
94 | # SageMath parsed files
95 | *.sage.py
96 |
97 | # Environments
98 | .env
99 | .venv
100 | env/
101 | venv/
102 | ENV/
103 | env.bak/
104 | venv.bak/
105 |
106 | # Spyder project settings
107 | .spyderproject
108 | .spyproject
109 |
110 | # Rope project settings
111 | .ropeproject
112 |
113 | # mkdocs documentation
114 | /site
115 |
116 | # mypy
117 | .mypy_cache/
118 | .dmypy.json
119 | dmypy.json
120 |
121 | # Pyre type checker
122 | .pyre/
123 |
124 | # PyCharm
125 | .idea/
126 |
127 | # Lightning logs
128 | lightning_logs
129 | *.gz
130 | .DS_Store
131 | .*_submit.py
132 |
133 | # Editor temporary file
134 | *.swn
135 | *.swo
136 | *.swp
137 | *.swm
138 | *~
139 |
140 | # Build artifacts
141 | scripts/build
142 |
143 | # Profiler traces
144 | benchmarks/traces
145 |
146 | # benchmark results
147 | .results
148 |
149 | .ruff_cache/
150 |
151 | # come CI artifacts
152 | notebooks/all.txt
153 | notebooks/ci.txt
154 |
155 | quickstart_report.json
156 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | default_language_version:
2 | python: python3.10
3 |
4 | ci:
5 | autofix_prs: true
6 | autoupdate_commit_msg: "[pre-commit.ci] pre-commit suggestions"
7 | # submodules: true
8 |
9 | repos:
10 | - repo: https://github.com/pre-commit/pre-commit-hooks
11 | rev: v5.0.0
12 | hooks:
13 | - id: end-of-file-fixer
14 | - id: trailing-whitespace
15 | exclude: README.md
16 | - id: check-case-conflict
17 | - id: check-yaml
18 | - id: check-toml
19 | - id: check-json
20 | - id: check-added-large-files
21 | args: ["--maxkb=400", "--enforce-all"]
22 | exclude: notebooks
23 | - id: check-docstring-first
24 | - id: detect-private-key
25 |
26 | - repo: https://github.com/asottile/pyupgrade
27 | rev: v3.20.0
28 | hooks:
29 | - id: pyupgrade
30 | args: ["--py310-plus"]
31 | name: Upgrade code
32 | exclude: "examples|thunder/tests/test_interpreter.py|thunder/tests/test_jit_general.py"
33 |
34 | - repo: https://github.com/codespell-project/codespell
35 | rev: v2.4.1
36 | hooks:
37 | - id: codespell
38 | additional_dependencies: [tomli]
39 | #args: ["--write-changes"] # uncomment if you want to get automatic fixing
40 |
41 | - repo: https://github.com/psf/black
42 | rev: 25.1.0
43 | hooks:
44 | - id: black
45 | name: Black code
46 | exclude: "examples"
47 |
48 | - repo: https://github.com/executablebooks/mdformat
49 | rev: 0.7.22
50 | hooks:
51 | - id: mdformat
52 | additional_dependencies:
53 | - mdformat-gfm
54 | - mdformat-black
55 | - mdformat_frontmatter
56 | exclude: "examples"
57 |
58 | - repo: https://github.com/sphinx-contrib/sphinx-lint
59 | rev: v1.0.0
60 | hooks:
61 | - id: sphinx-lint
62 |
63 | - repo: https://github.com/asottile/yesqa
64 | rev: v1.5.0
65 | hooks:
66 | - id: yesqa
67 |
68 | # - repo: https://github.com/charliermarsh/ruff-pre-commit
69 | # rev: v0.0.270
70 | # hooks:
71 | # - id: ruff
72 | # args: ["--fix"]
73 |
74 | - repo: https://github.com/pre-commit/mirrors-prettier
75 | rev: v3.1.0
76 | hooks:
77 | - id: prettier
78 | # https://prettier.io/docs/en/options.html#print-width
79 | files: \.(json|yml|yaml|toml)
80 | args: ["--print-width=120"]
81 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | recursive-exclude __pycache__ *.py[cod] *.orig
2 |
3 | # Include the README and CHANGELOG
4 | include *.md
5 | recursive-include pl_sandbox *.md
6 |
7 | # Include the license file
8 | include LICENSE
9 |
10 | # Exclude build configs
11 | exclude *.toml
12 | exclude *.svg
13 | exclude *.yml
14 | exclude *.yaml
15 |
16 | # exclude tests from package
17 | recursive-exclude thunder/tests *
18 | recursive-exclude site *
19 | exclude thunder/tests
20 |
21 | # Exclude the documentation files
22 | recursive-exclude docs *
23 | exclude docs
24 |
25 | # Include the Requirements
26 | recursive-include requirements *.txt
27 |
28 | # Exclude Makefile
29 | exclude Makefile
30 |
31 | prune .git
32 | prune .github
33 | prune notebook*
34 | prune temp*
35 | prune test*
36 | prune benchmark*
37 | prune examples*
38 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | .PHONY: test clean docs
2 |
3 | # assume you have installed need packages
4 | export SPHINX_MOCK_REQUIREMENTS=0
5 |
6 | test: clean
7 | pip install -q -r requirements.txt -r requirements/test.txt
8 |
9 | # use this to run tests
10 | python -m coverage run --source thunder -m pytest thunder tests -v
11 | python -m coverage report
12 |
13 | get-sphinx-theme:
14 | pip install -q awscli
15 | mkdir -p dist/
16 | aws s3 sync --no-sign-request s3://sphinx-packages/ dist/
17 | pip install lai-sphinx-theme -f dist/
18 |
19 | docs: clean get-sphinx-theme
20 | pip install -e . --quiet -r requirements/docs.txt -f https://download.pytorch.org/whl/cpu/torch_stable.html
21 | cd docs ; python -m sphinx -b html -W --keep-going source build
22 |
23 | clean:
24 | # clean all temp runs
25 | rm -rf .mypy_cache
26 | rm -rf .pytest_cache
27 | rm -rf ./docs/build
28 | rm -rf ./docs/source/**/generated
29 | rm -rf ./docs/source/api
30 | rm -rf _ckpt_*
31 |
--------------------------------------------------------------------------------
/dockers/README.md:
--------------------------------------------------------------------------------
1 | # Docker images
2 |
3 | ## Build images from Dockerfiles
4 |
5 | You can build it on your own, note it takes lots of time, be prepared.
6 |
7 | ```bash
8 | # build with specific arguments
9 | docker image build -t lightning:ubuntu-cuda-py3.10-cuda12.1.1 -f dockers/ubuntu-cuda/Dockerfile --build-arg "CUDA_VERSION=12.1.1" .
10 | ```
11 |
12 | To run your docker use
13 |
14 | ```bash
15 | docker image list
16 | docker run --rm -it pytorch-lightning:ubuntu-cuda-py3.10-cuda11.7.0 bash
17 | ```
18 |
19 | ## Run docker image with GPUs
20 |
21 | To run docker image with access to your GPUs, you need to install
22 |
23 | ```bash
24 | # Add the package repositories
25 | distribution=$(. /etc/os-release;echo $ID$VERSION_ID)
26 | curl -s -L https://nvidia.github.io/nvidia-docker/gpgkey | sudo apt-key add -
27 | curl -s -L https://nvidia.github.io/nvidia-docker/$distribution/nvidia-docker.list | sudo tee /etc/apt/sources.list.d/nvidia-docker.list
28 |
29 | sudo apt-get update && sudo apt-get install -y nvidia-container-toolkit
30 | sudo systemctl restart docker
31 | ```
32 |
33 | and later run the docker image with `--gpus=all`. For example,
34 |
35 | ```bash
36 | docker run --rm -it --gpus=all pytorchlightning/lightning:ubuntu-cuda-py3.10-cuda12.1.0
37 | ```
38 |
--------------------------------------------------------------------------------
/dockers/with-apex/Dockerfile:
--------------------------------------------------------------------------------
1 | # Copyright The Lightning AI team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | ARG BASE_IMAGE_TAG
16 |
17 | FROM pytorchlightning/lightning-thunder:${BASE_IMAGE_TAG}
18 |
19 | ARG APEX_CHECKOUT="master"
20 |
21 | SHELL ["/bin/bash", "-c"]
22 |
23 | RUN \
24 | # building Apex from source
25 | pip install "pip>=23.1" packaging && \
26 | git clone https://github.com/NVIDIA/apex && \
27 | cd apex && \
28 | git checkout ${APEX_CHECKOUT} && \
29 | # https://github.com/NVIDIA/apex#linux
30 | pip install -v \
31 | --disable-pip-version-check \
32 | --no-cache-dir \
33 | --no-build-isolation \
34 | --config-settings "--build-option=--xentropy" \
35 | . && \
36 | cd .. && \
37 | rm -rf apex
38 |
39 | RUN \
40 | # Show what we have
41 | pip --version && \
42 | pip list && \
43 | python -c "import apex"
44 |
--------------------------------------------------------------------------------
/dockers/with-dev/Dockerfile:
--------------------------------------------------------------------------------
1 | # Copyright The Lightning AI team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | ARG BASE_IMAGE_TAG
16 |
17 | FROM pytorchlightning/lightning-thunder:${BASE_IMAGE_TAG}
18 |
19 | SHELL ["/bin/bash", "-c"]
20 |
21 | COPY requirements/ requirements/
22 |
23 | RUN \
24 | ls -lh requirements/ && \
25 | CUDA_VERSION_MM=${CUDA_VERSION%.*} && \
26 | MAKEFLAGS="-j$(( $(nproc) / 4 ))" && \
27 | pip install -U -r requirements/test.txt && \
28 | rm -rf requirements/
29 |
30 | RUN \
31 | # Show what we have
32 | pip --version && \
33 | pip list
34 |
--------------------------------------------------------------------------------
/docs/.build_docs.sh:
--------------------------------------------------------------------------------
1 | make clean
2 | make html --debug --jobs $(nproc)
3 |
--------------------------------------------------------------------------------
/docs/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | # .readthedocs.yml
2 | # Read the Docs configuration file
3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
4 |
5 | # Required
6 | version: 2
7 |
8 | # Build documentation in the docs/ directory with Sphinx
9 | # reference: https://docs.readthedocs.io/en/stable/config-file/v2.html#sphinx
10 | sphinx:
11 | fail_on_warning: true
12 | configuration: docs/conf.py
13 |
14 | build:
15 | os: "ubuntu-22.04"
16 | tools:
17 | python: "3.10"
18 | commands:
19 | - printenv
20 | - pwd ; pip install -q py-tree ; py-tree .
21 | - make docs
22 | - mkdir -p _readthedocs ; mv docs/build _readthedocs/html
23 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line.
5 | SPHINXOPTS = -W
6 | SPHINXBUILD = python $(shell which sphinx-build)
7 | SOURCEDIR = source
8 | BUILDDIR = build
9 |
10 | # Put it first so that "make" without argument is like "make help".
11 | help:
12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
13 |
14 | .PHONY: help Makefile
15 |
16 | # Catch-all target: route all unknown targets to Sphinx using the new
17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
18 | %: Makefile
19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
20 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=source
11 | set BUILDDIR=build
12 |
13 | if "%1" == "" goto help
14 |
15 | %SPHINXBUILD% >NUL 2>NUL
16 | if errorlevel 9009 (
17 | echo.
18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
19 | echo.installed, then set the SPHINXBUILD environment variable to point
20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
21 | echo.may add the Sphinx directory to PATH.
22 | echo.
23 | echo.If you don't have Sphinx installed, grab it from
24 | echo.http://sphinx-doc.org/
25 | exit /b 1
26 | )
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/docs/source/_static/copybutton.js:
--------------------------------------------------------------------------------
1 | /* Copied from the official Python docs: https://docs.python.org/3/_static/copybutton.js */
2 | $(document).ready(function() {
3 | /* Add a [>>>] button on the top-right corner of code samples to hide
4 | * the >>> and ... prompts and the output and thus make the code
5 | * copyable. */
6 | var div = $('.highlight-python .highlight,' +
7 | '.highlight-python3 .highlight,' +
8 | '.highlight-pycon .highlight,' +
9 | '.highlight-default .highlight');
10 | var pre = div.find('pre');
11 |
12 | // get the styles from the current theme
13 | pre.parent().parent().css('position', 'relative');
14 | var hide_text = 'Hide the prompts and output';
15 | var show_text = 'Show the prompts and output';
16 | var border_width = pre.css('border-top-width');
17 | var border_style = pre.css('border-top-style');
18 | var border_color = pre.css('border-top-color');
19 | var button_styles = {
20 | 'cursor':'pointer', 'position': 'absolute', 'top': '0', 'right': '0',
21 | 'border-color': border_color, 'border-style': border_style,
22 | 'border-width': border_width, 'color': border_color, 'text-size': '75%',
23 | 'font-family': 'monospace', 'padding-left': '0.2em', 'padding-right': '0.2em',
24 | 'border-radius': '0 3px 0 0'
25 | }
26 |
27 | // create and add the button to all the code blocks that contain >>>
28 | div.each(function(index) {
29 | var jthis = $(this);
30 | if (jthis.find('.gp').length > 0) {
31 | var button = $('>>>');
32 | button.css(button_styles)
33 | button.attr('title', hide_text);
34 | button.data('hidden', 'false');
35 | jthis.prepend(button);
36 | }
37 | // tracebacks (.gt) contain bare text elements that need to be
38 | // wrapped in a span to work with .nextUntil() (see later)
39 | jthis.find('pre:has(.gt)').contents().filter(function() {
40 | return ((this.nodeType == 3) && (this.data.trim().length > 0));
41 | }).wrap('');
42 | });
43 |
44 | // define the behavior of the button when it's clicked
45 | $('.copybutton').click(function(e){
46 | e.preventDefault();
47 | var button = $(this);
48 | if (button.data('hidden') === 'false') {
49 | // hide the code output
50 | button.parent().find('.go, .gp, .gt').hide();
51 | button.next('pre').find('.gt').nextUntil('.gp, .go').css('visibility', 'hidden');
52 | button.css('text-decoration', 'line-through');
53 | button.attr('title', show_text);
54 | button.data('hidden', 'true');
55 | } else {
56 | // show the code output
57 | button.parent().find('.go, .gp, .gt').show();
58 | button.next('pre').find('.gt').nextUntil('.gp, .go').css('visibility', 'visible');
59 | button.css('text-decoration', 'none');
60 | button.attr('title', hide_text);
61 | button.data('hidden', 'false');
62 | }
63 | });
64 | });
65 |
--------------------------------------------------------------------------------
/docs/source/_static/images/LightningThunderDarkModewByline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-AI/lightning-thunder/8bd802c89988e6668f5328d3927deda1b4e9e706/docs/source/_static/images/LightningThunderDarkModewByline.png
--------------------------------------------------------------------------------
/docs/source/_static/images/LightningThunderLightModewByline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-AI/lightning-thunder/8bd802c89988e6668f5328d3927deda1b4e9e706/docs/source/_static/images/LightningThunderLightModewByline.png
--------------------------------------------------------------------------------
/docs/source/_static/images/how_it_works.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-AI/lightning-thunder/8bd802c89988e6668f5328d3927deda1b4e9e706/docs/source/_static/images/how_it_works.png
--------------------------------------------------------------------------------
/docs/source/_static/images/lightning_thunder_lightmode_nobyline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-AI/lightning-thunder/8bd802c89988e6668f5328d3927deda1b4e9e706/docs/source/_static/images/lightning_thunder_lightmode_nobyline.png
--------------------------------------------------------------------------------
/docs/source/_static/images/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-AI/lightning-thunder/8bd802c89988e6668f5328d3927deda1b4e9e706/docs/source/_static/images/logo.png
--------------------------------------------------------------------------------
/docs/source/_static/images/normalized_training_throughput_zero2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-AI/lightning-thunder/8bd802c89988e6668f5328d3927deda1b4e9e706/docs/source/_static/images/normalized_training_throughput_zero2.png
--------------------------------------------------------------------------------
/docs/source/_static/images/pretrain_perf.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-AI/lightning-thunder/8bd802c89988e6668f5328d3927deda1b4e9e706/docs/source/_static/images/pretrain_perf.png
--------------------------------------------------------------------------------
/docs/source/_static/images/training_throughput_single.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-AI/lightning-thunder/8bd802c89988e6668f5328d3927deda1b4e9e706/docs/source/_static/images/training_throughput_single.png
--------------------------------------------------------------------------------
/docs/source/_templates/theme_variables.jinja:
--------------------------------------------------------------------------------
1 | {%- set external_urls = {
2 | 'github': 'https://github.com/Lightning-AI/lightning-thunder',
3 | 'github_issues': 'https://github.com/Lightning-AI/lightning-thunder/issues',
4 | 'contributing': 'https://github.com/Lightning-AI/lightning-thunder/blob/master/CONTRIBUTING.md',
5 | 'governance': 'https://github.com/Lightning-AI/lightning-thunder/blob/master/governance.md',
6 | 'docs': 'https://lightning-thunder.rtfd.io/en/latest',
7 | 'twitter': 'https://twitter.com/LightningAI',
8 | 'discuss': 'https://pytorch-lightning.slack.com',
9 | 'tutorials': 'https://lightning-thunder.readthedocs.io/en/latest/#tutorials',
10 | 'previous_pytorch_versions': 'https://lightning-thunder.rtfd.io/en/latest/',
11 | 'home': 'https://lightning-thunder.rtfd.io/en/latest/',
12 | 'get_started': 'https://lightning-thunder.readthedocs.io/en/latest/introduction_guide.html',
13 | 'features': 'https://lightning-thunder.rtfd.io/en/latest/',
14 | 'blog': 'https://www.Lightning-AI.ai/blog',
15 | 'resources': 'https://lightning-thunder.readthedocs.io/en/latest/#community-examples',
16 | 'support': 'https://lightning-thunder.rtfd.io/en/latest/',
17 | }
18 | -%}
19 |
--------------------------------------------------------------------------------
/docs/source/advanced/extending.rst:
--------------------------------------------------------------------------------
1 | Extending Thunder
2 | #################
3 |
4 | ..
5 | TODO RC1: update using the extend API
6 |
7 | This section describes how to add an executor to Thunder for a PyTorch operation.
8 |
9 | First, define a Python function with the same signature as the targeted operation, and have it call your implementation. For example, the Apex executor for ``torch.nn.functional.cross_entropy`` might define its implementation like::
10 |
11 | import torch
12 | import xentropy_cuda
13 |
14 | def apex_xentropy(
15 | a: torch.Tensor, # a is an actual PyTorch tensor
16 | target,
17 | weight=None,
18 | size_average=None,
19 | ignore_index=-100,
20 | reduce=None,
21 | reduction="mean",
22 | label_smoothing=0.0,
23 | ):
24 | losses, max_log_sum_exp = xentropy_cuda.forward(a, target, label_smoothing, half_to_float)
25 |
26 | When this implementation is used it will be called with actual PyTorch tensors, and not with proxies.
27 |
28 | Next, define a “checker” function with the same signature as the targeted operation that returns True if your operation can execute the targeted operation and False otherwise. Checkers, unlike the implementations, are called with proxies, and not actual PyTorch tensors, because they're called at optimization time. The purpose of a checker function is to let executors target only specific inputs to an operation, and defer to another executor on other inputs.
29 |
30 | A checker function for the Apex executor might look like::
31 |
32 | from thunder.core.proxies import TensorProxy
33 |
34 | def apex_xentropy_checker(
35 | a: TensorProxy, # a is a proxy
36 | target,
37 | weight=None,
38 | size_average=None,
39 | ignore_index=-100,
40 | reduce=None,
41 | reduction="mean",
42 | label_smoothing=0.0,
43 | ):
44 | # Apex's xentropy only supports "sum", "mean" or "none" reductions
45 | if reduction not in ["sum", "mean", "none"]:
46 | return False
47 |
48 | return True
49 |
50 | Create a mapping from the name of the PyTorch operation to your replacement implementation's name, its checker, and its implementation::
51 |
52 | _op_to_xentropy = {
53 | "torch.nn.functional.cross_entropy": ("apex_xentropy", apex_xentropy_checker, apex_xentropy),
54 | }
55 |
56 | Then define a registration function that practitioners can call to access your executor::
57 |
58 | def register_apex_xentropyex(*, add_to_default_executors: bool = True) -> None:
59 | from thunder.executors import add_operator_executor
60 |
61 | return add_operator_executor("apex_xentropy", _op_to_xentropy, add_to_default_executors=add_to_default_executors)
62 |
63 | You can test your executor by registering it, compiling a function that calls the targeted operator, and then verifying that your operation is called (by inspecting the execution trace) and producing the correct output. A good example of this is the tests for the Apex executor.
64 |
--------------------------------------------------------------------------------
/docs/source/advanced/inside_thunder.rst:
--------------------------------------------------------------------------------
1 | Inside Thunder
2 | ##############
3 |
4 | This section elaborates on the design of some of Thunder's internals.
5 |
6 | Bytecode interpretation
7 | =======================
8 |
9 | Thunder's interpreter works by:
10 |
11 | 1. Disassembling the PyTorch module or function into CPython bytecode
12 | 2. Interpreting the bytecode using an extended Python interpreter
13 | 3. Generating a sequential trace of operations on tensors and numbers
14 |
15 | Representing Operations
16 | =======================
17 |
18 | Thunder supports a subset of PyTorch's operators (see ``thunder.torch.__init__.py`` for which operators are supported).
19 | Thunder has to define its own versions of PyTorch operators so it knows how to compile and trace them properly. This section details how Thunder represents operations.
20 |
21 | Let's start by looking at how ``torch.nn.functional.softmax`` appears in a trace of the nanoGPT (https://github.com/karpathy/nanoGPT) model::
22 |
23 | t63 = ltorch.softmax(t52, dim=-1) # t63: "cuda:0 f16[8, 12, 64, 64]"
24 | # t53 = prims.convert_element_type(t52, dtypes.float32) # t53: "cuda:0 f32[8, 12, 64, 64]"
25 | # t55 = ltorch.amax(t53, -1, keepdim=True) # t55: "cuda:0 f32[8, 12, 64, 1]"
26 | # t54 = prims.amax(t53, (3,)) # t54: "cuda:0 f32[8, 12, 64]"
27 | # t55 = prims.broadcast_in_dim(t54, [8, 12, 64, 1], [0, 1, 2]) # t55: "cuda:0 f32[8, 12, 64, 1]"
28 | # t57 = ltorch.sub(t53, t55) # t57: "cuda:0 f32[8, 12, 64, 64]"
29 | # t56 = prims.broadcast_in_dim(t55, [8, 12, 64, 64], (0, 1, 2, 3)) # t56: "cuda:0 f32[8, 12, 64, 64]"
30 | # t57 = prims.sub(t53, t56) # t57: "cuda:0 f32[8, 12, 64, 64]"
31 | # t58 = ltorch.exp(t57) # t58: "cuda:0 f32[8, 12, 64, 64]"
32 | # t58 = prims.exp(t57) # t58: "cuda:0 f32[8, 12, 64, 64]"
33 | # t60 = ltorch.sum(t58, -1, keepdim=True) # t60: "cuda:0 f32[8, 12, 64, 1]"
34 | # t59 = prims.sum(t58, (3,)) # t59: "cuda:0 f32[8, 12, 64]"
35 | # t60 = prims.broadcast_in_dim(t59, [8, 12, 64, 1], [0, 1, 2]) # t60: "cuda:0 f32[8, 12, 64, 1]"
36 | # t62 = ltorch.true_divide(t58, t60) # t62: "cuda:0 f32[8, 12, 64, 64]"
37 | # t61 = prims.broadcast_in_dim(t60, [8, 12, 64, 64], (0, 1, 2, 3)) # t61: "cuda:0 f32[8, 12, 64, 64]"
38 | # t62 = prims.div(t58, t61) # t62: "cuda:0 f32[8, 12, 64, 64]"
39 | # t63 = prims.convert_element_type(t62, dtypes.float16) # t63: "cuda:0 f16[8, 12, 64, 64]"
40 |
41 | Instead of the original operation we see a call to the corresponding ``thunder.torch`` operation, then a comment that describes the decomposition of this operation into other ``thunder.torch`` calls (identified by ``ltorch`` in the snippet above) and ``thunder.core.prims`` calls, with the calls to the *primitives* defined in ``thunder.core.prims`` being “terminal” — they decompose into nothing.
42 |
43 | Every Thunder operation can be decomposed into one or more primitives, and these decompositions are essential to trace, transform, and optimize them. For example, every primitive operation defines a “meta function” that maps proxy inputs to proxy outputs. When tracing, some inputs, like PyTorch tensors, are replaced with proxies. We know what the proxy output of operations like ``torch.softmax`` would be by decomposing it into primitives, essentially giving it an implicit meta function. If operations weren't defined in terms of primitives, then each operation would require its own meta function, and its own rule for transforms like autograd, and executors would have to reason about every Pytorch operator.
44 |
45 | Primitives are what let Thunder define operations without dramatically increasing its complexity, but the set of primitives must also be carefully chosen. Primitives serve two purposes:
46 |
47 | - They must be as simple and few in number as possible. A small set of simple operations is easier to analyze, transform, optimize, and execute than a large set of complicated operations.
48 | - They must be expressive enough to describe and facilitate the execution of deep learning and scientific computations.
49 |
50 | Since primitives are as simple as possible, they do not broadcast or type promote, and they typically do not have default arguments. For example, in the above trace the call to ``ltorch.sub`` decomposes into a broadcast and then the primitive for subtraction because broadcast is its own primitive.
51 |
52 | Thunder primitives are similar to the operations in JAX's ``jax.lax`` module, which is a wrapper around XLA's HLO operations.
53 |
54 | Because the prims are so simple and few, writing the decompositions of PyTorch operations directly in terms of primitives would be painstaking. Instead, Thunder has a “core language” of common deep learning operations, and Thunder's PyTorch decompositions typically call these core language operations or other PyTorch operations. Note that core language operations don't appear in traces for simplicity (they have no other use except producing decompositions and can't be executed directly).
55 |
--------------------------------------------------------------------------------
/docs/source/basic/overview.rst:
--------------------------------------------------------------------------------
1 | Thunder Overview
2 | ################
3 |
4 | This section introduces Thunder's core concepts and architecture. For more details, see :doc:`Inside thunder <../advanced/inside_thunder>`.
5 |
6 | Thunder is a deep learning compiler for PyTorch, which means it translates calls to PyTorch modules into a format that is easy to transform and that executors can consume to produce fast executables. This translation must produce a simple representation focusing on tensor operations. The format we've chosen, like other deep learning compilers, is a sequence of operations called a program *trace*.
7 |
8 | This translation begins with::
9 |
10 | jitted_model = thunder.jit(my_module)
11 |
12 | or::
13 |
14 | jitted_fn = thunder.jit(my_function)
15 |
16 | When given a module, the call to ``thunder.jit()`` returns a Thunder-optimized module that shares parameters with the original module (as demonstrated in the :doc:`Train a MLP on MNIST ` example), and when given a function it returns a function that when called will jit compile a path through the original function given information about the inputs.
17 |
18 | When the jitted module or function is called::
19 |
20 | jitted_model(*args, **kwargs)
21 |
22 | or::
23 |
24 | jitted_fn(*args, **kwargs)
25 |
26 |
27 | As suggested above, Thunder begins reviewing the module's or function's Python bytecode and the input. It may be surprising that Thunder considers the inputs at all, but since control flow (and therefore the operations captured) may vary depending on the input, this is actually required to produce a trace. These traces are cached, so that if inputs of the same type, shape, etc are used again, the trace can be reused.
28 |
29 | Traces are generated by running the bytecode through a custom Python interpreter, which is itself implemented in Python. This interpreter has been extended to perform instructions in a different way compared to what standard CPython does. In particular, it constructs a trace of operations performed on tensors or numbers, and keeps track of the provenance of all objects in the program, whether they originated from inside the interpreter or outside.
30 |
31 | Much like other machine learning frameworks, Traces don't typically deal directly with PyTorch tensors, but with *proxies* that only have metadata like shape, device, dtype, and whether the tensor requires grad or not. As such, during interpretation for trace generation, the execution of the program doesn't perform any computation on accelerators. Instead, it records the operators along one path of the traceable function.
32 |
33 | If replacing CPython with an interpreter written in Python sounds problematic from a performance perspective, you would be largely correct. We haven't yet put any time into optimizing it, and we think it consumes roughly 400x as much CPU time as CPython. However, the function only needs to be jitted once per equivalence class of inputs, and CPU is not a bottleneck in most machine learning pipelines. As long as the metadata of the inputs (such as a tensor's shape) and control flow conditions are not changed, we can rely on smart caching to immediately execute an optimized trace. The end result is a faster total execution time.
34 |
35 | Traces can be transformed (like for ``backward()``) and optimized (like by replacing calls to eager PyTorch operations with calls to faster executors), and the final result of this process is an *execution trace*. Thunder executes the original call by converting the execution trace into a Python function and calling that function with the actual inputs. For details about this optimization process, see the :doc:`thunder step by step ` section.
36 |
37 | To recap, the complete translation process is:
38 |
39 | - For PyTorch modules, a Thunder-optimized module is created from the original module.
40 | - For PyTorch functions, compilation produces a compiled function.
41 | - When the module or function is called, the trace is generated, swapping some inputs with “proxies”.
42 | - The trace is transformed and optimized to produce an execution trace.
43 | - The execution trace is converted into a Python function and called.
44 |
45 | As mentioned, this translation process is often slow - it takes tens of seconds for nanoGPT's (https://github.com/karpathy/nanoGPT) largest configuration - so Thunder's performance model expects relatively few of these translations and then a lot of uses of the result. This corresponds with many training and inference patterns, where the same program is executed many times.
46 |
--------------------------------------------------------------------------------
/docs/source/basic/sharp_edges.rst:
--------------------------------------------------------------------------------
1 | The sharp edges
2 | ###############
3 |
4 | This section describes features in the language that are not (yet) supported by Thunder, along with their workarounds. You might encounter these when compiling a module or function with Thunder. The examples should give you a good idea of how to change your program so that it works.
5 |
6 | Note that the fact that something is not supported today doesn't mean it won't be supported at some point in the future. Feel free to reach out to help us prioritize.
7 |
8 | Inplace operations
9 | ------------------
10 |
11 | Inplace PyTorch operations like `t.add_(1.0)` are not supported in Thunder yet. Support for inplace operations is coming soon.
12 |
13 |
14 | Tensor subclasses
15 | -----------------
16 |
17 | Thunder currently supports Python data types and PyTorch tensors as inputs of functions and models.
18 |
19 | Subclasses of these types, e.g. lazy tensors, nested tensors, or sparse tensors are not supported today.
20 |
21 |
22 | Tracing Python builtins, standard library operations and functions that call other languages
23 | --------------------------------------------------------------------------------------------
24 |
25 | Calling a Python builtin, standard library operation, or a function that calls into another language is safe to trace, so long as the following rules are observed:
26 |
27 | 1. The function should not have side effects. For example, calling ``print()`` will execute the ``print()`` function while tracing, but since it's not a Thunder operation it will not appear in a trace, and so future cached executions will not execute the ``print()`` statement.
28 | 2. The function must not manipulate tensor data or metadata. Since the operation won't appear in a trace, these manipulations won't be repeated by Thunder, and may even cause a crash while tracing. To implement such operations, see :doc:`Adding Custom Operators <../notebooks/adding_custom_operator>`
29 | 3. The function must not produce different results across invocations. Again, since the operation won't appear in traces, Thunder cannot replicate an operation that produces different results when it's invoked, like ``random.random()`` will.
30 |
31 | ..
32 | Certain op-level behavior
33 | -------------------------
34 | 1. Ops which have not yet been added to Thunder. Please let us know if there’s missing operator support you would like to see and we will be happy to help.
35 | 2. Data dependent control flow (e.g. ``if x.any()``). Since Thunder generates traces of programs ahead of the actual execution, control flow depending on the values of tensors as opposed to their metadata cannot be handled by Thunder.
36 |
37 |
38 | Using Thunder-optimized Modules
39 | -------------------------------
40 |
41 | Compiling a module produces a Thunder-optimized module”. A Thunder-optimized module is less dynamic than the original module, which facilitates tracing and optimization. It has a reference to the original module, and it shares its parameters with it.
42 |
43 | While modifying the original model's parameters will reflect in the Thunder-optimized module, other changes to the original module will not. In particular:
44 |
45 | - Whether model is in ``train`` or ``eval`` mode is captured at compilation time and constant
46 | - The structure of the module is captured at compilation time, and changing the original module's structure will likely break the Thunder-optimized module
47 | - Non-parameter attributes of the module may or may not be captured at compile time and treated as constants
48 |
49 | Not all features of PyTorch modules are currently supported, either. Module hooks are not supported, and adding new module attributes in a module's ``forward()`` method is only partially supported.
50 |
--------------------------------------------------------------------------------
/docs/source/fundamentals/examine.rst:
--------------------------------------------------------------------------------
1 | Using Examine
2 | #############
3 |
4 | We recommend using Thunder's ``examine()`` before compiling a function or a module.
5 |
6 | Thunder cannot run every PyTorch module, but you can quickly find out what is missing using ``examine()``.
7 |
8 | ``examine()`` not only determines if Thunder can compile the module, but provides a helpful report to use when filing an issue requesting support.
9 |
10 | You can run examine like this::
11 |
12 | from thunder.examine import examine
13 |
14 | model = MyModel(...)
15 | examine(model, *args, **kwargs)
16 |
17 | Where ``*args and **kwargs`` are valid inputs to the model. If examine determines that Thunder can run the module or function as expected, it will print::
18 |
19 | The function appears to be working as expected
20 |
21 | When ``examine`` encounters a module or function with one or more operators it doesn't support, it will specify the operators, like this::
22 |
23 | def foo(a):
24 | return torch.triu(a)
25 |
26 | import torch
27 | import thunder
28 | from thunder.examine import examine
29 |
30 | a = torch.full((2, 2), 1., device='cuda')
31 | examine(foo, a)
32 |
33 | Running the above will print::
34 |
35 | Found 1 distinct operations, of which 0 (0.0%) are supported
36 | Please file an issue requesting the following operators here: https://github.com/Lightning-AI/lightning-thunder/issues/new
37 | _VariableFunctionsClass.triu of torch
38 |
39 | To recap, ``examine()`` lets you know if Thunder can run a module, and if it can't it will provide a report to file an issue asking for support.
40 |
--------------------------------------------------------------------------------
/docs/source/fundamentals/hello_world.rst:
--------------------------------------------------------------------------------
1 | Hello World
2 | ###########
3 |
4 | Here is a simple example of how Thunder lets you compile and run PyTorch modules and functions::
5 |
6 | import torch
7 | import thunder
8 |
9 | def foo(a, b):
10 | return a + b
11 |
12 | jitted_foo = thunder.jit(foo)
13 |
14 | a = torch.full((2, 2), 1)
15 | b = torch.full((2, 2), 3)
16 |
17 | result = jitted_foo(a, b)
18 |
19 | print(result)
20 |
21 | # prints
22 | # tensor(
23 | # [[4, 4],
24 | # [4, 4]])
25 |
26 | The compiled function ``jitted_foo`` takes and returns PyTorch tensors, just like the original function, so modules and functions compiled by Thunder can be used as part of bigger PyTorch programs.
27 |
28 | Thunder currently understands a subset of PyTorch operations, and a subset of Python. It's adding support quickly, however. Reach out on the Thunder repo and open an issue there to easily get help if an operator is currently not supported.
29 |
--------------------------------------------------------------------------------
/docs/source/fundamentals/installation.rst:
--------------------------------------------------------------------------------
1 | Install Lightning Thunder
2 | #########################
3 |
4 | Minimal dependencies
5 | ====================
6 |
7 | Follow these instructions to install PyTorch, nvFuser, and finally Thunder.
8 |
9 | Install PyTorch and nvFuser with pip (command shown is for CUDA 12.1 and PyTorch 2.5.x)::
10 |
11 | pip install --pre nvfuser-cu121-torch25
12 |
13 | cu121 can be replaced with cu118 depending on your CUDA version. NVFuser builds typically support the latest point release of PyTorch stable versions.
14 | For torch 2.5, cu124 is also supported. For nightly versions and more detailed instructions, please see https://github.com/NVIDIA/Fuser/#installation
15 |
16 | You're all set with minimal dependencies, so you can follow `Install Thunder`_.
17 |
18 | Dependencies so far don't include additional optimized kernels for execution like OpenAI Triton, cuDNN Fusion, or Apex.
19 | These are described in the following section, `Optional dependencies`_.
20 |
21 | Optional dependencies
22 | =====================
23 |
24 | Install Apex
25 | ------------
26 |
27 | Thunder can use NVIDIA's Apex to accelerate some PyTorch operations. To install the Apex executor, first clone the Apex git repository and then execute the following command in the project's root directory::
28 |
29 | git clone https://github.com/NVIDIA/apex.git
30 | cd apex
31 |
32 | pip install -v --no-cache-dir --no-build-isolation --config-settings "--build-option=--xentropy" ./
33 |
34 | Install cuDNN
35 | -------------
36 |
37 | Thunder can use NVIDIA's cuDNN Python frontend bindings to accelerate some PyTorch operations. cuDNN backend is a pure-lib python package which needs to downloaded separately.
38 |
39 | pip install nvidia-cudnn-cu12
40 | pip install nvidia-cudnn-frontend
41 |
42 | You're all set, now follow `Install Thunder`_.
43 |
44 | Install OpenAI Triton
45 | ---------------------
46 |
47 | Thunder can easily integrate OpenAI Triton kernels. You can install Triton using::
48 |
49 | pip install triton
50 |
51 |
52 | Install Thunder
53 | ===============
54 |
55 | You can now install Thunder::
56 |
57 | pip install git+https://github.com/Lightning-AI/lightning-thunder.git
58 |
59 | Alternatively you can clone the Thunder repository and install locally::
60 |
61 | git clone https://github.com/Lightning-AI/lightning-thunder.git
62 | cd lightning-thunder
63 |
64 | pip install .
65 |
--------------------------------------------------------------------------------
/docs/source/intermediate/additional_executors.rst:
--------------------------------------------------------------------------------
1 | Additional executors
2 | ####################
3 |
4 | nvFuser and Pytorch are not the only executors available in Thunder today. Additional executors can be added to Thunder prior to compilation through a registration mechanism, which makes it easy to have specialized executors perform certain operations more efficiently.
5 |
6 | This section contains a list of all executors supported by PyTorch beyond nvFuser and PyTorch.
7 |
8 | Triton CrossEntropy Executor
9 | ============================
10 |
11 | The Triton CrossEntropy executor can execute ``torch.cross_entropy()`` using an optimized kernel written in OpenAI Triton (https://github.com/openai/triton). It can be used like in the following example::
12 |
13 | import torch
14 | import thunder
15 | from thunder.executors.triton_crossentropy import triton_ex as triton_cross_entropy_ex
16 |
17 | def xentropy(logits, labels, weight, reduction, ignore_index):
18 | return thunder.torch.cross_entropy(
19 | logits, labels, weight=weight, reduction=reduction, ignore_index=ignore_index
20 | )
21 |
22 | jitted_xentropy = thunder.jit(
23 | xentropy,
24 | executors=[triton_cross_entropy_ex,]
25 | )
26 |
27 | device = 'cuda'
28 | dtype = torch.float32
29 |
30 | logits = torch.randn([2048, 50257], device=device, dtype=dtype)
31 | labels = torch.randint(0, 50257, [2048], device=device)
32 | weight = torch.rand(50257, device=device, dtype=dtype, requires_grad=False)
33 | reduction = "sum"
34 | ignore_index = labels[5].item()
35 |
36 | jitted_xentropy(logits, labels, weight, reduction, ignore_index)
37 | traces = thunder.last_traces(jitted_xentropy)
38 | print(traces[-1])
39 |
40 | This prints::
41 |
42 | # Constructed by Delete Last Used (took 0 milliseconds)
43 | import torch
44 | from thunder.executors.torchex import no_autocast
45 |
46 | @torch.no_grad()
47 | @no_autocast
48 | def computation(logits, labels, weight):
49 | # logits: "cuda:0 f32[2048, 50257]"
50 | # labels: "cuda:0 i64[2048]"
51 | # weight: "cuda:0 f32[50257]"
52 | t23 = triton_crossentropy(logits, labels, weight, None, 45279, None, 'sum', 0.0) # t23: "cuda:0 f32[]"
53 | del logits, labels, weight
54 | return t23
55 |
56 | As shown in the above trace, ``triton_crossentropy()`` is the one running the operation.
57 |
58 | Apex CrossEntropy Executor
59 | ==========================
60 |
61 | The Apex CrossEntropy executor can execute ``torch.cross_entropy()`` through an optimized kernel, like this::
62 |
63 | import torch
64 | import thunder
65 | from thunder.executors.apexex import apex_ex
66 |
67 | def xentropy(logits, labels):
68 | return thunder.torch.cross_entropy(
69 | logits, labels, reduction='mean', ignore_index=-1
70 | )
71 |
72 | jitted_xentropy = thunder.jit(xentropy, executors=[apex_ex,])
73 |
74 | device = 'cuda'
75 | dtype = torch.float32
76 |
77 | logits = torch.randn([2048, 50257], device=device, dtype=dtype)
78 | labels = torch.randint(0, 50257, [2048], device=device)
79 |
80 | jitted_xentropy(logits, labels)
81 | traces = thunder.last_traces(jitted_xentropy)
82 | print(traces[-1])
83 |
84 | This prints::
85 |
86 | # Constructed by Delete Last Used (took 0 milliseconds)
87 | import torch
88 | from thunder.executors.torchex import no_autocast
89 |
90 | @torch.no_grad()
91 | @no_autocast
92 | def computation(logits, labels):
93 | # logits: "cuda:0 f32[2048, 50257]"
94 | # labels: "cuda:0 i64[2048]"
95 | (t18, _) = apex_cross_entropy(logits, labels, 'mean', 0.0)
96 | del logits, labels
97 | return t18
98 |
99 | showing that Apex is running the operation.
100 |
101 | cuDNN SDPA Executor
102 | ===================
103 |
104 | TODO RC1
105 |
106 | TransformerEngine Executor
107 | ==========================
108 |
109 | TODO RC1
110 |
--------------------------------------------------------------------------------
/docs/source/intermediate/ddp.rst:
--------------------------------------------------------------------------------
1 | Distributed Data Parallel (DDP)
2 | ###############################
3 |
4 | Thunder has its own Distributed Data Parallel (DDP) transform that we recommend using, although compiled modules also work with PyTorch's DDP transform.
5 |
6 | You can wrap a model in Thunder's ddp like this::
7 |
8 | from thunder.distributed import ddp
9 |
10 | model = MyModel()
11 | ddp_model = ddp(model)
12 | cmodel = thunder.jit(ddp_model)
13 |
14 | Specifying which rank to broadcast from is optional. ``ddp()`` will broadcast from the lowest rank in that group if ``broadcast_from`` is not specified.
15 |
16 | Thunder's ddp is compatible with PyTorch distributed runners like ``torchrun`` (https://pytorch.org/docs/stable/elastic/run.html).
17 |
18 | When using PyTorch's DDP, call DDP on the jitted module::
19 |
20 | from torch.nn.parallel import DistributedDataParallel as DDP
21 |
22 | model = MyModel()
23 | jitted_model = thunder.jit(model)
24 | ddp_model = DDP(jitted_model)
25 |
26 | The ability of Thunder to express distributed algorithms like DDP as a simple transform on the trace is one of Thunder's strengths and is being leveraged to quickly implement more elaborate distributed strategies, like Fully Sharded Data Parallel (FSDP).
27 |
--------------------------------------------------------------------------------
/docs/source/intermediate/whats_next.rst:
--------------------------------------------------------------------------------
1 | What's Next
2 | ###########
3 |
4 | Thunder is developing rapidly, and this section mentions some of what's happening. Please reach out (see Get Involved) if you're interested in one of these topics.
5 |
6 | Compiling the Training Loop
7 | ===========================
8 |
9 | Thunder currently supports compiling PyTorch modules - forward computation, loss calculation, backward computation -, but we plan to support compiling the entire training loop - forward computation, loss calculation, backward computation, and the optimizer step - for maximum performance.
10 |
11 | Dynamic Caching
12 | ===============
13 |
14 | Thunder currently supports either no caching or static caching, and static caching requires recompiling whenever a module is called with inputs with metadata different than past inputs. This can be overly strict. For example, adding two tensors with shape ``(5, 5)`` is essentially the same as adding two tensors with shape ``(10, 10)``. Dynamic caching will determine if the new metadata would result in a new trace or not, significantly reducing compilation time when training some models.
15 |
16 | Memory Layouts and Strides
17 | ==========================
18 |
19 | Thunder does not currently model any stride information on tensor proxies. In the future we will likely model some stride information, like memory layout (e.g. channels-last), to support integration with PyTorch programs that use memory layout, and to let executors use memory layout to inform kernel selection.
20 |
21 | Functional transforms: vmap and AMP
22 | ===================================
23 |
24 | Thunder already has early implementations of JAX's vmap transform and PyTorch's Automatic Mixed Precision (AMP) autocasting, and we're extending our support for these transforms so practitioners can easily apply a variety of composable transforms to PyTorch modules.
25 |
--------------------------------------------------------------------------------
/docs/source/reference/clang/index.rst:
--------------------------------------------------------------------------------
1 | .. module:: thunder.clang
2 |
3 | thunder.clang
4 | =============
5 |
6 | Thunder Core Language
7 |
8 |
9 | .. autosummary::
10 | :toctree: generated/
11 |
12 | maybe_convert_to_dtype
13 | device_put
14 | arange
15 | convolution
16 | full
17 | full_like
18 | uniform
19 | uniform_like
20 | diagonal
21 | expand
22 | flatten
23 | movedim
24 | reshape
25 | slice_in_dim
26 | squeeze
27 | transpose
28 | stride_order
29 | take
30 | index_add
31 | take_along_axis
32 | scatter_add
33 | unsqueeze
34 | cat
35 | stack
36 | compute_broadcast_shape
37 | matrix_transpose
38 | maybe_broadcast
39 |
40 | Unary
41 | ~~~~~
42 |
43 | .. autosummary::
44 | :toctree: generated/
45 |
46 | abs
47 | acos
48 | acosh
49 | asin
50 | asinh
51 | atan
52 | atanh
53 | bitwise_not
54 | ceil
55 | cos
56 | cosh
57 | erf
58 | erfc
59 | erfcinv
60 | erfinv
61 | exp
62 | exp2
63 | expm1
64 | floor
65 | isfinite
66 | lgamma
67 | log
68 | log10
69 | log1p
70 | log2
71 | ndtri
72 | neg
73 | reciprocal
74 | round
75 | rsqrt
76 | sigmoid
77 | sign
78 | signbit
79 | silu
80 | sin
81 | sinh
82 | sqrt
83 | tan
84 | tanh
85 | trunc
86 |
87 | Binary
88 | ~~~~~~
89 |
90 | .. autosummary::
91 | :toctree: generated/
92 |
93 | add
94 | atan2
95 | bitwise_and
96 | bitwise_or
97 | bitwise_xor
98 | copysign
99 | eq
100 | floor_divide
101 | fmod
102 | mod
103 | ge
104 | gt
105 | logical_and
106 | le
107 | lt
108 | mul
109 | ne
110 | nextafter
111 | pow
112 | remainder
113 | sub
114 | true_divide
115 |
116 | Conditional
117 | ~~~~~~~~~~~
118 |
119 | .. autosummary::
120 | :toctree: generated/
121 |
122 | where
123 |
--------------------------------------------------------------------------------
/docs/source/reference/common/index.rst:
--------------------------------------------------------------------------------
1 | .. module:: thunder.common
2 |
3 | thunder.common
4 | ==============
5 |
6 | Common functions and classes for Thunder.
7 |
8 | .. autosummary::
9 | :toctree: generated/
10 |
11 | CACHE_OPTIONS
12 | CompileData
13 | CompileStats
14 |
--------------------------------------------------------------------------------
/docs/source/reference/core/baseutils.rst:
--------------------------------------------------------------------------------
1 | .. module:: thunder.core.baseutils
2 | :noindex:
3 |
4 | Baseutils
5 | ---------
6 |
7 | .. currentmodule:: thunder.core.baseutils
8 |
9 | .. autosummary::
10 | :toctree: generated/
11 |
12 | check
13 | check_type
14 | ProxyInterface
15 | NumberProxyInterface
16 | TensorProxyInterface
17 | SymbolInterface
18 | BoundSymbolInterface
19 |
--------------------------------------------------------------------------------
/docs/source/reference/core/codeutils.rst:
--------------------------------------------------------------------------------
1 | .. module:: thunder.core.codeutils
2 |
3 | Codeutils
4 | ---------
5 |
6 | .. currentmodule:: thunder.core.codeutils
7 |
8 | .. autosummary::
9 | :toctree: generated/
10 |
11 | prettyprint
12 | SigInfo
13 | get_siginfo
14 |
--------------------------------------------------------------------------------
/docs/source/reference/core/devices.rst:
--------------------------------------------------------------------------------
1 | .. module:: thunder.core.devices
2 |
3 | Devices
4 | -------
5 |
6 | .. currentmodule:: thunder.core.devices
7 |
8 | .. autosummary::
9 | :toctree: generated/
10 |
11 | DeviceType
12 | devicetype_string
13 | Device
14 | device_from_string
15 | to_device
16 |
--------------------------------------------------------------------------------
/docs/source/reference/core/dtypes.rst:
--------------------------------------------------------------------------------
1 | .. module:: thunder.core.dtypes
2 |
3 | Dtypes
4 | ------
5 |
6 | ``dtype``\s such as ``float32`` are exposed to :mod:`thunder` like ``thunder.float32``.
7 |
8 | .. currentmodule:: thunder.core.dtypes
9 |
10 | .. autosummary::
11 | :toctree: generated/
12 | :nosignatures:
13 |
14 | dtype
15 | to_dtype
16 | float32
17 | float16
18 | bfloat16
19 | float64
20 | int8
21 | int16
22 | int32
23 | int64
24 | uint8
25 | bool8
26 | complex32
27 | complex64
28 | complex128
29 |
--------------------------------------------------------------------------------
/docs/source/reference/core/functionalization.rst:
--------------------------------------------------------------------------------
1 | .. module:: thunder.core.functionalization
2 |
3 | In-place Functionalization
4 | --------------------------
5 |
6 | .. currentmodule:: thunder.core.functionalization
7 |
8 | .. autosummary::
9 | :toctree: generated/
10 |
11 | functionalize_inplace_ops
12 |
--------------------------------------------------------------------------------
/docs/source/reference/core/index.rst:
--------------------------------------------------------------------------------
1 | .. module:: thunder.core
2 |
3 | thunder.core
4 | ============
5 |
6 | .. toctree::
7 | :maxdepth: 1
8 |
9 | baseutils
10 | codeutils
11 | devices
12 | dtypes
13 | langctxs
14 | prims
15 | proxies
16 | pytree
17 | rematerialization
18 | symbol
19 | trace
20 | transforms
21 | functionalization
22 |
--------------------------------------------------------------------------------
/docs/source/reference/core/langctxs.rst:
--------------------------------------------------------------------------------
1 | .. module:: thunder.core.langctxs
2 |
3 | langctxs
4 | --------
5 |
6 | .. currentmodule:: thunder.core.langctxs
7 |
8 | .. autosummary::
9 | :toctree: generated/
10 |
11 | set_langctx
12 | get_langctx
13 | reset_langctx
14 | langctx
15 |
--------------------------------------------------------------------------------
/docs/source/reference/core/prims.rst:
--------------------------------------------------------------------------------
1 | .. module:: thunder.core.prims
2 |
3 | .. todo Add other prims after resolving "Cannot resolve forward reference in type annotations of "thunder.core.prims.all_reduce": name 'Callable' is not defined"
4 |
5 | Prims
6 | -----
7 |
8 | .. currentmodule:: thunder.core.prims
9 |
10 | .. autosummary::
11 | :toctree: generated/
12 |
13 | PrimIDs
14 |
--------------------------------------------------------------------------------
/docs/source/reference/core/proxies.rst:
--------------------------------------------------------------------------------
1 | .. module:: thunder.core.proxies
2 |
3 | Proxies
4 | -------
5 |
6 | .. currentmodule:: thunder.core.proxies
7 |
8 | .. autosummary::
9 | :toctree: generated/
10 |
11 | Variable
12 | variableify
13 | unvariableify
14 | CollectionProxy
15 | pyval
16 | pytype
17 | ComplexProxy
18 | IntegerProxy
19 | FloatProxy
20 | FutureTensorProxy
21 | TensorProxy
22 | is_proxyable
23 | proxy
24 |
--------------------------------------------------------------------------------
/docs/source/reference/core/pytree.rst:
--------------------------------------------------------------------------------
1 | .. module:: thunder.core.pytree
2 |
3 | PyTree
4 | ------
5 |
6 | .. list-table:: pytree API map
7 | :widths: 50 50
8 | :header-rows: 1
9 |
10 | * - thunder.core.pytree
11 | - pytorch pytree
12 | * - :func:`tree_map`
13 | - See https://optree.readthedocs.io/en/latest/ops.html#optree.tree_map
14 | * - :func:`tree_flatten`
15 | - See https://optree.readthedocs.io/en/latest/ops.html#optree.tree_flatten
16 | * - :func:`tree_unflatten`
17 | - See https://optree.readthedocs.io/en/latest/ops.html#optree.tree_unflatten
18 |
--------------------------------------------------------------------------------
/docs/source/reference/core/rematerialization.rst:
--------------------------------------------------------------------------------
1 | .. module:: thundre.core.rematerialization
2 |
3 | Rematerialization
4 | -----------------
5 |
6 | .. currentmodule:: thunder.core.rematerialization
7 |
8 | .. autosummary::
9 | :toctree: generated/
10 |
11 | rematerialize
12 |
--------------------------------------------------------------------------------
/docs/source/reference/core/symbol.rst:
--------------------------------------------------------------------------------
1 | .. module:: thunder.core.symbol
2 |
3 | Symbol
4 | ------
5 |
6 | .. currentmodule:: thunder.core.symbol
7 |
8 | .. autosummary::
9 | :toctree: generated/
10 |
11 | Symbol
12 | BoundSymbol
13 | BoundSymbolRHS
14 |
--------------------------------------------------------------------------------
/docs/source/reference/core/trace.rst:
--------------------------------------------------------------------------------
1 | .. module:: thunder.core.trace
2 |
3 | Trace
4 | -----
5 |
6 | .. currentmodule:: thunder.core.trace
7 |
8 | .. autosummary::
9 | :toctree: generated/
10 |
11 | TraceCtx
12 | from_trace
13 | set_tracectx
14 | get_tracectx
15 | reset_tracectx
16 |
--------------------------------------------------------------------------------
/docs/source/reference/core/transforms.rst:
--------------------------------------------------------------------------------
1 | .. module:: thunder.core.transforms
2 |
3 | Transforms
4 | ----------
5 |
6 | .. currentmodule:: thunder.core.transforms
7 |
8 | .. autosummary::
9 | :toctree: generated/
10 |
11 | Node
12 | bsym_list_to_dag
13 | toposort_bsym_dag
14 | insert_inplace
15 | replace_inplace
16 | VISIT_TYPE
17 | visitor_transform
18 |
--------------------------------------------------------------------------------
/docs/source/reference/distributed/index.rst:
--------------------------------------------------------------------------------
1 | .. module:: thunder.distributed
2 |
3 | thunder.distributed
4 | ===================
5 |
6 | .. autosummary::
7 | :toctree: generated/
8 |
9 | ddp
10 | fsdp
11 | FSDPType
12 | FSDPBucketingStrategy
13 | set_skip_data_parallel_grad_sync
14 | reset_skip_data_parallel_grad_sync
15 | get_skip_data_parallel_grad_sync
16 | skip_data_parallel_grad_sync
17 | column_parallel
18 | row_parallel
19 |
--------------------------------------------------------------------------------
/docs/source/reference/dynamo/index.rst:
--------------------------------------------------------------------------------
1 | .. module:: thunder.dynamo
2 |
3 | thunder.dynamo
4 | ==============
5 |
6 | .. autosummary::
7 | :toctree:
8 |
9 | ThunderCompiler
10 |
--------------------------------------------------------------------------------
/docs/source/reference/examine/index.rst:
--------------------------------------------------------------------------------
1 | .. module:: thunder.examine
2 |
3 | thunder.examine
4 | ===============
5 |
6 | .. autosummary::
7 | :toctree: generated/
8 |
9 | examine
10 |
--------------------------------------------------------------------------------
/docs/source/reference/executors/apexex.rst:
--------------------------------------------------------------------------------
1 | .. module:: thunder.executors.apexex
2 |
3 | APEX Executor
4 | -------------
5 |
6 | Executor of `NVIDIA/apex `_.
7 |
8 | .. currentmodule:: thunder.executors.apexex
9 |
10 | .. autosummary::
11 | :toctree: generated/
12 |
13 | apex_entropy_available
14 |
--------------------------------------------------------------------------------
/docs/source/reference/executors/cudnnex.rst:
--------------------------------------------------------------------------------
1 | .. module:: thunder.executors.cudnnex
2 |
3 | cuDNN Executor
4 | --------------
5 |
6 | Executor of `NVIDIA/cudnn-frontend `_.
7 |
8 | .. currentmodule:: thunder.executors.cudnnex
9 |
10 | .. autosummary::
11 | :toctree: generated/
12 |
13 | cudnn_ex
14 |
--------------------------------------------------------------------------------
/docs/source/reference/executors/index.rst:
--------------------------------------------------------------------------------
1 | .. module:: thunder.executors
2 |
3 | thunder.executors
4 | =================
5 |
6 | .. currentmodule:: thunder.executors
7 |
8 | .. autosummary::
9 | :toctree: generated/
10 |
11 | get_nvfuser_executor
12 | get_torch_executor
13 | nvfuser_available
14 |
15 | .. toctree::
16 | :maxdepth: 1
17 |
18 | apexex
19 | cudnnex
20 | nvfuserex
21 | passes
22 | pythonex
23 | torch_compile
24 | torchex
25 | triton_crossentropy
26 | utils
27 |
--------------------------------------------------------------------------------
/docs/source/reference/executors/nvfuserex.rst:
--------------------------------------------------------------------------------
1 | .. module:: thunder.executors.nvfuserex
2 |
3 | nvFuser Executor
4 | ----------------
5 |
6 | An executor dispatches operators to `NVIDIA/Fuser `_.
7 |
8 | .. currentmodule:: thunder.executors.nvfuserex
9 |
10 | .. autosummary::
11 | :toctree: generated/
12 |
13 | .. fixme(crcrpar) add methods and classes
14 |
--------------------------------------------------------------------------------
/docs/source/reference/executors/passes.rst:
--------------------------------------------------------------------------------
1 | .. module:: thunder.executors.passes
2 |
3 | Optimization Passes
4 | -------------------
5 |
6 | .. currentmodule:: thunder.executors.passes
7 |
8 | .. autosummary::
9 | :toctree: generated/
10 |
11 | transform_for_execution
12 | update_fusion_call_ctx
13 | del_last_used
14 |
--------------------------------------------------------------------------------
/docs/source/reference/executors/pythonex.rst:
--------------------------------------------------------------------------------
1 | .. module:: thunder.executors.pythonex
2 |
3 |
4 | Python Executor
5 | ---------------
6 |
7 | .. currentmodule:: thunder.executors.pythonex
8 |
9 | .. autosummary::
10 | :toctree: generated/
11 |
--------------------------------------------------------------------------------
/docs/source/reference/executors/torch_compile.rst:
--------------------------------------------------------------------------------
1 | .. module:: thunder.executors.torch_compile
2 |
3 | Torch Compile Executor
4 | ----------------------
5 |
6 | .. currentmodule:: thunder.executors.torch_compile
7 |
8 | .. autosummary::
9 | :toctree: generated/
10 |
11 | make_compiled
12 | TorchCompileExecutor
13 |
--------------------------------------------------------------------------------
/docs/source/reference/executors/torchex.rst:
--------------------------------------------------------------------------------
1 | .. module:: thunder.executors.torchex
2 |
3 |
4 | PyTorch Executor
5 | ----------------
6 |
7 | .. currentmodule:: thunder.executors.torch_autograd
8 |
9 | .. autosummary::
10 | :toctree: generated/
11 |
--------------------------------------------------------------------------------
/docs/source/reference/executors/triton_crossentropy.rst:
--------------------------------------------------------------------------------
1 | .. module:: thunder.executors.triton_crossentropy
2 |
3 | Triton Executor
4 | ---------------
5 |
6 | Executor of `openai/triton `_.
7 |
8 | .. currentmodule:: thunder.executors.triton_crossentropy
9 |
10 | .. autosummary::
11 | :toctree: generated/
12 |
--------------------------------------------------------------------------------
/docs/source/reference/executors/utils.rst:
--------------------------------------------------------------------------------
1 | .. module:: thunder.executors.utils
2 |
3 | Utils
4 | -----
5 |
6 | .. currentmodule:: thunder.executors.utils
7 |
8 | .. autosummary::
9 | :toctree: generated/
10 |
11 | Region
12 |
--------------------------------------------------------------------------------
/docs/source/reference/extend/index.rst:
--------------------------------------------------------------------------------
1 | .. module:: thunder.extend
2 |
3 | thunder.extend
4 | ==============
5 |
6 | .. autosummary::
7 | :toctree:
8 |
9 | register_executor
10 | deregister_executor
11 | get_all_executors
12 | get_default_executors
13 | get_always_executors
14 | get_executor
15 | set_default_executors
16 | set_always_executors
17 | add_default_executor
18 | add_always_executor
19 | remove_default_executor
20 | remove_always_executor
21 |
22 | Executor
23 |
--------------------------------------------------------------------------------
/docs/source/reference/recipes/index.rst:
--------------------------------------------------------------------------------
1 | .. module:: thunder.recipes
2 |
3 | thunder.recipes
4 | ==================
5 |
6 | .. autosummary::
7 | :toctree: generated/
8 |
9 | BaseRecipe
10 | HFTransformers
11 |
--------------------------------------------------------------------------------
/docs/source/reference/thunder.rst:
--------------------------------------------------------------------------------
1 | .. module:: thunder
2 |
3 | thunder
4 | =======
5 |
6 |
7 | Compiling functions and modules
8 | -------------------------------
9 |
10 |
11 | .. autosummary::
12 | :toctree: generated/
13 |
14 | jit
15 |
16 |
17 | Querying information on compiled functions and modules
18 | ------------------------------------------------------
19 |
20 |
21 | .. autosummary::
22 | :toctree: generated/
23 |
24 | DebugOptions
25 | compile_data
26 | compile_stats
27 | last_traces
28 | last_backward_traces
29 | last_prologue_traces
30 | cache_option
31 | cache_hits
32 | cache_misses
33 | list_transforms
34 | last_interpreted_instructions
35 | last_interpreter_log
36 | last_compile_options
37 | ..
38 | compile
39 | grad
40 |
41 | JITed Model wrapper
42 | -------------------
43 |
44 | .. autoclass:: ThunderModule
45 | :members: no_sync
46 | :exclude-members: forward
47 |
--------------------------------------------------------------------------------
/docs/source/reference/torch/index.rst:
--------------------------------------------------------------------------------
1 | .. module:: thunder.torch
2 |
3 | thunder.torch
4 | -------------
5 |
6 | A PyTorch dialect in thunder.
7 |
8 | .. currentmodule:: thunder.torch
9 |
10 | .. autosummary::
11 | :toctree: generated/
12 |
13 | torchsymbol
14 | size
15 | to_float
16 |
17 | Under Construction
18 |
19 | .. fixme "Cannot resolve forward reference in type annotations of "thunder.torch.abs": name 'Callable' is not defined"
20 |
21 | Operators
22 | ~~~~~~~~~
23 |
24 | Unary
25 | ~~~~~
26 |
27 | Binary
28 | ~~~~~~
29 |
30 | Conditional
31 | ~~~~~~~~~~~
32 |
33 | Tensor Creation
34 | ~~~~~~~~~~~~~~~
35 |
36 | Shape Operation
37 | ~~~~~~~~~~~~~~~
38 |
--------------------------------------------------------------------------------
/docs/source/reference/transforms/index.rst:
--------------------------------------------------------------------------------
1 | .. module:: thunder.transforms
2 |
3 | thunder.transforms
4 | ==================
5 |
6 | .. autosummary::
7 | :toctree: generated/
8 |
9 | MaterializationTransform
10 | ConstantFolding
11 |
--------------------------------------------------------------------------------
/examples/quickstart/hf_bert.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import transformers
3 |
4 | import thunder
5 |
6 | from thunder.dev_utils.benchmark import benchmark
7 |
8 |
9 | def main():
10 | # model_name = "bert-large-uncased"
11 | model_name = "bert-base-uncased"
12 |
13 | device = "cuda:0" if torch.cuda.is_available() else "cpu"
14 |
15 | tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
16 |
17 | with torch.device(device):
18 | model = transformers.AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
19 | model.requires_grad_(False)
20 | model.eval()
21 | # apparently, Transformers 4.51.3 does not instantiate models on the default device
22 | model.to(device)
23 |
24 | inp = tokenizer(["Hello world!"], return_tensors="pt")
25 |
26 | print(f"Eager: {benchmark(model, **inp):.2f}ms")
27 |
28 | thunder_model = thunder.compile(
29 | model,
30 | recipe="hf-transformers",
31 | )
32 |
33 | print(f"Thunder: {benchmark(thunder_model, **inp):.2f}ms")
34 |
35 | if torch.cuda.is_available():
36 | thunder_model = thunder.compile(model, plugins="reduce-overhead")
37 |
38 | print(f"Thunder with 'reduce-overhead': {benchmark(thunder_model, **inp):.2f}ms")
39 |
40 |
41 | if __name__ == "__main__":
42 | torch.backends.cuda.matmul.allow_tf32 = True
43 |
44 | main()
45 |
--------------------------------------------------------------------------------
/examples/quickstart/hf_llm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import transformers
3 |
4 | import thunder
5 |
6 | from thunder.dev_utils.benchmark import benchmark_n
7 |
8 |
9 | def main():
10 | # model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
11 | # model_name = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
12 | # model_name = "meta-llama/Llama-3.1-8B"
13 | model_name = "meta-llama/Llama-3.2-1B"
14 | # model_name = "Qwen/Qwen2.5-7B-Instruct"
15 | # model_name = "microsoft/phi-4"
16 |
17 | device = "cuda:0" if torch.cuda.is_available() else "cpu"
18 |
19 | tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
20 |
21 | with torch.device(device):
22 | model = transformers.AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
23 | model.requires_grad_(False)
24 | model.eval()
25 | # apparently, Transformers 4.51.3 does not instantiate models on the default device
26 | model.to(device)
27 |
28 | inp = tokenizer(["Hello world! Here's a long story"], return_tensors="pt")
29 |
30 | def generate(model, inp, cache=None):
31 | out = model.generate(**inp, do_sample=False, cache_implementation=cache, max_new_tokens=100)
32 | print(tokenizer.decode(out[0].tolist()))
33 |
34 | print("\nGenerating with PyTorch eager:")
35 | eager_time = benchmark_n(2, generate, model, inp)
36 |
37 | thunder_model = thunder.compile(
38 | model,
39 | recipe="hf-transformers",
40 | )
41 |
42 | print("\nGenerating with Thunder:")
43 | thunder_time = benchmark_n(2, generate, thunder_model, inp, cache="static")
44 |
45 | print(f"\nEager: {eager_time:.2f}ms")
46 | print(f"Thunder: {thunder_time:.2f}ms")
47 |
48 |
49 | if __name__ == "__main__":
50 | torch.backends.cuda.matmul.allow_tf32 = True
51 |
52 | main()
53 |
--------------------------------------------------------------------------------
/examples/quickstart/mlp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | import thunder
5 |
6 |
7 | def main():
8 | model = nn.Sequential(
9 | nn.Linear(2048, 4096),
10 | nn.ReLU(),
11 | nn.Linear(4096, 64)
12 | )
13 |
14 | thunder_model = thunder.compile(model)
15 | x = torch.randn(64, 2048)
16 | y = thunder_model(x)
17 |
18 | print(thunder_model)
19 |
20 | print(thunder.last_traces(thunder_model)[-1])
21 |
22 |
23 | if __name__ == "__main__":
24 | main()
25 |
--------------------------------------------------------------------------------
/examples/quickstart/requirements.txt:
--------------------------------------------------------------------------------
1 | transformers
2 | accelerate
3 | nvfuser-cu128-torch27
4 |
--------------------------------------------------------------------------------
/examples/quickstart/vit_hf.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from transformers import ViTForImageClassification
3 |
4 | import thunder
5 |
6 | from thunder.dev_utils.benchmark import benchmark
7 |
8 |
9 | def main():
10 | device = "cuda:0" if torch.cuda.is_available() else "cpu"
11 |
12 | with torch.device(device):
13 | model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224", attn_implementation="sdpa", torch_dtype=torch.float32)
14 | model.requires_grad_(False)
15 | model.eval()
16 | # apparently, Transformers 4.51.3 does not instantiate models on the default device
17 | model.to(device)
18 |
19 | inp = torch.randn(128, 3, 224, 224)
20 |
21 | out = model(inp)
22 |
23 | thunder_model = thunder.compile(model, plugins="reduce-overhead" if torch.cuda.is_available() else None)
24 |
25 | thunder_out = thunder_model(inp)
26 |
27 | torch.testing.assert_close(out, thunder_out, atol=1e-2, rtol=1e-2)
28 |
29 | print(f"Eager: {benchmark(model, inp):.2f}ms")
30 | print(f"Thunder: {benchmark(thunder_model, inp):.2f}ms")
31 |
32 |
33 | if __name__ == "__main__":
34 | torch.set_float32_matmul_precision('high')
35 |
36 | main()
37 |
--------------------------------------------------------------------------------
/examples/quickstart/vit_tv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision.models as models
3 |
4 | import thunder
5 |
6 | from thunder.dev_utils.benchmark import benchmark
7 |
8 |
9 | def main():
10 | device = "cuda:0" if torch.cuda.is_available() else "cpu"
11 |
12 | with torch.device(device):
13 | model = models.vit_b_16()
14 | model.requires_grad_(False)
15 | model.eval()
16 |
17 | inp = torch.randn(128, 3, 224, 224)
18 |
19 | out = model(inp)
20 |
21 | thunder_model = thunder.compile(model, plugins="reduce-overhead" if torch.cuda.is_available() else None)
22 |
23 | thunder_out = thunder_model(inp)
24 |
25 | # print(thunder.last_traces(thunder_model)[-1])
26 |
27 | torch.testing.assert_close(out, thunder_out)
28 |
29 | print(f"Eager: {benchmark(model, inp):.2f}ms")
30 | print(f"Thunder: {benchmark(thunder_model, inp):.2f}ms")
31 |
32 |
33 | if __name__ == "__main__":
34 | torch.set_float32_matmul_precision('high')
35 |
36 | main()
37 |
--------------------------------------------------------------------------------
/notebooks/.ignore.ci:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-AI/lightning-thunder/8bd802c89988e6668f5328d3927deda1b4e9e706/notebooks/.ignore.ci
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | -r requirements/base.txt
2 |
--------------------------------------------------------------------------------
/requirements/base.txt:
--------------------------------------------------------------------------------
1 | torch >=2.3.0
2 | looseversion ==1.3.0
3 | lightning-utilities >=0.7.0
4 | numpy
5 | networkx >= 3.3
6 | optree >=0.12.1
7 | opt_einsum >= 3.3.0
8 | # todo: teporarl pin for `NameError: name '_C' is not defined`
9 | mpmath <1.4.0
10 | # Support for 3.12
11 | dill >=0.3.8
12 |
--------------------------------------------------------------------------------
/requirements/devel.txt:
--------------------------------------------------------------------------------
1 | packaging
2 | -r test.txt
3 |
--------------------------------------------------------------------------------
/requirements/docs.txt:
--------------------------------------------------------------------------------
1 | sphinx ==5.3.0
2 | myst-parser ==1.0.0
3 | nbsphinx ~=0.9.7
4 | ipython[all] ~=8.37.0
5 | pandoc ==2.4
6 | docutils >=0.16
7 | sphinxcontrib-fulltoc @ git+https://github.com/t-vi/sphinx-contrib-fulltoc@master
8 | sphinxcontrib-mockautodoc
9 |
10 | sphinx-autodoc-typehints ==1.23.0
11 | sphinx-paramlinks ==0.6.0
12 | sphinx-togglebutton ==0.3.2
13 | sphinx-copybutton ==0.5.2
14 | snowballstemmer < 4
15 |
16 | # installed from S3 location and fetched in advance
17 | lai-sphinx-theme
18 | # alternative /back-up (old) theme
19 | pt-lightning-sphinx-theme @ https://github.com/PyTorchLightning/lightning_sphinx_theme/archive/master.zip
20 |
--------------------------------------------------------------------------------
/requirements/notebooks.txt:
--------------------------------------------------------------------------------
1 | ipython[all] ~=8.37.0
2 | numpy
3 | liger-kernel == 0.4.0
4 | cuda-python
5 | litgpt == 0.5.1
6 |
--------------------------------------------------------------------------------
/requirements/test.txt:
--------------------------------------------------------------------------------
1 | coverage ~=7.8.2
2 | pytest ==8.3.5
3 | pytest-benchmark ==5.1.0
4 | pytest-timeout ==2.4.0
5 | pytest-cov ==6.1.1
6 | pytest-xdist ==3.7.0
7 | pytest-random-order ==1.1.1
8 | pytest-timestamper ==0.0.10
9 | graphviz ==0.20.3
10 | fdm ==0.5.0
11 | expecttest ==0.3.0 # for test_ddp.py
12 | hypothesis ~=6.133.0 # for test_ddp.py
13 | numpy
14 | einops # for test_einops.py
15 | litgpt==0.4.11 # for the model definition in tests and benchmarks
16 | absl-py # thunder/benchmarks/test_benchmark_litgpt.py
17 | pandas # thunder/benchmarks/test_benchmark_litgpt.py
18 | xlsxwriter # thunder/benchmarks/test_benchmark_litgpt.py
19 | jsonargparse # thunder/benchmarks/benchmark_litgpt.py
20 | bitsandbytes==0.42.0 # fixed version!
21 | transformers==4.52.4 # for test_networks.py
22 |
23 | # Installs JAX on Linux and MacOS
24 | jaxlib; sys_platform == 'linux' or sys_platform == 'darwin' # required for jax, see https://github.com/google/jax#installation
25 | jax; sys_platform == 'linux' or sys_platform == 'darwin' # for test_ops.py
26 |
27 | asvdb @ git+https://github.com/rapidsai/asvdb.git
28 | asv >=0.6.4
29 |
--------------------------------------------------------------------------------
/scripts/bisect_nvfuser.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import subprocess
4 |
5 |
6 | class Runner:
7 | def __init__(self, command_and_args: list[str], verbose: bool):
8 | self._command_and_args = command_and_args
9 | self._verbose = verbose
10 |
11 | def env_var_for_fuel(self) -> str:
12 | return "NVFUSER_OPTIMIZATION_FUEL"
13 |
14 | def run(self, fuel: int) -> int:
15 | os.environ[self.env_var_for_fuel()] = str(fuel)
16 | print(f"Running {self._command_and_args} with {fuel} units of fuel...")
17 | result = subprocess.run(
18 | self._command_and_args,
19 | check=False,
20 | stdout=(None if self._verbose else subprocess.DEVNULL),
21 | stderr=(None if self._verbose else subprocess.DEVNULL),
22 | )
23 | print(f"Command returned {result.returncode} with {fuel} units of fuel.")
24 | return result.returncode
25 |
26 |
27 | if __name__ == "__main__":
28 | parser = argparse.ArgumentParser(
29 | description="""
30 | A script that bisects nvFuser's optimization fuel to isolate a compiler bug.
31 |
32 | See `thunder.extend.FusionExecutor.get_fuel/set_fuel` for what optimization fuel is and how it is implemented.
33 |
34 | This script finds the minimum unit of optimization fuel between `lower_bound` and `upper_bound` that causes `command_and_args` to fail (i.e. exit with non-zero). The user will then run `NVFUSER_OPTIMIZATION_FUEL= ` as a minimal reproducer to identify the nvFusion that triggers the failure. Likely, the last nvFusion generated is the culprit.
35 | """
36 | )
37 | parser.add_argument("lower_bound", type=int, help="the lower bound for bisecting")
38 | parser.add_argument("upper_bound", type=int, help="the upper bound for bisecting")
39 | parser.add_argument("command_and_args", type=str, nargs=argparse.REMAINDER, help="the command and args to run")
40 | parser.add_argument("--verbose", action="store_true", help="whether to show stdout/stderr of the subprocess")
41 | args = parser.parse_args()
42 |
43 | runner = Runner(args.command_and_args, args.verbose)
44 |
45 | # The cornercases are not needed for the correctness of binary search. They are for catching errors earlier when the user specified a wrong lower/upper bound.
46 | if (exitcode := runner.run(args.lower_bound)) != 0:
47 | print("No need to bisect. Command failed with the lower bound.")
48 | exit(0)
49 |
50 | if (exitcode := runner.run(args.upper_bound)) == 0:
51 | print("Bisecting failed. Command passed with the upper bound.")
52 | exit(1)
53 |
54 | # Find the smallest fuel that fails `command_and_args`.
55 | low = args.lower_bound + 1 # +1 because we know `lower_bound` passed.
56 | high = args.upper_bound
57 | while low < high:
58 | mid = (low + high) // 2
59 | exitcode = runner.run(mid)
60 | if exitcode == 0:
61 | low = mid + 1
62 | else:
63 | high = mid
64 | assert low == high
65 |
66 | print(f"Bisecting succeeded. Run the following command as a minimal reproducer:")
67 | print(f" {runner.env_var_for_fuel()}={low} {' '.join(args.command_and_args)}")
68 | print(f"The last nvFusion likely triggered the failure.")
69 |
--------------------------------------------------------------------------------
/scripts/validate_build.py:
--------------------------------------------------------------------------------
1 | """Check that `scripts/build_from_source.sh` has produced a correct build."""
2 |
3 | import os
4 |
5 | BUILD_DIR = os.path.abspath(os.path.join(os.path.split(__file__)[0], "build"))
6 |
7 |
8 | def main():
9 | import torch
10 |
11 | expected = os.path.join(BUILD_DIR, "pytorch", "torch", "__init__.py")
12 | actual = os.path.abspath(torch.__file__)
13 | assert expected == actual, f"{expected} vs. {actual}"
14 |
15 | import nvfuser
16 | from looseversion import LooseVersion
17 |
18 | assert LooseVersion(nvfuser.version()) >= LooseVersion("0.0.6"), nvfuser.version()
19 |
20 | import thunder
21 |
22 | expected = os.path.abspath(os.path.join(BUILD_DIR, "..", "..", "thunder", "__init__.py"))
23 | assert expected == thunder.__file__, f"{expected} vs. {thunder.__file__}"
24 |
25 | print("Build looks good!")
26 |
27 |
28 | if __name__ == "__main__":
29 | main()
30 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | from importlib.util import module_from_spec, spec_from_file_location
4 |
5 | from packaging.requirements import Requirement as parse_requirements
6 | from setuptools import find_packages, setup
7 |
8 |
9 | _PATH_ROOT = os.path.dirname(__file__)
10 | _PATH_REQUIRES = os.path.join(_PATH_ROOT, "requirements")
11 | # check if os env. variable is set to convert version to nightly
12 | _CONVERT_VERSION = int(os.environ.get("CONVERT_VERSION2NIGHTLY", 0))
13 |
14 |
15 | def _load_py_module(fname, pkg="thunder"):
16 | spec = spec_from_file_location(os.path.join(pkg, fname), os.path.join(_PATH_ROOT, pkg, fname))
17 | py = module_from_spec(spec)
18 | spec.loader.exec_module(py)
19 | return py
20 |
21 |
22 | def _load_requirements(path_dir: str, file_name: str = "requirements.txt") -> list[str]:
23 | """Parse requirements file -> list of strings"""
24 | reqs: list[str] = []
25 | with open(os.path.join(path_dir, file_name)) as f:
26 | for line in f:
27 | if line and not line.startswith("#"):
28 | reqs.append(line.strip())
29 | # Filter out requirements referring to local paths or specific URLs (if any)
30 | reqs = [str(parse_requirements(r)) for r in reqs if "@" not in r and "://" not in r]
31 | return reqs
32 |
33 |
34 | def convert_version2nightly(about_file: str = "thunder/__about__.py") -> None:
35 | """Load the actual version and convert it to the nightly version."""
36 | from datetime import datetime
37 |
38 | # load the about file
39 | with open(about_file) as fo:
40 | lines = fo.readlines()
41 | idx = None
42 | # find the line with version
43 | for i, ln in enumerate(lines):
44 | if ln.startswith("__version__"):
45 | idx = i
46 | break
47 | if idx is None:
48 | raise ValueError("The version is not found in the `__about__.py` file.")
49 | # parse the version from variable assignment
50 | version = lines[idx].split("=")[1].strip().strip('"')
51 | # parse X.Y.Z version and prune any suffix
52 | vers = re.match(r"(\d+)\.(\d+)\.(\d+).*", version)
53 | # create timestamp YYYYMMDD
54 | timestamp = datetime.now().strftime("%Y%m%d")
55 | version = f"{'.'.join(vers.groups())}.dev{timestamp}"
56 | # print the new version
57 | lines[idx] = f'__version__ = "{version}"\n'
58 | # dump updated lines
59 | with open(about_file, "w") as fo:
60 | fo.writelines(lines)
61 |
62 |
63 | def _load_readme_description(path_dir: str, homepage: str, version: str) -> str:
64 | """Load readme as decribtion."""
65 | path_readme = os.path.join(path_dir, "README.md")
66 | with open(path_readme, encoding="utf-8") as fp:
67 | text = fp.read()
68 | # https://github.com/Lightning-AI/lightning-thunder/raw/master/docs/source/_static/images/lightning_module/pt_to_pl.png
69 | github_source_url = os.path.join(homepage, "raw", version)
70 | # replace relative repository path to absolute link to the release
71 | # do not replace all "docs" as in the readme we replace some other sources with particular path to docs
72 | text = text.replace("docs/source/_static/", f"{os.path.join(github_source_url, 'docs/source/_static/')}")
73 | return text
74 |
75 |
76 | if _CONVERT_VERSION:
77 | convert_version2nightly()
78 |
79 | about = _load_py_module("__about__.py")
80 |
81 | setup(
82 | version=about.__version__,
83 | packages=find_packages(exclude=["thunder/tests", "docs"]),
84 | long_description=_load_readme_description(
85 | path_dir=_PATH_ROOT,
86 | homepage=about.__homepage__,
87 | version=about.__version__,
88 | ),
89 | long_description_content_type="text/markdown",
90 | install_requires=_load_requirements(_PATH_REQUIRES, file_name="base.txt"),
91 | include_package_data=True,
92 | zip_safe=False,
93 | )
94 |
--------------------------------------------------------------------------------
/thunder/__about__.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.2.4.dev0"
2 | __author__ = "Lightning-AI et al"
3 | __author_email__ = "community@lightning.ai"
4 | __copyright__ = f"2024 {__author__}"
5 | __homepage__ = "https://github.com/Lightning-AI/lightning-thunder"
6 | __docs__ = "Lightning Thunder project."
7 |
8 |
9 | __all__ = [
10 | "__author__",
11 | "__author_email__",
12 | "__copyright__",
13 | "__docs__",
14 | "__homepage__",
15 | "__version__",
16 | ]
17 |
--------------------------------------------------------------------------------
/thunder/benchmarks/einsum.py:
--------------------------------------------------------------------------------
1 | from collections.abc import Callable
2 | from functools import partial, wraps
3 | from collections.abc import Sequence
4 |
5 | import pytest
6 | import torch
7 | import thunder
8 |
9 | from thunder.benchmarks import EinsumBenchmark
10 | from thunder.benchmarks.targets import (
11 | make_setup,
12 | fwd_executors,
13 | fwd_executor_ids,
14 | grad_executors,
15 | grad_executors_ids,
16 | thunder_gradv1,
17 | thunder_torchcompile_gradv1,
18 | )
19 |
20 |
21 | def _instantiate_benchmark_env(
22 | shapes: Sequence[Sequence[int]],
23 | equation: str,
24 | executor: Callable,
25 | device: str = "cuda:0",
26 | dtype: thunder.dtypes.dtype = thunder.bfloat16,
27 | requires_grad: bool = False,
28 | ) -> Callable:
29 | bench: Benchmark = EinsumBenchmark(
30 | shapes=shapes, equation=equation, device=device, dtype=dtype, requires_grad=requires_grad
31 | )
32 |
33 | setup = make_setup(bench)
34 | fn = executor(bench)
35 |
36 | return setup, fn
37 |
38 |
39 | _test_size_map = {
40 | "small": 16,
41 | "medium": 128,
42 | "large": 512,
43 | }
44 |
45 |
46 | def single_dim_contraction_cases(size: str = "small", left_broadcasts: None | bool = None):
47 | n = _test_size_map[size]
48 |
49 | lhs_shape = [2, n, n]
50 | rhs_shape = [n, n]
51 |
52 | if left_broadcasts is not None:
53 | if left_broadcasts:
54 | lhs_shape[-1] = 1
55 | else:
56 | rhs_shape[0] = 1
57 |
58 | return (lhs_shape, rhs_shape), "bij,jk"
59 |
60 |
61 | def multidim_contraction_cases(size: str = "small", left_broadcasts: None | bool = None):
62 | n = _test_size_map[size]
63 |
64 | lhs_shape = [2, 8, n, n]
65 | rhs_shape = [2, n, 8, n]
66 |
67 | if left_broadcasts is not None:
68 | if left_broadcasts:
69 | lhs_shape[-1] = 1
70 | else:
71 | rhs_shape[-1] = 1
72 |
73 | return (lhs_shape, rhs_shape), "bijk,bklj->bli"
74 |
75 |
76 | class TestEinsumBenchmarks:
77 | @pytest.mark.parametrize(
78 | "executor,",
79 | fwd_executors,
80 | ids=fwd_executor_ids,
81 | )
82 | @pytest.mark.parametrize(
83 | "size,",
84 | _test_size_map.keys(),
85 | )
86 | @pytest.mark.parametrize(
87 | "sample_gen,", (single_dim_contraction_cases, multidim_contraction_cases), ids=("singledim", "multidim")
88 | )
89 | @pytest.mark.parametrize("left_broadcasts,", (None, True, False), ids=("", "left_broadcasts", "right_broadcasts"))
90 | def test_einsum_fwd(
91 | self, benchmark, executor: None | Callable, size: str, sample_gen: Callable, left_broadcasts: None | bool
92 | ):
93 | setup, fn = _instantiate_benchmark_env(
94 | *sample_gen(size, left_broadcasts), executor=executor, requires_grad=False
95 | )
96 | benchmark.pedantic(fn, setup=setup, rounds=20, warmup_rounds=1)
97 |
98 | @pytest.mark.parametrize(
99 | "executor,",
100 | [ge for ge in grad_executors if ge not in (thunder_gradv1, thunder_torchcompile_gradv1)],
101 | ids=[gei for gei in grad_executors_ids if gei not in ("thunder-gradv1", "thunder+torchcompile_cat-gradv1")],
102 | )
103 | @pytest.mark.parametrize(
104 | "size,",
105 | _test_size_map.keys(),
106 | )
107 | @pytest.mark.parametrize(
108 | "sample_gen,", (single_dim_contraction_cases, multidim_contraction_cases), ids=("singledim", "multidim")
109 | )
110 | @pytest.mark.parametrize(
111 | "left_broadcasts,",
112 | # False/right_broadcasts is disabled because of
113 | # https://github.com/NVIDIA/Fuser/issues/1590.
114 | # TODO: update once the issue is fixed.
115 | (None, True),
116 | ids=("", "left_broadcasts"),
117 | )
118 | def test_einsum_grad(
119 | self, benchmark, executor: None | Callable, size: str, sample_gen: Callable, left_broadcasts: None | bool
120 | ):
121 | setup, fn = _instantiate_benchmark_env(
122 | *sample_gen(size, left_broadcasts), executor=executor, requires_grad=True
123 | )
124 | benchmark.pedantic(fn, setup=setup, rounds=20, warmup_rounds=1)
125 |
--------------------------------------------------------------------------------
/thunder/clang/langctx.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 | from collections.abc import Callable, Sequence
3 |
4 | from thunder.core.langctxs import LanguageContext, register_langctx, Languages, resolve_language
5 | from thunder.core.pytree import tree_flatten
6 | from thunder.core.proxies import TensorProxy, NumberProxy
7 |
8 | #
9 | # Creates and registers the torch language context
10 | #
11 | # NOTE That this is done separately from the definition of thunder.torch operations, because the
12 | # language context must be available before those operations are defined
13 |
14 | _method_name_to_fn_map: dict[str, Callable] = {}
15 |
16 |
17 | # Creates and registers the core language language context
18 | class ClangCtx(LanguageContext):
19 | def __init__(self):
20 | super().__init__("core")
21 |
22 | def has_method(self, id: str) -> bool:
23 | return id in _method_name_to_fn_map
24 |
25 | def get_method(self, id: str, *args, **kwargs) -> Callable:
26 | # Note: concrete implmenetations should only raise AttributeError or
27 | # return None for "missing" methods as the proxies will
28 | # route __getattr__ to here and hasattr relies on __getattr__
29 | # throwing AttributeError (only) when the attribute does
30 | # not exist.
31 | inps, _ = tree_flatten((args, kwargs))
32 |
33 | has_proxy_input: bool = False
34 | for x in inps:
35 | if isinstance(x, TensorProxy) or isinstance(x, NumberProxy):
36 | has_proxy_input = True
37 | break
38 |
39 | if has_proxy_input:
40 | method: None | Callable = _method_name_to_fn_map.get(id, None)
41 | if method is None:
42 | raise AttributeError(f"The {self.name} language context has no method {id}")
43 | return method
44 |
45 | # has_proxy_input is False
46 | # Defers to the primitive language context when there are no tensor inputs=
47 | # (the primitive language context handles operations on numbers)
48 | primsctx: LanguageContext = resolve_language(Languages.PRIMS)
49 | if not primsctx.has_method(id):
50 | raise AttributeError(
51 | f"Attempting to call method {id} in the core language context, but it has no tensor inputs and the primitive language context (which handles numbers) doesn't have the method"
52 | )
53 | prim_method: Callable = primsctx.get_method(id, *args, **kwargs)
54 | return prim_method
55 |
56 |
57 | clangctx = ClangCtx()
58 | register_langctx(Languages.CLANG, clangctx)
59 |
60 |
61 | # Registers a method with the torch language context
62 | def register_method(method_name: str, method: Callable, /) -> None:
63 | _method_name_to_fn_map[method_name] = method
64 |
--------------------------------------------------------------------------------
/thunder/core/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-AI/lightning-thunder/8bd802c89988e6668f5328d3927deda1b4e9e706/thunder/core/__init__.py
--------------------------------------------------------------------------------
/thunder/core/compile_data.py:
--------------------------------------------------------------------------------
1 | from contextlib import contextmanager
2 | from contextvars import ContextVar
3 | from typing import Any
4 |
5 | from thunder.core.options import CACHE_OPTIONS
6 |
7 | #
8 | # Setting and querying the "compile_data_and_stats" context variable, which contains
9 | # a tuple of (CompileData, CompileStats) objects for the current trace.
10 | #
11 |
12 | _compile_data = ContextVar("compile_data", default=(None, None))
13 |
14 |
15 | #
16 | # Setting, getting, and resetting the context variable
17 | #
18 |
19 |
20 | # NOTE This just acquires the compile data part of the context var's tuple
21 | def get_compile_data() -> None | Any:
22 | """Returns the current compile data."""
23 | cd, cs = _compile_data.get()
24 |
25 | return cd
26 |
27 |
28 | def set_compile_data_and_stats(cd, cs, /):
29 | """Sets the current compile data.
30 |
31 | This is used to pass compile data to functions that are called during compilation.
32 | """
33 | token = _compile_data.set((cd, cs))
34 | return token
35 |
36 |
37 | def reset_compile_data_and_stats(token, /):
38 | """Resets the compile data."""
39 | _compile_data.reset(token)
40 |
41 |
42 | @contextmanager
43 | def compile_data_and_stats(cd, cs, /):
44 | """Sets the current compile data for the duration of the context."""
45 | token = set_compile_data_and_stats(cd, cs)
46 | try:
47 | yield
48 | finally:
49 | reset_compile_data_and_stats(token)
50 |
51 |
52 | #
53 | # Query helpers
54 | #
55 |
56 |
57 | def get_compile_option(option: str, description: str, /) -> None | Any:
58 | cd, cs = _compile_data.get()
59 |
60 | if cd is None or cs is None:
61 | return None
62 |
63 | # See NOTE Different categories of compile options in thunder/__init__.py
64 | cs.last_compile_reasons[option].append(description)
65 | return cd.compile_options.get(option, None)
66 |
67 |
68 | # Whether or not the caching option uses symbolic values
69 | def get_cache_option() -> CACHE_OPTIONS:
70 | cd = get_compile_data()
71 | if cd is None:
72 | return CACHE_OPTIONS.CONSTANT_VALUES
73 | return cd.cache_option
74 |
75 |
76 | # TODO RC1 Remove the try (hack for when operating outside of this contextvar being set)
77 | def using_symbolic_values() -> bool:
78 | try:
79 | return get_cache_option() is CACHE_OPTIONS.SYMBOLIC_VALUES
80 | except:
81 | return False
82 |
83 |
84 | def using_jit() -> bool:
85 | try:
86 | cd, cs = _compile_data.get()
87 | return cd.using_jit
88 | except:
89 | return False
90 |
--------------------------------------------------------------------------------
/thunder/core/profile.py:
--------------------------------------------------------------------------------
1 | from collections.abc import Callable
2 | import contextlib
3 | import functools
4 | import os
5 | import warnings
6 |
7 | import torch
8 |
9 |
10 | _ENABLED = os.getenv("THUNDER_ANNOTATE_TRACES") in ("1", "y", "Y")
11 |
12 | # However, nvtx is incredibly cheap so we no longer bother requiring the
13 | # environment variable.
14 | try:
15 | import nvtx
16 | except ImportError:
17 | if _ENABLED:
18 | msg = "Requested nvtx but the package is not available."
19 | msg += "\nUse `pip install -m pip install nvtx`."
20 | warnings.warn(msg)
21 | _ENABLED = False
22 | raise
23 |
24 |
25 | def profiling_enabled() -> bool:
26 | return _ENABLED
27 |
28 |
29 | @contextlib.contextmanager
30 | def add_markers(msg: str) -> None:
31 | if not profiling_enabled():
32 | yield
33 | return
34 |
35 | assert "\n" not in msg, msg # Both NVTX and JSON forbid newlines
36 | assert '"' not in msg, msg # The PyTorch profiler does not properly escape quotations
37 |
38 | with torch.profiler.record_function(msg):
39 | torch.cuda.nvtx.range_push(msg)
40 | try:
41 | yield
42 |
43 | finally:
44 | torch.cuda.nvtx.range_pop()
45 |
46 |
47 | # The main interface to profiling something. Generally used as a decorator:
48 | # @thunder.core.profile.annotate_for_profile("foo")
49 | # def foo(...): ...
50 | # but alternatively as a `with` context:
51 | # with thunder.core.profile.annotate_for_profile("name for a block of code"):
52 | # # ... code ...
53 | annotate_for_profile: Callable[[str], None] = None
54 |
55 |
56 | if _ENABLED:
57 | annotate_for_profile = functools.partial(nvtx.annotate, domain="thunder")
58 | else:
59 |
60 | class _no_annotate(contextlib.nullcontext):
61 | """
62 | A profiling decorator that does nothing.
63 | """
64 |
65 | def __init__(self, *args, **kwargs):
66 | super().__init__()
67 |
68 | def __call__(self, fqn):
69 | return fqn
70 |
71 | annotate_for_profile = _no_annotate
72 |
--------------------------------------------------------------------------------
/thunder/dev_utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-AI/lightning-thunder/8bd802c89988e6668f5328d3927deda1b4e9e706/thunder/dev_utils/__init__.py
--------------------------------------------------------------------------------
/thunder/dev_utils/benchmark.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 |
3 | import torch
4 |
5 |
6 | def benchmark_n(n, model_or_fn, /, *args, **kwargs):
7 | for _ in range(n):
8 | _ = model_or_fn(*args, **kwargs)
9 | start_event = torch.cuda.Event(enable_timing=True)
10 | end_event = torch.cuda.Event(enable_timing=True)
11 | torch.cuda.synchronize()
12 | start_event.record()
13 | for _ in range(n):
14 | _ = model_or_fn(*args, **kwargs)
15 | end_event.record()
16 | torch.cuda.synchronize()
17 | return start_event.elapsed_time(end_event) / n
18 |
19 |
20 | benchmark = partial(benchmark_n, 10)
21 |
--------------------------------------------------------------------------------
/thunder/dev_utils/check_trace.py:
--------------------------------------------------------------------------------
1 | import thunder
2 |
3 | CHECK_VERSION = 3
4 |
5 |
6 | def check_subsymbols(parent_bsym):
7 | if parent_bsym.sym.is_prim:
8 | # assert that there are no subsymbols?
9 | return
10 | known_proxies = {a.name: a for a in parent_bsym.flat_proxy_args}
11 | for bsym in parent_bsym.subsymbols:
12 | for a in bsym.flat_proxy_args:
13 | assert a.name in known_proxies, f"unknown proxy {a.name} is used in {bsym} args"
14 | assert known_proxies[a.name] is a, f"proxy name collision {a.name} in {bsym} args"
15 | for o in bsym.flat_proxy_outs:
16 | assert known_proxies.get(o.name, o) is o, f"known proxy or proxy name collision {o.name} in {bsym} outputs"
17 | known_proxies[o.name] = o
18 | check_subsymbols(bsym)
19 | for o in parent_bsym.flat_proxy_outs:
20 | assert known_proxies.get(o.name, o) is o, f"known proxy or proxy name collision {o.name} in {parent_bsym}"
21 |
22 |
23 | def check_trace(trace, *, version=CHECK_VERSION):
24 | """checks a trace for consistency
25 |
26 | The check is versioned for the benefit of CI and other automated testing.
27 |
28 | As a user, don't pass a version to get all implemented checks.
29 |
30 | If you add new checks and do not fix all newly detected inconsistencies,
31 | bump the CHECK_VERSION and make your tests only apply to this latest version.
32 |
33 | Please do file issues for things that fail with the latest versions so we can
34 | catch up.
35 | """
36 | # TODO:
37 | # - args vs. unpack trivial
38 | # - args vs. flat_args in return
39 | known_proxies = {}
40 | for bsym in trace.bound_symbols:
41 | if (version >= 3) and bsym.sym == thunder.core.prims.unpack_sequence:
42 | coll = bsym.args[0].collection()
43 | assert len(coll) == len(bsym.output), f"unpack collection length mismatch {bsym}"
44 | for c, o in zip(coll, bsym.output):
45 | if o is None: # unused outputs
46 | continue
47 | if isinstance(c, thunder.Proxy):
48 | assert c is o, f"mismatch in unpack collection: {c} {o} {bsym}"
49 |
50 | for a in bsym.flat_proxy_args:
51 | assert a.name in known_proxies, f"unknown proxy {a.name} is used in {bsym} args"
52 | assert known_proxies[a.name] is a, f"proxy name collision {a.name} in {bsym} args"
53 | for o in bsym.flat_proxy_outs:
54 | assert known_proxies.get(o.name, o) is o, f"known proxy or proxy name collision {o.name} in {bsym} outputs"
55 | known_proxies[o.name] = o
56 | check_subsymbols(bsym)
57 |
58 | tr = thunder.core.trace.from_trace(trace)
59 | with thunder.core.trace.tracectx(tr):
60 | for bsym in trace.bound_symbols:
61 | if bsym.sym.name.startswith("unpack") or bsym.sym.name in {"pack_buffer"}:
62 | continue
63 | res = bsym.sym(*bsym.args, **bsym.kwargs)
64 |
65 | def check_shape(x, y):
66 | if isinstance(x, thunder.TensorProxy) and y is not None: # formal output can be none if unused
67 | assert x.shape == y.shape, f"shape of proxy {y.name} recomputes to {x.shape} incorrectly in {bsym}"
68 | return x
69 |
70 | thunder.core.utils.safe_map_flat(check_shape, res, bsym.output)
71 |
72 |
73 | class CheckedListOfTraces(list):
74 | def __init__(self, *args, **kwargs):
75 | super().__init__(*args, **kwargs)
76 | cd = thunder.core.compile_data.get_compile_data()
77 | if cd.debug_options.check_traces is True:
78 | self._check_version = CHECK_VERSION
79 | elif not cd.debug_options.check_traces:
80 | self._check_version = 0
81 | else:
82 | self._check_version = cd.debug_options.check_traces
83 |
84 | def append(self, trace):
85 | check_trace(trace, version=self._check_version)
86 | super().append(trace)
87 |
88 | def extend(self, traces):
89 | for tr in traces:
90 | check_trace(tr, version=self._check_version)
91 | super().extend(traces)
92 |
--------------------------------------------------------------------------------
/thunder/dev_utils/nvtx_profile_transform.py:
--------------------------------------------------------------------------------
1 | from thunder.core.trace import TraceCtx as Trace, from_trace, TraceProvenance
2 | from thunder.dev_utils.utils import NON_COMPUTATION_PRIMS
3 | from thunder.extend import OperatorExecutor
4 | import time
5 | import torch
6 | import thunder
7 |
8 |
9 | class Timer:
10 | def __init__(self):
11 | self.start_time_ns = None
12 | self.end_time_ns = None
13 |
14 | def __enter__(self):
15 | self.start_time_ns = time.perf_counter_ns()
16 | return self
17 |
18 | def __exit__(self, *args):
19 | self.end_time_ns = time.perf_counter_ns()
20 |
21 | def get_elapsed_time_in_ms(self):
22 | elapsed_time_ns = self.end_time_ns - self.start_time_ns
23 | return elapsed_time_ns // int(1e6)
24 |
25 |
26 | nvtx_profiler_ex = OperatorExecutor("nvtx_profiler_ex")
27 |
28 |
29 | def nvtx_push_impl(msg):
30 | torch.cuda.nvtx.range_push(msg)
31 |
32 |
33 | def nvtx_pop_impl():
34 | torch.cuda.nvtx.range_pop()
35 |
36 |
37 | # Symbols for profiling.
38 | nvtx_push = nvtx_profiler_ex.register_operator("nvtx_range_push", meta=lambda msg: None, fn=nvtx_push_impl)
39 | nvtx_pop = nvtx_profiler_ex.register_operator("nvtx_range_pop", meta=lambda: None, fn=nvtx_pop_impl)
40 |
41 |
42 | class NvtxProfileTransform(thunder.core.transforms.Transform):
43 | def transform_trace_post_optimization(self, trace: Trace, **kwargs) -> Trace:
44 | with Timer() as timer:
45 | profile_trace = from_trace(trace)
46 |
47 | for bound_symbol in trace.bound_symbols:
48 | if bound_symbol.sym.id in NON_COMPUTATION_PRIMS:
49 | profile_trace.bound_symbols.append(bound_symbol)
50 | continue
51 |
52 | # Add nvtx range for the symbol.
53 | profile_trace.bound_symbols.append(
54 | nvtx_push.bind(f"{''.join(bound_symbol.python(indent=0))}", output=None)
55 | )
56 | profile_trace.bound_symbols.append(bound_symbol)
57 | profile_trace.bound_symbols.append(nvtx_pop.bind(output=None))
58 |
59 | profile_trace.set_provenance(
60 | TraceProvenance(f"NVTX Profile Transform (took {timer.get_elapsed_time_in_ms()} milliseconds)")
61 | )
62 | return profile_trace
63 |
--------------------------------------------------------------------------------
/thunder/distributed/tensor_parallel/__init__.py:
--------------------------------------------------------------------------------
1 | import torch.distributed
2 |
3 | from thunder.distributed.tensor_parallel.common import TensorParallelLayerType
4 |
5 | if torch.distributed.is_available():
6 | from thunder.distributed.tensor_parallel.column_wise import column_parallel
7 | from thunder.distributed.tensor_parallel.row_wise import row_parallel
8 | else:
9 | column_parallel = None
10 | row_parallel = None
11 |
12 |
13 | __all__ = [
14 | "TensorParallelLayerType",
15 | "column_parallel",
16 | "row_parallel",
17 | ]
18 |
--------------------------------------------------------------------------------
/thunder/distributed/tensor_parallel/optimize_comm.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import TYPE_CHECKING
3 |
4 | from thunder.core import utils
5 | from thunder.core.proxies import TensorProxy
6 | from thunder.core.proxies import variableify
7 | from thunder.core.trace import from_trace
8 | from thunder.distributed.tensor_parallel.common import TensorParallelLayerType
9 |
10 | if TYPE_CHECKING:
11 | from thunder.core.trace import TraceCtx
12 | from thunder.core.symbol import BoundSymbol
13 | from thunder.core.proxies import ProxyInterface
14 | from thunder.core.trace import VariableInterface
15 |
16 |
17 | __all__ = [
18 | "remove_redundant_comms",
19 | ]
20 |
21 |
22 | def _get_tensor_parallel_layer_type(bsym: BoundSymbol) -> TensorParallelLayerType:
23 | value = bsym.flat_args[2]
24 | utils.check_type(value, TensorParallelLayerType)
25 | return value
26 |
27 |
28 | def remove_redundant_comms(trace: TraceCtx) -> TraceCtx:
29 | """Remove redundant paris of column-wise linear postprocessing and row-wise linear preprocessing.
30 |
31 | Args:
32 | trace: A trace modified by both of :func:`~thunder.distributed.tensor_parallel.column_parallel`
33 | and :func:`~thunder.distributed.tensor_parallel.row_parallel`.
34 | """
35 | from thunder.distributed import prims as dist_prims
36 |
37 | current_column_wise_parallel_linear_bsym: BoundSymbol | None = None
38 |
39 | interesting_pairs: list[tuple[BoundSymbol, BoundSymbol]] = []
40 | bsym_to_idx: dict[BoundSymbol, int] = {}
41 | idx_to_bsym: dict[int, BoundSymbol] = {}
42 | new_bsyms: list[BoundSymbol] = []
43 |
44 | bsym: BoundSymbol
45 | for idx, bsym in enumerate(trace.bound_symbols):
46 | bsym_to_idx[bsym] = idx
47 | idx_to_bsym[idx] = bsym
48 | new_bsyms.append(bsym)
49 | match bsym.sym.id:
50 | case dist_prims.PrimIDs.SYNCHRONIZE_TENSOR_PARALLEL_INPUT:
51 | match _get_tensor_parallel_layer_type(bsym):
52 | case TensorParallelLayerType.ROW_PARALLEL_LINEAR:
53 | if current_column_wise_parallel_linear_bsym is not None:
54 | interesting_pairs.append((current_column_wise_parallel_linear_bsym, bsym))
55 | case _:
56 | if current_column_wise_parallel_linear_bsym is not None:
57 | current_column_wise_parallel_linear_bsym = None
58 | case dist_prims.PrimIDs.SYNCHRONIZE_TENSOR_PARALLEL_OUTPUT:
59 | match _get_tensor_parallel_layer_type(bsym):
60 | case TensorParallelLayerType.COLUMN_PARALLEL_LINEAR:
61 | current_column_wise_parallel_linear_bsym = bsym
62 | case _:
63 | pass
64 |
65 | new_trace = from_trace(trace)
66 | new_trace.bound_symbols = new_bsyms
67 |
68 | if interesting_pairs:
69 | indices_to_filter = []
70 | # For column-parallel linear: postprocessed -> col-parallel linear output
71 | # For row-parallel linear: preprocessed -> col-parallel linear output
72 | swap_map: dict[VariableInterface, ProxyInterface] = {}
73 | for col_postprocess_bsym, row_preprocess_bsym in interesting_pairs:
74 |
75 | col_liinear_output: TensorProxy = col_postprocess_bsym.flat_proxy_args[0]
76 | utils.check_type(col_liinear_output, TensorProxy)
77 |
78 | row_linear_input: TensorProxy = row_preprocess_bsym.flat_proxy_outs[0]
79 | utils.check_type(row_linear_input, TensorProxy)
80 |
81 | # TODO(crcrpar): Better to make sure that between column-wise parallel linear row-wise parallel linear,
82 | # the existing bsyms are elementwise.
83 | if col_liinear_output.shape == row_linear_input.shape:
84 | indices_to_filter.extend([bsym_to_idx[col_postprocess_bsym], bsym_to_idx[row_preprocess_bsym]])
85 |
86 | swap_map[variableify(col_postprocess_bsym.flat_proxy_outs[0])] = col_liinear_output
87 |
88 | orig_row_linear_input: TensorProxy = row_preprocess_bsym.flat_proxy_args[0]
89 | swap_map[variableify(row_linear_input)] = orig_row_linear_input
90 |
91 | indices_to_filter = set(indices_to_filter)
92 | new_bsyms: list[BoundSymbol] = []
93 | for idx, bsym in enumerate(trace.bound_symbols):
94 | if idx in indices_to_filter:
95 | continue
96 | new_bsyms.append(bsym.from_bsym_swap_proxies(swap_map=swap_map, skip_output=True))
97 | new_trace.bound_symbols = new_bsyms
98 |
99 | return new_trace
100 |
--------------------------------------------------------------------------------
/thunder/distributed/transforms/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | if torch.distributed.is_available():
4 | from .ddp import apply_bucketing_to_grad_allreduce
5 | from .fsdp import FSDPCommBucketing
6 | else:
7 | apply_bucketing_to_grad_allreduce = None
8 | FSDPCommBucketing = None
9 |
10 | __all__ = [
11 | "apply_bucketing_to_grad_allreduce",
12 | "FSDPCommBucketing",
13 | ]
14 |
--------------------------------------------------------------------------------
/thunder/dynamo/__init__.py:
--------------------------------------------------------------------------------
1 | from thunder.dynamo.compiler import ThunderCompiler, thunderfx, thunder_profile, thunder_optimize
2 |
3 |
4 | __all__ = ["ThunderCompiler", "thunderfx", "thunder_profile", "thunder_optimize"]
5 |
--------------------------------------------------------------------------------
/thunder/dynamo/repro_script_template.py:
--------------------------------------------------------------------------------
1 | FXGRAPH_CLASS_NAME = "DynamoModule"
2 | INPUTS_NAME = "inputs"
3 | CALLABLE_NAME = "model"
4 | COMPILED_CALLABLE_NAME = "compiled_model"
5 |
6 | pytest_benchmark_multi_exe_code_template = '''
7 | # NOTE: This script requires `pytest-benchmark==4.0.0` to be installed.
8 | # To execute the script, run `pytest {graph_name}_benchmark.py --benchmark-timer=torch.utils.benchmark.utils.timer.timer --benchmark-warmup=on --benchmark-group-by=param:compute_type`
9 | # To check the peak allocated CUDA memory, use --benchmark-json=json_file_name and look at the "max_allocated_memory_MB" field in the json file
10 | # To run tests for a specific compute_type, use the pytest `-k` option.
11 | # For example:
12 | # - `-k "forward"` will run only the forward pass.
13 | #
14 | # Available options:
15 | # - compute_type: "forward", "backward"
16 |
17 | import pytest
18 | from thunder.benchmarks.targets import parametrize_compute_type_only_training, benchmark_for_compute_type, ComputeType
19 | {torch_import_str}
20 | {import_str}
21 |
22 | # NOTE: The reproducer function has already been processed by TorchDynamo.
23 | # If we let it go through TorchDynamo again, it could be segmented further.
24 | # To avoid this, we directly use Inductor here.
25 | # See issue https://github.com/Lightning-AI/lightning-thunder/issues/1521
26 | def torch_inductor(fn, inputs):
27 | from torch._inductor import compile as inductor_compile
28 | from torch.fx import symbolic_trace
29 |
30 | fx_graph = symbolic_trace(fn)
31 | return inductor_compile(fx_graph, inputs)
32 |
33 | {executors}
34 | {executor_names}
35 |
36 | {dynamo_module}
37 |
38 | @pytest.mark.parametrize(
39 | "executor",
40 | executors,
41 | ids=executor_names,
42 | )
43 | {compute_type_decorator}
44 | def test_{graph_name}(benchmark, executor, compute_type):
45 | {inputs}
46 |
47 | model = DynamoModule()
48 | if executor is None:
49 | compiled_model = model
50 | elif executor == torch_inductor:
51 | compiled_model = executor(model, inputs)
52 | else:
53 | compiled_model = executor(model)
54 | {call_benchmark}
55 |
56 | """
57 | Environment information get from `torch.utils.collect_env.get_pretty_env_info()`:
58 | {torch_env}
59 |
60 | Versions of Thunder related libraries:
61 | {thunder_pkgs}
62 |
63 | {extra_comment_str}
64 | """
65 | '''
66 |
67 |
68 | bsym_torch_compile_repro_template = '''
69 | """
70 | {extra_comment_str}
71 | """
72 | {python_func}
73 |
74 | from thunder.executors.torch_compile import make_compiled as make_torch_compile_callable
75 | import thunder.examine
76 |
77 | inputs = {inputs}
78 |
79 | jfn = thunder.jit({func_name})
80 | jfn(*inputs)
81 |
82 | trc = thunder.last_traces(jfn)[-1]
83 | fusion_symbols = thunder.examine.get_fusion_symbols(trc)
84 | assert len(fusion_symbols) == 1
85 | bsym = fusion_symbols[0]
86 |
87 | # NOTE: The nvFusion function cannot be compiled directly using `torch.compile`.
88 | # It must first be processed by Thunder into BoundSymbols and compiled with `make_torch_compile_callable`.
89 | # Additionally, it's recommended to visually verify that `bsym` matches the
90 | # `nvFusion` function above by printing it using `print(bsym)`.
91 | torch_compiled_callable = make_torch_compile_callable(bsym.subsymbols, bsym.flat_args, bsym.flat_outs)
92 | '''
93 |
94 | repro_bench_code_template = f"""
95 | {{import_str}}
96 |
97 | {{dynamo_module}}
98 | def test_{{graph_name}}():
99 | {{inputs}}
100 |
101 | model = {FXGRAPH_CLASS_NAME}()
102 | """
103 |
104 | main_code = """
105 | if __name__ == "__main__":
106 | test_{graph_name}()
107 | """
108 |
109 | comment_str_template = '''
110 | """
111 | Environment information get from `torch.utils.collect_env.get_pretty_env_info()`:
112 | {torch_env}
113 |
114 | Versions of Thunder related libraries:
115 | {thunder_pkgs}
116 |
117 | {extra_comment_str}
118 | """
119 | '''
120 |
--------------------------------------------------------------------------------
/thunder/executors/__init__.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Any, List, Tuple
2 | from collections.abc import Sequence
3 |
4 | import thunder.executors.passes as passes
5 |
6 | import thunder.extend as extend
7 |
8 |
9 | # NOTE The executors submodule depends on the extend submodule
10 |
11 | __all__ = [
12 | "passes",
13 | "get_torch_executor",
14 | "get_nvfuser_executor",
15 | "nvfuser_available",
16 | ]
17 |
18 |
19 | def get_nvfuser_executor() -> None | extend.Executor:
20 | return extend.get_executor("nvfuser")
21 |
22 |
23 | def get_torch_executor() -> extend.Executor:
24 | return extend.get_executor("torch")
25 |
26 |
27 | def nvfuser_available() -> bool:
28 | return get_nvfuser_executor() is not None
29 |
--------------------------------------------------------------------------------
/thunder/executors/apex_fused_rms_norm_impl.py:
--------------------------------------------------------------------------------
1 | from collections.abc import Sequence
2 | import math
3 |
4 | import torch
5 |
6 | import thunder
7 | from thunder.core.proxies import TensorProxy, AnyProxy
8 | from thunder.core.transforms import get_grad, put_grads
9 | from thunder.executors.utils import Context, set_saved_tensors
10 | from thunder.torch import TensorLike
11 | from thunder.core.compile_data import get_compile_option
12 | from thunder.executors.apexex import apex_ex
13 |
14 |
15 | APEX_FUSED_NORMS_AVAILABLE = True
16 | try:
17 | # Fused layer norm is only importable if torch.distributed is available
18 | # https://github.com/NVIDIA/apex/issues/1853
19 | from torch.distributed import is_available
20 |
21 | if not is_available():
22 | raise ImportError
23 | import fused_layer_norm_cuda
24 | from apex.normalization.fused_layer_norm import FusedRMSNormAffineMixedDtypesFunction
25 | except ImportError:
26 | APEX_FUSED_NORMS_AVAILABLE = False
27 |
28 |
29 | def apex_fused_norms_available() -> bool:
30 | return APEX_FUSED_NORMS_AVAILABLE
31 |
32 |
33 | def apex_fused_rms_norm_forward_affine_meta(
34 | input: TensorLike, normalized_shape: Sequence[int], weight: TensorLike, eps: float
35 | ):
36 | output_or_input = TensorProxy(like=input)
37 | weight = TensorProxy(like=input, shape=normalized_shape)
38 | unnormalized_dims = len(input.shape) - len(normalized_shape)
39 | invvar = TensorProxy(like=input, shape=(math.prod(input.shape[:unnormalized_dims]),))
40 | return TensorProxy(like=input), invvar
41 |
42 |
43 | def apex_fused_rms_norm_backward_affine_meta(
44 | grad_output: TensorLike,
45 | invvar: TensorLike,
46 | input_or_output: TensorLike,
47 | normalized_shape: Sequence[int],
48 | weight_,
49 | eps: float,
50 | memory_efficient: bool,
51 | ):
52 | return TensorProxy(like=grad_output), TensorProxy(like=weight_)
53 |
54 |
55 | # Create a new symbol and register lookaside only if import is available.
56 | if apex_fused_norms_available():
57 | apex_fused_rms_norm_forward_affine = apex_ex.register_operator(
58 | "apex_fused_rms_norm_forward_affine",
59 | meta=apex_fused_rms_norm_forward_affine_meta,
60 | fn=fused_layer_norm_cuda.rms_forward_affine,
61 | replaces=fused_layer_norm_cuda.rms_forward_affine,
62 | )
63 |
64 | apex_fused_rms_norm_forward_affine_mixed_dtypes = apex_ex.register_operator(
65 | "apex_fused_rms_norm_forward_affine_mixed_dtypes",
66 | meta=apex_fused_rms_norm_forward_affine_meta,
67 | fn=fused_layer_norm_cuda.rms_forward_affine_mixed_dtypes,
68 | replaces=fused_layer_norm_cuda.rms_forward_affine_mixed_dtypes,
69 | )
70 |
71 | apex_fused_rms_norm_backward_affine = apex_ex.register_operator(
72 | "apex_fused_rms_norm_backward_affine",
73 | meta=apex_fused_rms_norm_backward_affine_meta,
74 | fn=fused_layer_norm_cuda.rms_backward_affine,
75 | replaces=fused_layer_norm_cuda.rms_backward_affine,
76 | )
77 |
--------------------------------------------------------------------------------
/thunder/executors/apexex.py:
--------------------------------------------------------------------------------
1 | from thunder.extend import OperatorExecutor, register_executor
2 |
3 |
4 | # TODO Does apex have a version this should use?
5 | apex_ex = OperatorExecutor("apex", version="0.1")
6 | register_executor(apex_ex)
7 |
8 | import thunder.executors.apex_entropyex_impl
9 | import thunder.executors.apex_fused_rms_norm_impl as apex_fused_rms_norm_impl
10 |
11 | from thunder.executors.apex_entropyex_impl import apex_entropy_available
12 | from thunder.executors.apex_fused_rms_norm_impl import apex_fused_norms_available
13 |
14 | __all__ = ["apex_ex", "apex_entropy_available", "apex_fused_norms_available"]
15 |
--------------------------------------------------------------------------------
/thunder/executors/nvfuserex.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from looseversion import LooseVersion
3 |
4 | from thunder.extend import FusionExecutor
5 |
6 | __all__ = ["nvfuser_version", "required_nvfuser_version", "nvfuser_available", "nvfuserex"]
7 |
8 |
9 | #
10 | # Functions for detecting nvFuser and its version
11 | #
12 | def nvfuser_version() -> LooseVersion | None:
13 | # Short-circuits if CUDA isn't available
14 | if not torch.cuda.is_available():
15 | return None
16 |
17 | try:
18 | import nvfuser
19 |
20 | if hasattr(nvfuser, "version"):
21 | return LooseVersion(nvfuser.version())
22 |
23 | # NOTE: This import of nvFuser may or may not have version info
24 | return LooseVersion("0.0.0")
25 | except ImportError:
26 | pass
27 |
28 | # NOTE This occurs when nvFuser couldn't be imported
29 | return None
30 |
31 |
32 | def required_nvfuser_version() -> LooseVersion:
33 | return LooseVersion("0.2.8")
34 |
35 |
36 | def nvfuser_available() -> bool:
37 | v = nvfuser_version()
38 | if v is None:
39 | return False
40 |
41 | required = required_nvfuser_version()
42 | if v < required:
43 | import warnings
44 |
45 | msg = f"Your nvfuser installation is out of date. Thunder requires version {required}, but found version {v}."
46 | warnings.warn(msg)
47 | return False
48 | return True
49 |
50 |
51 | nvfuserex: None | FusionExecutor = None
52 | if nvfuser_available():
53 | import thunder.executors.nvfuserex_impl as impl
54 |
55 | nvfuserex = impl.ex
56 |
--------------------------------------------------------------------------------
/thunder/executors/transformer_engine_v2ex.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | from lightning_utilities.core.imports import package_available
4 |
5 | from thunder import Transform
6 | from thunder.extend import StatefulExecutor
7 |
8 | __all__ = ["transformer_engine_v2_ex", "TransformerEngineTransformV2"]
9 |
10 | transformer_engine_v2_ex: None | StatefulExecutor = None
11 | TransformerEngineTransformV2: None | Transform = None
12 |
13 | if package_available("transformer_engine"):
14 | import thunder.executors.transformer_engine_v2ex_impl as impl
15 |
16 | transformer_engine_v2_ex = impl.transformer_engine_v2_ex
17 | TransformerEngineTransformV2 = impl.TransformerEngineTransformV2
18 |
19 | else:
20 | warnings.warn("transformer_engine module not found!")
21 |
--------------------------------------------------------------------------------
/thunder/executors/triton_crossentropy.py:
--------------------------------------------------------------------------------
1 | from thunder.executors import triton_utils
2 | from thunder.extend import OperatorExecutor
3 |
4 | triton_version: None | str = triton_utils.triton_version()
5 |
6 | triton_ex: None | OperatorExecutor = None
7 | if triton_version is not None:
8 | try:
9 | from thunder.executors.triton_crossentropy_impl import triton_ex as impl_ex
10 |
11 | triton_ex = impl_ex
12 | except Exception:
13 | import warnings
14 |
15 | warnings.warn("triton is present but cannot be initialized")
16 | triton_version = None
17 |
--------------------------------------------------------------------------------
/thunder/executors/triton_utils.py:
--------------------------------------------------------------------------------
1 | from looseversion import LooseVersion
2 | from lightning_utilities.core.imports import package_available
3 |
4 |
5 | def triton_version() -> None | str:
6 | if not package_available("triton"):
7 | return None
8 |
9 | import triton
10 |
11 | return triton.__version__
12 |
13 |
14 | def is_triton_version_at_least(minimum_version: str) -> bool:
15 | version = triton_version()
16 |
17 | if version is None:
18 | return False
19 |
20 | return LooseVersion(version) >= minimum_version
21 |
--------------------------------------------------------------------------------
/thunder/numpy/__init__.py:
--------------------------------------------------------------------------------
1 | from numbers import Number
2 | from typing import Any
3 | from collections.abc import Callable
4 |
5 | import numpy as np
6 |
7 | from thunder.core.langctx import langctx, Languages
8 | from thunder.numpy.langctx import register_method
9 |
10 | from thunder.core.proxies import TensorProxy
11 | from thunder.core.symbol import Symbol
12 | import thunder.clang as clang
13 |
14 |
15 | #
16 | # NumPy operator definitions
17 | #
18 | # NOTE NumPy language support is demonstrative. PRs extending it are welcome!
19 |
20 |
21 | # Decorator that sets the language context and constructs a Symbol for each function
22 | class npsymbol:
23 | def __init__(self, *, method_name: None | str = None):
24 | self.method_name: None | str = method_name
25 |
26 | def __call__(self, fn: Callable) -> Symbol:
27 | _fn = langctx(Languages.NUMPY)(fn)
28 | # TODO: register _fn as opaque with the interpreter or do this in jit_ext?
29 | sym = Symbol(name=fn.__name__, meta=_fn)
30 |
31 | if self.method_name is not None:
32 | register_method(self.method_name, _fn)
33 |
34 | return sym
35 |
36 |
37 | #
38 | # Tensor properties
39 | #
40 |
41 |
42 | # NOTE Named `compute_len` so that it doesn't conflict with built-in `len`
43 | def compute_len(a: TensorProxy, /) -> int:
44 | return a.shape[0]
45 |
46 |
47 | register_method("len", compute_len)
48 |
49 |
50 | def size(a: TensorProxy, /) -> int:
51 | return a.numel
52 |
53 |
54 | register_method("size", size)
55 |
56 |
57 | #
58 | # Elementwise binary operators
59 | #
60 |
61 |
62 | # TODO Create a factory that adds ufunc support to elementwise operations
63 | npsymbol(method_name="add")
64 |
65 |
66 | def add(a: Number | TensorProxy, b: Number | TensorProxy, /, *, where: None | Number | TensorProxy = None):
67 | result = clang.add(a, b)
68 | if where is not None:
69 | return clang.where(where, result, a)
70 | return result
71 |
--------------------------------------------------------------------------------
/thunder/numpy/langctx.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 | from collections.abc import Callable, Sequence
3 |
4 | from thunder.core.langctxs import LanguageContext, register_langctx, Languages, resolve_language
5 | from thunder.core.pytree import tree_flatten
6 | from thunder.core.proxies import TensorProxy
7 |
8 | #
9 | # Creates and registers the torch language context
10 | #
11 | # NOTE That this is done separately from the definition of thunder.torch operations, because the
12 | # language context must be available before those operations are defined
13 |
14 | _method_name_to_fn_map: dict[str, Callable] = {}
15 |
16 |
17 | # Creates and registers the core language language context
18 | class NumPyCtx(LanguageContext):
19 | def __init__(self):
20 | super().__init__("numpy")
21 |
22 | def has_method(self, id: str) -> bool:
23 | return id in _method_name_to_fn_map
24 |
25 | def get_method(self, id: str, *args, **kwargs) -> Callable:
26 | # Note: concrete implementations should only raise AttributeError or
27 | # return None for "missing" methods as the proxies will
28 | # route __getattr__ to here and hasattr relies on __getattr__
29 | # throwing AttributeError (only) when the attribute does
30 | # not exist.
31 | inps, _ = tree_flatten((args, kwargs))
32 |
33 | has_tensor_input: bool = False
34 | for x in inps:
35 | if isinstance(x, TensorProxy):
36 | has_tensor_input = True
37 | break
38 |
39 | if has_tensor_input:
40 | method: None | Callable = _method_name_to_fn_map.get(id, None)
41 | if method is None:
42 | raise AttributeError(f"The {self.name} language context has no method {id}")
43 | return method
44 |
45 | # has_tensor_input is False
46 | # Defers to the primitive language context when there are no tensor inputs=
47 | # (the primitive language context handles operations on numbers)
48 | primsctx: LanguageContext = resolve_language(Languages.PRIMS)
49 | if not primsctx.has_method(id):
50 | raise AttributeError(
51 | f"Attempting to call method {id} in the numpy language context, but it has no tensor inputs and the primitive language context (which handles numbers) doesn't have the method"
52 | )
53 | prim_method: Callable = primsctx.get_method(id, *args, **kwargs)
54 | return prim_method
55 |
56 |
57 | numpyctx = NumPyCtx()
58 | register_langctx(Languages.NUMPY, numpyctx)
59 | register_langctx("numpy", numpyctx)
60 |
61 |
62 | # Registers a method with the torch language context
63 | def register_method(method_name: str, method: Callable, /) -> None:
64 | _method_name_to_fn_map[method_name] = method
65 |
--------------------------------------------------------------------------------
/thunder/plugins/__init__.py:
--------------------------------------------------------------------------------
1 | from thunder.plugins.distributed import DDP, FSDP
2 | from thunder.plugins.quantization import QuantizeInt4
3 | from thunder.plugins.fp8 import FP8
4 | from thunder.plugins.reduce_overhead import ReduceOverhead
5 |
6 | names_to_plugins = {
7 | "ddp": DDP,
8 | "fsdp": FSDP,
9 | "quantize-int4": QuantizeInt4,
10 | "fp8": FP8,
11 | "reduce-overhead": ReduceOverhead,
12 | }
13 |
14 |
15 | def get_plugin(name):
16 | return names_to_plugins.get(name)
17 |
18 |
19 | def get_plugin_names():
20 | return list(names_to_plugins.keys())
21 |
22 |
23 | def register_plugin(name, cls):
24 | names_to_plugins[name] = cls
25 |
--------------------------------------------------------------------------------
/thunder/plugins/fp8.py:
--------------------------------------------------------------------------------
1 | from thunder import Plugin
2 |
3 |
4 | class FP8(Plugin):
5 | def setup_executors(self):
6 | from thunder.executors.transformer_engineex import transformer_engine_ex
7 |
8 | return [transformer_engine_ex]
9 |
--------------------------------------------------------------------------------
/thunder/plugins/quantization.py:
--------------------------------------------------------------------------------
1 | from thunder import Plugin
2 |
3 |
4 | class QuantizeInt4(Plugin):
5 | def setup_transforms(self):
6 | from thunder.transforms.quantization import BitsAndBytesLinearQuant4bit
7 |
8 | return [BitsAndBytesLinearQuant4bit()]
9 |
10 | def setup_executors(self):
11 | from thunder.transforms.quantization import get_bitsandbytes_executor
12 |
13 | return [get_bitsandbytes_executor()]
14 |
--------------------------------------------------------------------------------
/thunder/plugins/reduce_overhead.py:
--------------------------------------------------------------------------------
1 | from thunder import Plugin
2 | from thunder.core.recipe import Plugin, PluginPolicy
3 | from thunder.transforms.cudagraph import CUDAGraphTransform
4 |
5 |
6 | class ReduceOverhead(Plugin):
7 | policy = PluginPolicy.POST
8 |
9 | def setup_transforms(self):
10 | return [CUDAGraphTransform()]
11 |
--------------------------------------------------------------------------------
/thunder/py.typed:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-AI/lightning-thunder/8bd802c89988e6668f5328d3927deda1b4e9e706/thunder/py.typed
--------------------------------------------------------------------------------
/thunder/recipes/__init__.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Type
2 |
3 | from thunder.recipes.base import BaseRecipe
4 | from thunder.recipes.hf_transformers import HFTransformers
5 |
6 |
7 | names_to_recipes: dict[str : type[Any]] = {
8 | "default": BaseRecipe,
9 | "hf-transformers": HFTransformers,
10 | }
11 |
12 |
13 | def get_recipe_class(name: str) -> type[Any]:
14 | return names_to_recipes.get(name)
15 |
16 |
17 | def get_recipes() -> list[str]:
18 | return list(names_to_recipes.keys())
19 |
20 |
21 | def register_recipe(name: str, cls: type[Any]):
22 | if name == "auto":
23 | raise ValueError("Recipe name 'auto' is reserved.")
24 | names_to_recipes[name] = cls
25 |
--------------------------------------------------------------------------------
/thunder/tests/README.md:
--------------------------------------------------------------------------------
1 | Testing is important!
2 |
3 | Most Thunder operators are tested by automatically generated tests that use data from the operator's "OpInfo", defined
4 | in `opinfos.py`. These tests typically compare Thunder's operator behavior with another framework, using the
5 | "sample inputs" defined by its OpInfo. Each OpInfo should define a sample input generator and at least one “reference” implementation in PyTorch or JAX (and in the future, NumPy).
6 |
7 | When an operator is very similar across the three language levels (user, core, and primitive), like cos, only one variation is typically tested (the core language variant is preferred).
8 | However, there are also cases where testing a torch language operation or primitive operation directly makes sense.
9 |
10 | Operator tests are autogenerated in `test_ops.py`. To run the tests for a particular operator, use pytest’s -k option.
11 | For example, to run the tests for `cos`, for example, the command would be
12 |
13 | ```bash
14 | pytest test_ops.py -k cos -v
15 | ```
16 |
17 | This will run tests for Thunder’s different executors, supported dtypes, and supported devicetypes.
18 |
19 | The OpInfo tests can also be run from Python, and this can be extremely useful for debugging.
20 | To do that, find the generated test name you’d like to run (often by invoking the tests from the command line and observing their names) and call it. Here’s a sample program:
21 |
22 | ```py
23 | import traceback
24 | import thunder.tests.test_ops as to
25 |
26 | e, exc_info, snippet, opinfo, devicetype, dtype, args, kwargs = to.test_core_vs_torch_consistency_cos_nvFuser_CUDA_float32()
27 |
28 | traceback.print_exception(*exc_info)
29 | ```
30 |
31 | If the test fails, it will return information about the failure, including error information and the arguments that caused the failure.
32 | In the above sample the traceback information is printed. If the test succeeds then it will return nothing.
33 |
34 | It can be a little tricky to remember all the components a test returns, but you can
35 | always catch and print the return value to better understand what's available.
36 |
--------------------------------------------------------------------------------
/thunder/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-AI/lightning-thunder/8bd802c89988e6668f5328d3927deda1b4e9e706/thunder/tests/__init__.py
--------------------------------------------------------------------------------
/thunder/tests/bf16.py:
--------------------------------------------------------------------------------
1 | import thunder.core
2 | import torch
3 |
4 |
5 | def device_supports_bf16(device: None | str | torch.device | thunder.core.devices.Device, /) -> bool:
6 | """Torch has its own `torch.cuda.is_bf16_supported()`, but the contract of
7 | this API changed in December 2023 to be "can bf16 be represented?", which
8 | is true even for devices that existed before bf16 was invented.
9 |
10 | The contract for this API is that the device implements bf16 operations."""
11 | if not torch.cuda.is_available():
12 | return False
13 |
14 | dev: torch.device = thunder.core.devices.to_torch_device(device)
15 |
16 | if dev.type != "cuda":
17 | return False
18 |
19 | cuda_major: int
20 | cuda_minor: int
21 | cuda_major, cuda_minor = torch.cuda.get_device_capability(dev)
22 | return cuda_major >= 8
23 |
--------------------------------------------------------------------------------
/thunder/tests/conftest.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import pytest_benchmark
3 | from thunder.dynamo.compiler_graph_benchmark import GRAPH_BY_GRAPH_BENCHMARK_PARAMS_KEYS
4 |
5 |
6 | @pytest.hookimpl(hookwrapper=True)
7 | def pytest_benchmark_group_stats(config, benchmarks, group_by):
8 | """
9 | The function customize the behavior for ThunderCompilerGraphBenchmarking.
10 | The custom grouping function is only invoked when the `--benchmark-group-by`
11 | option is set to 'graph-by-graph:param:GraphID,param:SplitModuleName'.
12 | For an example, refer to the comment section in `ThunderCompilerGraphBenchmarking`.
13 |
14 | Reference: https://pytest-benchmark.readthedocs.io/en/latest/hooks.html#pytest_benchmark.hookspec.pytest_benchmark_group_stats
15 | """
16 | prefix = "graph-by-graph:"
17 | outcome = yield
18 | if group_by.startswith(prefix):
19 | group_by = group_by[len(prefix) :]
20 | for bench in benchmarks:
21 | if bench["params"] is None:
22 | bench["params"] = {}
23 | # The benchs with the same `params`` share the same dict
24 | # We need to create a deepcopy of the original dictionary to add parameters specific to each graph.
25 | else:
26 | bench["params"] = bench["params"].copy()
27 | if bench["param"] is None:
28 | bench["param"] = ""
29 |
30 | name = bench["name"]
31 | gid, module_name, ex = name.split("-")[-3:]
32 | # Add the "GraphID", "SplitModuleName","executor" as params in benchmark
33 | gid_key, module_name_key, ex_key = GRAPH_BY_GRAPH_BENCHMARK_PARAMS_KEYS
34 | bench["params"].update({gid_key: gid, module_name_key: module_name, ex_key: ex})
35 | bench["param"] += f"-{gid}-{module_name}-{ex}"
36 |
37 | result = pytest_benchmark.plugin.pytest_benchmark_group_stats(config, benchmarks, group_by)
38 | outcome.force_result(result)
39 |
40 |
41 | def pytest_collection_modifyitems(items):
42 | items.sort(key=lambda item: item.name)
43 |
--------------------------------------------------------------------------------
/thunder/tests/distributed/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-AI/lightning-thunder/8bd802c89988e6668f5328d3927deda1b4e9e706/thunder/tests/distributed/__init__.py
--------------------------------------------------------------------------------
/thunder/tests/distributed/modules.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import ClassVar, TYPE_CHECKING
3 |
4 | import torch.nn as nn
5 |
6 | from thunder.core import utils
7 |
8 | if TYPE_CHECKING:
9 | import torch
10 |
11 |
12 | __all__ = [
13 | "ParallelMLP",
14 | ]
15 |
16 |
17 | class ParallelMLP(nn.Module):
18 | """Simplified version of Megatron/NeMo's ParallelMLP.
19 |
20 | Ref: https://github.com/NVIDIA/NeMo/blob/95ca2f4/nemo/collections/nlp/modules/common/megatron/mlp.py#L61
21 | """
22 |
23 | COLUMN_WISE: ClassVar[tuple[str]] = ("dense_h_to_4h",)
24 | ROW_WISE: ClassVar[tuple[str]] = ("dense_4h_to_h",)
25 |
26 | SUPPORTED_GELU_APPROX: ClassVar[tuple[str, str]] = ("none", "tanh")
27 |
28 | def __init__(
29 | self,
30 | hidden_size: int,
31 | ffn_hidden_size: int | None = None,
32 | bias: bool = True,
33 | gelu_approximate: str = "none",
34 | ) -> None:
35 | utils.check(
36 | gelu_approximate in ParallelMLP.SUPPORTED_GELU_APPROX,
37 | lambda: f"Invalid {gelu_approximate}, supported are {ParallelMLP.SUPPORTED_GELU_APPROX}",
38 | )
39 | if ffn_hidden_size is None:
40 | ffn_hidden_size = 4 * hidden_size
41 |
42 | super().__init__()
43 | self.dense_h_to_4h = nn.Linear(hidden_size, ffn_hidden_size, bias=bias)
44 | self.dense_4h_to_h = nn.Linear(ffn_hidden_size, hidden_size, bias=bias)
45 | self.gelu = nn.GELU(approximate=gelu_approximate)
46 |
47 | def forward(self, x: torch.Tensor) -> torch.Tensor:
48 | four_h = self.gelu(self.dense_h_to_4h(x))
49 | h = self.dense_4h_to_h(four_h)
50 | return h
51 |
--------------------------------------------------------------------------------
/thunder/tests/hf_bart_self_attn.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # This is an abbreviated version of the HuggingFace Bart Encoder Self-Attention
16 | # module. The block was editted to remove non-essential code for
17 | # self-attention.
18 |
19 | import torch
20 | from torch import nn
21 |
22 |
23 | class BartAttention(nn.Module):
24 | """Multi-headed attention from 'Attention Is All You Need' paper"""
25 |
26 | def __init__(
27 | self,
28 | embed_dim: int,
29 | num_heads: int,
30 | dropout: float = 0.0,
31 | bias: bool = True,
32 | ):
33 | super().__init__()
34 | self.embed_dim = embed_dim
35 | self.num_heads = num_heads
36 | self.dropout = dropout
37 | self.head_dim = embed_dim // num_heads
38 |
39 | if (self.head_dim * num_heads) != self.embed_dim:
40 | raise ValueError(
41 | f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
42 | f" and `num_heads`: {num_heads})."
43 | )
44 | self.scaling = self.head_dim**-0.5
45 |
46 | self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
47 | self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
48 | self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
49 | self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
50 |
51 | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
52 | return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
53 |
54 | def forward(
55 | self,
56 | hidden_states: torch.Tensor,
57 | attention_mask: torch.Tensor,
58 | ) -> torch.Tensor:
59 | """Input shape: Batch x Time x Channel"""
60 | bsz, tgt_len, _ = hidden_states.size()
61 |
62 | # get query proj
63 | query_states = self.q_proj(hidden_states) * self.scaling
64 | key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
65 | value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
66 |
67 | proj_shape = (bsz * self.num_heads, -1, self.head_dim)
68 | query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
69 | key_states = key_states.view(*proj_shape)
70 | value_states = value_states.view(*proj_shape)
71 |
72 | src_len = key_states.size(1)
73 | attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
74 |
75 | if attention_mask is not None:
76 | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
77 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
78 |
79 | attn_weights = nn.functional.softmax(attn_weights, dim=-1)
80 |
81 | attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
82 |
83 | attn_output = torch.bmm(attn_probs, value_states)
84 |
85 | attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
86 | attn_output = attn_output.transpose(1, 2)
87 |
88 | # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
89 | # partitioned aross GPUs when using tensor-parallelism.
90 | attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
91 |
92 | attn_output = self.out_proj(attn_output)
93 |
94 | return attn_output
95 |
--------------------------------------------------------------------------------
/thunder/tests/litgpt_model.py:
--------------------------------------------------------------------------------
1 | """Taken from https://github.com/Lightning-AI/litgpt/blob/main/litgpt/model.py"""
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 |
7 | configs = [
8 | # diverse sample of configs FOR TESTING that cover all major checkpoints variants architecturally but with reduced
9 | # size
10 | dict(name="gpt-neox-like", block_size=128, n_layer=2, n_embd=64, n_head=4, padding_multiple=8),
11 | dict(
12 | name="llama1-like",
13 | block_size=128,
14 | vocab_size=320,
15 | padding_multiple=64,
16 | n_layer=2,
17 | n_head=4,
18 | n_embd=64,
19 | rotary_percentage=1.0,
20 | parallel_residual=False,
21 | bias=False,
22 | norm_class_name="RMSNorm",
23 | norm_eps=1e-6,
24 | mlp_class_name="LLaMAMLP",
25 | intermediate_size=1376,
26 | ),
27 | dict(
28 | name="long-context-like",
29 | block_size=512,
30 | vocab_size=320,
31 | padding_multiple=64,
32 | n_layer=2,
33 | n_head=4,
34 | n_embd=64,
35 | rotary_percentage=1.0,
36 | parallel_residual=False,
37 | bias=False,
38 | norm_class_name="RMSNorm",
39 | mlp_class_name="LLaMAMLP",
40 | intermediate_size=11008,
41 | rope_condense_ratio=4,
42 | ),
43 | dict(
44 | name="llama2-like",
45 | vocab_size=320,
46 | padding_multiple=64,
47 | n_layer=2,
48 | n_head=4,
49 | n_embd=64,
50 | rotary_percentage=1.0,
51 | parallel_residual=False,
52 | bias=False,
53 | norm_class_name="RMSNorm",
54 | mlp_class_name="LLaMAMLP",
55 | intermediate_size=1376,
56 | ),
57 | dict(
58 | name="falcon-7b-like",
59 | block_size=128,
60 | padded_vocab_size=254,
61 | n_layer=2,
62 | n_head=7,
63 | n_embd=448,
64 | rotary_percentage=1.0,
65 | n_query_groups=1,
66 | bias=False,
67 | shared_attention_norm=True,
68 | ),
69 | dict(
70 | name="falcon-40b-like",
71 | block_size=128,
72 | padded_vocab_size=508,
73 | n_layer=2,
74 | n_head=64,
75 | n_embd=256,
76 | rotary_percentage=1.0,
77 | n_query_groups=4,
78 | bias=False,
79 | ),
80 | dict(
81 | name="codellama2-like",
82 | block_size=1024,
83 | vocab_size=2001,
84 | padding_multiple=16,
85 | n_layer=2,
86 | n_head=4,
87 | n_embd=64,
88 | rotary_percentage=1.0,
89 | parallel_residual=False,
90 | bias=False,
91 | norm_class_name="RMSNorm",
92 | norm_eps=1e-05,
93 | mlp_class_name="LLaMAMLP",
94 | intermediate_size=1376,
95 | rope_base=1000000,
96 | ),
97 | dict(
98 | name="mixtral-like",
99 | block_size=512,
100 | padded_vocab_size=500,
101 | n_layer=2,
102 | n_head=64,
103 | n_embd=256,
104 | rotary_percentage=1.0,
105 | n_query_groups=8,
106 | parallel_residual=False,
107 | bias=False,
108 | norm_class_name="RMSNorm",
109 | norm_eps=1e-05,
110 | mlp_class_name="LLaMAMoE",
111 | intermediate_size=224,
112 | rope_base=1000000,
113 | n_expert=8,
114 | n_expert_per_token=2,
115 | ),
116 | ]
117 |
118 | name_to_config = {config["name"]: config for config in configs}
119 |
120 |
121 | import litgpt
122 |
123 | # add the testing configurations
124 | litgpt.config.name_to_config.update(name_to_config)
125 | name_to_config.update(litgpt.config.name_to_config)
126 |
127 | # manually expose for backwards compatibility
128 | Config = litgpt.Config
129 | GPT = litgpt.GPT
130 | RMSNorm = litgpt.model.RMSNorm
131 | CausalSelfAttention = litgpt.model.CausalSelfAttention
132 | LLaMAMLP = litgpt.model.LLaMAMLP
133 | build_rope_cache = litgpt.model.build_rope_cache
134 | apply_rope = litgpt.model.apply_rope
135 | Block = litgpt.model.Block
136 |
--------------------------------------------------------------------------------
/thunder/tests/make_tensor.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import thunder
3 |
4 |
5 | def make_tensor(
6 | *shape: int | torch.Size | list[int] | tuple[int, ...],
7 | dtype: torch.dtype,
8 | device: str | torch.device | thunder.devices.Device,
9 | low: float | None = None,
10 | high: float | None = None,
11 | requires_grad: bool = False,
12 | noncontiguous: bool = False,
13 | exclude_zero: bool = False,
14 | memory_format: torch.memory_format | None = None,
15 | ) -> torch.Tensor:
16 | r"""Creates a tensor with the given :attr:`shape`, :attr:`device`, and :attr:`dtype`, and filled with
17 | values uniformly drawn from ``[low, high)``.
18 | Calls torch.testing.make_tensor and optionally torch.Tensor.fill_ to allow for low == high, as
19 | torch.testing.make_tensor enforces low < high.
20 |
21 | Args:
22 | shape (Tuple[int, ...]): Single integer or a sequence of integers defining the shape of the output tensor.
23 | dtype (:class:`torch.dtype`): The data type of the returned tensor.
24 | device (Union[str, torch.device, thunder.devices.Device]): The device of the returned tensor.
25 | low (Optional[Number]): Sets the lower limit (inclusive) of the given range. If a number is provided it is
26 | clamped to the least representable finite value of the given dtype. When ``None`` (default),
27 | this value is determined based on the :attr:`dtype` (see the table above). Default: ``None``.
28 | high (Optional[Number]): Sets the upper limit (exclusive) of the given range. If a number is provided it is
29 | clamped to the greatest representable finite value of the given dtype. When ``None`` (default) this value
30 | is determined based on the :attr:`dtype` (see the table above). Default: ``None``.
31 | requires_grad (Optional[bool]): If autograd should record operations on the returned tensor. Default: ``False``.
32 | noncontiguous (Optional[bool]): If `True`, the returned tensor will be noncontiguous. This argument is
33 | ignored if the constructed tensor has fewer than two elements.
34 | exclude_zero (Optional[bool]): If ``True`` then zeros are replaced with the dtype's small positive value
35 | depending on the :attr:`dtype`. For bool and integer types zero is replaced with one. For floating
36 | point types it is replaced with the dtype's smallest positive normal number (the "tiny" value of the
37 | :attr:`dtype`'s :func:`~torch.finfo` object), and for complex types it is replaced with a complex number
38 | whose real and imaginary parts are both the smallest positive normal number representable by the complex
39 | type. Default ``False``.
40 | memory_format (Optional[torch.memory_format]): The memory format of the returned tensor. Incompatible
41 | with :attr:`noncontiguous`.
42 | Raises:
43 | ValueError: if ``requires_grad=True`` is passed for integral `dtype`
44 | ValueError: If ``low > high``.
45 | ValueError: If either :attr:`low` or :attr:`high` is ``nan``.
46 | TypeError: If :attr:`dtype` isn't supported by this function.
47 | """
48 | if isinstance(device, thunder.devices.Device):
49 | device = device.device_str()
50 |
51 | fill_value = None
52 | if low is not None and low == high:
53 | fill_value = low
54 | low = None
55 | high = None
56 |
57 | t = torch.testing.make_tensor(
58 | *shape,
59 | dtype=dtype,
60 | device=device,
61 | low=low,
62 | high=high,
63 | requires_grad=requires_grad,
64 | noncontiguous=noncontiguous,
65 | exclude_zero=exclude_zero,
66 | memory_format=memory_format,
67 | )
68 |
69 | if fill_value is not None:
70 | t.fill_(fill_value)
71 |
72 | return t
73 |
74 |
75 | def make_tensor_like(a, **kwargs):
76 | # type: (torch.Tensor) -> torch.Tensor
77 | """Returns a tensor with the same properties as the given tensor.
78 |
79 | Args:
80 | a (torch.Tensor): The tensor to copy properties from.
81 | kwargs (dict): Additional properties for `make_tensor`.
82 |
83 | Returns:
84 | torch.Tensor: A tensor with the same properties as :attr:`a`.
85 | """
86 | kwargs = kwargs | dict(device=a.device, dtype=a.dtype, requires_grad=a.requires_grad)
87 | return make_tensor(a.shape, **kwargs)
88 |
--------------------------------------------------------------------------------
/thunder/tests/module_example.py:
--------------------------------------------------------------------------------
1 | def returns_two():
2 | return 2
3 |
4 |
5 | def _returns_three():
6 | return 3
7 |
8 |
9 | def returns_five():
10 | return 5
11 |
--------------------------------------------------------------------------------
/thunder/tests/test_apex_fused_norms.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | from torch.testing import assert_close
4 |
5 | fused_layer_norm_cuda = pytest.importorskip("fused_layer_norm_cuda")
6 |
7 | from torch.distributed import is_available
8 | from thunder.executors.apexex import apex_ex
9 | import thunder
10 |
11 |
12 | # See https://github.com/NVIDIA/apex/issues/1853
13 | @pytest.mark.skipif(not is_available(), reason="torch.distributed is not available")
14 | @pytest.mark.parametrize("requires_grad", [True, False])
15 | @pytest.mark.parametrize("memory_efficient", [True, False])
16 | def test_apex_fused_rms_norm(requires_grad, memory_efficient):
17 | from apex.normalization.fused_layer_norm import FusedRMSNormAffineMixedDtypesFunction
18 |
19 | def fn(x, weight, normalized_shape, eps):
20 | return FusedRMSNormAffineMixedDtypesFunction.apply(x, weight, normalized_shape, eps, memory_efficient)
21 |
22 | device = "cuda"
23 | normalized_shape = (3, 2)
24 | x = torch.randn(4, 5, *normalized_shape, device=device, requires_grad=requires_grad)
25 | weight = torch.randn(*normalized_shape, device=device, requires_grad=requires_grad)
26 | eps = 1e-5
27 |
28 | expected = fn(x, weight, normalized_shape, eps)
29 | jfn = thunder.jit(fn, executors=[apex_ex])
30 | actual = jfn(x, weight, normalized_shape, eps)
31 |
32 | assert_close(actual, expected)
33 |
34 | if requires_grad:
35 | grad_output = torch.rand_like(actual)
36 | actual_grad = torch.autograd.grad(actual, [x, weight], grad_output)
37 | expected_grad = torch.autograd.grad(expected, [x, weight], grad_output)
38 |
39 | assert_close(actual_grad, expected_grad)
40 |
41 |
42 | # See https://github.com/NVIDIA/apex/issues/1853
43 | @pytest.mark.skipif(not is_available(), reason="torch.distributed is not available")
44 | @pytest.mark.parametrize("requires_grad", [True, False])
45 | @pytest.mark.parametrize("memory_efficient", [True, False])
46 | def test_apex_fused_rms_norm_autoregister(requires_grad, memory_efficient):
47 | from apex.normalization.fused_layer_norm import FusedRMSNormAffineMixedDtypesFunction
48 |
49 | def fn(x, weight, normalized_shape, eps):
50 | return FusedRMSNormAffineMixedDtypesFunction.apply(x, weight, normalized_shape, eps, memory_efficient)
51 |
52 | device = "cuda"
53 | normalized_shape = (3, 2)
54 | x = torch.randn(4, 5, *normalized_shape, device=device, requires_grad=requires_grad)
55 | weight = torch.randn(*normalized_shape, device=device, requires_grad=requires_grad)
56 | eps = 1e-5
57 |
58 | expected = fn(x, weight, normalized_shape, eps)
59 | jfn = thunder.jit(fn, executors=())
60 | actual = jfn(x, weight, normalized_shape, eps)
61 |
62 | assert_close(actual, expected)
63 |
64 | if requires_grad:
65 | grad_output = torch.rand_like(actual)
66 | actual_grad = torch.autograd.grad(actual, [x, weight], grad_output)
67 | expected_grad = torch.autograd.grad(expected, [x, weight], grad_output)
68 |
69 | assert_close(actual_grad, expected_grad)
70 |
--------------------------------------------------------------------------------
/thunder/tests/test_check_trace.py:
--------------------------------------------------------------------------------
1 | import thunder
2 | from thunder.dev_utils.check_trace import check_trace
3 | import torch
4 | import pytest
5 |
6 |
7 | def test_missing_symbol():
8 | def fn(a, b):
9 | c = a + b
10 | d = 2 * c
11 | return d
12 |
13 | jfn = thunder.jit(fn)
14 |
15 | a = torch.randn(2, 2)
16 | b = torch.randn(2, 2)
17 |
18 | jfn(a, b)
19 |
20 | tr = thunder.last_traces(jfn)[-1]
21 | check_trace(tr)
22 |
23 | del tr.bound_symbols[-3]
24 |
25 | with pytest.raises(AssertionError, match="unknown proxy"):
26 | check_trace(tr)
27 |
28 |
29 | def test_debug_option():
30 | class BrokenTransform(thunder.core.transform_common.Transform):
31 | def transform_traces_pre_prologue(self, pro, comp, epi, **kwargs):
32 | new_comp = thunder.core.trace.from_trace(comp)
33 | new_comp.bound_symbols = comp.bound_symbols[:]
34 | del new_comp.bound_symbols[2]
35 | return pro, new_comp, epi
36 |
37 | def fn(a, b):
38 | c = a + b
39 | d = 2 * c
40 | return d
41 |
42 | a = torch.randn(2, 2)
43 | b = torch.randn(2, 2)
44 |
45 | # works
46 | jfn = thunder.jit(fn, debug_options=thunder.DebugOptions(check_traces=True))
47 | jfn(a, b)
48 |
49 | # broken with nice error
50 | jfn = thunder.jit(fn, transforms=(BrokenTransform(),), debug_options=thunder.DebugOptions(check_traces=True))
51 |
52 | with pytest.raises(AssertionError, match="unknown proxy"):
53 | jfn(a, b)
54 |
55 | # broken with less nice error
56 | jfn = thunder.jit(fn, transforms=(BrokenTransform(),), executors=())
57 | with pytest.raises(UnboundLocalError, match="cannot access local|referenced before assignment"):
58 | jfn(a, b)
59 |
--------------------------------------------------------------------------------
/thunder/tests/test_examine.py:
--------------------------------------------------------------------------------
1 | import thunder.examine
2 | import torch
3 | import pytest
4 | from thunder.executors import nvfuser_available
5 |
6 |
7 | def test_examine_fn():
8 | def foo(x):
9 | x[0] = 5 * x[1]
10 |
11 | x = torch.ones(2, 2)
12 | thunder.examine.examine(foo, x)
13 |
14 |
15 | def test_examine_jfn():
16 | def foo(x):
17 | x[0] = 5 * x[1]
18 |
19 | jfoo = thunder.jit(foo)
20 | x = torch.ones(2, 2)
21 | thunder.examine.examine(jfoo, x)
22 |
23 |
24 | def test_examine_noncallable(capsys):
25 | x = torch.ones(2, 2)
26 | y = torch.ones(2, 2)
27 | thunder.examine.examine(x, y)
28 | captured = capsys.readouterr()
29 | assert "expected `fn` to be a callable" in captured.out
30 |
--------------------------------------------------------------------------------
/thunder/tests/test_examine_memory.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | import torch
4 |
5 | import thunder
6 | from thunder.core.pytree import tree_map
7 | import thunder.torch as ltorch
8 | from thunder.examine.memory_calculation import get_alloc_memory
9 |
10 | from thunder.tests.framework import requiresCUDA, TorchExecutor
11 | from thunder.tests.make_tensor import make_tensor
12 |
13 |
14 | def measure_memory_usage(trace):
15 | torch.cuda.reset_peak_memory_stats()
16 | before = torch.cuda.memory_stats().get("requested_bytes.all.current", 0)
17 |
18 | def make_tensor_like_torch_dtype(p):
19 | return make_tensor(p.shape, dtype=ltorch.to_torch_dtype(p.dtype), device=p.device)
20 |
21 | args, kwargs = tree_map(make_tensor_like_torch_dtype, (trace.args, trace.kwargs))
22 | output = trace.python_callable()(*args, **kwargs)
23 |
24 | after = torch.cuda.memory_stats()["requested_bytes.all.current"]
25 | peak = torch.cuda.memory_stats()["requested_bytes.all.peak"]
26 |
27 | return {"peak": peak - before, "current": after - before, "output": output}
28 |
29 |
30 | def measure_fw_and_bw_memory_usage(fw_trace, bw_trace):
31 | fw_results = measure_memory_usage(fw_trace)
32 | bw_results = measure_memory_usage(bw_trace)
33 |
34 | return {f"fw_{k}": v for k, v in fw_results.items()} | {f"bw_{k}": v for k, v in bw_results.items()}
35 |
36 |
37 | # TODO: Test for nvFuserExecutor
38 | # nvFuserExecutor is skipped for now, because nvFuser and eager execution treat allocation and broadcast differently.
39 | # In the future, we need to update get_alloc_memory to support nvFuser and update tests accordingly.
40 | @requiresCUDA
41 | def test_view_ops():
42 | def test(func, *shapes):
43 | inputs = [make_tensor(shape, dtype=torch.float32, device="cuda", requires_grad=True) for shape in shapes]
44 | cfunc = TorchExecutor.make_callable(func, disable_preprocessing=False)
45 | cfunc(*inputs)
46 |
47 | fw_trace = thunder.last_traces(cfunc)[-1]
48 | bw_trace = thunder.last_backward_traces(cfunc)[-1]
49 | max_mem_fw = get_alloc_memory(fw_trace)
50 | max_mem_bw = get_alloc_memory(bw_trace)
51 |
52 | result = measure_fw_and_bw_memory_usage(fw_trace, bw_trace)
53 | assert max_mem_fw[0] == result["fw_peak"]
54 | assert sum(max_mem_fw[1].values()) == result["fw_current"]
55 | assert max_mem_bw[0] == result["bw_peak"]
56 | assert sum(max_mem_bw[1].values()) == result["bw_current"]
57 |
58 | def foo(a, b): # [4] [4]
59 | a_1 = torch.unsqueeze(a, 0) # [1,4]
60 | b_2 = torch.unsqueeze(b, 0) # [1,4]
61 | return (a_1 + b_2,)
62 |
63 | test(foo, (4,), (4,))
64 |
65 | def bar(a, b): # [4] [2,2]
66 | a_1 = torch.unsqueeze(a, 0) # [1,4]
67 | a_2 = torch.unsqueeze(a_1, 1) # [1,1,4]
68 | a_3 = a_2.expand(2, 3, 4) # [2,3,4]
69 |
70 | b_1 = torch.reshape(b, (4,)) # [4]
71 | b_2 = torch.unsqueeze(b_1, 0) # [1,4]
72 | b_3 = torch.unsqueeze(b_2, 1) # [1,1,4]
73 | b_4 = b_3.expand(2, 3, 4) # [2,3,4]
74 |
75 | result1 = a_2 + b_3
76 | result2 = b_4 + a_3
77 | return result1, result2
78 |
79 | test(bar, (4,), (2, 2))
80 |
81 | def bar1(a, b, c): # [4], [1,4,4], [4,1,4]
82 | a_1 = torch.unsqueeze(a, 0) # [1,4]
83 | a_2 = torch.unsqueeze(a_1, 1) # [1,1,4]
84 | a_3 = a_2.expand(1, 4, 4)
85 | a_4 = a_2.expand(4, 1, 4)
86 | return b + a_3, c + a_4
87 |
88 | test(bar1, (4,), (1, 4, 4), (4, 1, 4))
89 |
90 | def bar2(a, b): # [5,2], [2,2]
91 | a_1, a_2, a_3 = torch.split(a, 2)
92 | c = a_1 + b
93 | d = a + a
94 | return c, d, a_2, a_3 # We have to use all the outputs of torch.split due to #1043
95 |
96 | test(bar2, (5, 2), (2, 2))
97 |
98 |
99 | @requiresCUDA
100 | def test_nanogpt_block():
101 | import thunder.tests.nanogpt_model as nanogpt_model
102 |
103 | config = nanogpt_model.GPTConfig(dropout=0)
104 | block = nanogpt_model.Block(config).to(dtype=torch.float32, device="cuda")
105 | cblock = TorchExecutor.make_callable(block)
106 | inp = make_tensor((2, config.block_size, config.n_embd), dtype=torch.float32, device="cuda", requires_grad=True)
107 | cblock(inp)
108 |
109 | fw_trace = thunder.last_traces(cblock)[-1]
110 | bw_trace = thunder.last_backward_traces(cblock)[-1]
111 | max_mem_fw = get_alloc_memory(fw_trace)
112 | max_mem_bw = get_alloc_memory(bw_trace)
113 |
114 | # Actual memory usage may vary depending on hardware and cuBLAS settings.
115 | # We are checking the estimated memory against a fixed value for consistency.
116 | assert max_mem_fw[0] == 381754368
117 | assert sum(max_mem_fw[1].values()) == 375462912
118 | assert max_mem_bw[0] == 641097728
119 | assert sum(max_mem_bw[1].values()) == 440474624
120 |
--------------------------------------------------------------------------------
/thunder/tests/test_pythonex.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import math
3 |
4 | import thunder
5 |
6 |
7 | def _run_cache_symbolic_values(fn, ref_fn, *args):
8 | jit_fn = thunder.jit(fn, cache="symbolic values")
9 | out = jit_fn(*args)
10 |
11 | out_ref = ref_fn(*args)
12 | assert out == out_ref
13 |
14 |
15 | def test_fmod():
16 | def foo(a, b):
17 | return a % b
18 |
19 | _run_cache_symbolic_values(foo, foo, 2.0, 1.3)
20 |
21 |
22 | def test_bitwise_or():
23 | def foo(a, b):
24 | return a | b
25 |
26 | _run_cache_symbolic_values(foo, foo, 3, 5)
27 |
28 |
29 | def test_bitwise_and():
30 | def foo(a, b):
31 | return a & b
32 |
33 | _run_cache_symbolic_values(foo, foo, 3, 5)
34 |
35 |
36 | def test_bitwise_xor():
37 | def foo(a, b):
38 | return a ^ b
39 |
40 | _run_cache_symbolic_values(foo, foo, 3, 5)
41 |
42 |
43 | def test_math_atan2():
44 | def foo(a, b):
45 | # TODO: calling through math.atan2 bakes in constant, this needs to be investigated.
46 | return thunder.clang.atan2(a, b)
47 |
48 | # NOTE: we have thunder.clang in foo, which cannot be run with non-proxy
49 | _run_cache_symbolic_values(foo, math.atan2, 2.0, 1.3)
50 |
51 |
52 | def test_math_fmod():
53 | def foo(a, b):
54 | return thunder.clang.fmod(a, b)
55 |
56 | _run_cache_symbolic_values(foo, math.fmod, 2.0, 1.3)
57 |
--------------------------------------------------------------------------------
/thunder/tests/test_reductions.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 |
3 | import torch
4 | from torch.testing import assert_close, make_tensor
5 |
6 | import thunder
7 | import thunder.torch as ttorch
8 | from thunder.tests.framework import instantiate
9 |
10 | import pytest
11 |
12 |
13 | # TODO: convert these tests to OpInfo generated tests
14 |
15 |
16 | @instantiate(dtypes=(thunder.float32,))
17 | def test_torch_var(executor, device, dtype):
18 | # Tests passing all arguments as function inputs
19 | def foo(a, dim, *, keepdim=False, correction=1):
20 | return ttorch.var(a, dim, keepdim=keepdim, correction=correction)
21 |
22 | traced_foo = executor.make_callable(foo)
23 |
24 | tdtype = ttorch.to_torch_dtype(dtype)
25 | a = torch.testing.make_tensor((4, 4), device=device, dtype=tdtype)
26 |
27 | # Full reduction
28 | thunder_result = traced_foo(a, [0, 1])
29 | torch_result = torch.var(a, [0, 1])
30 | assert_close(thunder_result, torch_result)
31 |
32 | # Reduce along dim 1
33 | thunder_result = traced_foo(a, [1])
34 | torch_result = torch.var(a, [1])
35 | assert_close(thunder_result, torch_result)
36 |
37 | # Specifying the correction
38 | thunder_result = traced_foo(a, [1], correction=2)
39 | torch_result = torch.var(a, [1], correction=2)
40 | assert_close(thunder_result, torch_result)
41 |
42 | # Specifying keepdim
43 | thunder_result = traced_foo(a, [1], keepdim=True, correction=2)
44 | torch_result = torch.var(a, [1], keepdim=True, correction=2)
45 | assert_close(thunder_result, torch_result)
46 |
47 | # Tests passing arguments as constants
48 | def foo(a):
49 | return ttorch.var(a, [0, 1], keepdim=True, correction=2)
50 |
51 | traced_foo = executor.make_callable(foo)
52 |
53 | a = torch.testing.make_tensor((4, 4), device=device, dtype=tdtype)
54 |
55 | thunder_result = traced_foo(a)
56 | torch_result = torch.var(a, [0, 1], keepdim=True, correction=2)
57 | assert_close(thunder_result, torch_result)
58 |
59 |
60 | @instantiate(dtypes=(thunder.float32,))
61 | def test_torch_mean(executor, device, dtype):
62 | def foo(a, dim=None, keepdim=False, *, dtype=None):
63 | return ttorch.mean(a, dim, keepdim, dtype=dtype)
64 |
65 | traced_foo = executor.make_callable(foo)
66 |
67 | tdtype = ttorch.to_torch_dtype(dtype)
68 | a = torch.testing.make_tensor((4, 4), device=device, dtype=tdtype)
69 |
70 | # Full reduction
71 | thunder_result = traced_foo(a, [0, 1])
72 | torch_result = torch.mean(a, [0, 1])
73 | assert_close(thunder_result, torch_result)
74 |
75 | # Reduce along dim 1
76 | thunder_result = traced_foo(a, [1])
77 | torch_result = torch.mean(a, [1])
78 | assert_close(thunder_result, torch_result)
79 |
80 | # Reduce with () dims
81 | thunder_result = traced_foo(a, ())
82 | torch_result = torch.mean(a, ())
83 | assert_close(thunder_result, torch_result)
84 |
85 |
86 | @instantiate(dtypes=(thunder.float32,))
87 | def test_var_mean(executor, device, dtype):
88 | def foo(a, dim=None, keepdim=False, *, correction=1):
89 | return ttorch.var_mean(a, dim, keepdim=keepdim, correction=correction)
90 |
91 | traced_foo = executor.make_callable(foo)
92 |
93 | tdtype = ttorch.to_torch_dtype(dtype)
94 | a = torch.testing.make_tensor((4, 4), device=device, dtype=tdtype)
95 |
96 | # Full reduction
97 | thunder_result = traced_foo(a, [0, 1])
98 | torch_result = torch.var_mean(a, [0, 1])
99 | assert_close(thunder_result, torch_result)
100 |
101 | # Reduce along dim 1
102 | thunder_result = traced_foo(a, [1])
103 | torch_result = torch.var_mean(a, [1])
104 | assert_close(thunder_result, torch_result)
105 |
106 | # Tests passing arguments as constants
107 | def foo(a):
108 | return ttorch.var_mean(a, [0, 1], keepdim=True, correction=2)
109 |
110 | traced_foo = executor.make_callable(foo)
111 |
112 | a = torch.testing.make_tensor((4, 4), device=device, dtype=tdtype)
113 |
114 | thunder_result = traced_foo(a)
115 | torch_result = torch.var_mean(a, [0, 1], keepdim=True, correction=2)
116 | assert_close(thunder_result, torch_result)
117 |
--------------------------------------------------------------------------------
/thunder/tests/test_shape_ops.py:
--------------------------------------------------------------------------------
1 | import thunder
2 | import torch
3 | from thunder.tests.framework import JAX_AVAILABLE
4 |
5 | if JAX_AVAILABLE:
6 | pass
7 |
8 |
9 | def test_pad_cast_value_itof():
10 | """
11 | Pad should cast the given value to the type of tensor and pad that value.
12 | """
13 |
14 | def fqn():
15 | x = torch.tensor([2, 3], dtype=torch.int32)
16 | y = torch.nn.functional.pad(x, pad=(1, 2), value=6.4)
17 | return y
18 |
19 | th_fqn = thunder.jit(fqn)
20 | v = th_fqn()
21 | assert v[0] == 6
22 | assert v[1] == 2
23 | assert v[2] == 3
24 | assert v[3] == 6
25 | assert v[4] == 6
26 |
27 |
28 | def test_pad_cast_value_ftoi():
29 | """
30 | Pad should cast the given value to the type of tensor and pad that value.
31 | """
32 |
33 | def fqn():
34 | x = torch.tensor([2.4, 3.8])
35 | y = torch.nn.functional.pad(x, pad=(1, 2), value=1)
36 | return y
37 |
38 | th_fqn = thunder.jit(fqn)
39 | v = th_fqn()
40 | assert v[0] == 1.0
41 | assert v[1] == 2.4
42 | assert v[2] == 3.8
43 | assert v[3] == 1.0
44 | assert v[4] == 1.0
45 |
--------------------------------------------------------------------------------
/thunder/tests/test_triton_ce.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 | from numbers import Number
3 |
4 | import pytest
5 |
6 | import torch
7 | from torch.testing import assert_close
8 |
9 | import thunder
10 |
11 |
12 | from thunder.tests.opinfos import get_opinfo
13 | from thunder.tests.framework import instantiate, requiresCUDA, requiresTriton, run_snippet, IN_CI
14 | from thunder.executors import triton_utils
15 |
16 |
17 | from lightning_utilities.core.imports import package_available
18 |
19 | TRITON_AVAILABLE = package_available("triton")
20 |
21 | # Requires triton 2.1 or greater
22 | triton: None | Any = None
23 | min_triton_version = "2.1"
24 | if triton_utils.is_triton_version_at_least(min_triton_version):
25 | from thunder.executors.triton_crossentropy import triton_ex
26 |
27 |
28 | # NOTE This test modifies the global executor map, so it technically should not
29 | # be run in parallel with other tests
30 | @pytest.mark.parametrize(
31 | "dtype",
32 | [torch.float16, torch.bfloat16, torch.float32, torch.float64],
33 | ids=("float16", "bfloat16", "float32", "float64"),
34 | )
35 | @pytest.mark.parametrize("device,", ["cuda"])
36 | @requiresCUDA
37 | @requiresTriton
38 | def test_triton_cross_entropy(device, dtype):
39 | if IN_CI:
40 | pytest.skip("Currently these tests are skipped in CI for speed")
41 |
42 | logits = torch.randn([2048, 50257], device=device, dtype=dtype)
43 | labels = torch.randint(0, 50257, [2048], device=device)
44 | reduction = "sum"
45 | ignore_index = labels[5].item()
46 | weight = torch.rand(50257, device=device, dtype=dtype, requires_grad=False)
47 | expected = torch.nn.functional.cross_entropy(
48 | logits, labels, weight=weight, reduction=reduction, ignore_index=ignore_index
49 | )
50 |
51 | def test(logits, labels, weight, reduction, ignore_index):
52 | return thunder.torch.cross_entropy(
53 | logits, labels, weight=weight, reduction=reduction, ignore_index=ignore_index
54 | )
55 |
56 | ctest = thunder.jit(test, executors=[triton_ex])
57 | actual = ctest(logits, labels, weight, reduction, ignore_index)
58 | torch.testing.assert_close(actual, expected)
59 | last_trace = thunder.last_traces(ctest)[-1]
60 | assert any(bsym.sym.name == "triton_crossentropy" for bsym in last_trace.bound_symbols)
61 |
62 |
63 | def snippet_torch_consistency(op, torch_op, sample):
64 | thunder_result = op(*sample.args, **sample.kwargs)
65 | torch_result = torch_op(*sample.args, **sample.kwargs)
66 | torch.cuda.synchronize()
67 |
68 | # Sets atol and rtol to looser tolerances than assert_close's defaults
69 | atol: Number = 1e-1
70 | rtol: Number = 1.3e-6
71 | assert_close(thunder_result, torch_result, equal_nan=True, atol=atol, rtol=rtol)
72 |
73 |
74 | @pytest.mark.parametrize(
75 | "dtype",
76 | [torch.float16, torch.bfloat16, torch.float32, torch.float64],
77 | ids=("float16", "bfloat16", "float32", "float64"),
78 | )
79 | @pytest.mark.parametrize("device,", ["cuda"])
80 | @requiresCUDA
81 | @requiresTriton
82 | def test_triton_cross_entropy_vs_torch_consistency(device, dtype):
83 | if IN_CI:
84 | pytest.skip("Currently these tests are skipped in CI for speed")
85 | if dtype == torch.float16 or dtype == torch.bfloat16:
86 | pytest.skip("Currently skipping float16 and bfloat16 due to numerical accuracy")
87 |
88 | opinfo = get_opinfo("cross_entropy")
89 |
90 | def foo(*args, **kwargs):
91 | return torch.nn.functional.cross_entropy(*args, **kwargs)
92 |
93 | ce = thunder.jit(foo, executors=[triton_ex])
94 |
95 | # NOTE reference inputs take a long time to run in CI, so this uses sample inputs in CI
96 | # opinfo.reference_inputs if not IN_CI else opinfo.sample_inputs
97 | # reference inputs for cross_entropy contains cases not implemented in Thunder
98 | input_generator = opinfo.reference_inputs
99 |
100 | for sample in input_generator(device=device, dtype=dtype, requires_grad=False):
101 | result = run_snippet(
102 | snippet_torch_consistency,
103 | opinfo,
104 | device,
105 | dtype,
106 | ce,
107 | opinfo.torch_reference,
108 | sample,
109 | )
110 | if result is not None:
111 | return result
112 |
--------------------------------------------------------------------------------
/thunder/torch/langctx.py:
--------------------------------------------------------------------------------
1 | from collections.abc import Callable
2 |
3 | from thunder.core.langctxs import LanguageContext, register_langctx, Languages, resolve_language
4 | from thunder.core.pytree import tree_flatten
5 | from thunder.core.proxies import TensorProxy
6 |
7 | #
8 | # Creates and registers the torch language context
9 | #
10 | # NOTE That this is done separately from the definition of thunder.torch operations, because the
11 | # language context must be available before those operations are defined
12 |
13 | _method_name_to_fn_map: dict[str, Callable] = {}
14 | _property_name_to_fn_map: dict[str, Callable] = {}
15 |
16 |
17 | class TorchCtx(LanguageContext):
18 | def __init__(self):
19 | super().__init__("torch")
20 |
21 | def has_method(self, id: str) -> bool:
22 | return id in _method_name_to_fn_map or id in _property_name_to_fn_map
23 |
24 | def get_method(self, id: str, *args, **kwargs) -> Callable:
25 | # Note: concrete implmenetations should only raise AttributeError or
26 | # return None for "missing" methods as the proxies will
27 | # route __getattr__ to here and hasattr relies on __getattr__
28 | # throwing AttributeError (only) when the attribute does
29 | # not exist.
30 | inps, _ = tree_flatten((args, kwargs))
31 |
32 | has_tensor_input: bool = False
33 | for x in inps:
34 | if isinstance(x, TensorProxy):
35 | has_tensor_input = True
36 | break
37 | if has_tensor_input:
38 | method: None | Callable = _method_name_to_fn_map.get(id, None)
39 | prop: None | Callable = _property_name_to_fn_map.get(id, None)
40 | if method is None and prop is None:
41 | raise AttributeError(f"The {self.name} language context has no method or attribute {id}")
42 | if method:
43 | return method
44 | else:
45 | return prop(inps[0])
46 |
47 | # has_tensor_input is False
48 | # Defers to the CLANG language context when there are no tensor inputs=
49 | # (the clang language context handles operations on NumberProxies and Numbers)
50 | primsctx: LanguageContext = resolve_language(Languages.CLANG)
51 | if not primsctx.has_method(id):
52 | raise AttributeError(
53 | f"Attempting to call method {id} in the torch language context, but it has no tensor inputs and the primitive language context (which handles numbers) doesn't have the method"
54 | )
55 | prim_method: Callable = primsctx.get_method(id, *args, **kwargs)
56 | return prim_method
57 |
58 |
59 | torchctx = TorchCtx()
60 | register_langctx(Languages.TORCH, torchctx)
61 | register_langctx("torch", torchctx)
62 |
63 |
64 | # Registers a method with the torch language context
65 | def register_method(method_name: str, method: Callable, /) -> None:
66 | _method_name_to_fn_map[method_name] = method
67 |
68 |
69 | def register_property(property_name, property) -> None:
70 | _property_name_to_fn_map[property_name] = property
71 |
--------------------------------------------------------------------------------
/thunder/transforms/__init__.py:
--------------------------------------------------------------------------------
1 | from .constant_folding import ConstantFolding
2 | from .materialization import MaterializationTransform
3 | from .qlora import LORATransform
4 | from .prune_prologue_checks import PrunePrologueChecks
5 | from .extraction_only_prologue_transform import ExtractionOnlyPrologueTransform
6 |
7 |
8 | __all__ = [
9 | "ConstantFolding",
10 | "LORATransform",
11 | "MaterializationTransform",
12 | "PrunePrologueChecks",
13 | "ExtractionOnlyPrologueTransform",
14 | ]
15 |
--------------------------------------------------------------------------------
/thunder/transforms/extraction_only_prologue_transform.py:
--------------------------------------------------------------------------------
1 | import thunder
2 | from thunder.core.trace import from_trace
3 | from thunder.core.proxies import ProxyTag
4 |
5 |
6 | class ExtractionOnlyPrologueTransform(thunder.Transform):
7 | def transform_traces_pre_prologue(self, prologue_trace, computation_trace, epilogue_trace, **kwargs):
8 | new_prologue_trace = from_trace(prologue_trace)
9 | new_bsyms = []
10 |
11 | for bsym in prologue_trace.bound_symbols:
12 | # NOTE - We assume TensorProxy's tagged with `STATIC_MEMORY_LOCATION` to
13 | # be Parameters or Buffer. It should be safe to disable check for
14 | # tensors we deem to be static.
15 | if (
16 | bsym.sym.id == thunder.prims.PrimIDs.CHECK_TENSOR_SHAPE_AND_METADATA
17 | and ProxyTag.STATIC_MEMORY_LOCATION in bsym.args[0].tags
18 | ):
19 | continue
20 |
21 | new_bsyms.append(bsym)
22 |
23 | new_prologue_trace.bound_symbols = new_bsyms
24 |
25 | new_prologue_trace.set_provenance("Extraction only prologue pass")
26 | return new_prologue_trace, computation_trace, epilogue_trace
27 |
--------------------------------------------------------------------------------
/thunder/transforms/prune_prologue_checks.py:
--------------------------------------------------------------------------------
1 | import thunder
2 | from thunder.core import prims as prims
3 |
4 |
5 | class PrunePrologueChecks(thunder.core.transform_common.Transform):
6 | """A Transform to prune Prologue checks
7 |
8 | This transform removes prologue checks and can be applied when the user
9 | controls the environment enough to ensure that these checks would always
10 | succeed. By default, we remove all checks that use the module.
11 | """
12 |
13 | def __init__(self, prune_all_checks=False):
14 | self.prune_all_checks = prune_all_checks
15 |
16 | def transform_traces_pre_prologue(self, prologue_trace, compute_trace, epilogue_trace, **kwargs):
17 | def is_check(bsym):
18 | return bsym.sym in {
19 | prims.check_tensor_shape_and_metadata,
20 | prims.check_string_value,
21 | prims.check_number_type_and_value,
22 | prims.check_len,
23 | prims.check_literal_like,
24 | }
25 |
26 | if not self.prune_all_checks:
27 | bsyms_to_skip = set()
28 | module_member_names = set()
29 |
30 | for bsym in prologue_trace.bound_symbols:
31 | if bsym.sym in {prims.unpack_trivial, prims.unpack_cache_info, prims.python_return}:
32 | # These don't have inputs but need to be skipped to not trigger false positives
33 | # python_return may have no inputs
34 | continue
35 | if bsym.sym is prims.unpack_function_obj:
36 | for o in bsym.flat_proxy_outs:
37 | module_member_names.add(o.name)
38 | continue
39 |
40 | input_names = {i.name in module_member_names for i in bsym.flat_proxy_args}
41 | if all(input_names):
42 | # This has the special case of no proxy inputs, which is the case for unpack_function_obj,
43 | # the root of module_member_names
44 | assert input_names, f"unexpected symbol {bsym.sym.name} without inputs"
45 | for o in bsym.flat_proxy_outs:
46 | module_member_names.add(o.name)
47 | if is_check(bsym):
48 | bsyms_to_skip.add(bsym)
49 |
50 | def should_skip_bsym(bsym):
51 | return bsym in bsyms_to_skip
52 |
53 | else:
54 | should_skip_bsym = is_check
55 |
56 | new_prologue_trace = thunder.core.trace.from_trace(prologue_trace)
57 | for bsym in prologue_trace.bound_symbols:
58 | if not should_skip_bsym(bsym):
59 | new_prologue_trace.bound_symbols.append(bsym.from_bsym())
60 |
61 | new_prologue_trace = thunder.core.transform_common.dce(new_prologue_trace)
62 | new_prologue_trace.set_provenance(thunder.core.trace.TraceProvenance(f"{self.__class__.__name__}"))
63 |
64 | return new_prologue_trace, compute_trace, epilogue_trace
65 |
--------------------------------------------------------------------------------
/thunder/transforms/utils.py:
--------------------------------------------------------------------------------
1 | import thunder
2 | from thunder.core import utils
3 | from thunder.core import prims
4 | from thunder.core.trace import TraceCtx
5 | from thunder.core.pytree import tree_map
6 |
7 |
8 | def get_checks(prologue_trace):
9 | # returns a dictionary mapping model param names to (check bsym, get param bsym
10 | check_dict = {}
11 | prologue_producers, prologue_consumers = utils.producers_and_consumers(prologue_trace)
12 | for bsym in prologue_trace.bound_symbols:
13 | if bsym.sym == prims.unpack_parameter or bsym.sym == prims.unpack_buffer:
14 | param_thunder_module, param_name = bsym.args
15 | checks = [
16 | bsym2 for bsym2 in prologue_consumers[bsym.output] if bsym2.sym == prims.check_tensor_shape_and_metadata
17 | ]
18 | assert (
19 | len(checks) == 1
20 | ), f"expected each parameter and buffer to have exactly one checker, but {bsym.output} has {len(checks)}"
21 | assert isinstance(param_name, str)
22 | check_dict[param_name] = (checks[0], bsym)
23 | return check_dict
24 |
25 |
26 | def add_trace_output(trace, output, subindex: int | None = None):
27 | ret_node = trace.bound_symbols[-1]
28 | assert ret_node.sym == prims.python_return
29 | assert len(ret_node.args) == 1
30 | (ret_args,) = ret_node.args
31 |
32 | if subindex is None:
33 | ret_args = (*ret_args, output)
34 | else:
35 | assert isinstance(ret_args[subindex], tuple)
36 | ret_args = (*ret_args[:subindex], (*ret_args[subindex], output), *ret_args[subindex + 1 :])
37 |
38 | ret_node.args = (ret_args,)
39 |
40 |
41 | def trace_with_replaced_proxy_metadata(trace: TraceCtx, proxy_replacement_metadata) -> TraceCtx:
42 | t = TraceCtx(trace.fn)
43 |
44 | proxymap: dict[str, thunder.Proxy] = {}
45 |
46 | def map_proxy(p):
47 | if isinstance(p, thunder.Proxy):
48 | return proxymap[p.name]
49 | return p
50 |
51 | def create_proxy(p):
52 | if isinstance(p, thunder.Proxy):
53 | if p.name in proxymap: # happens with subsymbols
54 | return proxymap[p.name]
55 | with thunder.core.trace.tracectx(t):
56 | np = p.replace(**proxy_replacement_metadata.get(p.name, {}))
57 | proxymap[p.name] = np
58 | return np
59 | return p
60 |
61 | def process_bound_symbols(src_bound_symbols, target_bound_symbols):
62 | for bsym in src_bound_symbols:
63 | new_args = tree_map(map_proxy, bsym.args)
64 | new_kwargs = tree_map(map_proxy, bsym.kwargs)
65 | new_output = tree_map(create_proxy, bsym.output)
66 | new_bsym = bsym.from_bsym(output=new_output, args=new_args, kwargs=new_kwargs, subsymbols=[])
67 | target_bound_symbols.append(new_bsym)
68 | if len(bsym.subsymbols) > 0:
69 | process_bound_symbols(bsym.subsymbols, new_bsym.subsymbols)
70 |
71 | process_bound_symbols(trace.bound_symbols, t.bound_symbols)
72 |
73 | t.args = tree_map(map_proxy, trace.args)
74 | t.kwargs = tree_map(map_proxy, trace.kwargs)
75 | t._siginfo = trace._siginfo
76 | return t
77 |
--------------------------------------------------------------------------------