├── .azure ├── docker-build.yml ├── gpu-coverage.yml ├── gpu-tests.yml └── notebook-runs.yml ├── .codecov.yml ├── .git-blame-ignore-revs ├── .github ├── CODEOWNERS ├── ISSUE_TEMPLATE │ ├── bug_report.md │ ├── documentation.md │ ├── feature_request.md │ └── program_coverage.md ├── PULL_REQUEST_TEMPLATE.md ├── dependabot.yml ├── labeling-config.yml ├── lightning-probot.yml ├── load-jit-coverage-report.py ├── load-quickstart-report.py ├── run-benchmark-as-lit-jobs.py ├── run-jit-coverage-as-lit-studios.py ├── run-quickstart-as-lit-jobs.py └── workflows │ ├── auto-cc.yml │ ├── ci-benchmark.yml │ ├── ci-checks.yml │ ├── ci-jit-coverage.yml │ ├── ci-quickstart.yml │ ├── ci-testing.yml │ ├── docs-build.yml │ ├── label-conflicts.yml │ ├── labeler.yml │ ├── release-nightly.yml │ ├── release-pypi.yml │ └── stale-branches.yaml ├── .gitignore ├── .lightning └── workflows │ ├── all-tests.yaml │ ├── notebooks.yaml │ └── transformer-engine.yaml ├── .pre-commit-config.yaml ├── LICENSE ├── MANIFEST.in ├── Makefile ├── README.md ├── dockers ├── README.md ├── ubuntu-cuda │ └── Dockerfile ├── with-apex │ └── Dockerfile └── with-dev │ └── Dockerfile ├── docs ├── .build_docs.sh ├── .readthedocs.yaml ├── Makefile ├── make.bat └── source │ ├── _static │ ├── copybutton.js │ └── images │ │ ├── LightningThunderDarkModewByline.png │ │ ├── LightningThunderLightModewByline.png │ │ ├── how_it_works.png │ │ ├── icon.svg │ │ ├── lightning_thunder_lightmode_nobyline.png │ │ ├── logo-large.svg │ │ ├── logo-small.svg │ │ ├── logo.png │ │ ├── logo.svg │ │ ├── normalized_training_throughput_zero2.png │ │ ├── pretrain_perf.png │ │ └── training_throughput_single.png │ ├── _templates │ └── theme_variables.jinja │ ├── advanced │ ├── contrib_thunder.rst │ ├── extending.rst │ └── inside_thunder.rst │ ├── basic │ ├── faq.rst │ ├── inspecting_traces.rst │ ├── mlp_mnist.rst │ ├── overview.rst │ └── sharp_edges.rst │ ├── conf.py │ ├── fundamentals │ ├── examine.rst │ ├── hello_world.rst │ └── installation.rst │ ├── index.rst │ ├── intermediate │ ├── additional_executors.rst │ ├── benchmarking.rst │ ├── ddp.rst │ └── whats_next.rst │ └── reference │ ├── clang │ └── index.rst │ ├── common │ └── index.rst │ ├── core │ ├── baseutils.rst │ ├── codeutils.rst │ ├── devices.rst │ ├── dtypes.rst │ ├── index.rst │ ├── langctxs.rst │ ├── prims.rst │ ├── proxies.rst │ ├── pytree.rst │ ├── rematerialization.rst │ ├── symbol.rst │ ├── trace.rst │ └── transforms.rst │ ├── distributed │ └── index.rst │ ├── dynamo │ └── index.rst │ ├── examine │ └── index.rst │ ├── executors │ ├── apexex.rst │ ├── cudnnex.rst │ ├── index.rst │ ├── nvfuserex.rst │ ├── passes.rst │ ├── pythonex.rst │ ├── torch_compile.rst │ ├── torchex.rst │ ├── triton_crossentropy.rst │ └── utils.rst │ ├── extend │ └── index.rst │ ├── plugins │ └── index.rst │ ├── recipes │ └── index.rst │ ├── thunder.rst │ ├── torch │ └── index.rst │ └── transforms │ └── index.rst ├── examples ├── coverage │ ├── all.txt │ ├── jit_coverage_hf.py │ └── requirements.txt └── quickstart │ ├── hf_bert.py │ ├── hf_llm.py │ ├── mlp.py │ ├── requirements.txt │ ├── vit_hf.py │ └── vit_tv.py ├── notebooks ├── .ignore.ci ├── adding_custom_operator.ipynb ├── adding_custom_operator_backward.ipynb ├── dev_tutorials │ ├── extend.ipynb │ ├── fsdp_tutorial.ipynb │ └── thunder-add-vjp-rule.md ├── extend_thunder_with_cuda_python.ipynb ├── hello_world_thunderfx.ipynb ├── liger_kernel.ipynb ├── thunder_trace_intro.ipynb ├── writing_a_trace_transform_cpu_offloading.ipynb └── zero_to_thunder.ipynb ├── pyproject.toml ├── requirements.txt ├── requirements ├── base.txt ├── coverage.txt ├── devel.txt ├── docs.txt ├── notebooks.txt └── test.txt ├── scripts ├── bisect_nvfuser.py ├── build_from_source.sh ├── remove-torch-lines.sh ├── sanity-check.sh └── validate_build.py ├── setup.py └── thunder ├── __about__.py ├── __init__.py ├── benchmarks ├── __init__.py ├── benchmark_hf.py ├── benchmark_litgpt.py ├── benchmark_peft.py ├── conftest.py ├── distributed.py ├── einsum.py ├── targets.py ├── test_benchmark_litgpt.py └── utils.py ├── clang ├── __init__.py ├── langctx.py └── utils.py ├── common.py ├── core ├── __init__.py ├── baseutils.py ├── codeutils.py ├── compile_data.py ├── devices.py ├── dtypes.py ├── interpreter.py ├── jit_ext.py ├── langctxs.py ├── module.py ├── options.py ├── patterns.py ├── prims.py ├── profile.py ├── proxies.py ├── pytree.py ├── recipe.py ├── rematerialization.py ├── symbol.py ├── trace.py ├── trace_interpreter.py ├── transform_common.py ├── transforms.py ├── update_aliases.py ├── utils.py └── vjp_utils.py ├── dev_utils ├── __init__.py ├── benchmark.py ├── check_trace.py ├── debug_transform.py ├── nvtx_profile_transform.py └── utils.py ├── distributed ├── __init__.py ├── bucketing.py ├── checkpoint.py ├── prims.py ├── tensor_parallel │ ├── __init__.py │ ├── column_wise.py │ ├── common.py │ ├── optimize_comm.py │ └── row_wise.py ├── transforms │ ├── __init__.py │ ├── ddp.py │ ├── ddp_v2.py │ ├── fsdp.py │ └── fsdp_v2.py └── utils.py ├── dynamo ├── __init__.py ├── benchmark_utils.py ├── compiler.py ├── compiler_graph_benchmark.py ├── report.py ├── repro_script_template.py ├── splitter.py └── utils.py ├── examine ├── __init__.py └── memory_calculation.py ├── executors ├── __init__.py ├── apex_entropyex_impl.py ├── apex_fused_rms_norm_impl.py ├── apexex.py ├── cudnn_layernormex.py ├── cudnn_sdpa.py ├── cudnnex.py ├── custom_op_ex.py ├── data_dependent_partition.py ├── fa3ex.py ├── nvfuserex.py ├── nvfuserex_impl.py ├── passes.py ├── pythonex.py ├── sdpaex.py ├── torch_autograd.py ├── torch_compile.py ├── torchex.py ├── transformer_engine_v1ex.py ├── transformer_engineex.py ├── transformer_engineex_impl.py ├── triton_crossentropy.py ├── triton_crossentropy_impl.py ├── triton_utils.py └── utils.py ├── extend └── __init__.py ├── numpy ├── __init__.py └── langctx.py ├── plugins ├── __init__.py ├── distributed.py ├── fp8.py ├── quantization.py └── reduce_overhead.py ├── py.typed ├── recipes ├── __init__.py ├── base.py └── hf_transformers.py ├── tests ├── README.md ├── __init__.py ├── bf16.py ├── conftest.py ├── coverage_tests │ ├── __init__.py │ ├── test_coverage_hf_diffusers.py │ └── test_coverage_trace.py ├── distributed │ ├── __init__.py │ ├── helper.py │ ├── modules.py │ ├── test_checkpoint.py │ ├── test_ddp.py │ ├── test_dtensor.py │ ├── test_fsdp.py │ ├── test_ops.py │ └── test_tensor_parallel.py ├── framework.py ├── hf_bart_self_attn.py ├── litgpt_model.py ├── llama2_model.py ├── make_tensor.py ├── module_example.py ├── nanogpt_model.py ├── opinfos.py ├── test_apex_cross_entropy_executor.py ├── test_apex_fused_norms.py ├── test_auto_register_torchops.py ├── test_autocast.py ├── test_check_trace.py ├── test_core.py ├── test_cudnn_executor.py ├── test_dynamo.py ├── test_einops.py ├── test_elementwise.py ├── test_examine.py ├── test_examine_memory.py ├── test_extend.py ├── test_fa3_executor.py ├── test_grad.py ├── test_inplace_copy.py ├── test_interpreter.py ├── test_jit_general.py ├── test_networks.py ├── test_nvfuser.py ├── test_nvfuser_remat.py ├── test_ops.py ├── test_patterns.py ├── test_pythonex.py ├── test_randomness.py ├── test_recipes.py ├── test_reductions.py ├── test_sdpaex_executor.py ├── test_shape_ops.py ├── test_torch_compile_executor.py ├── test_torch_library_custom_op.py ├── test_torch_library_custom_op_with_lists.py ├── test_transformer_engine_executor.py ├── test_transformer_engine_v1_executor.py ├── test_transforms.py ├── test_triton_ce.py ├── test_update_aliases.py └── utils.py ├── torch ├── __init__.py ├── custom_op.py ├── default_torch_ops.py ├── experimental │ ├── __init__.py │ ├── dtensor_codeutils.py │ ├── dtensor_proxy.py │ ├── dtensor_torch_and_prims.py │ └── dtensor_utils.py └── langctx.py └── transforms ├── __init__.py ├── autocast.py ├── autodiff.py ├── constant_folding.py ├── cudagraph.py ├── extraction_only_prologue_transform.py ├── materialization.py ├── prune_prologue_checks.py ├── qlora.py ├── quantization.py └── utils.py /.azure/gpu-coverage.yml: -------------------------------------------------------------------------------- 1 | trigger: 2 | tags: 3 | include: ["*"] 4 | paths: 5 | include: 6 | - ".azure/gpu-coverage.yml" 7 | - "requirements/coverage.txt" 8 | - "thunder/tests/coverage/**" 9 | branches: 10 | include: 11 | - "main" 12 | - "release/*" 13 | - "refs/tags/*" 14 | 15 | pr: 16 | branches: 17 | include: ["*"] 18 | 19 | jobs: 20 | - job: coverage 21 | strategy: 22 | matrix: 23 | "w/ torch 2.7.1": 24 | docker-image: "ubuntu24.04-cuda12.6.3-cudnn-fe1.10.0-py3.10-pt_2.7.1-dev" 25 | # how much time to give 'run always even if cancelled tasks' before stopping them 26 | cancelTimeoutInMinutes: "2" 27 | pool: "lit-rtx-3090" 28 | variables: 29 | DEVICES: $( python -c 'name = "$(Agent.Name)" ; gpus = name.split("_")[-1] if "_" in name else "0"; print(gpus)' ) 30 | TORCH_HOME: "/var/tmp/torch" 31 | PIP_CACHE_DIR: "/var/tmp/pip" 32 | PYTHONHASHSEED: "0" 33 | NCCL_DEBUG: "INFO" 34 | ALLOW_COVERAGE_TRACE: "1" 35 | container: 36 | image: "pytorchlightning/lightning-thunder:$(docker-image)" 37 | options: "--gpus=all --shm-size=16g -v /var/tmp:/var/tmp" 38 | workspace: 39 | clean: all 40 | steps: 41 | - bash: | 42 | echo $(DEVICES) 43 | lspci | egrep 'VGA|3D' 44 | dpkg-query -W -f='${Package} ${Version}\n' libnccl2 libnccl-dev 45 | whereis nvidia 46 | nvidia-smi 47 | which python && which pip 48 | python --version 49 | pip --version 50 | pip list 51 | echo "##vso[task.setvariable variable=CUDA_VISIBLE_DEVICES]$(DEVICES)" 52 | displayName: "Image info & NVIDIA" 53 | 54 | - bash: | 55 | set -ex 56 | # drop pt from requirements so not to interfere with the existing one 57 | bash scripts/remove-torch-lines.sh requirements/base.txt 58 | cat requirements/base.txt 59 | 60 | # double check on test requirements 61 | pip install -U -r requirements/base.txt -r requirements/coverage.txt 62 | 63 | # https://docs.codecov.com/docs/codecov-uploader 64 | curl -Os https://uploader.codecov.io/latest/linux/codecov 65 | chmod +x codecov 66 | 67 | # install this package 68 | python setup.py develop 69 | displayName: "Install package & ..." 70 | 71 | - bash: bash scripts/sanity-check.sh 72 | displayName: "Sanity check / details" 73 | 74 | - bash: | 75 | PYTHONPATH=$(pwd)/thunder/tests pytest thunder/tests/coverage_tests 76 | timeoutInMinutes: "45" 77 | displayName: "Testing: coverage_tests" 78 | -------------------------------------------------------------------------------- /.azure/notebook-runs.yml: -------------------------------------------------------------------------------- 1 | trigger: 2 | tags: 3 | include: ["*"] 4 | branches: 5 | include: 6 | - "main" 7 | - "release/*" 8 | - "refs/tags/*" 9 | 10 | pr: 11 | branches: 12 | include: ["*"] 13 | 14 | jobs: 15 | - job: jupyter 16 | strategy: 17 | matrix: 18 | "notebooks w/ torch 2.8": 19 | docker-image: "ubuntu24.04-cuda12.6.3-cudnn-fe1.10.0-py3.10-pt_2.8.0-dev" 20 | "notebooks w/ torch-nightly": 21 | docker-image: "ubuntu24.04-cuda12.6.3-cudnn-fe1.10.0-py3.10-pt_main-dev" 22 | # how long to run the job before automatically cancelling 23 | timeoutInMinutes: "45" 24 | # how much time to give 'run always even if cancelled tasks' before stopping them 25 | cancelTimeoutInMinutes: "2" 26 | pool: "lit-rtx-3090" 27 | variables: 28 | DEVICES: $( python -c 'name = "$(Agent.Name)" ; gpus = name.split("_")[-1] if "_" in name else "0"; print(gpus)' ) 29 | TORCH_HOME: "/var/tmp/torch" 30 | PIP_CACHE_DIR: "/var/tmp/pip" 31 | container: 32 | image: "pytorchlightning/lightning-thunder:$(docker-image)" 33 | options: "--gpus=all --shm-size=16g -v /var/tmp:/var/tmp" 34 | workspace: 35 | clean: all 36 | steps: 37 | - bash: | 38 | echo $(DEVICES) 39 | lspci | egrep 'VGA|3D' 40 | whereis nvidia 41 | nvidia-smi 42 | which python && which pip 43 | python --version 44 | pip --version 45 | pip list 46 | echo "##vso[task.setvariable variable=CUDA_VISIBLE_DEVICES]$(DEVICES)" 47 | displayName: "Image info & NVIDIA" 48 | 49 | - bash: | 50 | set -ex 51 | # drop pt from requirements so not to interfere with the existing one 52 | bash scripts/remove-torch-lines.sh requirements/base.txt 53 | cat requirements/base.txt 54 | pip install -U -r requirements/notebooks.txt 55 | # install this package 56 | python setup.py develop 57 | # double check on test requirements 58 | echo "Install special requirements for notebooks" 59 | displayName: "Install package & ..." 60 | 61 | - bash: | 62 | set -ex 63 | pip list 64 | bash scripts/sanity-check.sh 65 | displayName: "Sanity check / details" 66 | 67 | - bash: | 68 | set -ex 69 | # list all notebooks in this folder 70 | find . -name "*.ipynb" > all.txt 71 | # drop all "./" from beginning of each line 72 | sed -i 's/^\.\///' all.txt 73 | # filter out the ones that are listed in .ignore.ci 74 | grep -Fxv -f .ignore.ci all.txt > ci.txt 75 | # iterate over all listed notebooks and execute them with jupyter 76 | while read -r line; do 77 | echo "Processing $line" 78 | jupyter execute $line --timeout=300 79 | done <<< $(cat ci.txt) 80 | workingDirectory: "notebooks/" 81 | displayName: "Execute notebooks" 82 | -------------------------------------------------------------------------------- /.codecov.yml: -------------------------------------------------------------------------------- 1 | # see https://docs.codecov.io/docs/codecov-yaml 2 | # Validation check: 3 | # $ curl --data-binary @.codecov.yml https://codecov.io/validate 4 | 5 | # https://docs.codecov.io/docs/codecovyml-reference 6 | codecov: 7 | bot: "codecov-io" 8 | strict_yaml_branch: "yaml-config" 9 | require_ci_to_pass: yes 10 | notify: 11 | # after_n_builds: 2 12 | wait_for_ci: yes 13 | 14 | coverage: 15 | precision: 0 # 2 = xx.xx%, 0 = xx% 16 | round: nearest # how coverage is rounded: down/up/nearest 17 | range: 40...100 # custom range of coverage colors from red -> yellow -> green 18 | status: 19 | # https://codecov.readme.io/v1.0/docs/commit-status 20 | project: 21 | default: 22 | informational: true 23 | target: 99% # specify the target coverage for each commit status 24 | threshold: 30% # allow this little decrease on project 25 | # https://github.com/codecov/support/wiki/Filtering-Branches 26 | # branches: main 27 | if_ci_failed: error 28 | # https://github.com/codecov/support/wiki/Patch-Status 29 | patch: 30 | default: 31 | informational: true 32 | target: 50% # specify the target "X%" coverage to hit 33 | # threshold: 50% # allow this much decrease on patch 34 | changes: false 35 | 36 | # https://docs.codecov.com/docs/github-checks#disabling-github-checks-patch-annotations 37 | github_checks: 38 | annotations: false 39 | 40 | parsers: 41 | gcov: 42 | branch_detection: 43 | conditional: true 44 | loop: true 45 | macro: false 46 | method: false 47 | javascript: 48 | enable_partials: false 49 | 50 | comment: false 51 | -------------------------------------------------------------------------------- /.git-blame-ignore-revs: -------------------------------------------------------------------------------- 1 | # Move optimize_allreduce_in_ddp_backward to thunder.distributed.transforms namespace #2087 2 | 7a3d5d4e64da010eea42210a065960e3568870ac 3 | # Split `test_ddp` into `test_ddp`, `test_fsdp` and `test_ops` (#772) 4 | b63e8713141044ed8c184f2f4c8305048d477319 5 | # Enable `ruff-check` in pre-commit (#2192) 6 | ae51153db2e623e8b9fb3609de3a6ad976addd0a 7 | # Enable ruff format in pre-commit (#2142) 8 | 0c6f9a91f1ac955bd5c1087ae26d120d7ab184a3 9 | -------------------------------------------------------------------------------- /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | # Each line is a file pattern followed by one or more owners. 2 | 3 | # These owners will be the default owners for everything in the repo. Unless a later match takes precedence, 4 | # @global-owner1, @global-owner2, and @global-owner3 will be requested for review when someone opens a pull request. 5 | 6 | # Thank you, our previous code owners for their service: 7 | # @carmocca @borda 8 | 9 | * @mruberry @lantiga @t-vi @KaelanDt 10 | 11 | # CI/CD and configs 12 | /.azure/ @KaelanDt @lantiga @t-vi 13 | /.github/ @KaelanDt @lantiga @t-vi 14 | /.lightning/ @KaelanDt @lantiga @t-vi 15 | /dockers/ @KaelanDt @lantiga @t-vi 16 | Makefile @KaelanDt @lantiga @t-vi 17 | *.yml @KaelanDt @lantiga @t-vi 18 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | assignees: '' 6 | --- 7 | 8 | *Note*: If you have a model or program that is not supported yet but should be, please use the program coverage template. 9 | 10 | ## 🐛 Bug 11 | 12 | 13 | 14 | ### To Reproduce 15 | 16 | Steps to reproduce the behavior: 17 | 18 | 1. Go to '...' 19 | 1. Run '....' 20 | 1. Scroll down to '....' 21 | 1. See error 22 | 23 | 24 | 25 | #### Code sample 26 | 27 | 29 | 30 | ### Expected behavior 31 | 32 | 33 | 34 | ### Environment 35 | 36 | - PyTorch Version (e.g., 1.0): 37 | - OS (e.g., Linux): 38 | - How you installed PyTorch (`conda`, `pip`, source): 39 | - Build command you used (if compiling from source): 40 | - Python version: 41 | - CUDA/cuDNN version: 42 | - GPU models and configuration: 43 | - Any other relevant information: 44 | 45 | ### Additional context 46 | 47 | 48 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/documentation.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Typos and doc fixes 3 | about: Typos and doc fixes 4 | title: '' 5 | labels: documentation 6 | assignees: '' 7 | --- 8 | 9 | ## 📚 Documentation 10 | 11 | For typos and doc fixes, please go ahead and: 12 | 13 | 1. Create an issue. 14 | 1. Fix the typo. 15 | 1. Submit a PR. 16 | 17 | Thanks! 18 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: enhancement 6 | assignees: '' 7 | --- 8 | 9 | ## 🚀 Feature 10 | 11 | 12 | 13 | ### Motivation 14 | 15 | 16 | 17 | ### Pitch 18 | 19 | 20 | 21 | ### Alternatives 22 | 23 | 24 | 25 | ### Additional context 26 | 27 | 28 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/program_coverage.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Program Coverage 3 | about: Expand the programs / models Thunder can process 4 | title: '' 5 | labels: program-coverage 6 | assignees: '' 7 | --- 8 | 9 | ## 🚀 Model / language coverage 10 | 11 | 12 | 13 | ### Pitch 14 | 15 | 16 | 17 | ### Alternatives / Potential work-arounds 18 | 19 | 24 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 |
2 | Before submitting 3 | 4 | - [ ] Was this discussed/approved via a Github issue? (no need for typos and docs improvements) 5 | - [ ] Did you read the [contributor guideline](https://github.com/Lightning-AI/pytorch-lightning/blob/main/.github/CONTRIBUTING.md), Pull Request section? 6 | - [ ] Did you make sure to update the docs? 7 | - [ ] Did you write any new necessary tests? 8 | 9 |
10 | 11 | ## What does this PR do? 12 | 13 | Fixes # (issue). 14 | 15 | ## PR review 16 | 17 | Anyone in the community is free to review the PR once the tests have passed. 18 | If we didn't discuss your PR in Github issues there's a high chance it will not be merged. 19 | 20 | ## Did you have fun? 21 | 22 | Make sure you had fun coding 🙃 23 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # Basic dependabot.yml file with 2 | # minimum configuration for two package managers 3 | 4 | version: 2 5 | updates: 6 | # Enable version updates for python 7 | - package-ecosystem: "pip" 8 | # Look for a `requirements` in the `root` directory 9 | directory: "/" 10 | schedule: 11 | interval: "monthly" 12 | pull-request-branch-name: 13 | separator: "-" 14 | # Allow up to 5 open pull requests for pip dependencies 15 | open-pull-requests-limit: 5 16 | 17 | # Enable version updates for GitHub Actions 18 | - package-ecosystem: "github-actions" 19 | directory: "/" 20 | schedule: 21 | interval: "monthly" 22 | pull-request-branch-name: 23 | separator: "-" 24 | # Allow up to 5 open pull requests for GitHub Actions 25 | open-pull-requests-limit: 5 26 | groups: 27 | GHA-updates: 28 | patterns: 29 | - "*" 30 | -------------------------------------------------------------------------------- /.github/labeling-config.yml: -------------------------------------------------------------------------------- 1 | documentation: 2 | - changed-files: 3 | - any-glob-to-any-file: 4 | - docs/**/* 5 | - dev_tutorials/* 6 | 7 | "ci": 8 | - changed-files: 9 | - any-glob-to-any-file: 10 | - .azure/* 11 | - .github/* 12 | - .github/workflows/* 13 | - .lightning/workflows/* 14 | - dockers/**/* 15 | 16 | "docker": 17 | - changed-files: 18 | - any-glob-to-any-file: 19 | - dockers/**/* 20 | - .azure/docker-build.yml 21 | 22 | "install": 23 | - changed-files: 24 | - any-glob-to-any-file: 25 | - setup.py 26 | - pyproject.toml 27 | - MANIFEST.in 28 | 29 | "dependencies": 30 | - changed-files: 31 | - any-glob-to-any-file: 32 | - requirements/* 33 | - requirements.txt 34 | -------------------------------------------------------------------------------- /.github/lightning-probot.yml: -------------------------------------------------------------------------------- 1 | tracking_issue: 72 2 | -------------------------------------------------------------------------------- /.github/load-jit-coverage-report.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | 4 | 5 | def main(report_path: str = "jit_coverage_report.json"): 6 | """Load and print the jit coverage report.""" 7 | with open(report_path) as fp: 8 | report = json.load(fp) 9 | 10 | all_count = len(report) 11 | success = [el for el in report if el["status"] == "[SUCCESS]"] 12 | skipped = [el for el in report if el["status"] == "[SKIPPED]"] 13 | failure = [el for el in report if el["status"] == "[FAILURE]"] 14 | 15 | print("thunder.jit coverage report") 16 | print(f"🟢 {len(success)}/{all_count} [SUCCESS]") 17 | print(f"🟡 {len(skipped)}/{all_count} [SKIPPED]") 18 | print(f"⛔ {len(failure)}/{all_count} [FAILURE]") 19 | 20 | 21 | if __name__ == "__main__": 22 | # optional path to the report file 23 | system_args = sys.argv[1:] 24 | main_args = {"report_path": system_args[0]} if len(system_args) > 0 else {} 25 | main(**main_args) 26 | -------------------------------------------------------------------------------- /.github/load-quickstart-report.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | 4 | 5 | def main(report_path: str = "quickstart_report.json"): 6 | """Load and print the quickstart report.""" 7 | with open(report_path) as fp: 8 | report = json.load(fp) 9 | 10 | success_count = sum(status == "completed" for status in report.values()) 11 | overall_status = "🟢" if success_count == len(report) else "⛔" 12 | print(f"Quickstart report {overall_status} with {success_count} out of {len(report)} successful:") 13 | # Sort so that entries with status "success" (or "completed") are last 14 | for name, status in sorted(report.items(), key=lambda x: x[1] == "completed"): 15 | status_icon = "✔️" if status == "completed" else "❌" 16 | print(f"{status_icon} {name}") 17 | 18 | 19 | if __name__ == "__main__": 20 | # optional path to the report file 21 | system_args = sys.argv[1:] 22 | main_args = {"report_path": system_args[0]} if len(system_args) > 0 else {} 23 | main(**main_args) 24 | -------------------------------------------------------------------------------- /.github/run-benchmark-as-lit-jobs.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import json 3 | import sys 4 | import time 5 | from datetime import datetime 6 | 7 | from lightning_sdk import Studio, Job, Machine, Status 8 | 9 | 10 | def main(gh_run_id: str = ""): 11 | if not gh_run_id: 12 | gh_run_id = datetime.now().strftime("%Y-%m-%d|%H:%M:%S") 13 | print("Creating studio...") 14 | s = Studio(f"thunder-benchmark-run{gh_run_id}", "oss-thunder", org="lightning-ai", create_ok=True) 15 | 16 | print("Uploading package and benchmark script...") 17 | s.upload_folder("dist", remote_path="dist") 18 | pkg_path = glob.glob("dist/*.whl")[0] 19 | s.upload_file("thunder/benchmarks/benchmark_hf.py", remote_path="benchmarks/benchmark_hf.py") 20 | 21 | print("Starting studio...") 22 | s.start() 23 | print("Installing Thunder and dependencies...") 24 | s.run( 25 | f"pip install {pkg_path} -U transformers==4.52.4 nvidia-cudnn-frontend 'numpy<2.0' 'nvfuser_cu128_torch27==0.2.27.dev20250615'" 26 | ) 27 | 28 | print("Running HF benchmark script...") 29 | job = Job.run( 30 | name=f"benchmark-run{gh_run_id}", 31 | command="pip list && python benchmarks/benchmark_hf.py", 32 | studio=s, 33 | machine=Machine.L40S, 34 | interruptible=True, 35 | ) 36 | 37 | print("Stopping studio...") 38 | s.stop() 39 | 40 | print("Waiting for job to finish...") 41 | job.wait() 42 | status = str(job.status).lower() 43 | print(f"[{job.status}]\t {job.name}") 44 | 45 | report = {"benchmark_hf.py": status} 46 | with open("benchmark_hf_report.json", "w") as fp: 47 | json.dump(report, fp, indent=4) 48 | 49 | if job.status != Status.Completed: 50 | print("=" * 80) 51 | print(f"===== benchmark_hf.py -> {job.status} =====") 52 | print("=" * 80) 53 | print(job.logs) 54 | print("=" * 80) 55 | time.sleep(3) 56 | raise RuntimeError(f"Benchmark HF job {job.status}") 57 | # clean up 58 | job.delete() 59 | s.delete() 60 | 61 | 62 | if __name__ == "__main__": 63 | # parse command line arguments 64 | args = sys.argv[1:] 65 | main(*args) 66 | -------------------------------------------------------------------------------- /.github/run-jit-coverage-as-lit-studios.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from datetime import datetime 3 | import glob 4 | import os 5 | import json 6 | 7 | from lightning_sdk import Studio, Machine 8 | from tqdm import tqdm 9 | 10 | 11 | def main(gh_run_id: str = ""): 12 | if not gh_run_id: 13 | gh_run_id = datetime.now().strftime("%Y-%m-%d|%H:%M:%S") 14 | 15 | batches = [] 16 | chunk_size = 16 17 | with open("examples/coverage/all.txt") as f: 18 | lines = [el for el in f.readlines()] 19 | chunks = [lines[i : i + chunk_size] for i in range(0, len(lines), chunk_size)] 20 | for i, chunk_lines in enumerate(chunks): 21 | filename = f"{i:03d}.txt" 22 | with open(f"examples/coverage/{filename}", "w") as f: 23 | f.writelines([el + "\n" for el in chunk_lines]) 24 | batches.append((filename, chunk_lines)) 25 | 26 | print("Creating studio...") 27 | s = Studio(f"thunder-jit-coverage-run{gh_run_id}", "oss-thunder", org="lightning-ai", create_ok=True) 28 | 29 | print("Uploading package and scripts...") 30 | s.upload_folder("dist", remote_path="dist") 31 | pkg_path = glob.glob("dist/*.whl")[0] 32 | s.upload_folder("examples/coverage", remote_path="coverage") 33 | 34 | print("Starting studio...") 35 | s.start(machine=Machine.L40S, interruptible=False) 36 | 37 | print("Installing Thunder and other requirements...") 38 | s.run(f"pip install {pkg_path} -U -r coverage/requirements.txt") 39 | 40 | hf_token = os.environ["HF_TOKEN"] 41 | print("Running thunder.jit coverage...") 42 | for filename, models in tqdm(batches, unit_scale=chunk_size): 43 | s.run( 44 | f"HF_TOKEN={hf_token} python coverage/jit_coverage_hf.py --models-file coverage/{filename} --output-dir data" 45 | ) 46 | 47 | print("Aggregating results...") 48 | s.run("python coverage/jit_coverage_hf.py --output-dir data --results-file data.json") 49 | 50 | data_json = s.run("cat data.json") 51 | data = json.loads(data_json) 52 | success = [el for el in data if el["status"] == "[SUCCESS]"] 53 | skipped = [el for el in data if el["status"] == "[SKIPPED]"] 54 | failure = [el for el in data if el["status"] == "[FAILURE]"] 55 | 56 | for el in success: 57 | print(f"🟢 [SUCCESS] {el['model']}") 58 | 59 | for el in skipped: 60 | print(f"🟡 [SKIPPED] {el['model']}") 61 | print(f" Error: {el['last']}") 62 | 63 | for el in failure: 64 | print(f"⛔️ [FAILURE] {el['model']}") 65 | print(f" Error: {el['last']}") 66 | 67 | with open("jit_coverage_report.json", "w") as f: 68 | f.write(data_json) 69 | 70 | print("Stopping studio...") 71 | s.stop() 72 | s.delete() 73 | 74 | 75 | if __name__ == "__main__": 76 | # parse command line arguments 77 | args = sys.argv[1:] 78 | main(*args) 79 | -------------------------------------------------------------------------------- /.github/run-quickstart-as-lit-jobs.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | from datetime import datetime 4 | import glob 5 | import os.path 6 | 7 | from lightning_sdk import Studio, Job, Machine, Status 8 | 9 | 10 | def main(gh_run_id: str = ""): 11 | if not gh_run_id: 12 | gh_run_id = datetime.now().strftime("%Y-%m-%d|%H:%M:%S") 13 | print("Creating studio...") 14 | s = Studio(f"thunder-quickstarts-run{gh_run_id}", "oss-thunder", org="lightning-ai", create_ok=True) 15 | print("Uploading package and scripts...") 16 | s.upload_folder("dist", remote_path="dist") 17 | pkg_path = glob.glob("dist/*.whl")[0] 18 | s.upload_folder("examples/quickstart", remote_path="quickstart") 19 | 20 | print("Starting studio...") 21 | s.start() 22 | print("Installing Thunder and other requirements...") 23 | s.run(f"pip install {pkg_path} -U -r quickstart/requirements.txt") 24 | 25 | ls_quickstart = glob.glob("examples/quickstart/*.py") 26 | print("Found quickstart scripts:", ls_quickstart) 27 | 28 | print("Running quickstart scripts...") 29 | jobs = { 30 | os.path.basename(script): Job.run( 31 | name=f"ci-run{gh_run_id}_{script}", 32 | command=f"pip list && python quickstart/{os.path.basename(script)}", 33 | studio=s, 34 | machine=Machine.L40S, 35 | interruptible=True, 36 | ) 37 | for script in ls_quickstart 38 | } 39 | 40 | print("Stopping studio...") 41 | s.stop() 42 | 43 | print("Waiting for jobs to finish...") 44 | report, failures = {}, {} 45 | for name, job in jobs.items(): 46 | job.wait() 47 | print(f"[{job.status}]\t {job.name}") 48 | report[name] = str(job.status).lower() 49 | if job.status != Status.Completed: 50 | failures[name] = job.logs 51 | else: # clean up successful jobs 52 | job.delete() 53 | 54 | with open("quickstart_report.json", "w") as fp: 55 | json.dump(report, fp, indent=4) 56 | 57 | print("Showing logs of failed jobs...") 58 | separator = "=" * 80 59 | for name, logs in failures.items(): 60 | offset = "=" * (80 - 5 - 2 - len(name)) 61 | print(f"{separator}\n===== {name} {offset}\n{separator}") 62 | print(logs) 63 | print(separator + "\n" * 5) 64 | 65 | assert not failures 66 | 67 | print("Cleaning up...") 68 | s.delete() 69 | 70 | 71 | if __name__ == "__main__": 72 | # parse command line arguments 73 | args = sys.argv[1:] 74 | main(*args) 75 | -------------------------------------------------------------------------------- /.github/workflows/auto-cc.yml: -------------------------------------------------------------------------------- 1 | name: Probot 2 | 3 | on: 4 | issues: 5 | types: [labeled] 6 | # should use `pull_request_target` but it's blocked by 7 | # https://github.com/probot/probot/issues/1635 8 | # so this job will not run on forks until the above is fixed 9 | pull_request: 10 | types: [labeled, ready_for_review] 11 | 12 | jobs: 13 | auto-cc: 14 | runs-on: ubuntu-latest 15 | if: github.event_name == 'issue' || github.event.pull_request.draft == false 16 | timeout-minutes: 5 17 | steps: 18 | - uses: Lightning-AI/probot@v5 19 | env: 20 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 21 | with: 22 | job: auto-cc 23 | -------------------------------------------------------------------------------- /.github/workflows/ci-benchmark.yml: -------------------------------------------------------------------------------- 1 | name: Benchmark Models 2 | 3 | on: 4 | workflow_dispatch: {} 5 | pull_request: 6 | paths: 7 | - ".github/workflows/benchmark-hf.yml" 8 | - ".github/run-benchmark-as-lit-jobs.py" 9 | push: # enable tracing which commit to master broke it... 10 | branches: [main] 11 | 12 | concurrency: 13 | group: ${{ github.workflow }}-${{ github.ref }} 14 | cancel-in-progress: false 15 | 16 | defaults: 17 | run: 18 | shell: bash 19 | 20 | jobs: 21 | launcher-benchmark: 22 | runs-on: "ubuntu-22.04" 23 | 24 | steps: 25 | - uses: actions/checkout@v5 26 | 27 | - uses: actions/setup-python@v5 28 | with: 29 | python-version: "3.10" 30 | 31 | - name: Build Thunder Package 32 | run: | 33 | pip install -U build 34 | python -m build --sdist --wheel --outdir dist/ 35 | ls -l dist/ 36 | 37 | - name: Launch Benchmark Job in Lightning Studio 38 | env: 39 | LIGHTNING_USER_ID: ${{ secrets.LIGHTNING_USER_ID }} 40 | LIGHTNING_API_KEY: ${{ secrets.LIGHTNING_API_KEY }} 41 | run: | 42 | pip install lightning_sdk -U -q 43 | python .github/run-benchmark-as-lit-jobs.py ${{ github.run_id }} 44 | 45 | - name: Post Slack Notification 46 | if: always() && github.event_name != 'pull_request' 47 | uses: act10ns/slack@v2 48 | with: 49 | webhook-url: ${{ secrets.SLACK_WEBHOOK_URL }} 50 | status: ${{ job.status }} 51 | message: | 52 | *Benchmark Triggered Manually* - [${{ job.status }}] 53 | ref: ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }} 54 | -------------------------------------------------------------------------------- /.github/workflows/ci-checks.yml: -------------------------------------------------------------------------------- 1 | name: General checks 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: {} 7 | 8 | concurrency: 9 | group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref }} 10 | cancel-in-progress: ${{ ! (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/release/')) }} 11 | 12 | jobs: 13 | check-schema: 14 | uses: Lightning-AI/utilities/.github/workflows/check-schema.yml@main 15 | with: 16 | azure-dir: ".azure" 17 | 18 | check-package: 19 | uses: Lightning-AI/utilities/.github/workflows/check-package.yml@main 20 | with: 21 | actions-ref: main 22 | import-name: "thunder" 23 | artifact-name: dist-packages-${{ github.sha }} 24 | testing-matrix: | 25 | { 26 | "os": ["ubuntu-latest", "macOS-latest", "windows-latest"], 27 | "python-version": ["3.10", "3.11"] 28 | } 29 | -------------------------------------------------------------------------------- /.github/workflows/ci-jit-coverage.yml: -------------------------------------------------------------------------------- 1 | name: CI jit coverage 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | paths: 8 | - ".github/workflows/ci-jit-coverage.yml" 9 | - ".github/run-jit-coverage-as-lit-studios.py" 10 | - ".github/load-jit-coverage-report.py" 11 | - "examples/coverage/*" 12 | workflow_dispatch: {} 13 | schedule: 14 | - cron: "0 0 * * *" # every midnight 15 | 16 | concurrency: 17 | group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref }} 18 | cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} 19 | 20 | defaults: 21 | run: 22 | shell: bash 23 | 24 | jobs: 25 | launcher-jit-coverage: 26 | runs-on: "ubuntu-22.04" 27 | #timeout-minutes: 55 28 | steps: 29 | - uses: actions/checkout@v5 30 | - uses: actions/setup-python@v5 31 | with: 32 | python-version: "3.10" 33 | 34 | - name: Build package 35 | run: | 36 | pip install -U build 37 | python -m build --sdist --wheel --outdir dist/ 38 | ls -l dist/ 39 | 40 | - name: Run thunder.jit coverage 41 | env: 42 | LIGHTNING_USER_ID: ${{ secrets.LIGHTNING_USER_ID }} 43 | LIGHTNING_API_KEY: ${{ secrets.LIGHTNING_API_KEY }} 44 | HF_TOKEN: ${{ secrets.HF_TOKEN }} 45 | run: | 46 | pip install lightning_sdk tqdm -U -q 47 | python .github/run-jit-coverage-as-lit-studios.py ${{ github.run_id }} 48 | 49 | - name: Load report 50 | if: always() 51 | id: load-report 52 | run: | 53 | report=$(python .github/load-jit-coverage-report.py) 54 | echo "REPORT<> $GITHUB_ENV 55 | echo "$report" >> $GITHUB_ENV 56 | echo "EOF" >> $GITHUB_ENV 57 | 58 | - uses: act10ns/slack@v2 59 | # if: always() && github.event_name != 'pull_request' 60 | with: 61 | webhook-url: ${{ secrets.SLACK_WEBHOOK_URL }} 62 | status: ${{ job.status }} 63 | message: | 64 | *thunder.jit coverage CI* - [${{ job.status }}] 65 | ${{ env.REPORT }} 66 | ref: ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }} 67 | -------------------------------------------------------------------------------- /.github/workflows/ci-quickstart.yml: -------------------------------------------------------------------------------- 1 | name: CI quickstart 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | paths: 8 | - ".github/workflows/ci-quickstart.yml" 9 | - ".github/run-quickstart-as-lit-jobs.py" 10 | - ".github/load-quickstart-report.py" 11 | - "examples/quickstart/*" 12 | workflow_dispatch: {} 13 | schedule: 14 | - cron: "0 0 * * *" # every midnight 15 | 16 | concurrency: 17 | group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref }} 18 | cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} 19 | 20 | defaults: 21 | run: 22 | shell: bash 23 | 24 | jobs: 25 | launcher-quickstart: 26 | runs-on: "ubuntu-22.04" 27 | #timeout-minutes: 55 28 | steps: 29 | - uses: actions/checkout@v5 30 | - uses: actions/setup-python@v5 31 | with: 32 | python-version: "3.10" 33 | 34 | - name: Build package 35 | run: | 36 | pip install -U build 37 | python -m build --sdist --wheel --outdir dist/ 38 | ls -l dist/ 39 | 40 | - name: Run scripts in jobs 41 | env: 42 | LIGHTNING_USER_ID: ${{ secrets.LIGHTNING_USER_ID }} 43 | LIGHTNING_API_KEY: ${{ secrets.LIGHTNING_API_KEY }} 44 | run: | 45 | pip install lightning_sdk -U -q 46 | python .github/run-quickstart-as-lit-jobs.py ${{ github.run_id }} 47 | 48 | - name: Load report 49 | if: always() 50 | id: load-report 51 | run: | 52 | report=$(python .github/load-quickstart-report.py) 53 | echo "REPORT<> $GITHUB_ENV 54 | echo "$report" >> $GITHUB_ENV 55 | echo "EOF" >> $GITHUB_ENV 56 | 57 | - uses: act10ns/slack@v2 58 | if: always() && github.event_name != 'pull_request' && steps.load-report.conclusion == 'success' 59 | with: 60 | webhook-url: ${{ secrets.SLACK_WEBHOOK_URL }} 61 | status: ${{ job.status }} 62 | message: | 63 | *Quickstart CI* - [${{ job.status }}] 64 | ${{ env.REPORT }} 65 | ref: ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }} 66 | -------------------------------------------------------------------------------- /.github/workflows/docs-build.yml: -------------------------------------------------------------------------------- 1 | name: "Build (& deploy) Docs" 2 | on: 3 | push: 4 | branches: [main] 5 | pull_request: 6 | branches: [main] 7 | types: [opened, reopened, ready_for_review, synchronize] 8 | workflow_dispatch: {} 9 | 10 | concurrency: 11 | group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref }} 12 | cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} 13 | 14 | defaults: 15 | run: 16 | shell: bash 17 | 18 | jobs: 19 | docs-make: 20 | if: github.event.pull_request.draft == false 21 | runs-on: ubuntu-22.04 22 | strategy: 23 | fail-fast: false 24 | matrix: 25 | target: ["html", "doctest", "linkcheck"] 26 | env: 27 | ARTIFACT_DAYS: 0 28 | PYPI_LOCAL_DIR: "pypi_pkgs/" 29 | steps: 30 | - uses: actions/checkout@v5 31 | - uses: actions/setup-python@v5 32 | with: 33 | python-version: "3.10" 34 | 35 | - name: Pull sphinx template 36 | run: make get-sphinx-theme 37 | - name: Install pandoc 38 | timeout-minutes: 5 39 | run: sudo apt-get install -y pandoc 40 | - name: Install package & dependencies 41 | timeout-minutes: 20 42 | run: pip install . -U -r requirements/docs.txt 43 | 44 | - name: Make ${{ matrix.target }} 45 | working-directory: docs/ 46 | # allow failing link check and doctest if you run with dispatch 47 | continue-on-error: ${{ matrix.target == 'doctest' || matrix.target == 'linkcheck' }} 48 | run: make ${{ matrix.target }} --debug --jobs $(nproc) SPHINXOPTS="-W --keep-going" 49 | 50 | - name: Keep artifact 51 | if: github.event_name == 'pull_request' 52 | run: echo "ARTIFACT_DAYS=7" >> $GITHUB_ENV 53 | - name: Upload built docs 54 | if: ${{ matrix.target == 'html' }} 55 | uses: actions/upload-artifact@v4 56 | with: 57 | name: docs-html-${{ github.sha }} 58 | path: docs/build/html/ 59 | retention-days: ${{ env.ARTIFACT_DAYS }} 60 | 61 | deploy-docs: 62 | needs: docs-make 63 | if: github.repository_owner == 'Lightning-AI' && github.event_name == 'push' 64 | runs-on: ubuntu-latest 65 | env: 66 | GCP_TARGET: "gs://lightning-docs-thunder" 67 | steps: 68 | - uses: actions/download-artifact@v5 69 | with: 70 | name: docs-html-${{ github.sha }} 71 | path: docs/build/html/ 72 | 73 | - name: Authenticate to Google Cloud 74 | uses: google-github-actions/auth@v3 75 | with: 76 | credentials_json: ${{ secrets.GCS_SA_KEY }} 77 | 78 | - name: Setup gcloud 79 | uses: google-github-actions/setup-gcloud@v3 80 | with: 81 | project_id: ${{ secrets.GCS_PROJECT }} 82 | 83 | # Uploading docs to GCS, so they can be served on lightning.ai 84 | #- name: Upload docs/thunder/stable to GCS 🪣 85 | # if: startsWith(github.ref, 'refs/heads/release/') 86 | # run: gsutil -m rsync -d -R docs/build/html/ ${GCP_TARGET}/stable 87 | 88 | # Uploading docs to GCS, so they can be served on lightning.ai 89 | - name: Upload docs/thunder/latest to GCS 🪣 90 | if: github.ref == 'refs/heads/main' 91 | run: gsutil -m rsync -d -R docs/build/html/ ${GCP_TARGET}/latest 92 | 93 | # Uploading docs to GCS, so they can be served on lightning.ai 94 | #- name: Upload docs/thunder/release to GCS 🪣 95 | # if: startsWith(github.ref, 'refs/tags/') 96 | # run: gsutil -m rsync -d -R docs/build/html/ ${GCP_TARGET}/${{ github.ref_name }} 97 | 98 | # Uploading docs as archive to GCS, so they can be as backup 99 | #- name: Upload docs as archive to GCS 🪣 100 | # if: startsWith(github.ref, 'refs/tags/') 101 | # working-directory: docs/build 102 | # run: | 103 | # zip ${{ github.ref_name }}.zip -r html/ 104 | # gsutil cp ${{ github.ref_name }}.zip ${GCP_TARGET} 105 | -------------------------------------------------------------------------------- /.github/workflows/label-conflicts.yml: -------------------------------------------------------------------------------- 1 | name: Label merge conflicts 2 | 3 | on: 4 | push: 5 | branches: ["main"] 6 | pull_request_target: 7 | types: ["synchronize", "reopened", "opened"] 8 | 9 | concurrency: 10 | group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref }} 11 | cancel-in-progress: true 12 | 13 | jobs: 14 | triage-conflicts: 15 | runs-on: ubuntu-latest 16 | steps: 17 | - uses: mschilde/auto-label-merge-conflicts@591722e97f3c4142df3eca156ed0dcf2bcd362bd 18 | with: 19 | CONFLICT_LABEL_NAME: "has conflicts" 20 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 21 | MAX_RETRIES: 3 22 | WAIT_MS: 5000 23 | -------------------------------------------------------------------------------- /.github/workflows/labeler.yml: -------------------------------------------------------------------------------- 1 | name: "Pull Request Labeler" 2 | on: [pull_request_target] 3 | 4 | jobs: 5 | triage-prs: 6 | permissions: 7 | contents: read 8 | pull-requests: write 9 | runs-on: ubuntu-latest 10 | steps: 11 | # Uploads repository content to the runner 12 | - uses: actions/checkout@v5 13 | - uses: actions/labeler@v5 14 | with: 15 | # The path to the label configuration file. 16 | configuration-path: .github/labeling-config.yml 17 | # Whether removing labels when matching files are reverted or no longer changed by the PR 18 | sync-labels: true 19 | -------------------------------------------------------------------------------- /.github/workflows/release-nightly.yml: -------------------------------------------------------------------------------- 1 | name: Nightly packages 2 | 3 | on: 4 | pull_request: # this shall test only the part of workflow before publishing 5 | branches: [main, "release/*"] 6 | types: [opened, reopened, ready_for_review, synchronize] 7 | paths: 8 | - ".github/workflows/release-nightly.yml" 9 | schedule: 10 | - cron: "0 0 * * 0" # on Sundays 11 | workflow_dispatch: {} 12 | 13 | defaults: 14 | run: 15 | shell: bash 16 | 17 | jobs: 18 | releasing-nightly: 19 | runs-on: ubuntu-22.04 20 | steps: 21 | - uses: actions/checkout@v5 22 | - uses: actions/setup-python@v5 23 | with: 24 | python-version: "3.10" 25 | 26 | - name: Install dependencies 27 | run: python -m pip install --user --upgrade setuptools wheel packaging 28 | - name: Build package 29 | env: 30 | CONVERT_VERSION2NIGHTLY: "1" 31 | run: python setup.py sdist bdist_wheel 32 | 33 | # We do this, since failures on test.pypi aren't that bad 34 | - name: Publish to Test PyPI 35 | if: github.event_name == 'schedule' || github.event_name == 'workflow_dispatch' 36 | uses: pypa/gh-action-pypi-publish@v1.12.4 37 | with: 38 | user: __token__ 39 | password: ${{ secrets.test_pypi_password }} 40 | repository_url: https://test.pypi.org/legacy/ 41 | 42 | - name: Publish distribution 📦 to PyPI 43 | if: github.event_name == 'schedule' || github.event_name == 'workflow_dispatch' 44 | uses: pypa/gh-action-pypi-publish@v1.12.4 45 | with: 46 | user: __token__ 47 | password: ${{ secrets.pypi_password }} 48 | -------------------------------------------------------------------------------- /.github/workflows/release-pypi.yml: -------------------------------------------------------------------------------- 1 | name: PyPI Release 2 | 3 | # https://help.github.com/en/actions/reference/events-that-trigger-workflows 4 | on: # Trigger the workflow on push or pull request, but only for the main branch 5 | push: 6 | branches: [main] 7 | release: 8 | types: [published] 9 | 10 | # based on https://github.com/pypa/gh-action-pypi-publish 11 | 12 | jobs: 13 | releasing-pypi: 14 | runs-on: ubuntu-22.04 15 | steps: 16 | - uses: actions/checkout@v5 17 | - uses: actions/setup-python@v5 18 | with: 19 | python-version: "3.10" 20 | 21 | - name: Install dependencies 22 | run: python -m pip install --user --upgrade setuptools wheel packaging 23 | - name: Build 24 | run: python setup.py sdist bdist_wheel 25 | 26 | # We do this, since failures on test.pypi aren't that bad 27 | - name: Publish to Test PyPI 28 | if: startsWith(github.event.ref, 'refs/tags') || github.event_name == 'release' 29 | uses: pypa/gh-action-pypi-publish@v1.12.4 30 | with: 31 | user: __token__ 32 | password: ${{ secrets.test_pypi_password }} 33 | repository_url: https://test.pypi.org/legacy/ 34 | 35 | - name: Publish distribution 📦 to PyPI 36 | if: startsWith(github.event.ref, 'refs/tags') || github.event_name == 'release' 37 | uses: pypa/gh-action-pypi-publish@v1.12.4 38 | with: 39 | user: __token__ 40 | password: ${{ secrets.pypi_password }} 41 | -------------------------------------------------------------------------------- /.github/workflows/stale-branches.yaml: -------------------------------------------------------------------------------- 1 | name: Delete abandoned branches 2 | 3 | on: 4 | pull_request: 5 | branches: ["main"] 6 | paths: 7 | - ".github/workflows/stale-branches.yaml" 8 | # Run daily at midnight 9 | schedule: 10 | - cron: "0 0 * * *" 11 | # Allow workflow to be manually run from the GitHub UI 12 | workflow_dispatch: 13 | 14 | concurrency: 15 | group: ${{ github.workflow }}-${{ github.ref }} 16 | cancel-in-progress: true 17 | 18 | jobs: 19 | cleanup-old-branches: 20 | runs-on: ubuntu-latest 21 | name: Satisfy my repo CDO 22 | env: 23 | DRY_RUN: no 24 | steps: 25 | - name: Set dry run 26 | if: ${{ github.event_name == 'pull_request' }} 27 | run: echo "DRY_RUN=yes" >> $GITHUB_ENV 28 | - name: Delete those pesky dead branches 29 | uses: phpdocker-io/github-actions-delete-abandoned-branches@v2 30 | id: delete_stuff 31 | with: 32 | github_token: ${{ github.token }} 33 | last_commit_age_days: 90 34 | ignore_branches: main 35 | # Disable dry run and actually get stuff deleted 36 | # For a PR, always perform dry run 37 | dry_run: ${{ env.DRY_RUN }} 38 | 39 | - name: Get output 40 | run: "echo 'Deleted branches: ${{ steps.delete_stuff.outputs.deleted_branches }}'" 41 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | 29 | # Python venv 30 | bin/ 31 | lib64 32 | pyvenv.cfg 33 | share/ 34 | etc/ 35 | include/ 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .nox/ 51 | .coverage 52 | .coverage.* 53 | .cache 54 | nosetests.xml 55 | coverage.xml 56 | *.cover 57 | *.py,cover 58 | .hypothesis/ 59 | .pytest_cache/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Sphinx documentation 66 | docs/_build/ 67 | docs/source/api/ 68 | docs/source/*.md 69 | docs/source/notebooks/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # Jupyter Notebook 75 | .ipynb_checkpoints 76 | 77 | # VSCode 78 | .vscode/ 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 88 | __pypackages__/ 89 | 90 | # Celery stuff 91 | celerybeat-schedule 92 | celerybeat.pid 93 | 94 | # SageMath parsed files 95 | *.sage.py 96 | 97 | # Environments 98 | .env 99 | .venv 100 | env/ 101 | venv/ 102 | ENV/ 103 | env.bak/ 104 | venv.bak/ 105 | 106 | # Spyder project settings 107 | .spyderproject 108 | .spyproject 109 | 110 | # Rope project settings 111 | .ropeproject 112 | 113 | # mkdocs documentation 114 | /site 115 | 116 | # mypy 117 | .mypy_cache/ 118 | .dmypy.json 119 | dmypy.json 120 | 121 | # Pyre type checker 122 | .pyre/ 123 | 124 | # PyCharm 125 | .idea/ 126 | 127 | # Lightning logs 128 | lightning_logs 129 | *.gz 130 | .DS_Store 131 | .*_submit.py 132 | 133 | # Editor temporary file 134 | *.swn 135 | *.swo 136 | *.swp 137 | *.swm 138 | *~ 139 | 140 | # Build artifacts 141 | scripts/build 142 | 143 | # Profiler traces 144 | benchmarks/traces 145 | 146 | # benchmark results 147 | .results 148 | 149 | .ruff_cache/ 150 | 151 | # come CI artifacts 152 | notebooks/all.txt 153 | notebooks/ci.txt 154 | 155 | quickstart_report.json 156 | -------------------------------------------------------------------------------- /.lightning/workflows/all-tests.yaml: -------------------------------------------------------------------------------- 1 | trigger: 2 | push: 3 | branches: ["main"] 4 | pull_request: 5 | branches: ["main"] 6 | 7 | timeout: "60" # minutes 8 | interruptible: False 9 | parametrize: 10 | matrix: 11 | image: 12 | - "pytorchlightning/lightning-thunder:ubuntu24.04-cuda12.6.3-cudnn-fe1.10.0-py3.10-pt_2.8.0-dev" 13 | - "pytorchlightning/lightning-thunder:ubuntu24.04-cuda12.6.3-cudnn-fe1.10.0-py3.10-pt_main-dev" 14 | testing: ["main", "ops", "grads"] 15 | machine: ["L4"] 16 | exclude: [] 17 | include: 18 | - image: "pytorchlightning/lightning-thunder:ubuntu24.04-cuda12.6.3-cudnn-fe1.10.0-py3.10-pt_2.8.0-dev" 19 | testing: "distributed" 20 | machine: "L4_X_2" 21 | - image: "pytorchlightning/lightning-thunder:ubuntu24.04-cuda12.6.3-cudnn-fe1.10.0-py3.10-pt_main-dev" 22 | testing: "distributed" 23 | machine: "L4_X_2" 24 | 25 | env: 26 | CI: "true" # skip some tests with CI 27 | NCCL_DEBUG: "INFO" 28 | NCCL_IGNORE_DISABLED_P2P: "1" 29 | TORCH_VERSION: "2.7.1" 30 | CUDA_LAUNCH_BLOCKING: "1" # for debugging purposes, to get better stack traces 31 | 32 | run: | 33 | whereis nvidia 34 | nvidia-smi 35 | python --version 36 | pip --version 37 | pip list 38 | set -ex 39 | 40 | # drop pt from requirements so not to interfere with the existing one 41 | bash scripts/remove-torch-lines.sh requirements/base.txt 42 | cat requirements/base.txt 43 | 44 | # double check on test requirements 45 | pip install -U -r requirements/base.txt -r requirements/test.txt 46 | 47 | # https://docs.codecov.com/docs/codecov-uploader 48 | curl -Os https://uploader.codecov.io/latest/linux/codecov 49 | chmod +x codecov 50 | 51 | # install this package 52 | python setup.py develop 53 | 54 | bash scripts/sanity-check.sh 55 | 56 | if [ "${testing}" == "main" ]; then 57 | coverage run --source thunder -m \ 58 | pytest thunder/tests/ \ 59 | -m "not standalone" \ 60 | -v --datefmt="%Y%m%d-%H:%M:%S.%f" \ 61 | --random-order-seed=42 \ 62 | --durations=250 \ 63 | --timeout=360 \ 64 | --numprocesses=9 \ 65 | --ignore=thunder/tests/distributed --ignore=thunder/tests/test_networks.py \ 66 | --ignore=thunder/tests/test_ops.py --ignore=thunder/tests/test_grad.py 67 | coverage run --source thunder -m \ 68 | pytest \ 69 | thunder/tests/test_networks.py \ 70 | -m "not standalone" \ 71 | -v --durations=0 \ 72 | --random-order-seed=42 \ 73 | --numprocesses=3 74 | elif [ "${testing}" == "ops" ]; then 75 | coverage run --source thunder -m \ 76 | pytest thunder/tests/test_ops.py \ 77 | -m "not standalone" \ 78 | -v --datefmt="%Y%m%d-%H:%M:%S.%f" \ 79 | --random-order-seed=42 \ 80 | --durations=250 \ 81 | --timeout=240 \ 82 | --numprocesses=9 83 | elif [ "${testing}" == "grads" ]; then 84 | coverage run --source thunder -m \ 85 | pytest thunder/tests/test_grad.py \ 86 | -m "not standalone" \ 87 | -v --datefmt="%Y%m%d-%H:%M:%S.%f" \ 88 | --random-order-seed=42 \ 89 | --durations=250 \ 90 | --timeout=360 \ 91 | --numprocesses=9 92 | elif [ "${testing}" == "distributed" ]; then 93 | pytest thunder/tests/distributed \ 94 | -v --durations=0 \ 95 | --random-order-seed=42 96 | else 97 | echo "Unknown testing type: ${testing}" 98 | exit 1 99 | fi 100 | 101 | # TODO: compile coverage results 102 | #python -m coverage report 103 | #python -m coverage xml 104 | # upload to codecov 105 | # TODO: add >> --token=$(CODECOV_TOKEN) --commit=$(Build.SourceVersion) 106 | #./codecov --flags=gpu,pytest,${testing} --name="GPU-coverage" --env=linux 107 | -------------------------------------------------------------------------------- /.lightning/workflows/notebooks.yaml: -------------------------------------------------------------------------------- 1 | trigger: 2 | push: 3 | branches: ["main"] 4 | pull_request: 5 | branches: ["main"] 6 | 7 | timeout: "50" # minutes 8 | machine: "L4" 9 | interruptible: False 10 | parametrize: 11 | matrix: 12 | image: 13 | - "pytorchlightning/lightning-thunder:ubuntu24.04-cuda12.6.3-cudnn-fe1.10.0-py3.10-pt_2.8.0-dev" 14 | - "pytorchlightning/lightning-thunder:ubuntu24.04-cuda12.6.3-cudnn-fe1.10.0-py3.10-pt_main-dev" 15 | exclude: [] 16 | include: [] 17 | 18 | run: | 19 | whereis nvidia 20 | nvidia-smi 21 | python --version 22 | pip --version 23 | pip list 24 | set -ex 25 | 26 | # drop pt from requirements so not to interfere with the existing one 27 | bash scripts/remove-torch-lines.sh requirements/base.txt 28 | cat requirements/base.txt 29 | # double check on test requirements 30 | pip install -q -U -r requirements/base.txt -r requirements/notebooks.txt 31 | # install this package 32 | python setup.py develop 33 | 34 | bash scripts/sanity-check.sh 35 | 36 | # list all notebooks in this folder 37 | cd notebooks/ 38 | find . -name "*.ipynb" > all.txt 39 | # drop all "./" from beginning of each line 40 | sed -i 's/^\.\///' all.txt 41 | # filter out the ones that are listed in .ignore.ci 42 | grep -Fxv -f .ignore.ci all.txt > ci.txt 43 | # iterate over all listed notebooks and execute them with jupyter 44 | while read -r line; do 45 | echo "Processing $line" 46 | jupyter execute $line --timeout=300 47 | done <<< $(cat ci.txt) 48 | -------------------------------------------------------------------------------- /.lightning/workflows/transformer-engine.yaml: -------------------------------------------------------------------------------- 1 | trigger: 2 | push: 3 | branches: ["main"] 4 | pull_request: 5 | branches: ["main"] 6 | 7 | timeout: "30" # minutes 8 | machine: "L4" 9 | interruptible: False 10 | image: "pytorchlightning/lightning-thunder:ubuntu24.04-cuda12.6.3-cudnn-fe1.10.0-py3.10-pt_2.8.0-dev" 11 | parametrize: 12 | matrix: 13 | test_file: 14 | - test_transformer_engine_executor.py 15 | - test_transformer_engine_v1_executor.py 16 | 17 | run: | 18 | whereis nvidia 19 | nvidia-smi 20 | pip list 21 | set -ex 22 | 23 | # conda install -c conda-forge libstdcxx-ng 24 | # sudo apt install libstdc++6 libstdc++-*-dev 25 | pip install . -U -q -r requirements/test.txt 26 | # Need to explicitly point to cudnn.h as it is installed at a non-standard location 27 | # Ref: https://github.com/NVIDIA/TransformerEngine/issues/918#issuecomment-2187703769 28 | CPLUS_INCLUDE_PATH="/usr/local/lib/python3.10/dist-packages/nvidia/cudnn/include/" pip install --no-build-isolation 'transformer_engine[pytorch]' 29 | pip list # for debugging purposes 30 | pytest thunder/tests/${test_file} -v -rs 31 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3 3 | 4 | ci: 5 | autofix_prs: true 6 | autoupdate_commit_msg: "[pre-commit.ci] pre-commit suggestions" 7 | # submodules: true 8 | 9 | repos: 10 | - repo: https://github.com/pre-commit/pre-commit-hooks 11 | rev: v6.0.0 12 | hooks: 13 | - id: end-of-file-fixer 14 | - id: trailing-whitespace 15 | exclude: README.md 16 | - id: check-case-conflict 17 | - id: check-yaml 18 | - id: check-toml 19 | - id: check-json 20 | - id: check-added-large-files 21 | args: ["--maxkb=400", "--enforce-all"] 22 | exclude: notebooks 23 | - id: check-docstring-first 24 | - id: detect-private-key 25 | 26 | - repo: https://github.com/asottile/pyupgrade 27 | rev: v3.20.0 28 | hooks: 29 | - id: pyupgrade 30 | args: ["--py310-plus"] 31 | name: Upgrade code 32 | exclude: "examples|thunder/tests/test_interpreter.py|thunder/tests/test_jit_general.py" 33 | 34 | - repo: https://github.com/codespell-project/codespell 35 | rev: v2.4.1 36 | hooks: 37 | - id: codespell 38 | additional_dependencies: [tomli] 39 | #args: ["--write-changes"] # uncomment if you want to get automatic fixing 40 | 41 | - repo: https://github.com/astral-sh/ruff-pre-commit 42 | rev: v0.13.0 43 | hooks: 44 | - id: ruff-check 45 | args: ["--fix"] 46 | - id: ruff-format 47 | types_or: [python] 48 | exclude: "examples" 49 | 50 | - repo: https://github.com/executablebooks/mdformat 51 | rev: 0.7.22 52 | hooks: 53 | - id: mdformat 54 | additional_dependencies: 55 | - mdformat-gfm 56 | - mdformat-black 57 | - mdformat_frontmatter 58 | exclude: "examples" 59 | 60 | - repo: https://github.com/sphinx-contrib/sphinx-lint 61 | rev: v1.0.0 62 | hooks: 63 | - id: sphinx-lint 64 | 65 | - repo: https://github.com/JoC0de/pre-commit-prettier 66 | rev: b3e25fa39aa676c36bc18eb9eae6f26d9bb63f39 # v3.6.2 using SHA as tags are not persistent 67 | hooks: 68 | - id: prettier 69 | files: \.(json|yml|yaml|toml) 70 | # https://prettier.io/docs/en/options.html#print-width 71 | args: ["--print-width=120"] 72 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-exclude __pycache__ *.py[cod] *.orig 2 | 3 | # Include the README and CHANGELOG 4 | include *.md 5 | recursive-include pl_sandbox *.md 6 | 7 | # Include the license file 8 | include LICENSE 9 | 10 | # Exclude build configs 11 | exclude *.toml 12 | exclude *.svg 13 | exclude *.yml 14 | exclude *.yaml 15 | 16 | # exclude tests from package 17 | recursive-exclude thunder/tests * 18 | recursive-exclude site * 19 | exclude thunder/tests 20 | 21 | # Exclude the documentation files 22 | recursive-exclude docs * 23 | exclude docs 24 | 25 | # Include the Requirements 26 | recursive-include requirements *.txt 27 | 28 | # Exclude Makefile 29 | exclude Makefile 30 | 31 | prune .git 32 | prune .github 33 | prune notebook* 34 | prune temp* 35 | prune test* 36 | prune benchmark* 37 | prune examples* 38 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: test clean docs 2 | 3 | # assume you have installed need packages 4 | export SPHINX_MOCK_REQUIREMENTS=0 5 | 6 | test: clean 7 | pip install -q -r requirements.txt -r requirements/test.txt 8 | 9 | # use this to run tests 10 | python -m coverage run --source thunder -m pytest thunder tests -v 11 | python -m coverage report 12 | 13 | get-sphinx-theme: 14 | pip install -q awscli 15 | mkdir -p dist/ 16 | aws s3 sync --no-sign-request s3://sphinx-packages/ dist/ 17 | pip install lai-sphinx-theme -f dist/ 18 | 19 | docs: clean get-sphinx-theme 20 | pip install -e . --quiet -r requirements/docs.txt -f https://download.pytorch.org/whl/cpu/torch_stable.html 21 | cd docs ; python -m sphinx -b html -W --keep-going source build 22 | 23 | clean: 24 | # clean all temp runs 25 | rm -rf .mypy_cache 26 | rm -rf .pytest_cache 27 | rm -rf ./docs/build 28 | rm -rf ./docs/source/**/generated 29 | rm -rf ./docs/source/api 30 | rm -rf _ckpt_* 31 | -------------------------------------------------------------------------------- /dockers/README.md: -------------------------------------------------------------------------------- 1 | # Docker images 2 | 3 | ## Build images from Dockerfiles 4 | 5 | You can build it on your own, note it takes lots of time, be prepared. 6 | 7 | ```bash 8 | # build with specific arguments 9 | docker image build -t lightning:ubuntu-cuda-py3.10-cuda12.1.1 -f dockers/ubuntu-cuda/Dockerfile --build-arg "CUDA_VERSION=12.1.1" . 10 | ``` 11 | 12 | To run your docker use 13 | 14 | ```bash 15 | docker image list 16 | docker run --rm -it pytorch-lightning:ubuntu-cuda-py3.10-cuda11.7.0 bash 17 | ``` 18 | 19 | ## Run docker image with GPUs 20 | 21 | To run docker image with access to your GPUs, you need to install 22 | 23 | ```bash 24 | # Add the package repositories 25 | distribution=$(. /etc/os-release;echo $ID$VERSION_ID) 26 | curl -s -L https://nvidia.github.io/nvidia-docker/gpgkey | sudo apt-key add - 27 | curl -s -L https://nvidia.github.io/nvidia-docker/$distribution/nvidia-docker.list | sudo tee /etc/apt/sources.list.d/nvidia-docker.list 28 | 29 | sudo apt-get update && sudo apt-get install -y nvidia-container-toolkit 30 | sudo systemctl restart docker 31 | ``` 32 | 33 | and later run the docker image with `--gpus=all`. For example, 34 | 35 | ```bash 36 | docker run --rm -it --gpus=all pytorchlightning/lightning:ubuntu-cuda-py3.10-cuda12.1.0 37 | ``` 38 | -------------------------------------------------------------------------------- /dockers/with-apex/Dockerfile: -------------------------------------------------------------------------------- 1 | # Copyright The Lightning AI team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | ARG BASE_IMAGE_TAG 16 | 17 | FROM pytorchlightning/lightning-thunder:${BASE_IMAGE_TAG} 18 | 19 | ARG APEX_CHECKOUT="master" 20 | 21 | SHELL ["/bin/bash", "-c"] 22 | 23 | RUN \ 24 | # building Apex from source 25 | pip install "pip>=23.1" packaging && \ 26 | git clone https://github.com/NVIDIA/apex && \ 27 | cd apex && \ 28 | git checkout ${APEX_CHECKOUT} && \ 29 | # https://github.com/NVIDIA/apex#linux 30 | pip install -v \ 31 | --disable-pip-version-check \ 32 | --no-cache-dir \ 33 | --no-build-isolation \ 34 | --config-settings "--build-option=--xentropy" \ 35 | . && \ 36 | cd .. && \ 37 | rm -rf apex 38 | 39 | RUN \ 40 | # Show what we have 41 | pip --version && \ 42 | pip list && \ 43 | python -c "import apex" 44 | -------------------------------------------------------------------------------- /dockers/with-dev/Dockerfile: -------------------------------------------------------------------------------- 1 | # Copyright The Lightning AI team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | ARG BASE_IMAGE_TAG 16 | 17 | FROM pytorchlightning/lightning-thunder:${BASE_IMAGE_TAG} 18 | 19 | SHELL ["/bin/bash", "-c"] 20 | 21 | COPY requirements/ requirements/ 22 | 23 | RUN \ 24 | ls -lh requirements/ && \ 25 | CUDA_VERSION_MM=${CUDA_VERSION%.*} && \ 26 | MAKEFLAGS="-j$(( $(nproc) / 4 ))" && \ 27 | pip install -U -r requirements/test.txt && \ 28 | rm -rf requirements/ 29 | 30 | RUN \ 31 | # Show what we have 32 | pip --version && \ 33 | pip list 34 | -------------------------------------------------------------------------------- /docs/.build_docs.sh: -------------------------------------------------------------------------------- 1 | make clean 2 | make html --debug --jobs $(nproc) 3 | -------------------------------------------------------------------------------- /docs/.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Build documentation in the docs/ directory with Sphinx 9 | # reference: https://docs.readthedocs.io/en/stable/config-file/v2.html#sphinx 10 | sphinx: 11 | fail_on_warning: true 12 | configuration: docs/conf.py 13 | 14 | build: 15 | os: "ubuntu-22.04" 16 | tools: 17 | python: "3.10" 18 | commands: 19 | - printenv 20 | - pwd ; pip install -q py-tree ; py-tree . 21 | - make docs 22 | - mkdir -p _readthedocs ; mv docs/build _readthedocs/html 23 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = -W 6 | SPHINXBUILD = python $(shell which sphinx-build) 7 | SOURCEDIR = source 8 | BUILDDIR = build 9 | 10 | # Put it first so that "make" without argument is like "make help". 11 | help: 12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 13 | 14 | .PHONY: help Makefile 15 | 16 | # Catch-all target: route all unknown targets to Sphinx using the new 17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 18 | %: Makefile 19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 20 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/source/_static/copybutton.js: -------------------------------------------------------------------------------- 1 | /* Copied from the official Python docs: https://docs.python.org/3/_static/copybutton.js */ 2 | $(document).ready(function() { 3 | /* Add a [>>>] button on the top-right corner of code samples to hide 4 | * the >>> and ... prompts and the output and thus make the code 5 | * copyable. */ 6 | var div = $('.highlight-python .highlight,' + 7 | '.highlight-python3 .highlight,' + 8 | '.highlight-pycon .highlight,' + 9 | '.highlight-default .highlight'); 10 | var pre = div.find('pre'); 11 | 12 | // get the styles from the current theme 13 | pre.parent().parent().css('position', 'relative'); 14 | var hide_text = 'Hide the prompts and output'; 15 | var show_text = 'Show the prompts and output'; 16 | var border_width = pre.css('border-top-width'); 17 | var border_style = pre.css('border-top-style'); 18 | var border_color = pre.css('border-top-color'); 19 | var button_styles = { 20 | 'cursor':'pointer', 'position': 'absolute', 'top': '0', 'right': '0', 21 | 'border-color': border_color, 'border-style': border_style, 22 | 'border-width': border_width, 'color': border_color, 'text-size': '75%', 23 | 'font-family': 'monospace', 'padding-left': '0.2em', 'padding-right': '0.2em', 24 | 'border-radius': '0 3px 0 0' 25 | } 26 | 27 | // create and add the button to all the code blocks that contain >>> 28 | div.each(function(index) { 29 | var jthis = $(this); 30 | if (jthis.find('.gp').length > 0) { 31 | var button = $('>>>'); 32 | button.css(button_styles) 33 | button.attr('title', hide_text); 34 | button.data('hidden', 'false'); 35 | jthis.prepend(button); 36 | } 37 | // tracebacks (.gt) contain bare text elements that need to be 38 | // wrapped in a span to work with .nextUntil() (see later) 39 | jthis.find('pre:has(.gt)').contents().filter(function() { 40 | return ((this.nodeType == 3) && (this.data.trim().length > 0)); 41 | }).wrap(''); 42 | }); 43 | 44 | // define the behavior of the button when it's clicked 45 | $('.copybutton').click(function(e){ 46 | e.preventDefault(); 47 | var button = $(this); 48 | if (button.data('hidden') === 'false') { 49 | // hide the code output 50 | button.parent().find('.go, .gp, .gt').hide(); 51 | button.next('pre').find('.gt').nextUntil('.gp, .go').css('visibility', 'hidden'); 52 | button.css('text-decoration', 'line-through'); 53 | button.attr('title', show_text); 54 | button.data('hidden', 'true'); 55 | } else { 56 | // show the code output 57 | button.parent().find('.go, .gp, .gt').show(); 58 | button.next('pre').find('.gt').nextUntil('.gp, .go').css('visibility', 'visible'); 59 | button.css('text-decoration', 'none'); 60 | button.attr('title', hide_text); 61 | button.data('hidden', 'false'); 62 | } 63 | }); 64 | }); 65 | -------------------------------------------------------------------------------- /docs/source/_static/images/LightningThunderDarkModewByline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-AI/lightning-thunder/6720e82b18884a26bbadaffff2f33d836682cc9d/docs/source/_static/images/LightningThunderDarkModewByline.png -------------------------------------------------------------------------------- /docs/source/_static/images/LightningThunderLightModewByline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-AI/lightning-thunder/6720e82b18884a26bbadaffff2f33d836682cc9d/docs/source/_static/images/LightningThunderLightModewByline.png -------------------------------------------------------------------------------- /docs/source/_static/images/how_it_works.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-AI/lightning-thunder/6720e82b18884a26bbadaffff2f33d836682cc9d/docs/source/_static/images/how_it_works.png -------------------------------------------------------------------------------- /docs/source/_static/images/lightning_thunder_lightmode_nobyline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-AI/lightning-thunder/6720e82b18884a26bbadaffff2f33d836682cc9d/docs/source/_static/images/lightning_thunder_lightmode_nobyline.png -------------------------------------------------------------------------------- /docs/source/_static/images/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-AI/lightning-thunder/6720e82b18884a26bbadaffff2f33d836682cc9d/docs/source/_static/images/logo.png -------------------------------------------------------------------------------- /docs/source/_static/images/normalized_training_throughput_zero2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-AI/lightning-thunder/6720e82b18884a26bbadaffff2f33d836682cc9d/docs/source/_static/images/normalized_training_throughput_zero2.png -------------------------------------------------------------------------------- /docs/source/_static/images/pretrain_perf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-AI/lightning-thunder/6720e82b18884a26bbadaffff2f33d836682cc9d/docs/source/_static/images/pretrain_perf.png -------------------------------------------------------------------------------- /docs/source/_static/images/training_throughput_single.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-AI/lightning-thunder/6720e82b18884a26bbadaffff2f33d836682cc9d/docs/source/_static/images/training_throughput_single.png -------------------------------------------------------------------------------- /docs/source/_templates/theme_variables.jinja: -------------------------------------------------------------------------------- 1 | {%- set external_urls = { 2 | 'github': 'https://github.com/Lightning-AI/lightning-thunder', 3 | 'github_issues': 'https://github.com/Lightning-AI/lightning-thunder/issues', 4 | 'contributing': 'https://github.com/Lightning-AI/lightning-thunder/blob/master/CONTRIBUTING.md', 5 | 'governance': 'https://github.com/Lightning-AI/lightning-thunder/blob/master/governance.md', 6 | 'docs': 'https://lightning-thunder.rtfd.io/en/latest', 7 | 'twitter': 'https://twitter.com/LightningAI', 8 | 'discuss': 'https://pytorch-lightning.slack.com', 9 | 'tutorials': 'https://lightning-thunder.readthedocs.io/en/latest/#tutorials', 10 | 'previous_pytorch_versions': 'https://lightning-thunder.rtfd.io/en/latest/', 11 | 'home': 'https://lightning-thunder.rtfd.io/en/latest/', 12 | 'get_started': 'https://lightning-thunder.readthedocs.io/en/latest/introduction_guide.html', 13 | 'features': 'https://lightning-thunder.rtfd.io/en/latest/', 14 | 'blog': 'https://www.Lightning-AI.ai/blog', 15 | 'resources': 'https://lightning-thunder.readthedocs.io/en/latest/#community-examples', 16 | 'support': 'https://lightning-thunder.rtfd.io/en/latest/', 17 | } 18 | -%} 19 | -------------------------------------------------------------------------------- /docs/source/advanced/extending.rst: -------------------------------------------------------------------------------- 1 | Extending Thunder 2 | ################# 3 | 4 | .. 5 | TODO RC1: update using the extend API 6 | 7 | This section describes how to add an executor to Thunder for a PyTorch operation. 8 | 9 | First, define a Python function with the same signature as the targeted operation, and have it call your implementation. For example, the Apex executor for ``torch.nn.functional.cross_entropy`` might define its implementation like:: 10 | 11 | import torch 12 | import xentropy_cuda 13 | 14 | def apex_xentropy( 15 | a: torch.Tensor, # a is an actual PyTorch tensor 16 | target, 17 | weight=None, 18 | size_average=None, 19 | ignore_index=-100, 20 | reduce=None, 21 | reduction="mean", 22 | label_smoothing=0.0, 23 | ): 24 | losses, max_log_sum_exp = xentropy_cuda.forward(a, target, label_smoothing, half_to_float) 25 | 26 | When this implementation is used it will be called with actual PyTorch tensors, and not with proxies. 27 | 28 | Next, define a “checker” function with the same signature as the targeted operation that returns True if your operation can execute the targeted operation and False otherwise. Checkers, unlike the implementations, are called with proxies, and not actual PyTorch tensors, because they're called at optimization time. The purpose of a checker function is to let executors target only specific inputs to an operation, and defer to another executor on other inputs. 29 | 30 | A checker function for the Apex executor might look like:: 31 | 32 | from thunder.core.proxies import TensorProxy 33 | 34 | def apex_xentropy_checker( 35 | a: TensorProxy, # a is a proxy 36 | target, 37 | weight=None, 38 | size_average=None, 39 | ignore_index=-100, 40 | reduce=None, 41 | reduction="mean", 42 | label_smoothing=0.0, 43 | ): 44 | # Apex's xentropy only supports "sum", "mean" or "none" reductions 45 | if reduction not in ["sum", "mean", "none"]: 46 | return False 47 | 48 | return True 49 | 50 | Create a mapping from the name of the PyTorch operation to your replacement implementation's name, its checker, and its implementation:: 51 | 52 | _op_to_xentropy = { 53 | "torch.nn.functional.cross_entropy": ("apex_xentropy", apex_xentropy_checker, apex_xentropy), 54 | } 55 | 56 | Then define a registration function that practitioners can call to access your executor:: 57 | 58 | def register_apex_xentropyex(*, add_to_default_executors: bool = True) -> None: 59 | from thunder.executors import add_operator_executor 60 | 61 | return add_operator_executor("apex_xentropy", _op_to_xentropy, add_to_default_executors=add_to_default_executors) 62 | 63 | You can test your executor by registering it, compiling a function that calls the targeted operator, and then verifying that your operation is called (by inspecting the execution trace) and producing the correct output. A good example of this is the tests for the Apex executor. 64 | -------------------------------------------------------------------------------- /docs/source/basic/sharp_edges.rst: -------------------------------------------------------------------------------- 1 | The sharp edges 2 | ############### 3 | 4 | This section describes features in the language that are not (yet) supported by Thunder, along with their workarounds. You might encounter these when compiling a module or function with Thunder. The examples should give you a good idea of how to change your program so that it works. 5 | 6 | Note that the fact that something is not supported today doesn't mean it won't be supported at some point in the future. Feel free to reach out to help us prioritize. 7 | 8 | Inplace operations 9 | ------------------ 10 | 11 | Inplace PyTorch operations like `t.add_(1.0)` are not supported in Thunder yet. Support for inplace operations is coming soon. 12 | 13 | 14 | Tensor subclasses 15 | ----------------- 16 | 17 | Thunder currently supports Python data types and PyTorch tensors as inputs of functions and models. 18 | 19 | Subclasses of these types, e.g. lazy tensors, nested tensors, or sparse tensors are not supported today. 20 | 21 | 22 | Tracing Python builtins, standard library operations and functions that call other languages 23 | -------------------------------------------------------------------------------------------- 24 | 25 | Calling a Python builtin, standard library operation, or a function that calls into another language is safe to trace, so long as the following rules are observed: 26 | 27 | 1. The function should not have side effects. For example, calling ``print()`` will execute the ``print()`` function while tracing, but since it's not a Thunder operation it will not appear in a trace, and so future cached executions will not execute the ``print()`` statement. 28 | 2. The function must not manipulate tensor data or metadata. Since the operation won't appear in a trace, these manipulations won't be repeated by Thunder, and may even cause a crash while tracing. To implement such operations, see :doc:`Adding Custom Operators <../notebooks/adding_custom_operator>` 29 | 3. The function must not produce different results across invocations. Again, since the operation won't appear in traces, Thunder cannot replicate an operation that produces different results when it's invoked, like ``random.random()`` will. 30 | 31 | .. 32 | Certain op-level behavior 33 | ------------------------- 34 | 1. Ops which have not yet been added to Thunder. Please let us know if there’s missing operator support you would like to see and we will be happy to help. 35 | 2. Data dependent control flow (e.g. ``if x.any()``). Since Thunder generates traces of programs ahead of the actual execution, control flow depending on the values of tensors as opposed to their metadata cannot be handled by Thunder. 36 | 37 | 38 | Using Thunder-optimized Modules 39 | ------------------------------- 40 | 41 | Compiling a module produces a Thunder-optimized module”. A Thunder-optimized module is less dynamic than the original module, which facilitates tracing and optimization. It has a reference to the original module, and it shares its parameters with it. 42 | 43 | While modifying the original model's parameters will reflect in the Thunder-optimized module, other changes to the original module will not. In particular: 44 | 45 | - Whether model is in ``train`` or ``eval`` mode is captured at compilation time and constant 46 | - The structure of the module is captured at compilation time, and changing the original module's structure will likely break the Thunder-optimized module 47 | - Non-parameter attributes of the module may or may not be captured at compile time and treated as constants 48 | 49 | Not all features of PyTorch modules are currently supported, either. Module hooks are not supported, and adding new module attributes in a module's ``forward()`` method is only partially supported. 50 | -------------------------------------------------------------------------------- /docs/source/fundamentals/examine.rst: -------------------------------------------------------------------------------- 1 | Using Examine 2 | ############# 3 | 4 | We recommend using Thunder's ``examine()`` before compiling a function or a module. 5 | 6 | Thunder cannot run every PyTorch module, but you can quickly find out what is missing using ``examine()``. 7 | 8 | ``examine()`` not only determines if Thunder can compile the module, but provides a helpful report to use when filing an issue requesting support. 9 | 10 | You can run examine like this:: 11 | 12 | from thunder.examine import examine 13 | 14 | model = MyModel(...) 15 | examine(model, *args, **kwargs) 16 | 17 | Where ``*args and **kwargs`` are valid inputs to the model. If examine determines that Thunder can run the module or function as expected, it will print:: 18 | 19 | The function appears to be working as expected 20 | 21 | When ``examine`` encounters a module or function with one or more operators it doesn't support, it will specify the operators, like this:: 22 | 23 | def foo(a): 24 | return torch.triu(a) 25 | 26 | import torch 27 | import thunder 28 | from thunder.examine import examine 29 | 30 | a = torch.full((2, 2), 1., device='cuda') 31 | examine(foo, a) 32 | 33 | Running the above will print:: 34 | 35 | Found 1 distinct operations, of which 0 (0.0%) are supported 36 | Please file an issue requesting the following operators here: https://github.com/Lightning-AI/lightning-thunder/issues/new 37 | _VariableFunctionsClass.triu of torch 38 | 39 | To recap, ``examine()`` lets you know if Thunder can run a module, and if it can't it will provide a report to file an issue asking for support. 40 | -------------------------------------------------------------------------------- /docs/source/fundamentals/hello_world.rst: -------------------------------------------------------------------------------- 1 | Hello World 2 | ########### 3 | 4 | Here is a simple example of how Thunder lets you compile and run PyTorch modules and functions:: 5 | 6 | import torch 7 | import thunder 8 | 9 | def foo(a, b): 10 | return a + b 11 | 12 | jitted_foo = thunder.jit(foo) 13 | 14 | a = torch.full((2, 2), 1) 15 | b = torch.full((2, 2), 3) 16 | 17 | result = jitted_foo(a, b) 18 | 19 | print(result) 20 | 21 | # prints 22 | # tensor( 23 | # [[4, 4], 24 | # [4, 4]]) 25 | 26 | The compiled function ``jitted_foo`` takes and returns PyTorch tensors, just like the original function, so modules and functions compiled by Thunder can be used as part of bigger PyTorch programs. 27 | 28 | Thunder currently understands a subset of PyTorch operations, and a subset of Python. It's adding support quickly, however. Reach out on the Thunder repo and open an issue there to easily get help if an operator is currently not supported. 29 | -------------------------------------------------------------------------------- /docs/source/fundamentals/installation.rst: -------------------------------------------------------------------------------- 1 | Install Lightning Thunder 2 | ######################### 3 | 4 | Minimal dependencies 5 | ==================== 6 | 7 | Follow these instructions to install PyTorch, nvFuser, and finally Thunder. 8 | 9 | Install PyTorch and nvFuser with pip (command shown is for CUDA 12.1 and PyTorch 2.5.x):: 10 | 11 | pip install --pre nvfuser-cu121-torch25 12 | 13 | cu121 can be replaced with cu118 depending on your CUDA version. NVFuser builds typically support the latest point release of PyTorch stable versions. 14 | For torch 2.5, cu124 is also supported. For nightly versions and more detailed instructions, please see https://github.com/NVIDIA/Fuser/#installation 15 | 16 | You're all set with minimal dependencies, so you can follow `Install Thunder`_. 17 | 18 | Dependencies so far don't include additional optimized kernels for execution like OpenAI Triton, cuDNN Fusion, or Apex. 19 | These are described in the following section, `Optional dependencies`_. 20 | 21 | Optional dependencies 22 | ===================== 23 | 24 | Install Apex 25 | ------------ 26 | 27 | Thunder can use NVIDIA's Apex to accelerate some PyTorch operations. To install the Apex executor, first clone the Apex git repository and then execute the following command in the project's root directory:: 28 | 29 | git clone https://github.com/NVIDIA/apex.git 30 | cd apex 31 | 32 | pip install -v --no-cache-dir --no-build-isolation --config-settings "--build-option=--xentropy" ./ 33 | 34 | Install cuDNN 35 | ------------- 36 | 37 | Thunder can use NVIDIA's cuDNN Python frontend bindings to accelerate some PyTorch operations. cuDNN backend is a pure-lib python package which needs to downloaded separately. 38 | 39 | pip install nvidia-cudnn-cu12 40 | pip install nvidia-cudnn-frontend 41 | 42 | You're all set, now follow `Install Thunder`_. 43 | 44 | Install OpenAI Triton 45 | --------------------- 46 | 47 | Thunder can easily integrate OpenAI Triton kernels. You can install Triton using:: 48 | 49 | pip install triton 50 | 51 | 52 | Install Thunder 53 | =============== 54 | 55 | You can now install Thunder:: 56 | 57 | pip install git+https://github.com/Lightning-AI/lightning-thunder.git 58 | 59 | Alternatively you can clone the Thunder repository and install locally:: 60 | 61 | git clone https://github.com/Lightning-AI/lightning-thunder.git 62 | cd lightning-thunder 63 | 64 | pip install . 65 | -------------------------------------------------------------------------------- /docs/source/intermediate/additional_executors.rst: -------------------------------------------------------------------------------- 1 | Additional executors 2 | #################### 3 | 4 | nvFuser and Pytorch are not the only executors available in Thunder today. Additional executors can be added to Thunder prior to compilation through a registration mechanism, which makes it easy to have specialized executors perform certain operations more efficiently. 5 | 6 | This section contains a list of all executors supported by PyTorch beyond nvFuser and PyTorch. 7 | 8 | Triton CrossEntropy Executor 9 | ============================ 10 | 11 | The Triton CrossEntropy executor can execute ``torch.cross_entropy()`` using an optimized kernel written in OpenAI Triton (https://github.com/openai/triton). It can be used like in the following example:: 12 | 13 | import torch 14 | import thunder 15 | from thunder.executors.triton_crossentropy import triton_ex as triton_cross_entropy_ex 16 | 17 | def xentropy(logits, labels, weight, reduction, ignore_index): 18 | return thunder.torch.cross_entropy( 19 | logits, labels, weight=weight, reduction=reduction, ignore_index=ignore_index 20 | ) 21 | 22 | jitted_xentropy = thunder.jit( 23 | xentropy, 24 | executors=[triton_cross_entropy_ex,] 25 | ) 26 | 27 | device = 'cuda' 28 | dtype = torch.float32 29 | 30 | logits = torch.randn([2048, 50257], device=device, dtype=dtype) 31 | labels = torch.randint(0, 50257, [2048], device=device) 32 | weight = torch.rand(50257, device=device, dtype=dtype, requires_grad=False) 33 | reduction = "sum" 34 | ignore_index = labels[5].item() 35 | 36 | jitted_xentropy(logits, labels, weight, reduction, ignore_index) 37 | traces = thunder.last_traces(jitted_xentropy) 38 | print(traces[-1]) 39 | 40 | This prints:: 41 | 42 | # Constructed by Delete Last Used (took 0 milliseconds) 43 | import torch 44 | from thunder.executors.torchex import no_autocast 45 | 46 | @torch.no_grad() 47 | @no_autocast 48 | def computation(logits, labels, weight): 49 | # logits: "cuda:0 f32[2048, 50257]" 50 | # labels: "cuda:0 i64[2048]" 51 | # weight: "cuda:0 f32[50257]" 52 | t23 = triton_crossentropy(logits, labels, weight, None, 45279, None, 'sum', 0.0) # t23: "cuda:0 f32[]" 53 | del logits, labels, weight 54 | return t23 55 | 56 | As shown in the above trace, ``triton_crossentropy()`` is the one running the operation. 57 | 58 | Apex CrossEntropy Executor 59 | ========================== 60 | 61 | The Apex CrossEntropy executor can execute ``torch.cross_entropy()`` through an optimized kernel, like this:: 62 | 63 | import torch 64 | import thunder 65 | from thunder.executors.apexex import apex_ex 66 | 67 | def xentropy(logits, labels): 68 | return thunder.torch.cross_entropy( 69 | logits, labels, reduction='mean', ignore_index=-1 70 | ) 71 | 72 | jitted_xentropy = thunder.jit(xentropy, executors=[apex_ex,]) 73 | 74 | device = 'cuda' 75 | dtype = torch.float32 76 | 77 | logits = torch.randn([2048, 50257], device=device, dtype=dtype) 78 | labels = torch.randint(0, 50257, [2048], device=device) 79 | 80 | jitted_xentropy(logits, labels) 81 | traces = thunder.last_traces(jitted_xentropy) 82 | print(traces[-1]) 83 | 84 | This prints:: 85 | 86 | # Constructed by Delete Last Used (took 0 milliseconds) 87 | import torch 88 | from thunder.executors.torchex import no_autocast 89 | 90 | @torch.no_grad() 91 | @no_autocast 92 | def computation(logits, labels): 93 | # logits: "cuda:0 f32[2048, 50257]" 94 | # labels: "cuda:0 i64[2048]" 95 | (t18, _) = apex_cross_entropy(logits, labels, 'mean', 0.0) 96 | del logits, labels 97 | return t18 98 | 99 | showing that Apex is running the operation. 100 | 101 | cuDNN SDPA Executor 102 | =================== 103 | 104 | TODO RC1 105 | 106 | TransformerEngine Executor 107 | ========================== 108 | 109 | TODO RC1 110 | -------------------------------------------------------------------------------- /docs/source/intermediate/ddp.rst: -------------------------------------------------------------------------------- 1 | Distributed Data Parallel (DDP) 2 | ############################### 3 | 4 | Thunder has its own Distributed Data Parallel (DDP) transform that we recommend using, although compiled modules also work with PyTorch's DDP transform. 5 | 6 | You can wrap a model in Thunder's ddp like this:: 7 | 8 | from thunder.distributed import ddp 9 | 10 | model = MyModel() 11 | ddp_model = ddp(model) 12 | cmodel = thunder.jit(ddp_model) 13 | 14 | Specifying which rank to broadcast from is optional. ``ddp()`` will broadcast from the lowest rank in that group if ``broadcast_from`` is not specified. 15 | 16 | Thunder's ddp is compatible with PyTorch distributed runners like ``torchrun`` (https://pytorch.org/docs/stable/elastic/run.html). 17 | 18 | When using PyTorch's DDP, call DDP on the jitted module:: 19 | 20 | from torch.nn.parallel import DistributedDataParallel as DDP 21 | 22 | model = MyModel() 23 | jitted_model = thunder.jit(model) 24 | ddp_model = DDP(jitted_model) 25 | 26 | The ability of Thunder to express distributed algorithms like DDP as a simple transform on the trace is one of Thunder's strengths and is being leveraged to quickly implement more elaborate distributed strategies, like Fully Sharded Data Parallel (FSDP). 27 | -------------------------------------------------------------------------------- /docs/source/intermediate/whats_next.rst: -------------------------------------------------------------------------------- 1 | What's Next 2 | ########### 3 | 4 | Thunder is developing rapidly, and this section mentions some of what's happening. Please reach out (see Get Involved) if you're interested in one of these topics. 5 | 6 | Compiling the Training Loop 7 | =========================== 8 | 9 | Thunder currently supports compiling PyTorch modules - forward computation, loss calculation, backward computation -, but we plan to support compiling the entire training loop - forward computation, loss calculation, backward computation, and the optimizer step - for maximum performance. 10 | 11 | Dynamic Caching 12 | =============== 13 | 14 | Thunder currently supports either no caching or static caching, and static caching requires recompiling whenever a module is called with inputs with metadata different than past inputs. This can be overly strict. For example, adding two tensors with shape ``(5, 5)`` is essentially the same as adding two tensors with shape ``(10, 10)``. Dynamic caching will determine if the new metadata would result in a new trace or not, significantly reducing compilation time when training some models. 15 | 16 | Memory Layouts and Strides 17 | ========================== 18 | 19 | Thunder does not currently model any stride information on tensor proxies. In the future we will likely model some stride information, like memory layout (e.g. channels-last), to support integration with PyTorch programs that use memory layout, and to let executors use memory layout to inform kernel selection. 20 | 21 | Functional transforms: vmap and AMP 22 | =================================== 23 | 24 | Thunder already has early implementations of JAX's vmap transform and PyTorch's Automatic Mixed Precision (AMP) autocasting, and we're extending our support for these transforms so practitioners can easily apply a variety of composable transforms to PyTorch modules. 25 | -------------------------------------------------------------------------------- /docs/source/reference/clang/index.rst: -------------------------------------------------------------------------------- 1 | .. module:: thunder.clang 2 | 3 | thunder.clang 4 | ============= 5 | 6 | Thunder Core Language 7 | 8 | 9 | .. autosummary:: 10 | :toctree: generated/ 11 | 12 | maybe_convert_to_dtype 13 | device_put 14 | arange 15 | convolution 16 | full 17 | full_like 18 | uniform 19 | uniform_like 20 | diagonal 21 | expand 22 | flatten 23 | movedim 24 | reshape 25 | slice_in_dim 26 | squeeze 27 | transpose 28 | stride_order 29 | take 30 | index_add 31 | take_along_axis 32 | scatter_add 33 | unsqueeze 34 | cat 35 | stack 36 | compute_broadcast_shape 37 | matrix_transpose 38 | maybe_broadcast 39 | 40 | Unary 41 | ~~~~~ 42 | 43 | .. autosummary:: 44 | :toctree: generated/ 45 | 46 | abs 47 | acos 48 | acosh 49 | asin 50 | asinh 51 | atan 52 | atanh 53 | bitwise_not 54 | ceil 55 | cos 56 | cosh 57 | erf 58 | erfc 59 | erfcinv 60 | erfinv 61 | exp 62 | exp2 63 | expm1 64 | floor 65 | isfinite 66 | lgamma 67 | log 68 | log10 69 | log1p 70 | log2 71 | ndtri 72 | neg 73 | reciprocal 74 | round 75 | rsqrt 76 | sigmoid 77 | sign 78 | signbit 79 | silu 80 | sin 81 | sinh 82 | sqrt 83 | tan 84 | tanh 85 | trunc 86 | 87 | Binary 88 | ~~~~~~ 89 | 90 | .. autosummary:: 91 | :toctree: generated/ 92 | 93 | add 94 | atan2 95 | bitwise_and 96 | bitwise_or 97 | bitwise_xor 98 | copysign 99 | eq 100 | floor_divide 101 | fmod 102 | mod 103 | ge 104 | gt 105 | logical_and 106 | le 107 | lt 108 | mul 109 | ne 110 | nextafter 111 | pow 112 | remainder 113 | sub 114 | true_divide 115 | 116 | Conditional 117 | ~~~~~~~~~~~ 118 | 119 | .. autosummary:: 120 | :toctree: generated/ 121 | 122 | where 123 | -------------------------------------------------------------------------------- /docs/source/reference/common/index.rst: -------------------------------------------------------------------------------- 1 | .. module:: thunder.common 2 | 3 | thunder.common 4 | ============== 5 | 6 | Common functions and classes for Thunder. 7 | 8 | .. autosummary:: 9 | :toctree: generated/ 10 | 11 | CACHE_OPTIONS 12 | CompileData 13 | CompileStats 14 | -------------------------------------------------------------------------------- /docs/source/reference/core/baseutils.rst: -------------------------------------------------------------------------------- 1 | .. module:: thunder.core.baseutils 2 | :noindex: 3 | 4 | Baseutils 5 | --------- 6 | 7 | .. currentmodule:: thunder.core.baseutils 8 | 9 | .. autosummary:: 10 | :toctree: generated/ 11 | 12 | check 13 | check_type 14 | ProxyInterface 15 | NumberProxyInterface 16 | TensorProxyInterface 17 | SymbolInterface 18 | BoundSymbolInterface 19 | -------------------------------------------------------------------------------- /docs/source/reference/core/codeutils.rst: -------------------------------------------------------------------------------- 1 | .. module:: thunder.core.codeutils 2 | 3 | Codeutils 4 | --------- 5 | 6 | .. currentmodule:: thunder.core.codeutils 7 | 8 | .. autosummary:: 9 | :toctree: generated/ 10 | 11 | prettyprint 12 | SigInfo 13 | get_siginfo 14 | -------------------------------------------------------------------------------- /docs/source/reference/core/devices.rst: -------------------------------------------------------------------------------- 1 | .. module:: thunder.core.devices 2 | 3 | Devices 4 | ------- 5 | 6 | .. currentmodule:: thunder.core.devices 7 | 8 | .. autosummary:: 9 | :toctree: generated/ 10 | 11 | DeviceType 12 | devicetype_string 13 | Device 14 | device_from_string 15 | to_device 16 | -------------------------------------------------------------------------------- /docs/source/reference/core/dtypes.rst: -------------------------------------------------------------------------------- 1 | .. module:: thunder.core.dtypes 2 | 3 | Dtypes 4 | ------ 5 | 6 | ``dtype``\s such as ``float32`` are exposed to :mod:`thunder` like ``thunder.float32``. 7 | 8 | .. currentmodule:: thunder.core.dtypes 9 | 10 | .. autosummary:: 11 | :toctree: generated/ 12 | :nosignatures: 13 | 14 | dtype 15 | to_dtype 16 | float32 17 | float16 18 | bfloat16 19 | float64 20 | int8 21 | int16 22 | int32 23 | int64 24 | uint8 25 | bool8 26 | complex32 27 | complex64 28 | complex128 29 | -------------------------------------------------------------------------------- /docs/source/reference/core/index.rst: -------------------------------------------------------------------------------- 1 | .. module:: thunder.core 2 | 3 | thunder.core 4 | ============ 5 | 6 | .. toctree:: 7 | :maxdepth: 1 8 | 9 | baseutils 10 | codeutils 11 | devices 12 | dtypes 13 | langctxs 14 | prims 15 | proxies 16 | pytree 17 | rematerialization 18 | symbol 19 | trace 20 | transforms 21 | -------------------------------------------------------------------------------- /docs/source/reference/core/langctxs.rst: -------------------------------------------------------------------------------- 1 | .. module:: thunder.core.langctxs 2 | 3 | langctxs 4 | -------- 5 | 6 | .. currentmodule:: thunder.core.langctxs 7 | 8 | .. autosummary:: 9 | :toctree: generated/ 10 | 11 | set_langctx 12 | get_langctx 13 | reset_langctx 14 | langctx 15 | -------------------------------------------------------------------------------- /docs/source/reference/core/prims.rst: -------------------------------------------------------------------------------- 1 | .. module:: thunder.core.prims 2 | 3 | .. todo Add other prims after resolving "Cannot resolve forward reference in type annotations of "thunder.core.prims.all_reduce": name 'Callable' is not defined" 4 | 5 | Prims 6 | ----- 7 | 8 | .. currentmodule:: thunder.core.prims 9 | 10 | .. autosummary:: 11 | :toctree: generated/ 12 | 13 | PrimIDs 14 | -------------------------------------------------------------------------------- /docs/source/reference/core/proxies.rst: -------------------------------------------------------------------------------- 1 | .. module:: thunder.core.proxies 2 | 3 | Proxies 4 | ------- 5 | 6 | .. currentmodule:: thunder.core.proxies 7 | 8 | .. autosummary:: 9 | :toctree: generated/ 10 | 11 | Variable 12 | variableify 13 | unvariableify 14 | CollectionProxy 15 | pyval 16 | pytype 17 | ComplexProxy 18 | IntegerProxy 19 | FloatProxy 20 | FutureTensorProxy 21 | TensorProxy 22 | is_proxyable 23 | proxy 24 | -------------------------------------------------------------------------------- /docs/source/reference/core/pytree.rst: -------------------------------------------------------------------------------- 1 | .. module:: thunder.core.pytree 2 | 3 | PyTree 4 | ------ 5 | 6 | .. list-table:: pytree API map 7 | :widths: 50 50 8 | :header-rows: 1 9 | 10 | * - thunder.core.pytree 11 | - pytorch pytree 12 | * - :func:`tree_map` 13 | - See https://optree.readthedocs.io/en/latest/ops.html#optree.tree_map 14 | * - :func:`tree_flatten` 15 | - See https://optree.readthedocs.io/en/latest/ops.html#optree.tree_flatten 16 | * - :func:`tree_unflatten` 17 | - See https://optree.readthedocs.io/en/latest/ops.html#optree.tree_unflatten 18 | -------------------------------------------------------------------------------- /docs/source/reference/core/rematerialization.rst: -------------------------------------------------------------------------------- 1 | .. module:: thundre.core.rematerialization 2 | 3 | Rematerialization 4 | ----------------- 5 | 6 | .. currentmodule:: thunder.core.rematerialization 7 | 8 | .. autosummary:: 9 | :toctree: generated/ 10 | 11 | rematerialize 12 | -------------------------------------------------------------------------------- /docs/source/reference/core/symbol.rst: -------------------------------------------------------------------------------- 1 | .. module:: thunder.core.symbol 2 | 3 | Symbol 4 | ------ 5 | 6 | .. currentmodule:: thunder.core.symbol 7 | 8 | .. autosummary:: 9 | :toctree: generated/ 10 | 11 | Symbol 12 | BoundSymbol 13 | BoundSymbolRHS 14 | -------------------------------------------------------------------------------- /docs/source/reference/core/trace.rst: -------------------------------------------------------------------------------- 1 | .. module:: thunder.core.trace 2 | 3 | Trace 4 | ----- 5 | 6 | .. currentmodule:: thunder.core.trace 7 | 8 | .. autosummary:: 9 | :toctree: generated/ 10 | 11 | TraceCtx 12 | from_trace 13 | set_tracectx 14 | get_tracectx 15 | reset_tracectx 16 | -------------------------------------------------------------------------------- /docs/source/reference/core/transforms.rst: -------------------------------------------------------------------------------- 1 | .. module:: thunder.core.transforms 2 | 3 | Transforms 4 | ---------- 5 | 6 | .. currentmodule:: thunder.core.transforms 7 | 8 | .. autosummary:: 9 | :toctree: generated/ 10 | 11 | Node 12 | bsym_list_to_dag 13 | toposort_bsym_dag 14 | insert_inplace 15 | replace_inplace 16 | VISIT_TYPE 17 | visitor_transform 18 | -------------------------------------------------------------------------------- /docs/source/reference/distributed/index.rst: -------------------------------------------------------------------------------- 1 | .. module:: thunder.distributed 2 | 3 | thunder.distributed 4 | =================== 5 | 6 | .. autosummary:: 7 | :toctree: generated/ 8 | 9 | ddp 10 | fsdp 11 | FSDPType 12 | FSDPBucketingStrategy 13 | set_skip_data_parallel_grad_sync 14 | reset_skip_data_parallel_grad_sync 15 | get_skip_data_parallel_grad_sync 16 | skip_data_parallel_grad_sync 17 | column_parallel 18 | row_parallel 19 | -------------------------------------------------------------------------------- /docs/source/reference/dynamo/index.rst: -------------------------------------------------------------------------------- 1 | .. module:: thunder.dynamo 2 | 3 | thunder.dynamo 4 | ============== 5 | 6 | .. autosummary:: 7 | :toctree: 8 | 9 | ThunderCompiler 10 | thunderfx 11 | ThunderFXCompiledObject 12 | -------------------------------------------------------------------------------- /docs/source/reference/examine/index.rst: -------------------------------------------------------------------------------- 1 | .. module:: thunder.examine 2 | 3 | thunder.examine 4 | =============== 5 | 6 | .. autosummary:: 7 | :toctree: generated/ 8 | 9 | examine 10 | -------------------------------------------------------------------------------- /docs/source/reference/executors/apexex.rst: -------------------------------------------------------------------------------- 1 | .. module:: thunder.executors.apexex 2 | 3 | APEX Executor 4 | ------------- 5 | 6 | Executor of `NVIDIA/apex `_. 7 | 8 | .. currentmodule:: thunder.executors.apexex 9 | 10 | .. autosummary:: 11 | :toctree: generated/ 12 | 13 | apex_entropy_available 14 | -------------------------------------------------------------------------------- /docs/source/reference/executors/cudnnex.rst: -------------------------------------------------------------------------------- 1 | .. module:: thunder.executors.cudnnex 2 | 3 | cuDNN Executor 4 | -------------- 5 | 6 | Executor of `NVIDIA/cudnn-frontend `_. 7 | 8 | .. currentmodule:: thunder.executors.cudnnex 9 | 10 | .. autosummary:: 11 | :toctree: generated/ 12 | 13 | cudnn_ex 14 | -------------------------------------------------------------------------------- /docs/source/reference/executors/index.rst: -------------------------------------------------------------------------------- 1 | .. module:: thunder.executors 2 | 3 | thunder.executors 4 | ================= 5 | 6 | .. currentmodule:: thunder.executors 7 | 8 | .. autosummary:: 9 | :toctree: generated/ 10 | 11 | get_nvfuser_executor 12 | get_torch_executor 13 | nvfuser_available 14 | 15 | .. toctree:: 16 | :maxdepth: 1 17 | 18 | apexex 19 | cudnnex 20 | nvfuserex 21 | passes 22 | pythonex 23 | torch_compile 24 | torchex 25 | triton_crossentropy 26 | utils 27 | -------------------------------------------------------------------------------- /docs/source/reference/executors/nvfuserex.rst: -------------------------------------------------------------------------------- 1 | .. module:: thunder.executors.nvfuserex 2 | 3 | nvFuser Executor 4 | ---------------- 5 | 6 | An executor dispatches operators to `NVIDIA/Fuser `_. 7 | 8 | .. currentmodule:: thunder.executors.nvfuserex 9 | 10 | .. autosummary:: 11 | :toctree: generated/ 12 | 13 | .. fixme(crcrpar) add methods and classes 14 | -------------------------------------------------------------------------------- /docs/source/reference/executors/passes.rst: -------------------------------------------------------------------------------- 1 | .. module:: thunder.executors.passes 2 | 3 | Optimization Passes 4 | ------------------- 5 | 6 | .. currentmodule:: thunder.executors.passes 7 | 8 | .. autosummary:: 9 | :toctree: generated/ 10 | 11 | transform_for_execution 12 | update_fusion_call_ctx 13 | del_last_used 14 | -------------------------------------------------------------------------------- /docs/source/reference/executors/pythonex.rst: -------------------------------------------------------------------------------- 1 | .. module:: thunder.executors.pythonex 2 | 3 | 4 | Python Executor 5 | --------------- 6 | 7 | .. currentmodule:: thunder.executors.pythonex 8 | 9 | .. autosummary:: 10 | :toctree: generated/ 11 | -------------------------------------------------------------------------------- /docs/source/reference/executors/torch_compile.rst: -------------------------------------------------------------------------------- 1 | .. module:: thunder.executors.torch_compile 2 | 3 | Torch Compile Executor 4 | ---------------------- 5 | 6 | .. currentmodule:: thunder.executors.torch_compile 7 | 8 | .. autosummary:: 9 | :toctree: generated/ 10 | 11 | make_compiled 12 | TorchCompileExecutor 13 | -------------------------------------------------------------------------------- /docs/source/reference/executors/torchex.rst: -------------------------------------------------------------------------------- 1 | .. module:: thunder.executors.torchex 2 | 3 | 4 | PyTorch Executor 5 | ---------------- 6 | 7 | .. currentmodule:: thunder.executors.torch_autograd 8 | 9 | .. autosummary:: 10 | :toctree: generated/ 11 | -------------------------------------------------------------------------------- /docs/source/reference/executors/triton_crossentropy.rst: -------------------------------------------------------------------------------- 1 | .. module:: thunder.executors.triton_crossentropy 2 | 3 | Triton Executor 4 | --------------- 5 | 6 | Executor of `openai/triton `_. 7 | 8 | .. currentmodule:: thunder.executors.triton_crossentropy 9 | 10 | .. autosummary:: 11 | :toctree: generated/ 12 | -------------------------------------------------------------------------------- /docs/source/reference/executors/utils.rst: -------------------------------------------------------------------------------- 1 | .. module:: thunder.executors.utils 2 | 3 | Utils 4 | ----- 5 | 6 | .. currentmodule:: thunder.executors.utils 7 | 8 | .. autosummary:: 9 | :toctree: generated/ 10 | 11 | Region 12 | -------------------------------------------------------------------------------- /docs/source/reference/extend/index.rst: -------------------------------------------------------------------------------- 1 | .. module:: thunder.extend 2 | 3 | thunder.extend 4 | ============== 5 | 6 | .. autosummary:: 7 | :toctree: 8 | 9 | register_executor 10 | deregister_executor 11 | get_all_executors 12 | get_default_executors 13 | get_always_executors 14 | get_executor 15 | set_default_executors 16 | set_always_executors 17 | add_default_executor 18 | add_always_executor 19 | remove_default_executor 20 | remove_always_executor 21 | 22 | Executor 23 | -------------------------------------------------------------------------------- /docs/source/reference/plugins/index.rst: -------------------------------------------------------------------------------- 1 | .. module:: thunder.plugins 2 | 3 | thunder.plugins 4 | ================== 5 | 6 | .. autosummary:: 7 | :toctree: generated/ 8 | 9 | DDP 10 | FSDP 11 | QuantizeInt4 12 | FP8 13 | ReduceOverhead 14 | -------------------------------------------------------------------------------- /docs/source/reference/recipes/index.rst: -------------------------------------------------------------------------------- 1 | .. module:: thunder.recipes 2 | 3 | thunder.recipes 4 | ================== 5 | 6 | .. autosummary:: 7 | :toctree: generated/ 8 | 9 | BaseRecipe 10 | HFTransformers 11 | -------------------------------------------------------------------------------- /docs/source/reference/thunder.rst: -------------------------------------------------------------------------------- 1 | .. module:: thunder 2 | 3 | thunder 4 | ======= 5 | 6 | 7 | Compiling functions and modules 8 | ------------------------------- 9 | 10 | 11 | .. autosummary:: 12 | :toctree: generated/ 13 | 14 | jit 15 | 16 | 17 | Querying information on compiled functions and modules 18 | ------------------------------------------------------ 19 | 20 | 21 | .. autosummary:: 22 | :toctree: generated/ 23 | 24 | DebugOptions 25 | compile_data 26 | compile_stats 27 | last_traces 28 | last_backward_traces 29 | last_prologue_traces 30 | cache_option 31 | cache_hits 32 | cache_misses 33 | list_transforms 34 | last_interpreted_instructions 35 | last_interpreter_log 36 | last_compile_options 37 | .. 38 | compile 39 | grad 40 | 41 | JITed Model wrapper 42 | ------------------- 43 | 44 | .. autoclass:: ThunderModule 45 | :members: no_sync 46 | :exclude-members: forward 47 | -------------------------------------------------------------------------------- /docs/source/reference/torch/index.rst: -------------------------------------------------------------------------------- 1 | .. module:: thunder.torch 2 | 3 | thunder.torch 4 | ------------- 5 | 6 | A PyTorch dialect in thunder. 7 | 8 | .. currentmodule:: thunder.torch 9 | 10 | .. autosummary:: 11 | :toctree: generated/ 12 | 13 | torchsymbol 14 | size 15 | to_float 16 | _register_custom_op 17 | 18 | Under Construction 19 | 20 | .. fixme "Cannot resolve forward reference in type annotations of "thunder.torch.abs": name 'Callable' is not defined" 21 | 22 | Operators 23 | ~~~~~~~~~ 24 | 25 | Unary 26 | ~~~~~ 27 | 28 | Binary 29 | ~~~~~~ 30 | 31 | Conditional 32 | ~~~~~~~~~~~ 33 | 34 | Tensor Creation 35 | ~~~~~~~~~~~~~~~ 36 | 37 | Shape Operation 38 | ~~~~~~~~~~~~~~~ 39 | -------------------------------------------------------------------------------- /docs/source/reference/transforms/index.rst: -------------------------------------------------------------------------------- 1 | .. module:: thunder.transforms 2 | 3 | thunder.transforms 4 | ================== 5 | 6 | .. autosummary:: 7 | :toctree: generated/ 8 | 9 | MaterializationTransform 10 | ConstantFolding 11 | -------------------------------------------------------------------------------- /examples/coverage/requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.52.4 2 | accelerate 3 | nvfuser-cu128-torch27 4 | nvidia-cudnn-frontend 5 | einops 6 | tiktoken 7 | -------------------------------------------------------------------------------- /examples/quickstart/hf_bert.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import transformers 3 | 4 | import thunder 5 | 6 | from thunder.dev_utils.benchmark import benchmark 7 | 8 | 9 | def main(): 10 | # model_name = "bert-large-uncased" 11 | model_name = "bert-base-uncased" 12 | 13 | device = "cuda:0" if torch.cuda.is_available() else "cpu" 14 | 15 | tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) 16 | 17 | with torch.device(device): 18 | model = transformers.AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16) 19 | model.requires_grad_(False) 20 | model.eval() 21 | # apparently, Transformers 4.51.3 does not instantiate models on the default device 22 | model.to(device) 23 | 24 | inp = tokenizer(["Hello world!"], return_tensors="pt") 25 | 26 | print(f"Eager: {benchmark(model, **inp):.2f}ms") 27 | 28 | thunder_model = thunder.compile( 29 | model, 30 | recipe="hf-transformers", 31 | ) 32 | 33 | print(f"Thunder: {benchmark(thunder_model, **inp):.2f}ms") 34 | 35 | if torch.cuda.is_available(): 36 | thunder_model = thunder.compile(model, plugins="reduce-overhead") 37 | 38 | print(f"Thunder with 'reduce-overhead': {benchmark(thunder_model, **inp):.2f}ms") 39 | 40 | 41 | if __name__ == "__main__": 42 | torch.backends.cuda.matmul.allow_tf32 = True 43 | 44 | main() 45 | -------------------------------------------------------------------------------- /examples/quickstart/hf_llm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import transformers 3 | 4 | import thunder 5 | 6 | from thunder.dev_utils.benchmark import benchmark_n 7 | 8 | 9 | def main(): 10 | # model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" 11 | # model_name = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B" 12 | # model_name = "meta-llama/Llama-3.1-8B" 13 | model_name = "meta-llama/Llama-3.2-1B" 14 | # model_name = "Qwen/Qwen2.5-7B-Instruct" 15 | # model_name = "microsoft/phi-4" 16 | 17 | device = "cuda:0" if torch.cuda.is_available() else "cpu" 18 | 19 | tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) 20 | 21 | with torch.device(device): 22 | model = transformers.AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16) 23 | model.requires_grad_(False) 24 | model.eval() 25 | # apparently, Transformers 4.51.3 does not instantiate models on the default device 26 | model.to(device) 27 | 28 | inp = tokenizer(["Hello world! Here's a long story"], return_tensors="pt") 29 | 30 | def generate(model, inp, cache=None): 31 | out = model.generate(**inp, do_sample=False, cache_implementation=cache, max_new_tokens=100) 32 | print(tokenizer.decode(out[0].tolist())) 33 | 34 | print("\nGenerating with PyTorch eager:") 35 | eager_time = benchmark_n(2, generate, model, inp) 36 | 37 | thunder_model = thunder.compile( 38 | model, 39 | recipe="hf-transformers", 40 | ) 41 | 42 | print("\nGenerating with Thunder:") 43 | thunder_time = benchmark_n(2, generate, thunder_model, inp, cache="static") 44 | 45 | print(f"\nEager: {eager_time:.2f}ms") 46 | print(f"Thunder: {thunder_time:.2f}ms") 47 | 48 | 49 | if __name__ == "__main__": 50 | torch.backends.cuda.matmul.allow_tf32 = True 51 | 52 | main() 53 | -------------------------------------------------------------------------------- /examples/quickstart/mlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import thunder 5 | 6 | 7 | def main(): 8 | model = nn.Sequential( 9 | nn.Linear(2048, 4096), 10 | nn.ReLU(), 11 | nn.Linear(4096, 64) 12 | ) 13 | 14 | thunder_model = thunder.compile(model) 15 | x = torch.randn(64, 2048) 16 | y = thunder_model(x) 17 | 18 | print(thunder_model) 19 | 20 | print(thunder.last_traces(thunder_model)[-1]) 21 | 22 | 23 | if __name__ == "__main__": 24 | main() 25 | -------------------------------------------------------------------------------- /examples/quickstart/requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.55.4 2 | accelerate 3 | nvfuser-cu128-torch27 4 | nvidia-cudnn-frontend 5 | -------------------------------------------------------------------------------- /examples/quickstart/vit_hf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import ViTForImageClassification 3 | 4 | import thunder 5 | 6 | from thunder.dev_utils.benchmark import benchmark 7 | 8 | 9 | def main(): 10 | device = "cuda:0" if torch.cuda.is_available() else "cpu" 11 | 12 | with torch.device(device): 13 | model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224", attn_implementation="sdpa", torch_dtype=torch.float32) 14 | model.requires_grad_(False) 15 | model.eval() 16 | # apparently, Transformers 4.51.3 does not instantiate models on the default device 17 | model.to(device) 18 | 19 | inp = torch.randn(128, 3, 224, 224) 20 | 21 | out = model(inp) 22 | 23 | thunder_model = thunder.compile(model, plugins="reduce-overhead" if torch.cuda.is_available() else None) 24 | 25 | thunder_out = thunder_model(inp) 26 | 27 | torch.testing.assert_close(out, thunder_out, atol=1e-2, rtol=1e-2) 28 | 29 | print(f"Eager: {benchmark(model, inp):.2f}ms") 30 | print(f"Thunder: {benchmark(thunder_model, inp):.2f}ms") 31 | 32 | 33 | if __name__ == "__main__": 34 | torch.set_float32_matmul_precision('high') 35 | 36 | main() 37 | -------------------------------------------------------------------------------- /examples/quickstart/vit_tv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.models as models 3 | 4 | import thunder 5 | 6 | from thunder.dev_utils.benchmark import benchmark 7 | 8 | 9 | def main(): 10 | device = "cuda:0" if torch.cuda.is_available() else "cpu" 11 | 12 | with torch.device(device): 13 | model = models.vit_b_16() 14 | model.requires_grad_(False) 15 | model.eval() 16 | 17 | inp = torch.randn(128, 3, 224, 224) 18 | 19 | out = model(inp) 20 | 21 | thunder_model = thunder.compile(model, plugins="reduce-overhead" if torch.cuda.is_available() else None) 22 | 23 | thunder_out = thunder_model(inp) 24 | 25 | # print(thunder.last_traces(thunder_model)[-1]) 26 | 27 | torch.testing.assert_close(out, thunder_out) 28 | 29 | print(f"Eager: {benchmark(model, inp):.2f}ms") 30 | print(f"Thunder: {benchmark(thunder_model, inp):.2f}ms") 31 | 32 | 33 | if __name__ == "__main__": 34 | torch.set_float32_matmul_precision('high') 35 | 36 | main() 37 | -------------------------------------------------------------------------------- /notebooks/.ignore.ci: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-AI/lightning-thunder/6720e82b18884a26bbadaffff2f33d836682cc9d/notebooks/.ignore.ci -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | -r requirements/base.txt 2 | -------------------------------------------------------------------------------- /requirements/base.txt: -------------------------------------------------------------------------------- 1 | torch >=2.7.1 2 | looseversion ==1.3.0 3 | lightning-utilities >0.7.0 4 | numpy 5 | networkx >=3.3 6 | optree >=0.12.1 7 | opt_einsum >=3.3.0 8 | # todo: teporarl pin for `NameError: name '_C' is not defined` 9 | mpmath <1.4.0 10 | # Support for 3.12 11 | dill >=0.3.8 12 | -------------------------------------------------------------------------------- /requirements/coverage.txt: -------------------------------------------------------------------------------- 1 | coverage ~=7.9.1 2 | pytest ==8.4.2 3 | pytest-cov ==6.2.1 4 | pytest-benchmark ==5.1.0 5 | transformers ==4.52.4 6 | lightning_sdk 7 | diffusers==0.35.1 8 | accelerate 9 | bitsandbytes==0.48.0 10 | -------------------------------------------------------------------------------- /requirements/devel.txt: -------------------------------------------------------------------------------- 1 | packaging 2 | -r base.txt 3 | -r test.txt 4 | 5 | datasets>=3.5.0 6 | peft>=0.15.2 7 | 8 | torchvision 9 | torchaudio 10 | -------------------------------------------------------------------------------- /requirements/docs.txt: -------------------------------------------------------------------------------- 1 | sphinx ==5.3.0 2 | myst-parser ==1.0.0 3 | nbsphinx ~=0.9.7 4 | ipython[all] ~=8.37.0 5 | pandoc ==2.4 6 | docutils >=0.16 7 | sphinxcontrib-fulltoc @ git+https://github.com/t-vi/sphinx-contrib-fulltoc@master 8 | sphinxcontrib-mockautodoc 9 | 10 | sphinx-autodoc-typehints ==1.23.0 11 | sphinx-paramlinks ==0.6.0 12 | sphinx-togglebutton ==0.3.2 13 | sphinx-copybutton ==0.5.2 14 | snowballstemmer < 4 15 | 16 | # installed from S3 location and fetched in advance 17 | lai-sphinx-theme 18 | # alternative /back-up (old) theme 19 | pt-lightning-sphinx-theme @ https://github.com/PyTorchLightning/lightning_sphinx_theme/archive/master.zip 20 | -------------------------------------------------------------------------------- /requirements/notebooks.txt: -------------------------------------------------------------------------------- 1 | # just reuse the base requirements 2 | -r base.txt 3 | 4 | ipython[all] ~=8.37.0 5 | numpy 6 | liger-kernel == 0.4.0 7 | cuda-python >=12.6.1, <14.0.0 8 | litgpt == 0.5.1 9 | -------------------------------------------------------------------------------- /requirements/test.txt: -------------------------------------------------------------------------------- 1 | coverage ~=7.9.1 2 | pytest ==8.4.2 3 | pytest-benchmark ==5.1.0 4 | pytest-timeout ==2.4.0 5 | pytest-cov ==6.2.1 6 | pytest-xdist ==3.8.0 7 | pytest-random-order ==1.2.0 8 | pytest-timestamper ==0.0.10 9 | graphviz ==0.21 10 | fdm ==0.5.0 11 | expecttest ==0.3.0 # for test_ddp.py 12 | hypothesis ~=6.136.6 # for test_ddp.py 13 | numpy 14 | einops # for test_einops.py 15 | litgpt==0.5.0 # for the model definition in tests and benchmarks # todo: need update to latest 16 | absl-py # thunder/benchmarks/test_benchmark_litgpt.py 17 | pandas # thunder/benchmarks/test_benchmark_litgpt.py 18 | xlsxwriter # thunder/benchmarks/test_benchmark_litgpt.py 19 | jsonargparse # thunder/benchmarks/benchmark_litgpt.py 20 | bitsandbytes==0.48.0; 'arm' not in platform_machine and 'aarch' not in platform_machine 21 | bitsandbytes>=0.42,<0.43; 'arm' in platform_machine or 'aarch' in platform_machine 22 | transformers==4.52.4 # for test_networks.py 23 | diffusers==0.35.1 # for test_networks.py 24 | accelerate # for test_networks.py 25 | 26 | asvdb @ git+https://github.com/rapidsai/asvdb.git 27 | asv >=0.6.4 28 | -------------------------------------------------------------------------------- /scripts/bisect_nvfuser.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import subprocess 4 | 5 | 6 | class Runner: 7 | def __init__(self, command_and_args: list[str], verbose: bool): 8 | self._command_and_args = command_and_args 9 | self._verbose = verbose 10 | 11 | def env_var_for_fuel(self) -> str: 12 | return "NVFUSER_OPTIMIZATION_FUEL" 13 | 14 | def run(self, fuel: int) -> int: 15 | os.environ[self.env_var_for_fuel()] = str(fuel) 16 | print(f"Running {self._command_and_args} with {fuel} units of fuel...") 17 | result = subprocess.run( 18 | self._command_and_args, 19 | check=False, 20 | stdout=(None if self._verbose else subprocess.DEVNULL), 21 | stderr=(None if self._verbose else subprocess.DEVNULL), 22 | ) 23 | print(f"Command returned {result.returncode} with {fuel} units of fuel.") 24 | return result.returncode 25 | 26 | 27 | if __name__ == "__main__": 28 | parser = argparse.ArgumentParser( 29 | description=""" 30 | A script that bisects nvFuser's optimization fuel to isolate a compiler bug. 31 | 32 | See `thunder.extend.FusionExecutor.get_fuel/set_fuel` for what optimization fuel is and how it is implemented. 33 | 34 | This script finds the minimum unit of optimization fuel between `lower_bound` and `upper_bound` that causes `command_and_args` to fail (i.e. exit with non-zero). The user will then run `NVFUSER_OPTIMIZATION_FUEL= ` as a minimal reproducer to identify the nvFusion that triggers the failure. Likely, the last nvFusion generated is the culprit. 35 | """ 36 | ) 37 | parser.add_argument("lower_bound", type=int, help="the lower bound for bisecting") 38 | parser.add_argument("upper_bound", type=int, help="the upper bound for bisecting") 39 | parser.add_argument("command_and_args", type=str, nargs=argparse.REMAINDER, help="the command and args to run") 40 | parser.add_argument("--verbose", action="store_true", help="whether to show stdout/stderr of the subprocess") 41 | args = parser.parse_args() 42 | 43 | runner = Runner(args.command_and_args, args.verbose) 44 | 45 | # The cornercases are not needed for the correctness of binary search. They are for catching errors earlier when the user specified a wrong lower/upper bound. 46 | if (exitcode := runner.run(args.lower_bound)) != 0: 47 | print("No need to bisect. Command failed with the lower bound.") 48 | exit(0) 49 | 50 | if (exitcode := runner.run(args.upper_bound)) == 0: 51 | print("Bisecting failed. Command passed with the upper bound.") 52 | exit(1) 53 | 54 | # Find the smallest fuel that fails `command_and_args`. 55 | low = args.lower_bound + 1 # +1 because we know `lower_bound` passed. 56 | high = args.upper_bound 57 | while low < high: 58 | mid = (low + high) // 2 59 | exitcode = runner.run(mid) 60 | if exitcode == 0: 61 | low = mid + 1 62 | else: 63 | high = mid 64 | assert low == high 65 | 66 | print("Bisecting succeeded. Run the following command as a minimal reproducer:") 67 | print(f" {runner.env_var_for_fuel()}={low} {' '.join(args.command_and_args)}") 68 | print("The last nvFusion likely triggered the failure.") 69 | -------------------------------------------------------------------------------- /scripts/remove-torch-lines.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Get the name of the file to process 4 | filename=$1 5 | 6 | # Create a temporary file to store the filtered lines 7 | tempfile=$(mktemp) 8 | 9 | # Loop through the original file and remove lines including the word "torch" 10 | while read -r line; do 11 | if [[ ! "$line" =~ "torch" ]]; then 12 | echo "$line" >> "$tempfile" 13 | fi 14 | done < "$filename" 15 | 16 | # Move the temporary file to the original file 17 | mv "$tempfile" "$filename" 18 | -------------------------------------------------------------------------------- /scripts/sanity-check.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -ex 4 | pip list 5 | python -c "import torch ; assert torch.cuda.is_available(), 'missing GPU support!'" 6 | python -c "import torch ; v = torch.__version__ ; assert str(v).startswith('2'), v" 7 | python -c "from thunder.executors import nvfuser_available ; assert nvfuser_available(), 'nvFuser is missing!'" 8 | python -c "from thunder.executors.triton_utils import triton_version ; assert triton_version() is not None, 'triton is missing!'" 9 | -------------------------------------------------------------------------------- /scripts/validate_build.py: -------------------------------------------------------------------------------- 1 | """Check that `scripts/build_from_source.sh` has produced a correct build.""" 2 | 3 | import os 4 | 5 | BUILD_DIR = os.path.abspath(os.path.join(os.path.split(__file__)[0], "build")) 6 | 7 | 8 | def main(): 9 | import torch 10 | 11 | expected = os.path.join(BUILD_DIR, "pytorch", "torch", "__init__.py") 12 | actual = os.path.abspath(torch.__file__) 13 | assert expected == actual, f"{expected} vs. {actual}" 14 | 15 | import nvfuser 16 | from looseversion import LooseVersion 17 | 18 | assert LooseVersion(nvfuser.version()) >= LooseVersion("0.0.6"), nvfuser.version() 19 | 20 | import thunder 21 | 22 | expected = os.path.abspath(os.path.join(BUILD_DIR, "..", "..", "thunder", "__init__.py")) 23 | assert expected == thunder.__file__, f"{expected} vs. {thunder.__file__}" 24 | 25 | print("Build looks good!") 26 | 27 | 28 | if __name__ == "__main__": 29 | main() 30 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | from importlib.util import module_from_spec, spec_from_file_location 4 | 5 | from packaging.requirements import Requirement as parse_requirements 6 | from setuptools import find_packages, setup 7 | 8 | 9 | _PATH_ROOT = os.path.dirname(__file__) 10 | _PATH_REQUIRES = os.path.join(_PATH_ROOT, "requirements") 11 | # check if os env. variable is set to convert version to nightly 12 | _CONVERT_VERSION = int(os.environ.get("CONVERT_VERSION2NIGHTLY", 0)) 13 | 14 | 15 | def _load_py_module(fname, pkg="thunder"): 16 | spec = spec_from_file_location(os.path.join(pkg, fname), os.path.join(_PATH_ROOT, pkg, fname)) 17 | py = module_from_spec(spec) 18 | spec.loader.exec_module(py) 19 | return py 20 | 21 | 22 | def _load_requirements(path_dir: str, file_name: str = "requirements.txt") -> list[str]: 23 | """Parse requirements file -> list of strings""" 24 | reqs: list[str] = [] 25 | with open(os.path.join(path_dir, file_name)) as f: 26 | for line in f: 27 | if line and not line.startswith("#"): 28 | reqs.append(line.strip()) 29 | # Filter out requirements referring to local paths or specific URLs (if any) 30 | reqs = [str(parse_requirements(r)) for r in reqs if "@" not in r and "://" not in r] 31 | return reqs 32 | 33 | 34 | def convert_version2nightly(about_file: str = "thunder/__about__.py") -> None: 35 | """Load the actual version and convert it to the nightly version.""" 36 | from datetime import datetime 37 | 38 | # load the about file 39 | with open(about_file) as fo: 40 | lines = fo.readlines() 41 | idx = None 42 | # find the line with version 43 | for i, ln in enumerate(lines): 44 | if ln.startswith("__version__"): 45 | idx = i 46 | break 47 | if idx is None: 48 | raise ValueError("The version is not found in the `__about__.py` file.") 49 | # parse the version from variable assignment 50 | version = lines[idx].split("=")[1].strip().strip('"') 51 | # parse X.Y.Z version and prune any suffix 52 | vers = re.match(r"(\d+)\.(\d+)\.(\d+).*", version) 53 | # create timestamp YYYYMMDD 54 | timestamp = datetime.now().strftime("%Y%m%d") 55 | version = f"{'.'.join(vers.groups())}.dev{timestamp}" 56 | # print the new version 57 | lines[idx] = f'__version__ = "{version}"\n' 58 | # dump updated lines 59 | with open(about_file, "w") as fo: 60 | fo.writelines(lines) 61 | 62 | 63 | def _load_readme_description(path_dir: str, homepage: str, version: str) -> str: 64 | """Load readme as decribtion.""" 65 | path_readme = os.path.join(path_dir, "README.md") 66 | with open(path_readme, encoding="utf-8") as fp: 67 | text = fp.read() 68 | # https://github.com/Lightning-AI/lightning-thunder/raw/master/docs/source/_static/images/lightning_module/pt_to_pl.png 69 | github_source_url = os.path.join(homepage, "raw", version) 70 | # replace relative repository path to absolute link to the release 71 | # do not replace all "docs" as in the readme we replace some other sources with particular path to docs 72 | text = text.replace("docs/source/_static/", f"{os.path.join(github_source_url, 'docs/source/_static/')}") 73 | return text 74 | 75 | 76 | if _CONVERT_VERSION: 77 | convert_version2nightly() 78 | 79 | about = _load_py_module("__about__.py") 80 | 81 | setup( 82 | version=about.__version__, 83 | packages=find_packages(exclude=["thunder/tests", "docs"]), 84 | long_description=_load_readme_description( 85 | path_dir=_PATH_ROOT, 86 | homepage=about.__homepage__, 87 | version=about.__version__, 88 | ), 89 | long_description_content_type="text/markdown", 90 | install_requires=_load_requirements(_PATH_REQUIRES, file_name="base.txt"), 91 | include_package_data=True, 92 | zip_safe=False, 93 | ) 94 | -------------------------------------------------------------------------------- /thunder/__about__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.2.6.dev0" 2 | __author__ = "Lightning-AI et al" 3 | __author_email__ = "community@lightning.ai" 4 | __copyright__ = f"2024, 2025 {__author__}" 5 | __homepage__ = "https://github.com/Lightning-AI/lightning-thunder" 6 | __docs__ = "Lightning Thunder project." 7 | 8 | 9 | __all__ = [ 10 | "__author__", 11 | "__author_email__", 12 | "__copyright__", 13 | "__docs__", 14 | "__homepage__", 15 | "__version__", 16 | ] 17 | -------------------------------------------------------------------------------- /thunder/clang/langctx.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Callable 2 | 3 | from thunder.core.langctxs import LanguageContext, register_langctx, Languages, resolve_language 4 | from thunder.core.pytree import tree_flatten 5 | from thunder.core.proxies import TensorProxy, NumberProxy 6 | 7 | # 8 | # Creates and registers the torch language context 9 | # 10 | # NOTE That this is done separately from the definition of thunder.torch operations, because the 11 | # language context must be available before those operations are defined 12 | 13 | _method_name_to_fn_map: dict[str, Callable] = {} 14 | 15 | 16 | # Creates and registers the core language language context 17 | class ClangCtx(LanguageContext): 18 | def __init__(self): 19 | super().__init__("core") 20 | 21 | def has_method(self, id: str) -> bool: 22 | return id in _method_name_to_fn_map 23 | 24 | def get_method(self, id: str, *args, **kwargs) -> Callable: 25 | # Note: concrete implmenetations should only raise AttributeError or 26 | # return None for "missing" methods as the proxies will 27 | # route __getattr__ to here and hasattr relies on __getattr__ 28 | # throwing AttributeError (only) when the attribute does 29 | # not exist. 30 | inps, _ = tree_flatten((args, kwargs)) 31 | 32 | has_proxy_input: bool = False 33 | for x in inps: 34 | if isinstance(x, TensorProxy) or isinstance(x, NumberProxy): 35 | has_proxy_input = True 36 | break 37 | 38 | if has_proxy_input: 39 | method: None | Callable = _method_name_to_fn_map.get(id, None) 40 | if method is None: 41 | raise AttributeError(f"The {self.name} language context has no method {id}") 42 | return method 43 | 44 | # has_proxy_input is False 45 | # Defers to the primitive language context when there are no tensor inputs= 46 | # (the primitive language context handles operations on numbers) 47 | primsctx: LanguageContext = resolve_language(Languages.PRIMS) 48 | if not primsctx.has_method(id): 49 | raise AttributeError( 50 | f"Attempting to call method {id} in the core language context, but it has no tensor inputs and the primitive language context (which handles numbers) doesn't have the method" 51 | ) 52 | prim_method: Callable = primsctx.get_method(id, *args, **kwargs) 53 | return prim_method 54 | 55 | 56 | clangctx = ClangCtx() 57 | register_langctx(Languages.CLANG, clangctx) 58 | 59 | 60 | # Registers a method with the torch language context 61 | def register_method(method_name: str, method: Callable, /) -> None: 62 | _method_name_to_fn_map[method_name] = method 63 | -------------------------------------------------------------------------------- /thunder/clang/utils.py: -------------------------------------------------------------------------------- 1 | from numbers import Number 2 | from collections.abc import Sequence 3 | from collections.abc import Callable 4 | 5 | from thunder.core import utils 6 | import thunder.core.dtypes as dtypes 7 | from thunder.core.symbol import Symbol 8 | 9 | from thunder.core.proxies import ( 10 | NumberProxy, 11 | TensorProxy, 12 | ) 13 | 14 | 15 | def create_maybe_convert_to_dtype_with_prim(conversion_prim: Symbol): 16 | assert isinstance(conversion_prim, Symbol) 17 | 18 | def maybe_convert_to_dtype(a, dtype, *, enforce_safe_casting=False): 19 | """If a has the same dtype as the given dtype, returns a unmodified. 20 | 21 | Otherwise returns a converted to the given dtype. 22 | """ 23 | 24 | utils.check(utils.is_dtype(dtype), lambda: f"Unknown dtype {dtype}!") 25 | 26 | if isinstance(a, Sequence): 27 | return tuple(maybe_convert_to_dtype(x, dtype) for x in a) 28 | if isinstance(a, TensorProxy): 29 | # Translates numbertypes to dtypes 30 | if dtypes.is_numbertype(dtype): 31 | dtype = dtypes.numbertype_to_dtype(dtype) 32 | elif isinstance(a, (Number, NumberProxy)): 33 | # NOTE This allows conversions like (5, float32) -> 5., which is a little odd 34 | dtype = utils.dtype_to_numbertype(dtype) 35 | else: 36 | raise ValueError( 37 | f"Trying to convert the type of the data of an unknown object {a} of {type(a)} that is neither a tensor, number, or sequence!" 38 | ) 39 | 40 | if not utils.are_same_dtypes(a, dtype): 41 | if enforce_safe_casting: 42 | utils.check( 43 | utils.can_safe_cast_to(cast_from=utils.to_dtype(a), cast_to=dtype), 44 | lambda: f"Can't safe case from a={a} with dtype {utils.to_dtype(a)} to {dtype}!", 45 | ) 46 | 47 | return conversion_prim(a, dtype) 48 | 49 | return a 50 | 51 | return maybe_convert_to_dtype 52 | 53 | 54 | # TODO Add supported dtypes 55 | def _elementwise_unary_wrapper( 56 | a, 57 | *, 58 | prim, 59 | type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, 60 | dtype_conversion_fn: Callable[[TensorProxy | NumberProxy, dtypes.dtype], TensorProxy | NumberProxy], 61 | ): 62 | computation_dtype, result_dtype = utils.elementwise_type_promotion(a, type_promotion_kind=type_promotion_kind) 63 | 64 | a = dtype_conversion_fn(a, computation_dtype) 65 | result = prim(a) 66 | result = dtype_conversion_fn(result, result_dtype) 67 | 68 | return result 69 | -------------------------------------------------------------------------------- /thunder/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-AI/lightning-thunder/6720e82b18884a26bbadaffff2f33d836682cc9d/thunder/core/__init__.py -------------------------------------------------------------------------------- /thunder/core/compile_data.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | from contextvars import ContextVar 3 | from typing import Any 4 | 5 | from thunder.core.options import CACHE_OPTIONS 6 | 7 | # 8 | # Setting and querying the "compile_data_and_stats" context variable, which contains 9 | # a tuple of (CompileData, CompileStats) objects for the current trace. 10 | # 11 | 12 | _compile_data = ContextVar("compile_data", default=(None, None)) 13 | 14 | 15 | # 16 | # Setting, getting, and resetting the context variable 17 | # 18 | 19 | 20 | # NOTE This just acquires the compile data part of the context var's tuple 21 | def get_compile_data() -> None | Any: 22 | """Returns the current compile data.""" 23 | cd, cs = _compile_data.get() 24 | 25 | return cd 26 | 27 | 28 | def set_compile_data_and_stats(cd, cs, /): 29 | """Sets the current compile data. 30 | 31 | This is used to pass compile data to functions that are called during compilation. 32 | """ 33 | token = _compile_data.set((cd, cs)) 34 | return token 35 | 36 | 37 | def reset_compile_data_and_stats(token, /): 38 | """Resets the compile data.""" 39 | _compile_data.reset(token) 40 | 41 | 42 | @contextmanager 43 | def compile_data_and_stats(cd, cs, /): 44 | """Sets the current compile data for the duration of the context.""" 45 | token = set_compile_data_and_stats(cd, cs) 46 | try: 47 | yield 48 | finally: 49 | reset_compile_data_and_stats(token) 50 | 51 | 52 | # 53 | # Query helpers 54 | # 55 | 56 | 57 | def get_compile_option(option: str, description: str, /) -> None | Any: 58 | cd, cs = _compile_data.get() 59 | 60 | if cd is None or cs is None: 61 | return None 62 | 63 | # See NOTE Different categories of compile options in thunder/__init__.py 64 | cs.last_compile_reasons[option].append(description) 65 | return cd.compile_options.get(option, None) 66 | 67 | 68 | # Whether or not the caching option uses symbolic values 69 | def get_cache_option() -> CACHE_OPTIONS: 70 | cd = get_compile_data() 71 | if cd is None: 72 | return CACHE_OPTIONS.CONSTANT_VALUES 73 | return cd.cache_option 74 | 75 | 76 | # TODO RC1 Remove the try (hack for when operating outside of this contextvar being set) 77 | def using_symbolic_values() -> bool: 78 | try: 79 | return get_cache_option() is CACHE_OPTIONS.SYMBOLIC_VALUES 80 | except: 81 | return False 82 | 83 | 84 | def using_jit() -> bool: 85 | try: 86 | cd, cs = _compile_data.get() 87 | return cd.using_jit 88 | except: 89 | return False 90 | -------------------------------------------------------------------------------- /thunder/core/profile.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Callable 2 | import contextlib 3 | import functools 4 | import os 5 | import warnings 6 | 7 | import torch 8 | 9 | 10 | _ENABLED = os.getenv("THUNDER_ANNOTATE_TRACES") in ("1", "y", "Y") 11 | 12 | # However, nvtx is incredibly cheap so we no longer bother requiring the 13 | # environment variable. 14 | try: 15 | import nvtx 16 | except ImportError: 17 | if _ENABLED: 18 | msg = "Requested nvtx but the package is not available." 19 | msg += "\nUse `pip install -m pip install nvtx`." 20 | warnings.warn(msg) 21 | _ENABLED = False 22 | raise 23 | 24 | 25 | def profiling_enabled() -> bool: 26 | return _ENABLED 27 | 28 | 29 | @contextlib.contextmanager 30 | def add_markers(msg: str) -> None: 31 | if not profiling_enabled(): 32 | yield 33 | return 34 | 35 | assert "\n" not in msg, msg # Both NVTX and JSON forbid newlines 36 | assert '"' not in msg, msg # The PyTorch profiler does not properly escape quotations 37 | 38 | with torch.profiler.record_function(msg): 39 | torch.cuda.nvtx.range_push(msg) 40 | try: 41 | yield 42 | 43 | finally: 44 | torch.cuda.nvtx.range_pop() 45 | 46 | 47 | # The main interface to profiling something. Generally used as a decorator: 48 | # @thunder.core.profile.annotate_for_profile("foo") 49 | # def foo(...): ... 50 | # but alternatively as a `with` context: 51 | # with thunder.core.profile.annotate_for_profile("name for a block of code"): 52 | # # ... code ... 53 | annotate_for_profile: Callable[[str], None] = None 54 | 55 | 56 | if _ENABLED: 57 | annotate_for_profile = functools.partial(nvtx.annotate, domain="thunder") 58 | else: 59 | 60 | class _no_annotate(contextlib.nullcontext): 61 | """ 62 | A profiling decorator that does nothing. 63 | """ 64 | 65 | def __init__(self, *args, **kwargs): 66 | super().__init__() 67 | 68 | def __call__(self, fqn): 69 | return fqn 70 | 71 | annotate_for_profile = _no_annotate 72 | -------------------------------------------------------------------------------- /thunder/dev_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-AI/lightning-thunder/6720e82b18884a26bbadaffff2f33d836682cc9d/thunder/dev_utils/__init__.py -------------------------------------------------------------------------------- /thunder/dev_utils/benchmark.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | 5 | 6 | def benchmark_n(n, model_or_fn, /, *args, **kwargs): 7 | for _ in range(n): 8 | _ = model_or_fn(*args, **kwargs) 9 | start_event = torch.cuda.Event(enable_timing=True) 10 | end_event = torch.cuda.Event(enable_timing=True) 11 | torch.cuda.synchronize() 12 | start_event.record() 13 | for _ in range(n): 14 | _ = model_or_fn(*args, **kwargs) 15 | end_event.record() 16 | torch.cuda.synchronize() 17 | return start_event.elapsed_time(end_event) / n 18 | 19 | 20 | benchmark = partial(benchmark_n, 10) 21 | -------------------------------------------------------------------------------- /thunder/dev_utils/check_trace.py: -------------------------------------------------------------------------------- 1 | import thunder 2 | 3 | CHECK_VERSION = 3 4 | 5 | 6 | def check_subsymbols(parent_bsym): 7 | if parent_bsym.sym.is_prim: 8 | # assert that there are no subsymbols? 9 | return 10 | known_proxies = {a.name: a for a in parent_bsym.flat_proxy_args} 11 | for bsym in parent_bsym.subsymbols: 12 | for a in bsym.flat_proxy_args: 13 | assert a.name in known_proxies, f"unknown proxy {a.name} is used in {bsym} args" 14 | assert known_proxies[a.name] is a, f"proxy name collision {a.name} in {bsym} args" 15 | for o in bsym.flat_proxy_outs: 16 | assert known_proxies.get(o.name, o) is o, f"known proxy or proxy name collision {o.name} in {bsym} outputs" 17 | known_proxies[o.name] = o 18 | check_subsymbols(bsym) 19 | for o in parent_bsym.flat_proxy_outs: 20 | assert known_proxies.get(o.name, o) is o, f"known proxy or proxy name collision {o.name} in {parent_bsym}" 21 | 22 | 23 | def check_trace(trace, *, version=CHECK_VERSION): 24 | """checks a trace for consistency 25 | 26 | The check is versioned for the benefit of CI and other automated testing. 27 | 28 | As a user, don't pass a version to get all implemented checks. 29 | 30 | If you add new checks and do not fix all newly detected inconsistencies, 31 | bump the CHECK_VERSION and make your tests only apply to this latest version. 32 | 33 | Please do file issues for things that fail with the latest versions so we can 34 | catch up. 35 | """ 36 | # TODO: 37 | # - args vs. unpack trivial 38 | # - args vs. flat_args in return 39 | known_proxies = {} 40 | for bsym in trace.bound_symbols: 41 | if (version >= 3) and bsym.sym == thunder.core.prims.unpack_sequence: 42 | coll = bsym.args[0].collection() 43 | assert len(coll) == len(bsym.output), f"unpack collection length mismatch {bsym}" 44 | for c, o in zip(coll, bsym.output): 45 | if o is None: # unused outputs 46 | continue 47 | if isinstance(c, thunder.Proxy): 48 | assert c is o, f"mismatch in unpack collection: {c} {o} {bsym}" 49 | 50 | for a in bsym.flat_proxy_args: 51 | assert a.name in known_proxies, f"unknown proxy {a.name} is used in {bsym} args" 52 | assert known_proxies[a.name] is a, f"proxy name collision {a.name} in {bsym} args" 53 | for o in bsym.flat_proxy_outs: 54 | assert known_proxies.get(o.name, o) is o, f"known proxy or proxy name collision {o.name} in {bsym} outputs" 55 | known_proxies[o.name] = o 56 | check_subsymbols(bsym) 57 | 58 | tr = thunder.core.trace.from_trace(trace) 59 | with thunder.core.trace.tracectx(tr): 60 | for bsym in trace.bound_symbols: 61 | if bsym.sym.name.startswith("unpack") or bsym.sym.name in {"pack_buffer"}: 62 | continue 63 | res = bsym.sym(*bsym.args, **bsym.kwargs) 64 | 65 | def check_shape(x, y): 66 | if isinstance(x, thunder.TensorProxy) and y is not None: # formal output can be none if unused 67 | assert x.shape == y.shape, f"shape of proxy {y.name} recomputes to {x.shape} incorrectly in {bsym}" 68 | return x 69 | 70 | thunder.core.utils.safe_map_flat(check_shape, res, bsym.output) 71 | 72 | 73 | class CheckedListOfTraces(list): 74 | def __init__(self, *args, **kwargs): 75 | super().__init__(*args, **kwargs) 76 | cd = thunder.core.compile_data.get_compile_data() 77 | if cd.debug_options.check_traces is True: 78 | self._check_version = CHECK_VERSION 79 | elif not cd.debug_options.check_traces: 80 | self._check_version = 0 81 | else: 82 | self._check_version = cd.debug_options.check_traces 83 | 84 | def append(self, trace): 85 | check_trace(trace, version=self._check_version) 86 | super().append(trace) 87 | 88 | def extend(self, traces): 89 | for tr in traces: 90 | check_trace(tr, version=self._check_version) 91 | super().extend(traces) 92 | -------------------------------------------------------------------------------- /thunder/dev_utils/nvtx_profile_transform.py: -------------------------------------------------------------------------------- 1 | from thunder.core.trace import TraceCtx as Trace, from_trace, TraceProvenance 2 | from thunder.dev_utils.utils import NON_COMPUTATION_PRIMS 3 | from thunder.extend import OperatorExecutor 4 | import time 5 | import torch 6 | import thunder 7 | 8 | 9 | class Timer: 10 | def __init__(self): 11 | self.start_time_ns = None 12 | self.end_time_ns = None 13 | 14 | def __enter__(self): 15 | self.start_time_ns = time.perf_counter_ns() 16 | return self 17 | 18 | def __exit__(self, *args): 19 | self.end_time_ns = time.perf_counter_ns() 20 | 21 | def get_elapsed_time_in_ms(self): 22 | elapsed_time_ns = self.end_time_ns - self.start_time_ns 23 | return elapsed_time_ns // int(1e6) 24 | 25 | 26 | nvtx_profiler_ex = OperatorExecutor("nvtx_profiler_ex") 27 | 28 | 29 | def nvtx_push_impl(msg): 30 | torch.cuda.nvtx.range_push(msg) 31 | 32 | 33 | def nvtx_pop_impl(): 34 | torch.cuda.nvtx.range_pop() 35 | 36 | 37 | # Symbols for profiling. 38 | nvtx_push = nvtx_profiler_ex.register_operator("nvtx_range_push", meta=lambda msg: None, fn=nvtx_push_impl) 39 | nvtx_pop = nvtx_profiler_ex.register_operator("nvtx_range_pop", meta=lambda: None, fn=nvtx_pop_impl) 40 | 41 | 42 | class NvtxProfileTransform(thunder.core.transforms.Transform): 43 | def transform_trace_post_optimization(self, trace: Trace, **kwargs) -> Trace: 44 | with Timer() as timer: 45 | profile_trace = from_trace(trace) 46 | 47 | for bound_symbol in trace.bound_symbols: 48 | if bound_symbol.sym.id in NON_COMPUTATION_PRIMS: 49 | profile_trace.bound_symbols.append(bound_symbol) 50 | continue 51 | 52 | # Add nvtx range for the symbol. 53 | profile_trace.bound_symbols.append( 54 | nvtx_push.bind(f"{''.join(bound_symbol.python(indent=0))}", output=None) 55 | ) 56 | profile_trace.bound_symbols.append(bound_symbol) 57 | profile_trace.bound_symbols.append(nvtx_pop.bind(output=None)) 58 | 59 | profile_trace.set_provenance( 60 | TraceProvenance(f"NVTX Profile Transform (took {timer.get_elapsed_time_in_ms()} milliseconds)") 61 | ) 62 | return profile_trace 63 | -------------------------------------------------------------------------------- /thunder/distributed/tensor_parallel/__init__.py: -------------------------------------------------------------------------------- 1 | import torch.distributed 2 | 3 | from thunder.distributed.tensor_parallel.common import TensorParallelLayerType 4 | 5 | if torch.distributed.is_available(): 6 | from thunder.distributed.tensor_parallel.column_wise import column_parallel 7 | from thunder.distributed.tensor_parallel.row_wise import row_parallel 8 | else: 9 | column_parallel = None 10 | row_parallel = None 11 | 12 | 13 | __all__ = [ 14 | "TensorParallelLayerType", 15 | "column_parallel", 16 | "row_parallel", 17 | ] 18 | -------------------------------------------------------------------------------- /thunder/distributed/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | if torch.distributed.is_available(): 4 | from .ddp import apply_bucketing_to_grad_allreduce 5 | from .fsdp import FSDPCommBucketing 6 | else: 7 | apply_bucketing_to_grad_allreduce = None 8 | FSDPCommBucketing = None 9 | 10 | __all__ = [ 11 | "apply_bucketing_to_grad_allreduce", 12 | "FSDPCommBucketing", 13 | ] 14 | -------------------------------------------------------------------------------- /thunder/dynamo/__init__.py: -------------------------------------------------------------------------------- 1 | from thunder.dynamo.compiler import ( 2 | ThunderCompiler, 3 | ThunderFXCompiledObject, 4 | thunder_optimize, 5 | thunder_profile, 6 | thunderfx, 7 | ) 8 | 9 | 10 | __all__ = [ 11 | "ThunderCompiler", 12 | "ThunderFXCompiledObject", 13 | "thunder_optimize", 14 | "thunder_profile", 15 | "thunderfx", 16 | ] 17 | -------------------------------------------------------------------------------- /thunder/dynamo/repro_script_template.py: -------------------------------------------------------------------------------- 1 | FXGRAPH_CLASS_NAME = "DynamoModule" 2 | INPUTS_NAME = "inputs" 3 | CALLABLE_NAME = "model" 4 | COMPILED_CALLABLE_NAME = "compiled_model" 5 | THUNDER_IMPORT_STRS = """ 6 | from thunder.dev_utils import nvtx_profile_transform 7 | """ 8 | 9 | pytest_benchmark_multi_exe_code_template = f''' 10 | # NOTE: This script requires `pytest-benchmark==4.0.0` to be installed. 11 | # To execute the script, run `pytest {{graph_name}}_benchmark.py --benchmark-timer=torch.utils.benchmark.utils.timer.timer --benchmark-warmup=on --benchmark-group-by=param:compute_type` 12 | # To check the peak allocated CUDA memory, use --benchmark-json=json_file_name and look at the "max_allocated_memory_MB" field in the json file 13 | # To run tests for a specific compute_type, use the pytest `-k` option. 14 | # For example: 15 | # - `-k "forward"` will run only the forward pass. 16 | # 17 | # Available options: 18 | # - compute_type: "forward", "backward" 19 | 20 | import pytest 21 | from thunder.benchmarks.targets import parametrize_compute_type_only_training, benchmark_for_compute_type, ComputeType 22 | {{torch_import_str}} 23 | {{import_str}} 24 | {THUNDER_IMPORT_STRS} 25 | 26 | # NOTE: The reproducer function has already been processed by TorchDynamo. 27 | # If we let it go through TorchDynamo again, it could be segmented further. 28 | # To avoid this, we directly use Inductor here. 29 | # See issue https://github.com/Lightning-AI/lightning-thunder/issues/1521 30 | def torch_inductor(fn, inputs): 31 | from torch._inductor import compile as inductor_compile 32 | from torch.fx import symbolic_trace 33 | 34 | fx_graph = symbolic_trace(fn) 35 | return inductor_compile(fx_graph, inputs) 36 | 37 | {{executors}} 38 | {{executor_names}} 39 | 40 | {{dynamo_module}} 41 | 42 | @pytest.mark.parametrize( 43 | "executor", 44 | executors, 45 | ids=executor_names, 46 | ) 47 | {{compute_type_decorator}} 48 | def test_{{graph_name}}(benchmark, executor, compute_type): 49 | {{inputs}} 50 | 51 | model = DynamoModule() 52 | if executor is None: 53 | compiled_model = model 54 | elif executor == torch_inductor: 55 | compiled_model = executor(model, inputs) 56 | else: 57 | compiled_model = executor(model) 58 | {{call_benchmark}} 59 | 60 | """ 61 | Environment information get from `torch.utils.collect_env.get_pretty_env_info()`: 62 | {{torch_env}} 63 | 64 | Versions of Thunder related libraries: 65 | {{thunder_pkgs}} 66 | 67 | {{extra_comment_str}} 68 | """ 69 | ''' 70 | 71 | 72 | bsym_torch_compile_repro_template = ''' 73 | """ 74 | {extra_comment_str} 75 | """ 76 | {python_func} 77 | 78 | from thunder.executors.torch_compile import make_compiled as make_torch_compile_callable 79 | import thunder.examine 80 | 81 | inputs = {inputs} 82 | 83 | jfn = thunder.jit({func_name}) 84 | jfn(*inputs) 85 | 86 | trc = thunder.last_traces(jfn)[-1] 87 | fusion_symbols = thunder.examine.get_fusion_symbols(trc) 88 | assert len(fusion_symbols) == 1 89 | bsym = fusion_symbols[0] 90 | 91 | # NOTE: The nvFusion function cannot be compiled directly using `torch.compile`. 92 | # It must first be processed by Thunder into BoundSymbols and compiled with `make_torch_compile_callable`. 93 | # Additionally, it's recommended to visually verify that `bsym` matches the 94 | # `nvFusion` function above by printing it using `print(bsym)`. 95 | torch_compiled_callable = make_torch_compile_callable(bsym.subsymbols, bsym.flat_args, bsym.flat_outs) 96 | ''' 97 | 98 | repro_bench_code_template = f""" 99 | {{import_str}} 100 | {THUNDER_IMPORT_STRS} 101 | 102 | {{dynamo_module}} 103 | def test_{{graph_name}}(): 104 | {{inputs}} 105 | 106 | model = {FXGRAPH_CLASS_NAME}() 107 | """ 108 | 109 | main_code = """ 110 | if __name__ == "__main__": 111 | test_{graph_name}() 112 | """ 113 | 114 | comment_str_template = ''' 115 | """ 116 | Environment information get from `torch.utils.collect_env.get_pretty_env_info()`: 117 | {torch_env} 118 | 119 | Versions of Thunder related libraries: 120 | {thunder_pkgs} 121 | 122 | {extra_comment_str} 123 | """ 124 | ''' 125 | -------------------------------------------------------------------------------- /thunder/executors/__init__.py: -------------------------------------------------------------------------------- 1 | import thunder.executors.passes as passes 2 | 3 | import thunder.extend as extend 4 | 5 | 6 | # NOTE The executors submodule depends on the extend submodule 7 | 8 | __all__ = [ 9 | "passes", 10 | "get_torch_executor", 11 | "get_nvfuser_executor", 12 | "nvfuser_available", 13 | ] 14 | 15 | 16 | def get_nvfuser_executor() -> None | extend.Executor: 17 | return extend.get_executor("nvfuser") 18 | 19 | 20 | def get_torch_executor() -> extend.Executor: 21 | return extend.get_executor("torch") 22 | 23 | 24 | def nvfuser_available() -> bool: 25 | return get_nvfuser_executor() is not None 26 | -------------------------------------------------------------------------------- /thunder/executors/apex_fused_rms_norm_impl.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Sequence 2 | import math 3 | 4 | 5 | from thunder.core.proxies import TensorProxy 6 | from thunder.torch import TensorLike 7 | from thunder.executors.apexex import apex_ex 8 | 9 | 10 | APEX_FUSED_NORMS_AVAILABLE = True 11 | try: 12 | # Fused layer norm is only importable if torch.distributed is available 13 | # https://github.com/NVIDIA/apex/issues/1853 14 | from torch.distributed import is_available 15 | 16 | if not is_available(): 17 | raise ImportError 18 | import fused_layer_norm_cuda 19 | from apex.normalization.fused_layer_norm import FusedRMSNormAffineMixedDtypesFunction # noqa: F401 20 | except ImportError: 21 | APEX_FUSED_NORMS_AVAILABLE = False 22 | 23 | 24 | def apex_fused_norms_available() -> bool: 25 | return APEX_FUSED_NORMS_AVAILABLE 26 | 27 | 28 | def apex_fused_rms_norm_forward_affine_meta( 29 | input: TensorLike, normalized_shape: Sequence[int], weight: TensorLike, eps: float 30 | ): 31 | unnormalized_dims = len(input.shape) - len(normalized_shape) 32 | invvar = TensorProxy(like=input, shape=(math.prod(input.shape[:unnormalized_dims]),)) 33 | return TensorProxy(like=input), invvar 34 | 35 | 36 | def apex_fused_rms_norm_backward_affine_meta( 37 | grad_output: TensorLike, 38 | invvar: TensorLike, 39 | input_or_output: TensorLike, 40 | normalized_shape: Sequence[int], 41 | weight_, 42 | eps: float, 43 | memory_efficient: bool, 44 | ): 45 | return TensorProxy(like=grad_output), TensorProxy(like=weight_) 46 | 47 | 48 | # Create a new symbol and register lookaside only if import is available. 49 | if apex_fused_norms_available(): 50 | apex_fused_rms_norm_forward_affine = apex_ex.register_operator( 51 | "apex_fused_rms_norm_forward_affine", 52 | meta=apex_fused_rms_norm_forward_affine_meta, 53 | fn=fused_layer_norm_cuda.rms_forward_affine, 54 | replaces=fused_layer_norm_cuda.rms_forward_affine, 55 | ) 56 | 57 | apex_fused_rms_norm_forward_affine_mixed_dtypes = apex_ex.register_operator( 58 | "apex_fused_rms_norm_forward_affine_mixed_dtypes", 59 | meta=apex_fused_rms_norm_forward_affine_meta, 60 | fn=fused_layer_norm_cuda.rms_forward_affine_mixed_dtypes, 61 | replaces=fused_layer_norm_cuda.rms_forward_affine_mixed_dtypes, 62 | ) 63 | 64 | apex_fused_rms_norm_backward_affine = apex_ex.register_operator( 65 | "apex_fused_rms_norm_backward_affine", 66 | meta=apex_fused_rms_norm_backward_affine_meta, 67 | fn=fused_layer_norm_cuda.rms_backward_affine, 68 | replaces=fused_layer_norm_cuda.rms_backward_affine, 69 | ) 70 | -------------------------------------------------------------------------------- /thunder/executors/apexex.py: -------------------------------------------------------------------------------- 1 | from thunder.extend import OperatorExecutor, register_executor 2 | 3 | 4 | # TODO Does apex have a version this should use? 5 | apex_ex = OperatorExecutor("apex", version="0.1") 6 | register_executor(apex_ex) 7 | 8 | 9 | from thunder.executors.apex_entropyex_impl import apex_entropy_available 10 | from thunder.executors.apex_fused_rms_norm_impl import apex_fused_norms_available 11 | 12 | __all__ = ["apex_ex", "apex_entropy_available", "apex_fused_norms_available"] 13 | -------------------------------------------------------------------------------- /thunder/executors/cudnnex.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Callable 2 | 3 | from looseversion import LooseVersion 4 | from thunder.extend import OperatorExecutor 5 | 6 | __all__ = ["cudnn_version", "required_cudnn_version", "cudnn_available", "cudnn_ex", "torch_to_cudnn_dtype"] 7 | 8 | 9 | # 10 | # Functions for detecting cudnn and its version 11 | # 12 | def cudnn_version() -> LooseVersion | None: 13 | try: 14 | import cudnn 15 | 16 | if hasattr(cudnn, "__version__"): 17 | return LooseVersion(cudnn.__version__) 18 | 19 | # NOTE: This import of cudnn may or may not have version info 20 | return LooseVersion("0.0.0") 21 | except ImportError: 22 | pass 23 | 24 | # NOTE This occurs when cudnn couldn't be imported 25 | return None 26 | 27 | 28 | def required_cudnn_version() -> LooseVersion: 29 | # History of versions: 30 | # Using 1.3.0+ because it works better with other libraries (e.g. torch) that also build on top of cudnn 31 | # Using 1.5.0+ because it handles exception with unsupported graphs better 32 | # Using 1.5.1 because of a compatibility fix 33 | # Using 1.5.2 to allow stride 0 34 | return LooseVersion("1.5.2") 35 | 36 | 37 | def cudnn_available() -> bool: 38 | v = cudnn_version() 39 | if v is None: 40 | return False 41 | required = required_cudnn_version() 42 | if v < required: 43 | import warnings 44 | 45 | msg = f"Your cuDNN installation is out of date. Thunder requires version {required}, but found version {v}." 46 | warnings.warn(msg) 47 | return False 48 | return True 49 | 50 | 51 | cudnn_ex: None | OperatorExecutor = None 52 | torch_to_cudnn_dtype: None | Callable = None 53 | cudnn = None 54 | 55 | if cudnn_available(): 56 | import thunder.executors.cudnn_sdpa as sdpa_impl 57 | 58 | torch_to_cudnn_dtype = sdpa_impl.torch_to_cudnn_dtype 59 | cudnn_ex = sdpa_impl.cudnn_ex 60 | -------------------------------------------------------------------------------- /thunder/executors/custom_op_ex.py: -------------------------------------------------------------------------------- 1 | """Executor for `torch.library.custom_op` operators""" 2 | 3 | from thunder.extend import OperatorExecutor 4 | 5 | 6 | __all__ = [ 7 | "custom_op_ex", 8 | ] 9 | 10 | 11 | custom_op_ex = OperatorExecutor("custom_op") 12 | -------------------------------------------------------------------------------- /thunder/executors/nvfuserex.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from looseversion import LooseVersion 3 | 4 | from thunder.extend import FusionExecutor 5 | 6 | __all__ = ["nvfuser_version", "required_nvfuser_version", "nvfuser_available", "nvfuserex"] 7 | 8 | 9 | # 10 | # Functions for detecting nvFuser and its version 11 | # 12 | def nvfuser_version() -> LooseVersion | None: 13 | # Short-circuits if CUDA isn't available 14 | if not torch.cuda.is_available(): 15 | return None 16 | 17 | try: 18 | import nvfuser 19 | 20 | if hasattr(nvfuser, "version"): 21 | return LooseVersion(nvfuser.version()) 22 | 23 | # NOTE: This import of nvFuser may or may not have version info 24 | return LooseVersion("0.0.0") 25 | except ImportError: 26 | pass 27 | 28 | # NOTE This occurs when nvFuser couldn't be imported 29 | return None 30 | 31 | 32 | def required_nvfuser_version() -> LooseVersion: 33 | return LooseVersion("0.2.8") 34 | 35 | 36 | def nvfuser_available() -> bool: 37 | v = nvfuser_version() 38 | if v is None: 39 | return False 40 | 41 | required = required_nvfuser_version() 42 | if v < required: 43 | import warnings 44 | 45 | msg = f"Your nvfuser installation is out of date. Thunder requires version {required}, but found version {v}." 46 | warnings.warn(msg) 47 | return False 48 | return True 49 | 50 | 51 | nvfuserex: None | FusionExecutor = None 52 | if nvfuser_available(): 53 | import thunder.executors.nvfuserex_impl as impl 54 | 55 | nvfuserex = impl.ex 56 | -------------------------------------------------------------------------------- /thunder/executors/transformer_engineex.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | from collections.abc import Callable 4 | 5 | from lightning_utilities.core.imports import package_available 6 | 7 | from thunder import Transform 8 | from thunder.extend import StatefulExecutor 9 | from thunder.core.trace import TraceCtx 10 | 11 | import torch 12 | 13 | __all__ = ["transformer_engine_ex", "TransformerEngineTransform", "_te_activation_checkpointing_transform"] 14 | 15 | transformer_engine_ex: None | StatefulExecutor = None 16 | TransformerEngineTransform: None | Transform = None 17 | _te_activation_checkpointing_transform: None | Callable[[TraceCtx], TraceCtx] = None 18 | 19 | if torch.cuda.is_available(): 20 | if package_available("transformer_engine"): 21 | import thunder.executors.transformer_engineex_impl as impl 22 | 23 | transformer_engine_ex = impl.transformer_engine_ex 24 | TransformerEngineTransform = impl.TransformerEngineTransform 25 | _te_activation_checkpointing_transform = impl._te_activation_checkpointing_transform 26 | 27 | else: 28 | warnings.warn("transformer_engine module not found!") 29 | -------------------------------------------------------------------------------- /thunder/executors/triton_crossentropy.py: -------------------------------------------------------------------------------- 1 | from thunder.executors import triton_utils 2 | from thunder.extend import OperatorExecutor 3 | 4 | import torch 5 | 6 | triton_version: None | str = triton_utils.triton_version() 7 | 8 | triton_ex: None | OperatorExecutor = None 9 | 10 | if torch.cuda.is_available(): 11 | if triton_version is not None: 12 | try: 13 | from thunder.executors.triton_crossentropy_impl import triton_ex as impl_ex 14 | 15 | triton_ex = impl_ex 16 | except Exception: 17 | import warnings 18 | 19 | warnings.warn("triton is present but cannot be initialized") 20 | triton_version = None 21 | -------------------------------------------------------------------------------- /thunder/executors/triton_utils.py: -------------------------------------------------------------------------------- 1 | from looseversion import LooseVersion 2 | from lightning_utilities.core.imports import package_available 3 | 4 | 5 | def triton_version() -> None | str: 6 | if not package_available("triton"): 7 | return None 8 | 9 | import triton 10 | 11 | return triton.__version__ 12 | 13 | 14 | def is_triton_version_at_least(minimum_version: str) -> bool: 15 | version = triton_version() 16 | 17 | if version is None: 18 | return False 19 | 20 | return LooseVersion(version) >= minimum_version 21 | -------------------------------------------------------------------------------- /thunder/numpy/__init__.py: -------------------------------------------------------------------------------- 1 | from numbers import Number 2 | from collections.abc import Callable 3 | 4 | from thunder.core.langctx import langctx, Languages 5 | from thunder.numpy.langctx import register_method 6 | 7 | from thunder.core.proxies import TensorProxy 8 | from thunder.core.symbol import Symbol 9 | import thunder.clang as clang 10 | 11 | 12 | # 13 | # NumPy operator definitions 14 | # 15 | # NOTE NumPy language support is demonstrative. PRs extending it are welcome! 16 | 17 | 18 | # Decorator that sets the language context and constructs a Symbol for each function 19 | class npsymbol: 20 | def __init__(self, *, method_name: None | str = None): 21 | self.method_name: None | str = method_name 22 | 23 | def __call__(self, fn: Callable) -> Symbol: 24 | _fn = langctx(Languages.NUMPY)(fn) 25 | # TODO: register _fn as opaque with the interpreter or do this in jit_ext? 26 | sym = Symbol(name=fn.__name__, meta=_fn) 27 | 28 | if self.method_name is not None: 29 | register_method(self.method_name, _fn) 30 | 31 | return sym 32 | 33 | 34 | # 35 | # Tensor properties 36 | # 37 | 38 | 39 | # NOTE Named `compute_len` so that it doesn't conflict with built-in `len` 40 | def compute_len(a: TensorProxy, /) -> int: 41 | return a.shape[0] 42 | 43 | 44 | register_method("len", compute_len) 45 | 46 | 47 | def size(a: TensorProxy, /) -> int: 48 | return a.numel 49 | 50 | 51 | register_method("size", size) 52 | 53 | 54 | # 55 | # Elementwise binary operators 56 | # 57 | 58 | 59 | # TODO Create a factory that adds ufunc support to elementwise operations 60 | npsymbol(method_name="add") 61 | 62 | 63 | def add(a: Number | TensorProxy, b: Number | TensorProxy, /, *, where: None | Number | TensorProxy = None): 64 | result = clang.add(a, b) 65 | if where is not None: 66 | return clang.where(where, result, a) 67 | return result 68 | -------------------------------------------------------------------------------- /thunder/numpy/langctx.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Callable 2 | 3 | from thunder.core.langctxs import LanguageContext, register_langctx, Languages, resolve_language 4 | from thunder.core.pytree import tree_flatten 5 | from thunder.core.proxies import TensorProxy 6 | 7 | # 8 | # Creates and registers the torch language context 9 | # 10 | # NOTE That this is done separately from the definition of thunder.torch operations, because the 11 | # language context must be available before those operations are defined 12 | 13 | _method_name_to_fn_map: dict[str, Callable] = {} 14 | 15 | 16 | # Creates and registers the core language language context 17 | class NumPyCtx(LanguageContext): 18 | def __init__(self): 19 | super().__init__("numpy") 20 | 21 | def has_method(self, id: str) -> bool: 22 | return id in _method_name_to_fn_map 23 | 24 | def get_method(self, id: str, *args, **kwargs) -> Callable: 25 | # Note: concrete implementations should only raise AttributeError or 26 | # return None for "missing" methods as the proxies will 27 | # route __getattr__ to here and hasattr relies on __getattr__ 28 | # throwing AttributeError (only) when the attribute does 29 | # not exist. 30 | inps, _ = tree_flatten((args, kwargs)) 31 | 32 | has_tensor_input: bool = False 33 | for x in inps: 34 | if isinstance(x, TensorProxy): 35 | has_tensor_input = True 36 | break 37 | 38 | if has_tensor_input: 39 | method: None | Callable = _method_name_to_fn_map.get(id, None) 40 | if method is None: 41 | raise AttributeError(f"The {self.name} language context has no method {id}") 42 | return method 43 | 44 | # has_tensor_input is False 45 | # Defers to the primitive language context when there are no tensor inputs= 46 | # (the primitive language context handles operations on numbers) 47 | primsctx: LanguageContext = resolve_language(Languages.PRIMS) 48 | if not primsctx.has_method(id): 49 | raise AttributeError( 50 | f"Attempting to call method {id} in the numpy language context, but it has no tensor inputs and the primitive language context (which handles numbers) doesn't have the method" 51 | ) 52 | prim_method: Callable = primsctx.get_method(id, *args, **kwargs) 53 | return prim_method 54 | 55 | 56 | numpyctx = NumPyCtx() 57 | register_langctx(Languages.NUMPY, numpyctx) 58 | register_langctx("numpy", numpyctx) 59 | 60 | 61 | # Registers a method with the torch language context 62 | def register_method(method_name: str, method: Callable, /) -> None: 63 | _method_name_to_fn_map[method_name] = method 64 | -------------------------------------------------------------------------------- /thunder/plugins/__init__.py: -------------------------------------------------------------------------------- 1 | from thunder.plugins.distributed import DDP, FSDP 2 | from thunder.plugins.quantization import QuantizeInt4 3 | from thunder.plugins.fp8 import FP8 4 | from thunder.plugins.reduce_overhead import ReduceOverhead 5 | 6 | names_to_plugins = { 7 | "ddp": DDP, 8 | "fsdp": FSDP, 9 | "quantize-int4": QuantizeInt4, 10 | "fp8": FP8, 11 | "reduce-overhead": ReduceOverhead, 12 | } 13 | 14 | 15 | def get_plugin(name): 16 | return names_to_plugins.get(name) 17 | 18 | 19 | def get_plugin_names(): 20 | return list(names_to_plugins.keys()) 21 | 22 | 23 | def register_plugin(name, cls): 24 | names_to_plugins[name] = cls 25 | -------------------------------------------------------------------------------- /thunder/plugins/fp8.py: -------------------------------------------------------------------------------- 1 | from thunder import Plugin 2 | 3 | 4 | class FP8(Plugin): 5 | """ 6 | Plugin for enabling FP8 precision via NVIDIA Transformer Engine, enabling higher throughput of matrix operations in FP8. 7 | 8 | See `lightning-thunder/thunder/executors/transformer_engine_v1ex.py` for implementation details. 9 | """ 10 | 11 | def setup_executors(self): 12 | """ 13 | Imports the Transformer Engine executor. 14 | 15 | Returns: 16 | list[Executor]: A list containing the Transformer Engine executor. 17 | 18 | """ 19 | from thunder.executors.transformer_engine_v1ex import transformer_engine_v1_ex 20 | 21 | return [transformer_engine_v1_ex] 22 | -------------------------------------------------------------------------------- /thunder/plugins/quantization.py: -------------------------------------------------------------------------------- 1 | from thunder import Plugin 2 | 3 | 4 | class QuantizeInt4(Plugin): 5 | """ 6 | Plugin for 4-bit integer quantization using BitsAndBytes. 7 | 8 | This plugin applies a 4-bit linear quantization transform to 9 | model weights, reducing memory footprint and improving 10 | throughput for both training and inference. 11 | 12 | See https://github.com/bitsandbytes-foundation/bitsandbytes/blob/main/bitsandbytes/functional.py#L889 for more details. 13 | """ 14 | 15 | def setup_transforms(self): 16 | """ 17 | Fetches the BitsAndBytes quantization transform. 18 | 19 | Returns: 20 | list[Transform]: A list containing the Transformer Engine executor. 21 | """ 22 | 23 | from thunder.transforms.quantization import BitsAndBytesLinearQuant4bit 24 | 25 | return [BitsAndBytesLinearQuant4bit()] 26 | 27 | def setup_executors(self): 28 | """ 29 | Fetches the BitsAndBytes quantization executor. 30 | 31 | Returns: 32 | list[Executor]: A list containing the Transformer Engine executor. 33 | 34 | """ 35 | from thunder.transforms.quantization import get_bitsandbytes_executor 36 | 37 | return [get_bitsandbytes_executor()] 38 | -------------------------------------------------------------------------------- /thunder/plugins/reduce_overhead.py: -------------------------------------------------------------------------------- 1 | from thunder.core.recipe import Plugin, PluginPolicy 2 | from thunder.transforms.cudagraph import CUDAGraphTransform 3 | 4 | 5 | class ReduceOverhead(Plugin): 6 | """ 7 | Plugin to enable CUDA Graphs and reduce CPU overhead. 8 | """ 9 | 10 | policy = PluginPolicy.POST 11 | 12 | def setup_transforms(self): 13 | """ 14 | Fetches the CUDAGraph transform. 15 | 16 | Returns: 17 | list[Transform]: A list containing the CUDAGraph transform. 18 | """ 19 | return [CUDAGraphTransform()] 20 | -------------------------------------------------------------------------------- /thunder/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-AI/lightning-thunder/6720e82b18884a26bbadaffff2f33d836682cc9d/thunder/py.typed -------------------------------------------------------------------------------- /thunder/recipes/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from thunder.recipes.base import BaseRecipe 4 | from thunder.recipes.hf_transformers import HFTransformers 5 | 6 | 7 | names_to_recipes: dict[str : type[Any]] = { 8 | "default": BaseRecipe, 9 | "hf-transformers": HFTransformers, 10 | } 11 | 12 | 13 | def get_recipe_class(name: str) -> type[Any]: 14 | return names_to_recipes.get(name) 15 | 16 | 17 | def get_recipes() -> list[str]: 18 | return list(names_to_recipes.keys()) 19 | 20 | 21 | def register_recipe(name: str, cls: type[Any]): 22 | if name == "auto": 23 | raise ValueError("Recipe name 'auto' is reserved.") 24 | names_to_recipes[name] = cls 25 | -------------------------------------------------------------------------------- /thunder/tests/README.md: -------------------------------------------------------------------------------- 1 | Testing is important! 2 | 3 | Most Thunder operators are tested by automatically generated tests that use data from the operator's "OpInfo", defined 4 | in `opinfos.py`. These tests typically compare Thunder's operator behavior with another framework, using the 5 | "sample inputs" defined by its OpInfo. Each OpInfo should define a sample input generator and “reference” implementation in PyTorch (and in the future, NumPy). 6 | 7 | When an operator is very similar across the three language levels (user, core, and primitive), like cos, only one variation is typically tested (the core language variant is preferred). 8 | However, there are also cases where testing a torch language operation or primitive operation directly makes sense. 9 | 10 | Operator tests are autogenerated in `test_ops.py`. To run the tests for a particular operator, use pytest’s -k option. 11 | For example, to run the tests for `cos`, for example, the command would be 12 | 13 | ```bash 14 | pytest test_ops.py -k cos -v 15 | ``` 16 | 17 | This will run tests for Thunder’s different executors, supported dtypes, and supported devicetypes. 18 | -------------------------------------------------------------------------------- /thunder/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-AI/lightning-thunder/6720e82b18884a26bbadaffff2f33d836682cc9d/thunder/tests/__init__.py -------------------------------------------------------------------------------- /thunder/tests/bf16.py: -------------------------------------------------------------------------------- 1 | import thunder.core 2 | import torch 3 | 4 | 5 | def device_supports_bf16(device: None | str | torch.device | thunder.core.devices.Device, /) -> bool: 6 | """Torch has its own `torch.cuda.is_bf16_supported()`, but the contract of 7 | this API changed in December 2023 to be "can bf16 be represented?", which 8 | is true even for devices that existed before bf16 was invented. 9 | 10 | The contract for this API is that the device implements bf16 operations.""" 11 | if not torch.cuda.is_available(): 12 | return False 13 | 14 | dev: torch.device = thunder.core.devices.to_torch_device(device) 15 | 16 | if dev.type != "cuda": 17 | return False 18 | 19 | cuda_major: int 20 | cuda_minor: int 21 | cuda_major, cuda_minor = torch.cuda.get_device_capability(dev) 22 | return cuda_major >= 8 23 | -------------------------------------------------------------------------------- /thunder/tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import pytest_benchmark 3 | from thunder.dynamo.compiler_graph_benchmark import GRAPH_BY_GRAPH_BENCHMARK_PARAMS_KEYS 4 | 5 | import torch 6 | 7 | try: 8 | import nvfuser 9 | except ImportError: 10 | nvfuser = None 11 | 12 | 13 | @pytest.fixture(autouse=True) 14 | def gpu_memory(request): 15 | if torch.cuda.is_available(): 16 | torch.cuda.empty_cache() 17 | torch.cuda.reset_peak_memory_stats() 18 | gpu_mem_before = torch.cuda.max_memory_allocated() / 2**30 19 | yield 20 | if torch.cuda.is_available(): 21 | gpu_mem = torch.cuda.max_memory_allocated() / 2**30 22 | gpu_mem_limit = request.config.getoption("--gpu-mem-limit") 23 | gpu_mem_use = gpu_mem - gpu_mem_before 24 | if gpu_mem_limit: 25 | assert gpu_mem_use <= gpu_mem_limit, ( 26 | f"test needs {gpu_mem - gpu_mem_before:.2f}GB VRAM, only {gpu_mem_limit:.2f}GB allowed" 27 | ) 28 | torch.cuda.empty_cache() 29 | torch.cuda.reset_peak_memory_stats() 30 | 31 | 32 | @pytest.fixture(autouse=True) 33 | def test_cleanup(request): 34 | yield 35 | if nvfuser is not None: 36 | nvfuser.FusionCache.reset() 37 | 38 | 39 | @pytest.hookimpl(hookwrapper=True) 40 | def pytest_benchmark_group_stats(config, benchmarks, group_by): 41 | """ 42 | The function customize the behavior for ThunderCompilerGraphBenchmarking. 43 | The custom grouping function is only invoked when the `--benchmark-group-by` 44 | option is set to 'graph-by-graph:param:GraphID,param:SplitModuleName'. 45 | For an example, refer to the comment section in `ThunderCompilerGraphBenchmarking`. 46 | 47 | Reference: https://pytest-benchmark.readthedocs.io/en/latest/hooks.html#pytest_benchmark.hookspec.pytest_benchmark_group_stats 48 | """ 49 | prefix = "graph-by-graph:" 50 | outcome = yield 51 | if group_by.startswith(prefix): 52 | group_by = group_by[len(prefix) :] 53 | for bench in benchmarks: 54 | if bench["params"] is None: 55 | bench["params"] = {} 56 | # The benchs with the same `params`` share the same dict 57 | # We need to create a deepcopy of the original dictionary to add parameters specific to each graph. 58 | else: 59 | bench["params"] = bench["params"].copy() 60 | if bench["param"] is None: 61 | bench["param"] = "" 62 | 63 | name = bench["name"] 64 | gid, module_name, ex = name.split("-")[-3:] 65 | # Add the "GraphID", "SplitModuleName","executor" as params in benchmark 66 | gid_key, module_name_key, ex_key = GRAPH_BY_GRAPH_BENCHMARK_PARAMS_KEYS 67 | bench["params"].update({gid_key: gid, module_name_key: module_name, ex_key: ex}) 68 | bench["param"] += f"-{gid}-{module_name}-{ex}" 69 | 70 | result = pytest_benchmark.plugin.pytest_benchmark_group_stats(config, benchmarks, group_by) 71 | outcome.force_result(result) 72 | 73 | 74 | def pytest_collection_modifyitems(items): 75 | items.sort(key=lambda item: item.name) 76 | 77 | 78 | def pytest_addoption(parser): 79 | parser.addoption("--gpu-mem-limit", type=float) 80 | -------------------------------------------------------------------------------- /thunder/tests/coverage_tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-AI/lightning-thunder/6720e82b18884a26bbadaffff2f33d836682cc9d/thunder/tests/coverage_tests/__init__.py -------------------------------------------------------------------------------- /thunder/tests/coverage_tests/test_coverage_hf_diffusers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import pytest 4 | 5 | 6 | if os.getenv("ALLOW_COVERAGE_TRACE") != "1": 7 | pytest.skip("Skipping test_coverage_hf_diffusers.py in regular CI", allow_module_level=True) 8 | 9 | hf_diffusers_unet2d_condition_model_ids = [ 10 | "runwayml/stable-diffusion-v1-5", 11 | "CompVis/stable-diffusion-v1-4", 12 | "ionet-official/bc8-alpha", 13 | "stabilityai/sd-turbo", 14 | "runwayml/stable-diffusion-inpainting", 15 | "stabilityai/stable-diffusion-xl-base-1.0", 16 | "stabilityai/stable-diffusion-xl-refiner-1.0", 17 | "diffusers/stable-diffusion-xl-1.0-inpainting-0.1", 18 | ] 19 | 20 | from thunder.tests.framework import requiresCUDA 21 | 22 | 23 | @requiresCUDA 24 | @pytest.mark.parametrize("model_id", hf_diffusers_unet2d_condition_model_ids) 25 | def test_hf_diffusers(model_id): 26 | from thunder.dynamo import thunderfx 27 | from diffusers import UNet2DConditionModel 28 | 29 | unet_config = UNet2DConditionModel.load_config(model_id, subfolder="unet", torch_dtype=torch.bfloat16) 30 | unet = UNet2DConditionModel(unet_config) 31 | in_channels = unet.config.in_channels 32 | cross_attention_dim = unet.config.cross_attention_dim 33 | addition_embed_type = unet.config.addition_embed_type 34 | 35 | sample_size = 4 36 | batch_size = 1 37 | seq_length = 4 38 | 39 | if "xl" in model_id: 40 | time_ids_dim = 6 41 | text_embeds_dim = 4 42 | if "refiner" in model_id: 43 | time_ids_dim = 2 44 | text_embeds_dim = 4 45 | else: 46 | time_ids_dim = None 47 | text_embeds_dim = None 48 | 49 | input_shape = (batch_size, in_channels, sample_size, sample_size) 50 | hidden_states_shape = (batch_size, seq_length, cross_attention_dim) 51 | 52 | unet = unet.to("cuda", dtype=torch.bfloat16).requires_grad_(True) 53 | compiled_model = thunderfx(unet) 54 | 55 | def make_inputs(dtype=torch.bfloat16): 56 | added_cond_kwargs = {} 57 | with torch.device("cuda"): 58 | input = torch.randn(input_shape, dtype=dtype) 59 | hidden_states = torch.randn(hidden_states_shape, dtype=dtype) 60 | timestep = torch.ones(batch_size, dtype=torch.long) 61 | if addition_embed_type is not None: 62 | assert text_embeds_dim is not None and time_ids_dim is not None 63 | time_ids_shape = (batch_size, time_ids_dim) 64 | text_embeds_shape = (batch_size, text_embeds_dim) 65 | added_cond_kwargs["time_ids"] = torch.randn(time_ids_shape, device="cuda", dtype=dtype) 66 | added_cond_kwargs["text_embeds"] = torch.randn(text_embeds_shape, device="cuda", dtype=dtype) 67 | return (input, timestep, hidden_states), {"added_cond_kwargs": added_cond_kwargs} 68 | 69 | compiled_args, compiled_kwargs = make_inputs(torch.bfloat16) 70 | compiled_output = compiled_model(*compiled_args, **compiled_kwargs) 71 | 72 | ref_output = unet(*compiled_args, **compiled_kwargs) 73 | 74 | ref_output = ref_output.sample 75 | compiled_output = compiled_output.sample 76 | 77 | torch.testing.assert_close(compiled_output, ref_output, rtol=1e-2, atol=2e-1) 78 | 79 | # TODO: Currently fails, needs investigation https://github.com/Lightning-AI/lightning-thunder/issues/2153 80 | # loss_grad = torch.randn_like(compiled_output) 81 | # grads_ref = torch.autograd.grad(ref_output, unet.parameters(), grad_outputs=loss_grad) 82 | # grads_compiled = torch.autograd.grad(compiled_output, unet.parameters(), grad_outputs=loss_grad) 83 | # torch.testing.assert_close(grads_ref, grads_compiled, rtol=1e-1, atol=1e-1) 84 | -------------------------------------------------------------------------------- /thunder/tests/coverage_tests/test_coverage_trace.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import thunder 3 | import os 4 | import pytest 5 | 6 | 7 | if os.getenv("ALLOW_COVERAGE_TRACE") != "1": 8 | pytest.skip("Skipping test_coverage_trace.py in regular CI", allow_module_level=True) 9 | 10 | 11 | from transformers import ( 12 | AutoConfig, 13 | AutoModel, 14 | AutoModelForCausalLM, 15 | AutoModelForSeq2SeqLM, 16 | AutoModelForImageClassification, 17 | ) 18 | from thunder.tests.test_core import run_prologue 19 | 20 | 21 | MODEL_LIST = [ 22 | "gpt2", 23 | "bert-base-uncased", 24 | "google/reformer-enwik8", 25 | "facebook/bart-large", 26 | "t5-base", 27 | "xlnet-base-cased", 28 | "facebook/dinov2-small", 29 | "albert-base-v2", 30 | "google/electra-base-discriminator", 31 | "facebook/opt-1.3b", 32 | "google/vit-base-patch16-224", 33 | ] 34 | 35 | MODEL_TASKS = {"openai/clip-vit-large-patch14": "clip", "facebook/dinov2-small": "vision"} 36 | 37 | 38 | def get_model_class(model_name, config): 39 | task = MODEL_TASKS.get(model_name, "text") 40 | if task == "vision": 41 | return AutoModelForImageClassification 42 | elif config.architectures and "CausalLM" in config.architectures[0]: 43 | return AutoModelForCausalLM 44 | elif config.architectures and "Seq2SeqLM" in config.architectures[0]: 45 | return AutoModelForSeq2SeqLM 46 | else: 47 | return AutoModel 48 | 49 | 50 | # custom input depending on model task 51 | def get_dummy_input(model_name, config): 52 | task = MODEL_TASKS.get(model_name, "text") 53 | 54 | if task == "vision": 55 | return {"pixel_values": torch.randn(1, 3, 224, 224, device="cpu", dtype=torch.float32)} 56 | else: 57 | return {"input_ids": torch.randint(0, 1000, (1, 16), device="cpu")} 58 | 59 | 60 | @pytest.mark.skip(reason="https://github.com/Lightning-AI/lightning-thunder/issues/2436") 61 | @pytest.mark.parametrize("model_name", MODEL_LIST) 62 | def test_model_trace(model_name): 63 | print(f"\n=== Testing {model_name} ===") 64 | 65 | config = AutoConfig.from_pretrained(model_name) 66 | model_class = get_model_class(model_name, config) 67 | model = model_class.from_config(config).to("meta") 68 | input_sample = get_dummy_input(model_name, config) 69 | 70 | jmodel = thunder.jit(model) 71 | ce, pro_to_comp, pro_to_epi = run_prologue(jmodel, **input_sample) 72 | 73 | print(f"[SUCCESS] {model_name} Trace acquired!") 74 | -------------------------------------------------------------------------------- /thunder/tests/distributed/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-AI/lightning-thunder/6720e82b18884a26bbadaffff2f33d836682cc9d/thunder/tests/distributed/__init__.py -------------------------------------------------------------------------------- /thunder/tests/distributed/modules.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import ClassVar, TYPE_CHECKING 3 | 4 | import torch.nn as nn 5 | 6 | from thunder.core import utils 7 | 8 | if TYPE_CHECKING: 9 | import torch 10 | 11 | 12 | __all__ = [ 13 | "ParallelMLP", 14 | ] 15 | 16 | 17 | class ParallelMLP(nn.Module): 18 | """Simplified version of Megatron/NeMo's ParallelMLP. 19 | 20 | Ref: https://github.com/NVIDIA/NeMo/blob/95ca2f4/nemo/collections/nlp/modules/common/megatron/mlp.py#L61 21 | """ 22 | 23 | COLUMN_WISE: ClassVar[tuple[str]] = ("dense_h_to_4h",) 24 | ROW_WISE: ClassVar[tuple[str]] = ("dense_4h_to_h",) 25 | 26 | SUPPORTED_GELU_APPROX: ClassVar[tuple[str, str]] = ("none", "tanh") 27 | 28 | def __init__( 29 | self, 30 | hidden_size: int, 31 | ffn_hidden_size: int | None = None, 32 | bias: bool = True, 33 | gelu_approximate: str = "none", 34 | ) -> None: 35 | utils.check( 36 | gelu_approximate in ParallelMLP.SUPPORTED_GELU_APPROX, 37 | lambda: f"Invalid {gelu_approximate}, supported are {ParallelMLP.SUPPORTED_GELU_APPROX}", 38 | ) 39 | if ffn_hidden_size is None: 40 | ffn_hidden_size = 4 * hidden_size 41 | 42 | super().__init__() 43 | self.dense_h_to_4h = nn.Linear(hidden_size, ffn_hidden_size, bias=bias) 44 | self.dense_4h_to_h = nn.Linear(ffn_hidden_size, hidden_size, bias=bias) 45 | self.gelu = nn.GELU(approximate=gelu_approximate) 46 | 47 | def forward(self, x: torch.Tensor) -> torch.Tensor: 48 | four_h = self.gelu(self.dense_h_to_4h(x)) 49 | h = self.dense_4h_to_h(four_h) 50 | return h 51 | -------------------------------------------------------------------------------- /thunder/tests/hf_bart_self_attn.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # This is an abbreviated version of the HuggingFace Bart Encoder Self-Attention 16 | # module. The block was editted to remove non-essential code for 17 | # self-attention. 18 | 19 | import torch 20 | from torch import nn 21 | 22 | 23 | class BartAttention(nn.Module): 24 | """Multi-headed attention from 'Attention Is All You Need' paper""" 25 | 26 | def __init__( 27 | self, 28 | embed_dim: int, 29 | num_heads: int, 30 | dropout: float = 0.0, 31 | bias: bool = True, 32 | ): 33 | super().__init__() 34 | self.embed_dim = embed_dim 35 | self.num_heads = num_heads 36 | self.dropout = dropout 37 | self.head_dim = embed_dim // num_heads 38 | 39 | if (self.head_dim * num_heads) != self.embed_dim: 40 | raise ValueError( 41 | f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" 42 | f" and `num_heads`: {num_heads})." 43 | ) 44 | self.scaling = self.head_dim**-0.5 45 | 46 | self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 47 | self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 48 | self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 49 | self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 50 | 51 | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): 52 | return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() 53 | 54 | def forward( 55 | self, 56 | hidden_states: torch.Tensor, 57 | attention_mask: torch.Tensor, 58 | ) -> torch.Tensor: 59 | """Input shape: Batch x Time x Channel""" 60 | bsz, tgt_len, _ = hidden_states.size() 61 | 62 | # get query proj 63 | query_states = self.q_proj(hidden_states) * self.scaling 64 | key_states = self._shape(self.k_proj(hidden_states), -1, bsz) 65 | value_states = self._shape(self.v_proj(hidden_states), -1, bsz) 66 | 67 | proj_shape = (bsz * self.num_heads, -1, self.head_dim) 68 | query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) 69 | key_states = key_states.view(*proj_shape) 70 | value_states = value_states.view(*proj_shape) 71 | 72 | src_len = key_states.size(1) 73 | attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) 74 | 75 | if attention_mask is not None: 76 | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask 77 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) 78 | 79 | attn_weights = nn.functional.softmax(attn_weights, dim=-1) 80 | 81 | attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) 82 | 83 | attn_output = torch.bmm(attn_probs, value_states) 84 | 85 | attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) 86 | attn_output = attn_output.transpose(1, 2) 87 | 88 | # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be 89 | # partitioned aross GPUs when using tensor-parallelism. 90 | attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) 91 | 92 | attn_output = self.out_proj(attn_output) 93 | 94 | return attn_output 95 | -------------------------------------------------------------------------------- /thunder/tests/litgpt_model.py: -------------------------------------------------------------------------------- 1 | """Taken from https://github.com/Lightning-AI/litgpt/blob/main/litgpt/model.py""" 2 | 3 | configs = [ 4 | # diverse sample of configs FOR TESTING that cover all major checkpoints variants architecturally but with reduced 5 | # size 6 | dict(name="gpt-neox-like", block_size=128, n_layer=2, n_embd=64, n_head=4, padding_multiple=8), 7 | dict( 8 | name="llama1-like", 9 | block_size=128, 10 | vocab_size=320, 11 | padding_multiple=64, 12 | n_layer=2, 13 | n_head=4, 14 | n_embd=64, 15 | rotary_percentage=1.0, 16 | parallel_residual=False, 17 | bias=False, 18 | norm_class_name="RMSNorm", 19 | norm_eps=1e-6, 20 | mlp_class_name="LLaMAMLP", 21 | intermediate_size=1376, 22 | ), 23 | dict( 24 | name="long-context-like", 25 | block_size=512, 26 | vocab_size=320, 27 | padding_multiple=64, 28 | n_layer=2, 29 | n_head=4, 30 | n_embd=64, 31 | rotary_percentage=1.0, 32 | parallel_residual=False, 33 | bias=False, 34 | norm_class_name="RMSNorm", 35 | mlp_class_name="LLaMAMLP", 36 | intermediate_size=11008, 37 | rope_condense_ratio=4, 38 | ), 39 | dict( 40 | name="llama2-like", 41 | vocab_size=320, 42 | padding_multiple=64, 43 | n_layer=2, 44 | n_head=4, 45 | n_embd=64, 46 | rotary_percentage=1.0, 47 | parallel_residual=False, 48 | bias=False, 49 | norm_class_name="RMSNorm", 50 | mlp_class_name="LLaMAMLP", 51 | intermediate_size=1376, 52 | ), 53 | dict( 54 | name="falcon-7b-like", 55 | block_size=128, 56 | padded_vocab_size=254, 57 | n_layer=2, 58 | n_head=7, 59 | n_embd=448, 60 | rotary_percentage=1.0, 61 | n_query_groups=1, 62 | bias=False, 63 | shared_attention_norm=True, 64 | ), 65 | dict( 66 | name="falcon-40b-like", 67 | block_size=128, 68 | padded_vocab_size=508, 69 | n_layer=2, 70 | n_head=64, 71 | n_embd=256, 72 | rotary_percentage=1.0, 73 | n_query_groups=4, 74 | bias=False, 75 | ), 76 | dict( 77 | name="codellama2-like", 78 | block_size=1024, 79 | vocab_size=2001, 80 | padding_multiple=16, 81 | n_layer=2, 82 | n_head=4, 83 | n_embd=64, 84 | rotary_percentage=1.0, 85 | parallel_residual=False, 86 | bias=False, 87 | norm_class_name="RMSNorm", 88 | norm_eps=1e-05, 89 | mlp_class_name="LLaMAMLP", 90 | intermediate_size=1376, 91 | rope_base=1000000, 92 | ), 93 | dict( 94 | name="mixtral-like", 95 | block_size=512, 96 | padded_vocab_size=500, 97 | n_layer=2, 98 | n_head=64, 99 | n_embd=256, 100 | rotary_percentage=1.0, 101 | n_query_groups=8, 102 | parallel_residual=False, 103 | bias=False, 104 | norm_class_name="RMSNorm", 105 | norm_eps=1e-05, 106 | mlp_class_name="LLaMAMoE", 107 | intermediate_size=224, 108 | rope_base=1000000, 109 | n_expert=8, 110 | n_expert_per_token=2, 111 | ), 112 | ] 113 | 114 | name_to_config = {config["name"]: config for config in configs} 115 | 116 | 117 | import litgpt 118 | 119 | # add the testing configurations 120 | litgpt.config.name_to_config.update(name_to_config) 121 | name_to_config.update(litgpt.config.name_to_config) 122 | 123 | # manually expose for backwards compatibility 124 | Config = litgpt.Config 125 | GPT = litgpt.GPT 126 | RMSNorm = litgpt.model.RMSNorm 127 | CausalSelfAttention = litgpt.model.CausalSelfAttention 128 | LLaMAMLP = litgpt.model.LLaMAMLP 129 | build_rope_cache = litgpt.model.build_rope_cache 130 | apply_rope = litgpt.model.apply_rope 131 | Block = litgpt.model.Block 132 | -------------------------------------------------------------------------------- /thunder/tests/module_example.py: -------------------------------------------------------------------------------- 1 | def returns_two(): 2 | return 2 3 | 4 | 5 | def _returns_three(): 6 | return 3 7 | 8 | 9 | def returns_five(): 10 | return 5 11 | -------------------------------------------------------------------------------- /thunder/tests/test_apex_fused_norms.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torch.testing import assert_close 4 | 5 | fused_layer_norm_cuda = pytest.importorskip("fused_layer_norm_cuda") 6 | 7 | from torch.distributed import is_available 8 | from thunder.executors.apexex import apex_ex 9 | import thunder 10 | 11 | 12 | # See https://github.com/NVIDIA/apex/issues/1853 13 | @pytest.mark.skipif(not is_available(), reason="torch.distributed is not available") 14 | @pytest.mark.parametrize("requires_grad", [True, False]) 15 | @pytest.mark.parametrize("memory_efficient", [True, False]) 16 | def test_apex_fused_rms_norm(requires_grad, memory_efficient): 17 | from apex.normalization.fused_layer_norm import FusedRMSNormAffineMixedDtypesFunction 18 | 19 | def fn(x, weight, normalized_shape, eps): 20 | return FusedRMSNormAffineMixedDtypesFunction.apply(x, weight, normalized_shape, eps, memory_efficient) 21 | 22 | device = "cuda" 23 | normalized_shape = (3, 2) 24 | x = torch.randn(4, 5, *normalized_shape, device=device, requires_grad=requires_grad) 25 | weight = torch.randn(*normalized_shape, device=device, requires_grad=requires_grad) 26 | eps = 1e-5 27 | 28 | expected = fn(x, weight, normalized_shape, eps) 29 | jfn = thunder.jit(fn, executors=[apex_ex]) 30 | actual = jfn(x, weight, normalized_shape, eps) 31 | 32 | assert_close(actual, expected) 33 | 34 | if requires_grad: 35 | grad_output = torch.rand_like(actual) 36 | actual_grad = torch.autograd.grad(actual, [x, weight], grad_output) 37 | expected_grad = torch.autograd.grad(expected, [x, weight], grad_output) 38 | 39 | assert_close(actual_grad, expected_grad) 40 | 41 | 42 | # See https://github.com/NVIDIA/apex/issues/1853 43 | @pytest.mark.skipif(not is_available(), reason="torch.distributed is not available") 44 | @pytest.mark.parametrize("requires_grad", [True, False]) 45 | @pytest.mark.parametrize("memory_efficient", [True, False]) 46 | def test_apex_fused_rms_norm_autoregister(requires_grad, memory_efficient): 47 | from apex.normalization.fused_layer_norm import FusedRMSNormAffineMixedDtypesFunction 48 | 49 | def fn(x, weight, normalized_shape, eps): 50 | return FusedRMSNormAffineMixedDtypesFunction.apply(x, weight, normalized_shape, eps, memory_efficient) 51 | 52 | device = "cuda" 53 | normalized_shape = (3, 2) 54 | x = torch.randn(4, 5, *normalized_shape, device=device, requires_grad=requires_grad) 55 | weight = torch.randn(*normalized_shape, device=device, requires_grad=requires_grad) 56 | eps = 1e-5 57 | 58 | expected = fn(x, weight, normalized_shape, eps) 59 | jfn = thunder.jit(fn, executors=()) 60 | actual = jfn(x, weight, normalized_shape, eps) 61 | 62 | assert_close(actual, expected) 63 | 64 | if requires_grad: 65 | grad_output = torch.rand_like(actual) 66 | actual_grad = torch.autograd.grad(actual, [x, weight], grad_output) 67 | expected_grad = torch.autograd.grad(expected, [x, weight], grad_output) 68 | 69 | assert_close(actual_grad, expected_grad) 70 | -------------------------------------------------------------------------------- /thunder/tests/test_check_trace.py: -------------------------------------------------------------------------------- 1 | import thunder 2 | from thunder.dev_utils.check_trace import check_trace 3 | import torch 4 | import pytest 5 | 6 | 7 | def test_missing_symbol(): 8 | def fn(a, b): 9 | c = a + b 10 | d = 2 * c 11 | return d 12 | 13 | jfn = thunder.jit(fn) 14 | 15 | a = torch.randn(2, 2) 16 | b = torch.randn(2, 2) 17 | 18 | jfn(a, b) 19 | 20 | tr = thunder.last_traces(jfn)[-1] 21 | check_trace(tr) 22 | 23 | del tr.bound_symbols[-3] 24 | 25 | with pytest.raises(AssertionError, match="unknown proxy"): 26 | check_trace(tr) 27 | 28 | 29 | def test_debug_option(): 30 | class BrokenTransform(thunder.core.transform_common.Transform): 31 | def transform_traces_pre_prologue(self, pro, comp, epi, **kwargs): 32 | new_comp = thunder.core.trace.from_trace(comp) 33 | new_comp.bound_symbols = comp.bound_symbols[:] 34 | del new_comp.bound_symbols[2] 35 | return pro, new_comp, epi 36 | 37 | def fn(a, b): 38 | c = a + b 39 | d = 2 * c 40 | return d 41 | 42 | a = torch.randn(2, 2) 43 | b = torch.randn(2, 2) 44 | 45 | # works 46 | jfn = thunder.jit(fn, debug_options=thunder.DebugOptions(check_traces=True)) 47 | jfn(a, b) 48 | 49 | # broken with nice error 50 | jfn = thunder.jit(fn, transforms=(BrokenTransform(),), debug_options=thunder.DebugOptions(check_traces=True)) 51 | 52 | with pytest.raises(AssertionError, match="unknown proxy"): 53 | jfn(a, b) 54 | 55 | # broken with less nice error 56 | jfn = thunder.jit(fn, transforms=(BrokenTransform(),), executors=()) 57 | with pytest.raises(UnboundLocalError, match="cannot access local|referenced before assignment"): 58 | jfn(a, b) 59 | -------------------------------------------------------------------------------- /thunder/tests/test_examine.py: -------------------------------------------------------------------------------- 1 | import thunder.examine 2 | import torch 3 | 4 | 5 | def test_examine_fn(): 6 | def foo(x): 7 | x[0] = 5 * x[1] 8 | 9 | x = torch.ones(2, 2) 10 | thunder.examine.examine(foo, x) 11 | 12 | 13 | def test_examine_jfn(): 14 | def foo(x): 15 | x[0] = 5 * x[1] 16 | 17 | jfoo = thunder.jit(foo) 18 | x = torch.ones(2, 2) 19 | thunder.examine.examine(jfoo, x) 20 | 21 | 22 | def test_examine_noncallable(capsys): 23 | x = torch.ones(2, 2) 24 | y = torch.ones(2, 2) 25 | thunder.examine.examine(x, y) 26 | captured = capsys.readouterr() 27 | assert "expected `fn` to be a callable" in captured.out 28 | -------------------------------------------------------------------------------- /thunder/tests/test_pythonex.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import thunder 4 | 5 | 6 | def _run_cache_symbolic_values(fn, ref_fn, *args): 7 | jit_fn = thunder.jit(fn, cache="symbolic values") 8 | out = jit_fn(*args) 9 | 10 | out_ref = ref_fn(*args) 11 | assert out == out_ref 12 | 13 | 14 | def test_fmod(): 15 | def foo(a, b): 16 | return a % b 17 | 18 | _run_cache_symbolic_values(foo, foo, 2.0, 1.3) 19 | 20 | 21 | def test_bitwise_or(): 22 | def foo(a, b): 23 | return a | b 24 | 25 | _run_cache_symbolic_values(foo, foo, 3, 5) 26 | 27 | 28 | def test_bitwise_and(): 29 | def foo(a, b): 30 | return a & b 31 | 32 | _run_cache_symbolic_values(foo, foo, 3, 5) 33 | 34 | 35 | def test_bitwise_xor(): 36 | def foo(a, b): 37 | return a ^ b 38 | 39 | _run_cache_symbolic_values(foo, foo, 3, 5) 40 | 41 | 42 | def test_math_atan2(): 43 | def foo(a, b): 44 | # TODO: calling through math.atan2 bakes in constant, this needs to be investigated. 45 | return thunder.clang.atan2(a, b) 46 | 47 | # NOTE: we have thunder.clang in foo, which cannot be run with non-proxy 48 | _run_cache_symbolic_values(foo, math.atan2, 2.0, 1.3) 49 | 50 | 51 | def test_math_fmod(): 52 | def foo(a, b): 53 | return thunder.clang.fmod(a, b) 54 | 55 | _run_cache_symbolic_values(foo, math.fmod, 2.0, 1.3) 56 | -------------------------------------------------------------------------------- /thunder/tests/test_reductions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.testing import assert_close 3 | 4 | import thunder 5 | import thunder.torch as ttorch 6 | from thunder.tests.framework import instantiate 7 | 8 | 9 | # TODO: convert these tests to OpInfo generated tests 10 | 11 | 12 | @instantiate(dtypes=(thunder.float32,)) 13 | def test_torch_var(executor, device, dtype): 14 | # Tests passing all arguments as function inputs 15 | def foo(a, dim, *, keepdim=False, correction=1): 16 | return ttorch.var(a, dim, keepdim=keepdim, correction=correction) 17 | 18 | traced_foo = executor.make_callable(foo) 19 | 20 | tdtype = ttorch.to_torch_dtype(dtype) 21 | a = torch.testing.make_tensor((4, 4), device=device, dtype=tdtype) 22 | 23 | # Full reduction 24 | thunder_result = traced_foo(a, [0, 1]) 25 | torch_result = torch.var(a, [0, 1]) 26 | assert_close(thunder_result, torch_result) 27 | 28 | # Reduce along dim 1 29 | thunder_result = traced_foo(a, [1]) 30 | torch_result = torch.var(a, [1]) 31 | assert_close(thunder_result, torch_result) 32 | 33 | # Specifying the correction 34 | thunder_result = traced_foo(a, [1], correction=2) 35 | torch_result = torch.var(a, [1], correction=2) 36 | assert_close(thunder_result, torch_result) 37 | 38 | # Specifying keepdim 39 | thunder_result = traced_foo(a, [1], keepdim=True, correction=2) 40 | torch_result = torch.var(a, [1], keepdim=True, correction=2) 41 | assert_close(thunder_result, torch_result) 42 | 43 | # Tests passing arguments as constants 44 | def foo(a): 45 | return ttorch.var(a, [0, 1], keepdim=True, correction=2) 46 | 47 | traced_foo = executor.make_callable(foo) 48 | 49 | a = torch.testing.make_tensor((4, 4), device=device, dtype=tdtype) 50 | 51 | thunder_result = traced_foo(a) 52 | torch_result = torch.var(a, [0, 1], keepdim=True, correction=2) 53 | assert_close(thunder_result, torch_result) 54 | 55 | 56 | @instantiate(dtypes=(thunder.float32,)) 57 | def test_torch_mean(executor, device, dtype): 58 | def foo(a, dim=None, keepdim=False, *, dtype=None): 59 | return ttorch.mean(a, dim, keepdim, dtype=dtype) 60 | 61 | traced_foo = executor.make_callable(foo) 62 | 63 | tdtype = ttorch.to_torch_dtype(dtype) 64 | a = torch.testing.make_tensor((4, 4), device=device, dtype=tdtype) 65 | 66 | # Full reduction 67 | thunder_result = traced_foo(a, [0, 1]) 68 | torch_result = torch.mean(a, [0, 1]) 69 | assert_close(thunder_result, torch_result) 70 | 71 | # Reduce along dim 1 72 | thunder_result = traced_foo(a, [1]) 73 | torch_result = torch.mean(a, [1]) 74 | assert_close(thunder_result, torch_result) 75 | 76 | # Reduce with () dims 77 | thunder_result = traced_foo(a, ()) 78 | torch_result = torch.mean(a, ()) 79 | assert_close(thunder_result, torch_result) 80 | 81 | 82 | @instantiate(dtypes=(thunder.float32,)) 83 | def test_var_mean(executor, device, dtype): 84 | def foo(a, dim=None, keepdim=False, *, correction=1): 85 | return ttorch.var_mean(a, dim, keepdim=keepdim, correction=correction) 86 | 87 | traced_foo = executor.make_callable(foo) 88 | 89 | tdtype = ttorch.to_torch_dtype(dtype) 90 | a = torch.testing.make_tensor((4, 4), device=device, dtype=tdtype) 91 | 92 | # Full reduction 93 | thunder_result = traced_foo(a, [0, 1]) 94 | torch_result = torch.var_mean(a, [0, 1]) 95 | assert_close(thunder_result, torch_result) 96 | 97 | # Reduce along dim 1 98 | thunder_result = traced_foo(a, [1]) 99 | torch_result = torch.var_mean(a, [1]) 100 | assert_close(thunder_result, torch_result) 101 | 102 | # Tests passing arguments as constants 103 | def foo(a): 104 | return ttorch.var_mean(a, [0, 1], keepdim=True, correction=2) 105 | 106 | traced_foo = executor.make_callable(foo) 107 | 108 | a = torch.testing.make_tensor((4, 4), device=device, dtype=tdtype) 109 | 110 | thunder_result = traced_foo(a) 111 | torch_result = torch.var_mean(a, [0, 1], keepdim=True, correction=2) 112 | assert_close(thunder_result, torch_result) 113 | -------------------------------------------------------------------------------- /thunder/tests/test_shape_ops.py: -------------------------------------------------------------------------------- 1 | import thunder 2 | import torch 3 | 4 | 5 | def test_pad_cast_value_itof(): 6 | """ 7 | Pad should cast the given value to the type of tensor and pad that value. 8 | """ 9 | 10 | def fqn(): 11 | x = torch.tensor([2, 3], dtype=torch.int32) 12 | y = torch.nn.functional.pad(x, pad=(1, 2), value=6.4) 13 | return y 14 | 15 | th_fqn = thunder.jit(fqn) 16 | v = th_fqn() 17 | assert v[0] == 6 18 | assert v[1] == 2 19 | assert v[2] == 3 20 | assert v[3] == 6 21 | assert v[4] == 6 22 | 23 | 24 | def test_pad_cast_value_ftoi(): 25 | """ 26 | Pad should cast the given value to the type of tensor and pad that value. 27 | """ 28 | 29 | def fqn(): 30 | x = torch.tensor([2.4, 3.8]) 31 | y = torch.nn.functional.pad(x, pad=(1, 2), value=1) 32 | return y 33 | 34 | th_fqn = thunder.jit(fqn) 35 | v = th_fqn() 36 | assert v[0] == 1.0 37 | assert v[1] == 2.4 38 | assert v[2] == 3.8 39 | assert v[3] == 1.0 40 | assert v[4] == 1.0 41 | -------------------------------------------------------------------------------- /thunder/tests/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytest 3 | import functools 4 | 5 | 6 | def is_output_differentiable(x): 7 | # grad_fn is set only if one of the input `requires_grad=True` 8 | # and the op is differentiable. 9 | # Example: 10 | # >>> x = torch.ones(3, requires_grad=True) 11 | # >>> y = torch.ones(3, requires_grad=False) 12 | # >>> (x + x).grad_fn # 13 | # >>> (y + y).grad_fn # None 14 | # >>> (y + x).grad_fn # 15 | # >>> (x < 1).grad_fn # None (non-differentiable op) 16 | # Op with differentiable and non-differentiable outputs. 17 | # >>> torch.topk(x, k=2) 18 | # torch.return_types.topk( 19 | # values=tensor([1., 1.], grad_fn=), 20 | # indices=tensor([0, 1])) 21 | # >>> torch.topk(torch.ones(3, requires_grad=False), k=2) 22 | # torch.return_types.topk( 23 | # values=tensor([1., 1.]), 24 | # indices=tensor([0, 1])) 25 | return x.grad_fn is not None or is_returning_self(x) 26 | 27 | 28 | def is_returning_self(x): 29 | if x.is_leaf and x.requires_grad: 30 | return True 31 | return False 32 | 33 | 34 | def filter_differentiable_outputs(outputs): 35 | if isinstance(outputs, torch.Tensor): 36 | # Otherwise `filter` below will 37 | # iterate over the Tensor data. 38 | outputs = [outputs] 39 | 40 | return list(filter(is_output_differentiable, outputs)) 41 | 42 | 43 | def is_sm120_orsm121(): 44 | return torch.cuda.get_device_capability() in ((12, 1), (12, 0)) 45 | 46 | 47 | def skip_on_sm120_and_sm121(fn): 48 | @functools.wraps(fn) 49 | def wrapped_fn(*args, **kwargs): 50 | if is_sm120_orsm121(): 51 | pytest.skip("Skipped on SM120/121") 52 | else: 53 | fn(*args, **kwargs) 54 | 55 | return wrapped_fn 56 | -------------------------------------------------------------------------------- /thunder/torch/experimental/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-AI/lightning-thunder/6720e82b18884a26bbadaffff2f33d836682cc9d/thunder/torch/experimental/__init__.py -------------------------------------------------------------------------------- /thunder/torch/experimental/dtensor_codeutils.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | import torch 3 | from thunder.torch.experimental.dtensor_utils import run_only_if_distributed_is_available 4 | 5 | if torch.distributed.is_available(): 6 | from torch.distributed.tensor._dtensor_spec import DTensorSpec 7 | 8 | 9 | @run_only_if_distributed_is_available 10 | def is_dtensor_spec(x: Any) -> bool: 11 | return isinstance(x, DTensorSpec) 12 | -------------------------------------------------------------------------------- /thunder/torch/langctx.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Callable 2 | 3 | from thunder.core.langctxs import LanguageContext, register_langctx, Languages, resolve_language 4 | from thunder.core.pytree import tree_flatten 5 | from thunder.core.proxies import TensorProxy 6 | 7 | # 8 | # Creates and registers the torch language context 9 | # 10 | # NOTE That this is done separately from the definition of thunder.torch operations, because the 11 | # language context must be available before those operations are defined 12 | 13 | _method_name_to_fn_map: dict[str, Callable] = {} 14 | _property_name_to_fn_map: dict[str, Callable] = {} 15 | 16 | 17 | class TorchCtx(LanguageContext): 18 | def __init__(self): 19 | super().__init__("torch") 20 | 21 | def has_method(self, id: str) -> bool: 22 | return id in _method_name_to_fn_map or id in _property_name_to_fn_map 23 | 24 | def get_method(self, id: str, *args, **kwargs) -> Callable: 25 | # Note: concrete implmenetations should only raise AttributeError or 26 | # return None for "missing" methods as the proxies will 27 | # route __getattr__ to here and hasattr relies on __getattr__ 28 | # throwing AttributeError (only) when the attribute does 29 | # not exist. 30 | inps, _ = tree_flatten((args, kwargs)) 31 | 32 | has_tensor_input: bool = False 33 | for x in inps: 34 | if isinstance(x, TensorProxy): 35 | has_tensor_input = True 36 | break 37 | if has_tensor_input: 38 | method: None | Callable = _method_name_to_fn_map.get(id, None) 39 | prop: None | Callable = _property_name_to_fn_map.get(id, None) 40 | if method is None and prop is None: 41 | raise AttributeError(f"The {self.name} language context has no method or attribute {id}") 42 | if method: 43 | return method 44 | else: 45 | return prop(inps[0]) 46 | 47 | # has_tensor_input is False 48 | # Defers to the CLANG language context when there are no tensor inputs= 49 | # (the clang language context handles operations on NumberProxies and Numbers) 50 | primsctx: LanguageContext = resolve_language(Languages.CLANG) 51 | if not primsctx.has_method(id): 52 | raise AttributeError( 53 | f"Attempting to call method {id} in the torch language context, but it has no tensor inputs and the primitive language context (which handles numbers) doesn't have the method" 54 | ) 55 | prim_method: Callable = primsctx.get_method(id, *args, **kwargs) 56 | return prim_method 57 | 58 | 59 | torchctx = TorchCtx() 60 | register_langctx(Languages.TORCH, torchctx) 61 | register_langctx("torch", torchctx) 62 | 63 | 64 | # Registers a method with the torch language context 65 | def register_method(method_name: str, method: Callable, /) -> None: 66 | _method_name_to_fn_map[method_name] = method 67 | 68 | 69 | def register_property(property_name, property) -> None: 70 | _property_name_to_fn_map[property_name] = property 71 | -------------------------------------------------------------------------------- /thunder/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | from .constant_folding import ConstantFolding 2 | from .materialization import MaterializationTransform 3 | from .qlora import LORATransform 4 | from .prune_prologue_checks import PrunePrologueChecks 5 | from .extraction_only_prologue_transform import ExtractionOnlyPrologueTransform 6 | 7 | 8 | __all__ = [ 9 | "ConstantFolding", 10 | "LORATransform", 11 | "MaterializationTransform", 12 | "PrunePrologueChecks", 13 | "ExtractionOnlyPrologueTransform", 14 | ] 15 | -------------------------------------------------------------------------------- /thunder/transforms/extraction_only_prologue_transform.py: -------------------------------------------------------------------------------- 1 | from thunder.core.prims import PrimIDs 2 | from thunder.core.proxies import ProxyTag 3 | from thunder.core.trace import from_trace 4 | from thunder.core.transform_common import Transform 5 | 6 | 7 | __all__ = [ 8 | "ExtractionOnlyPrologueTransform", 9 | ] 10 | 11 | 12 | class ExtractionOnlyPrologueTransform(Transform): 13 | """Exclude :func:`~thunder.core.prims.check_tensor_shape_and_metadata` from prologue trace of ThunderCompiler. 14 | 15 | This transform is mainly used by :class:`~thunder.dynamo.ThunderCompiler` to remove the check of input tensors 16 | when either they are :class:`torch.nn.Parameter` or all of them don't have any symbolic shape. 17 | 18 | Args: 19 | skip_check_on_input_tensors: If :obj:`True`, remove all the check from the prologue trace as TorchDynamo caching would do enough. 20 | Otherwise, remove the checks of tensor proxies with ``ProxyTag.STATIC_MEMORY_LOCATION``. 21 | """ 22 | 23 | def __init__(self, skip_check_on_input_tensors: bool = False): 24 | self.skip_check_on_input_tensors = skip_check_on_input_tensors 25 | 26 | def transform_traces_pre_prologue(self, prologue_trace, computation_trace, epilogue_trace, **kwargs): 27 | new_prologue_trace = from_trace(prologue_trace) 28 | new_bsyms = [] 29 | 30 | for bsym in prologue_trace.bound_symbols: 31 | # NOTE - We assume TensorProxy's tagged with `STATIC_MEMORY_LOCATION` to 32 | # be Parameters or Buffer. It should be safe to disable check for 33 | # tensors we deem to be static. 34 | if bsym.sym.id == PrimIDs.CHECK_TENSOR_SHAPE_AND_METADATA and ( 35 | self.skip_check_on_input_tensors or ProxyTag.STATIC_MEMORY_LOCATION in bsym.args[0].tags 36 | ): 37 | continue 38 | 39 | new_bsyms.append(bsym) 40 | 41 | new_prologue_trace.bound_symbols = new_bsyms 42 | 43 | new_prologue_trace.set_provenance("Extraction only prologue pass") 44 | return new_prologue_trace, computation_trace, epilogue_trace 45 | -------------------------------------------------------------------------------- /thunder/transforms/prune_prologue_checks.py: -------------------------------------------------------------------------------- 1 | import thunder 2 | from thunder.core import prims as prims 3 | 4 | 5 | class PrunePrologueChecks(thunder.core.transform_common.Transform): 6 | """A Transform to prune Prologue checks 7 | 8 | This transform removes prologue checks and can be applied when the user 9 | controls the environment enough to ensure that these checks would always 10 | succeed. By default, we remove all checks that use the module. 11 | """ 12 | 13 | def __init__(self, prune_all_checks=False): 14 | self.prune_all_checks = prune_all_checks 15 | 16 | def transform_traces_pre_prologue(self, prologue_trace, compute_trace, epilogue_trace, **kwargs): 17 | def is_check(bsym): 18 | return bsym.sym in { 19 | prims.check_tensor_shape_and_metadata, 20 | prims.check_string_value, 21 | prims.check_number_type_and_value, 22 | prims.check_len, 23 | prims.check_literal_like, 24 | } 25 | 26 | if not self.prune_all_checks: 27 | bsyms_to_skip = set() 28 | module_member_names = set() 29 | 30 | for bsym in prologue_trace.bound_symbols: 31 | if bsym.sym in {prims.unpack_trivial, prims.unpack_cache_info, prims.python_return}: 32 | # These don't have inputs but need to be skipped to not trigger false positives 33 | # python_return may have no inputs 34 | continue 35 | if bsym.sym is prims.unpack_function_obj: 36 | for o in bsym.flat_proxy_outs: 37 | module_member_names.add(o.name) 38 | continue 39 | 40 | input_names = {i.name in module_member_names for i in bsym.flat_proxy_args} 41 | if all(input_names): 42 | # This has the special case of no proxy inputs, which is the case for unpack_function_obj, 43 | # the root of module_member_names 44 | assert input_names, f"unexpected symbol {bsym.sym.name} without inputs" 45 | for o in bsym.flat_proxy_outs: 46 | module_member_names.add(o.name) 47 | if is_check(bsym): 48 | bsyms_to_skip.add(bsym) 49 | 50 | def should_skip_bsym(bsym): 51 | return bsym in bsyms_to_skip 52 | 53 | else: 54 | should_skip_bsym = is_check 55 | 56 | new_prologue_trace = thunder.core.trace.from_trace(prologue_trace) 57 | for bsym in prologue_trace.bound_symbols: 58 | if not should_skip_bsym(bsym): 59 | new_prologue_trace.bound_symbols.append(bsym.from_bsym()) 60 | 61 | new_prologue_trace = thunder.core.transform_common.dce(new_prologue_trace) 62 | new_prologue_trace.set_provenance(thunder.core.trace.TraceProvenance(f"{self.__class__.__name__}")) 63 | 64 | return new_prologue_trace, compute_trace, epilogue_trace 65 | -------------------------------------------------------------------------------- /thunder/transforms/utils.py: -------------------------------------------------------------------------------- 1 | import thunder 2 | from thunder.core import utils 3 | from thunder.core import prims 4 | from thunder.core.trace import TraceCtx 5 | from thunder.core.pytree import tree_map 6 | 7 | 8 | def get_checks(prologue_trace): 9 | # returns a dictionary mapping model param names to (check bsym, get param bsym 10 | check_dict = {} 11 | prologue_producers, prologue_consumers = utils.producers_and_consumers(prologue_trace) 12 | for bsym in prologue_trace.bound_symbols: 13 | if bsym.sym == prims.unpack_parameter or bsym.sym == prims.unpack_buffer: 14 | param_thunder_module, param_name = bsym.args 15 | checks = [ 16 | bsym2 for bsym2 in prologue_consumers[bsym.output] if bsym2.sym == prims.check_tensor_shape_and_metadata 17 | ] 18 | assert len(checks) == 1, ( 19 | f"expected each parameter and buffer to have exactly one checker, but {bsym.output} has {len(checks)}" 20 | ) 21 | assert isinstance(param_name, str) 22 | check_dict[param_name] = (checks[0], bsym) 23 | return check_dict 24 | 25 | 26 | def add_trace_output(trace, output, subindex: int | None = None): 27 | ret_node = trace.bound_symbols[-1] 28 | assert ret_node.sym == prims.python_return 29 | assert len(ret_node.args) == 1 30 | (ret_args,) = ret_node.args 31 | 32 | if subindex is None: 33 | ret_args = (*ret_args, output) 34 | else: 35 | assert isinstance(ret_args[subindex], tuple) 36 | ret_args = (*ret_args[:subindex], (*ret_args[subindex], output), *ret_args[subindex + 1 :]) 37 | 38 | ret_node.args = (ret_args,) 39 | 40 | 41 | def trace_with_replaced_proxy_metadata(trace: TraceCtx, proxy_replacement_metadata) -> TraceCtx: 42 | t = TraceCtx(trace.fn) 43 | 44 | proxymap: dict[str, thunder.Proxy] = {} 45 | 46 | def map_proxy(p): 47 | if isinstance(p, thunder.Proxy): 48 | return proxymap[p.name] 49 | return p 50 | 51 | def create_proxy(p): 52 | if isinstance(p, thunder.Proxy): 53 | if p.name in proxymap: # happens with subsymbols 54 | return proxymap[p.name] 55 | with thunder.core.trace.tracectx(t): 56 | np = p.replace(**proxy_replacement_metadata.get(p.name, {})) 57 | proxymap[p.name] = np 58 | return np 59 | return p 60 | 61 | def process_bound_symbols(src_bound_symbols, target_bound_symbols): 62 | for bsym in src_bound_symbols: 63 | new_args = tree_map(map_proxy, bsym.args) 64 | new_kwargs = tree_map(map_proxy, bsym.kwargs) 65 | new_output = tree_map(create_proxy, bsym.output) 66 | new_bsym = bsym.from_bsym(output=new_output, args=new_args, kwargs=new_kwargs, subsymbols=[]) 67 | target_bound_symbols.append(new_bsym) 68 | if len(bsym.subsymbols) > 0: 69 | process_bound_symbols(bsym.subsymbols, new_bsym.subsymbols) 70 | 71 | process_bound_symbols(trace.bound_symbols, t.bound_symbols) 72 | 73 | t.args = tree_map(map_proxy, trace.args) 74 | t.kwargs = tree_map(map_proxy, trace.kwargs) 75 | t._siginfo = trace._siginfo 76 | return t 77 | --------------------------------------------------------------------------------