├── .azure ├── docker-build.yml ├── gpu-tests.yml ├── notebook-runs.yml ├── remove-torch-lines.sh └── sanity-check.sh ├── .codecov.yml ├── .git-blame-ignore-revs ├── .github ├── CODEOWNERS ├── ISSUE_TEMPLATE │ ├── bug_report.md │ ├── documentation.md │ ├── feature_request.md │ └── program_coverage.md ├── PULL_REQUEST_TEMPLATE.md ├── dependabot.yml ├── labeling-config.yml ├── lightning-probot.yml ├── load-quickstart-report.py ├── run-benchmark-as-lit-jobs.py ├── run-quickstart-as-lit-jobs.py └── workflows │ ├── auto-cc.yml │ ├── ci-benchmark.yml │ ├── ci-checks.yml │ ├── ci-quickstart.yml │ ├── ci-testing.yml │ ├── docs-build.yml │ ├── label-conflicts.yml │ ├── labeler.yml │ ├── release-nightly.yml │ ├── release-pypi.yml │ └── stale-branches.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── MANIFEST.in ├── Makefile ├── README.md ├── dockers ├── README.md ├── ubuntu-cuda │ └── Dockerfile ├── with-apex │ └── Dockerfile └── with-dev │ └── Dockerfile ├── docs ├── .build_docs.sh ├── .readthedocs.yaml ├── Makefile ├── make.bat └── source │ ├── _static │ ├── copybutton.js │ └── images │ │ ├── LightningThunderDarkModewByline.png │ │ ├── LightningThunderLightModewByline.png │ │ ├── how_it_works.png │ │ ├── icon.svg │ │ ├── lightning_thunder_lightmode_nobyline.png │ │ ├── logo-large.svg │ │ ├── logo-small.svg │ │ ├── logo.png │ │ ├── logo.svg │ │ ├── normalized_training_throughput_zero2.png │ │ ├── pretrain_perf.png │ │ └── training_throughput_single.png │ ├── _templates │ └── theme_variables.jinja │ ├── advanced │ ├── contrib_thunder.rst │ ├── extending.rst │ └── inside_thunder.rst │ ├── basic │ ├── faq.rst │ ├── inspecting_traces.rst │ ├── mlp_mnist.rst │ ├── overview.rst │ └── sharp_edges.rst │ ├── conf.py │ ├── fundamentals │ ├── examine.rst │ ├── hello_world.rst │ └── installation.rst │ ├── index.rst │ ├── intermediate │ ├── additional_executors.rst │ ├── benchmarking.rst │ ├── ddp.rst │ └── whats_next.rst │ └── reference │ ├── clang │ └── index.rst │ ├── common │ └── index.rst │ ├── core │ ├── baseutils.rst │ ├── codeutils.rst │ ├── devices.rst │ ├── dtypes.rst │ ├── functionalization.rst │ ├── index.rst │ ├── langctxs.rst │ ├── prims.rst │ ├── proxies.rst │ ├── pytree.rst │ ├── rematerialization.rst │ ├── symbol.rst │ ├── trace.rst │ └── transforms.rst │ ├── distributed │ └── index.rst │ ├── dynamo │ └── index.rst │ ├── examine │ └── index.rst │ ├── executors │ ├── apexex.rst │ ├── cudnnex.rst │ ├── index.rst │ ├── nvfuserex.rst │ ├── passes.rst │ ├── pythonex.rst │ ├── torch_compile.rst │ ├── torchex.rst │ ├── triton_crossentropy.rst │ └── utils.rst │ ├── extend │ └── index.rst │ ├── recipes │ └── index.rst │ ├── thunder.rst │ ├── torch │ └── index.rst │ └── transforms │ └── index.rst ├── examples └── quickstart │ ├── hf_bert.py │ ├── hf_llm.py │ ├── mlp.py │ ├── requirements.txt │ ├── vit_hf.py │ └── vit_tv.py ├── notebooks ├── .ignore.ci ├── adding_custom_operator.ipynb ├── adding_custom_operator_backward.ipynb ├── dev_tutorials │ ├── extend.ipynb │ ├── fsdp_tutorial.ipynb │ └── thunder-add-vjp-rule.md ├── extend_thunder_with_cuda_python.ipynb ├── hello_world_thunderfx.ipynb ├── liger_kernel.ipynb ├── thunder_trace_intro.ipynb ├── writing_a_trace_transform_cpu_offloading.ipynb └── zero_to_thunder.ipynb ├── pyproject.toml ├── requirements.txt ├── requirements ├── base.txt ├── devel.txt ├── docs.txt ├── notebooks.txt └── test.txt ├── scripts ├── bisect_nvfuser.py ├── build_from_source.sh └── validate_build.py ├── setup.py └── thunder ├── __about__.py ├── __init__.py ├── benchmarks ├── __init__.py ├── benchmark_hf.py ├── benchmark_litgpt.py ├── conftest.py ├── distributed.py ├── einsum.py ├── targets.py ├── test_benchmark_litgpt.py └── utils.py ├── clang ├── __init__.py └── langctx.py ├── common.py ├── core ├── __init__.py ├── baseutils.py ├── codeutils.py ├── compile_data.py ├── devices.py ├── dtypes.py ├── functionalization.py ├── interpreter.py ├── jit_ext.py ├── langctxs.py ├── module.py ├── options.py ├── patterns.py ├── prims.py ├── profile.py ├── proxies.py ├── pytree.py ├── recipe.py ├── rematerialization.py ├── symbol.py ├── trace.py ├── trace_interpreter.py ├── transform_common.py ├── transforms.py ├── update_aliases.py ├── utils.py └── vjp_utils.py ├── dev_utils ├── __init__.py ├── benchmark.py ├── check_trace.py ├── debug_transform.py ├── nvtx_profile_transform.py └── utils.py ├── distributed ├── __init__.py ├── bucketing.py ├── checkpoint.py ├── prims.py ├── tensor_parallel │ ├── __init__.py │ ├── column_wise.py │ ├── common.py │ ├── optimize_comm.py │ └── row_wise.py ├── transforms │ ├── __init__.py │ ├── ddp.py │ ├── ddp_v2.py │ ├── fsdp.py │ └── fsdp_v2.py └── utils.py ├── dynamo ├── __init__.py ├── benchmark_utils.py ├── compiler.py ├── compiler_graph_benchmark.py ├── report.py ├── repro_script_template.py ├── splitter.py └── utils.py ├── examine ├── __init__.py └── memory_calculation.py ├── executors ├── __init__.py ├── apex_entropyex_impl.py ├── apex_fused_rms_norm_impl.py ├── apexex.py ├── cudnn_layernormex.py ├── cudnnex.py ├── data_dependent_partition.py ├── fa3ex.py ├── nvfuserex.py ├── nvfuserex_impl.py ├── passes.py ├── pythonex.py ├── sdpaex.py ├── torch_autograd.py ├── torch_compile.py ├── torchex.py ├── transformer_engine_v2ex.py ├── transformer_engine_v2ex_impl.py ├── transformer_engineex.py ├── triton_crossentropy.py ├── triton_crossentropy_impl.py ├── triton_utils.py └── utils.py ├── extend └── __init__.py ├── numpy ├── __init__.py └── langctx.py ├── plugins ├── __init__.py ├── distributed.py ├── fp8.py ├── quantization.py └── reduce_overhead.py ├── py.typed ├── recipes ├── __init__.py ├── base.py └── hf_transformers.py ├── tests ├── README.md ├── __init__.py ├── bf16.py ├── conftest.py ├── distributed │ ├── __init__.py │ ├── helper.py │ ├── modules.py │ ├── test_checkpoint.py │ ├── test_ddp.py │ ├── test_fsdp.py │ ├── test_ops.py │ └── test_tensor_parallel.py ├── framework.py ├── hf_bart_self_attn.py ├── litgpt_model.py ├── llama2_model.py ├── make_tensor.py ├── module_example.py ├── nanogpt_model.py ├── opinfos.py ├── test_apex_cross_entropy_executor.py ├── test_apex_fused_norms.py ├── test_auto_register_torchops.py ├── test_autocast.py ├── test_check_trace.py ├── test_core.py ├── test_cudnn_executor.py ├── test_dynamo.py ├── test_einops.py ├── test_elementwise.py ├── test_examine.py ├── test_examine_memory.py ├── test_extend.py ├── test_fa3_executor.py ├── test_grad.py ├── test_inplace_copy.py ├── test_inplace_functionalization.py ├── test_interpreter.py ├── test_jit_general.py ├── test_networks.py ├── test_nvfuser.py ├── test_nvfuser_remat.py ├── test_ops.py ├── test_patterns.py ├── test_pythonex.py ├── test_randomness.py ├── test_recipes.py ├── test_reductions.py ├── test_sdpaex_executor.py ├── test_shape_ops.py ├── test_torch_compile_executor.py ├── test_transformer_engine_executor.py ├── test_transformer_engine_v2_executor.py ├── test_transforms.py ├── test_triton_ce.py └── test_update_aliases.py ├── torch ├── __init__.py ├── default_torch_ops.py └── langctx.py └── transforms ├── __init__.py ├── autocast.py ├── autodiff.py ├── constant_folding.py ├── cudagraph.py ├── extraction_only_prologue_transform.py ├── materialization.py ├── prune_prologue_checks.py ├── qlora.py ├── quantization.py └── utils.py /.azure/docker-build.yml: -------------------------------------------------------------------------------- 1 | trigger: 2 | tags: 3 | include: ["*"] 4 | branches: 5 | include: ["main"] 6 | paths: 7 | include: 8 | - ".azure/docker-build.yml" 9 | - "dockers/**" 10 | - "requirements.txt" 11 | - "requirements/*.txt" 12 | - "setup.py" 13 | exclude: 14 | - "*.md" 15 | - "**/*.md" 16 | 17 | pr: 18 | branches: 19 | include: ["*"] 20 | paths: 21 | include: 22 | - ".azure/docker-build.yml" 23 | - "dockers/**" 24 | - "requirements.txt" 25 | - "requirements/*.txt" 26 | - "setup.py" 27 | exclude: 28 | - "*.md" 29 | - "**/*.md" 30 | 31 | schedules: 32 | - cron: "0 */2 * * *" 33 | displayName: rebuild dockers for CI every 2 hours 34 | branches: 35 | include: ["main"] 36 | 37 | jobs: 38 | - job: build_push 39 | strategy: 40 | matrix: 41 | "cuda 12.6 | torch 2.7.1 | cudnn FE v1.10.0": 42 | { CUDA_VERSION: "12.6.3", TORCH_VERSION: "2.7.1", TRITON_VERSION: "3.3.1", CUDNN_FRONTEND_VERSION: "1.10.0" } 43 | "cuda 12.6 | torch nightly | cudnn FE v1.10.0": 44 | { CUDA_VERSION: "12.6.3", TORCH_VERSION: "main", TORCH_INSTALL: "source", CUDNN_FRONTEND_VERSION: "1.10.0" } 45 | #'cuda 12.1': # this version - '8.9.5.29-1+cuda12.1' for 'libcudnn8' was not found 46 | # how much time to give 'run always even if cancelled tasks' before stopping them 47 | cancelTimeoutInMinutes: "2" 48 | timeoutInMinutes: "95" 49 | variables: 50 | UBUNTU_VERSION: "24.04" 51 | PYTHON_VERSION: "3.10" 52 | imageRepository: "pytorchlightning/lightning-thunder" 53 | imageTag: "ubuntu$(UBUNTU_VERSION)-cuda$(CUDA_VERSION)-cudnn-fe$(CUDNN_FRONTEND_VERSION)-py$(PYTHON_VERSION)-pt_${TORCH_VERSION/v/}" 54 | pool: "lit-rtx-3090" 55 | workspace: 56 | clean: all 57 | steps: 58 | - bash: | 59 | set -e 60 | echo $imageTag 61 | nvidia-smi 62 | docker image build \ 63 | -t $(imageRepository):$(imageTag) \ 64 | -f "dockers/ubuntu-cuda/Dockerfile" \ 65 | --build-arg UBUNTU_VERSION="$(UBUNTU_VERSION)" \ 66 | --build-arg CUDA_VERSION="$(CUDA_VERSION)" \ 67 | --build-arg CUDNN_FRONTEND_VERSION="v$(CUDNN_FRONTEND_VERSION)" \ 68 | --build-arg PYTHON_VERSION="$(PYTHON_VERSION)" \ 69 | --build-arg TORCH_VERSION="$(TORCH_VERSION)" \ 70 | --build-arg TORCH_INSTALL="$(TORCH_INSTALL)" \ 71 | --build-arg TRITON_VERSION="$(TRITON_VERSION)" \ 72 | . --no-cache 73 | timeoutInMinutes: "95" 74 | displayName: "Build base image" 75 | 76 | - bash: | 77 | docker image build \ 78 | -t $(imageRepository):$(imageTag)-apex \ 79 | -f "dockers/with-apex/Dockerfile" \ 80 | --build-arg BASE_IMAGE_TAG="$(imageTag)" \ 81 | . --no-cache 82 | timeoutInMinutes: "25" 83 | displayName: "Build Apex image" 84 | 85 | - bash: | 86 | docker image build \ 87 | -t $(imageRepository):$(imageTag)-dev \ 88 | -f "dockers/with-dev/Dockerfile" \ 89 | --build-arg BASE_IMAGE_TAG="$(imageTag)-apex" \ 90 | . --no-cache 91 | timeoutInMinutes: "25" 92 | displayName: "Build Dev image" 93 | 94 | - bash: | 95 | docker image ls | grep $(imageRepository) 96 | # drop pt from requirements so not to interfere with the existing one 97 | bash .azure/remove-torch-lines.sh requirements/base.txt 98 | mv .azure azure 99 | docker run --rm --gpus=all -v .:/workspace $(imageRepository):$(imageTag)-dev \ 100 | bash -c "cd /workspace && ls -lh . && \ 101 | pip install -q . && \ 102 | bash azure/sanity-check.sh" 103 | timeoutInMinutes: "5" 104 | displayName: "Sanity check" 105 | 106 | - bash: | 107 | set -e 108 | echo $(imageRepository):$(imageTag) 109 | echo $(DOCKERHUB_PAT) | docker login --username $(DOCKERHUB_USER) --password-stdin 110 | docker push $(imageRepository):$(imageTag)-dev 111 | condition: ne(variables['Build.Reason'], 'PullRequest') 112 | timeoutInMinutes: "35" 113 | displayName: "Push base image" 114 | -------------------------------------------------------------------------------- /.azure/notebook-runs.yml: -------------------------------------------------------------------------------- 1 | trigger: 2 | tags: 3 | include: ["*"] 4 | branches: 5 | include: 6 | - "main" 7 | - "release/*" 8 | - "refs/tags/*" 9 | 10 | pr: 11 | branches: 12 | include: ["*"] 13 | 14 | jobs: 15 | - job: jupyter 16 | strategy: 17 | matrix: 18 | "notebooks w/ torch 2.7": 19 | docker-image: "ubuntu24.04-cuda12.6.3-cudnn-fe1.10.0-py3.10-pt_2.7.1-dev" 20 | "notebooks w/ torch-nightly": 21 | docker-image: "ubuntu24.04-cuda12.6.3-cudnn-fe1.10.0-py3.10-pt_main-dev" 22 | # how long to run the job before automatically cancelling 23 | timeoutInMinutes: "45" 24 | # how much time to give 'run always even if cancelled tasks' before stopping them 25 | cancelTimeoutInMinutes: "2" 26 | pool: "lit-rtx-3090" 27 | variables: 28 | DEVICES: $( python -c 'name = "$(Agent.Name)" ; gpus = name.split("_")[-1] if "_" in name else "0"; print(gpus)' ) 29 | TORCH_HOME: "/var/tmp/torch" 30 | PIP_CACHE_DIR: "/var/tmp/pip" 31 | container: 32 | image: "pytorchlightning/lightning-thunder:$(docker-image)" 33 | options: "--gpus=all --shm-size=16g -v /var/tmp:/var/tmp" 34 | workspace: 35 | clean: all 36 | steps: 37 | - bash: | 38 | echo $(DEVICES) 39 | lspci | egrep 'VGA|3D' 40 | whereis nvidia 41 | nvidia-smi 42 | which python && which pip 43 | python --version 44 | pip --version 45 | pip list 46 | echo "##vso[task.setvariable variable=CUDA_VISIBLE_DEVICES]$(DEVICES)" 47 | displayName: "Image info & NVIDIA" 48 | 49 | - bash: | 50 | # drop pt from requirements so not to interfere with the existing one 51 | bash .azure/remove-torch-lines.sh requirements/base.txt 52 | cat requirements/base.txt 53 | # double check on test requirements 54 | pip install -r requirements/notebooks.txt 55 | # install this package 56 | python setup.py develop 57 | displayName: "Install package & ..." 58 | 59 | - bash: | 60 | set -ex 61 | bash .azure/sanity-check.sh 62 | displayName: "Sanity check / details" 63 | 64 | - bash: | 65 | set -ex 66 | # list all notebooks in this folder 67 | find . -name "*.ipynb" > all.txt 68 | # drop all "./" from beginning of each line 69 | sed -i 's/^\.\///' all.txt 70 | # filter out the ones that are listed in .ignore.ci 71 | grep -Fxv -f .ignore.ci all.txt > ci.txt 72 | # iterate over all listed notebooks and execute them with jupyter 73 | while read -r line; do 74 | echo "Processing $line" 75 | jupyter execute $line --timeout=300 76 | done <<< $(cat ci.txt) 77 | workingDirectory: "notebooks/" 78 | displayName: "Execute notebooks" 79 | -------------------------------------------------------------------------------- /.azure/remove-torch-lines.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Get the name of the file to process 4 | filename=$1 5 | 6 | # Create a temporary file to store the filtered lines 7 | tempfile=$(mktemp) 8 | 9 | # Loop through the original file and remove lines including the word "torch" 10 | while read -r line; do 11 | if [[ ! "$line" =~ "torch" ]]; then 12 | echo "$line" >> "$tempfile" 13 | fi 14 | done < "$filename" 15 | 16 | # Move the temporary file to the original file 17 | mv "$tempfile" "$filename" 18 | -------------------------------------------------------------------------------- /.azure/sanity-check.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -ex 4 | pip list 5 | python -c "import torch ; assert torch.cuda.is_available(), 'missing GPU support!'" 6 | python -c "import torch ; v = torch.__version__ ; assert str(v).startswith('2'), v" 7 | python -c "from thunder.executors import nvfuser_available ; assert nvfuser_available(), 'nvFuser is missing!'" 8 | python -c "from thunder.executors.triton_utils import triton_version ; assert triton_version() is not None, 'triton is missing!'" 9 | -------------------------------------------------------------------------------- /.codecov.yml: -------------------------------------------------------------------------------- 1 | # see https://docs.codecov.io/docs/codecov-yaml 2 | # Validation check: 3 | # $ curl --data-binary @.codecov.yml https://codecov.io/validate 4 | 5 | # https://docs.codecov.io/docs/codecovyml-reference 6 | codecov: 7 | bot: "codecov-io" 8 | strict_yaml_branch: "yaml-config" 9 | require_ci_to_pass: yes 10 | notify: 11 | # after_n_builds: 2 12 | wait_for_ci: yes 13 | 14 | coverage: 15 | precision: 0 # 2 = xx.xx%, 0 = xx% 16 | round: nearest # how coverage is rounded: down/up/nearest 17 | range: 40...100 # custom range of coverage colors from red -> yellow -> green 18 | status: 19 | # https://codecov.readme.io/v1.0/docs/commit-status 20 | project: 21 | default: 22 | informational: true 23 | target: 99% # specify the target coverage for each commit status 24 | threshold: 30% # allow this little decrease on project 25 | # https://github.com/codecov/support/wiki/Filtering-Branches 26 | # branches: main 27 | if_ci_failed: error 28 | # https://github.com/codecov/support/wiki/Patch-Status 29 | patch: 30 | default: 31 | informational: true 32 | target: 50% # specify the target "X%" coverage to hit 33 | # threshold: 50% # allow this much decrease on patch 34 | changes: false 35 | 36 | # https://docs.codecov.com/docs/github-checks#disabling-github-checks-patch-annotations 37 | github_checks: 38 | annotations: false 39 | 40 | parsers: 41 | gcov: 42 | branch_detection: 43 | conditional: true 44 | loop: true 45 | macro: false 46 | method: false 47 | javascript: 48 | enable_partials: false 49 | 50 | comment: false 51 | -------------------------------------------------------------------------------- /.git-blame-ignore-revs: -------------------------------------------------------------------------------- 1 | # Move optimize_allreduce_in_ddp_backward to thunder.distributed.transforms namespace #2087 2 | 7a3d5d4e64da010eea42210a065960e3568870ac 3 | # Split `test_ddp` into `test_ddp`, `test_fsdp` and `test_ops` (#772) 4 | b63e8713141044ed8c184f2f4c8305048d477319 5 | -------------------------------------------------------------------------------- /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | # Each line is a file pattern followed by one or more owners. 2 | 3 | # These owners will be the default owners for everything in the repo. Unless a later match takes precedence, 4 | # @global-owner1, @global-owner2, and @global-owner3 will be requested for review when someone opens a pull request. 5 | 6 | # Thank you, our previous code owners for their service: 7 | # @carmocca 8 | 9 | * @mruberry @lantiga @t-vi 10 | 11 | # CI/CD and configs 12 | /.azure/ @borda @lantiga @t-vi 13 | /.github/ @borda @lantiga @t-vi 14 | /dockers/ @borda @lantiga @t-vi 15 | Makefile @borda @lantiga @t-vi 16 | *.yml @borda @lantiga @t-vi 17 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | assignees: '' 6 | --- 7 | 8 | *Note*: If you have a model or program that is not supported yet but should be, please use the program coverage template. 9 | 10 | ## 🐛 Bug 11 | 12 | 13 | 14 | ### To Reproduce 15 | 16 | Steps to reproduce the behavior: 17 | 18 | 1. Go to '...' 19 | 1. Run '....' 20 | 1. Scroll down to '....' 21 | 1. See error 22 | 23 | 24 | 25 | #### Code sample 26 | 27 | 29 | 30 | ### Expected behavior 31 | 32 | 33 | 34 | ### Environment 35 | 36 | - PyTorch Version (e.g., 1.0): 37 | - OS (e.g., Linux): 38 | - How you installed PyTorch (`conda`, `pip`, source): 39 | - Build command you used (if compiling from source): 40 | - Python version: 41 | - CUDA/cuDNN version: 42 | - GPU models and configuration: 43 | - Any other relevant information: 44 | 45 | ### Additional context 46 | 47 | 48 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/documentation.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Typos and doc fixes 3 | about: Typos and doc fixes 4 | title: '' 5 | labels: documentation 6 | assignees: '' 7 | --- 8 | 9 | ## 📚 Documentation 10 | 11 | For typos and doc fixes, please go ahead and: 12 | 13 | 1. Create an issue. 14 | 1. Fix the typo. 15 | 1. Submit a PR. 16 | 17 | Thanks! 18 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: enhancement 6 | assignees: '' 7 | --- 8 | 9 | ## 🚀 Feature 10 | 11 | 12 | 13 | ### Motivation 14 | 15 | 16 | 17 | ### Pitch 18 | 19 | 20 | 21 | ### Alternatives 22 | 23 | 24 | 25 | ### Additional context 26 | 27 | 28 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/program_coverage.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Program Coverage 3 | about: Expand the programs / models Thunder can process 4 | title: '' 5 | labels: program-coverage 6 | assignees: '' 7 | --- 8 | 9 | ## 🚀 Model / language coverage 10 | 11 | 12 | 13 | ### Pitch 14 | 15 | 16 | 17 | ### Alternatives / Potential work-arounds 18 | 19 | 24 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 |
2 | Before submitting 3 | 4 | - [ ] Was this discussed/approved via a Github issue? (no need for typos and docs improvements) 5 | - [ ] Did you read the [contributor guideline](https://github.com/Lightning-AI/pytorch-lightning/blob/main/.github/CONTRIBUTING.md), Pull Request section? 6 | - [ ] Did you make sure to update the docs? 7 | - [ ] Did you write any new necessary tests? 8 | 9 |
10 | 11 | ## What does this PR do? 12 | 13 | Fixes # (issue). 14 | 15 | ## PR review 16 | 17 | Anyone in the community is free to review the PR once the tests have passed. 18 | If we didn't discuss your PR in Github issues there's a high chance it will not be merged. 19 | 20 | ## Did you have fun? 21 | 22 | Make sure you had fun coding 🙃 23 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # Basic dependabot.yml file with 2 | # minimum configuration for two package managers 3 | 4 | version: 2 5 | updates: 6 | # Enable version updates for python 7 | - package-ecosystem: "pip" 8 | # Look for a `requirements` in the `root` directory 9 | directory: "/" 10 | schedule: 11 | interval: "monthly" 12 | pull-request-branch-name: 13 | separator: "-" 14 | # Allow up to 5 open pull requests for pip dependencies 15 | open-pull-requests-limit: 5 16 | 17 | # Enable version updates for GitHub Actions 18 | - package-ecosystem: "github-actions" 19 | directory: "/" 20 | schedule: 21 | interval: "monthly" 22 | pull-request-branch-name: 23 | separator: "-" 24 | # Allow up to 5 open pull requests for GitHub Actions 25 | open-pull-requests-limit: 5 26 | groups: 27 | GHA-updates: 28 | patterns: 29 | - "*" 30 | -------------------------------------------------------------------------------- /.github/labeling-config.yml: -------------------------------------------------------------------------------- 1 | documentation: 2 | - changed-files: 3 | - any-glob-to-any-file: 4 | - docs/**/* 5 | - dev_tutorials/* 6 | 7 | "ci": 8 | - changed-files: 9 | - any-glob-to-any-file: 10 | - .azure/* 11 | - .github/* 12 | - .github/workflows/* 13 | - dockers/**/* 14 | 15 | "docker": 16 | - changed-files: 17 | - any-glob-to-any-file: 18 | - dockers/**/* 19 | - .azure/docker-build.yml 20 | 21 | "install": 22 | - changed-files: 23 | - any-glob-to-any-file: 24 | - setup.py 25 | - pyproject.toml 26 | - MANIFEST.in 27 | 28 | "dependencies": 29 | - changed-files: 30 | - any-glob-to-any-file: 31 | - requirements/* 32 | - requirements.txt 33 | -------------------------------------------------------------------------------- /.github/lightning-probot.yml: -------------------------------------------------------------------------------- 1 | tracking_issue: 72 2 | -------------------------------------------------------------------------------- /.github/load-quickstart-report.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | 4 | 5 | def main(report_path: str = "quickstart_report.json"): 6 | """Load and print the quickstart report.""" 7 | with open(report_path) as fp: 8 | report = json.load(fp) 9 | 10 | success_count = sum(status == "completed" for status in report.values()) 11 | overall_status = "🟢" if success_count == len(report) else "⛔" 12 | print(f"Quickstart report {overall_status} with {success_count} out of {len(report)} successful:") 13 | # Sort so that entries with status "success" (or "completed") are last 14 | for name, status in sorted(report.items(), key=lambda x: x[1] == "completed"): 15 | status_icon = "✔️" if status == "completed" else "❌" 16 | print(f"{status_icon} {name}") 17 | 18 | 19 | if __name__ == "__main__": 20 | # optional path to the report file 21 | system_args = sys.argv[1:] 22 | main_args = {"report_path": system_args[0]} if len(system_args) > 0 else {} 23 | main(**main_args) 24 | -------------------------------------------------------------------------------- /.github/run-benchmark-as-lit-jobs.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import json 3 | import sys 4 | import time 5 | from datetime import datetime 6 | 7 | from lightning_sdk import Studio, Job, Machine, Status 8 | 9 | 10 | def main(gh_run_id: str = ""): 11 | if not gh_run_id: 12 | gh_run_id = datetime.now().strftime("%Y-%m-%d|%H:%M:%S") 13 | print("Creating studio...") 14 | s = Studio(f"thunder-benchmark-run{gh_run_id}", "oss-thunder", org="lightning-ai", create_ok=True) 15 | 16 | print("Uploading package and benchmark script...") 17 | s.upload_folder("dist", remote_path="dist") 18 | pkg_path = glob.glob("dist/*.whl")[0] 19 | s.upload_file("thunder/benchmarks/benchmark_hf.py", remote_path="benchmarks/benchmark_hf.py") 20 | 21 | print("Starting studio...") 22 | s.start() 23 | print("Installing Thunder and dependencies...") 24 | s.run(f"pip install {pkg_path} -U transformers 'numpy<2.0' 'nvfuser_cu128_torch27==0.2.27.dev20250501'") 25 | 26 | print("Running HF benchmark script...") 27 | job = Job.run( 28 | name=f"benchmark-run{gh_run_id}", 29 | command="pip list && python benchmarks/benchmark_hf.py", 30 | studio=s, 31 | machine=Machine.L40S, 32 | interruptible=True, 33 | ) 34 | 35 | print("Stopping studio...") 36 | s.stop() 37 | 38 | print("Waiting for job to finish...") 39 | job.wait() 40 | status = str(job.status).lower() 41 | print(f"[{job.status}]\t {job.name}") 42 | 43 | report = {"benchmark_hf.py": status} 44 | with open("benchmark_hf_report.json", "w") as fp: 45 | json.dump(report, fp, indent=4) 46 | 47 | if job.status != Status.Completed: 48 | print("=" * 80) 49 | print(f"===== benchmark_hf.py -> {job.status} =====") 50 | print("=" * 80) 51 | print(job.logs) 52 | print("=" * 80) 53 | time.sleep(3) 54 | raise RuntimeError(f"Benchmark HF job {job.status}") 55 | # clean up 56 | job.delete() 57 | s.delete() 58 | 59 | 60 | if __name__ == "__main__": 61 | # parse command line arguments 62 | args = sys.argv[1:] 63 | main(*args) 64 | -------------------------------------------------------------------------------- /.github/run-quickstart-as-lit-jobs.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | from datetime import datetime 4 | import glob 5 | import os.path 6 | 7 | from lightning_sdk import Studio, Job, Machine, Status 8 | 9 | 10 | def main(gh_run_id: str = ""): 11 | if not gh_run_id: 12 | gh_run_id = datetime.now().strftime("%Y-%m-%d|%H:%M:%S") 13 | print("Creating studio...") 14 | s = Studio(f"thunder-quickstarts-run{gh_run_id}", "oss-thunder", org="lightning-ai", create_ok=True) 15 | print("Uploading package and scripts...") 16 | s.upload_folder("dist", remote_path="dist") 17 | pkg_path = glob.glob("dist/*.whl")[0] 18 | s.upload_folder("examples/quickstart", remote_path="quickstart") 19 | 20 | print("Starting studio...") 21 | s.start() 22 | print("Installing Thunder and other requirements...") 23 | s.run(f"pip install {pkg_path} -U -r quickstart/requirements.txt") 24 | 25 | ls_quickstart = glob.glob("examples/quickstart/*.py") 26 | print("Found quickstart scripts:", ls_quickstart) 27 | 28 | print("Running quickstart scripts...") 29 | jobs = { 30 | os.path.basename(script): Job.run( 31 | name=f"ci-run{gh_run_id}_{script}", 32 | command=f"pip list && python quickstart/{os.path.basename(script)}", 33 | studio=s, 34 | machine=Machine.L40S, 35 | interruptible=True, 36 | ) 37 | for script in ls_quickstart 38 | } 39 | 40 | print("Stopping studio...") 41 | s.stop() 42 | 43 | print("Waiting for jobs to finish...") 44 | report, failures = {}, {} 45 | for name, job in jobs.items(): 46 | job.wait() 47 | print(f"[{job.status}]\t {job.name}") 48 | report[name] = str(job.status).lower() 49 | if job.status != Status.Completed: 50 | failures[name] = job.logs 51 | else: # clean up successful jobs 52 | job.delete() 53 | 54 | with open("quickstart_report.json", "w") as fp: 55 | json.dump(report, fp, indent=4) 56 | 57 | print("Showing logs of failed jobs...") 58 | separator = "=" * 80 59 | for name, logs in failures.items(): 60 | offset = "=" * (80 - 5 - 2 - len(name)) 61 | print(f"{separator}\n===== {name} {offset}\n{separator}") 62 | print(logs) 63 | print(separator + "\n" * 5) 64 | 65 | assert not failures 66 | 67 | print("Cleaning up...") 68 | s.delete() 69 | 70 | 71 | if __name__ == "__main__": 72 | # parse command line arguments 73 | args = sys.argv[1:] 74 | main(*args) 75 | -------------------------------------------------------------------------------- /.github/workflows/auto-cc.yml: -------------------------------------------------------------------------------- 1 | name: Probot 2 | 3 | on: 4 | issues: 5 | types: [labeled] 6 | # should use `pull_request_target` but it's blocked by 7 | # https://github.com/probot/probot/issues/1635 8 | # so this job will not run on forks until the above is fixed 9 | pull_request: 10 | types: [labeled, ready_for_review] 11 | 12 | jobs: 13 | auto-cc: 14 | runs-on: ubuntu-latest 15 | if: github.event_name == 'issue' || github.event.pull_request.draft == false 16 | timeout-minutes: 5 17 | steps: 18 | - uses: Lightning-AI/probot@v5 19 | env: 20 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 21 | with: 22 | job: auto-cc 23 | -------------------------------------------------------------------------------- /.github/workflows/ci-benchmark.yml: -------------------------------------------------------------------------------- 1 | name: Benchmark Models 2 | 3 | on: 4 | workflow_dispatch: {} 5 | pull_request: 6 | paths: 7 | - ".github/workflows/benchmark-hf.yml" 8 | - ".github/run-benchmark-as-lit-jobs.py" 9 | push: # enable tracing which commit to master broke it... 10 | branches: [main] 11 | 12 | concurrency: 13 | group: ${{ github.workflow }}-${{ github.ref }} 14 | cancel-in-progress: false 15 | 16 | defaults: 17 | run: 18 | shell: bash 19 | 20 | jobs: 21 | launcher-benchmark: 22 | runs-on: "ubuntu-22.04" 23 | 24 | steps: 25 | - uses: actions/checkout@v4 26 | 27 | - uses: actions/setup-python@v5 28 | with: 29 | python-version: "3.10" 30 | 31 | - name: Build Thunder Package 32 | run: | 33 | pip install -U build 34 | python -m build --sdist --wheel --outdir dist/ 35 | ls -l dist/ 36 | 37 | - name: Launch Benchmark Job in Lightning Studio 38 | env: 39 | LIGHTNING_USER_ID: ${{ secrets.LIGHTNING_USER_ID }} 40 | LIGHTNING_API_KEY: ${{ secrets.LIGHTNING_API_KEY }} 41 | run: | 42 | pip install lightning_sdk -U -q 43 | python .github/run-benchmark-as-lit-jobs.py ${{ github.run_id }} 44 | 45 | - name: Post Slack Notification 46 | if: always() && github.event_name != 'pull_request' 47 | uses: act10ns/slack@v2 48 | with: 49 | webhook-url: ${{ secrets.SLACK_WEBHOOK_URL }} 50 | status: ${{ job.status }} 51 | message: | 52 | *Benchmark Triggered Manually* - [${{ job.status }}] 53 | ref: ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }} 54 | -------------------------------------------------------------------------------- /.github/workflows/ci-checks.yml: -------------------------------------------------------------------------------- 1 | name: General checks 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: {} 7 | 8 | concurrency: 9 | group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref }} 10 | cancel-in-progress: ${{ ! (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/release/')) }} 11 | 12 | jobs: 13 | check-schema: 14 | uses: Lightning-AI/utilities/.github/workflows/check-schema.yml@main 15 | with: 16 | azure-dir: ".azure" 17 | 18 | check-package: 19 | uses: Lightning-AI/utilities/.github/workflows/check-package.yml@main 20 | with: 21 | actions-ref: main 22 | import-name: "thunder" 23 | artifact-name: dist-packages-${{ github.sha }} 24 | testing-matrix: | 25 | { 26 | "os": ["ubuntu-latest", "macOS-latest", "windows-latest"], 27 | "python-version": ["3.10", "3.11"] 28 | } 29 | -------------------------------------------------------------------------------- /.github/workflows/ci-quickstart.yml: -------------------------------------------------------------------------------- 1 | name: CI quickstart 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | paths: 8 | - ".github/workflows/ci-quickstart.yml" 9 | - ".github/run-quickstart-as-lit-jobs.py" 10 | - ".github/load-quickstart-report.py" 11 | - "examples/quickstart/*" 12 | workflow_dispatch: {} 13 | schedule: 14 | - cron: "0 0 * * *" # every midnight 15 | 16 | concurrency: 17 | group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref }} 18 | cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} 19 | 20 | defaults: 21 | run: 22 | shell: bash 23 | 24 | jobs: 25 | launcher-quickstart: 26 | runs-on: "ubuntu-22.04" 27 | #timeout-minutes: 55 28 | steps: 29 | - uses: actions/checkout@v4 30 | - uses: actions/setup-python@v5 31 | with: 32 | python-version: "3.10" 33 | 34 | - name: Build package 35 | run: | 36 | pip install -U build 37 | python -m build --sdist --wheel --outdir dist/ 38 | ls -l dist/ 39 | 40 | - name: Run scripts in jobs 41 | env: 42 | LIGHTNING_USER_ID: ${{ secrets.LIGHTNING_USER_ID }} 43 | LIGHTNING_API_KEY: ${{ secrets.LIGHTNING_API_KEY }} 44 | run: | 45 | pip install lightning_sdk -U -q 46 | python .github/run-quickstart-as-lit-jobs.py ${{ github.run_id }} 47 | 48 | - name: Load report 49 | if: always() 50 | id: load-report 51 | run: | 52 | report=$(python .github/load-quickstart-report.py) 53 | echo "REPORT<> $GITHUB_ENV 54 | echo "$report" >> $GITHUB_ENV 55 | echo "EOF" >> $GITHUB_ENV 56 | 57 | - uses: act10ns/slack@v2 58 | if: always() && github.event_name != 'pull_request' && steps.load-report.conclusion == 'success' 59 | with: 60 | webhook-url: ${{ secrets.SLACK_WEBHOOK_URL }} 61 | status: ${{ job.status }} 62 | message: | 63 | *Quickstart CI* - [${{ job.status }}] 64 | ${{ env.REPORT }} 65 | ref: ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }} 66 | -------------------------------------------------------------------------------- /.github/workflows/docs-build.yml: -------------------------------------------------------------------------------- 1 | name: "Build (& deploy) Docs" 2 | on: 3 | push: 4 | branches: [main] 5 | pull_request: 6 | branches: [main] 7 | types: [opened, reopened, ready_for_review, synchronize] 8 | workflow_dispatch: {} 9 | 10 | concurrency: 11 | group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref }} 12 | cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} 13 | 14 | defaults: 15 | run: 16 | shell: bash 17 | 18 | jobs: 19 | docs-make: 20 | if: github.event.pull_request.draft == false 21 | runs-on: ubuntu-22.04 22 | strategy: 23 | fail-fast: false 24 | matrix: 25 | target: ["html", "doctest", "linkcheck"] 26 | env: 27 | ARTIFACT_DAYS: 0 28 | PYPI_LOCAL_DIR: "pypi_pkgs/" 29 | steps: 30 | - uses: actions/checkout@v4 31 | - uses: actions/setup-python@v5 32 | with: 33 | python-version: "3.10" 34 | 35 | - name: Pull sphinx template 36 | run: make get-sphinx-theme 37 | - name: Install pandoc 38 | timeout-minutes: 5 39 | run: sudo apt-get install -y pandoc 40 | - name: Install package & dependencies 41 | timeout-minutes: 20 42 | run: pip install . -U -r requirements/docs.txt 43 | 44 | - name: Make ${{ matrix.target }} 45 | working-directory: docs/ 46 | # allow failing link check and doctest if you run with dispatch 47 | continue-on-error: ${{ matrix.target == 'doctest' || matrix.target == 'linkcheck' }} 48 | run: make ${{ matrix.target }} --debug --jobs $(nproc) SPHINXOPTS="-W --keep-going" 49 | 50 | - name: Keep artifact 51 | if: github.event_name == 'pull_request' 52 | run: echo "ARTIFACT_DAYS=7" >> $GITHUB_ENV 53 | - name: Upload built docs 54 | if: ${{ matrix.target == 'html' }} 55 | uses: actions/upload-artifact@v4 56 | with: 57 | name: docs-html-${{ github.sha }} 58 | path: docs/build/html/ 59 | retention-days: ${{ env.ARTIFACT_DAYS }} 60 | 61 | deploy-docs: 62 | needs: docs-make 63 | if: github.repository_owner == 'Lightning-AI' && github.event_name == 'push' 64 | runs-on: ubuntu-latest 65 | env: 66 | GCP_TARGET: "gs://lightning-docs-thunder" 67 | steps: 68 | - uses: actions/download-artifact@v4 69 | with: 70 | name: docs-html-${{ github.sha }} 71 | path: docs/build/html/ 72 | 73 | - name: Authenticate to Google Cloud 74 | uses: google-github-actions/auth@v2 75 | with: 76 | credentials_json: ${{ secrets.GCS_SA_KEY }} 77 | 78 | - name: Setup gcloud 79 | uses: google-github-actions/setup-gcloud@v2 80 | with: 81 | project_id: ${{ secrets.GCS_PROJECT }} 82 | 83 | # Uploading docs to GCS, so they can be served on lightning.ai 84 | #- name: Upload docs/thunder/stable to GCS 🪣 85 | # if: startsWith(github.ref, 'refs/heads/release/') 86 | # run: gsutil -m rsync -d -R docs/build/html/ ${GCP_TARGET}/stable 87 | 88 | # Uploading docs to GCS, so they can be served on lightning.ai 89 | - name: Upload docs/thunder/latest to GCS 🪣 90 | if: github.ref == 'refs/heads/main' 91 | run: gsutil -m rsync -d -R docs/build/html/ ${GCP_TARGET}/latest 92 | 93 | # Uploading docs to GCS, so they can be served on lightning.ai 94 | #- name: Upload docs/thunder/release to GCS 🪣 95 | # if: startsWith(github.ref, 'refs/tags/') 96 | # run: gsutil -m rsync -d -R docs/build/html/ ${GCP_TARGET}/${{ github.ref_name }} 97 | 98 | # Uploading docs as archive to GCS, so they can be as backup 99 | #- name: Upload docs as archive to GCS 🪣 100 | # if: startsWith(github.ref, 'refs/tags/') 101 | # working-directory: docs/build 102 | # run: | 103 | # zip ${{ github.ref_name }}.zip -r html/ 104 | # gsutil cp ${{ github.ref_name }}.zip ${GCP_TARGET} 105 | -------------------------------------------------------------------------------- /.github/workflows/label-conflicts.yml: -------------------------------------------------------------------------------- 1 | name: Label merge conflicts 2 | 3 | on: 4 | push: 5 | branches: ["main"] 6 | pull_request_target: 7 | types: ["synchronize", "reopened", "opened"] 8 | 9 | concurrency: 10 | group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref }} 11 | cancel-in-progress: true 12 | 13 | jobs: 14 | triage-conflicts: 15 | runs-on: ubuntu-latest 16 | steps: 17 | - uses: mschilde/auto-label-merge-conflicts@591722e97f3c4142df3eca156ed0dcf2bcd362bd 18 | with: 19 | CONFLICT_LABEL_NAME: "has conflicts" 20 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 21 | MAX_RETRIES: 3 22 | WAIT_MS: 5000 23 | -------------------------------------------------------------------------------- /.github/workflows/labeler.yml: -------------------------------------------------------------------------------- 1 | name: "Pull Request Labeler" 2 | on: [pull_request_target] 3 | 4 | jobs: 5 | triage-prs: 6 | permissions: 7 | contents: read 8 | pull-requests: write 9 | runs-on: ubuntu-latest 10 | steps: 11 | # Uploads repository content to the runner 12 | - uses: actions/checkout@v4 13 | - uses: actions/labeler@v5 14 | with: 15 | # The path to the label configuration file. 16 | configuration-path: .github/labeling-config.yml 17 | # Whether removing labels when matching files are reverted or no longer changed by the PR 18 | sync-labels: true 19 | -------------------------------------------------------------------------------- /.github/workflows/release-nightly.yml: -------------------------------------------------------------------------------- 1 | name: Nightly packages 2 | 3 | on: 4 | pull_request: # this shall test only the part of workflow before publishing 5 | branches: [main, "release/*"] 6 | types: [opened, reopened, ready_for_review, synchronize] 7 | paths: 8 | - ".github/workflows/release-nightly.yml" 9 | schedule: 10 | - cron: "0 0 * * 0" # on Sundays 11 | workflow_dispatch: {} 12 | 13 | defaults: 14 | run: 15 | shell: bash 16 | 17 | jobs: 18 | releasing-nightly: 19 | runs-on: ubuntu-22.04 20 | steps: 21 | - uses: actions/checkout@v4 22 | - uses: actions/setup-python@v5 23 | with: 24 | python-version: "3.10" 25 | 26 | - name: Install dependencies 27 | run: python -m pip install --user --upgrade setuptools wheel packaging 28 | - name: Build package 29 | env: 30 | CONVERT_VERSION2NIGHTLY: "1" 31 | run: python setup.py sdist bdist_wheel 32 | 33 | # We do this, since failures on test.pypi aren't that bad 34 | - name: Publish to Test PyPI 35 | if: github.event_name == 'schedule' || github.event_name == 'workflow_dispatch' 36 | uses: pypa/gh-action-pypi-publish@v1.12.4 37 | with: 38 | user: __token__ 39 | password: ${{ secrets.test_pypi_password }} 40 | repository_url: https://test.pypi.org/legacy/ 41 | 42 | - name: Publish distribution 📦 to PyPI 43 | if: github.event_name == 'schedule' || github.event_name == 'workflow_dispatch' 44 | uses: pypa/gh-action-pypi-publish@v1.12.4 45 | with: 46 | user: __token__ 47 | password: ${{ secrets.pypi_password }} 48 | -------------------------------------------------------------------------------- /.github/workflows/release-pypi.yml: -------------------------------------------------------------------------------- 1 | name: PyPI Release 2 | 3 | # https://help.github.com/en/actions/reference/events-that-trigger-workflows 4 | on: # Trigger the workflow on push or pull request, but only for the main branch 5 | push: 6 | branches: [main] 7 | release: 8 | types: [published] 9 | 10 | # based on https://github.com/pypa/gh-action-pypi-publish 11 | 12 | jobs: 13 | releasing-pypi: 14 | runs-on: ubuntu-22.04 15 | steps: 16 | - uses: actions/checkout@v4 17 | - uses: actions/setup-python@v5 18 | with: 19 | python-version: "3.10" 20 | 21 | - name: Install dependencies 22 | run: python -m pip install --user --upgrade setuptools wheel packaging 23 | - name: Build 24 | run: python setup.py sdist bdist_wheel 25 | 26 | # We do this, since failures on test.pypi aren't that bad 27 | - name: Publish to Test PyPI 28 | if: startsWith(github.event.ref, 'refs/tags') || github.event_name == 'release' 29 | uses: pypa/gh-action-pypi-publish@v1.12.4 30 | with: 31 | user: __token__ 32 | password: ${{ secrets.test_pypi_password }} 33 | repository_url: https://test.pypi.org/legacy/ 34 | 35 | - name: Publish distribution 📦 to PyPI 36 | if: startsWith(github.event.ref, 'refs/tags') || github.event_name == 'release' 37 | uses: pypa/gh-action-pypi-publish@v1.12.4 38 | with: 39 | user: __token__ 40 | password: ${{ secrets.pypi_password }} 41 | -------------------------------------------------------------------------------- /.github/workflows/stale-branches.yaml: -------------------------------------------------------------------------------- 1 | name: Delete abandoned branches 2 | 3 | on: 4 | pull_request: 5 | branches: ["main"] 6 | paths: 7 | - ".github/workflows/stale-branches.yaml" 8 | # Run daily at midnight 9 | schedule: 10 | - cron: "0 0 * * *" 11 | # Allow workflow to be manually run from the GitHub UI 12 | workflow_dispatch: 13 | 14 | concurrency: 15 | group: ${{ github.workflow }}-${{ github.ref }} 16 | cancel-in-progress: true 17 | 18 | jobs: 19 | cleanup-old-branches: 20 | runs-on: ubuntu-latest 21 | name: Satisfy my repo CDO 22 | env: 23 | DRY_RUN: no 24 | steps: 25 | - name: Set dry run 26 | if: ${{ github.event_name == 'pull_request' }} 27 | run: echo "DRY_RUN=yes" >> $GITHUB_ENV 28 | - name: Delete those pesky dead branches 29 | uses: phpdocker-io/github-actions-delete-abandoned-branches@v2 30 | id: delete_stuff 31 | with: 32 | github_token: ${{ github.token }} 33 | last_commit_age_days: 90 34 | ignore_branches: main 35 | # Disable dry run and actually get stuff deleted 36 | # For a PR, always perform dry run 37 | dry_run: ${{ env.DRY_RUN }} 38 | 39 | - name: Get output 40 | run: "echo 'Deleted branches: ${{ steps.delete_stuff.outputs.deleted_branches }}'" 41 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | 29 | # Python venv 30 | bin/ 31 | lib64 32 | pyvenv.cfg 33 | share/ 34 | etc/ 35 | include/ 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .nox/ 51 | .coverage 52 | .coverage.* 53 | .cache 54 | nosetests.xml 55 | coverage.xml 56 | *.cover 57 | *.py,cover 58 | .hypothesis/ 59 | .pytest_cache/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Sphinx documentation 66 | docs/_build/ 67 | docs/source/api/ 68 | docs/source/*.md 69 | docs/source/notebooks/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # Jupyter Notebook 75 | .ipynb_checkpoints 76 | 77 | # VSCode 78 | .vscode/ 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 88 | __pypackages__/ 89 | 90 | # Celery stuff 91 | celerybeat-schedule 92 | celerybeat.pid 93 | 94 | # SageMath parsed files 95 | *.sage.py 96 | 97 | # Environments 98 | .env 99 | .venv 100 | env/ 101 | venv/ 102 | ENV/ 103 | env.bak/ 104 | venv.bak/ 105 | 106 | # Spyder project settings 107 | .spyderproject 108 | .spyproject 109 | 110 | # Rope project settings 111 | .ropeproject 112 | 113 | # mkdocs documentation 114 | /site 115 | 116 | # mypy 117 | .mypy_cache/ 118 | .dmypy.json 119 | dmypy.json 120 | 121 | # Pyre type checker 122 | .pyre/ 123 | 124 | # PyCharm 125 | .idea/ 126 | 127 | # Lightning logs 128 | lightning_logs 129 | *.gz 130 | .DS_Store 131 | .*_submit.py 132 | 133 | # Editor temporary file 134 | *.swn 135 | *.swo 136 | *.swp 137 | *.swm 138 | *~ 139 | 140 | # Build artifacts 141 | scripts/build 142 | 143 | # Profiler traces 144 | benchmarks/traces 145 | 146 | # benchmark results 147 | .results 148 | 149 | .ruff_cache/ 150 | 151 | # come CI artifacts 152 | notebooks/all.txt 153 | notebooks/ci.txt 154 | 155 | quickstart_report.json 156 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3.10 3 | 4 | ci: 5 | autofix_prs: true 6 | autoupdate_commit_msg: "[pre-commit.ci] pre-commit suggestions" 7 | # submodules: true 8 | 9 | repos: 10 | - repo: https://github.com/pre-commit/pre-commit-hooks 11 | rev: v5.0.0 12 | hooks: 13 | - id: end-of-file-fixer 14 | - id: trailing-whitespace 15 | exclude: README.md 16 | - id: check-case-conflict 17 | - id: check-yaml 18 | - id: check-toml 19 | - id: check-json 20 | - id: check-added-large-files 21 | args: ["--maxkb=400", "--enforce-all"] 22 | exclude: notebooks 23 | - id: check-docstring-first 24 | - id: detect-private-key 25 | 26 | - repo: https://github.com/asottile/pyupgrade 27 | rev: v3.20.0 28 | hooks: 29 | - id: pyupgrade 30 | args: ["--py310-plus"] 31 | name: Upgrade code 32 | exclude: "examples|thunder/tests/test_interpreter.py|thunder/tests/test_jit_general.py" 33 | 34 | - repo: https://github.com/codespell-project/codespell 35 | rev: v2.4.1 36 | hooks: 37 | - id: codespell 38 | additional_dependencies: [tomli] 39 | #args: ["--write-changes"] # uncomment if you want to get automatic fixing 40 | 41 | - repo: https://github.com/psf/black 42 | rev: 25.1.0 43 | hooks: 44 | - id: black 45 | name: Black code 46 | exclude: "examples" 47 | 48 | - repo: https://github.com/executablebooks/mdformat 49 | rev: 0.7.22 50 | hooks: 51 | - id: mdformat 52 | additional_dependencies: 53 | - mdformat-gfm 54 | - mdformat-black 55 | - mdformat_frontmatter 56 | exclude: "examples" 57 | 58 | - repo: https://github.com/sphinx-contrib/sphinx-lint 59 | rev: v1.0.0 60 | hooks: 61 | - id: sphinx-lint 62 | 63 | - repo: https://github.com/asottile/yesqa 64 | rev: v1.5.0 65 | hooks: 66 | - id: yesqa 67 | 68 | # - repo: https://github.com/charliermarsh/ruff-pre-commit 69 | # rev: v0.0.270 70 | # hooks: 71 | # - id: ruff 72 | # args: ["--fix"] 73 | 74 | - repo: https://github.com/pre-commit/mirrors-prettier 75 | rev: v3.1.0 76 | hooks: 77 | - id: prettier 78 | # https://prettier.io/docs/en/options.html#print-width 79 | files: \.(json|yml|yaml|toml) 80 | args: ["--print-width=120"] 81 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-exclude __pycache__ *.py[cod] *.orig 2 | 3 | # Include the README and CHANGELOG 4 | include *.md 5 | recursive-include pl_sandbox *.md 6 | 7 | # Include the license file 8 | include LICENSE 9 | 10 | # Exclude build configs 11 | exclude *.toml 12 | exclude *.svg 13 | exclude *.yml 14 | exclude *.yaml 15 | 16 | # exclude tests from package 17 | recursive-exclude thunder/tests * 18 | recursive-exclude site * 19 | exclude thunder/tests 20 | 21 | # Exclude the documentation files 22 | recursive-exclude docs * 23 | exclude docs 24 | 25 | # Include the Requirements 26 | recursive-include requirements *.txt 27 | 28 | # Exclude Makefile 29 | exclude Makefile 30 | 31 | prune .git 32 | prune .github 33 | prune notebook* 34 | prune temp* 35 | prune test* 36 | prune benchmark* 37 | prune examples* 38 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: test clean docs 2 | 3 | # assume you have installed need packages 4 | export SPHINX_MOCK_REQUIREMENTS=0 5 | 6 | test: clean 7 | pip install -q -r requirements.txt -r requirements/test.txt 8 | 9 | # use this to run tests 10 | python -m coverage run --source thunder -m pytest thunder tests -v 11 | python -m coverage report 12 | 13 | get-sphinx-theme: 14 | pip install -q awscli 15 | mkdir -p dist/ 16 | aws s3 sync --no-sign-request s3://sphinx-packages/ dist/ 17 | pip install lai-sphinx-theme -f dist/ 18 | 19 | docs: clean get-sphinx-theme 20 | pip install -e . --quiet -r requirements/docs.txt -f https://download.pytorch.org/whl/cpu/torch_stable.html 21 | cd docs ; python -m sphinx -b html -W --keep-going source build 22 | 23 | clean: 24 | # clean all temp runs 25 | rm -rf .mypy_cache 26 | rm -rf .pytest_cache 27 | rm -rf ./docs/build 28 | rm -rf ./docs/source/**/generated 29 | rm -rf ./docs/source/api 30 | rm -rf _ckpt_* 31 | -------------------------------------------------------------------------------- /dockers/README.md: -------------------------------------------------------------------------------- 1 | # Docker images 2 | 3 | ## Build images from Dockerfiles 4 | 5 | You can build it on your own, note it takes lots of time, be prepared. 6 | 7 | ```bash 8 | # build with specific arguments 9 | docker image build -t lightning:ubuntu-cuda-py3.10-cuda12.1.1 -f dockers/ubuntu-cuda/Dockerfile --build-arg "CUDA_VERSION=12.1.1" . 10 | ``` 11 | 12 | To run your docker use 13 | 14 | ```bash 15 | docker image list 16 | docker run --rm -it pytorch-lightning:ubuntu-cuda-py3.10-cuda11.7.0 bash 17 | ``` 18 | 19 | ## Run docker image with GPUs 20 | 21 | To run docker image with access to your GPUs, you need to install 22 | 23 | ```bash 24 | # Add the package repositories 25 | distribution=$(. /etc/os-release;echo $ID$VERSION_ID) 26 | curl -s -L https://nvidia.github.io/nvidia-docker/gpgkey | sudo apt-key add - 27 | curl -s -L https://nvidia.github.io/nvidia-docker/$distribution/nvidia-docker.list | sudo tee /etc/apt/sources.list.d/nvidia-docker.list 28 | 29 | sudo apt-get update && sudo apt-get install -y nvidia-container-toolkit 30 | sudo systemctl restart docker 31 | ``` 32 | 33 | and later run the docker image with `--gpus=all`. For example, 34 | 35 | ```bash 36 | docker run --rm -it --gpus=all pytorchlightning/lightning:ubuntu-cuda-py3.10-cuda12.1.0 37 | ``` 38 | -------------------------------------------------------------------------------- /dockers/with-apex/Dockerfile: -------------------------------------------------------------------------------- 1 | # Copyright The Lightning AI team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | ARG BASE_IMAGE_TAG 16 | 17 | FROM pytorchlightning/lightning-thunder:${BASE_IMAGE_TAG} 18 | 19 | ARG APEX_CHECKOUT="master" 20 | 21 | SHELL ["/bin/bash", "-c"] 22 | 23 | RUN \ 24 | # building Apex from source 25 | pip install "pip>=23.1" packaging && \ 26 | git clone https://github.com/NVIDIA/apex && \ 27 | cd apex && \ 28 | git checkout ${APEX_CHECKOUT} && \ 29 | # https://github.com/NVIDIA/apex#linux 30 | pip install -v \ 31 | --disable-pip-version-check \ 32 | --no-cache-dir \ 33 | --no-build-isolation \ 34 | --config-settings "--build-option=--xentropy" \ 35 | . && \ 36 | cd .. && \ 37 | rm -rf apex 38 | 39 | RUN \ 40 | # Show what we have 41 | pip --version && \ 42 | pip list && \ 43 | python -c "import apex" 44 | -------------------------------------------------------------------------------- /dockers/with-dev/Dockerfile: -------------------------------------------------------------------------------- 1 | # Copyright The Lightning AI team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | ARG BASE_IMAGE_TAG 16 | 17 | FROM pytorchlightning/lightning-thunder:${BASE_IMAGE_TAG} 18 | 19 | SHELL ["/bin/bash", "-c"] 20 | 21 | COPY requirements/ requirements/ 22 | 23 | RUN \ 24 | ls -lh requirements/ && \ 25 | CUDA_VERSION_MM=${CUDA_VERSION%.*} && \ 26 | MAKEFLAGS="-j$(( $(nproc) / 4 ))" && \ 27 | pip install -U -r requirements/test.txt && \ 28 | rm -rf requirements/ 29 | 30 | RUN \ 31 | # Show what we have 32 | pip --version && \ 33 | pip list 34 | -------------------------------------------------------------------------------- /docs/.build_docs.sh: -------------------------------------------------------------------------------- 1 | make clean 2 | make html --debug --jobs $(nproc) 3 | -------------------------------------------------------------------------------- /docs/.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Build documentation in the docs/ directory with Sphinx 9 | # reference: https://docs.readthedocs.io/en/stable/config-file/v2.html#sphinx 10 | sphinx: 11 | fail_on_warning: true 12 | configuration: docs/conf.py 13 | 14 | build: 15 | os: "ubuntu-22.04" 16 | tools: 17 | python: "3.10" 18 | commands: 19 | - printenv 20 | - pwd ; pip install -q py-tree ; py-tree . 21 | - make docs 22 | - mkdir -p _readthedocs ; mv docs/build _readthedocs/html 23 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = -W 6 | SPHINXBUILD = python $(shell which sphinx-build) 7 | SOURCEDIR = source 8 | BUILDDIR = build 9 | 10 | # Put it first so that "make" without argument is like "make help". 11 | help: 12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 13 | 14 | .PHONY: help Makefile 15 | 16 | # Catch-all target: route all unknown targets to Sphinx using the new 17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 18 | %: Makefile 19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 20 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/source/_static/copybutton.js: -------------------------------------------------------------------------------- 1 | /* Copied from the official Python docs: https://docs.python.org/3/_static/copybutton.js */ 2 | $(document).ready(function() { 3 | /* Add a [>>>] button on the top-right corner of code samples to hide 4 | * the >>> and ... prompts and the output and thus make the code 5 | * copyable. */ 6 | var div = $('.highlight-python .highlight,' + 7 | '.highlight-python3 .highlight,' + 8 | '.highlight-pycon .highlight,' + 9 | '.highlight-default .highlight'); 10 | var pre = div.find('pre'); 11 | 12 | // get the styles from the current theme 13 | pre.parent().parent().css('position', 'relative'); 14 | var hide_text = 'Hide the prompts and output'; 15 | var show_text = 'Show the prompts and output'; 16 | var border_width = pre.css('border-top-width'); 17 | var border_style = pre.css('border-top-style'); 18 | var border_color = pre.css('border-top-color'); 19 | var button_styles = { 20 | 'cursor':'pointer', 'position': 'absolute', 'top': '0', 'right': '0', 21 | 'border-color': border_color, 'border-style': border_style, 22 | 'border-width': border_width, 'color': border_color, 'text-size': '75%', 23 | 'font-family': 'monospace', 'padding-left': '0.2em', 'padding-right': '0.2em', 24 | 'border-radius': '0 3px 0 0' 25 | } 26 | 27 | // create and add the button to all the code blocks that contain >>> 28 | div.each(function(index) { 29 | var jthis = $(this); 30 | if (jthis.find('.gp').length > 0) { 31 | var button = $('>>>'); 32 | button.css(button_styles) 33 | button.attr('title', hide_text); 34 | button.data('hidden', 'false'); 35 | jthis.prepend(button); 36 | } 37 | // tracebacks (.gt) contain bare text elements that need to be 38 | // wrapped in a span to work with .nextUntil() (see later) 39 | jthis.find('pre:has(.gt)').contents().filter(function() { 40 | return ((this.nodeType == 3) && (this.data.trim().length > 0)); 41 | }).wrap(''); 42 | }); 43 | 44 | // define the behavior of the button when it's clicked 45 | $('.copybutton').click(function(e){ 46 | e.preventDefault(); 47 | var button = $(this); 48 | if (button.data('hidden') === 'false') { 49 | // hide the code output 50 | button.parent().find('.go, .gp, .gt').hide(); 51 | button.next('pre').find('.gt').nextUntil('.gp, .go').css('visibility', 'hidden'); 52 | button.css('text-decoration', 'line-through'); 53 | button.attr('title', show_text); 54 | button.data('hidden', 'true'); 55 | } else { 56 | // show the code output 57 | button.parent().find('.go, .gp, .gt').show(); 58 | button.next('pre').find('.gt').nextUntil('.gp, .go').css('visibility', 'visible'); 59 | button.css('text-decoration', 'none'); 60 | button.attr('title', hide_text); 61 | button.data('hidden', 'false'); 62 | } 63 | }); 64 | }); 65 | -------------------------------------------------------------------------------- /docs/source/_static/images/LightningThunderDarkModewByline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-AI/lightning-thunder/8bd802c89988e6668f5328d3927deda1b4e9e706/docs/source/_static/images/LightningThunderDarkModewByline.png -------------------------------------------------------------------------------- /docs/source/_static/images/LightningThunderLightModewByline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-AI/lightning-thunder/8bd802c89988e6668f5328d3927deda1b4e9e706/docs/source/_static/images/LightningThunderLightModewByline.png -------------------------------------------------------------------------------- /docs/source/_static/images/how_it_works.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-AI/lightning-thunder/8bd802c89988e6668f5328d3927deda1b4e9e706/docs/source/_static/images/how_it_works.png -------------------------------------------------------------------------------- /docs/source/_static/images/lightning_thunder_lightmode_nobyline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-AI/lightning-thunder/8bd802c89988e6668f5328d3927deda1b4e9e706/docs/source/_static/images/lightning_thunder_lightmode_nobyline.png -------------------------------------------------------------------------------- /docs/source/_static/images/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-AI/lightning-thunder/8bd802c89988e6668f5328d3927deda1b4e9e706/docs/source/_static/images/logo.png -------------------------------------------------------------------------------- /docs/source/_static/images/normalized_training_throughput_zero2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-AI/lightning-thunder/8bd802c89988e6668f5328d3927deda1b4e9e706/docs/source/_static/images/normalized_training_throughput_zero2.png -------------------------------------------------------------------------------- /docs/source/_static/images/pretrain_perf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-AI/lightning-thunder/8bd802c89988e6668f5328d3927deda1b4e9e706/docs/source/_static/images/pretrain_perf.png -------------------------------------------------------------------------------- /docs/source/_static/images/training_throughput_single.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-AI/lightning-thunder/8bd802c89988e6668f5328d3927deda1b4e9e706/docs/source/_static/images/training_throughput_single.png -------------------------------------------------------------------------------- /docs/source/_templates/theme_variables.jinja: -------------------------------------------------------------------------------- 1 | {%- set external_urls = { 2 | 'github': 'https://github.com/Lightning-AI/lightning-thunder', 3 | 'github_issues': 'https://github.com/Lightning-AI/lightning-thunder/issues', 4 | 'contributing': 'https://github.com/Lightning-AI/lightning-thunder/blob/master/CONTRIBUTING.md', 5 | 'governance': 'https://github.com/Lightning-AI/lightning-thunder/blob/master/governance.md', 6 | 'docs': 'https://lightning-thunder.rtfd.io/en/latest', 7 | 'twitter': 'https://twitter.com/LightningAI', 8 | 'discuss': 'https://pytorch-lightning.slack.com', 9 | 'tutorials': 'https://lightning-thunder.readthedocs.io/en/latest/#tutorials', 10 | 'previous_pytorch_versions': 'https://lightning-thunder.rtfd.io/en/latest/', 11 | 'home': 'https://lightning-thunder.rtfd.io/en/latest/', 12 | 'get_started': 'https://lightning-thunder.readthedocs.io/en/latest/introduction_guide.html', 13 | 'features': 'https://lightning-thunder.rtfd.io/en/latest/', 14 | 'blog': 'https://www.Lightning-AI.ai/blog', 15 | 'resources': 'https://lightning-thunder.readthedocs.io/en/latest/#community-examples', 16 | 'support': 'https://lightning-thunder.rtfd.io/en/latest/', 17 | } 18 | -%} 19 | -------------------------------------------------------------------------------- /docs/source/advanced/extending.rst: -------------------------------------------------------------------------------- 1 | Extending Thunder 2 | ################# 3 | 4 | .. 5 | TODO RC1: update using the extend API 6 | 7 | This section describes how to add an executor to Thunder for a PyTorch operation. 8 | 9 | First, define a Python function with the same signature as the targeted operation, and have it call your implementation. For example, the Apex executor for ``torch.nn.functional.cross_entropy`` might define its implementation like:: 10 | 11 | import torch 12 | import xentropy_cuda 13 | 14 | def apex_xentropy( 15 | a: torch.Tensor, # a is an actual PyTorch tensor 16 | target, 17 | weight=None, 18 | size_average=None, 19 | ignore_index=-100, 20 | reduce=None, 21 | reduction="mean", 22 | label_smoothing=0.0, 23 | ): 24 | losses, max_log_sum_exp = xentropy_cuda.forward(a, target, label_smoothing, half_to_float) 25 | 26 | When this implementation is used it will be called with actual PyTorch tensors, and not with proxies. 27 | 28 | Next, define a “checker” function with the same signature as the targeted operation that returns True if your operation can execute the targeted operation and False otherwise. Checkers, unlike the implementations, are called with proxies, and not actual PyTorch tensors, because they're called at optimization time. The purpose of a checker function is to let executors target only specific inputs to an operation, and defer to another executor on other inputs. 29 | 30 | A checker function for the Apex executor might look like:: 31 | 32 | from thunder.core.proxies import TensorProxy 33 | 34 | def apex_xentropy_checker( 35 | a: TensorProxy, # a is a proxy 36 | target, 37 | weight=None, 38 | size_average=None, 39 | ignore_index=-100, 40 | reduce=None, 41 | reduction="mean", 42 | label_smoothing=0.0, 43 | ): 44 | # Apex's xentropy only supports "sum", "mean" or "none" reductions 45 | if reduction not in ["sum", "mean", "none"]: 46 | return False 47 | 48 | return True 49 | 50 | Create a mapping from the name of the PyTorch operation to your replacement implementation's name, its checker, and its implementation:: 51 | 52 | _op_to_xentropy = { 53 | "torch.nn.functional.cross_entropy": ("apex_xentropy", apex_xentropy_checker, apex_xentropy), 54 | } 55 | 56 | Then define a registration function that practitioners can call to access your executor:: 57 | 58 | def register_apex_xentropyex(*, add_to_default_executors: bool = True) -> None: 59 | from thunder.executors import add_operator_executor 60 | 61 | return add_operator_executor("apex_xentropy", _op_to_xentropy, add_to_default_executors=add_to_default_executors) 62 | 63 | You can test your executor by registering it, compiling a function that calls the targeted operator, and then verifying that your operation is called (by inspecting the execution trace) and producing the correct output. A good example of this is the tests for the Apex executor. 64 | -------------------------------------------------------------------------------- /docs/source/advanced/inside_thunder.rst: -------------------------------------------------------------------------------- 1 | Inside Thunder 2 | ############## 3 | 4 | This section elaborates on the design of some of Thunder's internals. 5 | 6 | Bytecode interpretation 7 | ======================= 8 | 9 | Thunder's interpreter works by: 10 | 11 | 1. Disassembling the PyTorch module or function into CPython bytecode 12 | 2. Interpreting the bytecode using an extended Python interpreter 13 | 3. Generating a sequential trace of operations on tensors and numbers 14 | 15 | Representing Operations 16 | ======================= 17 | 18 | Thunder supports a subset of PyTorch's operators (see ``thunder.torch.__init__.py`` for which operators are supported). 19 | Thunder has to define its own versions of PyTorch operators so it knows how to compile and trace them properly. This section details how Thunder represents operations. 20 | 21 | Let's start by looking at how ``torch.nn.functional.softmax`` appears in a trace of the nanoGPT (https://github.com/karpathy/nanoGPT) model:: 22 | 23 | t63 = ltorch.softmax(t52, dim=-1) # t63: "cuda:0 f16[8, 12, 64, 64]" 24 | # t53 = prims.convert_element_type(t52, dtypes.float32) # t53: "cuda:0 f32[8, 12, 64, 64]" 25 | # t55 = ltorch.amax(t53, -1, keepdim=True) # t55: "cuda:0 f32[8, 12, 64, 1]" 26 | # t54 = prims.amax(t53, (3,)) # t54: "cuda:0 f32[8, 12, 64]" 27 | # t55 = prims.broadcast_in_dim(t54, [8, 12, 64, 1], [0, 1, 2]) # t55: "cuda:0 f32[8, 12, 64, 1]" 28 | # t57 = ltorch.sub(t53, t55) # t57: "cuda:0 f32[8, 12, 64, 64]" 29 | # t56 = prims.broadcast_in_dim(t55, [8, 12, 64, 64], (0, 1, 2, 3)) # t56: "cuda:0 f32[8, 12, 64, 64]" 30 | # t57 = prims.sub(t53, t56) # t57: "cuda:0 f32[8, 12, 64, 64]" 31 | # t58 = ltorch.exp(t57) # t58: "cuda:0 f32[8, 12, 64, 64]" 32 | # t58 = prims.exp(t57) # t58: "cuda:0 f32[8, 12, 64, 64]" 33 | # t60 = ltorch.sum(t58, -1, keepdim=True) # t60: "cuda:0 f32[8, 12, 64, 1]" 34 | # t59 = prims.sum(t58, (3,)) # t59: "cuda:0 f32[8, 12, 64]" 35 | # t60 = prims.broadcast_in_dim(t59, [8, 12, 64, 1], [0, 1, 2]) # t60: "cuda:0 f32[8, 12, 64, 1]" 36 | # t62 = ltorch.true_divide(t58, t60) # t62: "cuda:0 f32[8, 12, 64, 64]" 37 | # t61 = prims.broadcast_in_dim(t60, [8, 12, 64, 64], (0, 1, 2, 3)) # t61: "cuda:0 f32[8, 12, 64, 64]" 38 | # t62 = prims.div(t58, t61) # t62: "cuda:0 f32[8, 12, 64, 64]" 39 | # t63 = prims.convert_element_type(t62, dtypes.float16) # t63: "cuda:0 f16[8, 12, 64, 64]" 40 | 41 | Instead of the original operation we see a call to the corresponding ``thunder.torch`` operation, then a comment that describes the decomposition of this operation into other ``thunder.torch`` calls (identified by ``ltorch`` in the snippet above) and ``thunder.core.prims`` calls, with the calls to the *primitives* defined in ``thunder.core.prims`` being “terminal” — they decompose into nothing. 42 | 43 | Every Thunder operation can be decomposed into one or more primitives, and these decompositions are essential to trace, transform, and optimize them. For example, every primitive operation defines a “meta function” that maps proxy inputs to proxy outputs. When tracing, some inputs, like PyTorch tensors, are replaced with proxies. We know what the proxy output of operations like ``torch.softmax`` would be by decomposing it into primitives, essentially giving it an implicit meta function. If operations weren't defined in terms of primitives, then each operation would require its own meta function, and its own rule for transforms like autograd, and executors would have to reason about every Pytorch operator. 44 | 45 | Primitives are what let Thunder define operations without dramatically increasing its complexity, but the set of primitives must also be carefully chosen. Primitives serve two purposes: 46 | 47 | - They must be as simple and few in number as possible. A small set of simple operations is easier to analyze, transform, optimize, and execute than a large set of complicated operations. 48 | - They must be expressive enough to describe and facilitate the execution of deep learning and scientific computations. 49 | 50 | Since primitives are as simple as possible, they do not broadcast or type promote, and they typically do not have default arguments. For example, in the above trace the call to ``ltorch.sub`` decomposes into a broadcast and then the primitive for subtraction because broadcast is its own primitive. 51 | 52 | Thunder primitives are similar to the operations in JAX's ``jax.lax`` module, which is a wrapper around XLA's HLO operations. 53 | 54 | Because the prims are so simple and few, writing the decompositions of PyTorch operations directly in terms of primitives would be painstaking. Instead, Thunder has a “core language” of common deep learning operations, and Thunder's PyTorch decompositions typically call these core language operations or other PyTorch operations. Note that core language operations don't appear in traces for simplicity (they have no other use except producing decompositions and can't be executed directly). 55 | -------------------------------------------------------------------------------- /docs/source/basic/overview.rst: -------------------------------------------------------------------------------- 1 | Thunder Overview 2 | ################ 3 | 4 | This section introduces Thunder's core concepts and architecture. For more details, see :doc:`Inside thunder <../advanced/inside_thunder>`. 5 | 6 | Thunder is a deep learning compiler for PyTorch, which means it translates calls to PyTorch modules into a format that is easy to transform and that executors can consume to produce fast executables. This translation must produce a simple representation focusing on tensor operations. The format we've chosen, like other deep learning compilers, is a sequence of operations called a program *trace*. 7 | 8 | This translation begins with:: 9 | 10 | jitted_model = thunder.jit(my_module) 11 | 12 | or:: 13 | 14 | jitted_fn = thunder.jit(my_function) 15 | 16 | When given a module, the call to ``thunder.jit()`` returns a Thunder-optimized module that shares parameters with the original module (as demonstrated in the :doc:`Train a MLP on MNIST ` example), and when given a function it returns a function that when called will jit compile a path through the original function given information about the inputs. 17 | 18 | When the jitted module or function is called:: 19 | 20 | jitted_model(*args, **kwargs) 21 | 22 | or:: 23 | 24 | jitted_fn(*args, **kwargs) 25 | 26 | 27 | As suggested above, Thunder begins reviewing the module's or function's Python bytecode and the input. It may be surprising that Thunder considers the inputs at all, but since control flow (and therefore the operations captured) may vary depending on the input, this is actually required to produce a trace. These traces are cached, so that if inputs of the same type, shape, etc are used again, the trace can be reused. 28 | 29 | Traces are generated by running the bytecode through a custom Python interpreter, which is itself implemented in Python. This interpreter has been extended to perform instructions in a different way compared to what standard CPython does. In particular, it constructs a trace of operations performed on tensors or numbers, and keeps track of the provenance of all objects in the program, whether they originated from inside the interpreter or outside. 30 | 31 | Much like other machine learning frameworks, Traces don't typically deal directly with PyTorch tensors, but with *proxies* that only have metadata like shape, device, dtype, and whether the tensor requires grad or not. As such, during interpretation for trace generation, the execution of the program doesn't perform any computation on accelerators. Instead, it records the operators along one path of the traceable function. 32 | 33 | If replacing CPython with an interpreter written in Python sounds problematic from a performance perspective, you would be largely correct. We haven't yet put any time into optimizing it, and we think it consumes roughly 400x as much CPU time as CPython. However, the function only needs to be jitted once per equivalence class of inputs, and CPU is not a bottleneck in most machine learning pipelines. As long as the metadata of the inputs (such as a tensor's shape) and control flow conditions are not changed, we can rely on smart caching to immediately execute an optimized trace. The end result is a faster total execution time. 34 | 35 | Traces can be transformed (like for ``backward()``) and optimized (like by replacing calls to eager PyTorch operations with calls to faster executors), and the final result of this process is an *execution trace*. Thunder executes the original call by converting the execution trace into a Python function and calling that function with the actual inputs. For details about this optimization process, see the :doc:`thunder step by step ` section. 36 | 37 | To recap, the complete translation process is: 38 | 39 | - For PyTorch modules, a Thunder-optimized module is created from the original module. 40 | - For PyTorch functions, compilation produces a compiled function. 41 | - When the module or function is called, the trace is generated, swapping some inputs with “proxies”. 42 | - The trace is transformed and optimized to produce an execution trace. 43 | - The execution trace is converted into a Python function and called. 44 | 45 | As mentioned, this translation process is often slow - it takes tens of seconds for nanoGPT's (https://github.com/karpathy/nanoGPT) largest configuration - so Thunder's performance model expects relatively few of these translations and then a lot of uses of the result. This corresponds with many training and inference patterns, where the same program is executed many times. 46 | -------------------------------------------------------------------------------- /docs/source/basic/sharp_edges.rst: -------------------------------------------------------------------------------- 1 | The sharp edges 2 | ############### 3 | 4 | This section describes features in the language that are not (yet) supported by Thunder, along with their workarounds. You might encounter these when compiling a module or function with Thunder. The examples should give you a good idea of how to change your program so that it works. 5 | 6 | Note that the fact that something is not supported today doesn't mean it won't be supported at some point in the future. Feel free to reach out to help us prioritize. 7 | 8 | Inplace operations 9 | ------------------ 10 | 11 | Inplace PyTorch operations like `t.add_(1.0)` are not supported in Thunder yet. Support for inplace operations is coming soon. 12 | 13 | 14 | Tensor subclasses 15 | ----------------- 16 | 17 | Thunder currently supports Python data types and PyTorch tensors as inputs of functions and models. 18 | 19 | Subclasses of these types, e.g. lazy tensors, nested tensors, or sparse tensors are not supported today. 20 | 21 | 22 | Tracing Python builtins, standard library operations and functions that call other languages 23 | -------------------------------------------------------------------------------------------- 24 | 25 | Calling a Python builtin, standard library operation, or a function that calls into another language is safe to trace, so long as the following rules are observed: 26 | 27 | 1. The function should not have side effects. For example, calling ``print()`` will execute the ``print()`` function while tracing, but since it's not a Thunder operation it will not appear in a trace, and so future cached executions will not execute the ``print()`` statement. 28 | 2. The function must not manipulate tensor data or metadata. Since the operation won't appear in a trace, these manipulations won't be repeated by Thunder, and may even cause a crash while tracing. To implement such operations, see :doc:`Adding Custom Operators <../notebooks/adding_custom_operator>` 29 | 3. The function must not produce different results across invocations. Again, since the operation won't appear in traces, Thunder cannot replicate an operation that produces different results when it's invoked, like ``random.random()`` will. 30 | 31 | .. 32 | Certain op-level behavior 33 | ------------------------- 34 | 1. Ops which have not yet been added to Thunder. Please let us know if there’s missing operator support you would like to see and we will be happy to help. 35 | 2. Data dependent control flow (e.g. ``if x.any()``). Since Thunder generates traces of programs ahead of the actual execution, control flow depending on the values of tensors as opposed to their metadata cannot be handled by Thunder. 36 | 37 | 38 | Using Thunder-optimized Modules 39 | ------------------------------- 40 | 41 | Compiling a module produces a Thunder-optimized module”. A Thunder-optimized module is less dynamic than the original module, which facilitates tracing and optimization. It has a reference to the original module, and it shares its parameters with it. 42 | 43 | While modifying the original model's parameters will reflect in the Thunder-optimized module, other changes to the original module will not. In particular: 44 | 45 | - Whether model is in ``train`` or ``eval`` mode is captured at compilation time and constant 46 | - The structure of the module is captured at compilation time, and changing the original module's structure will likely break the Thunder-optimized module 47 | - Non-parameter attributes of the module may or may not be captured at compile time and treated as constants 48 | 49 | Not all features of PyTorch modules are currently supported, either. Module hooks are not supported, and adding new module attributes in a module's ``forward()`` method is only partially supported. 50 | -------------------------------------------------------------------------------- /docs/source/fundamentals/examine.rst: -------------------------------------------------------------------------------- 1 | Using Examine 2 | ############# 3 | 4 | We recommend using Thunder's ``examine()`` before compiling a function or a module. 5 | 6 | Thunder cannot run every PyTorch module, but you can quickly find out what is missing using ``examine()``. 7 | 8 | ``examine()`` not only determines if Thunder can compile the module, but provides a helpful report to use when filing an issue requesting support. 9 | 10 | You can run examine like this:: 11 | 12 | from thunder.examine import examine 13 | 14 | model = MyModel(...) 15 | examine(model, *args, **kwargs) 16 | 17 | Where ``*args and **kwargs`` are valid inputs to the model. If examine determines that Thunder can run the module or function as expected, it will print:: 18 | 19 | The function appears to be working as expected 20 | 21 | When ``examine`` encounters a module or function with one or more operators it doesn't support, it will specify the operators, like this:: 22 | 23 | def foo(a): 24 | return torch.triu(a) 25 | 26 | import torch 27 | import thunder 28 | from thunder.examine import examine 29 | 30 | a = torch.full((2, 2), 1., device='cuda') 31 | examine(foo, a) 32 | 33 | Running the above will print:: 34 | 35 | Found 1 distinct operations, of which 0 (0.0%) are supported 36 | Please file an issue requesting the following operators here: https://github.com/Lightning-AI/lightning-thunder/issues/new 37 | _VariableFunctionsClass.triu of torch 38 | 39 | To recap, ``examine()`` lets you know if Thunder can run a module, and if it can't it will provide a report to file an issue asking for support. 40 | -------------------------------------------------------------------------------- /docs/source/fundamentals/hello_world.rst: -------------------------------------------------------------------------------- 1 | Hello World 2 | ########### 3 | 4 | Here is a simple example of how Thunder lets you compile and run PyTorch modules and functions:: 5 | 6 | import torch 7 | import thunder 8 | 9 | def foo(a, b): 10 | return a + b 11 | 12 | jitted_foo = thunder.jit(foo) 13 | 14 | a = torch.full((2, 2), 1) 15 | b = torch.full((2, 2), 3) 16 | 17 | result = jitted_foo(a, b) 18 | 19 | print(result) 20 | 21 | # prints 22 | # tensor( 23 | # [[4, 4], 24 | # [4, 4]]) 25 | 26 | The compiled function ``jitted_foo`` takes and returns PyTorch tensors, just like the original function, so modules and functions compiled by Thunder can be used as part of bigger PyTorch programs. 27 | 28 | Thunder currently understands a subset of PyTorch operations, and a subset of Python. It's adding support quickly, however. Reach out on the Thunder repo and open an issue there to easily get help if an operator is currently not supported. 29 | -------------------------------------------------------------------------------- /docs/source/fundamentals/installation.rst: -------------------------------------------------------------------------------- 1 | Install Lightning Thunder 2 | ######################### 3 | 4 | Minimal dependencies 5 | ==================== 6 | 7 | Follow these instructions to install PyTorch, nvFuser, and finally Thunder. 8 | 9 | Install PyTorch and nvFuser with pip (command shown is for CUDA 12.1 and PyTorch 2.5.x):: 10 | 11 | pip install --pre nvfuser-cu121-torch25 12 | 13 | cu121 can be replaced with cu118 depending on your CUDA version. NVFuser builds typically support the latest point release of PyTorch stable versions. 14 | For torch 2.5, cu124 is also supported. For nightly versions and more detailed instructions, please see https://github.com/NVIDIA/Fuser/#installation 15 | 16 | You're all set with minimal dependencies, so you can follow `Install Thunder`_. 17 | 18 | Dependencies so far don't include additional optimized kernels for execution like OpenAI Triton, cuDNN Fusion, or Apex. 19 | These are described in the following section, `Optional dependencies`_. 20 | 21 | Optional dependencies 22 | ===================== 23 | 24 | Install Apex 25 | ------------ 26 | 27 | Thunder can use NVIDIA's Apex to accelerate some PyTorch operations. To install the Apex executor, first clone the Apex git repository and then execute the following command in the project's root directory:: 28 | 29 | git clone https://github.com/NVIDIA/apex.git 30 | cd apex 31 | 32 | pip install -v --no-cache-dir --no-build-isolation --config-settings "--build-option=--xentropy" ./ 33 | 34 | Install cuDNN 35 | ------------- 36 | 37 | Thunder can use NVIDIA's cuDNN Python frontend bindings to accelerate some PyTorch operations. cuDNN backend is a pure-lib python package which needs to downloaded separately. 38 | 39 | pip install nvidia-cudnn-cu12 40 | pip install nvidia-cudnn-frontend 41 | 42 | You're all set, now follow `Install Thunder`_. 43 | 44 | Install OpenAI Triton 45 | --------------------- 46 | 47 | Thunder can easily integrate OpenAI Triton kernels. You can install Triton using:: 48 | 49 | pip install triton 50 | 51 | 52 | Install Thunder 53 | =============== 54 | 55 | You can now install Thunder:: 56 | 57 | pip install git+https://github.com/Lightning-AI/lightning-thunder.git 58 | 59 | Alternatively you can clone the Thunder repository and install locally:: 60 | 61 | git clone https://github.com/Lightning-AI/lightning-thunder.git 62 | cd lightning-thunder 63 | 64 | pip install . 65 | -------------------------------------------------------------------------------- /docs/source/intermediate/additional_executors.rst: -------------------------------------------------------------------------------- 1 | Additional executors 2 | #################### 3 | 4 | nvFuser and Pytorch are not the only executors available in Thunder today. Additional executors can be added to Thunder prior to compilation through a registration mechanism, which makes it easy to have specialized executors perform certain operations more efficiently. 5 | 6 | This section contains a list of all executors supported by PyTorch beyond nvFuser and PyTorch. 7 | 8 | Triton CrossEntropy Executor 9 | ============================ 10 | 11 | The Triton CrossEntropy executor can execute ``torch.cross_entropy()`` using an optimized kernel written in OpenAI Triton (https://github.com/openai/triton). It can be used like in the following example:: 12 | 13 | import torch 14 | import thunder 15 | from thunder.executors.triton_crossentropy import triton_ex as triton_cross_entropy_ex 16 | 17 | def xentropy(logits, labels, weight, reduction, ignore_index): 18 | return thunder.torch.cross_entropy( 19 | logits, labels, weight=weight, reduction=reduction, ignore_index=ignore_index 20 | ) 21 | 22 | jitted_xentropy = thunder.jit( 23 | xentropy, 24 | executors=[triton_cross_entropy_ex,] 25 | ) 26 | 27 | device = 'cuda' 28 | dtype = torch.float32 29 | 30 | logits = torch.randn([2048, 50257], device=device, dtype=dtype) 31 | labels = torch.randint(0, 50257, [2048], device=device) 32 | weight = torch.rand(50257, device=device, dtype=dtype, requires_grad=False) 33 | reduction = "sum" 34 | ignore_index = labels[5].item() 35 | 36 | jitted_xentropy(logits, labels, weight, reduction, ignore_index) 37 | traces = thunder.last_traces(jitted_xentropy) 38 | print(traces[-1]) 39 | 40 | This prints:: 41 | 42 | # Constructed by Delete Last Used (took 0 milliseconds) 43 | import torch 44 | from thunder.executors.torchex import no_autocast 45 | 46 | @torch.no_grad() 47 | @no_autocast 48 | def computation(logits, labels, weight): 49 | # logits: "cuda:0 f32[2048, 50257]" 50 | # labels: "cuda:0 i64[2048]" 51 | # weight: "cuda:0 f32[50257]" 52 | t23 = triton_crossentropy(logits, labels, weight, None, 45279, None, 'sum', 0.0) # t23: "cuda:0 f32[]" 53 | del logits, labels, weight 54 | return t23 55 | 56 | As shown in the above trace, ``triton_crossentropy()`` is the one running the operation. 57 | 58 | Apex CrossEntropy Executor 59 | ========================== 60 | 61 | The Apex CrossEntropy executor can execute ``torch.cross_entropy()`` through an optimized kernel, like this:: 62 | 63 | import torch 64 | import thunder 65 | from thunder.executors.apexex import apex_ex 66 | 67 | def xentropy(logits, labels): 68 | return thunder.torch.cross_entropy( 69 | logits, labels, reduction='mean', ignore_index=-1 70 | ) 71 | 72 | jitted_xentropy = thunder.jit(xentropy, executors=[apex_ex,]) 73 | 74 | device = 'cuda' 75 | dtype = torch.float32 76 | 77 | logits = torch.randn([2048, 50257], device=device, dtype=dtype) 78 | labels = torch.randint(0, 50257, [2048], device=device) 79 | 80 | jitted_xentropy(logits, labels) 81 | traces = thunder.last_traces(jitted_xentropy) 82 | print(traces[-1]) 83 | 84 | This prints:: 85 | 86 | # Constructed by Delete Last Used (took 0 milliseconds) 87 | import torch 88 | from thunder.executors.torchex import no_autocast 89 | 90 | @torch.no_grad() 91 | @no_autocast 92 | def computation(logits, labels): 93 | # logits: "cuda:0 f32[2048, 50257]" 94 | # labels: "cuda:0 i64[2048]" 95 | (t18, _) = apex_cross_entropy(logits, labels, 'mean', 0.0) 96 | del logits, labels 97 | return t18 98 | 99 | showing that Apex is running the operation. 100 | 101 | cuDNN SDPA Executor 102 | =================== 103 | 104 | TODO RC1 105 | 106 | TransformerEngine Executor 107 | ========================== 108 | 109 | TODO RC1 110 | -------------------------------------------------------------------------------- /docs/source/intermediate/ddp.rst: -------------------------------------------------------------------------------- 1 | Distributed Data Parallel (DDP) 2 | ############################### 3 | 4 | Thunder has its own Distributed Data Parallel (DDP) transform that we recommend using, although compiled modules also work with PyTorch's DDP transform. 5 | 6 | You can wrap a model in Thunder's ddp like this:: 7 | 8 | from thunder.distributed import ddp 9 | 10 | model = MyModel() 11 | ddp_model = ddp(model) 12 | cmodel = thunder.jit(ddp_model) 13 | 14 | Specifying which rank to broadcast from is optional. ``ddp()`` will broadcast from the lowest rank in that group if ``broadcast_from`` is not specified. 15 | 16 | Thunder's ddp is compatible with PyTorch distributed runners like ``torchrun`` (https://pytorch.org/docs/stable/elastic/run.html). 17 | 18 | When using PyTorch's DDP, call DDP on the jitted module:: 19 | 20 | from torch.nn.parallel import DistributedDataParallel as DDP 21 | 22 | model = MyModel() 23 | jitted_model = thunder.jit(model) 24 | ddp_model = DDP(jitted_model) 25 | 26 | The ability of Thunder to express distributed algorithms like DDP as a simple transform on the trace is one of Thunder's strengths and is being leveraged to quickly implement more elaborate distributed strategies, like Fully Sharded Data Parallel (FSDP). 27 | -------------------------------------------------------------------------------- /docs/source/intermediate/whats_next.rst: -------------------------------------------------------------------------------- 1 | What's Next 2 | ########### 3 | 4 | Thunder is developing rapidly, and this section mentions some of what's happening. Please reach out (see Get Involved) if you're interested in one of these topics. 5 | 6 | Compiling the Training Loop 7 | =========================== 8 | 9 | Thunder currently supports compiling PyTorch modules - forward computation, loss calculation, backward computation -, but we plan to support compiling the entire training loop - forward computation, loss calculation, backward computation, and the optimizer step - for maximum performance. 10 | 11 | Dynamic Caching 12 | =============== 13 | 14 | Thunder currently supports either no caching or static caching, and static caching requires recompiling whenever a module is called with inputs with metadata different than past inputs. This can be overly strict. For example, adding two tensors with shape ``(5, 5)`` is essentially the same as adding two tensors with shape ``(10, 10)``. Dynamic caching will determine if the new metadata would result in a new trace or not, significantly reducing compilation time when training some models. 15 | 16 | Memory Layouts and Strides 17 | ========================== 18 | 19 | Thunder does not currently model any stride information on tensor proxies. In the future we will likely model some stride information, like memory layout (e.g. channels-last), to support integration with PyTorch programs that use memory layout, and to let executors use memory layout to inform kernel selection. 20 | 21 | Functional transforms: vmap and AMP 22 | =================================== 23 | 24 | Thunder already has early implementations of JAX's vmap transform and PyTorch's Automatic Mixed Precision (AMP) autocasting, and we're extending our support for these transforms so practitioners can easily apply a variety of composable transforms to PyTorch modules. 25 | -------------------------------------------------------------------------------- /docs/source/reference/clang/index.rst: -------------------------------------------------------------------------------- 1 | .. module:: thunder.clang 2 | 3 | thunder.clang 4 | ============= 5 | 6 | Thunder Core Language 7 | 8 | 9 | .. autosummary:: 10 | :toctree: generated/ 11 | 12 | maybe_convert_to_dtype 13 | device_put 14 | arange 15 | convolution 16 | full 17 | full_like 18 | uniform 19 | uniform_like 20 | diagonal 21 | expand 22 | flatten 23 | movedim 24 | reshape 25 | slice_in_dim 26 | squeeze 27 | transpose 28 | stride_order 29 | take 30 | index_add 31 | take_along_axis 32 | scatter_add 33 | unsqueeze 34 | cat 35 | stack 36 | compute_broadcast_shape 37 | matrix_transpose 38 | maybe_broadcast 39 | 40 | Unary 41 | ~~~~~ 42 | 43 | .. autosummary:: 44 | :toctree: generated/ 45 | 46 | abs 47 | acos 48 | acosh 49 | asin 50 | asinh 51 | atan 52 | atanh 53 | bitwise_not 54 | ceil 55 | cos 56 | cosh 57 | erf 58 | erfc 59 | erfcinv 60 | erfinv 61 | exp 62 | exp2 63 | expm1 64 | floor 65 | isfinite 66 | lgamma 67 | log 68 | log10 69 | log1p 70 | log2 71 | ndtri 72 | neg 73 | reciprocal 74 | round 75 | rsqrt 76 | sigmoid 77 | sign 78 | signbit 79 | silu 80 | sin 81 | sinh 82 | sqrt 83 | tan 84 | tanh 85 | trunc 86 | 87 | Binary 88 | ~~~~~~ 89 | 90 | .. autosummary:: 91 | :toctree: generated/ 92 | 93 | add 94 | atan2 95 | bitwise_and 96 | bitwise_or 97 | bitwise_xor 98 | copysign 99 | eq 100 | floor_divide 101 | fmod 102 | mod 103 | ge 104 | gt 105 | logical_and 106 | le 107 | lt 108 | mul 109 | ne 110 | nextafter 111 | pow 112 | remainder 113 | sub 114 | true_divide 115 | 116 | Conditional 117 | ~~~~~~~~~~~ 118 | 119 | .. autosummary:: 120 | :toctree: generated/ 121 | 122 | where 123 | -------------------------------------------------------------------------------- /docs/source/reference/common/index.rst: -------------------------------------------------------------------------------- 1 | .. module:: thunder.common 2 | 3 | thunder.common 4 | ============== 5 | 6 | Common functions and classes for Thunder. 7 | 8 | .. autosummary:: 9 | :toctree: generated/ 10 | 11 | CACHE_OPTIONS 12 | CompileData 13 | CompileStats 14 | -------------------------------------------------------------------------------- /docs/source/reference/core/baseutils.rst: -------------------------------------------------------------------------------- 1 | .. module:: thunder.core.baseutils 2 | :noindex: 3 | 4 | Baseutils 5 | --------- 6 | 7 | .. currentmodule:: thunder.core.baseutils 8 | 9 | .. autosummary:: 10 | :toctree: generated/ 11 | 12 | check 13 | check_type 14 | ProxyInterface 15 | NumberProxyInterface 16 | TensorProxyInterface 17 | SymbolInterface 18 | BoundSymbolInterface 19 | -------------------------------------------------------------------------------- /docs/source/reference/core/codeutils.rst: -------------------------------------------------------------------------------- 1 | .. module:: thunder.core.codeutils 2 | 3 | Codeutils 4 | --------- 5 | 6 | .. currentmodule:: thunder.core.codeutils 7 | 8 | .. autosummary:: 9 | :toctree: generated/ 10 | 11 | prettyprint 12 | SigInfo 13 | get_siginfo 14 | -------------------------------------------------------------------------------- /docs/source/reference/core/devices.rst: -------------------------------------------------------------------------------- 1 | .. module:: thunder.core.devices 2 | 3 | Devices 4 | ------- 5 | 6 | .. currentmodule:: thunder.core.devices 7 | 8 | .. autosummary:: 9 | :toctree: generated/ 10 | 11 | DeviceType 12 | devicetype_string 13 | Device 14 | device_from_string 15 | to_device 16 | -------------------------------------------------------------------------------- /docs/source/reference/core/dtypes.rst: -------------------------------------------------------------------------------- 1 | .. module:: thunder.core.dtypes 2 | 3 | Dtypes 4 | ------ 5 | 6 | ``dtype``\s such as ``float32`` are exposed to :mod:`thunder` like ``thunder.float32``. 7 | 8 | .. currentmodule:: thunder.core.dtypes 9 | 10 | .. autosummary:: 11 | :toctree: generated/ 12 | :nosignatures: 13 | 14 | dtype 15 | to_dtype 16 | float32 17 | float16 18 | bfloat16 19 | float64 20 | int8 21 | int16 22 | int32 23 | int64 24 | uint8 25 | bool8 26 | complex32 27 | complex64 28 | complex128 29 | -------------------------------------------------------------------------------- /docs/source/reference/core/functionalization.rst: -------------------------------------------------------------------------------- 1 | .. module:: thunder.core.functionalization 2 | 3 | In-place Functionalization 4 | -------------------------- 5 | 6 | .. currentmodule:: thunder.core.functionalization 7 | 8 | .. autosummary:: 9 | :toctree: generated/ 10 | 11 | functionalize_inplace_ops 12 | -------------------------------------------------------------------------------- /docs/source/reference/core/index.rst: -------------------------------------------------------------------------------- 1 | .. module:: thunder.core 2 | 3 | thunder.core 4 | ============ 5 | 6 | .. toctree:: 7 | :maxdepth: 1 8 | 9 | baseutils 10 | codeutils 11 | devices 12 | dtypes 13 | langctxs 14 | prims 15 | proxies 16 | pytree 17 | rematerialization 18 | symbol 19 | trace 20 | transforms 21 | functionalization 22 | -------------------------------------------------------------------------------- /docs/source/reference/core/langctxs.rst: -------------------------------------------------------------------------------- 1 | .. module:: thunder.core.langctxs 2 | 3 | langctxs 4 | -------- 5 | 6 | .. currentmodule:: thunder.core.langctxs 7 | 8 | .. autosummary:: 9 | :toctree: generated/ 10 | 11 | set_langctx 12 | get_langctx 13 | reset_langctx 14 | langctx 15 | -------------------------------------------------------------------------------- /docs/source/reference/core/prims.rst: -------------------------------------------------------------------------------- 1 | .. module:: thunder.core.prims 2 | 3 | .. todo Add other prims after resolving "Cannot resolve forward reference in type annotations of "thunder.core.prims.all_reduce": name 'Callable' is not defined" 4 | 5 | Prims 6 | ----- 7 | 8 | .. currentmodule:: thunder.core.prims 9 | 10 | .. autosummary:: 11 | :toctree: generated/ 12 | 13 | PrimIDs 14 | -------------------------------------------------------------------------------- /docs/source/reference/core/proxies.rst: -------------------------------------------------------------------------------- 1 | .. module:: thunder.core.proxies 2 | 3 | Proxies 4 | ------- 5 | 6 | .. currentmodule:: thunder.core.proxies 7 | 8 | .. autosummary:: 9 | :toctree: generated/ 10 | 11 | Variable 12 | variableify 13 | unvariableify 14 | CollectionProxy 15 | pyval 16 | pytype 17 | ComplexProxy 18 | IntegerProxy 19 | FloatProxy 20 | FutureTensorProxy 21 | TensorProxy 22 | is_proxyable 23 | proxy 24 | -------------------------------------------------------------------------------- /docs/source/reference/core/pytree.rst: -------------------------------------------------------------------------------- 1 | .. module:: thunder.core.pytree 2 | 3 | PyTree 4 | ------ 5 | 6 | .. list-table:: pytree API map 7 | :widths: 50 50 8 | :header-rows: 1 9 | 10 | * - thunder.core.pytree 11 | - pytorch pytree 12 | * - :func:`tree_map` 13 | - See https://optree.readthedocs.io/en/latest/ops.html#optree.tree_map 14 | * - :func:`tree_flatten` 15 | - See https://optree.readthedocs.io/en/latest/ops.html#optree.tree_flatten 16 | * - :func:`tree_unflatten` 17 | - See https://optree.readthedocs.io/en/latest/ops.html#optree.tree_unflatten 18 | -------------------------------------------------------------------------------- /docs/source/reference/core/rematerialization.rst: -------------------------------------------------------------------------------- 1 | .. module:: thundre.core.rematerialization 2 | 3 | Rematerialization 4 | ----------------- 5 | 6 | .. currentmodule:: thunder.core.rematerialization 7 | 8 | .. autosummary:: 9 | :toctree: generated/ 10 | 11 | rematerialize 12 | -------------------------------------------------------------------------------- /docs/source/reference/core/symbol.rst: -------------------------------------------------------------------------------- 1 | .. module:: thunder.core.symbol 2 | 3 | Symbol 4 | ------ 5 | 6 | .. currentmodule:: thunder.core.symbol 7 | 8 | .. autosummary:: 9 | :toctree: generated/ 10 | 11 | Symbol 12 | BoundSymbol 13 | BoundSymbolRHS 14 | -------------------------------------------------------------------------------- /docs/source/reference/core/trace.rst: -------------------------------------------------------------------------------- 1 | .. module:: thunder.core.trace 2 | 3 | Trace 4 | ----- 5 | 6 | .. currentmodule:: thunder.core.trace 7 | 8 | .. autosummary:: 9 | :toctree: generated/ 10 | 11 | TraceCtx 12 | from_trace 13 | set_tracectx 14 | get_tracectx 15 | reset_tracectx 16 | -------------------------------------------------------------------------------- /docs/source/reference/core/transforms.rst: -------------------------------------------------------------------------------- 1 | .. module:: thunder.core.transforms 2 | 3 | Transforms 4 | ---------- 5 | 6 | .. currentmodule:: thunder.core.transforms 7 | 8 | .. autosummary:: 9 | :toctree: generated/ 10 | 11 | Node 12 | bsym_list_to_dag 13 | toposort_bsym_dag 14 | insert_inplace 15 | replace_inplace 16 | VISIT_TYPE 17 | visitor_transform 18 | -------------------------------------------------------------------------------- /docs/source/reference/distributed/index.rst: -------------------------------------------------------------------------------- 1 | .. module:: thunder.distributed 2 | 3 | thunder.distributed 4 | =================== 5 | 6 | .. autosummary:: 7 | :toctree: generated/ 8 | 9 | ddp 10 | fsdp 11 | FSDPType 12 | FSDPBucketingStrategy 13 | set_skip_data_parallel_grad_sync 14 | reset_skip_data_parallel_grad_sync 15 | get_skip_data_parallel_grad_sync 16 | skip_data_parallel_grad_sync 17 | column_parallel 18 | row_parallel 19 | -------------------------------------------------------------------------------- /docs/source/reference/dynamo/index.rst: -------------------------------------------------------------------------------- 1 | .. module:: thunder.dynamo 2 | 3 | thunder.dynamo 4 | ============== 5 | 6 | .. autosummary:: 7 | :toctree: 8 | 9 | ThunderCompiler 10 | -------------------------------------------------------------------------------- /docs/source/reference/examine/index.rst: -------------------------------------------------------------------------------- 1 | .. module:: thunder.examine 2 | 3 | thunder.examine 4 | =============== 5 | 6 | .. autosummary:: 7 | :toctree: generated/ 8 | 9 | examine 10 | -------------------------------------------------------------------------------- /docs/source/reference/executors/apexex.rst: -------------------------------------------------------------------------------- 1 | .. module:: thunder.executors.apexex 2 | 3 | APEX Executor 4 | ------------- 5 | 6 | Executor of `NVIDIA/apex `_. 7 | 8 | .. currentmodule:: thunder.executors.apexex 9 | 10 | .. autosummary:: 11 | :toctree: generated/ 12 | 13 | apex_entropy_available 14 | -------------------------------------------------------------------------------- /docs/source/reference/executors/cudnnex.rst: -------------------------------------------------------------------------------- 1 | .. module:: thunder.executors.cudnnex 2 | 3 | cuDNN Executor 4 | -------------- 5 | 6 | Executor of `NVIDIA/cudnn-frontend `_. 7 | 8 | .. currentmodule:: thunder.executors.cudnnex 9 | 10 | .. autosummary:: 11 | :toctree: generated/ 12 | 13 | cudnn_ex 14 | -------------------------------------------------------------------------------- /docs/source/reference/executors/index.rst: -------------------------------------------------------------------------------- 1 | .. module:: thunder.executors 2 | 3 | thunder.executors 4 | ================= 5 | 6 | .. currentmodule:: thunder.executors 7 | 8 | .. autosummary:: 9 | :toctree: generated/ 10 | 11 | get_nvfuser_executor 12 | get_torch_executor 13 | nvfuser_available 14 | 15 | .. toctree:: 16 | :maxdepth: 1 17 | 18 | apexex 19 | cudnnex 20 | nvfuserex 21 | passes 22 | pythonex 23 | torch_compile 24 | torchex 25 | triton_crossentropy 26 | utils 27 | -------------------------------------------------------------------------------- /docs/source/reference/executors/nvfuserex.rst: -------------------------------------------------------------------------------- 1 | .. module:: thunder.executors.nvfuserex 2 | 3 | nvFuser Executor 4 | ---------------- 5 | 6 | An executor dispatches operators to `NVIDIA/Fuser `_. 7 | 8 | .. currentmodule:: thunder.executors.nvfuserex 9 | 10 | .. autosummary:: 11 | :toctree: generated/ 12 | 13 | .. fixme(crcrpar) add methods and classes 14 | -------------------------------------------------------------------------------- /docs/source/reference/executors/passes.rst: -------------------------------------------------------------------------------- 1 | .. module:: thunder.executors.passes 2 | 3 | Optimization Passes 4 | ------------------- 5 | 6 | .. currentmodule:: thunder.executors.passes 7 | 8 | .. autosummary:: 9 | :toctree: generated/ 10 | 11 | transform_for_execution 12 | update_fusion_call_ctx 13 | del_last_used 14 | -------------------------------------------------------------------------------- /docs/source/reference/executors/pythonex.rst: -------------------------------------------------------------------------------- 1 | .. module:: thunder.executors.pythonex 2 | 3 | 4 | Python Executor 5 | --------------- 6 | 7 | .. currentmodule:: thunder.executors.pythonex 8 | 9 | .. autosummary:: 10 | :toctree: generated/ 11 | -------------------------------------------------------------------------------- /docs/source/reference/executors/torch_compile.rst: -------------------------------------------------------------------------------- 1 | .. module:: thunder.executors.torch_compile 2 | 3 | Torch Compile Executor 4 | ---------------------- 5 | 6 | .. currentmodule:: thunder.executors.torch_compile 7 | 8 | .. autosummary:: 9 | :toctree: generated/ 10 | 11 | make_compiled 12 | TorchCompileExecutor 13 | -------------------------------------------------------------------------------- /docs/source/reference/executors/torchex.rst: -------------------------------------------------------------------------------- 1 | .. module:: thunder.executors.torchex 2 | 3 | 4 | PyTorch Executor 5 | ---------------- 6 | 7 | .. currentmodule:: thunder.executors.torch_autograd 8 | 9 | .. autosummary:: 10 | :toctree: generated/ 11 | -------------------------------------------------------------------------------- /docs/source/reference/executors/triton_crossentropy.rst: -------------------------------------------------------------------------------- 1 | .. module:: thunder.executors.triton_crossentropy 2 | 3 | Triton Executor 4 | --------------- 5 | 6 | Executor of `openai/triton `_. 7 | 8 | .. currentmodule:: thunder.executors.triton_crossentropy 9 | 10 | .. autosummary:: 11 | :toctree: generated/ 12 | -------------------------------------------------------------------------------- /docs/source/reference/executors/utils.rst: -------------------------------------------------------------------------------- 1 | .. module:: thunder.executors.utils 2 | 3 | Utils 4 | ----- 5 | 6 | .. currentmodule:: thunder.executors.utils 7 | 8 | .. autosummary:: 9 | :toctree: generated/ 10 | 11 | Region 12 | -------------------------------------------------------------------------------- /docs/source/reference/extend/index.rst: -------------------------------------------------------------------------------- 1 | .. module:: thunder.extend 2 | 3 | thunder.extend 4 | ============== 5 | 6 | .. autosummary:: 7 | :toctree: 8 | 9 | register_executor 10 | deregister_executor 11 | get_all_executors 12 | get_default_executors 13 | get_always_executors 14 | get_executor 15 | set_default_executors 16 | set_always_executors 17 | add_default_executor 18 | add_always_executor 19 | remove_default_executor 20 | remove_always_executor 21 | 22 | Executor 23 | -------------------------------------------------------------------------------- /docs/source/reference/recipes/index.rst: -------------------------------------------------------------------------------- 1 | .. module:: thunder.recipes 2 | 3 | thunder.recipes 4 | ================== 5 | 6 | .. autosummary:: 7 | :toctree: generated/ 8 | 9 | BaseRecipe 10 | HFTransformers 11 | -------------------------------------------------------------------------------- /docs/source/reference/thunder.rst: -------------------------------------------------------------------------------- 1 | .. module:: thunder 2 | 3 | thunder 4 | ======= 5 | 6 | 7 | Compiling functions and modules 8 | ------------------------------- 9 | 10 | 11 | .. autosummary:: 12 | :toctree: generated/ 13 | 14 | jit 15 | 16 | 17 | Querying information on compiled functions and modules 18 | ------------------------------------------------------ 19 | 20 | 21 | .. autosummary:: 22 | :toctree: generated/ 23 | 24 | DebugOptions 25 | compile_data 26 | compile_stats 27 | last_traces 28 | last_backward_traces 29 | last_prologue_traces 30 | cache_option 31 | cache_hits 32 | cache_misses 33 | list_transforms 34 | last_interpreted_instructions 35 | last_interpreter_log 36 | last_compile_options 37 | .. 38 | compile 39 | grad 40 | 41 | JITed Model wrapper 42 | ------------------- 43 | 44 | .. autoclass:: ThunderModule 45 | :members: no_sync 46 | :exclude-members: forward 47 | -------------------------------------------------------------------------------- /docs/source/reference/torch/index.rst: -------------------------------------------------------------------------------- 1 | .. module:: thunder.torch 2 | 3 | thunder.torch 4 | ------------- 5 | 6 | A PyTorch dialect in thunder. 7 | 8 | .. currentmodule:: thunder.torch 9 | 10 | .. autosummary:: 11 | :toctree: generated/ 12 | 13 | torchsymbol 14 | size 15 | to_float 16 | 17 | Under Construction 18 | 19 | .. fixme "Cannot resolve forward reference in type annotations of "thunder.torch.abs": name 'Callable' is not defined" 20 | 21 | Operators 22 | ~~~~~~~~~ 23 | 24 | Unary 25 | ~~~~~ 26 | 27 | Binary 28 | ~~~~~~ 29 | 30 | Conditional 31 | ~~~~~~~~~~~ 32 | 33 | Tensor Creation 34 | ~~~~~~~~~~~~~~~ 35 | 36 | Shape Operation 37 | ~~~~~~~~~~~~~~~ 38 | -------------------------------------------------------------------------------- /docs/source/reference/transforms/index.rst: -------------------------------------------------------------------------------- 1 | .. module:: thunder.transforms 2 | 3 | thunder.transforms 4 | ================== 5 | 6 | .. autosummary:: 7 | :toctree: generated/ 8 | 9 | MaterializationTransform 10 | ConstantFolding 11 | -------------------------------------------------------------------------------- /examples/quickstart/hf_bert.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import transformers 3 | 4 | import thunder 5 | 6 | from thunder.dev_utils.benchmark import benchmark 7 | 8 | 9 | def main(): 10 | # model_name = "bert-large-uncased" 11 | model_name = "bert-base-uncased" 12 | 13 | device = "cuda:0" if torch.cuda.is_available() else "cpu" 14 | 15 | tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) 16 | 17 | with torch.device(device): 18 | model = transformers.AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16) 19 | model.requires_grad_(False) 20 | model.eval() 21 | # apparently, Transformers 4.51.3 does not instantiate models on the default device 22 | model.to(device) 23 | 24 | inp = tokenizer(["Hello world!"], return_tensors="pt") 25 | 26 | print(f"Eager: {benchmark(model, **inp):.2f}ms") 27 | 28 | thunder_model = thunder.compile( 29 | model, 30 | recipe="hf-transformers", 31 | ) 32 | 33 | print(f"Thunder: {benchmark(thunder_model, **inp):.2f}ms") 34 | 35 | if torch.cuda.is_available(): 36 | thunder_model = thunder.compile(model, plugins="reduce-overhead") 37 | 38 | print(f"Thunder with 'reduce-overhead': {benchmark(thunder_model, **inp):.2f}ms") 39 | 40 | 41 | if __name__ == "__main__": 42 | torch.backends.cuda.matmul.allow_tf32 = True 43 | 44 | main() 45 | -------------------------------------------------------------------------------- /examples/quickstart/hf_llm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import transformers 3 | 4 | import thunder 5 | 6 | from thunder.dev_utils.benchmark import benchmark_n 7 | 8 | 9 | def main(): 10 | # model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" 11 | # model_name = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B" 12 | # model_name = "meta-llama/Llama-3.1-8B" 13 | model_name = "meta-llama/Llama-3.2-1B" 14 | # model_name = "Qwen/Qwen2.5-7B-Instruct" 15 | # model_name = "microsoft/phi-4" 16 | 17 | device = "cuda:0" if torch.cuda.is_available() else "cpu" 18 | 19 | tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) 20 | 21 | with torch.device(device): 22 | model = transformers.AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16) 23 | model.requires_grad_(False) 24 | model.eval() 25 | # apparently, Transformers 4.51.3 does not instantiate models on the default device 26 | model.to(device) 27 | 28 | inp = tokenizer(["Hello world! Here's a long story"], return_tensors="pt") 29 | 30 | def generate(model, inp, cache=None): 31 | out = model.generate(**inp, do_sample=False, cache_implementation=cache, max_new_tokens=100) 32 | print(tokenizer.decode(out[0].tolist())) 33 | 34 | print("\nGenerating with PyTorch eager:") 35 | eager_time = benchmark_n(2, generate, model, inp) 36 | 37 | thunder_model = thunder.compile( 38 | model, 39 | recipe="hf-transformers", 40 | ) 41 | 42 | print("\nGenerating with Thunder:") 43 | thunder_time = benchmark_n(2, generate, thunder_model, inp, cache="static") 44 | 45 | print(f"\nEager: {eager_time:.2f}ms") 46 | print(f"Thunder: {thunder_time:.2f}ms") 47 | 48 | 49 | if __name__ == "__main__": 50 | torch.backends.cuda.matmul.allow_tf32 = True 51 | 52 | main() 53 | -------------------------------------------------------------------------------- /examples/quickstart/mlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import thunder 5 | 6 | 7 | def main(): 8 | model = nn.Sequential( 9 | nn.Linear(2048, 4096), 10 | nn.ReLU(), 11 | nn.Linear(4096, 64) 12 | ) 13 | 14 | thunder_model = thunder.compile(model) 15 | x = torch.randn(64, 2048) 16 | y = thunder_model(x) 17 | 18 | print(thunder_model) 19 | 20 | print(thunder.last_traces(thunder_model)[-1]) 21 | 22 | 23 | if __name__ == "__main__": 24 | main() 25 | -------------------------------------------------------------------------------- /examples/quickstart/requirements.txt: -------------------------------------------------------------------------------- 1 | transformers 2 | accelerate 3 | nvfuser-cu128-torch27 4 | -------------------------------------------------------------------------------- /examples/quickstart/vit_hf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import ViTForImageClassification 3 | 4 | import thunder 5 | 6 | from thunder.dev_utils.benchmark import benchmark 7 | 8 | 9 | def main(): 10 | device = "cuda:0" if torch.cuda.is_available() else "cpu" 11 | 12 | with torch.device(device): 13 | model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224", attn_implementation="sdpa", torch_dtype=torch.float32) 14 | model.requires_grad_(False) 15 | model.eval() 16 | # apparently, Transformers 4.51.3 does not instantiate models on the default device 17 | model.to(device) 18 | 19 | inp = torch.randn(128, 3, 224, 224) 20 | 21 | out = model(inp) 22 | 23 | thunder_model = thunder.compile(model, plugins="reduce-overhead" if torch.cuda.is_available() else None) 24 | 25 | thunder_out = thunder_model(inp) 26 | 27 | torch.testing.assert_close(out, thunder_out, atol=1e-2, rtol=1e-2) 28 | 29 | print(f"Eager: {benchmark(model, inp):.2f}ms") 30 | print(f"Thunder: {benchmark(thunder_model, inp):.2f}ms") 31 | 32 | 33 | if __name__ == "__main__": 34 | torch.set_float32_matmul_precision('high') 35 | 36 | main() 37 | -------------------------------------------------------------------------------- /examples/quickstart/vit_tv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.models as models 3 | 4 | import thunder 5 | 6 | from thunder.dev_utils.benchmark import benchmark 7 | 8 | 9 | def main(): 10 | device = "cuda:0" if torch.cuda.is_available() else "cpu" 11 | 12 | with torch.device(device): 13 | model = models.vit_b_16() 14 | model.requires_grad_(False) 15 | model.eval() 16 | 17 | inp = torch.randn(128, 3, 224, 224) 18 | 19 | out = model(inp) 20 | 21 | thunder_model = thunder.compile(model, plugins="reduce-overhead" if torch.cuda.is_available() else None) 22 | 23 | thunder_out = thunder_model(inp) 24 | 25 | # print(thunder.last_traces(thunder_model)[-1]) 26 | 27 | torch.testing.assert_close(out, thunder_out) 28 | 29 | print(f"Eager: {benchmark(model, inp):.2f}ms") 30 | print(f"Thunder: {benchmark(thunder_model, inp):.2f}ms") 31 | 32 | 33 | if __name__ == "__main__": 34 | torch.set_float32_matmul_precision('high') 35 | 36 | main() 37 | -------------------------------------------------------------------------------- /notebooks/.ignore.ci: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-AI/lightning-thunder/8bd802c89988e6668f5328d3927deda1b4e9e706/notebooks/.ignore.ci -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | -r requirements/base.txt 2 | -------------------------------------------------------------------------------- /requirements/base.txt: -------------------------------------------------------------------------------- 1 | torch >=2.3.0 2 | looseversion ==1.3.0 3 | lightning-utilities >=0.7.0 4 | numpy 5 | networkx >= 3.3 6 | optree >=0.12.1 7 | opt_einsum >= 3.3.0 8 | # todo: teporarl pin for `NameError: name '_C' is not defined` 9 | mpmath <1.4.0 10 | # Support for 3.12 11 | dill >=0.3.8 12 | -------------------------------------------------------------------------------- /requirements/devel.txt: -------------------------------------------------------------------------------- 1 | packaging 2 | -r test.txt 3 | -------------------------------------------------------------------------------- /requirements/docs.txt: -------------------------------------------------------------------------------- 1 | sphinx ==5.3.0 2 | myst-parser ==1.0.0 3 | nbsphinx ~=0.9.7 4 | ipython[all] ~=8.37.0 5 | pandoc ==2.4 6 | docutils >=0.16 7 | sphinxcontrib-fulltoc @ git+https://github.com/t-vi/sphinx-contrib-fulltoc@master 8 | sphinxcontrib-mockautodoc 9 | 10 | sphinx-autodoc-typehints ==1.23.0 11 | sphinx-paramlinks ==0.6.0 12 | sphinx-togglebutton ==0.3.2 13 | sphinx-copybutton ==0.5.2 14 | snowballstemmer < 4 15 | 16 | # installed from S3 location and fetched in advance 17 | lai-sphinx-theme 18 | # alternative /back-up (old) theme 19 | pt-lightning-sphinx-theme @ https://github.com/PyTorchLightning/lightning_sphinx_theme/archive/master.zip 20 | -------------------------------------------------------------------------------- /requirements/notebooks.txt: -------------------------------------------------------------------------------- 1 | ipython[all] ~=8.37.0 2 | numpy 3 | liger-kernel == 0.4.0 4 | cuda-python 5 | litgpt == 0.5.1 6 | -------------------------------------------------------------------------------- /requirements/test.txt: -------------------------------------------------------------------------------- 1 | coverage ~=7.8.2 2 | pytest ==8.3.5 3 | pytest-benchmark ==5.1.0 4 | pytest-timeout ==2.4.0 5 | pytest-cov ==6.1.1 6 | pytest-xdist ==3.7.0 7 | pytest-random-order ==1.1.1 8 | pytest-timestamper ==0.0.10 9 | graphviz ==0.20.3 10 | fdm ==0.5.0 11 | expecttest ==0.3.0 # for test_ddp.py 12 | hypothesis ~=6.133.0 # for test_ddp.py 13 | numpy 14 | einops # for test_einops.py 15 | litgpt==0.4.11 # for the model definition in tests and benchmarks 16 | absl-py # thunder/benchmarks/test_benchmark_litgpt.py 17 | pandas # thunder/benchmarks/test_benchmark_litgpt.py 18 | xlsxwriter # thunder/benchmarks/test_benchmark_litgpt.py 19 | jsonargparse # thunder/benchmarks/benchmark_litgpt.py 20 | bitsandbytes==0.42.0 # fixed version! 21 | transformers==4.52.4 # for test_networks.py 22 | 23 | # Installs JAX on Linux and MacOS 24 | jaxlib; sys_platform == 'linux' or sys_platform == 'darwin' # required for jax, see https://github.com/google/jax#installation 25 | jax; sys_platform == 'linux' or sys_platform == 'darwin' # for test_ops.py 26 | 27 | asvdb @ git+https://github.com/rapidsai/asvdb.git 28 | asv >=0.6.4 29 | -------------------------------------------------------------------------------- /scripts/bisect_nvfuser.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import subprocess 4 | 5 | 6 | class Runner: 7 | def __init__(self, command_and_args: list[str], verbose: bool): 8 | self._command_and_args = command_and_args 9 | self._verbose = verbose 10 | 11 | def env_var_for_fuel(self) -> str: 12 | return "NVFUSER_OPTIMIZATION_FUEL" 13 | 14 | def run(self, fuel: int) -> int: 15 | os.environ[self.env_var_for_fuel()] = str(fuel) 16 | print(f"Running {self._command_and_args} with {fuel} units of fuel...") 17 | result = subprocess.run( 18 | self._command_and_args, 19 | check=False, 20 | stdout=(None if self._verbose else subprocess.DEVNULL), 21 | stderr=(None if self._verbose else subprocess.DEVNULL), 22 | ) 23 | print(f"Command returned {result.returncode} with {fuel} units of fuel.") 24 | return result.returncode 25 | 26 | 27 | if __name__ == "__main__": 28 | parser = argparse.ArgumentParser( 29 | description=""" 30 | A script that bisects nvFuser's optimization fuel to isolate a compiler bug. 31 | 32 | See `thunder.extend.FusionExecutor.get_fuel/set_fuel` for what optimization fuel is and how it is implemented. 33 | 34 | This script finds the minimum unit of optimization fuel between `lower_bound` and `upper_bound` that causes `command_and_args` to fail (i.e. exit with non-zero). The user will then run `NVFUSER_OPTIMIZATION_FUEL= ` as a minimal reproducer to identify the nvFusion that triggers the failure. Likely, the last nvFusion generated is the culprit. 35 | """ 36 | ) 37 | parser.add_argument("lower_bound", type=int, help="the lower bound for bisecting") 38 | parser.add_argument("upper_bound", type=int, help="the upper bound for bisecting") 39 | parser.add_argument("command_and_args", type=str, nargs=argparse.REMAINDER, help="the command and args to run") 40 | parser.add_argument("--verbose", action="store_true", help="whether to show stdout/stderr of the subprocess") 41 | args = parser.parse_args() 42 | 43 | runner = Runner(args.command_and_args, args.verbose) 44 | 45 | # The cornercases are not needed for the correctness of binary search. They are for catching errors earlier when the user specified a wrong lower/upper bound. 46 | if (exitcode := runner.run(args.lower_bound)) != 0: 47 | print("No need to bisect. Command failed with the lower bound.") 48 | exit(0) 49 | 50 | if (exitcode := runner.run(args.upper_bound)) == 0: 51 | print("Bisecting failed. Command passed with the upper bound.") 52 | exit(1) 53 | 54 | # Find the smallest fuel that fails `command_and_args`. 55 | low = args.lower_bound + 1 # +1 because we know `lower_bound` passed. 56 | high = args.upper_bound 57 | while low < high: 58 | mid = (low + high) // 2 59 | exitcode = runner.run(mid) 60 | if exitcode == 0: 61 | low = mid + 1 62 | else: 63 | high = mid 64 | assert low == high 65 | 66 | print(f"Bisecting succeeded. Run the following command as a minimal reproducer:") 67 | print(f" {runner.env_var_for_fuel()}={low} {' '.join(args.command_and_args)}") 68 | print(f"The last nvFusion likely triggered the failure.") 69 | -------------------------------------------------------------------------------- /scripts/validate_build.py: -------------------------------------------------------------------------------- 1 | """Check that `scripts/build_from_source.sh` has produced a correct build.""" 2 | 3 | import os 4 | 5 | BUILD_DIR = os.path.abspath(os.path.join(os.path.split(__file__)[0], "build")) 6 | 7 | 8 | def main(): 9 | import torch 10 | 11 | expected = os.path.join(BUILD_DIR, "pytorch", "torch", "__init__.py") 12 | actual = os.path.abspath(torch.__file__) 13 | assert expected == actual, f"{expected} vs. {actual}" 14 | 15 | import nvfuser 16 | from looseversion import LooseVersion 17 | 18 | assert LooseVersion(nvfuser.version()) >= LooseVersion("0.0.6"), nvfuser.version() 19 | 20 | import thunder 21 | 22 | expected = os.path.abspath(os.path.join(BUILD_DIR, "..", "..", "thunder", "__init__.py")) 23 | assert expected == thunder.__file__, f"{expected} vs. {thunder.__file__}" 24 | 25 | print("Build looks good!") 26 | 27 | 28 | if __name__ == "__main__": 29 | main() 30 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | from importlib.util import module_from_spec, spec_from_file_location 4 | 5 | from packaging.requirements import Requirement as parse_requirements 6 | from setuptools import find_packages, setup 7 | 8 | 9 | _PATH_ROOT = os.path.dirname(__file__) 10 | _PATH_REQUIRES = os.path.join(_PATH_ROOT, "requirements") 11 | # check if os env. variable is set to convert version to nightly 12 | _CONVERT_VERSION = int(os.environ.get("CONVERT_VERSION2NIGHTLY", 0)) 13 | 14 | 15 | def _load_py_module(fname, pkg="thunder"): 16 | spec = spec_from_file_location(os.path.join(pkg, fname), os.path.join(_PATH_ROOT, pkg, fname)) 17 | py = module_from_spec(spec) 18 | spec.loader.exec_module(py) 19 | return py 20 | 21 | 22 | def _load_requirements(path_dir: str, file_name: str = "requirements.txt") -> list[str]: 23 | """Parse requirements file -> list of strings""" 24 | reqs: list[str] = [] 25 | with open(os.path.join(path_dir, file_name)) as f: 26 | for line in f: 27 | if line and not line.startswith("#"): 28 | reqs.append(line.strip()) 29 | # Filter out requirements referring to local paths or specific URLs (if any) 30 | reqs = [str(parse_requirements(r)) for r in reqs if "@" not in r and "://" not in r] 31 | return reqs 32 | 33 | 34 | def convert_version2nightly(about_file: str = "thunder/__about__.py") -> None: 35 | """Load the actual version and convert it to the nightly version.""" 36 | from datetime import datetime 37 | 38 | # load the about file 39 | with open(about_file) as fo: 40 | lines = fo.readlines() 41 | idx = None 42 | # find the line with version 43 | for i, ln in enumerate(lines): 44 | if ln.startswith("__version__"): 45 | idx = i 46 | break 47 | if idx is None: 48 | raise ValueError("The version is not found in the `__about__.py` file.") 49 | # parse the version from variable assignment 50 | version = lines[idx].split("=")[1].strip().strip('"') 51 | # parse X.Y.Z version and prune any suffix 52 | vers = re.match(r"(\d+)\.(\d+)\.(\d+).*", version) 53 | # create timestamp YYYYMMDD 54 | timestamp = datetime.now().strftime("%Y%m%d") 55 | version = f"{'.'.join(vers.groups())}.dev{timestamp}" 56 | # print the new version 57 | lines[idx] = f'__version__ = "{version}"\n' 58 | # dump updated lines 59 | with open(about_file, "w") as fo: 60 | fo.writelines(lines) 61 | 62 | 63 | def _load_readme_description(path_dir: str, homepage: str, version: str) -> str: 64 | """Load readme as decribtion.""" 65 | path_readme = os.path.join(path_dir, "README.md") 66 | with open(path_readme, encoding="utf-8") as fp: 67 | text = fp.read() 68 | # https://github.com/Lightning-AI/lightning-thunder/raw/master/docs/source/_static/images/lightning_module/pt_to_pl.png 69 | github_source_url = os.path.join(homepage, "raw", version) 70 | # replace relative repository path to absolute link to the release 71 | # do not replace all "docs" as in the readme we replace some other sources with particular path to docs 72 | text = text.replace("docs/source/_static/", f"{os.path.join(github_source_url, 'docs/source/_static/')}") 73 | return text 74 | 75 | 76 | if _CONVERT_VERSION: 77 | convert_version2nightly() 78 | 79 | about = _load_py_module("__about__.py") 80 | 81 | setup( 82 | version=about.__version__, 83 | packages=find_packages(exclude=["thunder/tests", "docs"]), 84 | long_description=_load_readme_description( 85 | path_dir=_PATH_ROOT, 86 | homepage=about.__homepage__, 87 | version=about.__version__, 88 | ), 89 | long_description_content_type="text/markdown", 90 | install_requires=_load_requirements(_PATH_REQUIRES, file_name="base.txt"), 91 | include_package_data=True, 92 | zip_safe=False, 93 | ) 94 | -------------------------------------------------------------------------------- /thunder/__about__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.2.4.dev0" 2 | __author__ = "Lightning-AI et al" 3 | __author_email__ = "community@lightning.ai" 4 | __copyright__ = f"2024 {__author__}" 5 | __homepage__ = "https://github.com/Lightning-AI/lightning-thunder" 6 | __docs__ = "Lightning Thunder project." 7 | 8 | 9 | __all__ = [ 10 | "__author__", 11 | "__author_email__", 12 | "__copyright__", 13 | "__docs__", 14 | "__homepage__", 15 | "__version__", 16 | ] 17 | -------------------------------------------------------------------------------- /thunder/benchmarks/einsum.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Callable 2 | from functools import partial, wraps 3 | from collections.abc import Sequence 4 | 5 | import pytest 6 | import torch 7 | import thunder 8 | 9 | from thunder.benchmarks import EinsumBenchmark 10 | from thunder.benchmarks.targets import ( 11 | make_setup, 12 | fwd_executors, 13 | fwd_executor_ids, 14 | grad_executors, 15 | grad_executors_ids, 16 | thunder_gradv1, 17 | thunder_torchcompile_gradv1, 18 | ) 19 | 20 | 21 | def _instantiate_benchmark_env( 22 | shapes: Sequence[Sequence[int]], 23 | equation: str, 24 | executor: Callable, 25 | device: str = "cuda:0", 26 | dtype: thunder.dtypes.dtype = thunder.bfloat16, 27 | requires_grad: bool = False, 28 | ) -> Callable: 29 | bench: Benchmark = EinsumBenchmark( 30 | shapes=shapes, equation=equation, device=device, dtype=dtype, requires_grad=requires_grad 31 | ) 32 | 33 | setup = make_setup(bench) 34 | fn = executor(bench) 35 | 36 | return setup, fn 37 | 38 | 39 | _test_size_map = { 40 | "small": 16, 41 | "medium": 128, 42 | "large": 512, 43 | } 44 | 45 | 46 | def single_dim_contraction_cases(size: str = "small", left_broadcasts: None | bool = None): 47 | n = _test_size_map[size] 48 | 49 | lhs_shape = [2, n, n] 50 | rhs_shape = [n, n] 51 | 52 | if left_broadcasts is not None: 53 | if left_broadcasts: 54 | lhs_shape[-1] = 1 55 | else: 56 | rhs_shape[0] = 1 57 | 58 | return (lhs_shape, rhs_shape), "bij,jk" 59 | 60 | 61 | def multidim_contraction_cases(size: str = "small", left_broadcasts: None | bool = None): 62 | n = _test_size_map[size] 63 | 64 | lhs_shape = [2, 8, n, n] 65 | rhs_shape = [2, n, 8, n] 66 | 67 | if left_broadcasts is not None: 68 | if left_broadcasts: 69 | lhs_shape[-1] = 1 70 | else: 71 | rhs_shape[-1] = 1 72 | 73 | return (lhs_shape, rhs_shape), "bijk,bklj->bli" 74 | 75 | 76 | class TestEinsumBenchmarks: 77 | @pytest.mark.parametrize( 78 | "executor,", 79 | fwd_executors, 80 | ids=fwd_executor_ids, 81 | ) 82 | @pytest.mark.parametrize( 83 | "size,", 84 | _test_size_map.keys(), 85 | ) 86 | @pytest.mark.parametrize( 87 | "sample_gen,", (single_dim_contraction_cases, multidim_contraction_cases), ids=("singledim", "multidim") 88 | ) 89 | @pytest.mark.parametrize("left_broadcasts,", (None, True, False), ids=("", "left_broadcasts", "right_broadcasts")) 90 | def test_einsum_fwd( 91 | self, benchmark, executor: None | Callable, size: str, sample_gen: Callable, left_broadcasts: None | bool 92 | ): 93 | setup, fn = _instantiate_benchmark_env( 94 | *sample_gen(size, left_broadcasts), executor=executor, requires_grad=False 95 | ) 96 | benchmark.pedantic(fn, setup=setup, rounds=20, warmup_rounds=1) 97 | 98 | @pytest.mark.parametrize( 99 | "executor,", 100 | [ge for ge in grad_executors if ge not in (thunder_gradv1, thunder_torchcompile_gradv1)], 101 | ids=[gei for gei in grad_executors_ids if gei not in ("thunder-gradv1", "thunder+torchcompile_cat-gradv1")], 102 | ) 103 | @pytest.mark.parametrize( 104 | "size,", 105 | _test_size_map.keys(), 106 | ) 107 | @pytest.mark.parametrize( 108 | "sample_gen,", (single_dim_contraction_cases, multidim_contraction_cases), ids=("singledim", "multidim") 109 | ) 110 | @pytest.mark.parametrize( 111 | "left_broadcasts,", 112 | # False/right_broadcasts is disabled because of 113 | # https://github.com/NVIDIA/Fuser/issues/1590. 114 | # TODO: update once the issue is fixed. 115 | (None, True), 116 | ids=("", "left_broadcasts"), 117 | ) 118 | def test_einsum_grad( 119 | self, benchmark, executor: None | Callable, size: str, sample_gen: Callable, left_broadcasts: None | bool 120 | ): 121 | setup, fn = _instantiate_benchmark_env( 122 | *sample_gen(size, left_broadcasts), executor=executor, requires_grad=True 123 | ) 124 | benchmark.pedantic(fn, setup=setup, rounds=20, warmup_rounds=1) 125 | -------------------------------------------------------------------------------- /thunder/clang/langctx.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | from collections.abc import Callable, Sequence 3 | 4 | from thunder.core.langctxs import LanguageContext, register_langctx, Languages, resolve_language 5 | from thunder.core.pytree import tree_flatten 6 | from thunder.core.proxies import TensorProxy, NumberProxy 7 | 8 | # 9 | # Creates and registers the torch language context 10 | # 11 | # NOTE That this is done separately from the definition of thunder.torch operations, because the 12 | # language context must be available before those operations are defined 13 | 14 | _method_name_to_fn_map: dict[str, Callable] = {} 15 | 16 | 17 | # Creates and registers the core language language context 18 | class ClangCtx(LanguageContext): 19 | def __init__(self): 20 | super().__init__("core") 21 | 22 | def has_method(self, id: str) -> bool: 23 | return id in _method_name_to_fn_map 24 | 25 | def get_method(self, id: str, *args, **kwargs) -> Callable: 26 | # Note: concrete implmenetations should only raise AttributeError or 27 | # return None for "missing" methods as the proxies will 28 | # route __getattr__ to here and hasattr relies on __getattr__ 29 | # throwing AttributeError (only) when the attribute does 30 | # not exist. 31 | inps, _ = tree_flatten((args, kwargs)) 32 | 33 | has_proxy_input: bool = False 34 | for x in inps: 35 | if isinstance(x, TensorProxy) or isinstance(x, NumberProxy): 36 | has_proxy_input = True 37 | break 38 | 39 | if has_proxy_input: 40 | method: None | Callable = _method_name_to_fn_map.get(id, None) 41 | if method is None: 42 | raise AttributeError(f"The {self.name} language context has no method {id}") 43 | return method 44 | 45 | # has_proxy_input is False 46 | # Defers to the primitive language context when there are no tensor inputs= 47 | # (the primitive language context handles operations on numbers) 48 | primsctx: LanguageContext = resolve_language(Languages.PRIMS) 49 | if not primsctx.has_method(id): 50 | raise AttributeError( 51 | f"Attempting to call method {id} in the core language context, but it has no tensor inputs and the primitive language context (which handles numbers) doesn't have the method" 52 | ) 53 | prim_method: Callable = primsctx.get_method(id, *args, **kwargs) 54 | return prim_method 55 | 56 | 57 | clangctx = ClangCtx() 58 | register_langctx(Languages.CLANG, clangctx) 59 | 60 | 61 | # Registers a method with the torch language context 62 | def register_method(method_name: str, method: Callable, /) -> None: 63 | _method_name_to_fn_map[method_name] = method 64 | -------------------------------------------------------------------------------- /thunder/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-AI/lightning-thunder/8bd802c89988e6668f5328d3927deda1b4e9e706/thunder/core/__init__.py -------------------------------------------------------------------------------- /thunder/core/compile_data.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | from contextvars import ContextVar 3 | from typing import Any 4 | 5 | from thunder.core.options import CACHE_OPTIONS 6 | 7 | # 8 | # Setting and querying the "compile_data_and_stats" context variable, which contains 9 | # a tuple of (CompileData, CompileStats) objects for the current trace. 10 | # 11 | 12 | _compile_data = ContextVar("compile_data", default=(None, None)) 13 | 14 | 15 | # 16 | # Setting, getting, and resetting the context variable 17 | # 18 | 19 | 20 | # NOTE This just acquires the compile data part of the context var's tuple 21 | def get_compile_data() -> None | Any: 22 | """Returns the current compile data.""" 23 | cd, cs = _compile_data.get() 24 | 25 | return cd 26 | 27 | 28 | def set_compile_data_and_stats(cd, cs, /): 29 | """Sets the current compile data. 30 | 31 | This is used to pass compile data to functions that are called during compilation. 32 | """ 33 | token = _compile_data.set((cd, cs)) 34 | return token 35 | 36 | 37 | def reset_compile_data_and_stats(token, /): 38 | """Resets the compile data.""" 39 | _compile_data.reset(token) 40 | 41 | 42 | @contextmanager 43 | def compile_data_and_stats(cd, cs, /): 44 | """Sets the current compile data for the duration of the context.""" 45 | token = set_compile_data_and_stats(cd, cs) 46 | try: 47 | yield 48 | finally: 49 | reset_compile_data_and_stats(token) 50 | 51 | 52 | # 53 | # Query helpers 54 | # 55 | 56 | 57 | def get_compile_option(option: str, description: str, /) -> None | Any: 58 | cd, cs = _compile_data.get() 59 | 60 | if cd is None or cs is None: 61 | return None 62 | 63 | # See NOTE Different categories of compile options in thunder/__init__.py 64 | cs.last_compile_reasons[option].append(description) 65 | return cd.compile_options.get(option, None) 66 | 67 | 68 | # Whether or not the caching option uses symbolic values 69 | def get_cache_option() -> CACHE_OPTIONS: 70 | cd = get_compile_data() 71 | if cd is None: 72 | return CACHE_OPTIONS.CONSTANT_VALUES 73 | return cd.cache_option 74 | 75 | 76 | # TODO RC1 Remove the try (hack for when operating outside of this contextvar being set) 77 | def using_symbolic_values() -> bool: 78 | try: 79 | return get_cache_option() is CACHE_OPTIONS.SYMBOLIC_VALUES 80 | except: 81 | return False 82 | 83 | 84 | def using_jit() -> bool: 85 | try: 86 | cd, cs = _compile_data.get() 87 | return cd.using_jit 88 | except: 89 | return False 90 | -------------------------------------------------------------------------------- /thunder/core/profile.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Callable 2 | import contextlib 3 | import functools 4 | import os 5 | import warnings 6 | 7 | import torch 8 | 9 | 10 | _ENABLED = os.getenv("THUNDER_ANNOTATE_TRACES") in ("1", "y", "Y") 11 | 12 | # However, nvtx is incredibly cheap so we no longer bother requiring the 13 | # environment variable. 14 | try: 15 | import nvtx 16 | except ImportError: 17 | if _ENABLED: 18 | msg = "Requested nvtx but the package is not available." 19 | msg += "\nUse `pip install -m pip install nvtx`." 20 | warnings.warn(msg) 21 | _ENABLED = False 22 | raise 23 | 24 | 25 | def profiling_enabled() -> bool: 26 | return _ENABLED 27 | 28 | 29 | @contextlib.contextmanager 30 | def add_markers(msg: str) -> None: 31 | if not profiling_enabled(): 32 | yield 33 | return 34 | 35 | assert "\n" not in msg, msg # Both NVTX and JSON forbid newlines 36 | assert '"' not in msg, msg # The PyTorch profiler does not properly escape quotations 37 | 38 | with torch.profiler.record_function(msg): 39 | torch.cuda.nvtx.range_push(msg) 40 | try: 41 | yield 42 | 43 | finally: 44 | torch.cuda.nvtx.range_pop() 45 | 46 | 47 | # The main interface to profiling something. Generally used as a decorator: 48 | # @thunder.core.profile.annotate_for_profile("foo") 49 | # def foo(...): ... 50 | # but alternatively as a `with` context: 51 | # with thunder.core.profile.annotate_for_profile("name for a block of code"): 52 | # # ... code ... 53 | annotate_for_profile: Callable[[str], None] = None 54 | 55 | 56 | if _ENABLED: 57 | annotate_for_profile = functools.partial(nvtx.annotate, domain="thunder") 58 | else: 59 | 60 | class _no_annotate(contextlib.nullcontext): 61 | """ 62 | A profiling decorator that does nothing. 63 | """ 64 | 65 | def __init__(self, *args, **kwargs): 66 | super().__init__() 67 | 68 | def __call__(self, fqn): 69 | return fqn 70 | 71 | annotate_for_profile = _no_annotate 72 | -------------------------------------------------------------------------------- /thunder/dev_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-AI/lightning-thunder/8bd802c89988e6668f5328d3927deda1b4e9e706/thunder/dev_utils/__init__.py -------------------------------------------------------------------------------- /thunder/dev_utils/benchmark.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | 5 | 6 | def benchmark_n(n, model_or_fn, /, *args, **kwargs): 7 | for _ in range(n): 8 | _ = model_or_fn(*args, **kwargs) 9 | start_event = torch.cuda.Event(enable_timing=True) 10 | end_event = torch.cuda.Event(enable_timing=True) 11 | torch.cuda.synchronize() 12 | start_event.record() 13 | for _ in range(n): 14 | _ = model_or_fn(*args, **kwargs) 15 | end_event.record() 16 | torch.cuda.synchronize() 17 | return start_event.elapsed_time(end_event) / n 18 | 19 | 20 | benchmark = partial(benchmark_n, 10) 21 | -------------------------------------------------------------------------------- /thunder/dev_utils/check_trace.py: -------------------------------------------------------------------------------- 1 | import thunder 2 | 3 | CHECK_VERSION = 3 4 | 5 | 6 | def check_subsymbols(parent_bsym): 7 | if parent_bsym.sym.is_prim: 8 | # assert that there are no subsymbols? 9 | return 10 | known_proxies = {a.name: a for a in parent_bsym.flat_proxy_args} 11 | for bsym in parent_bsym.subsymbols: 12 | for a in bsym.flat_proxy_args: 13 | assert a.name in known_proxies, f"unknown proxy {a.name} is used in {bsym} args" 14 | assert known_proxies[a.name] is a, f"proxy name collision {a.name} in {bsym} args" 15 | for o in bsym.flat_proxy_outs: 16 | assert known_proxies.get(o.name, o) is o, f"known proxy or proxy name collision {o.name} in {bsym} outputs" 17 | known_proxies[o.name] = o 18 | check_subsymbols(bsym) 19 | for o in parent_bsym.flat_proxy_outs: 20 | assert known_proxies.get(o.name, o) is o, f"known proxy or proxy name collision {o.name} in {parent_bsym}" 21 | 22 | 23 | def check_trace(trace, *, version=CHECK_VERSION): 24 | """checks a trace for consistency 25 | 26 | The check is versioned for the benefit of CI and other automated testing. 27 | 28 | As a user, don't pass a version to get all implemented checks. 29 | 30 | If you add new checks and do not fix all newly detected inconsistencies, 31 | bump the CHECK_VERSION and make your tests only apply to this latest version. 32 | 33 | Please do file issues for things that fail with the latest versions so we can 34 | catch up. 35 | """ 36 | # TODO: 37 | # - args vs. unpack trivial 38 | # - args vs. flat_args in return 39 | known_proxies = {} 40 | for bsym in trace.bound_symbols: 41 | if (version >= 3) and bsym.sym == thunder.core.prims.unpack_sequence: 42 | coll = bsym.args[0].collection() 43 | assert len(coll) == len(bsym.output), f"unpack collection length mismatch {bsym}" 44 | for c, o in zip(coll, bsym.output): 45 | if o is None: # unused outputs 46 | continue 47 | if isinstance(c, thunder.Proxy): 48 | assert c is o, f"mismatch in unpack collection: {c} {o} {bsym}" 49 | 50 | for a in bsym.flat_proxy_args: 51 | assert a.name in known_proxies, f"unknown proxy {a.name} is used in {bsym} args" 52 | assert known_proxies[a.name] is a, f"proxy name collision {a.name} in {bsym} args" 53 | for o in bsym.flat_proxy_outs: 54 | assert known_proxies.get(o.name, o) is o, f"known proxy or proxy name collision {o.name} in {bsym} outputs" 55 | known_proxies[o.name] = o 56 | check_subsymbols(bsym) 57 | 58 | tr = thunder.core.trace.from_trace(trace) 59 | with thunder.core.trace.tracectx(tr): 60 | for bsym in trace.bound_symbols: 61 | if bsym.sym.name.startswith("unpack") or bsym.sym.name in {"pack_buffer"}: 62 | continue 63 | res = bsym.sym(*bsym.args, **bsym.kwargs) 64 | 65 | def check_shape(x, y): 66 | if isinstance(x, thunder.TensorProxy) and y is not None: # formal output can be none if unused 67 | assert x.shape == y.shape, f"shape of proxy {y.name} recomputes to {x.shape} incorrectly in {bsym}" 68 | return x 69 | 70 | thunder.core.utils.safe_map_flat(check_shape, res, bsym.output) 71 | 72 | 73 | class CheckedListOfTraces(list): 74 | def __init__(self, *args, **kwargs): 75 | super().__init__(*args, **kwargs) 76 | cd = thunder.core.compile_data.get_compile_data() 77 | if cd.debug_options.check_traces is True: 78 | self._check_version = CHECK_VERSION 79 | elif not cd.debug_options.check_traces: 80 | self._check_version = 0 81 | else: 82 | self._check_version = cd.debug_options.check_traces 83 | 84 | def append(self, trace): 85 | check_trace(trace, version=self._check_version) 86 | super().append(trace) 87 | 88 | def extend(self, traces): 89 | for tr in traces: 90 | check_trace(tr, version=self._check_version) 91 | super().extend(traces) 92 | -------------------------------------------------------------------------------- /thunder/dev_utils/nvtx_profile_transform.py: -------------------------------------------------------------------------------- 1 | from thunder.core.trace import TraceCtx as Trace, from_trace, TraceProvenance 2 | from thunder.dev_utils.utils import NON_COMPUTATION_PRIMS 3 | from thunder.extend import OperatorExecutor 4 | import time 5 | import torch 6 | import thunder 7 | 8 | 9 | class Timer: 10 | def __init__(self): 11 | self.start_time_ns = None 12 | self.end_time_ns = None 13 | 14 | def __enter__(self): 15 | self.start_time_ns = time.perf_counter_ns() 16 | return self 17 | 18 | def __exit__(self, *args): 19 | self.end_time_ns = time.perf_counter_ns() 20 | 21 | def get_elapsed_time_in_ms(self): 22 | elapsed_time_ns = self.end_time_ns - self.start_time_ns 23 | return elapsed_time_ns // int(1e6) 24 | 25 | 26 | nvtx_profiler_ex = OperatorExecutor("nvtx_profiler_ex") 27 | 28 | 29 | def nvtx_push_impl(msg): 30 | torch.cuda.nvtx.range_push(msg) 31 | 32 | 33 | def nvtx_pop_impl(): 34 | torch.cuda.nvtx.range_pop() 35 | 36 | 37 | # Symbols for profiling. 38 | nvtx_push = nvtx_profiler_ex.register_operator("nvtx_range_push", meta=lambda msg: None, fn=nvtx_push_impl) 39 | nvtx_pop = nvtx_profiler_ex.register_operator("nvtx_range_pop", meta=lambda: None, fn=nvtx_pop_impl) 40 | 41 | 42 | class NvtxProfileTransform(thunder.core.transforms.Transform): 43 | def transform_trace_post_optimization(self, trace: Trace, **kwargs) -> Trace: 44 | with Timer() as timer: 45 | profile_trace = from_trace(trace) 46 | 47 | for bound_symbol in trace.bound_symbols: 48 | if bound_symbol.sym.id in NON_COMPUTATION_PRIMS: 49 | profile_trace.bound_symbols.append(bound_symbol) 50 | continue 51 | 52 | # Add nvtx range for the symbol. 53 | profile_trace.bound_symbols.append( 54 | nvtx_push.bind(f"{''.join(bound_symbol.python(indent=0))}", output=None) 55 | ) 56 | profile_trace.bound_symbols.append(bound_symbol) 57 | profile_trace.bound_symbols.append(nvtx_pop.bind(output=None)) 58 | 59 | profile_trace.set_provenance( 60 | TraceProvenance(f"NVTX Profile Transform (took {timer.get_elapsed_time_in_ms()} milliseconds)") 61 | ) 62 | return profile_trace 63 | -------------------------------------------------------------------------------- /thunder/distributed/tensor_parallel/__init__.py: -------------------------------------------------------------------------------- 1 | import torch.distributed 2 | 3 | from thunder.distributed.tensor_parallel.common import TensorParallelLayerType 4 | 5 | if torch.distributed.is_available(): 6 | from thunder.distributed.tensor_parallel.column_wise import column_parallel 7 | from thunder.distributed.tensor_parallel.row_wise import row_parallel 8 | else: 9 | column_parallel = None 10 | row_parallel = None 11 | 12 | 13 | __all__ = [ 14 | "TensorParallelLayerType", 15 | "column_parallel", 16 | "row_parallel", 17 | ] 18 | -------------------------------------------------------------------------------- /thunder/distributed/tensor_parallel/optimize_comm.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import TYPE_CHECKING 3 | 4 | from thunder.core import utils 5 | from thunder.core.proxies import TensorProxy 6 | from thunder.core.proxies import variableify 7 | from thunder.core.trace import from_trace 8 | from thunder.distributed.tensor_parallel.common import TensorParallelLayerType 9 | 10 | if TYPE_CHECKING: 11 | from thunder.core.trace import TraceCtx 12 | from thunder.core.symbol import BoundSymbol 13 | from thunder.core.proxies import ProxyInterface 14 | from thunder.core.trace import VariableInterface 15 | 16 | 17 | __all__ = [ 18 | "remove_redundant_comms", 19 | ] 20 | 21 | 22 | def _get_tensor_parallel_layer_type(bsym: BoundSymbol) -> TensorParallelLayerType: 23 | value = bsym.flat_args[2] 24 | utils.check_type(value, TensorParallelLayerType) 25 | return value 26 | 27 | 28 | def remove_redundant_comms(trace: TraceCtx) -> TraceCtx: 29 | """Remove redundant paris of column-wise linear postprocessing and row-wise linear preprocessing. 30 | 31 | Args: 32 | trace: A trace modified by both of :func:`~thunder.distributed.tensor_parallel.column_parallel` 33 | and :func:`~thunder.distributed.tensor_parallel.row_parallel`. 34 | """ 35 | from thunder.distributed import prims as dist_prims 36 | 37 | current_column_wise_parallel_linear_bsym: BoundSymbol | None = None 38 | 39 | interesting_pairs: list[tuple[BoundSymbol, BoundSymbol]] = [] 40 | bsym_to_idx: dict[BoundSymbol, int] = {} 41 | idx_to_bsym: dict[int, BoundSymbol] = {} 42 | new_bsyms: list[BoundSymbol] = [] 43 | 44 | bsym: BoundSymbol 45 | for idx, bsym in enumerate(trace.bound_symbols): 46 | bsym_to_idx[bsym] = idx 47 | idx_to_bsym[idx] = bsym 48 | new_bsyms.append(bsym) 49 | match bsym.sym.id: 50 | case dist_prims.PrimIDs.SYNCHRONIZE_TENSOR_PARALLEL_INPUT: 51 | match _get_tensor_parallel_layer_type(bsym): 52 | case TensorParallelLayerType.ROW_PARALLEL_LINEAR: 53 | if current_column_wise_parallel_linear_bsym is not None: 54 | interesting_pairs.append((current_column_wise_parallel_linear_bsym, bsym)) 55 | case _: 56 | if current_column_wise_parallel_linear_bsym is not None: 57 | current_column_wise_parallel_linear_bsym = None 58 | case dist_prims.PrimIDs.SYNCHRONIZE_TENSOR_PARALLEL_OUTPUT: 59 | match _get_tensor_parallel_layer_type(bsym): 60 | case TensorParallelLayerType.COLUMN_PARALLEL_LINEAR: 61 | current_column_wise_parallel_linear_bsym = bsym 62 | case _: 63 | pass 64 | 65 | new_trace = from_trace(trace) 66 | new_trace.bound_symbols = new_bsyms 67 | 68 | if interesting_pairs: 69 | indices_to_filter = [] 70 | # For column-parallel linear: postprocessed -> col-parallel linear output 71 | # For row-parallel linear: preprocessed -> col-parallel linear output 72 | swap_map: dict[VariableInterface, ProxyInterface] = {} 73 | for col_postprocess_bsym, row_preprocess_bsym in interesting_pairs: 74 | 75 | col_liinear_output: TensorProxy = col_postprocess_bsym.flat_proxy_args[0] 76 | utils.check_type(col_liinear_output, TensorProxy) 77 | 78 | row_linear_input: TensorProxy = row_preprocess_bsym.flat_proxy_outs[0] 79 | utils.check_type(row_linear_input, TensorProxy) 80 | 81 | # TODO(crcrpar): Better to make sure that between column-wise parallel linear row-wise parallel linear, 82 | # the existing bsyms are elementwise. 83 | if col_liinear_output.shape == row_linear_input.shape: 84 | indices_to_filter.extend([bsym_to_idx[col_postprocess_bsym], bsym_to_idx[row_preprocess_bsym]]) 85 | 86 | swap_map[variableify(col_postprocess_bsym.flat_proxy_outs[0])] = col_liinear_output 87 | 88 | orig_row_linear_input: TensorProxy = row_preprocess_bsym.flat_proxy_args[0] 89 | swap_map[variableify(row_linear_input)] = orig_row_linear_input 90 | 91 | indices_to_filter = set(indices_to_filter) 92 | new_bsyms: list[BoundSymbol] = [] 93 | for idx, bsym in enumerate(trace.bound_symbols): 94 | if idx in indices_to_filter: 95 | continue 96 | new_bsyms.append(bsym.from_bsym_swap_proxies(swap_map=swap_map, skip_output=True)) 97 | new_trace.bound_symbols = new_bsyms 98 | 99 | return new_trace 100 | -------------------------------------------------------------------------------- /thunder/distributed/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | if torch.distributed.is_available(): 4 | from .ddp import apply_bucketing_to_grad_allreduce 5 | from .fsdp import FSDPCommBucketing 6 | else: 7 | apply_bucketing_to_grad_allreduce = None 8 | FSDPCommBucketing = None 9 | 10 | __all__ = [ 11 | "apply_bucketing_to_grad_allreduce", 12 | "FSDPCommBucketing", 13 | ] 14 | -------------------------------------------------------------------------------- /thunder/dynamo/__init__.py: -------------------------------------------------------------------------------- 1 | from thunder.dynamo.compiler import ThunderCompiler, thunderfx, thunder_profile, thunder_optimize 2 | 3 | 4 | __all__ = ["ThunderCompiler", "thunderfx", "thunder_profile", "thunder_optimize"] 5 | -------------------------------------------------------------------------------- /thunder/dynamo/repro_script_template.py: -------------------------------------------------------------------------------- 1 | FXGRAPH_CLASS_NAME = "DynamoModule" 2 | INPUTS_NAME = "inputs" 3 | CALLABLE_NAME = "model" 4 | COMPILED_CALLABLE_NAME = "compiled_model" 5 | 6 | pytest_benchmark_multi_exe_code_template = ''' 7 | # NOTE: This script requires `pytest-benchmark==4.0.0` to be installed. 8 | # To execute the script, run `pytest {graph_name}_benchmark.py --benchmark-timer=torch.utils.benchmark.utils.timer.timer --benchmark-warmup=on --benchmark-group-by=param:compute_type` 9 | # To check the peak allocated CUDA memory, use --benchmark-json=json_file_name and look at the "max_allocated_memory_MB" field in the json file 10 | # To run tests for a specific compute_type, use the pytest `-k` option. 11 | # For example: 12 | # - `-k "forward"` will run only the forward pass. 13 | # 14 | # Available options: 15 | # - compute_type: "forward", "backward" 16 | 17 | import pytest 18 | from thunder.benchmarks.targets import parametrize_compute_type_only_training, benchmark_for_compute_type, ComputeType 19 | {torch_import_str} 20 | {import_str} 21 | 22 | # NOTE: The reproducer function has already been processed by TorchDynamo. 23 | # If we let it go through TorchDynamo again, it could be segmented further. 24 | # To avoid this, we directly use Inductor here. 25 | # See issue https://github.com/Lightning-AI/lightning-thunder/issues/1521 26 | def torch_inductor(fn, inputs): 27 | from torch._inductor import compile as inductor_compile 28 | from torch.fx import symbolic_trace 29 | 30 | fx_graph = symbolic_trace(fn) 31 | return inductor_compile(fx_graph, inputs) 32 | 33 | {executors} 34 | {executor_names} 35 | 36 | {dynamo_module} 37 | 38 | @pytest.mark.parametrize( 39 | "executor", 40 | executors, 41 | ids=executor_names, 42 | ) 43 | {compute_type_decorator} 44 | def test_{graph_name}(benchmark, executor, compute_type): 45 | {inputs} 46 | 47 | model = DynamoModule() 48 | if executor is None: 49 | compiled_model = model 50 | elif executor == torch_inductor: 51 | compiled_model = executor(model, inputs) 52 | else: 53 | compiled_model = executor(model) 54 | {call_benchmark} 55 | 56 | """ 57 | Environment information get from `torch.utils.collect_env.get_pretty_env_info()`: 58 | {torch_env} 59 | 60 | Versions of Thunder related libraries: 61 | {thunder_pkgs} 62 | 63 | {extra_comment_str} 64 | """ 65 | ''' 66 | 67 | 68 | bsym_torch_compile_repro_template = ''' 69 | """ 70 | {extra_comment_str} 71 | """ 72 | {python_func} 73 | 74 | from thunder.executors.torch_compile import make_compiled as make_torch_compile_callable 75 | import thunder.examine 76 | 77 | inputs = {inputs} 78 | 79 | jfn = thunder.jit({func_name}) 80 | jfn(*inputs) 81 | 82 | trc = thunder.last_traces(jfn)[-1] 83 | fusion_symbols = thunder.examine.get_fusion_symbols(trc) 84 | assert len(fusion_symbols) == 1 85 | bsym = fusion_symbols[0] 86 | 87 | # NOTE: The nvFusion function cannot be compiled directly using `torch.compile`. 88 | # It must first be processed by Thunder into BoundSymbols and compiled with `make_torch_compile_callable`. 89 | # Additionally, it's recommended to visually verify that `bsym` matches the 90 | # `nvFusion` function above by printing it using `print(bsym)`. 91 | torch_compiled_callable = make_torch_compile_callable(bsym.subsymbols, bsym.flat_args, bsym.flat_outs) 92 | ''' 93 | 94 | repro_bench_code_template = f""" 95 | {{import_str}} 96 | 97 | {{dynamo_module}} 98 | def test_{{graph_name}}(): 99 | {{inputs}} 100 | 101 | model = {FXGRAPH_CLASS_NAME}() 102 | """ 103 | 104 | main_code = """ 105 | if __name__ == "__main__": 106 | test_{graph_name}() 107 | """ 108 | 109 | comment_str_template = ''' 110 | """ 111 | Environment information get from `torch.utils.collect_env.get_pretty_env_info()`: 112 | {torch_env} 113 | 114 | Versions of Thunder related libraries: 115 | {thunder_pkgs} 116 | 117 | {extra_comment_str} 118 | """ 119 | ''' 120 | -------------------------------------------------------------------------------- /thunder/executors/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Any, List, Tuple 2 | from collections.abc import Sequence 3 | 4 | import thunder.executors.passes as passes 5 | 6 | import thunder.extend as extend 7 | 8 | 9 | # NOTE The executors submodule depends on the extend submodule 10 | 11 | __all__ = [ 12 | "passes", 13 | "get_torch_executor", 14 | "get_nvfuser_executor", 15 | "nvfuser_available", 16 | ] 17 | 18 | 19 | def get_nvfuser_executor() -> None | extend.Executor: 20 | return extend.get_executor("nvfuser") 21 | 22 | 23 | def get_torch_executor() -> extend.Executor: 24 | return extend.get_executor("torch") 25 | 26 | 27 | def nvfuser_available() -> bool: 28 | return get_nvfuser_executor() is not None 29 | -------------------------------------------------------------------------------- /thunder/executors/apex_fused_rms_norm_impl.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Sequence 2 | import math 3 | 4 | import torch 5 | 6 | import thunder 7 | from thunder.core.proxies import TensorProxy, AnyProxy 8 | from thunder.core.transforms import get_grad, put_grads 9 | from thunder.executors.utils import Context, set_saved_tensors 10 | from thunder.torch import TensorLike 11 | from thunder.core.compile_data import get_compile_option 12 | from thunder.executors.apexex import apex_ex 13 | 14 | 15 | APEX_FUSED_NORMS_AVAILABLE = True 16 | try: 17 | # Fused layer norm is only importable if torch.distributed is available 18 | # https://github.com/NVIDIA/apex/issues/1853 19 | from torch.distributed import is_available 20 | 21 | if not is_available(): 22 | raise ImportError 23 | import fused_layer_norm_cuda 24 | from apex.normalization.fused_layer_norm import FusedRMSNormAffineMixedDtypesFunction 25 | except ImportError: 26 | APEX_FUSED_NORMS_AVAILABLE = False 27 | 28 | 29 | def apex_fused_norms_available() -> bool: 30 | return APEX_FUSED_NORMS_AVAILABLE 31 | 32 | 33 | def apex_fused_rms_norm_forward_affine_meta( 34 | input: TensorLike, normalized_shape: Sequence[int], weight: TensorLike, eps: float 35 | ): 36 | output_or_input = TensorProxy(like=input) 37 | weight = TensorProxy(like=input, shape=normalized_shape) 38 | unnormalized_dims = len(input.shape) - len(normalized_shape) 39 | invvar = TensorProxy(like=input, shape=(math.prod(input.shape[:unnormalized_dims]),)) 40 | return TensorProxy(like=input), invvar 41 | 42 | 43 | def apex_fused_rms_norm_backward_affine_meta( 44 | grad_output: TensorLike, 45 | invvar: TensorLike, 46 | input_or_output: TensorLike, 47 | normalized_shape: Sequence[int], 48 | weight_, 49 | eps: float, 50 | memory_efficient: bool, 51 | ): 52 | return TensorProxy(like=grad_output), TensorProxy(like=weight_) 53 | 54 | 55 | # Create a new symbol and register lookaside only if import is available. 56 | if apex_fused_norms_available(): 57 | apex_fused_rms_norm_forward_affine = apex_ex.register_operator( 58 | "apex_fused_rms_norm_forward_affine", 59 | meta=apex_fused_rms_norm_forward_affine_meta, 60 | fn=fused_layer_norm_cuda.rms_forward_affine, 61 | replaces=fused_layer_norm_cuda.rms_forward_affine, 62 | ) 63 | 64 | apex_fused_rms_norm_forward_affine_mixed_dtypes = apex_ex.register_operator( 65 | "apex_fused_rms_norm_forward_affine_mixed_dtypes", 66 | meta=apex_fused_rms_norm_forward_affine_meta, 67 | fn=fused_layer_norm_cuda.rms_forward_affine_mixed_dtypes, 68 | replaces=fused_layer_norm_cuda.rms_forward_affine_mixed_dtypes, 69 | ) 70 | 71 | apex_fused_rms_norm_backward_affine = apex_ex.register_operator( 72 | "apex_fused_rms_norm_backward_affine", 73 | meta=apex_fused_rms_norm_backward_affine_meta, 74 | fn=fused_layer_norm_cuda.rms_backward_affine, 75 | replaces=fused_layer_norm_cuda.rms_backward_affine, 76 | ) 77 | -------------------------------------------------------------------------------- /thunder/executors/apexex.py: -------------------------------------------------------------------------------- 1 | from thunder.extend import OperatorExecutor, register_executor 2 | 3 | 4 | # TODO Does apex have a version this should use? 5 | apex_ex = OperatorExecutor("apex", version="0.1") 6 | register_executor(apex_ex) 7 | 8 | import thunder.executors.apex_entropyex_impl 9 | import thunder.executors.apex_fused_rms_norm_impl as apex_fused_rms_norm_impl 10 | 11 | from thunder.executors.apex_entropyex_impl import apex_entropy_available 12 | from thunder.executors.apex_fused_rms_norm_impl import apex_fused_norms_available 13 | 14 | __all__ = ["apex_ex", "apex_entropy_available", "apex_fused_norms_available"] 15 | -------------------------------------------------------------------------------- /thunder/executors/nvfuserex.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from looseversion import LooseVersion 3 | 4 | from thunder.extend import FusionExecutor 5 | 6 | __all__ = ["nvfuser_version", "required_nvfuser_version", "nvfuser_available", "nvfuserex"] 7 | 8 | 9 | # 10 | # Functions for detecting nvFuser and its version 11 | # 12 | def nvfuser_version() -> LooseVersion | None: 13 | # Short-circuits if CUDA isn't available 14 | if not torch.cuda.is_available(): 15 | return None 16 | 17 | try: 18 | import nvfuser 19 | 20 | if hasattr(nvfuser, "version"): 21 | return LooseVersion(nvfuser.version()) 22 | 23 | # NOTE: This import of nvFuser may or may not have version info 24 | return LooseVersion("0.0.0") 25 | except ImportError: 26 | pass 27 | 28 | # NOTE This occurs when nvFuser couldn't be imported 29 | return None 30 | 31 | 32 | def required_nvfuser_version() -> LooseVersion: 33 | return LooseVersion("0.2.8") 34 | 35 | 36 | def nvfuser_available() -> bool: 37 | v = nvfuser_version() 38 | if v is None: 39 | return False 40 | 41 | required = required_nvfuser_version() 42 | if v < required: 43 | import warnings 44 | 45 | msg = f"Your nvfuser installation is out of date. Thunder requires version {required}, but found version {v}." 46 | warnings.warn(msg) 47 | return False 48 | return True 49 | 50 | 51 | nvfuserex: None | FusionExecutor = None 52 | if nvfuser_available(): 53 | import thunder.executors.nvfuserex_impl as impl 54 | 55 | nvfuserex = impl.ex 56 | -------------------------------------------------------------------------------- /thunder/executors/transformer_engine_v2ex.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | from lightning_utilities.core.imports import package_available 4 | 5 | from thunder import Transform 6 | from thunder.extend import StatefulExecutor 7 | 8 | __all__ = ["transformer_engine_v2_ex", "TransformerEngineTransformV2"] 9 | 10 | transformer_engine_v2_ex: None | StatefulExecutor = None 11 | TransformerEngineTransformV2: None | Transform = None 12 | 13 | if package_available("transformer_engine"): 14 | import thunder.executors.transformer_engine_v2ex_impl as impl 15 | 16 | transformer_engine_v2_ex = impl.transformer_engine_v2_ex 17 | TransformerEngineTransformV2 = impl.TransformerEngineTransformV2 18 | 19 | else: 20 | warnings.warn("transformer_engine module not found!") 21 | -------------------------------------------------------------------------------- /thunder/executors/triton_crossentropy.py: -------------------------------------------------------------------------------- 1 | from thunder.executors import triton_utils 2 | from thunder.extend import OperatorExecutor 3 | 4 | triton_version: None | str = triton_utils.triton_version() 5 | 6 | triton_ex: None | OperatorExecutor = None 7 | if triton_version is not None: 8 | try: 9 | from thunder.executors.triton_crossentropy_impl import triton_ex as impl_ex 10 | 11 | triton_ex = impl_ex 12 | except Exception: 13 | import warnings 14 | 15 | warnings.warn("triton is present but cannot be initialized") 16 | triton_version = None 17 | -------------------------------------------------------------------------------- /thunder/executors/triton_utils.py: -------------------------------------------------------------------------------- 1 | from looseversion import LooseVersion 2 | from lightning_utilities.core.imports import package_available 3 | 4 | 5 | def triton_version() -> None | str: 6 | if not package_available("triton"): 7 | return None 8 | 9 | import triton 10 | 11 | return triton.__version__ 12 | 13 | 14 | def is_triton_version_at_least(minimum_version: str) -> bool: 15 | version = triton_version() 16 | 17 | if version is None: 18 | return False 19 | 20 | return LooseVersion(version) >= minimum_version 21 | -------------------------------------------------------------------------------- /thunder/numpy/__init__.py: -------------------------------------------------------------------------------- 1 | from numbers import Number 2 | from typing import Any 3 | from collections.abc import Callable 4 | 5 | import numpy as np 6 | 7 | from thunder.core.langctx import langctx, Languages 8 | from thunder.numpy.langctx import register_method 9 | 10 | from thunder.core.proxies import TensorProxy 11 | from thunder.core.symbol import Symbol 12 | import thunder.clang as clang 13 | 14 | 15 | # 16 | # NumPy operator definitions 17 | # 18 | # NOTE NumPy language support is demonstrative. PRs extending it are welcome! 19 | 20 | 21 | # Decorator that sets the language context and constructs a Symbol for each function 22 | class npsymbol: 23 | def __init__(self, *, method_name: None | str = None): 24 | self.method_name: None | str = method_name 25 | 26 | def __call__(self, fn: Callable) -> Symbol: 27 | _fn = langctx(Languages.NUMPY)(fn) 28 | # TODO: register _fn as opaque with the interpreter or do this in jit_ext? 29 | sym = Symbol(name=fn.__name__, meta=_fn) 30 | 31 | if self.method_name is not None: 32 | register_method(self.method_name, _fn) 33 | 34 | return sym 35 | 36 | 37 | # 38 | # Tensor properties 39 | # 40 | 41 | 42 | # NOTE Named `compute_len` so that it doesn't conflict with built-in `len` 43 | def compute_len(a: TensorProxy, /) -> int: 44 | return a.shape[0] 45 | 46 | 47 | register_method("len", compute_len) 48 | 49 | 50 | def size(a: TensorProxy, /) -> int: 51 | return a.numel 52 | 53 | 54 | register_method("size", size) 55 | 56 | 57 | # 58 | # Elementwise binary operators 59 | # 60 | 61 | 62 | # TODO Create a factory that adds ufunc support to elementwise operations 63 | npsymbol(method_name="add") 64 | 65 | 66 | def add(a: Number | TensorProxy, b: Number | TensorProxy, /, *, where: None | Number | TensorProxy = None): 67 | result = clang.add(a, b) 68 | if where is not None: 69 | return clang.where(where, result, a) 70 | return result 71 | -------------------------------------------------------------------------------- /thunder/numpy/langctx.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | from collections.abc import Callable, Sequence 3 | 4 | from thunder.core.langctxs import LanguageContext, register_langctx, Languages, resolve_language 5 | from thunder.core.pytree import tree_flatten 6 | from thunder.core.proxies import TensorProxy 7 | 8 | # 9 | # Creates and registers the torch language context 10 | # 11 | # NOTE That this is done separately from the definition of thunder.torch operations, because the 12 | # language context must be available before those operations are defined 13 | 14 | _method_name_to_fn_map: dict[str, Callable] = {} 15 | 16 | 17 | # Creates and registers the core language language context 18 | class NumPyCtx(LanguageContext): 19 | def __init__(self): 20 | super().__init__("numpy") 21 | 22 | def has_method(self, id: str) -> bool: 23 | return id in _method_name_to_fn_map 24 | 25 | def get_method(self, id: str, *args, **kwargs) -> Callable: 26 | # Note: concrete implementations should only raise AttributeError or 27 | # return None for "missing" methods as the proxies will 28 | # route __getattr__ to here and hasattr relies on __getattr__ 29 | # throwing AttributeError (only) when the attribute does 30 | # not exist. 31 | inps, _ = tree_flatten((args, kwargs)) 32 | 33 | has_tensor_input: bool = False 34 | for x in inps: 35 | if isinstance(x, TensorProxy): 36 | has_tensor_input = True 37 | break 38 | 39 | if has_tensor_input: 40 | method: None | Callable = _method_name_to_fn_map.get(id, None) 41 | if method is None: 42 | raise AttributeError(f"The {self.name} language context has no method {id}") 43 | return method 44 | 45 | # has_tensor_input is False 46 | # Defers to the primitive language context when there are no tensor inputs= 47 | # (the primitive language context handles operations on numbers) 48 | primsctx: LanguageContext = resolve_language(Languages.PRIMS) 49 | if not primsctx.has_method(id): 50 | raise AttributeError( 51 | f"Attempting to call method {id} in the numpy language context, but it has no tensor inputs and the primitive language context (which handles numbers) doesn't have the method" 52 | ) 53 | prim_method: Callable = primsctx.get_method(id, *args, **kwargs) 54 | return prim_method 55 | 56 | 57 | numpyctx = NumPyCtx() 58 | register_langctx(Languages.NUMPY, numpyctx) 59 | register_langctx("numpy", numpyctx) 60 | 61 | 62 | # Registers a method with the torch language context 63 | def register_method(method_name: str, method: Callable, /) -> None: 64 | _method_name_to_fn_map[method_name] = method 65 | -------------------------------------------------------------------------------- /thunder/plugins/__init__.py: -------------------------------------------------------------------------------- 1 | from thunder.plugins.distributed import DDP, FSDP 2 | from thunder.plugins.quantization import QuantizeInt4 3 | from thunder.plugins.fp8 import FP8 4 | from thunder.plugins.reduce_overhead import ReduceOverhead 5 | 6 | names_to_plugins = { 7 | "ddp": DDP, 8 | "fsdp": FSDP, 9 | "quantize-int4": QuantizeInt4, 10 | "fp8": FP8, 11 | "reduce-overhead": ReduceOverhead, 12 | } 13 | 14 | 15 | def get_plugin(name): 16 | return names_to_plugins.get(name) 17 | 18 | 19 | def get_plugin_names(): 20 | return list(names_to_plugins.keys()) 21 | 22 | 23 | def register_plugin(name, cls): 24 | names_to_plugins[name] = cls 25 | -------------------------------------------------------------------------------- /thunder/plugins/fp8.py: -------------------------------------------------------------------------------- 1 | from thunder import Plugin 2 | 3 | 4 | class FP8(Plugin): 5 | def setup_executors(self): 6 | from thunder.executors.transformer_engineex import transformer_engine_ex 7 | 8 | return [transformer_engine_ex] 9 | -------------------------------------------------------------------------------- /thunder/plugins/quantization.py: -------------------------------------------------------------------------------- 1 | from thunder import Plugin 2 | 3 | 4 | class QuantizeInt4(Plugin): 5 | def setup_transforms(self): 6 | from thunder.transforms.quantization import BitsAndBytesLinearQuant4bit 7 | 8 | return [BitsAndBytesLinearQuant4bit()] 9 | 10 | def setup_executors(self): 11 | from thunder.transforms.quantization import get_bitsandbytes_executor 12 | 13 | return [get_bitsandbytes_executor()] 14 | -------------------------------------------------------------------------------- /thunder/plugins/reduce_overhead.py: -------------------------------------------------------------------------------- 1 | from thunder import Plugin 2 | from thunder.core.recipe import Plugin, PluginPolicy 3 | from thunder.transforms.cudagraph import CUDAGraphTransform 4 | 5 | 6 | class ReduceOverhead(Plugin): 7 | policy = PluginPolicy.POST 8 | 9 | def setup_transforms(self): 10 | return [CUDAGraphTransform()] 11 | -------------------------------------------------------------------------------- /thunder/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-AI/lightning-thunder/8bd802c89988e6668f5328d3927deda1b4e9e706/thunder/py.typed -------------------------------------------------------------------------------- /thunder/recipes/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Type 2 | 3 | from thunder.recipes.base import BaseRecipe 4 | from thunder.recipes.hf_transformers import HFTransformers 5 | 6 | 7 | names_to_recipes: dict[str : type[Any]] = { 8 | "default": BaseRecipe, 9 | "hf-transformers": HFTransformers, 10 | } 11 | 12 | 13 | def get_recipe_class(name: str) -> type[Any]: 14 | return names_to_recipes.get(name) 15 | 16 | 17 | def get_recipes() -> list[str]: 18 | return list(names_to_recipes.keys()) 19 | 20 | 21 | def register_recipe(name: str, cls: type[Any]): 22 | if name == "auto": 23 | raise ValueError("Recipe name 'auto' is reserved.") 24 | names_to_recipes[name] = cls 25 | -------------------------------------------------------------------------------- /thunder/tests/README.md: -------------------------------------------------------------------------------- 1 | Testing is important! 2 | 3 | Most Thunder operators are tested by automatically generated tests that use data from the operator's "OpInfo", defined 4 | in `opinfos.py`. These tests typically compare Thunder's operator behavior with another framework, using the 5 | "sample inputs" defined by its OpInfo. Each OpInfo should define a sample input generator and at least one “reference” implementation in PyTorch or JAX (and in the future, NumPy). 6 | 7 | When an operator is very similar across the three language levels (user, core, and primitive), like cos, only one variation is typically tested (the core language variant is preferred). 8 | However, there are also cases where testing a torch language operation or primitive operation directly makes sense. 9 | 10 | Operator tests are autogenerated in `test_ops.py`. To run the tests for a particular operator, use pytest’s -k option. 11 | For example, to run the tests for `cos`, for example, the command would be 12 | 13 | ```bash 14 | pytest test_ops.py -k cos -v 15 | ``` 16 | 17 | This will run tests for Thunder’s different executors, supported dtypes, and supported devicetypes. 18 | 19 | The OpInfo tests can also be run from Python, and this can be extremely useful for debugging. 20 | To do that, find the generated test name you’d like to run (often by invoking the tests from the command line and observing their names) and call it. Here’s a sample program: 21 | 22 | ```py 23 | import traceback 24 | import thunder.tests.test_ops as to 25 | 26 | e, exc_info, snippet, opinfo, devicetype, dtype, args, kwargs = to.test_core_vs_torch_consistency_cos_nvFuser_CUDA_float32() 27 | 28 | traceback.print_exception(*exc_info) 29 | ``` 30 | 31 | If the test fails, it will return information about the failure, including error information and the arguments that caused the failure. 32 | In the above sample the traceback information is printed. If the test succeeds then it will return nothing. 33 | 34 | It can be a little tricky to remember all the components a test returns, but you can 35 | always catch and print the return value to better understand what's available. 36 | -------------------------------------------------------------------------------- /thunder/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-AI/lightning-thunder/8bd802c89988e6668f5328d3927deda1b4e9e706/thunder/tests/__init__.py -------------------------------------------------------------------------------- /thunder/tests/bf16.py: -------------------------------------------------------------------------------- 1 | import thunder.core 2 | import torch 3 | 4 | 5 | def device_supports_bf16(device: None | str | torch.device | thunder.core.devices.Device, /) -> bool: 6 | """Torch has its own `torch.cuda.is_bf16_supported()`, but the contract of 7 | this API changed in December 2023 to be "can bf16 be represented?", which 8 | is true even for devices that existed before bf16 was invented. 9 | 10 | The contract for this API is that the device implements bf16 operations.""" 11 | if not torch.cuda.is_available(): 12 | return False 13 | 14 | dev: torch.device = thunder.core.devices.to_torch_device(device) 15 | 16 | if dev.type != "cuda": 17 | return False 18 | 19 | cuda_major: int 20 | cuda_minor: int 21 | cuda_major, cuda_minor = torch.cuda.get_device_capability(dev) 22 | return cuda_major >= 8 23 | -------------------------------------------------------------------------------- /thunder/tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import pytest_benchmark 3 | from thunder.dynamo.compiler_graph_benchmark import GRAPH_BY_GRAPH_BENCHMARK_PARAMS_KEYS 4 | 5 | 6 | @pytest.hookimpl(hookwrapper=True) 7 | def pytest_benchmark_group_stats(config, benchmarks, group_by): 8 | """ 9 | The function customize the behavior for ThunderCompilerGraphBenchmarking. 10 | The custom grouping function is only invoked when the `--benchmark-group-by` 11 | option is set to 'graph-by-graph:param:GraphID,param:SplitModuleName'. 12 | For an example, refer to the comment section in `ThunderCompilerGraphBenchmarking`. 13 | 14 | Reference: https://pytest-benchmark.readthedocs.io/en/latest/hooks.html#pytest_benchmark.hookspec.pytest_benchmark_group_stats 15 | """ 16 | prefix = "graph-by-graph:" 17 | outcome = yield 18 | if group_by.startswith(prefix): 19 | group_by = group_by[len(prefix) :] 20 | for bench in benchmarks: 21 | if bench["params"] is None: 22 | bench["params"] = {} 23 | # The benchs with the same `params`` share the same dict 24 | # We need to create a deepcopy of the original dictionary to add parameters specific to each graph. 25 | else: 26 | bench["params"] = bench["params"].copy() 27 | if bench["param"] is None: 28 | bench["param"] = "" 29 | 30 | name = bench["name"] 31 | gid, module_name, ex = name.split("-")[-3:] 32 | # Add the "GraphID", "SplitModuleName","executor" as params in benchmark 33 | gid_key, module_name_key, ex_key = GRAPH_BY_GRAPH_BENCHMARK_PARAMS_KEYS 34 | bench["params"].update({gid_key: gid, module_name_key: module_name, ex_key: ex}) 35 | bench["param"] += f"-{gid}-{module_name}-{ex}" 36 | 37 | result = pytest_benchmark.plugin.pytest_benchmark_group_stats(config, benchmarks, group_by) 38 | outcome.force_result(result) 39 | 40 | 41 | def pytest_collection_modifyitems(items): 42 | items.sort(key=lambda item: item.name) 43 | -------------------------------------------------------------------------------- /thunder/tests/distributed/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-AI/lightning-thunder/8bd802c89988e6668f5328d3927deda1b4e9e706/thunder/tests/distributed/__init__.py -------------------------------------------------------------------------------- /thunder/tests/distributed/modules.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import ClassVar, TYPE_CHECKING 3 | 4 | import torch.nn as nn 5 | 6 | from thunder.core import utils 7 | 8 | if TYPE_CHECKING: 9 | import torch 10 | 11 | 12 | __all__ = [ 13 | "ParallelMLP", 14 | ] 15 | 16 | 17 | class ParallelMLP(nn.Module): 18 | """Simplified version of Megatron/NeMo's ParallelMLP. 19 | 20 | Ref: https://github.com/NVIDIA/NeMo/blob/95ca2f4/nemo/collections/nlp/modules/common/megatron/mlp.py#L61 21 | """ 22 | 23 | COLUMN_WISE: ClassVar[tuple[str]] = ("dense_h_to_4h",) 24 | ROW_WISE: ClassVar[tuple[str]] = ("dense_4h_to_h",) 25 | 26 | SUPPORTED_GELU_APPROX: ClassVar[tuple[str, str]] = ("none", "tanh") 27 | 28 | def __init__( 29 | self, 30 | hidden_size: int, 31 | ffn_hidden_size: int | None = None, 32 | bias: bool = True, 33 | gelu_approximate: str = "none", 34 | ) -> None: 35 | utils.check( 36 | gelu_approximate in ParallelMLP.SUPPORTED_GELU_APPROX, 37 | lambda: f"Invalid {gelu_approximate}, supported are {ParallelMLP.SUPPORTED_GELU_APPROX}", 38 | ) 39 | if ffn_hidden_size is None: 40 | ffn_hidden_size = 4 * hidden_size 41 | 42 | super().__init__() 43 | self.dense_h_to_4h = nn.Linear(hidden_size, ffn_hidden_size, bias=bias) 44 | self.dense_4h_to_h = nn.Linear(ffn_hidden_size, hidden_size, bias=bias) 45 | self.gelu = nn.GELU(approximate=gelu_approximate) 46 | 47 | def forward(self, x: torch.Tensor) -> torch.Tensor: 48 | four_h = self.gelu(self.dense_h_to_4h(x)) 49 | h = self.dense_4h_to_h(four_h) 50 | return h 51 | -------------------------------------------------------------------------------- /thunder/tests/hf_bart_self_attn.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # This is an abbreviated version of the HuggingFace Bart Encoder Self-Attention 16 | # module. The block was editted to remove non-essential code for 17 | # self-attention. 18 | 19 | import torch 20 | from torch import nn 21 | 22 | 23 | class BartAttention(nn.Module): 24 | """Multi-headed attention from 'Attention Is All You Need' paper""" 25 | 26 | def __init__( 27 | self, 28 | embed_dim: int, 29 | num_heads: int, 30 | dropout: float = 0.0, 31 | bias: bool = True, 32 | ): 33 | super().__init__() 34 | self.embed_dim = embed_dim 35 | self.num_heads = num_heads 36 | self.dropout = dropout 37 | self.head_dim = embed_dim // num_heads 38 | 39 | if (self.head_dim * num_heads) != self.embed_dim: 40 | raise ValueError( 41 | f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" 42 | f" and `num_heads`: {num_heads})." 43 | ) 44 | self.scaling = self.head_dim**-0.5 45 | 46 | self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 47 | self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 48 | self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 49 | self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 50 | 51 | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): 52 | return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() 53 | 54 | def forward( 55 | self, 56 | hidden_states: torch.Tensor, 57 | attention_mask: torch.Tensor, 58 | ) -> torch.Tensor: 59 | """Input shape: Batch x Time x Channel""" 60 | bsz, tgt_len, _ = hidden_states.size() 61 | 62 | # get query proj 63 | query_states = self.q_proj(hidden_states) * self.scaling 64 | key_states = self._shape(self.k_proj(hidden_states), -1, bsz) 65 | value_states = self._shape(self.v_proj(hidden_states), -1, bsz) 66 | 67 | proj_shape = (bsz * self.num_heads, -1, self.head_dim) 68 | query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) 69 | key_states = key_states.view(*proj_shape) 70 | value_states = value_states.view(*proj_shape) 71 | 72 | src_len = key_states.size(1) 73 | attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) 74 | 75 | if attention_mask is not None: 76 | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask 77 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) 78 | 79 | attn_weights = nn.functional.softmax(attn_weights, dim=-1) 80 | 81 | attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) 82 | 83 | attn_output = torch.bmm(attn_probs, value_states) 84 | 85 | attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) 86 | attn_output = attn_output.transpose(1, 2) 87 | 88 | # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be 89 | # partitioned aross GPUs when using tensor-parallelism. 90 | attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) 91 | 92 | attn_output = self.out_proj(attn_output) 93 | 94 | return attn_output 95 | -------------------------------------------------------------------------------- /thunder/tests/litgpt_model.py: -------------------------------------------------------------------------------- 1 | """Taken from https://github.com/Lightning-AI/litgpt/blob/main/litgpt/model.py""" 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | configs = [ 8 | # diverse sample of configs FOR TESTING that cover all major checkpoints variants architecturally but with reduced 9 | # size 10 | dict(name="gpt-neox-like", block_size=128, n_layer=2, n_embd=64, n_head=4, padding_multiple=8), 11 | dict( 12 | name="llama1-like", 13 | block_size=128, 14 | vocab_size=320, 15 | padding_multiple=64, 16 | n_layer=2, 17 | n_head=4, 18 | n_embd=64, 19 | rotary_percentage=1.0, 20 | parallel_residual=False, 21 | bias=False, 22 | norm_class_name="RMSNorm", 23 | norm_eps=1e-6, 24 | mlp_class_name="LLaMAMLP", 25 | intermediate_size=1376, 26 | ), 27 | dict( 28 | name="long-context-like", 29 | block_size=512, 30 | vocab_size=320, 31 | padding_multiple=64, 32 | n_layer=2, 33 | n_head=4, 34 | n_embd=64, 35 | rotary_percentage=1.0, 36 | parallel_residual=False, 37 | bias=False, 38 | norm_class_name="RMSNorm", 39 | mlp_class_name="LLaMAMLP", 40 | intermediate_size=11008, 41 | rope_condense_ratio=4, 42 | ), 43 | dict( 44 | name="llama2-like", 45 | vocab_size=320, 46 | padding_multiple=64, 47 | n_layer=2, 48 | n_head=4, 49 | n_embd=64, 50 | rotary_percentage=1.0, 51 | parallel_residual=False, 52 | bias=False, 53 | norm_class_name="RMSNorm", 54 | mlp_class_name="LLaMAMLP", 55 | intermediate_size=1376, 56 | ), 57 | dict( 58 | name="falcon-7b-like", 59 | block_size=128, 60 | padded_vocab_size=254, 61 | n_layer=2, 62 | n_head=7, 63 | n_embd=448, 64 | rotary_percentage=1.0, 65 | n_query_groups=1, 66 | bias=False, 67 | shared_attention_norm=True, 68 | ), 69 | dict( 70 | name="falcon-40b-like", 71 | block_size=128, 72 | padded_vocab_size=508, 73 | n_layer=2, 74 | n_head=64, 75 | n_embd=256, 76 | rotary_percentage=1.0, 77 | n_query_groups=4, 78 | bias=False, 79 | ), 80 | dict( 81 | name="codellama2-like", 82 | block_size=1024, 83 | vocab_size=2001, 84 | padding_multiple=16, 85 | n_layer=2, 86 | n_head=4, 87 | n_embd=64, 88 | rotary_percentage=1.0, 89 | parallel_residual=False, 90 | bias=False, 91 | norm_class_name="RMSNorm", 92 | norm_eps=1e-05, 93 | mlp_class_name="LLaMAMLP", 94 | intermediate_size=1376, 95 | rope_base=1000000, 96 | ), 97 | dict( 98 | name="mixtral-like", 99 | block_size=512, 100 | padded_vocab_size=500, 101 | n_layer=2, 102 | n_head=64, 103 | n_embd=256, 104 | rotary_percentage=1.0, 105 | n_query_groups=8, 106 | parallel_residual=False, 107 | bias=False, 108 | norm_class_name="RMSNorm", 109 | norm_eps=1e-05, 110 | mlp_class_name="LLaMAMoE", 111 | intermediate_size=224, 112 | rope_base=1000000, 113 | n_expert=8, 114 | n_expert_per_token=2, 115 | ), 116 | ] 117 | 118 | name_to_config = {config["name"]: config for config in configs} 119 | 120 | 121 | import litgpt 122 | 123 | # add the testing configurations 124 | litgpt.config.name_to_config.update(name_to_config) 125 | name_to_config.update(litgpt.config.name_to_config) 126 | 127 | # manually expose for backwards compatibility 128 | Config = litgpt.Config 129 | GPT = litgpt.GPT 130 | RMSNorm = litgpt.model.RMSNorm 131 | CausalSelfAttention = litgpt.model.CausalSelfAttention 132 | LLaMAMLP = litgpt.model.LLaMAMLP 133 | build_rope_cache = litgpt.model.build_rope_cache 134 | apply_rope = litgpt.model.apply_rope 135 | Block = litgpt.model.Block 136 | -------------------------------------------------------------------------------- /thunder/tests/make_tensor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import thunder 3 | 4 | 5 | def make_tensor( 6 | *shape: int | torch.Size | list[int] | tuple[int, ...], 7 | dtype: torch.dtype, 8 | device: str | torch.device | thunder.devices.Device, 9 | low: float | None = None, 10 | high: float | None = None, 11 | requires_grad: bool = False, 12 | noncontiguous: bool = False, 13 | exclude_zero: bool = False, 14 | memory_format: torch.memory_format | None = None, 15 | ) -> torch.Tensor: 16 | r"""Creates a tensor with the given :attr:`shape`, :attr:`device`, and :attr:`dtype`, and filled with 17 | values uniformly drawn from ``[low, high)``. 18 | Calls torch.testing.make_tensor and optionally torch.Tensor.fill_ to allow for low == high, as 19 | torch.testing.make_tensor enforces low < high. 20 | 21 | Args: 22 | shape (Tuple[int, ...]): Single integer or a sequence of integers defining the shape of the output tensor. 23 | dtype (:class:`torch.dtype`): The data type of the returned tensor. 24 | device (Union[str, torch.device, thunder.devices.Device]): The device of the returned tensor. 25 | low (Optional[Number]): Sets the lower limit (inclusive) of the given range. If a number is provided it is 26 | clamped to the least representable finite value of the given dtype. When ``None`` (default), 27 | this value is determined based on the :attr:`dtype` (see the table above). Default: ``None``. 28 | high (Optional[Number]): Sets the upper limit (exclusive) of the given range. If a number is provided it is 29 | clamped to the greatest representable finite value of the given dtype. When ``None`` (default) this value 30 | is determined based on the :attr:`dtype` (see the table above). Default: ``None``. 31 | requires_grad (Optional[bool]): If autograd should record operations on the returned tensor. Default: ``False``. 32 | noncontiguous (Optional[bool]): If `True`, the returned tensor will be noncontiguous. This argument is 33 | ignored if the constructed tensor has fewer than two elements. 34 | exclude_zero (Optional[bool]): If ``True`` then zeros are replaced with the dtype's small positive value 35 | depending on the :attr:`dtype`. For bool and integer types zero is replaced with one. For floating 36 | point types it is replaced with the dtype's smallest positive normal number (the "tiny" value of the 37 | :attr:`dtype`'s :func:`~torch.finfo` object), and for complex types it is replaced with a complex number 38 | whose real and imaginary parts are both the smallest positive normal number representable by the complex 39 | type. Default ``False``. 40 | memory_format (Optional[torch.memory_format]): The memory format of the returned tensor. Incompatible 41 | with :attr:`noncontiguous`. 42 | Raises: 43 | ValueError: if ``requires_grad=True`` is passed for integral `dtype` 44 | ValueError: If ``low > high``. 45 | ValueError: If either :attr:`low` or :attr:`high` is ``nan``. 46 | TypeError: If :attr:`dtype` isn't supported by this function. 47 | """ 48 | if isinstance(device, thunder.devices.Device): 49 | device = device.device_str() 50 | 51 | fill_value = None 52 | if low is not None and low == high: 53 | fill_value = low 54 | low = None 55 | high = None 56 | 57 | t = torch.testing.make_tensor( 58 | *shape, 59 | dtype=dtype, 60 | device=device, 61 | low=low, 62 | high=high, 63 | requires_grad=requires_grad, 64 | noncontiguous=noncontiguous, 65 | exclude_zero=exclude_zero, 66 | memory_format=memory_format, 67 | ) 68 | 69 | if fill_value is not None: 70 | t.fill_(fill_value) 71 | 72 | return t 73 | 74 | 75 | def make_tensor_like(a, **kwargs): 76 | # type: (torch.Tensor) -> torch.Tensor 77 | """Returns a tensor with the same properties as the given tensor. 78 | 79 | Args: 80 | a (torch.Tensor): The tensor to copy properties from. 81 | kwargs (dict): Additional properties for `make_tensor`. 82 | 83 | Returns: 84 | torch.Tensor: A tensor with the same properties as :attr:`a`. 85 | """ 86 | kwargs = kwargs | dict(device=a.device, dtype=a.dtype, requires_grad=a.requires_grad) 87 | return make_tensor(a.shape, **kwargs) 88 | -------------------------------------------------------------------------------- /thunder/tests/module_example.py: -------------------------------------------------------------------------------- 1 | def returns_two(): 2 | return 2 3 | 4 | 5 | def _returns_three(): 6 | return 3 7 | 8 | 9 | def returns_five(): 10 | return 5 11 | -------------------------------------------------------------------------------- /thunder/tests/test_apex_fused_norms.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torch.testing import assert_close 4 | 5 | fused_layer_norm_cuda = pytest.importorskip("fused_layer_norm_cuda") 6 | 7 | from torch.distributed import is_available 8 | from thunder.executors.apexex import apex_ex 9 | import thunder 10 | 11 | 12 | # See https://github.com/NVIDIA/apex/issues/1853 13 | @pytest.mark.skipif(not is_available(), reason="torch.distributed is not available") 14 | @pytest.mark.parametrize("requires_grad", [True, False]) 15 | @pytest.mark.parametrize("memory_efficient", [True, False]) 16 | def test_apex_fused_rms_norm(requires_grad, memory_efficient): 17 | from apex.normalization.fused_layer_norm import FusedRMSNormAffineMixedDtypesFunction 18 | 19 | def fn(x, weight, normalized_shape, eps): 20 | return FusedRMSNormAffineMixedDtypesFunction.apply(x, weight, normalized_shape, eps, memory_efficient) 21 | 22 | device = "cuda" 23 | normalized_shape = (3, 2) 24 | x = torch.randn(4, 5, *normalized_shape, device=device, requires_grad=requires_grad) 25 | weight = torch.randn(*normalized_shape, device=device, requires_grad=requires_grad) 26 | eps = 1e-5 27 | 28 | expected = fn(x, weight, normalized_shape, eps) 29 | jfn = thunder.jit(fn, executors=[apex_ex]) 30 | actual = jfn(x, weight, normalized_shape, eps) 31 | 32 | assert_close(actual, expected) 33 | 34 | if requires_grad: 35 | grad_output = torch.rand_like(actual) 36 | actual_grad = torch.autograd.grad(actual, [x, weight], grad_output) 37 | expected_grad = torch.autograd.grad(expected, [x, weight], grad_output) 38 | 39 | assert_close(actual_grad, expected_grad) 40 | 41 | 42 | # See https://github.com/NVIDIA/apex/issues/1853 43 | @pytest.mark.skipif(not is_available(), reason="torch.distributed is not available") 44 | @pytest.mark.parametrize("requires_grad", [True, False]) 45 | @pytest.mark.parametrize("memory_efficient", [True, False]) 46 | def test_apex_fused_rms_norm_autoregister(requires_grad, memory_efficient): 47 | from apex.normalization.fused_layer_norm import FusedRMSNormAffineMixedDtypesFunction 48 | 49 | def fn(x, weight, normalized_shape, eps): 50 | return FusedRMSNormAffineMixedDtypesFunction.apply(x, weight, normalized_shape, eps, memory_efficient) 51 | 52 | device = "cuda" 53 | normalized_shape = (3, 2) 54 | x = torch.randn(4, 5, *normalized_shape, device=device, requires_grad=requires_grad) 55 | weight = torch.randn(*normalized_shape, device=device, requires_grad=requires_grad) 56 | eps = 1e-5 57 | 58 | expected = fn(x, weight, normalized_shape, eps) 59 | jfn = thunder.jit(fn, executors=()) 60 | actual = jfn(x, weight, normalized_shape, eps) 61 | 62 | assert_close(actual, expected) 63 | 64 | if requires_grad: 65 | grad_output = torch.rand_like(actual) 66 | actual_grad = torch.autograd.grad(actual, [x, weight], grad_output) 67 | expected_grad = torch.autograd.grad(expected, [x, weight], grad_output) 68 | 69 | assert_close(actual_grad, expected_grad) 70 | -------------------------------------------------------------------------------- /thunder/tests/test_check_trace.py: -------------------------------------------------------------------------------- 1 | import thunder 2 | from thunder.dev_utils.check_trace import check_trace 3 | import torch 4 | import pytest 5 | 6 | 7 | def test_missing_symbol(): 8 | def fn(a, b): 9 | c = a + b 10 | d = 2 * c 11 | return d 12 | 13 | jfn = thunder.jit(fn) 14 | 15 | a = torch.randn(2, 2) 16 | b = torch.randn(2, 2) 17 | 18 | jfn(a, b) 19 | 20 | tr = thunder.last_traces(jfn)[-1] 21 | check_trace(tr) 22 | 23 | del tr.bound_symbols[-3] 24 | 25 | with pytest.raises(AssertionError, match="unknown proxy"): 26 | check_trace(tr) 27 | 28 | 29 | def test_debug_option(): 30 | class BrokenTransform(thunder.core.transform_common.Transform): 31 | def transform_traces_pre_prologue(self, pro, comp, epi, **kwargs): 32 | new_comp = thunder.core.trace.from_trace(comp) 33 | new_comp.bound_symbols = comp.bound_symbols[:] 34 | del new_comp.bound_symbols[2] 35 | return pro, new_comp, epi 36 | 37 | def fn(a, b): 38 | c = a + b 39 | d = 2 * c 40 | return d 41 | 42 | a = torch.randn(2, 2) 43 | b = torch.randn(2, 2) 44 | 45 | # works 46 | jfn = thunder.jit(fn, debug_options=thunder.DebugOptions(check_traces=True)) 47 | jfn(a, b) 48 | 49 | # broken with nice error 50 | jfn = thunder.jit(fn, transforms=(BrokenTransform(),), debug_options=thunder.DebugOptions(check_traces=True)) 51 | 52 | with pytest.raises(AssertionError, match="unknown proxy"): 53 | jfn(a, b) 54 | 55 | # broken with less nice error 56 | jfn = thunder.jit(fn, transforms=(BrokenTransform(),), executors=()) 57 | with pytest.raises(UnboundLocalError, match="cannot access local|referenced before assignment"): 58 | jfn(a, b) 59 | -------------------------------------------------------------------------------- /thunder/tests/test_examine.py: -------------------------------------------------------------------------------- 1 | import thunder.examine 2 | import torch 3 | import pytest 4 | from thunder.executors import nvfuser_available 5 | 6 | 7 | def test_examine_fn(): 8 | def foo(x): 9 | x[0] = 5 * x[1] 10 | 11 | x = torch.ones(2, 2) 12 | thunder.examine.examine(foo, x) 13 | 14 | 15 | def test_examine_jfn(): 16 | def foo(x): 17 | x[0] = 5 * x[1] 18 | 19 | jfoo = thunder.jit(foo) 20 | x = torch.ones(2, 2) 21 | thunder.examine.examine(jfoo, x) 22 | 23 | 24 | def test_examine_noncallable(capsys): 25 | x = torch.ones(2, 2) 26 | y = torch.ones(2, 2) 27 | thunder.examine.examine(x, y) 28 | captured = capsys.readouterr() 29 | assert "expected `fn` to be a callable" in captured.out 30 | -------------------------------------------------------------------------------- /thunder/tests/test_examine_memory.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import torch 4 | 5 | import thunder 6 | from thunder.core.pytree import tree_map 7 | import thunder.torch as ltorch 8 | from thunder.examine.memory_calculation import get_alloc_memory 9 | 10 | from thunder.tests.framework import requiresCUDA, TorchExecutor 11 | from thunder.tests.make_tensor import make_tensor 12 | 13 | 14 | def measure_memory_usage(trace): 15 | torch.cuda.reset_peak_memory_stats() 16 | before = torch.cuda.memory_stats().get("requested_bytes.all.current", 0) 17 | 18 | def make_tensor_like_torch_dtype(p): 19 | return make_tensor(p.shape, dtype=ltorch.to_torch_dtype(p.dtype), device=p.device) 20 | 21 | args, kwargs = tree_map(make_tensor_like_torch_dtype, (trace.args, trace.kwargs)) 22 | output = trace.python_callable()(*args, **kwargs) 23 | 24 | after = torch.cuda.memory_stats()["requested_bytes.all.current"] 25 | peak = torch.cuda.memory_stats()["requested_bytes.all.peak"] 26 | 27 | return {"peak": peak - before, "current": after - before, "output": output} 28 | 29 | 30 | def measure_fw_and_bw_memory_usage(fw_trace, bw_trace): 31 | fw_results = measure_memory_usage(fw_trace) 32 | bw_results = measure_memory_usage(bw_trace) 33 | 34 | return {f"fw_{k}": v for k, v in fw_results.items()} | {f"bw_{k}": v for k, v in bw_results.items()} 35 | 36 | 37 | # TODO: Test for nvFuserExecutor 38 | # nvFuserExecutor is skipped for now, because nvFuser and eager execution treat allocation and broadcast differently. 39 | # In the future, we need to update get_alloc_memory to support nvFuser and update tests accordingly. 40 | @requiresCUDA 41 | def test_view_ops(): 42 | def test(func, *shapes): 43 | inputs = [make_tensor(shape, dtype=torch.float32, device="cuda", requires_grad=True) for shape in shapes] 44 | cfunc = TorchExecutor.make_callable(func, disable_preprocessing=False) 45 | cfunc(*inputs) 46 | 47 | fw_trace = thunder.last_traces(cfunc)[-1] 48 | bw_trace = thunder.last_backward_traces(cfunc)[-1] 49 | max_mem_fw = get_alloc_memory(fw_trace) 50 | max_mem_bw = get_alloc_memory(bw_trace) 51 | 52 | result = measure_fw_and_bw_memory_usage(fw_trace, bw_trace) 53 | assert max_mem_fw[0] == result["fw_peak"] 54 | assert sum(max_mem_fw[1].values()) == result["fw_current"] 55 | assert max_mem_bw[0] == result["bw_peak"] 56 | assert sum(max_mem_bw[1].values()) == result["bw_current"] 57 | 58 | def foo(a, b): # [4] [4] 59 | a_1 = torch.unsqueeze(a, 0) # [1,4] 60 | b_2 = torch.unsqueeze(b, 0) # [1,4] 61 | return (a_1 + b_2,) 62 | 63 | test(foo, (4,), (4,)) 64 | 65 | def bar(a, b): # [4] [2,2] 66 | a_1 = torch.unsqueeze(a, 0) # [1,4] 67 | a_2 = torch.unsqueeze(a_1, 1) # [1,1,4] 68 | a_3 = a_2.expand(2, 3, 4) # [2,3,4] 69 | 70 | b_1 = torch.reshape(b, (4,)) # [4] 71 | b_2 = torch.unsqueeze(b_1, 0) # [1,4] 72 | b_3 = torch.unsqueeze(b_2, 1) # [1,1,4] 73 | b_4 = b_3.expand(2, 3, 4) # [2,3,4] 74 | 75 | result1 = a_2 + b_3 76 | result2 = b_4 + a_3 77 | return result1, result2 78 | 79 | test(bar, (4,), (2, 2)) 80 | 81 | def bar1(a, b, c): # [4], [1,4,4], [4,1,4] 82 | a_1 = torch.unsqueeze(a, 0) # [1,4] 83 | a_2 = torch.unsqueeze(a_1, 1) # [1,1,4] 84 | a_3 = a_2.expand(1, 4, 4) 85 | a_4 = a_2.expand(4, 1, 4) 86 | return b + a_3, c + a_4 87 | 88 | test(bar1, (4,), (1, 4, 4), (4, 1, 4)) 89 | 90 | def bar2(a, b): # [5,2], [2,2] 91 | a_1, a_2, a_3 = torch.split(a, 2) 92 | c = a_1 + b 93 | d = a + a 94 | return c, d, a_2, a_3 # We have to use all the outputs of torch.split due to #1043 95 | 96 | test(bar2, (5, 2), (2, 2)) 97 | 98 | 99 | @requiresCUDA 100 | def test_nanogpt_block(): 101 | import thunder.tests.nanogpt_model as nanogpt_model 102 | 103 | config = nanogpt_model.GPTConfig(dropout=0) 104 | block = nanogpt_model.Block(config).to(dtype=torch.float32, device="cuda") 105 | cblock = TorchExecutor.make_callable(block) 106 | inp = make_tensor((2, config.block_size, config.n_embd), dtype=torch.float32, device="cuda", requires_grad=True) 107 | cblock(inp) 108 | 109 | fw_trace = thunder.last_traces(cblock)[-1] 110 | bw_trace = thunder.last_backward_traces(cblock)[-1] 111 | max_mem_fw = get_alloc_memory(fw_trace) 112 | max_mem_bw = get_alloc_memory(bw_trace) 113 | 114 | # Actual memory usage may vary depending on hardware and cuBLAS settings. 115 | # We are checking the estimated memory against a fixed value for consistency. 116 | assert max_mem_fw[0] == 381754368 117 | assert sum(max_mem_fw[1].values()) == 375462912 118 | assert max_mem_bw[0] == 641097728 119 | assert sum(max_mem_bw[1].values()) == 440474624 120 | -------------------------------------------------------------------------------- /thunder/tests/test_pythonex.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import math 3 | 4 | import thunder 5 | 6 | 7 | def _run_cache_symbolic_values(fn, ref_fn, *args): 8 | jit_fn = thunder.jit(fn, cache="symbolic values") 9 | out = jit_fn(*args) 10 | 11 | out_ref = ref_fn(*args) 12 | assert out == out_ref 13 | 14 | 15 | def test_fmod(): 16 | def foo(a, b): 17 | return a % b 18 | 19 | _run_cache_symbolic_values(foo, foo, 2.0, 1.3) 20 | 21 | 22 | def test_bitwise_or(): 23 | def foo(a, b): 24 | return a | b 25 | 26 | _run_cache_symbolic_values(foo, foo, 3, 5) 27 | 28 | 29 | def test_bitwise_and(): 30 | def foo(a, b): 31 | return a & b 32 | 33 | _run_cache_symbolic_values(foo, foo, 3, 5) 34 | 35 | 36 | def test_bitwise_xor(): 37 | def foo(a, b): 38 | return a ^ b 39 | 40 | _run_cache_symbolic_values(foo, foo, 3, 5) 41 | 42 | 43 | def test_math_atan2(): 44 | def foo(a, b): 45 | # TODO: calling through math.atan2 bakes in constant, this needs to be investigated. 46 | return thunder.clang.atan2(a, b) 47 | 48 | # NOTE: we have thunder.clang in foo, which cannot be run with non-proxy 49 | _run_cache_symbolic_values(foo, math.atan2, 2.0, 1.3) 50 | 51 | 52 | def test_math_fmod(): 53 | def foo(a, b): 54 | return thunder.clang.fmod(a, b) 55 | 56 | _run_cache_symbolic_values(foo, math.fmod, 2.0, 1.3) 57 | -------------------------------------------------------------------------------- /thunder/tests/test_reductions.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | from torch.testing import assert_close, make_tensor 5 | 6 | import thunder 7 | import thunder.torch as ttorch 8 | from thunder.tests.framework import instantiate 9 | 10 | import pytest 11 | 12 | 13 | # TODO: convert these tests to OpInfo generated tests 14 | 15 | 16 | @instantiate(dtypes=(thunder.float32,)) 17 | def test_torch_var(executor, device, dtype): 18 | # Tests passing all arguments as function inputs 19 | def foo(a, dim, *, keepdim=False, correction=1): 20 | return ttorch.var(a, dim, keepdim=keepdim, correction=correction) 21 | 22 | traced_foo = executor.make_callable(foo) 23 | 24 | tdtype = ttorch.to_torch_dtype(dtype) 25 | a = torch.testing.make_tensor((4, 4), device=device, dtype=tdtype) 26 | 27 | # Full reduction 28 | thunder_result = traced_foo(a, [0, 1]) 29 | torch_result = torch.var(a, [0, 1]) 30 | assert_close(thunder_result, torch_result) 31 | 32 | # Reduce along dim 1 33 | thunder_result = traced_foo(a, [1]) 34 | torch_result = torch.var(a, [1]) 35 | assert_close(thunder_result, torch_result) 36 | 37 | # Specifying the correction 38 | thunder_result = traced_foo(a, [1], correction=2) 39 | torch_result = torch.var(a, [1], correction=2) 40 | assert_close(thunder_result, torch_result) 41 | 42 | # Specifying keepdim 43 | thunder_result = traced_foo(a, [1], keepdim=True, correction=2) 44 | torch_result = torch.var(a, [1], keepdim=True, correction=2) 45 | assert_close(thunder_result, torch_result) 46 | 47 | # Tests passing arguments as constants 48 | def foo(a): 49 | return ttorch.var(a, [0, 1], keepdim=True, correction=2) 50 | 51 | traced_foo = executor.make_callable(foo) 52 | 53 | a = torch.testing.make_tensor((4, 4), device=device, dtype=tdtype) 54 | 55 | thunder_result = traced_foo(a) 56 | torch_result = torch.var(a, [0, 1], keepdim=True, correction=2) 57 | assert_close(thunder_result, torch_result) 58 | 59 | 60 | @instantiate(dtypes=(thunder.float32,)) 61 | def test_torch_mean(executor, device, dtype): 62 | def foo(a, dim=None, keepdim=False, *, dtype=None): 63 | return ttorch.mean(a, dim, keepdim, dtype=dtype) 64 | 65 | traced_foo = executor.make_callable(foo) 66 | 67 | tdtype = ttorch.to_torch_dtype(dtype) 68 | a = torch.testing.make_tensor((4, 4), device=device, dtype=tdtype) 69 | 70 | # Full reduction 71 | thunder_result = traced_foo(a, [0, 1]) 72 | torch_result = torch.mean(a, [0, 1]) 73 | assert_close(thunder_result, torch_result) 74 | 75 | # Reduce along dim 1 76 | thunder_result = traced_foo(a, [1]) 77 | torch_result = torch.mean(a, [1]) 78 | assert_close(thunder_result, torch_result) 79 | 80 | # Reduce with () dims 81 | thunder_result = traced_foo(a, ()) 82 | torch_result = torch.mean(a, ()) 83 | assert_close(thunder_result, torch_result) 84 | 85 | 86 | @instantiate(dtypes=(thunder.float32,)) 87 | def test_var_mean(executor, device, dtype): 88 | def foo(a, dim=None, keepdim=False, *, correction=1): 89 | return ttorch.var_mean(a, dim, keepdim=keepdim, correction=correction) 90 | 91 | traced_foo = executor.make_callable(foo) 92 | 93 | tdtype = ttorch.to_torch_dtype(dtype) 94 | a = torch.testing.make_tensor((4, 4), device=device, dtype=tdtype) 95 | 96 | # Full reduction 97 | thunder_result = traced_foo(a, [0, 1]) 98 | torch_result = torch.var_mean(a, [0, 1]) 99 | assert_close(thunder_result, torch_result) 100 | 101 | # Reduce along dim 1 102 | thunder_result = traced_foo(a, [1]) 103 | torch_result = torch.var_mean(a, [1]) 104 | assert_close(thunder_result, torch_result) 105 | 106 | # Tests passing arguments as constants 107 | def foo(a): 108 | return ttorch.var_mean(a, [0, 1], keepdim=True, correction=2) 109 | 110 | traced_foo = executor.make_callable(foo) 111 | 112 | a = torch.testing.make_tensor((4, 4), device=device, dtype=tdtype) 113 | 114 | thunder_result = traced_foo(a) 115 | torch_result = torch.var_mean(a, [0, 1], keepdim=True, correction=2) 116 | assert_close(thunder_result, torch_result) 117 | -------------------------------------------------------------------------------- /thunder/tests/test_shape_ops.py: -------------------------------------------------------------------------------- 1 | import thunder 2 | import torch 3 | from thunder.tests.framework import JAX_AVAILABLE 4 | 5 | if JAX_AVAILABLE: 6 | pass 7 | 8 | 9 | def test_pad_cast_value_itof(): 10 | """ 11 | Pad should cast the given value to the type of tensor and pad that value. 12 | """ 13 | 14 | def fqn(): 15 | x = torch.tensor([2, 3], dtype=torch.int32) 16 | y = torch.nn.functional.pad(x, pad=(1, 2), value=6.4) 17 | return y 18 | 19 | th_fqn = thunder.jit(fqn) 20 | v = th_fqn() 21 | assert v[0] == 6 22 | assert v[1] == 2 23 | assert v[2] == 3 24 | assert v[3] == 6 25 | assert v[4] == 6 26 | 27 | 28 | def test_pad_cast_value_ftoi(): 29 | """ 30 | Pad should cast the given value to the type of tensor and pad that value. 31 | """ 32 | 33 | def fqn(): 34 | x = torch.tensor([2.4, 3.8]) 35 | y = torch.nn.functional.pad(x, pad=(1, 2), value=1) 36 | return y 37 | 38 | th_fqn = thunder.jit(fqn) 39 | v = th_fqn() 40 | assert v[0] == 1.0 41 | assert v[1] == 2.4 42 | assert v[2] == 3.8 43 | assert v[3] == 1.0 44 | assert v[4] == 1.0 45 | -------------------------------------------------------------------------------- /thunder/tests/test_triton_ce.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | from numbers import Number 3 | 4 | import pytest 5 | 6 | import torch 7 | from torch.testing import assert_close 8 | 9 | import thunder 10 | 11 | 12 | from thunder.tests.opinfos import get_opinfo 13 | from thunder.tests.framework import instantiate, requiresCUDA, requiresTriton, run_snippet, IN_CI 14 | from thunder.executors import triton_utils 15 | 16 | 17 | from lightning_utilities.core.imports import package_available 18 | 19 | TRITON_AVAILABLE = package_available("triton") 20 | 21 | # Requires triton 2.1 or greater 22 | triton: None | Any = None 23 | min_triton_version = "2.1" 24 | if triton_utils.is_triton_version_at_least(min_triton_version): 25 | from thunder.executors.triton_crossentropy import triton_ex 26 | 27 | 28 | # NOTE This test modifies the global executor map, so it technically should not 29 | # be run in parallel with other tests 30 | @pytest.mark.parametrize( 31 | "dtype", 32 | [torch.float16, torch.bfloat16, torch.float32, torch.float64], 33 | ids=("float16", "bfloat16", "float32", "float64"), 34 | ) 35 | @pytest.mark.parametrize("device,", ["cuda"]) 36 | @requiresCUDA 37 | @requiresTriton 38 | def test_triton_cross_entropy(device, dtype): 39 | if IN_CI: 40 | pytest.skip("Currently these tests are skipped in CI for speed") 41 | 42 | logits = torch.randn([2048, 50257], device=device, dtype=dtype) 43 | labels = torch.randint(0, 50257, [2048], device=device) 44 | reduction = "sum" 45 | ignore_index = labels[5].item() 46 | weight = torch.rand(50257, device=device, dtype=dtype, requires_grad=False) 47 | expected = torch.nn.functional.cross_entropy( 48 | logits, labels, weight=weight, reduction=reduction, ignore_index=ignore_index 49 | ) 50 | 51 | def test(logits, labels, weight, reduction, ignore_index): 52 | return thunder.torch.cross_entropy( 53 | logits, labels, weight=weight, reduction=reduction, ignore_index=ignore_index 54 | ) 55 | 56 | ctest = thunder.jit(test, executors=[triton_ex]) 57 | actual = ctest(logits, labels, weight, reduction, ignore_index) 58 | torch.testing.assert_close(actual, expected) 59 | last_trace = thunder.last_traces(ctest)[-1] 60 | assert any(bsym.sym.name == "triton_crossentropy" for bsym in last_trace.bound_symbols) 61 | 62 | 63 | def snippet_torch_consistency(op, torch_op, sample): 64 | thunder_result = op(*sample.args, **sample.kwargs) 65 | torch_result = torch_op(*sample.args, **sample.kwargs) 66 | torch.cuda.synchronize() 67 | 68 | # Sets atol and rtol to looser tolerances than assert_close's defaults 69 | atol: Number = 1e-1 70 | rtol: Number = 1.3e-6 71 | assert_close(thunder_result, torch_result, equal_nan=True, atol=atol, rtol=rtol) 72 | 73 | 74 | @pytest.mark.parametrize( 75 | "dtype", 76 | [torch.float16, torch.bfloat16, torch.float32, torch.float64], 77 | ids=("float16", "bfloat16", "float32", "float64"), 78 | ) 79 | @pytest.mark.parametrize("device,", ["cuda"]) 80 | @requiresCUDA 81 | @requiresTriton 82 | def test_triton_cross_entropy_vs_torch_consistency(device, dtype): 83 | if IN_CI: 84 | pytest.skip("Currently these tests are skipped in CI for speed") 85 | if dtype == torch.float16 or dtype == torch.bfloat16: 86 | pytest.skip("Currently skipping float16 and bfloat16 due to numerical accuracy") 87 | 88 | opinfo = get_opinfo("cross_entropy") 89 | 90 | def foo(*args, **kwargs): 91 | return torch.nn.functional.cross_entropy(*args, **kwargs) 92 | 93 | ce = thunder.jit(foo, executors=[triton_ex]) 94 | 95 | # NOTE reference inputs take a long time to run in CI, so this uses sample inputs in CI 96 | # opinfo.reference_inputs if not IN_CI else opinfo.sample_inputs 97 | # reference inputs for cross_entropy contains cases not implemented in Thunder 98 | input_generator = opinfo.reference_inputs 99 | 100 | for sample in input_generator(device=device, dtype=dtype, requires_grad=False): 101 | result = run_snippet( 102 | snippet_torch_consistency, 103 | opinfo, 104 | device, 105 | dtype, 106 | ce, 107 | opinfo.torch_reference, 108 | sample, 109 | ) 110 | if result is not None: 111 | return result 112 | -------------------------------------------------------------------------------- /thunder/torch/langctx.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Callable 2 | 3 | from thunder.core.langctxs import LanguageContext, register_langctx, Languages, resolve_language 4 | from thunder.core.pytree import tree_flatten 5 | from thunder.core.proxies import TensorProxy 6 | 7 | # 8 | # Creates and registers the torch language context 9 | # 10 | # NOTE That this is done separately from the definition of thunder.torch operations, because the 11 | # language context must be available before those operations are defined 12 | 13 | _method_name_to_fn_map: dict[str, Callable] = {} 14 | _property_name_to_fn_map: dict[str, Callable] = {} 15 | 16 | 17 | class TorchCtx(LanguageContext): 18 | def __init__(self): 19 | super().__init__("torch") 20 | 21 | def has_method(self, id: str) -> bool: 22 | return id in _method_name_to_fn_map or id in _property_name_to_fn_map 23 | 24 | def get_method(self, id: str, *args, **kwargs) -> Callable: 25 | # Note: concrete implmenetations should only raise AttributeError or 26 | # return None for "missing" methods as the proxies will 27 | # route __getattr__ to here and hasattr relies on __getattr__ 28 | # throwing AttributeError (only) when the attribute does 29 | # not exist. 30 | inps, _ = tree_flatten((args, kwargs)) 31 | 32 | has_tensor_input: bool = False 33 | for x in inps: 34 | if isinstance(x, TensorProxy): 35 | has_tensor_input = True 36 | break 37 | if has_tensor_input: 38 | method: None | Callable = _method_name_to_fn_map.get(id, None) 39 | prop: None | Callable = _property_name_to_fn_map.get(id, None) 40 | if method is None and prop is None: 41 | raise AttributeError(f"The {self.name} language context has no method or attribute {id}") 42 | if method: 43 | return method 44 | else: 45 | return prop(inps[0]) 46 | 47 | # has_tensor_input is False 48 | # Defers to the CLANG language context when there are no tensor inputs= 49 | # (the clang language context handles operations on NumberProxies and Numbers) 50 | primsctx: LanguageContext = resolve_language(Languages.CLANG) 51 | if not primsctx.has_method(id): 52 | raise AttributeError( 53 | f"Attempting to call method {id} in the torch language context, but it has no tensor inputs and the primitive language context (which handles numbers) doesn't have the method" 54 | ) 55 | prim_method: Callable = primsctx.get_method(id, *args, **kwargs) 56 | return prim_method 57 | 58 | 59 | torchctx = TorchCtx() 60 | register_langctx(Languages.TORCH, torchctx) 61 | register_langctx("torch", torchctx) 62 | 63 | 64 | # Registers a method with the torch language context 65 | def register_method(method_name: str, method: Callable, /) -> None: 66 | _method_name_to_fn_map[method_name] = method 67 | 68 | 69 | def register_property(property_name, property) -> None: 70 | _property_name_to_fn_map[property_name] = property 71 | -------------------------------------------------------------------------------- /thunder/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | from .constant_folding import ConstantFolding 2 | from .materialization import MaterializationTransform 3 | from .qlora import LORATransform 4 | from .prune_prologue_checks import PrunePrologueChecks 5 | from .extraction_only_prologue_transform import ExtractionOnlyPrologueTransform 6 | 7 | 8 | __all__ = [ 9 | "ConstantFolding", 10 | "LORATransform", 11 | "MaterializationTransform", 12 | "PrunePrologueChecks", 13 | "ExtractionOnlyPrologueTransform", 14 | ] 15 | -------------------------------------------------------------------------------- /thunder/transforms/extraction_only_prologue_transform.py: -------------------------------------------------------------------------------- 1 | import thunder 2 | from thunder.core.trace import from_trace 3 | from thunder.core.proxies import ProxyTag 4 | 5 | 6 | class ExtractionOnlyPrologueTransform(thunder.Transform): 7 | def transform_traces_pre_prologue(self, prologue_trace, computation_trace, epilogue_trace, **kwargs): 8 | new_prologue_trace = from_trace(prologue_trace) 9 | new_bsyms = [] 10 | 11 | for bsym in prologue_trace.bound_symbols: 12 | # NOTE - We assume TensorProxy's tagged with `STATIC_MEMORY_LOCATION` to 13 | # be Parameters or Buffer. It should be safe to disable check for 14 | # tensors we deem to be static. 15 | if ( 16 | bsym.sym.id == thunder.prims.PrimIDs.CHECK_TENSOR_SHAPE_AND_METADATA 17 | and ProxyTag.STATIC_MEMORY_LOCATION in bsym.args[0].tags 18 | ): 19 | continue 20 | 21 | new_bsyms.append(bsym) 22 | 23 | new_prologue_trace.bound_symbols = new_bsyms 24 | 25 | new_prologue_trace.set_provenance("Extraction only prologue pass") 26 | return new_prologue_trace, computation_trace, epilogue_trace 27 | -------------------------------------------------------------------------------- /thunder/transforms/prune_prologue_checks.py: -------------------------------------------------------------------------------- 1 | import thunder 2 | from thunder.core import prims as prims 3 | 4 | 5 | class PrunePrologueChecks(thunder.core.transform_common.Transform): 6 | """A Transform to prune Prologue checks 7 | 8 | This transform removes prologue checks and can be applied when the user 9 | controls the environment enough to ensure that these checks would always 10 | succeed. By default, we remove all checks that use the module. 11 | """ 12 | 13 | def __init__(self, prune_all_checks=False): 14 | self.prune_all_checks = prune_all_checks 15 | 16 | def transform_traces_pre_prologue(self, prologue_trace, compute_trace, epilogue_trace, **kwargs): 17 | def is_check(bsym): 18 | return bsym.sym in { 19 | prims.check_tensor_shape_and_metadata, 20 | prims.check_string_value, 21 | prims.check_number_type_and_value, 22 | prims.check_len, 23 | prims.check_literal_like, 24 | } 25 | 26 | if not self.prune_all_checks: 27 | bsyms_to_skip = set() 28 | module_member_names = set() 29 | 30 | for bsym in prologue_trace.bound_symbols: 31 | if bsym.sym in {prims.unpack_trivial, prims.unpack_cache_info, prims.python_return}: 32 | # These don't have inputs but need to be skipped to not trigger false positives 33 | # python_return may have no inputs 34 | continue 35 | if bsym.sym is prims.unpack_function_obj: 36 | for o in bsym.flat_proxy_outs: 37 | module_member_names.add(o.name) 38 | continue 39 | 40 | input_names = {i.name in module_member_names for i in bsym.flat_proxy_args} 41 | if all(input_names): 42 | # This has the special case of no proxy inputs, which is the case for unpack_function_obj, 43 | # the root of module_member_names 44 | assert input_names, f"unexpected symbol {bsym.sym.name} without inputs" 45 | for o in bsym.flat_proxy_outs: 46 | module_member_names.add(o.name) 47 | if is_check(bsym): 48 | bsyms_to_skip.add(bsym) 49 | 50 | def should_skip_bsym(bsym): 51 | return bsym in bsyms_to_skip 52 | 53 | else: 54 | should_skip_bsym = is_check 55 | 56 | new_prologue_trace = thunder.core.trace.from_trace(prologue_trace) 57 | for bsym in prologue_trace.bound_symbols: 58 | if not should_skip_bsym(bsym): 59 | new_prologue_trace.bound_symbols.append(bsym.from_bsym()) 60 | 61 | new_prologue_trace = thunder.core.transform_common.dce(new_prologue_trace) 62 | new_prologue_trace.set_provenance(thunder.core.trace.TraceProvenance(f"{self.__class__.__name__}")) 63 | 64 | return new_prologue_trace, compute_trace, epilogue_trace 65 | -------------------------------------------------------------------------------- /thunder/transforms/utils.py: -------------------------------------------------------------------------------- 1 | import thunder 2 | from thunder.core import utils 3 | from thunder.core import prims 4 | from thunder.core.trace import TraceCtx 5 | from thunder.core.pytree import tree_map 6 | 7 | 8 | def get_checks(prologue_trace): 9 | # returns a dictionary mapping model param names to (check bsym, get param bsym 10 | check_dict = {} 11 | prologue_producers, prologue_consumers = utils.producers_and_consumers(prologue_trace) 12 | for bsym in prologue_trace.bound_symbols: 13 | if bsym.sym == prims.unpack_parameter or bsym.sym == prims.unpack_buffer: 14 | param_thunder_module, param_name = bsym.args 15 | checks = [ 16 | bsym2 for bsym2 in prologue_consumers[bsym.output] if bsym2.sym == prims.check_tensor_shape_and_metadata 17 | ] 18 | assert ( 19 | len(checks) == 1 20 | ), f"expected each parameter and buffer to have exactly one checker, but {bsym.output} has {len(checks)}" 21 | assert isinstance(param_name, str) 22 | check_dict[param_name] = (checks[0], bsym) 23 | return check_dict 24 | 25 | 26 | def add_trace_output(trace, output, subindex: int | None = None): 27 | ret_node = trace.bound_symbols[-1] 28 | assert ret_node.sym == prims.python_return 29 | assert len(ret_node.args) == 1 30 | (ret_args,) = ret_node.args 31 | 32 | if subindex is None: 33 | ret_args = (*ret_args, output) 34 | else: 35 | assert isinstance(ret_args[subindex], tuple) 36 | ret_args = (*ret_args[:subindex], (*ret_args[subindex], output), *ret_args[subindex + 1 :]) 37 | 38 | ret_node.args = (ret_args,) 39 | 40 | 41 | def trace_with_replaced_proxy_metadata(trace: TraceCtx, proxy_replacement_metadata) -> TraceCtx: 42 | t = TraceCtx(trace.fn) 43 | 44 | proxymap: dict[str, thunder.Proxy] = {} 45 | 46 | def map_proxy(p): 47 | if isinstance(p, thunder.Proxy): 48 | return proxymap[p.name] 49 | return p 50 | 51 | def create_proxy(p): 52 | if isinstance(p, thunder.Proxy): 53 | if p.name in proxymap: # happens with subsymbols 54 | return proxymap[p.name] 55 | with thunder.core.trace.tracectx(t): 56 | np = p.replace(**proxy_replacement_metadata.get(p.name, {})) 57 | proxymap[p.name] = np 58 | return np 59 | return p 60 | 61 | def process_bound_symbols(src_bound_symbols, target_bound_symbols): 62 | for bsym in src_bound_symbols: 63 | new_args = tree_map(map_proxy, bsym.args) 64 | new_kwargs = tree_map(map_proxy, bsym.kwargs) 65 | new_output = tree_map(create_proxy, bsym.output) 66 | new_bsym = bsym.from_bsym(output=new_output, args=new_args, kwargs=new_kwargs, subsymbols=[]) 67 | target_bound_symbols.append(new_bsym) 68 | if len(bsym.subsymbols) > 0: 69 | process_bound_symbols(bsym.subsymbols, new_bsym.subsymbols) 70 | 71 | process_bound_symbols(trace.bound_symbols, t.bound_symbols) 72 | 73 | t.args = tree_map(map_proxy, trace.args) 74 | t.kwargs = tree_map(map_proxy, trace.kwargs) 75 | t._siginfo = trace._siginfo 76 | return t 77 | --------------------------------------------------------------------------------